From a03b8c49c62bd8085f6b9fd68bc540f5acf51696 Mon Sep 17 00:00:00 2001 From: yassinnouh21 Date: Wed, 9 Apr 2025 01:05:48 +0200 Subject: [PATCH 01/11] feat: add online document retrieval with hybrid search capabilities Signed-off-by: yassinnouh21 --- sdk/python/feast/feature_store.py | 42 ++- .../postgres_online_store/postgres.py | 289 +++++++++++++++++- 2 files changed, 318 insertions(+), 13 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 1578a91574e..d13fd85b4f3 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -868,7 +868,8 @@ def apply( views_to_update = [ ob for ob in objects - if ( + if + ( # BFVs are not handled separately from FVs right now. (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) and not isinstance(ob, StreamFeatureView) @@ -1955,9 +1956,9 @@ def retrieve_online_documents_v2( distance_metric: The distance metric to use for retrieval. query_string: The query string to retrieve the closest document features using keyword search (bm25). """ - assert query is not None or query_string is not None, ( - "Either query or query_string must be provided." - ) + assert ( + query is not None or query_string is not None + ), "Either query or query_string must be provided." ( available_feature_views, @@ -2097,15 +2098,34 @@ def _retrieve_from_online_store_v2( entity_key_dict[key] = [] entity_key_dict[key].append(python_value) - table_entity_values, idxs, output_len = utils._get_unique_entities_from_values( - entity_key_dict, - ) - features_to_request: List[str] = [] if requested_features: features_to_request = requested_features + ["distance"] + # Add text_rank for text search queries + if query_string is not None: + features_to_request.append("text_rank") else: features_to_request = ["distance"] + # Add text_rank for text search queries + if query_string is not None: + features_to_request.append("text_rank") + + if not datevals: + online_features_response = GetOnlineFeaturesResponse(results=[]) + for feature in features_to_request: + field = online_features_response.results.add() + field.values.extend([]) + field.statuses.extend([]) + field.event_timestamps.extend([]) + online_features_response.metadata.feature_names.val.extend( + features_to_request + ) + return OnlineResponse(online_features_response) + + table_entity_values, idxs, output_len = utils._get_unique_entities_from_values( + entity_key_dict, + ) + feature_data = utils._convert_rows_to_protobuf( requested_features=features_to_request, read_rows=list(zip(datevals, list_of_feature_dicts)), @@ -2238,9 +2258,9 @@ def write_logged_features( if not isinstance(source, FeatureService): raise ValueError("Only feature service is currently supported as a source") - assert source.logging_config is not None, ( - "Feature service must be configured with logging config in order to use this functionality" - ) + assert ( + source.logging_config is not None + ), "Feature service must be configured with logging config in order to use this functionality" assert isinstance(logs, (pa.Table, Path)) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index b5c1dd05f3a..bb79390a330 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -119,6 +119,12 @@ def online_write_batch( for feature_name, val in values.items(): vector_val = None + value_text = None + + # Check if the feature type is STRING + if val.WhichOneof("val") == "string_val": + value_text = val.string_val + if config.online_store.vector_enabled: vector_val = get_list_val_str(val) insert_values.append( @@ -126,6 +132,7 @@ def online_write_batch( entity_key_bin, feature_name, val.SerializeToString(), + value_text, vector_val, timestamp, created_ts, @@ -136,11 +143,12 @@ def online_write_batch( sql_query = sql.SQL( """ INSERT INTO {} - (entity_key, feature_name, value, vector_value, event_ts, created_ts) - VALUES (%s, %s, %s, %s, %s, %s) + (entity_key, feature_name, value, value_text, vector_value, event_ts, created_ts) + VALUES (%s, %s, %s, %s, %s, %s, %s) ON CONFLICT (entity_key, feature_name) DO UPDATE SET value = EXCLUDED.value, + value_text = EXCLUDED.value_text, vector_value = EXCLUDED.vector_value, event_ts = EXCLUDED.event_ts, created_ts = EXCLUDED.created_ts; @@ -308,6 +316,11 @@ def update( else: # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility vector_value_type = "BYTEA" + + has_string_features = any( + f.dtype.to_value_type().value == 2 for f in table.features + ) # 2 is STRING in ValueType + cur.execute( sql.SQL( """ @@ -316,6 +329,7 @@ def update( entity_key BYTEA, feature_name TEXT, value BYTEA, + value_text TEXT NULL, -- Added for FTS vector_value {} NULL, event_ts TIMESTAMPTZ, created_ts TIMESTAMPTZ, @@ -331,6 +345,16 @@ def update( ) ) + if has_string_features: + cur.execute( + sql.SQL( + """CREATE INDEX IF NOT EXISTS {} ON {} USING GIN (to_tsvector('english', value_text));""" + ).format( + sql.Identifier(f"{table_name}_fts_idx"), + sql.Identifier(table_name), + ) + ) + conn.commit() def teardown( @@ -456,6 +480,267 @@ def retrieve_online_documents( return result + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + embedding: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Retrieve documents using vector similarity search or keyword search in PostgreSQL. + + Args: + config: Feast configuration object + table: FeatureView object as the table to search + requested_features: List of requested features to retrieve + embedding: Query embedding to search for (optional) + top_k: Number of items to return + distance_metric: Distance metric to use (optional) + query_string: The query string to search for using keyword search (optional) + + Returns: + List of tuples containing the event timestamp, entity key, and feature values + """ + if not config.online_store.vector_enabled: + raise ValueError("Vector search is not enabled in the online store config") + + if embedding is None and query_string is None: + raise ValueError("Either embedding or query_string must be provided") + + distance_metric = distance_metric or "L2" + + if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT: + raise ValueError( + f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}" + ) + + distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric] + + string_fields = [] + for feature in table.features: + if ( + feature.dtype.to_value_type().value == 2 + and feature.name in requested_features + ): # 2 is STRING + string_fields.append(feature.name) + + table_name = _table_id(config.project, table) + + with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur: + # Case 1: Hybrid Search (vector + text) + if embedding is not None and query_string is not None and string_fields: + tsquery_str = " & ".join(query_string.split()) + + query = sql.SQL( + """ + SELECT + entity_key, + feature_name, + value, + vector_value, + vector_value {distance_metric_sql} %s::vector as distance, + ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank, + event_ts, + created_ts + FROM {table_name} + WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s) + ORDER BY distance + LIMIT {top_k} + """ + ).format( + distance_metric_sql=sql.SQL(distance_metric_sql), + table_name=sql.Identifier(table_name), + top_k=sql.Literal(top_k), + ) + + cur.execute(query, (embedding, tsquery_str, string_fields, tsquery_str)) + rows = cur.fetchall() + + # Case 2: Vector Search Only + elif embedding is not None: + query = sql.SQL( + """ + SELECT + entity_key, + feature_name, + value, + vector_value, + vector_value {distance_metric_sql} %s::vector as distance, + NULL as text_rank, -- Keep consistent columns + event_ts, + created_ts + FROM {table_name} + ORDER BY distance + LIMIT {top_k} + """ + ).format( + distance_metric_sql=sql.SQL(distance_metric_sql), + table_name=sql.Identifier(table_name), + top_k=sql.Literal(top_k), + ) + + cur.execute(query, (embedding,)) + rows = cur.fetchall() + + # Case 3: Text Search Only + elif query_string is not None and string_fields: + tsquery_str = " & ".join(query_string.split()) + query = sql.SQL( + """ + WITH text_matches AS ( + SELECT DISTINCT entity_key, ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank + FROM {table_name} + WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s) + ORDER BY text_rank DESC + LIMIT {top_k} + ) + SELECT + t1.entity_key, + t1.feature_name, + t1.value, + t1.vector_value, + NULL as distance, + t2.text_rank, + t1.event_ts, + t1.created_ts + FROM {table_name} t1 + INNER JOIN text_matches t2 ON t1.entity_key = t2.entity_key + WHERE t1.feature_name = ANY(%s) + ORDER BY t2.text_rank DESC + """ + ).format( + table_name=sql.Identifier(table_name), + top_k=sql.Literal(top_k), + ) + + cur.execute( + query, (tsquery_str, string_fields, tsquery_str, requested_features) + ) + rows = cur.fetchall() + + else: + raise ValueError( + "Either vector_enabled must be True for embedding search or string fields must be available for query_string search" + ) + + # Group by entity_key to build feature records + entities_dict: Dict[str, Dict[str, Any]] = defaultdict( + lambda: { + "features": {}, + "timestamp": None, + "entity_key_proto": None, + "vector_distance": float("inf"), + "text_rank": 0.0, + } + ) + + for ( + entity_key_bytes, + feature_name, + feature_val_bytes, + vector_val, + distance, + text_rank, + event_ts, + created_ts, + ) in rows: + entity_key_proto = None + if entity_key_bytes: + from feast.infra.key_encoding_utils import deserialize_entity_key + + entity_key_proto = deserialize_entity_key(entity_key_bytes) + + key = entity_key_bytes.hex() if entity_key_bytes else None + + if key is None: + continue + + entities_dict[key]["entity_key_proto"] = entity_key_proto + + if ( + entities_dict[key]["timestamp"] is None + or event_ts > entities_dict[key]["timestamp"] + ): + entities_dict[key]["timestamp"] = event_ts + + val = ValueProto() + if feature_val_bytes: + val.ParseFromString(feature_val_bytes) + + entities_dict[key]["features"][feature_name] = val + + if distance is not None: + entities_dict[key]["vector_distance"] = min( + entities_dict[key]["vector_distance"], float(distance) + ) + if text_rank is not None: + entities_dict[key]["text_rank"] = max( + entities_dict[key]["text_rank"], float(text_rank) + ) + + if embedding is not None and query_string is not None: + + def sort_key(x): + return x["vector_distance"] + elif embedding is not None: + + def sort_key(x): + return x["vector_distance"] + else: # Text only + + def sort_key(x): + return x["text_rank"] + + sorted_entities = sorted( + entities_dict.values(), key=sort_key, reverse=(embedding is None) + )[:top_k] + + result: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ] = [] + for entity_data in sorted_entities: + features = ( + entity_data["features"].copy() + if isinstance(entity_data["features"], dict) + else None + ) + + if features is not None: + if "vector_distance" in entity_data and entity_data[ + "vector_distance" + ] != float("inf"): + dist_val = ValueProto() + dist_val.double_val = entity_data["vector_distance"] + features["distance"] = dist_val + + if embedding is None or query_string is not None: + rank_val = ValueProto() + rank_val.double_val = entity_data["text_rank"] + features["text_rank"] = rank_val + + result.append( + ( + entity_data["timestamp"], + entity_data["entity_key_proto"], + features, + ) + ) + return result + def _table_id(project: str, table: FeatureView) -> str: return f"{project}_{table.name}" From 75233cfc435a2c3b27cfb8dcc7bf943710b6e1c7 Mon Sep 17 00:00:00 2001 From: yassinnouh21 Date: Wed, 9 Apr 2025 01:06:43 +0200 Subject: [PATCH 02/11] test: add integration tests for hybrid search and document retrieval Signed-off-by: yassinnouh21 --- .../online_store/test_universal_online.py | 177 +++++++++++++++++- 1 file changed, 173 insertions(+), 4 deletions(-) diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index b563f00bfd1..a547e36699c 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -1,8 +1,8 @@ -import datetime import os +import random import time import unittest -from datetime import timedelta +from datetime import datetime, timedelta from typing import Any, Dict, List, Tuple, Union import assertpy @@ -18,9 +18,10 @@ from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.field import Field +from feast.infra.offline_stores.file_source import FileSource from feast.infra.utils.postgres.postgres_config import ConnectionType from feast.online_response import TIMESTAMP_POSTFIX -from feast.types import Float32, Int32, String +from feast.types import Array, Float32, Int32, Int64, String, ValueType from feast.utils import _utc_now from feast.wait import wait_retry_backoff from tests.integration.feature_repos.repo_configuration import ( @@ -861,7 +862,7 @@ def assert_feature_service_entity_mapping_correctness( @pytest.mark.integration -@pytest.mark.universal_online_stores(only=["pgvector", "elasticsearch", "qdrant"]) +@pytest.mark.universal_online_stores(only=["pgvector"]) def test_retrieve_online_documents(environment, fake_document_data): fs = environment.feature_store df, data_source = fake_document_data @@ -919,3 +920,171 @@ def test_retrieve_online_milvus_documents(environment, fake_document_data): assert len(documents["item_id"]) == 2 assert documents["item_id"] == [2, 3] + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["pgvector"]) +def test_postgres_retrieve_online_documents_v2(environment, fake_document_data): + """Test retrieval of documents using PostgreSQL vector store capabilities.""" + fs = environment.feature_store + + n_rows = 20 + vector_dim = 2 + random.seed(42) + + df = pd.DataFrame( + { + "item_id": list(range(n_rows)), + "embedding": [list(np.random.random(vector_dim)) for _ in range(n_rows)], + "text_field": [ + f"Document text content {i} with searchable keywords" + for i in range(n_rows) + ], + "category": [f"Category-{i%5}" for i in range(n_rows)], + "event_timestamp": [datetime.now() for _ in range(n_rows)], + } + ) + + data_source = FileSource( + path="dummy_path.parquet", timestamp_field="event_timestamp" + ) + + item = Entity( + name="item_id", + join_keys=["item_id"], + value_type=ValueType.INT64, + ) + + item_embeddings_fv = FeatureView( + name="item_embeddings", + entities=[item], + schema=[ + Field(name="embedding", dtype=Array(Float32), vector_index=True), + Field(name="text_field", dtype=String), + Field(name="category", dtype=String), + Field(name="item_id", dtype=Int64), + ], + source=data_source, + ) + + fs.apply([item_embeddings_fv, item]) + fs.write_to_online_store("item_embeddings", df) + + # Test 1: Vector similarity search + query_embedding = list(np.random.random(vector_dim)) + vector_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=5, + distance_metric="L2", + ).to_dict() + + assert len(vector_results["embedding"]) == 5 + assert len(vector_results["distance"]) == 5 + assert len(vector_results["text_field"]) == 5 + assert len(vector_results["category"]) == 5 + + # Test 2: Vector similarity search with Cosine distance + vector_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + top_k=5, + distance_metric="cosine", + ).to_dict() + + assert len(vector_results["embedding"]) == 5 + assert len(vector_results["distance"]) == 5 + assert len(vector_results["text_field"]) == 5 + assert len(vector_results["category"]) == 5 + + # Test 3: Full text search + text_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query_string="searchable keywords", + top_k=5, + ).to_dict() + + # Verify text search results + assert len(text_results["text_field"]) == 5 + assert len(text_results["text_rank"]) == 5 + assert len(text_results["category"]) == 5 + assert len(text_results["item_id"]) == 5 + + # Verify text rank values are between 0 and 1 + assert all(0 <= rank <= 1 for rank in text_results["text_rank"]) + + # Verify results are sorted by text rank in descending order + text_ranks = text_results["text_rank"] + assert all(text_ranks[i] >= text_ranks[i + 1] for i in range(len(text_ranks) - 1)) + + # Test 4: Hybrid search (vector + text) + hybrid_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + query_string="searchable keywords", + top_k=5, + distance_metric="L2", + ).to_dict() + + # Verify hybrid search results + assert len(hybrid_results["embedding"]) == 5 + assert len(hybrid_results["distance"]) == 5 + assert len(hybrid_results["text_field"]) == 5 + assert len(hybrid_results["text_rank"]) == 5 + assert len(hybrid_results["category"]) == 5 + assert len(hybrid_results["item_id"]) == 5 + + # Test 5: Hybrid search with different text query + hybrid_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query=query_embedding, + query_string="Category-1", + top_k=5, + distance_metric="L2", + ).to_dict() + + # Verify results contain only documents from Category-1 + assert all(cat == "Category-1" for cat in hybrid_results["category"]) + + # Test 6: Full text search with no matches + no_match_results = fs.retrieve_online_documents_v2( + features=[ + "item_embeddings:embedding", + "item_embeddings:text_field", + "item_embeddings:category", + "item_embeddings:item_id", + ], + query_string="nonexistent keyword", + top_k=5, + ).to_dict() + + # Verify no results are returned for non-matching query + assert "text_field" in no_match_results + assert len(no_match_results["text_field"]) == 0 + assert "text_rank" in no_match_results + assert len(no_match_results["text_rank"]) == 0 From 776c3271f6db482300296cb9cdb1906ce6b1b98e Mon Sep 17 00:00:00 2001 From: yassinnouh21 Date: Wed, 9 Apr 2025 01:14:04 +0200 Subject: [PATCH 03/11] fix formatting Signed-off-by: yassinnouh21 --- sdk/python/feast/feature_store.py | 15 +++++++-------- .../online_store/test_universal_online.py | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index d13fd85b4f3..bd6f1873874 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -868,8 +868,7 @@ def apply( views_to_update = [ ob for ob in objects - if - ( + if ( # BFVs are not handled separately from FVs right now. (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) and not isinstance(ob, StreamFeatureView) @@ -1956,9 +1955,9 @@ def retrieve_online_documents_v2( distance_metric: The distance metric to use for retrieval. query_string: The query string to retrieve the closest document features using keyword search (bm25). """ - assert ( - query is not None or query_string is not None - ), "Either query or query_string must be provided." + assert query is not None or query_string is not None, ( + "Either query or query_string must be provided." + ) ( available_feature_views, @@ -2258,9 +2257,9 @@ def write_logged_features( if not isinstance(source, FeatureService): raise ValueError("Only feature service is currently supported as a source") - assert ( - source.logging_config is not None - ), "Feature service must be configured with logging config in order to use this functionality" + assert source.logging_config is not None, ( + "Feature service must be configured with logging config in order to use this functionality" + ) assert isinstance(logs, (pa.Table, Path)) diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index a547e36699c..c50801dd745 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -940,7 +940,7 @@ def test_postgres_retrieve_online_documents_v2(environment, fake_document_data): f"Document text content {i} with searchable keywords" for i in range(n_rows) ], - "category": [f"Category-{i%5}" for i in range(n_rows)], + "category": [f"Category-{i % 5}" for i in range(n_rows)], "event_timestamp": [datetime.now() for _ in range(n_rows)], } ) From e1f0cae70ed7248d7ec947805493ae54210ceec6 Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Wed, 9 Apr 2025 10:40:48 +0200 Subject: [PATCH 04/11] fix: Refactor string_fields assignment to filter features by dtype and requested features Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> --- .../infra/online_stores/postgres_online_store/postgres.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index bb79390a330..cdd081488a2 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -526,7 +526,13 @@ def retrieve_online_documents_v2( distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric] - string_fields = [] + string_fields = [ + feature.name + for feature in table.features + if feature.dtype.to_value_type().value == 2 + and feature.name in requested_features + ] + for feature in table.features: if ( feature.dtype.to_value_type().value == 2 From 246d6a671848a5ae7d15ca1e6c58eb2c9821c8aa Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Wed, 9 Apr 2025 10:50:59 +0200 Subject: [PATCH 05/11] fix: improve query execution logic in postgres.py Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> --- .../postgres_online_store/postgres.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index cdd081488a2..4fe54f88723 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -533,20 +533,12 @@ def retrieve_online_documents_v2( and feature.name in requested_features ] - for feature in table.features: - if ( - feature.dtype.to_value_type().value == 2 - and feature.name in requested_features - ): # 2 is STRING - string_fields.append(feature.name) - table_name = _table_id(config.project, table) with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur: - # Case 1: Hybrid Search (vector + text) if embedding is not None and query_string is not None and string_fields: + # Case 1: Hybrid Search (vector + text) tsquery_str = " & ".join(query_string.split()) - query = sql.SQL( """ SELECT @@ -568,12 +560,10 @@ def retrieve_online_documents_v2( table_name=sql.Identifier(table_name), top_k=sql.Literal(top_k), ) + params = (embedding, tsquery_str, string_fields, tsquery_str) - cur.execute(query, (embedding, tsquery_str, string_fields, tsquery_str)) - rows = cur.fetchall() - - # Case 2: Vector Search Only elif embedding is not None: + # Case 2: Vector Search Only query = sql.SQL( """ SELECT @@ -594,12 +584,10 @@ def retrieve_online_documents_v2( table_name=sql.Identifier(table_name), top_k=sql.Literal(top_k), ) + params = (embedding,) - cur.execute(query, (embedding,)) - rows = cur.fetchall() - - # Case 3: Text Search Only elif query_string is not None and string_fields: + # Case 3: Text Search Only tsquery_str = " & ".join(query_string.split()) query = sql.SQL( """ @@ -628,17 +616,16 @@ def retrieve_online_documents_v2( table_name=sql.Identifier(table_name), top_k=sql.Literal(top_k), ) - - cur.execute( - query, (tsquery_str, string_fields, tsquery_str, requested_features) - ) - rows = cur.fetchall() + params = (tsquery_str, string_fields, tsquery_str, requested_features) else: raise ValueError( "Either vector_enabled must be True for embedding search or string fields must be available for query_string search" ) + cur.execute(query, params) + rows = cur.fetchall() + # Group by entity_key to build feature records entities_dict: Dict[str, Dict[str, Any]] = defaultdict( lambda: { From 6e3413cd68d2bb201623c216bb286103dd05f80c Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Wed, 9 Apr 2025 10:54:40 +0200 Subject: [PATCH 06/11] fix linter Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> --- .../infra/online_stores/postgres_online_store/postgres.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index 4fe54f88723..3c7469b6adf 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -536,6 +536,9 @@ def retrieve_online_documents_v2( table_name = _table_id(config.project, table) with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur: + query = None + params: Any = None + if embedding is not None and query_string is not None and string_fields: # Case 1: Hybrid Search (vector + text) tsquery_str = " & ".join(query_string.split()) From 0169cb6328d86993f2d3c5e83c19d4965fb1c6bc Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Wed, 9 Apr 2025 11:03:58 +0200 Subject: [PATCH 07/11] fix: simplify sorting logic in query execution Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> --- .../postgres_online_store/postgres.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index 3c7469b6adf..f4e40eb3d82 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -684,21 +684,10 @@ def retrieve_online_documents_v2( entities_dict[key]["text_rank"], float(text_rank) ) - if embedding is not None and query_string is not None: - - def sort_key(x): - return x["vector_distance"] - elif embedding is not None: - - def sort_key(x): - return x["vector_distance"] - else: # Text only - - def sort_key(x): - return x["text_rank"] - sorted_entities = sorted( - entities_dict.values(), key=sort_key, reverse=(embedding is None) + entities_dict.values(), + key=lambda x: x["vector_distance"] if embedding is not None else x["text_rank"], + reverse=(embedding is None) )[:top_k] result: List[ From d35e0ba1baf900aec06bf5314133b7e7b43d48f4 Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Wed, 9 Apr 2025 11:05:10 +0200 Subject: [PATCH 08/11] fix formatting Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> --- .../infra/online_stores/postgres_online_store/postgres.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index f4e40eb3d82..13c4e4fad3e 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -685,9 +685,11 @@ def retrieve_online_documents_v2( ) sorted_entities = sorted( - entities_dict.values(), - key=lambda x: x["vector_distance"] if embedding is not None else x["text_rank"], - reverse=(embedding is None) + entities_dict.values(), + key=lambda x: x["vector_distance"] + if embedding is not None + else x["text_rank"], + reverse=(embedding is None), )[:top_k] result: List[ From 86d40b00064073d28c119dd091c74a5c734d8528 Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Wed, 9 Apr 2025 18:07:24 +0200 Subject: [PATCH 09/11] fix: update string feature check to use ValueType enumeration Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> --- .../infra/online_stores/postgres_online_store/postgres.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index 13c4e4fad3e..f6b3d592104 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -20,8 +20,7 @@ from psycopg.connection import Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool -from feast import Entity -from feast.feature_view import FeatureView +from feast import Entity, FeatureView, ValueType from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key from feast.infra.online_stores.helpers import _to_naive_utc from feast.infra.online_stores.online_store import OnlineStore @@ -318,8 +317,9 @@ def update( vector_value_type = "BYTEA" has_string_features = any( - f.dtype.to_value_type().value == 2 for f in table.features - ) # 2 is STRING in ValueType + f.dtype.to_value_type() == ValueType.STRING + for f in table.features + ) cur.execute( sql.SQL( From 55dec54221be9f54506ba7d3d68e686cfb0a801a Mon Sep 17 00:00:00 2001 From: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> Date: Wed, 9 Apr 2025 18:08:55 +0200 Subject: [PATCH 10/11] formatting Signed-off-by: Yassin Nouh <70436855+YassinNouh21@users.noreply.github.com> --- .../infra/online_stores/postgres_online_store/postgres.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index f6b3d592104..daa128264df 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -317,8 +317,7 @@ def update( vector_value_type = "BYTEA" has_string_features = any( - f.dtype.to_value_type() == ValueType.STRING - for f in table.features + f.dtype.to_value_type() == ValueType.STRING for f in table.features ) cur.execute( From 15dcabf30bdee8271f695cf0f776afa28a8b4ecd Mon Sep 17 00:00:00 2001 From: yassinnouh21 Date: Thu, 10 Apr 2025 07:36:13 +0200 Subject: [PATCH 11/11] fix datetime Signed-off-by: yassinnouh21 --- .../tests/integration/online_store/test_universal_online.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index c50801dd745..523cf700d46 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -220,7 +220,7 @@ def test_write_to_online_store_event_check(environment): # writes to online store via datasource (dataframe_source) materialization fs.materialize( - start_date=datetime.datetime.now() - timedelta(hours=12), + start_date=datetime.now() - timedelta(hours=12), end_date=_utc_now(), )