From 1b5aebef4ce5d4794e3cb0675ce1450352f358e4 Mon Sep 17 00:00:00 2001 From: Aniket Paluskar Date: Wed, 26 Nov 2025 00:44:38 +0530 Subject: [PATCH 1/5] feat: Offline Store historical features retrieval based on datetime range in Ray Signed-off-by: Aniket Paluskar --- .../contrib/ray_offline_store/ray.py | 167 +++++++++++++++++- 1 file changed, 163 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 98247c6c0e0..378748f4ca4 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -1,7 +1,7 @@ import logging import os import uuid -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -1197,6 +1197,146 @@ def schema(self) -> pa.Schema: return pa.Table.from_pandas(df).schema +def _compute_non_entity_dates_ray( + feature_views: List[FeatureView], + start_date_opt: Optional[datetime], + end_date_opt: Optional[datetime], +) -> Tuple[datetime, datetime]: + # Why: derive bounded time window when no entity_df is provided using explicit dates or max TTL fallback + end_date = make_tzaware(end_date_opt) if end_date_opt else make_tzaware(datetime.utcnow()) + if start_date_opt is None: + max_ttl_seconds = 0 + for fv in feature_views: + if getattr(fv, "ttl", None): + try: + ttl_val = fv.ttl + if isinstance(ttl_val, timedelta): + max_ttl_seconds = max(max_ttl_seconds, int(ttl_val.total_seconds())) + except Exception: + pass + start_date = ( + end_date - timedelta(seconds=max_ttl_seconds) + if max_ttl_seconds > 0 + else end_date - timedelta(days=30) + ) + else: + start_date = make_tzaware(start_date_opt) + return start_date, end_date + + +def _make_filter_range(timestamp_field: str, start_date: datetime, end_date: datetime): + # Why: factory function for time-range filtering in Ray map_batches + def _filter_range(batch: pd.DataFrame) -> pd.Series: + ts = pd.to_datetime(batch[timestamp_field], utc=True) + return (ts >= start_date) & (ts <= end_date) + + return _filter_range + + +def _make_select_distinct_keys(join_keys: List[str]): + # Why: factory function for distinct key projection in Ray map_batches + def _select_distinct_keys(batch: pd.DataFrame) -> pd.DataFrame: + cols = [c for c in join_keys if c in batch.columns] + if not cols: + return pd.DataFrame(columns=join_keys) + return batch[cols].drop_duplicates().reset_index(drop=True) + + return _select_distinct_keys + + +def _distinct_entities_for_feature_view_ray( + store: "RayOfflineStore", + config: RepoConfig, + fv: FeatureView, + registry: BaseRegistry, + project: str, + start_date: datetime, + end_date: datetime, +) -> Tuple[Dataset, List[str]]: + # Why: read minimal columns, filter by time, and project distinct join keys per FeatureView + ray_wrapper = get_ray_wrapper() + entities = fv.entities or [] + entity_objs = [registry.get_entity(e, project) for e in entities] + original_join_keys, _rev_feats, timestamp_field, _created_col = _get_column_names( + fv, entity_objs + ) + + source_info = resolve_feature_view_source_with_fallback( + fv, config, is_materialization=False + ) + source_path = store._get_source_path(source_info.data_source, config) + required_columns = list(set(original_join_keys + [timestamp_field])) + ds = ray_wrapper.read_parquet(source_path, columns=required_columns) + + field_mapping = getattr(fv.batch_source, "field_mapping", None) + if field_mapping: + ds = apply_field_mapping(ds, field_mapping) + original_join_keys = [field_mapping.get(k, k) for k in original_join_keys] + timestamp_field = field_mapping.get(timestamp_field, timestamp_field) + + if fv.projection.join_key_map: + join_keys = [ + fv.projection.join_key_map.get(key, key) for key in original_join_keys + ] + else: + join_keys = original_join_keys + + ds = ensure_timestamp_compatibility(ds, [timestamp_field]) + ds = ds.filter(_make_filter_range(timestamp_field, start_date, end_date)) + ds = ds.map_batches(_make_select_distinct_keys(join_keys), batch_format="pandas") + return ds, join_keys + + +def _make_align_columns(all_join_keys: List[str]): + # Why: factory function for schema alignment in Ray map_batches + def _align_columns(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + for k in all_join_keys: + if k not in batch.columns: + batch[k] = pd.NA + return batch[all_join_keys] + + return _align_columns + + +def _make_distinct_by_keys(keys: List[str]): + # Why: factory function for deduplication in Ray map_batches + def _distinct(batch: pd.DataFrame) -> pd.DataFrame: + return batch.drop_duplicates(subset=keys).reset_index(drop=True) + + return _distinct + + +def _align_and_union_entities_ray( + datasets: List[Dataset], + all_join_keys: List[str], +) -> Dataset: + # Why: align schemas across FeatureViews and union to a unified entity set + ray_wrapper = get_ray_wrapper() + if not datasets: + return ray_wrapper.from_pandas(pd.DataFrame(columns=all_join_keys)) + + aligned = [ + ds.map_batches(_make_align_columns(all_join_keys), batch_format="pandas") + for ds in datasets + ] + entity_ds = aligned[0] + for ds in aligned[1:]: + entity_ds = entity_ds.union(ds) + return entity_ds.map_batches(_make_distinct_by_keys(all_join_keys), batch_format="pandas") + + +def _add_asof_ts_ray(ds: Dataset, end_date: datetime) -> Dataset: + # Why: use a stable as-of timestamp for PIT joins when deriving entities + def _add_asof_ts(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + batch["event_timestamp"] = end_date + return batch + + ds = ds.map_batches(_add_asof_ts, batch_format="pandas") + return ensure_timestamp_compatibility(ds, ["event_timestamp"]) + + class RayOfflineStore(OfflineStore): def __init__(self) -> None: self._staging_location: Optional[str] = None @@ -1874,17 +2014,36 @@ def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], - entity_df: Union[pd.DataFrame, str], + entity_df: Optional[Union[pd.DataFrame, str]], registry: BaseRegistry, project: str, full_feature_names: bool = False, + **kwargs: Any, ) -> RetrievalJob: store = RayOfflineStore() store._init_ray(config) - # Load entity_df as Ray dataset for distributed processing + # Load or derive entity dataset for distributed processing ray_wrapper = get_ray_wrapper() - if isinstance(entity_df, str): + if entity_df is None: + # Non-entity mode: derive entity set from feature sources within a bounded time window + start_date, end_date = _compute_non_entity_dates_ray( + feature_views, kwargs.get("start_date"), kwargs.get("end_date") + ) + per_view_entity_ds: List[Dataset] = [] + all_join_keys: List[str] = [] + for fv in feature_views: + ds, join_keys = _distinct_entities_for_feature_view_ray( + store, config, fv, registry, project, start_date, end_date + ) + per_view_entity_ds.append(ds) + for k in join_keys: + if k not in all_join_keys: + all_join_keys.append(k) + entity_ds = _align_and_union_entities_ray(per_view_entity_ds, all_join_keys) + entity_ds = _add_asof_ts_ray(entity_ds, end_date) + entity_df_sample = entity_ds.limit(1000).to_pandas() + elif isinstance(entity_df, str): entity_ds = ray_wrapper.read_csv(entity_df) entity_df_sample = entity_ds.limit(1000).to_pandas() else: From de8b2c50e23e51722f8906a4570d151143dec54d Mon Sep 17 00:00:00 2001 From: Aniket Paluskar Date: Wed, 26 Nov 2025 00:49:36 +0530 Subject: [PATCH 2/5] Reforamatted code to fix lint issues Signed-off-by: Aniket Paluskar --- .../offline_stores/contrib/ray_offline_store/ray.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 378748f4ca4..903a5d86e50 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -1203,7 +1203,9 @@ def _compute_non_entity_dates_ray( end_date_opt: Optional[datetime], ) -> Tuple[datetime, datetime]: # Why: derive bounded time window when no entity_df is provided using explicit dates or max TTL fallback - end_date = make_tzaware(end_date_opt) if end_date_opt else make_tzaware(datetime.utcnow()) + end_date = ( + make_tzaware(end_date_opt) if end_date_opt else make_tzaware(datetime.utcnow()) + ) if start_date_opt is None: max_ttl_seconds = 0 for fv in feature_views: @@ -1211,7 +1213,9 @@ def _compute_non_entity_dates_ray( try: ttl_val = fv.ttl if isinstance(ttl_val, timedelta): - max_ttl_seconds = max(max_ttl_seconds, int(ttl_val.total_seconds())) + max_ttl_seconds = max( + max_ttl_seconds, int(ttl_val.total_seconds()) + ) except Exception: pass start_date = ( @@ -1323,7 +1327,9 @@ def _align_and_union_entities_ray( entity_ds = aligned[0] for ds in aligned[1:]: entity_ds = entity_ds.union(ds) - return entity_ds.map_batches(_make_distinct_by_keys(all_join_keys), batch_format="pandas") + return entity_ds.map_batches( + _make_distinct_by_keys(all_join_keys), batch_format="pandas" + ) def _add_asof_ts_ray(ds: Dataset, end_date: datetime) -> Dataset: From c25280c6722cb64a63f685666405c2530f76a4eb Mon Sep 17 00:00:00 2001 From: Aniket Paluskar Date: Fri, 26 Dec 2025 17:40:45 +0530 Subject: [PATCH 3/5] preserve event_timestamp in non-entity mode for correct point-in-time joins Signed-off-by: Aniket Paluskar --- .../contrib/ray_offline_store/ray.py | 82 ++++++++++++------- 1 file changed, 54 insertions(+), 28 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 903a5d86e50..308abea173b 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -1237,15 +1237,24 @@ def _filter_range(batch: pd.DataFrame) -> pd.Series: return _filter_range -def _make_select_distinct_keys(join_keys: List[str]): - # Why: factory function for distinct key projection in Ray map_batches - def _select_distinct_keys(batch: pd.DataFrame) -> pd.DataFrame: +def _make_select_distinct_entity_timestamps( + join_keys: List[str], timestamp_field: str +): + # Why: factory function for distinct (entity_keys, event_timestamp) projection in Ray map_batches + # This preserves multiple transactions per entity ID with different timestamps for proper PIT joins + def _select_distinct_entity_timestamps(batch: pd.DataFrame) -> pd.DataFrame: cols = [c for c in join_keys if c in batch.columns] + if timestamp_field in batch.columns: + # Rename timestamp to standardized event_timestamp + batch = batch.copy() + if timestamp_field != "event_timestamp": + batch["event_timestamp"] = batch[timestamp_field] + cols = cols + ["event_timestamp"] if not cols: - return pd.DataFrame(columns=join_keys) + return pd.DataFrame(columns=join_keys + ["event_timestamp"]) return batch[cols].drop_duplicates().reset_index(drop=True) - return _select_distinct_keys + return _select_distinct_entity_timestamps def _distinct_entities_for_feature_view_ray( @@ -1257,7 +1266,8 @@ def _distinct_entities_for_feature_view_ray( start_date: datetime, end_date: datetime, ) -> Tuple[Dataset, List[str]]: - # Why: read minimal columns, filter by time, and project distinct join keys per FeatureView + # Why: read minimal columns, filter by time, and project distinct (join_keys, event_timestamp) per FeatureView + # This preserves multiple transactions per entity ID for proper point-in-time joins ray_wrapper = get_ray_wrapper() entities = fv.entities or [] entity_objs = [registry.get_entity(e, project) for e in entities] @@ -1287,26 +1297,38 @@ def _distinct_entities_for_feature_view_ray( ds = ensure_timestamp_compatibility(ds, [timestamp_field]) ds = ds.filter(_make_filter_range(timestamp_field, start_date, end_date)) - ds = ds.map_batches(_make_select_distinct_keys(join_keys), batch_format="pandas") + # Extract distinct (entity_keys, event_timestamp) combinations - not just entity_keys + ds = ds.map_batches( + _make_select_distinct_entity_timestamps(join_keys, timestamp_field), + batch_format="pandas", + ) return ds, join_keys -def _make_align_columns(all_join_keys: List[str]): +def _make_align_columns(all_join_keys: List[str], include_timestamp: bool = False): # Why: factory function for schema alignment in Ray map_batches + # When include_timestamp=True, also aligns event_timestamp column for proper PIT joins def _align_columns(batch: pd.DataFrame) -> pd.DataFrame: batch = batch.copy() - for k in all_join_keys: + output_cols = list(all_join_keys) + if include_timestamp: + output_cols = output_cols + ["event_timestamp"] + for k in output_cols: if k not in batch.columns: batch[k] = pd.NA - return batch[all_join_keys] + return batch[output_cols] return _align_columns -def _make_distinct_by_keys(keys: List[str]): +def _make_distinct_by_keys(keys: List[str], include_timestamp: bool = False): # Why: factory function for deduplication in Ray map_batches + # When include_timestamp=True, deduplicates on (keys + event_timestamp) for proper PIT joins def _distinct(batch: pd.DataFrame) -> pd.DataFrame: - return batch.drop_duplicates(subset=keys).reset_index(drop=True) + subset = list(keys) + if include_timestamp and "event_timestamp" in batch.columns: + subset = subset + ["event_timestamp"] + return batch.drop_duplicates(subset=subset).reset_index(drop=True) return _distinct @@ -1314,35 +1336,34 @@ def _distinct(batch: pd.DataFrame) -> pd.DataFrame: def _align_and_union_entities_ray( datasets: List[Dataset], all_join_keys: List[str], + include_timestamp: bool = False, ) -> Dataset: # Why: align schemas across FeatureViews and union to a unified entity set + # When include_timestamp=True, preserves distinct (entity_keys, event_timestamp) combinations + # for proper point-in-time joins with multiple transactions per entity ray_wrapper = get_ray_wrapper() + output_cols = list(all_join_keys) + if include_timestamp: + output_cols = output_cols + ["event_timestamp"] if not datasets: - return ray_wrapper.from_pandas(pd.DataFrame(columns=all_join_keys)) + return ray_wrapper.from_pandas(pd.DataFrame(columns=output_cols)) aligned = [ - ds.map_batches(_make_align_columns(all_join_keys), batch_format="pandas") + ds.map_batches( + _make_align_columns(all_join_keys, include_timestamp=include_timestamp), + batch_format="pandas", + ) for ds in datasets ] entity_ds = aligned[0] for ds in aligned[1:]: entity_ds = entity_ds.union(ds) return entity_ds.map_batches( - _make_distinct_by_keys(all_join_keys), batch_format="pandas" + _make_distinct_by_keys(all_join_keys, include_timestamp=include_timestamp), + batch_format="pandas", ) -def _add_asof_ts_ray(ds: Dataset, end_date: datetime) -> Dataset: - # Why: use a stable as-of timestamp for PIT joins when deriving entities - def _add_asof_ts(batch: pd.DataFrame) -> pd.DataFrame: - batch = batch.copy() - batch["event_timestamp"] = end_date - return batch - - ds = ds.map_batches(_add_asof_ts, batch_format="pandas") - return ensure_timestamp_compatibility(ds, ["event_timestamp"]) - - class RayOfflineStore(OfflineStore): def __init__(self) -> None: self._staging_location: Optional[str] = None @@ -2033,6 +2054,8 @@ def get_historical_features( ray_wrapper = get_ray_wrapper() if entity_df is None: # Non-entity mode: derive entity set from feature sources within a bounded time window + # Preserves distinct (entity_keys, event_timestamp) combinations for proper PIT joins + # This handles cases where multiple transactions per entity ID exist start_date, end_date = _compute_non_entity_dates_ray( feature_views, kwargs.get("start_date"), kwargs.get("end_date") ) @@ -2046,8 +2069,11 @@ def get_historical_features( for k in join_keys: if k not in all_join_keys: all_join_keys.append(k) - entity_ds = _align_and_union_entities_ray(per_view_entity_ds, all_join_keys) - entity_ds = _add_asof_ts_ray(entity_ds, end_date) + # Use include_timestamp=True to preserve actual event_timestamp from data + # instead of assigning a fixed end_date to all entities + entity_ds = _align_and_union_entities_ray( + per_view_entity_ds, all_join_keys, include_timestamp=True + ) entity_df_sample = entity_ds.limit(1000).to_pandas() elif isinstance(entity_df, str): entity_ds = ray_wrapper.read_csv(entity_df) From 0c20ac078adb25d2dedd251e22581d772cd1d4c4 Mon Sep 17 00:00:00 2001 From: Aniket Paluskar Date: Fri, 26 Dec 2025 17:45:18 +0530 Subject: [PATCH 4/5] Minor lint changes Signed-off-by: Aniket Paluskar --- .../infra/offline_stores/contrib/ray_offline_store/ray.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 308abea173b..b7b3619c15a 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -1237,9 +1237,7 @@ def _filter_range(batch: pd.DataFrame) -> pd.Series: return _filter_range -def _make_select_distinct_entity_timestamps( - join_keys: List[str], timestamp_field: str -): +def _make_select_distinct_entity_timestamps(join_keys: List[str], timestamp_field: str): # Why: factory function for distinct (entity_keys, event_timestamp) projection in Ray map_batches # This preserves multiple transactions per entity ID with different timestamps for proper PIT joins def _select_distinct_entity_timestamps(batch: pd.DataFrame) -> pd.DataFrame: From f589956b00efe8cc2ab37dbb4a891cf5b70a4274 Mon Sep 17 00:00:00 2001 From: Aniket Paluskar Date: Mon, 29 Dec 2025 12:38:13 +0530 Subject: [PATCH 5/5] Added test cases for datetime range based feature retrieval in Ray Signed-off-by: Aniket Paluskar --- .../tests/test_ray_integration.py | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py index 5ab82f8ef47..0420054b8fb 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py @@ -1,3 +1,5 @@ +from datetime import timedelta + import pandas as pd import pytest @@ -144,3 +146,126 @@ def test_ray_offline_store_persist(environment, universal_data_sources): import os assert os.path.exists(saved_path) + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_ray_offline_store_non_entity_mode_basic(environment, universal_data_sources): + """Test historical features retrieval without entity_df (non-entity mode). + + This tests the basic functionality where entity_df=None and start_date/end_date + are provided to retrieve all features within the time range. + """ + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + store.apply( + [ + driver(), + feature_views.driver, + ] + ) + + # Use the environment's start and end dates for the query + start_date = environment.start_date + end_date = environment.end_date + + # Non-entity mode: entity_df=None with start_date and end_date + result_df = store.get_historical_features( + entity_df=None, + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + "driver_stats:avg_daily_trips", + ], + full_feature_names=False, + start_date=start_date, + end_date=end_date, + ).to_df() + + # Verify data was retrieved + assert len(result_df) > 0, "Non-entity mode should return data" + assert "conv_rate" in result_df.columns + assert "acc_rate" in result_df.columns + assert "avg_daily_trips" in result_df.columns + assert "event_timestamp" in result_df.columns + assert "driver_id" in result_df.columns + + # Verify timestamps are within the requested range + result_df["event_timestamp"] = pd.to_datetime( + result_df["event_timestamp"], utc=True + ) + assert (result_df["event_timestamp"] >= start_date).all() + assert (result_df["event_timestamp"] <= end_date).all() + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_ray_offline_store_non_entity_mode_preserves_multiple_timestamps( + environment, universal_data_sources +): + """Test that non-entity mode preserves multiple transactions per entity ID. + + This is a regression test for the fix that ensures distinct (entity_key, event_timestamp) + combinations are preserved, not just distinct entity keys. This is critical for + proper point-in-time joins when an entity has multiple transactions. + """ + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + store.apply( + [ + driver(), + feature_views.driver, + ] + ) + + now = _utc_now() + ts1 = pd.Timestamp(now - timedelta(hours=2)).round("ms") + ts2 = pd.Timestamp(now - timedelta(hours=1)).round("ms") + ts3 = pd.Timestamp(now).round("ms") + + # Write data with multiple timestamps for the same entity (driver_id=9001) + df_to_write = pd.DataFrame.from_dict( + { + "event_timestamp": [ts1, ts2, ts3], + "driver_id": [9001, 9001, 9001], # Same entity, different timestamps + "conv_rate": [0.1, 0.2, 0.3], + "acc_rate": [0.9, 0.8, 0.7], + "avg_daily_trips": [10, 20, 30], + "created": [ts1, ts2, ts3], + }, + ) + + store.write_to_offline_store( + feature_views.driver.name, df_to_write, allow_registry_cache=False + ) + + # Query without entity_df - should get all 3 rows for driver_id=9001 + result_df = store.get_historical_features( + entity_df=None, + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + ], + full_feature_names=False, + start_date=ts1 - timedelta(minutes=1), + end_date=ts3 + timedelta(minutes=1), + ).to_df() + + # Filter to just our test entity + result_df = result_df[result_df["driver_id"] == 9001] + + # Verify we got all 3 rows with different timestamps (not just 1 row) + assert len(result_df) == 3, ( + f"Expected 3 rows for driver_id=9001 (one per timestamp), got {len(result_df)}" + ) + + # Verify the feature values are correct for each timestamp + result_df = result_df.sort_values("event_timestamp").reset_index(drop=True) + assert list(result_df["conv_rate"]) == [0.1, 0.2, 0.3] + assert list(result_df["acc_rate"]) == [0.9, 0.8, 0.7]