diff --git a/docs/reference/offline-stores/postgres.md b/docs/reference/offline-stores/postgres.md index 094ab4885f4..321ddcf25e7 100644 --- a/docs/reference/offline-stores/postgres.md +++ b/docs/reference/offline-stores/postgres.md @@ -32,6 +32,7 @@ offline_store: sslkey_path: /path/to/client-key.pem sslcert_path: /path/to/client-cert.pem sslrootcert_path: /path/to/server-ca.pem + entity_select_mode: temp_table online_store: path: data/online_store.db ``` @@ -40,6 +41,8 @@ online_store: Note that `sslmode`, `sslkey_path`, `sslcert_path`, and `sslrootcert_path` are optional parameters. The full set of configuration options is available in [PostgreSQLOfflineStoreConfig](https://rtd.feast.dev/en/master/#feast.infra.offline_stores.contrib.postgres_offline_store.postgres.PostgreSQLOfflineStoreConfig). +Additionally, a new optional parameter `entity_select_mode` was added to tell how Postgres should load the entity data. By default(`temp_table`), a temporary table is created and the entity data frame or sql is loaded into that table. A new value of `embed_query` was added to allow directly loading the SQL query into a CTE, providing improved performance and skipping the need to CREATE and DROP the temporary table. + ## Functionality Matrix The set of functionality supported by offline stores is described in detail [here](overview.md#functionality). diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py index ec6b713941c..2b757019543 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py @@ -1,6 +1,7 @@ import contextlib from dataclasses import asdict from datetime import datetime, timezone +from enum import Enum from typing import ( Any, Callable, @@ -48,8 +49,16 @@ from .postgres_source import PostgreSQLSource +class EntitySelectMode(Enum): + temp_table = "temp_table" + """ Use a temporary table to store the entity DataFrame or SQL query when querying feature data """ + embed_query = "embed_query" + """ Use the entity SQL query directly when querying feature data """ + + class PostgreSQLOfflineStoreConfig(PostgreSQLConfig): type: Literal["postgres"] = "postgres" + entity_select_mode: EntitySelectMode = EntitySelectMode.temp_table class PostgreSQLOfflineStore(OfflineStore): @@ -134,7 +143,17 @@ def get_historical_features( def query_generator() -> Iterator[str]: table_name = offline_utils.get_temp_entity_table_name() - _upload_entity_df(config, entity_df, table_name) + # If using CTE and entity_df is a SQL query, we don't need a table + if config.offline_store.entity_select_mode == EntitySelectMode.embed_query: + if isinstance(entity_df, str): + left_table_query_string = entity_df + else: + raise ValueError( + f"Invalid entity select mode: {config.offline_store.entity_select_mode} cannot be used with entity_df as a DataFrame" + ) + else: + left_table_query_string = table_name + _upload_entity_df(config, entity_df, table_name) expected_join_keys = offline_utils.get_expected_join_keys( project, feature_views, registry @@ -163,14 +182,19 @@ def query_generator() -> Iterator[str]: try: yield build_point_in_time_query( query_context_dict, - left_table_query_string=table_name, + left_table_query_string=left_table_query_string, entity_df_event_timestamp_col=entity_df_event_timestamp_col, entity_df_columns=entity_schema.keys(), query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, + entity_select_mode=config.offline_store.entity_select_mode, ) finally: - if table_name: + # Only cleanup if we created a table + if ( + config.offline_store.entity_select_mode + == EntitySelectMode.temp_table + ): with _get_conn(config.offline_store) as conn, conn.cursor() as cur: cur.execute( sql.SQL( @@ -362,6 +386,7 @@ def build_point_in_time_query( entity_df_columns: KeysView[str], query_template: str, full_feature_names: bool = False, + entity_select_mode: EntitySelectMode = EntitySelectMode.temp_table, ) -> str: """Build point-in-time query between each feature view table and the entity dataframe for PostgreSQL""" template = Environment(loader=BaseLoader()).from_string(source=query_template) @@ -389,6 +414,7 @@ def build_point_in_time_query( "featureviews": feature_view_query_contexts, "full_feature_names": full_feature_names, "final_output_feature_names": final_output_feature_names, + "entity_select_mode": entity_select_mode.value, } query = template.render(template_context) @@ -429,11 +455,15 @@ def _get_entity_schema( # https://github.com/feast-dev/feast/blob/master/sdk/python/feast/infra/offline_stores/redshift.py MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ +WITH +{% if entity_select_mode == "embed_query" %} + entity_query AS ({{ left_table_query_string }}), +{% endif %} /* Compute a deterministic hash for the `left_table_query_string` that will be used throughout all the logic as the field to GROUP BY the data */ -WITH entity_dataframe AS ( +entity_dataframe AS ( SELECT *, {{entity_df_event_timestamp_col}} AS entity_timestamp {% for featureview in featureviews %} @@ -448,7 +478,12 @@ def _get_entity_schema( ,CAST("{{entity_df_event_timestamp_col}}" AS VARCHAR) AS "{{featureview.name}}__entity_row_unique_id" {% endif %} {% endfor %} - FROM {{ left_table_query_string }} + FROM + {% if entity_select_mode == "embed_query" %} + entity_query + {% else %} + {{ left_table_query_string }} + {% endif %} ), {% for featureview in featureviews %} diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/postgres_offline_store/test_postgres.py b/sdk/python/tests/unit/infra/offline_stores/contrib/postgres_offline_store/test_postgres.py new file mode 100644 index 00000000000..dd9259c4cf5 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/postgres_offline_store/test_postgres.py @@ -0,0 +1,390 @@ +import logging +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pandas as pd + +from feast.entity import Entity +from feast.feature_view import FeatureView, Field +from feast.infra.offline_stores.contrib.postgres_offline_store.postgres import ( + PostgreSQLOfflineStore, + PostgreSQLOfflineStoreConfig, +) +from feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source import ( + PostgreSQLSource, +) +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.repo_config import RepoConfig +from feast.types import Float32 + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn") +def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_conn): + mock_conn = MagicMock() + mock_get_conn.return_value.__enter__.return_value = mock_conn + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=PostgreSQLOfflineStoreConfig( + type="postgres", + host="localhost", + port=5432, + database="test_db", + db_schema="public", + user="test_user", + password="test_password", + ), + ) + + test_data_source = PostgreSQLSource( + name="test_nested_batch_source", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_header.event_published_datetime_utc" + created_timestamp_column = "created_timestamp" + start_date = datetime(2021, 1, 1, tzinfo=timezone.utc) + end_date = datetime(2021, 1, 2, tzinfo=timezone.utc) + + # Call the method + retrieval_job = PostgreSQLOfflineStore.pull_latest_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + actual_query = retrieval_job.to_sql().strip() + logger.debug("Actual query:\n%s", actual_query) + + expected_query = """SELECT + b."key1", b."key2", b."feature1", b."feature2", b."event_header.event_published_datetime_utc", b."created_timestamp" + + FROM ( + SELECT a."key1", a."key2", a."feature1", a."feature2", a."event_header.event_published_datetime_utc", a."created_timestamp", + ROW_NUMBER() OVER(PARTITION BY a."key1", a."key2" ORDER BY a."event_header.event_published_datetime_utc" DESC, a."created_timestamp" DESC) AS _feast_row + FROM offline_store_database_name.offline_store_table_name a + WHERE a."event_header.event_published_datetime_utc" BETWEEN '2021-01-01 00:00:00+00:00'::timestamptz AND '2021-01-02 00:00:00+00:00'::timestamptz + ) b + WHERE _feast_row = 1""" # noqa: W293 + + logger.debug("Expected query:\n%s", expected_query) + + assert isinstance(retrieval_job, RetrievalJob) + assert actual_query == expected_query + + +@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn") +def test_pull_latest_from_table_without_nested_timestamp_or_query(mock_get_conn): + mock_conn = MagicMock() + mock_get_conn.return_value.__enter__.return_value = mock_conn + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=PostgreSQLOfflineStoreConfig( + type="postgres", + host="localhost", + port=5432, + database="test_db", + db_schema="public", + user="test_user", + password="test_password", + ), + ) + + test_data_source = PostgreSQLSource( + name="test_batch_source", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="event_published_datetime_utc", + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_published_datetime_utc" + created_timestamp_column = "created_timestamp" + start_date = datetime(2021, 1, 1, tzinfo=timezone.utc) + end_date = datetime(2021, 1, 2, tzinfo=timezone.utc) + + # Call the method + retrieval_job = PostgreSQLOfflineStore.pull_latest_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + actual_query = retrieval_job.to_sql().strip() + logger.debug("Actual query:\n%s", actual_query) + + expected_query = """SELECT + b."key1", b."key2", b."feature1", b."feature2", b."event_published_datetime_utc", b."created_timestamp" + + FROM ( + SELECT a."key1", a."key2", a."feature1", a."feature2", a."event_published_datetime_utc", a."created_timestamp", + ROW_NUMBER() OVER(PARTITION BY a."key1", a."key2" ORDER BY a."event_published_datetime_utc" DESC, a."created_timestamp" DESC) AS _feast_row + FROM offline_store_database_name.offline_store_table_name a + WHERE a."event_published_datetime_utc" BETWEEN '2021-01-01 00:00:00+00:00'::timestamptz AND '2021-01-02 00:00:00+00:00'::timestamptz + ) b + WHERE _feast_row = 1""" # noqa: W293 + + logger.debug("Expected query:\n%s", expected_query) + + assert isinstance(retrieval_job, RetrievalJob) + assert actual_query == expected_query + + +@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn") +def test_pull_all_from_table_or_query(mock_get_conn): + mock_conn = MagicMock() + mock_get_conn.return_value.__enter__.return_value = mock_conn + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=PostgreSQLOfflineStoreConfig( + type="postgres", + host="localhost", + port=5432, + database="test_db", + db_schema="public", + user="test_user", + password="test_password", + ), + ) + + test_data_source = PostgreSQLSource( + name="test_batch_source", + description="test_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="event_published_datetime_utc", + ) + + # Define the parameters for the method + join_key_columns = ["key1", "key2"] + feature_name_columns = ["feature1", "feature2"] + timestamp_field = "event_published_datetime_utc" + start_date = datetime(2021, 1, 1, tzinfo=timezone.utc) + end_date = datetime(2021, 1, 2, tzinfo=timezone.utc) + + # Call the method + retrieval_job = PostgreSQLOfflineStore.pull_all_from_table_or_query( + config=test_repo_config, + data_source=test_data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + start_date=start_date, + end_date=end_date, + ) + + actual_query = retrieval_job.to_sql().strip() + logger.debug("Actual query:\n%s", actual_query) + + expected_query = """SELECT key1, key2, feature1, feature2, event_published_datetime_utc + FROM offline_store_database_name.offline_store_table_name AS paftoq_alias + WHERE "event_published_datetime_utc" BETWEEN '2021-01-01 00:00:00+00:00'::timestamptz AND '2021-01-02 00:00:00+00:00'::timestamptz""" # noqa: W293 + + logger.debug("Expected query:\n%s", expected_query) + + assert isinstance(retrieval_job, RetrievalJob) + assert actual_query == expected_query + + +@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn") +@patch( + "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.df_to_postgres_table" +) +@patch( + "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.get_query_schema" +) +def test_get_historical_features_entity_select_modes( + mock_get_query_schema, mock_df_to_postgres_table, mock_get_conn +): + mock_conn = MagicMock() + mock_get_conn.return_value.__enter__.return_value = mock_conn + + # Mock the query schema to return a simple schema + mock_get_query_schema.return_value = { + "event_timestamp": pd.Timestamp, + "driver_id": pd.Int64Dtype(), + } + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=PostgreSQLOfflineStoreConfig( + type="postgres", + host="localhost", + port=5432, + database="test_db", + db_schema="public", + user="test_user", + password="test_password", + ), + ) + + test_data_source = PostgreSQLSource( + name="test_batch_source", + description="test_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="event_published_datetime_utc", + ) + + test_feature_view = FeatureView( + name="test_feature_view", + entities=[ + Entity( + name="driver_id", + join_keys=["driver_id"], + description="Driver ID", + ) + ], + schema=[ + Field(name="feature1", dtype=Float32), + ], + source=test_data_source, + ) + + mock_registry = MagicMock() + mock_registry.get_feature_view.return_value = test_feature_view + + # Create a DataFrame with the required event_timestamp column + entity_df = pd.DataFrame( + {"event_timestamp": [datetime(2021, 1, 1)], "driver_id": [1]} + ) + + retrieval_job = PostgreSQLOfflineStore.get_historical_features( + config=test_repo_config, + feature_views=[test_feature_view], + feature_refs=["test_feature_view:feature1"], + entity_df=entity_df, + registry=mock_registry, + project="test_project", + ) + + actual_query = retrieval_job.to_sql().strip() + logger.debug("Actual query:\n%s", actual_query) + + # Check that the query starts with WITH and contains the expected comment block + assert actual_query.startswith("""WITH + +/* + Compute a deterministic hash for the `left_table_query_string` that will be used throughout + all the logic as the field to GROUP BY the data +*/""") + + +@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn") +@patch( + "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.df_to_postgres_table" +) +@patch( + "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.get_query_schema" +) +def test_get_historical_features_entity_select_modes_embed_query( + mock_get_query_schema, mock_df_to_postgres_table, mock_get_conn +): + mock_conn = MagicMock() + mock_get_conn.return_value.__enter__.return_value = mock_conn + + # Mock the query schema to return a simple schema + mock_get_query_schema.return_value = { + "event_timestamp": pd.Timestamp, + "driver_id": pd.Int64Dtype(), + } + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=PostgreSQLOfflineStoreConfig( + type="postgres", + host="localhost", + port=5432, + database="test_db", + db_schema="public", + user="test_user", + password="test_password", + entity_select_mode="embed_query", + ), + ) + + test_data_source = PostgreSQLSource( + name="test_batch_source", + description="test_batch_source", + table="offline_store_database_name.offline_store_table_name", + timestamp_field="event_published_datetime_utc", + ) + + test_feature_view = FeatureView( + name="test_feature_view", + entities=[ + Entity( + name="driver_id", + join_keys=["driver_id"], + description="Driver ID", + ) + ], + schema=[ + Field(name="feature1", dtype=Float32), + ], + source=test_data_source, + ) + + mock_registry = MagicMock() + mock_registry.get_feature_view.return_value = test_feature_view + + # Use a SQL query string instead of DataFrame for embed_query mode + entity_df = """ + SELECT + event_timestamp, + driver_id + FROM ( + VALUES + ('2021-01-01'::timestamp, 1) + ) AS t(event_timestamp, driver_id) + """ + + retrieval_job = PostgreSQLOfflineStore.get_historical_features( + config=test_repo_config, + feature_views=[test_feature_view], + feature_refs=["test_feature_view:feature1"], + entity_df=entity_df, + registry=mock_registry, + project="test_project", + ) + + actual_query = retrieval_job.to_sql().strip() + logger.debug("Actual query:\n%s", actual_query) + + # Check that the query starts with WITH and contains the expected comment block + assert actual_query.startswith("""WITH + + entity_query AS (""")