From 6ae5adaf417086518eb74b42ee6ba56491df15ea Mon Sep 17 00:00:00 2001 From: Aniket Paluskar Date: Thu, 13 Nov 2025 14:19:37 +0530 Subject: [PATCH 1/3] feat: Offline Store historical features retrieval based on datetime range for spark Signed-off-by: Aniket Paluskar --- sdk/python/feast/arrow_error_handler.py | 3 + .../contrib/spark_offline_store/spark.py | 155 ++++++++++++++---- .../spark_offline_store/spark_source.py | 3 +- .../test_spark_non_entity.py | 79 +++++++++ 4 files changed, 208 insertions(+), 32 deletions(-) create mode 100644 sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_non_entity.py diff --git a/sdk/python/feast/arrow_error_handler.py b/sdk/python/feast/arrow_error_handler.py index e873592bd5d..e4862bb0982 100644 --- a/sdk/python/feast/arrow_error_handler.py +++ b/sdk/python/feast/arrow_error_handler.py @@ -30,6 +30,9 @@ def wrapper(*args, **kwargs): except Exception as e: if isinstance(e, FeastError): raise fl.FlightError(e.to_error_detail()) + # Re-raise non-Feast exceptions so Arrow Flight returns a proper error + # instead of allowing the server method to return None. + raise e return wrapper diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index f1ba4baa939..173c57f3df1 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -3,12 +3,13 @@ import uuid import warnings from dataclasses import asdict, dataclass -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import ( TYPE_CHECKING, Any, Callable, Dict, + KeysView, List, Optional, Tuple, @@ -151,10 +152,11 @@ def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], - entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame], + entity_df: Optional[Union[pandas.DataFrame, str, pyspark.sql.DataFrame]], registry: BaseRegistry, project: str, full_feature_names: bool = False, + **kwargs, ) -> RetrievalJob: assert isinstance(config.offline_store, SparkOfflineStoreConfig) date_partition_column_formats = [] @@ -175,33 +177,124 @@ def get_historical_features( ) tmp_entity_df_table_name = offline_utils.get_temp_entity_table_name() - entity_schema = _get_entity_schema( - spark_session=spark_session, - entity_df=entity_df, - ) - event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema=entity_schema, - ) - entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, - event_timestamp_col, - spark_session, - ) - _upload_entity_df( - spark_session=spark_session, - table_name=tmp_entity_df_table_name, - entity_df=entity_df, - event_timestamp_col=event_timestamp_col, - ) + # Non-entity mode: synthesize a left table and timestamp range from start/end dates to avoid requiring entity_df. + # This makes date-range retrievals possible without enumerating entities upfront; sources remain bounded by time. + non_entity_mode = entity_df is None + if non_entity_mode: + start_date: Optional[datetime] = kwargs.get("start_date") + end_date: Optional[datetime] = kwargs.get("end_date") + + end_date = end_date or datetime.now(timezone.utc) + if start_date is None: + max_ttl_seconds = 0 + for fv in feature_views: + if fv.ttl and isinstance(fv.ttl, timedelta): + max_ttl_seconds = max( + max_ttl_seconds, int(fv.ttl.total_seconds()) + ) + start_date = ( + end_date - timedelta(seconds=max_ttl_seconds) + if max_ttl_seconds > 0 + else end_date - timedelta(days=30) + ) - expected_join_keys = offline_utils.get_expected_join_keys( - project=project, feature_views=feature_views, registry=registry - ) - offline_utils.assert_expected_columns_in_entity_df( - entity_schema=entity_schema, - join_keys=expected_join_keys, - entity_df_event_timestamp_col=event_timestamp_col, - ) + # Build query contexts so we can reuse entity names and per-view table info consistently. + entity_df_event_timestamp_range = (start_date, end_date) + fv_query_contexts = offline_utils.get_feature_view_query_context( + feature_refs, + feature_views, + registry, + project, + entity_df_event_timestamp_range, + ) + + # Collect the union of entity columns required across all feature views. + all_entities: List[str] = [] + for ctx in fv_query_contexts: + for e in ctx.entities: + if e not in all_entities: + all_entities.append(e) + + # Build a UNION DISTINCT of per-feature-view entity projections, time-bounded and partition-pruned. + start_date_str = _format_datetime(start_date) + end_date_str = _format_datetime(end_date) + per_view_selects: List[str] = [] + for fv, ctx, date_format in zip( + feature_views, fv_query_contexts, date_partition_column_formats + ): + from_expression = fv.batch_source.get_table_query_string() + timestamp_field = fv.batch_source.timestamp_field or "event_timestamp" + date_partition_column = fv.batch_source.date_partition_column + partition_clause = "" + if date_partition_column: + partition_clause = ( + f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'" + f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'" + ) + # Select all required entity columns, filling missing ones with NULL to keep UNION schemas aligned. + select_entities = [] + ctx_entities_set = set(ctx.entities) + for col in all_entities: + if col in ctx_entities_set: + # Cast entity columns to STRING to guarantee UNION schema alignment across sources. + select_entities.append(f"CAST({col} AS STRING) AS {col}") + else: + select_entities.append(f"CAST(NULL AS STRING) AS {col}") + + per_view_selects.append( + f""" + SELECT DISTINCT {", ".join(select_entities)} + FROM {from_expression} + WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause} + """ + ) + + union_query = "\nUNION DISTINCT\n".join( + [s.strip() for s in per_view_selects] + ) + spark_session.sql( + f"CREATE OR REPLACE TEMPORARY VIEW {tmp_entity_df_table_name} AS {union_query}" + ) + + # Add a stable as-of timestamp column for PIT joins. + left_table_query_string = f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS entity_ts FROM {tmp_entity_df_table_name})" + event_timestamp_col = "entity_ts" + # Why: Keep type consistent with entity_df branch (dict KeysView[str]) to satisfy typing and downstream usage. + entity_schema_keys = cast( + KeysView[str], + {k: None for k in (all_entities + [event_timestamp_col])}.keys(), + ) + else: + entity_schema = _get_entity_schema( + spark_session=spark_session, + entity_df=entity_df, + ) + event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema=entity_schema, + ) + entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( + entity_df, + event_timestamp_col, + spark_session, + ) + _upload_entity_df( + spark_session=spark_session, + table_name=tmp_entity_df_table_name, + entity_df=entity_df, + event_timestamp_col=event_timestamp_col, + ) + left_table_query_string = tmp_entity_df_table_name + entity_schema_keys = cast(KeysView[str], entity_schema.keys()) + + if not non_entity_mode: + expected_join_keys = offline_utils.get_expected_join_keys( + project=project, feature_views=feature_views, registry=registry + ) + offline_utils.assert_expected_columns_in_entity_df( + entity_schema=entity_schema, + join_keys=expected_join_keys, + entity_df_event_timestamp_col=event_timestamp_col, + ) query_context = offline_utils.get_feature_view_query_context( feature_refs, @@ -232,9 +325,9 @@ def get_historical_features( feature_view_query_contexts=cast( List[offline_utils.FeatureViewQueryContext], spark_query_context ), - left_table_query_string=tmp_entity_df_table_name, + left_table_query_string=left_table_query_string, entity_df_event_timestamp_col=event_timestamp_col, - entity_df_columns=entity_schema.keys(), + entity_df_columns=entity_schema_keys, query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) @@ -248,7 +341,7 @@ def get_historical_features( ), metadata=RetrievalMetadata( features=feature_refs, - keys=list(set(entity_schema.keys()) - {event_timestamp_col}), + keys=list(set(entity_schema_keys) - {event_timestamp_col}), min_event_timestamp=entity_df_event_timestamp_range[0], max_event_timestamp=entity_df_event_timestamp_range[1], ), diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 6f2af7054b4..cd41921e56a 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -229,7 +229,8 @@ def get_table_query_string(self) -> str: # If both the table query string and the actual query are null, we can load from file. spark_session = SparkSession.getActiveSession() if spark_session is None: - raise AssertionError("Could not find an active spark session.") + # Remote mode may not have an active session bound to the thread; create one on demand. + spark_session = SparkSession.builder.getOrCreate() try: df = self._load_dataframe_from_path(spark_session) except Exception: diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_non_entity.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_non_entity.py new file mode 100644 index 00000000000..3e2465ddb43 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_non_entity.py @@ -0,0 +1,79 @@ +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +from feast.entity import Entity +from feast.feature_view import FeatureView, Field +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, + SparkOfflineStoreConfig, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) +from feast.repo_config import RepoConfig +from feast.types import Float32, ValueType + + +def _mock_spark_offline_store_config(): + return SparkOfflineStoreConfig(type="spark") + + +def _mock_entity(): + return [ + Entity( + name="user_id", + join_keys=["user_id"], + description="User ID", + value_type=ValueType.INT64, + ) + ] + + +def _mock_feature_view(): + return FeatureView( + name="user_stats", + entities=_mock_entity(), + schema=[ + Field(name="metric", dtype=Float32), + ], + source=SparkSource( + name="user_stats_source", + table="default.user_stats", + timestamp_field="event_timestamp", + date_partition_column="ds", + date_partition_column_format="%Y-%m-%d", + ), + ) + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_spark_non_entity_historical_retrieval_accepts_dates(mock_get_spark_session): + # Why: Avoid executing real Spark SQL against non-existent tables during unit tests. + mock_spark_session = MagicMock() + mock_get_spark_session.return_value = mock_spark_session + repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=_mock_spark_offline_store_config(), + ) + + fv = _mock_feature_view() + + retrieval_job = SparkOfflineStore.get_historical_features( + config=repo_config, + feature_views=[fv], + feature_refs=["user_stats:metric"], + entity_df=None, # start/end-only mode + registry=MagicMock(), + project="test_project", + full_feature_names=False, + start_date=datetime(2023, 1, 1, tzinfo=timezone.utc), + end_date=datetime(2023, 1, 2, tzinfo=timezone.utc), + ) + + from feast.infra.offline_stores.offline_store import RetrievalJob + + assert isinstance(retrieval_job, RetrievalJob) From 4fbf122ed3cdba4cb85aa11e9405f4c894001e7f Mon Sep 17 00:00:00 2001 From: Aniket Paluskar Date: Sun, 23 Nov 2025 21:04:25 +0530 Subject: [PATCH 2/3] Restructured code, extended existing test cases Signed-off-by: Aniket Paluskar --- .../contrib/spark_offline_store/spark.py | 186 +++++++++++------- .../contrib/spark_offline_store/test_spark.py | 165 ++++++++++++++++ .../test_spark_non_entity.py | 79 -------- 3 files changed, 285 insertions(+), 145 deletions(-) delete mode 100644 sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_non_entity.py diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 173c57f3df1..7f20e937e3b 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -181,25 +181,11 @@ def get_historical_features( # This makes date-range retrievals possible without enumerating entities upfront; sources remain bounded by time. non_entity_mode = entity_df is None if non_entity_mode: - start_date: Optional[datetime] = kwargs.get("start_date") - end_date: Optional[datetime] = kwargs.get("end_date") - - end_date = end_date or datetime.now(timezone.utc) - if start_date is None: - max_ttl_seconds = 0 - for fv in feature_views: - if fv.ttl and isinstance(fv.ttl, timedelta): - max_ttl_seconds = max( - max_ttl_seconds, int(fv.ttl.total_seconds()) - ) - start_date = ( - end_date - timedelta(seconds=max_ttl_seconds) - if max_ttl_seconds > 0 - else end_date - timedelta(days=30) - ) + # Why: derive bounded time window without requiring entities; uses max TTL fallback to constrain scans. + start_date, end_date = _compute_non_entity_dates(feature_views, kwargs) + entity_df_event_timestamp_range = (start_date, end_date) # Build query contexts so we can reuse entity names and per-view table info consistently. - entity_df_event_timestamp_range = (start_date, end_date) fv_query_contexts = offline_utils.get_feature_view_query_context( feature_refs, feature_views, @@ -209,60 +195,25 @@ def get_historical_features( ) # Collect the union of entity columns required across all feature views. - all_entities: List[str] = [] - for ctx in fv_query_contexts: - for e in ctx.entities: - if e not in all_entities: - all_entities.append(e) + all_entities = _gather_all_entities(fv_query_contexts) # Build a UNION DISTINCT of per-feature-view entity projections, time-bounded and partition-pruned. - start_date_str = _format_datetime(start_date) - end_date_str = _format_datetime(end_date) - per_view_selects: List[str] = [] - for fv, ctx, date_format in zip( - feature_views, fv_query_contexts, date_partition_column_formats - ): - from_expression = fv.batch_source.get_table_query_string() - timestamp_field = fv.batch_source.timestamp_field or "event_timestamp" - date_partition_column = fv.batch_source.date_partition_column - partition_clause = "" - if date_partition_column: - partition_clause = ( - f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'" - f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'" - ) - # Select all required entity columns, filling missing ones with NULL to keep UNION schemas aligned. - select_entities = [] - ctx_entities_set = set(ctx.entities) - for col in all_entities: - if col in ctx_entities_set: - # Cast entity columns to STRING to guarantee UNION schema alignment across sources. - select_entities.append(f"CAST({col} AS STRING) AS {col}") - else: - select_entities.append(f"CAST(NULL AS STRING) AS {col}") - - per_view_selects.append( - f""" - SELECT DISTINCT {", ".join(select_entities)} - FROM {from_expression} - WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause} - """ - ) - - union_query = "\nUNION DISTINCT\n".join( - [s.strip() for s in per_view_selects] - ) - spark_session.sql( - f"CREATE OR REPLACE TEMPORARY VIEW {tmp_entity_df_table_name} AS {union_query}" + _create_temp_entity_union_view( + spark_session=spark_session, + tmp_view_name=tmp_entity_df_table_name, + feature_views=feature_views, + fv_query_contexts=fv_query_contexts, + start_date=start_date, + end_date=end_date, + date_partition_column_formats=date_partition_column_formats, ) # Add a stable as-of timestamp column for PIT joins. - left_table_query_string = f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS entity_ts FROM {tmp_entity_df_table_name})" - event_timestamp_col = "entity_ts" - # Why: Keep type consistent with entity_df branch (dict KeysView[str]) to satisfy typing and downstream usage. - entity_schema_keys = cast( - KeysView[str], - {k: None for k in (all_entities + [event_timestamp_col])}.keys(), + left_table_query_string, event_timestamp_col = _make_left_table_query( + end_date=end_date, tmp_view_name=tmp_entity_df_table_name + ) + entity_schema_keys = _entity_schema_keys_from( + all_entities=all_entities, event_timestamp_col=event_timestamp_col ) else: entity_schema = _get_entity_schema( @@ -633,6 +584,109 @@ def get_spark_session_or_start_new_with_repoconfig( return spark_session +def _compute_non_entity_dates( + feature_views: List[FeatureView], kwargs: Dict[str, Any] +) -> Tuple[datetime, datetime]: + # Why: bounds the scan window when no entity_df is provided using explicit dates or max TTL fallback. + start_date: Optional[datetime] = kwargs.get("start_date") + end_date: Optional[datetime] = kwargs.get("end_date") or datetime.now(timezone.utc) + + if start_date is None: + max_ttl_seconds = 0 + for fv in feature_views: + if fv.ttl and isinstance(fv.ttl, timedelta): + max_ttl_seconds = max(max_ttl_seconds, int(fv.ttl.total_seconds())) + start_date = ( + end_date - timedelta(seconds=max_ttl_seconds) + if max_ttl_seconds > 0 + else end_date - timedelta(days=30) + ) + return start_date, end_date + + +def _gather_all_entities( + fv_query_contexts: List[offline_utils.FeatureViewQueryContext], +) -> List[str]: + # Why: ensure a unified entity set across feature views to align UNION schemas. + all_entities: List[str] = [] + for ctx in fv_query_contexts: + for e in ctx.entities: + if e not in all_entities: + all_entities.append(e) + return all_entities + + +def _create_temp_entity_union_view( + spark_session: SparkSession, + tmp_view_name: str, + feature_views: List[FeatureView], + fv_query_contexts: List[offline_utils.FeatureViewQueryContext], + start_date: datetime, + end_date: datetime, + date_partition_column_formats: List[Optional[str]], +) -> None: + # Why: derive distinct entity keys observed in the time window without requiring an entity_df upfront. + start_date_str = _format_datetime(start_date) + end_date_str = _format_datetime(end_date) + + # Compute the unified entity set to align schemas in the UNION. + all_entities = _gather_all_entities(fv_query_contexts) + + per_view_selects: List[str] = [] + for fv, ctx, date_format in zip( + feature_views, fv_query_contexts, date_partition_column_formats + ): + assert isinstance(fv.batch_source, SparkSource) + from_expression = fv.batch_source.get_table_query_string() + timestamp_field = fv.batch_source.timestamp_field or "event_timestamp" + date_partition_column = fv.batch_source.date_partition_column + partition_clause = "" + if date_partition_column and date_format: + partition_clause = ( + f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'" + f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'" + ) + + # Fill missing entity columns with NULL and cast to STRING to keep UNION schemas aligned. + select_entities: List[str] = [] + ctx_entities_set = set(ctx.entities) + for col in all_entities: + if col in ctx_entities_set: + select_entities.append(f"CAST({col} AS STRING) AS {col}") + else: + select_entities.append(f"CAST(NULL AS STRING) AS {col}") + + per_view_selects.append( + f""" + SELECT DISTINCT {", ".join(select_entities)} + FROM {from_expression} + WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause} + """ + ) + + union_query = "\nUNION DISTINCT\n".join([s.strip() for s in per_view_selects]) + spark_session.sql( + f"CREATE OR REPLACE TEMPORARY VIEW {tmp_view_name} AS {union_query}" + ) + + +def _make_left_table_query(end_date: datetime, tmp_view_name: str) -> Tuple[str, str]: + # Why: use a stable as-of timestamp for PIT joins when no entity timestamps are provided. + event_timestamp_col = "entity_ts" + left_table_query_string = ( + f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS {event_timestamp_col} " + f"FROM {tmp_view_name})" + ) + return left_table_query_string, event_timestamp_col + + +def _entity_schema_keys_from( + all_entities: List[str], event_timestamp_col: str +) -> KeysView[str]: + # Why: pass a KeysView[str] to PIT query builder to match entity_df branch typing. + return cast(KeysView[str], {k: None for k in (all_entities + [event_timestamp_col])}.keys()) + + def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py index 938514a2ca0..22c75ebf387 100644 --- a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py @@ -339,3 +339,168 @@ def _mock_entity(): value_type=ValueType.INT64, ) ] + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_get_historical_features_non_entity_with_date_range(mock_get_spark_session): + mock_spark_session = MagicMock() + # Return a DataFrame for any sql call; last call is used by RetrievalJob + final_df = MagicMock() + expected_pdf = pd.DataFrame([{"feature1": 1.0, "feature2": 2.0}]) + final_df.toPandas.return_value = expected_pdf + mock_spark_session.sql.return_value = final_df + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source1 = SparkSource( + name="test_nested_batch_source1", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name1", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + date_partition_column_format="%Y%m%d", + ) + + test_data_source2 = SparkSource( + name="test_nested_batch_source2", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name2", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + ) + + test_feature_view1 = FeatureView( + name="test_feature_view1", + entities=_mock_entity(), + schema=[ + Field(name="feature1", dtype=Float32), + ], + source=test_data_source1, + ) + + test_feature_view2 = FeatureView( + name="test_feature_view2", + entities=_mock_entity(), + schema=[ + Field(name="feature2", dtype=Float32), + ], + source=test_data_source2, + ) + + mock_registry = MagicMock() + start_date = datetime(2021, 1, 1) + end_date = datetime(2021, 1, 2) + retrieval_job = SparkOfflineStore.get_historical_features( + config=test_repo_config, + feature_views=[test_feature_view2, test_feature_view1], + feature_refs=["test_feature_view2:feature2", "test_feature_view1:feature1"], + entity_df=None, + registry=mock_registry, + project="test_project", + start_date=start_date, + end_date=end_date, + ) + + # Verify query bounded by end_date correctly in both date formats from the two sources + query = retrieval_job.query + assert "effective_date <= '2021-01-02'" in query + assert "effective_date <= '20210102'" in query + + # Verify data: the mocked Spark DataFrame flows through to Pandas + pdf = retrieval_job._to_df_internal() + assert pdf.equals(expected_pdf) + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" +) +def test_get_historical_features_non_entity_with_only_end_date(mock_get_spark_session): + mock_spark_session = MagicMock() + final_df = MagicMock() + expected_pdf = pd.DataFrame([{"feature1": 10.0, "feature2": 20.0}]) + final_df.toPandas.return_value = expected_pdf + mock_spark_session.sql.return_value = final_df + mock_get_spark_session.return_value = mock_spark_session + + test_repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig(type="spark"), + ) + + test_data_source1 = SparkSource( + name="test_nested_batch_source1", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name1", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + date_partition_column_format="%Y%m%d", + ) + + test_data_source2 = SparkSource( + name="test_nested_batch_source2", + description="test_nested_batch_source", + table="offline_store_database_name.offline_store_table_name2", + timestamp_field="nested_timestamp", + field_mapping={ + "event_header.event_published_datetime_utc": "nested_timestamp", + }, + date_partition_column="effective_date", + ) + + test_feature_view1 = FeatureView( + name="test_feature_view1", + entities=_mock_entity(), + schema=[ + Field(name="feature1", dtype=Float32), + ], + source=test_data_source1, + ) + + test_feature_view2 = FeatureView( + name="test_feature_view2", + entities=_mock_entity(), + schema=[ + Field(name="feature2", dtype=Float32), + ], + source=test_data_source2, + ) + + mock_registry = MagicMock() + end_date = datetime(2021, 1, 2) + retrieval_job = SparkOfflineStore.get_historical_features( + config=test_repo_config, + feature_views=[test_feature_view2, test_feature_view1], + feature_refs=["test_feature_view2:feature2", "test_feature_view1:feature1"], + entity_df=None, + registry=mock_registry, + project="test_project", + end_date=end_date, + ) + + # Verify query bounded by end_date correctly for both sources + query = retrieval_job.query + assert "effective_date <= '2021-01-02'" in query + assert "effective_date <= '20210102'" in query + + # Verify data: mocked DataFrame flows to Pandas + pdf = retrieval_job._to_df_internal() + assert pdf.equals(expected_pdf) diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_non_entity.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_non_entity.py deleted file mode 100644 index 3e2465ddb43..00000000000 --- a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_non_entity.py +++ /dev/null @@ -1,79 +0,0 @@ -from datetime import datetime, timezone -from unittest.mock import MagicMock, patch - -from feast.entity import Entity -from feast.feature_view import FeatureView, Field -from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( - SparkOfflineStore, - SparkOfflineStoreConfig, -) -from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( - SparkSource, -) -from feast.repo_config import RepoConfig -from feast.types import Float32, ValueType - - -def _mock_spark_offline_store_config(): - return SparkOfflineStoreConfig(type="spark") - - -def _mock_entity(): - return [ - Entity( - name="user_id", - join_keys=["user_id"], - description="User ID", - value_type=ValueType.INT64, - ) - ] - - -def _mock_feature_view(): - return FeatureView( - name="user_stats", - entities=_mock_entity(), - schema=[ - Field(name="metric", dtype=Float32), - ], - source=SparkSource( - name="user_stats_source", - table="default.user_stats", - timestamp_field="event_timestamp", - date_partition_column="ds", - date_partition_column_format="%Y-%m-%d", - ), - ) - - -@patch( - "feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig" -) -def test_spark_non_entity_historical_retrieval_accepts_dates(mock_get_spark_session): - # Why: Avoid executing real Spark SQL against non-existent tables during unit tests. - mock_spark_session = MagicMock() - mock_get_spark_session.return_value = mock_spark_session - repo_config = RepoConfig( - project="test_project", - registry="test_registry", - provider="local", - offline_store=_mock_spark_offline_store_config(), - ) - - fv = _mock_feature_view() - - retrieval_job = SparkOfflineStore.get_historical_features( - config=repo_config, - feature_views=[fv], - feature_refs=["user_stats:metric"], - entity_df=None, # start/end-only mode - registry=MagicMock(), - project="test_project", - full_feature_names=False, - start_date=datetime(2023, 1, 1, tzinfo=timezone.utc), - end_date=datetime(2023, 1, 2, tzinfo=timezone.utc), - ) - - from feast.infra.offline_stores.offline_store import RetrievalJob - - assert isinstance(retrieval_job, RetrievalJob) From f39ece84423d4fdec1a389b1e533f2f81d512c76 Mon Sep 17 00:00:00 2001 From: Aniket Paluskar Date: Sun, 23 Nov 2025 21:09:24 +0530 Subject: [PATCH 3/3] Fixed lint issues Signed-off-by: Aniket Paluskar --- .../contrib/spark_offline_store/spark.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 7f20e937e3b..47e76a014f0 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -588,20 +588,23 @@ def _compute_non_entity_dates( feature_views: List[FeatureView], kwargs: Dict[str, Any] ) -> Tuple[datetime, datetime]: # Why: bounds the scan window when no entity_df is provided using explicit dates or max TTL fallback. - start_date: Optional[datetime] = kwargs.get("start_date") - end_date: Optional[datetime] = kwargs.get("end_date") or datetime.now(timezone.utc) + start_date_opt = cast(Optional[datetime], kwargs.get("start_date")) + end_date_opt = cast(Optional[datetime], kwargs.get("end_date")) + end_date: datetime = end_date_opt or datetime.now(timezone.utc) - if start_date is None: + if start_date_opt is None: max_ttl_seconds = 0 for fv in feature_views: if fv.ttl and isinstance(fv.ttl, timedelta): max_ttl_seconds = max(max_ttl_seconds, int(fv.ttl.total_seconds())) - start_date = ( + start_date: datetime = ( end_date - timedelta(seconds=max_ttl_seconds) if max_ttl_seconds > 0 else end_date - timedelta(days=30) ) - return start_date, end_date + else: + start_date = start_date_opt + return (start_date, end_date) def _gather_all_entities( @@ -684,7 +687,9 @@ def _entity_schema_keys_from( all_entities: List[str], event_timestamp_col: str ) -> KeysView[str]: # Why: pass a KeysView[str] to PIT query builder to match entity_df branch typing. - return cast(KeysView[str], {k: None for k in (all_entities + [event_timestamp_col])}.keys()) + return cast( + KeysView[str], {k: None for k in (all_entities + [event_timestamp_col])}.keys() + ) def _get_entity_df_event_timestamp_range(