diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 98247c6c0e0..b7b3619c15a 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -1,7 +1,7 @@ import logging import os import uuid -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -1197,6 +1197,171 @@ def schema(self) -> pa.Schema: return pa.Table.from_pandas(df).schema +def _compute_non_entity_dates_ray( + feature_views: List[FeatureView], + start_date_opt: Optional[datetime], + end_date_opt: Optional[datetime], +) -> Tuple[datetime, datetime]: + # Why: derive bounded time window when no entity_df is provided using explicit dates or max TTL fallback + end_date = ( + make_tzaware(end_date_opt) if end_date_opt else make_tzaware(datetime.utcnow()) + ) + if start_date_opt is None: + max_ttl_seconds = 0 + for fv in feature_views: + if getattr(fv, "ttl", None): + try: + ttl_val = fv.ttl + if isinstance(ttl_val, timedelta): + max_ttl_seconds = max( + max_ttl_seconds, int(ttl_val.total_seconds()) + ) + except Exception: + pass + start_date = ( + end_date - timedelta(seconds=max_ttl_seconds) + if max_ttl_seconds > 0 + else end_date - timedelta(days=30) + ) + else: + start_date = make_tzaware(start_date_opt) + return start_date, end_date + + +def _make_filter_range(timestamp_field: str, start_date: datetime, end_date: datetime): + # Why: factory function for time-range filtering in Ray map_batches + def _filter_range(batch: pd.DataFrame) -> pd.Series: + ts = pd.to_datetime(batch[timestamp_field], utc=True) + return (ts >= start_date) & (ts <= end_date) + + return _filter_range + + +def _make_select_distinct_entity_timestamps(join_keys: List[str], timestamp_field: str): + # Why: factory function for distinct (entity_keys, event_timestamp) projection in Ray map_batches + # This preserves multiple transactions per entity ID with different timestamps for proper PIT joins + def _select_distinct_entity_timestamps(batch: pd.DataFrame) -> pd.DataFrame: + cols = [c for c in join_keys if c in batch.columns] + if timestamp_field in batch.columns: + # Rename timestamp to standardized event_timestamp + batch = batch.copy() + if timestamp_field != "event_timestamp": + batch["event_timestamp"] = batch[timestamp_field] + cols = cols + ["event_timestamp"] + if not cols: + return pd.DataFrame(columns=join_keys + ["event_timestamp"]) + return batch[cols].drop_duplicates().reset_index(drop=True) + + return _select_distinct_entity_timestamps + + +def _distinct_entities_for_feature_view_ray( + store: "RayOfflineStore", + config: RepoConfig, + fv: FeatureView, + registry: BaseRegistry, + project: str, + start_date: datetime, + end_date: datetime, +) -> Tuple[Dataset, List[str]]: + # Why: read minimal columns, filter by time, and project distinct (join_keys, event_timestamp) per FeatureView + # This preserves multiple transactions per entity ID for proper point-in-time joins + ray_wrapper = get_ray_wrapper() + entities = fv.entities or [] + entity_objs = [registry.get_entity(e, project) for e in entities] + original_join_keys, _rev_feats, timestamp_field, _created_col = _get_column_names( + fv, entity_objs + ) + + source_info = resolve_feature_view_source_with_fallback( + fv, config, is_materialization=False + ) + source_path = store._get_source_path(source_info.data_source, config) + required_columns = list(set(original_join_keys + [timestamp_field])) + ds = ray_wrapper.read_parquet(source_path, columns=required_columns) + + field_mapping = getattr(fv.batch_source, "field_mapping", None) + if field_mapping: + ds = apply_field_mapping(ds, field_mapping) + original_join_keys = [field_mapping.get(k, k) for k in original_join_keys] + timestamp_field = field_mapping.get(timestamp_field, timestamp_field) + + if fv.projection.join_key_map: + join_keys = [ + fv.projection.join_key_map.get(key, key) for key in original_join_keys + ] + else: + join_keys = original_join_keys + + ds = ensure_timestamp_compatibility(ds, [timestamp_field]) + ds = ds.filter(_make_filter_range(timestamp_field, start_date, end_date)) + # Extract distinct (entity_keys, event_timestamp) combinations - not just entity_keys + ds = ds.map_batches( + _make_select_distinct_entity_timestamps(join_keys, timestamp_field), + batch_format="pandas", + ) + return ds, join_keys + + +def _make_align_columns(all_join_keys: List[str], include_timestamp: bool = False): + # Why: factory function for schema alignment in Ray map_batches + # When include_timestamp=True, also aligns event_timestamp column for proper PIT joins + def _align_columns(batch: pd.DataFrame) -> pd.DataFrame: + batch = batch.copy() + output_cols = list(all_join_keys) + if include_timestamp: + output_cols = output_cols + ["event_timestamp"] + for k in output_cols: + if k not in batch.columns: + batch[k] = pd.NA + return batch[output_cols] + + return _align_columns + + +def _make_distinct_by_keys(keys: List[str], include_timestamp: bool = False): + # Why: factory function for deduplication in Ray map_batches + # When include_timestamp=True, deduplicates on (keys + event_timestamp) for proper PIT joins + def _distinct(batch: pd.DataFrame) -> pd.DataFrame: + subset = list(keys) + if include_timestamp and "event_timestamp" in batch.columns: + subset = subset + ["event_timestamp"] + return batch.drop_duplicates(subset=subset).reset_index(drop=True) + + return _distinct + + +def _align_and_union_entities_ray( + datasets: List[Dataset], + all_join_keys: List[str], + include_timestamp: bool = False, +) -> Dataset: + # Why: align schemas across FeatureViews and union to a unified entity set + # When include_timestamp=True, preserves distinct (entity_keys, event_timestamp) combinations + # for proper point-in-time joins with multiple transactions per entity + ray_wrapper = get_ray_wrapper() + output_cols = list(all_join_keys) + if include_timestamp: + output_cols = output_cols + ["event_timestamp"] + if not datasets: + return ray_wrapper.from_pandas(pd.DataFrame(columns=output_cols)) + + aligned = [ + ds.map_batches( + _make_align_columns(all_join_keys, include_timestamp=include_timestamp), + batch_format="pandas", + ) + for ds in datasets + ] + entity_ds = aligned[0] + for ds in aligned[1:]: + entity_ds = entity_ds.union(ds) + return entity_ds.map_batches( + _make_distinct_by_keys(all_join_keys, include_timestamp=include_timestamp), + batch_format="pandas", + ) + + class RayOfflineStore(OfflineStore): def __init__(self) -> None: self._staging_location: Optional[str] = None @@ -1874,17 +2039,41 @@ def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], - entity_df: Union[pd.DataFrame, str], + entity_df: Optional[Union[pd.DataFrame, str]], registry: BaseRegistry, project: str, full_feature_names: bool = False, + **kwargs: Any, ) -> RetrievalJob: store = RayOfflineStore() store._init_ray(config) - # Load entity_df as Ray dataset for distributed processing + # Load or derive entity dataset for distributed processing ray_wrapper = get_ray_wrapper() - if isinstance(entity_df, str): + if entity_df is None: + # Non-entity mode: derive entity set from feature sources within a bounded time window + # Preserves distinct (entity_keys, event_timestamp) combinations for proper PIT joins + # This handles cases where multiple transactions per entity ID exist + start_date, end_date = _compute_non_entity_dates_ray( + feature_views, kwargs.get("start_date"), kwargs.get("end_date") + ) + per_view_entity_ds: List[Dataset] = [] + all_join_keys: List[str] = [] + for fv in feature_views: + ds, join_keys = _distinct_entities_for_feature_view_ray( + store, config, fv, registry, project, start_date, end_date + ) + per_view_entity_ds.append(ds) + for k in join_keys: + if k not in all_join_keys: + all_join_keys.append(k) + # Use include_timestamp=True to preserve actual event_timestamp from data + # instead of assigning a fixed end_date to all entities + entity_ds = _align_and_union_entities_ray( + per_view_entity_ds, all_join_keys, include_timestamp=True + ) + entity_df_sample = entity_ds.limit(1000).to_pandas() + elif isinstance(entity_df, str): entity_ds = ray_wrapper.read_csv(entity_df) entity_df_sample = entity_ds.limit(1000).to_pandas() else: diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py index 5ab82f8ef47..0420054b8fb 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py @@ -1,3 +1,5 @@ +from datetime import timedelta + import pandas as pd import pytest @@ -144,3 +146,126 @@ def test_ray_offline_store_persist(environment, universal_data_sources): import os assert os.path.exists(saved_path) + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_ray_offline_store_non_entity_mode_basic(environment, universal_data_sources): + """Test historical features retrieval without entity_df (non-entity mode). + + This tests the basic functionality where entity_df=None and start_date/end_date + are provided to retrieve all features within the time range. + """ + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + store.apply( + [ + driver(), + feature_views.driver, + ] + ) + + # Use the environment's start and end dates for the query + start_date = environment.start_date + end_date = environment.end_date + + # Non-entity mode: entity_df=None with start_date and end_date + result_df = store.get_historical_features( + entity_df=None, + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + "driver_stats:avg_daily_trips", + ], + full_feature_names=False, + start_date=start_date, + end_date=end_date, + ).to_df() + + # Verify data was retrieved + assert len(result_df) > 0, "Non-entity mode should return data" + assert "conv_rate" in result_df.columns + assert "acc_rate" in result_df.columns + assert "avg_daily_trips" in result_df.columns + assert "event_timestamp" in result_df.columns + assert "driver_id" in result_df.columns + + # Verify timestamps are within the requested range + result_df["event_timestamp"] = pd.to_datetime( + result_df["event_timestamp"], utc=True + ) + assert (result_df["event_timestamp"] >= start_date).all() + assert (result_df["event_timestamp"] <= end_date).all() + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_ray_offline_store_non_entity_mode_preserves_multiple_timestamps( + environment, universal_data_sources +): + """Test that non-entity mode preserves multiple transactions per entity ID. + + This is a regression test for the fix that ensures distinct (entity_key, event_timestamp) + combinations are preserved, not just distinct entity keys. This is critical for + proper point-in-time joins when an entity has multiple transactions. + """ + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + + store.apply( + [ + driver(), + feature_views.driver, + ] + ) + + now = _utc_now() + ts1 = pd.Timestamp(now - timedelta(hours=2)).round("ms") + ts2 = pd.Timestamp(now - timedelta(hours=1)).round("ms") + ts3 = pd.Timestamp(now).round("ms") + + # Write data with multiple timestamps for the same entity (driver_id=9001) + df_to_write = pd.DataFrame.from_dict( + { + "event_timestamp": [ts1, ts2, ts3], + "driver_id": [9001, 9001, 9001], # Same entity, different timestamps + "conv_rate": [0.1, 0.2, 0.3], + "acc_rate": [0.9, 0.8, 0.7], + "avg_daily_trips": [10, 20, 30], + "created": [ts1, ts2, ts3], + }, + ) + + store.write_to_offline_store( + feature_views.driver.name, df_to_write, allow_registry_cache=False + ) + + # Query without entity_df - should get all 3 rows for driver_id=9001 + result_df = store.get_historical_features( + entity_df=None, + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + ], + full_feature_names=False, + start_date=ts1 - timedelta(minutes=1), + end_date=ts3 + timedelta(minutes=1), + ).to_df() + + # Filter to just our test entity + result_df = result_df[result_df["driver_id"] == 9001] + + # Verify we got all 3 rows with different timestamps (not just 1 row) + assert len(result_df) == 3, ( + f"Expected 3 rows for driver_id=9001 (one per timestamp), got {len(result_df)}" + ) + + # Verify the feature values are correct for each timestamp + result_df = result_df.sort_values("event_timestamp").reset_index(drop=True) + assert list(result_df["conv_rate"]) == [0.1, 0.2, 0.3] + assert list(result_df["acc_rate"]) == [0.9, 0.8, 0.7]