From 99aec1870bd99685d091deaf939c5c22d8a7615e Mon Sep 17 00:00:00 2001 From: ntkathole Date: Tue, 22 Jul 2025 00:40:51 +0530 Subject: [PATCH] fix: Fixed ODFV on-write transformations Signed-off-by: ntkathole --- sdk/python/feast/feature_store.py | 139 ++++++++ sdk/python/feast/infra/provider.py | 7 +- sdk/python/feast/utils.py | 6 +- .../test_universal_materialization.py | 310 +++++++++++++++++- 4 files changed, 447 insertions(+), 15 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 255df3db41f..d8229ea1e56 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -67,6 +67,9 @@ update_feature_views_with_inferred_features_and_entities, ) from feast.infra.infra_object import Infra +from feast.infra.offline_stores.offline_utils import ( + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, +) from feast.infra.provider import Provider, RetrievalJob, get_provider from feast.infra.registry.base_registry import BaseRegistry from feast.infra.registry.registry import Registry @@ -1287,6 +1290,115 @@ def get_saved_dataset(self, name: str) -> SavedDataset: ) return dataset.with_retrieval_job(retrieval_job) + def _materialize_odfv( + self, + feature_view: OnDemandFeatureView, + start_date: datetime, + end_date: datetime, + ): + """Helper to materialize a single OnDemandFeatureView.""" + if not feature_view.source_feature_view_projections: + print( + f"[WARNING] ODFV {feature_view.name} materialization: No source feature views found." + ) + return + start_date = utils.make_tzaware(start_date) + end_date = utils.make_tzaware(end_date) + + source_features_from_projections = [] + all_join_keys = set() + entity_timestamp_col_names = set() + source_fvs = { + self._get_feature_view(p.name) + for p in feature_view.source_feature_view_projections.values() + } + + for source_fv in source_fvs: + all_join_keys.update(source_fv.entities) + if source_fv.batch_source: + entity_timestamp_col_names.add(source_fv.batch_source.timestamp_field) + + for proj in feature_view.source_feature_view_projections.values(): + source_features_from_projections.extend( + [f"{proj.name}:{f.name}" for f in proj.features] + ) + + all_join_keys = {key for key in all_join_keys if key} + + if not all_join_keys: + print( + f"[WARNING] ODFV {feature_view.name} materialization: No join keys found in source views. Cannot create entity_df. Skipping." + ) + return + + if len(entity_timestamp_col_names) > 1: + print( + f"[WARNING] ODFV {feature_view.name} materialization: Found multiple timestamp columns in sources ({entity_timestamp_col_names}). This is not supported. Skipping." + ) + return + + if not entity_timestamp_col_names: + print( + f"[WARNING] ODFV {feature_view.name} materialization: No batch sources with timestamp columns found for sources. Skipping." + ) + return + + event_timestamp_col = list(entity_timestamp_col_names)[0] + all_source_dfs = [] + provider = self._get_provider() + + for source_fv in source_fvs: + if not source_fv.batch_source: + continue + + job = provider.offline_store.pull_latest_from_table_or_query( + config=self.config, + data_source=source_fv.batch_source, + join_key_columns=source_fv.entities, + feature_name_columns=[f.name for f in source_fv.features], + timestamp_field=source_fv.batch_source.timestamp_field, + created_timestamp_column=getattr( + source_fv.batch_source, "created_timestamp_column", None + ), + start_date=start_date, + end_date=end_date, + ) + df = job.to_df() + if not df.empty: + all_source_dfs.append(df) + + if not all_source_dfs: + print( + f"No source data found for ODFV {feature_view.name} in the given time range. Skipping materialization." + ) + return + + entity_df_cols = list(all_join_keys) + [event_timestamp_col] + all_sources_combined_df = pd.concat(all_source_dfs, ignore_index=True) + if all_sources_combined_df.empty: + return + + entity_df = ( + all_sources_combined_df[entity_df_cols] + .drop_duplicates() + .reset_index(drop=True) + ) + + if event_timestamp_col != DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL: + entity_df = entity_df.rename( + columns={event_timestamp_col: DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL} + ) + + retrieval_job = self.get_historical_features( + entity_df=entity_df, + features=source_features_from_projections, + ) + input_df = retrieval_job.to_df() + transformed_df = self._transform_on_demand_feature_view_df( + feature_view, input_df + ) + self.write_to_online_store(feature_view.name, df=transformed_df) + def materialize_incremental( self, end_date: datetime, @@ -1332,7 +1444,27 @@ def materialize_incremental( # TODO paging large loads for feature_view in feature_views_to_materialize: if isinstance(feature_view, OnDemandFeatureView): + if feature_view.write_to_online_store: + source_fvs = { + self._get_feature_view(p.name) + for p in feature_view.source_feature_view_projections.values() + } + max_ttl = timedelta(0) + for fv in source_fvs: + if fv.ttl and fv.ttl > max_ttl: + max_ttl = fv.ttl + + if max_ttl.total_seconds() > 0: + odfv_start_date = end_date - max_ttl + else: + odfv_start_date = end_date - timedelta(weeks=52) + + print( + f"{Style.BRIGHT + Fore.GREEN}{feature_view.name}{Style.RESET_ALL}:" + ) + self._materialize_odfv(feature_view, odfv_start_date, end_date) continue + start_date = feature_view.most_recent_end_time if start_date is None: if feature_view.ttl is None: @@ -1428,6 +1560,13 @@ def materialize( ) # TODO paging large loads for feature_view in feature_views_to_materialize: + if isinstance(feature_view, OnDemandFeatureView): + if feature_view.write_to_online_store: + print( + f"{Style.BRIGHT + Fore.GREEN}{feature_view.name}{Style.RESET_ALL}:" + ) + self._materialize_odfv(feature_view, start_date, end_date) + continue provider = self._get_provider() print(f"{Style.BRIGHT + Fore.GREEN}{feature_view.name}{Style.RESET_ALL}:") diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 4f7b0d4b5c1..c9150c542e4 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -25,7 +25,8 @@ from feast.feature_view import FeatureView from feast.importer import import_class from feast.infra.infra_object import Infra -from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob +from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry from feast.infra.supported_async_methods import ProviderAsyncMethods from feast.on_demand_feature_view import OnDemandFeatureView @@ -52,6 +53,10 @@ class Provider(ABC): engine. It is configured through a RepoConfig object. """ + repo_config: RepoConfig + offline_store: OfflineStore + online_store: OnlineStore + @abstractmethod def __init__(self, config: RepoConfig): pass diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index d8f075d16d4..c63dad6a6ab 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -257,7 +257,11 @@ def _convert_arrow_to_proto( join_keys: Dict[str, ValueType], ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: # This is a workaround for isinstance(feature_view, OnDemandFeatureView), which triggers a circular import - if getattr(feature_view, "source_request_sources", None): + # Check for source_request_sources or source_feature_view_projections attributes to identify ODFVs + if ( + getattr(feature_view, "source_request_sources", None) is not None + or getattr(feature_view, "source_feature_view_projections", None) is not None + ): return _convert_arrow_odfv_to_proto(table, feature_view, join_keys) # type: ignore[arg-type] else: return _convert_arrow_fv_to_proto(table, feature_view, join_keys) # type: ignore[arg-type] diff --git a/sdk/python/tests/integration/materialization/test_universal_materialization.py b/sdk/python/tests/integration/materialization/test_universal_materialization.py index 37030b1bb30..860e9a5fc6c 100644 --- a/sdk/python/tests/integration/materialization/test_universal_materialization.py +++ b/sdk/python/tests/integration/materialization/test_universal_materialization.py @@ -1,33 +1,236 @@ from datetime import timedelta +import pandas as pd import pytest -from feast.entity import Entity -from feast.feature_view import FeatureView -from feast.field import Field -from feast.types import Float32 +from feast import ( + Entity, + FeatureView, + Field, +) +from feast.on_demand_feature_view import on_demand_feature_view +from feast.types import Float32, Float64 from tests.data.data_creator import create_basic_driver_dataset from tests.utils.e2e_test_validation import validate_offline_online_store_consistency +def _create_test_entities(): + """Helper function to create standard test entities.""" + customer = Entity(name="customer_id", join_keys=["customer_id"]) + product = Entity(name="product_id", join_keys=["product_id"]) + return customer, product + + +def _create_test_dataframe(include_revenue=False): + """Helper function to create standard test DataFrame.""" + data = { + "customer_id": [1, 2], + "product_id": [10, 20], + "price": [100.0, 200.0], + "event_timestamp": pd.to_datetime(["2024-01-01", "2024-01-01"]), + "created_timestamp": pd.to_datetime(["2024-01-01", "2024-01-01"]), + } + if include_revenue: + data["revenue"] = [5.0, 10.0] + return pd.DataFrame(data) + + +def _create_revenue_dataframe(): + """Helper function to create revenue test DataFrame.""" + return pd.DataFrame( + { + "customer_id": [1, 2], + "product_id": [10, 20], + "revenue": [5.0, 7.0], + "event_timestamp": pd.to_datetime(["2024-01-01", "2024-01-01"]), + "created_timestamp": pd.to_datetime(["2024-01-01", "2024-01-01"]), + } + ) + + +def _create_feature_view(name, entities, schema_fields, source): + """Helper function to create a standard FeatureView.""" + return FeatureView( + name=name, + entities=entities, + ttl=timedelta(days=1), + schema=schema_fields, + online=True, + source=source, + ) + + +def _materialize_and_assert(fs, df, feature_ref, entity_row, expected_value): + """Helper function to materialize and assert feature values.""" + feature_view_name = feature_ref.split(":")[0] + fs.materialize( + start_date=df["event_timestamp"].min() - timedelta(days=1), + end_date=df["event_timestamp"].max() + timedelta(days=1), + feature_views=[feature_view_name], + ) + resp = fs.get_online_features( + features=[feature_ref], + entity_rows=[entity_row], + ).to_dict() + feature_name = feature_ref.split(":")[-1] + assert resp[feature_name][0] == expected_value + + +def _assert_online_features( + fs, feature_ref, entity_row, expected_value, message_prefix="Expected" +): + """Helper function to assert online feature values.""" + resp = fs.get_online_features( + features=[feature_ref], + entity_rows=[entity_row], + ).to_dict() + feature_name = feature_ref.split(":")[-1] + assert resp[feature_name][0] == expected_value, ( + f"{message_prefix} {expected_value}, got {resp[feature_name][0]}" + ) + + +def _get_standard_entity_row(): + """Helper function to get standard entity row for testing.""" + return {"customer_id": 1, "product_id": 10} + + +@pytest.mark.integration +def test_odfv_materialization_single_source(environment): + fs = environment.feature_store + df = _create_test_dataframe() + ds = environment.data_source_creator.create_data_source( + df, + fs.project, + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + customer, product = _create_test_entities() + + fv1 = _create_feature_view( + "fv1", [customer, product], [Field(name="price", dtype=Float32)], ds + ) + + @on_demand_feature_view( + entities=[customer, product], + sources=[fv1], + schema=[Field(name="price_plus_10", dtype=Float64)], + write_to_online_store=True, + ) + def odfv_single(df: pd.DataFrame) -> pd.DataFrame: + df["price_plus_10"] = df["price"] + 10 + return df + + fs.apply([customer, product, fv1, odfv_single]) + _materialize_and_assert( + fs, df, "odfv_single:price_plus_10", _get_standard_entity_row(), 110.0 + ) + + +@pytest.mark.integration +def test_odfv_materialization_multi_source(environment): + fs = environment.feature_store + df1 = _create_test_dataframe() + df2 = _create_revenue_dataframe() + ds1 = environment.data_source_creator.create_data_source( + df1, + fs.project, + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + ds2 = environment.data_source_creator.create_data_source( + df2, + fs.project, + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + + customer, product = _create_test_entities() + fv1 = _create_feature_view( + "fv1", [customer, product], [Field(name="price", dtype=Float32)], ds1 + ) + fv2 = _create_feature_view( + "fv2", [customer, product], [Field(name="revenue", dtype=Float32)], ds2 + ) + + @on_demand_feature_view( + entities=[customer, product], + sources=[fv1, fv2], + schema=[Field(name="price_plus_revenue", dtype=Float64)], + write_to_online_store=True, + ) + def odfv_multi(df: pd.DataFrame) -> pd.DataFrame: + df["price_plus_revenue"] = df["price"] + df["revenue"] + return df + + fs.apply([customer, product, fv1, fv2, odfv_multi]) + _materialize_and_assert( + fs, df1, "odfv_multi:price_plus_revenue", _get_standard_entity_row(), 105.0 + ) + + +@pytest.mark.integration +def test_odfv_materialization_incremental_multi_source(environment): + fs = environment.feature_store + df1 = _create_test_dataframe() + df2 = _create_revenue_dataframe() + ds1 = environment.data_source_creator.create_data_source( + df1, + fs.project, + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + ds2 = environment.data_source_creator.create_data_source( + df2, + fs.project, + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + + customer, product = _create_test_entities() + fv1 = _create_feature_view( + "fv1", [customer, product], [Field(name="price", dtype=Float32)], ds1 + ) + fv2 = _create_feature_view( + "fv2", [customer, product], [Field(name="revenue", dtype=Float32)], ds2 + ) + + @on_demand_feature_view( + entities=[customer, product], + sources=[fv1, fv2], + schema=[Field(name="price_plus_revenue", dtype=Float64)], + write_to_online_store=True, + ) + def odfv_multi(df: pd.DataFrame) -> pd.DataFrame: + df["price_plus_revenue"] = df["price"] + df["revenue"] + return df + + fs.apply([customer, product, fv1, fv2, odfv_multi]) + fs.materialize_incremental( + end_date=df1["event_timestamp"].max() + timedelta(days=1) + ) + + resp = fs.get_online_features( + features=["odfv_multi:price_plus_revenue"], + entity_rows=[_get_standard_entity_row()], + ).to_dict() + assert resp["price_plus_revenue"][0] == 105.0 + + @pytest.mark.integration @pytest.mark.universal_offline_stores def test_universal_materialization_consistency(environment): fs = environment.feature_store - df = create_basic_driver_dataset() - ds = environment.data_source_creator.create_data_source( df, fs.project, field_mapping={"ts_1": "ts"}, ) - driver = Entity( name="driver_id", join_keys=["driver_id"], ) - driver_stats_fv = FeatureView( name="driver_hourly_stats", entities=[driver], @@ -35,11 +238,92 @@ def test_universal_materialization_consistency(environment): schema=[Field(name="value", dtype=Float32)], source=ds, ) - fs.apply([driver, driver_stats_fv]) - - # materialization is run in two steps and - # we use timestamp from generated dataframe as a split point split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1) - validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt) + + +@pytest.mark.integration +def test_odfv_write_methods(environment): + """ + Comprehensive test for ODFV on-write transformations not persisting. + Tests store.push(), store.write_to_online_store(), and materialize() methods. + """ + fs = environment.feature_store + df = _create_test_dataframe(include_revenue=True) + ds = environment.data_source_creator.create_data_source( + df, + fs.project, + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + + customer, product = _create_test_entities() + fv = _create_feature_view( + "price_revenue_fv", + [customer, product], + [Field(name="price", dtype=Float32), Field(name="revenue", dtype=Float32)], + ds, + ) + + @on_demand_feature_view( + entities=[customer, product], + sources=[fv], + schema=[Field(name="total_value", dtype=Float64)], + write_to_online_store=True, + ) + def total_value_odfv(df: pd.DataFrame) -> pd.DataFrame: + df["total_value"] = df["price"] + df["revenue"] + return df + + fs.apply([customer, product, fv, total_value_odfv]) + _materialize_and_assert( + fs, df, "total_value_odfv:total_value", _get_standard_entity_row(), 105.0 + ) + + new_data = pd.DataFrame( + { + "customer_id": [3], + "product_id": [30], + "price": [300.0], + "revenue": [15.0], + "event_timestamp": [pd.Timestamp.now()], + "created_timestamp": [pd.Timestamp.now()], + } + ) + + transformed_data = fs._transform_on_demand_feature_view_df( + total_value_odfv, new_data + ) + fs.write_to_online_store("total_value_odfv", df=transformed_data) + _assert_online_features( + fs, "total_value_odfv:total_value", {"customer_id": 3, "product_id": 30}, 315.0 + ) + + @on_demand_feature_view( + entities=[customer, product], + sources=[fv], + schema=[Field(name="price_doubled", dtype=Float64)], + write_to_online_store=False, # This is on-read only + ) + def price_doubled_odfv(df: pd.DataFrame) -> pd.DataFrame: + df["price_doubled"] = df["price"] * 2 + return df + + fs.apply([price_doubled_odfv]) + # Materialize the underlying feature view so the on-read ODFV can access the price feature + fs.materialize( + start_date=df["event_timestamp"].min() - timedelta(days=1), + end_date=df["event_timestamp"].max() + timedelta(days=1), + feature_views=["price_revenue_fv"], + ) + _assert_online_features( + fs, "price_doubled_odfv:price_doubled", _get_standard_entity_row(), 200.0 + ) + + resp = fs.get_online_features( + features=["total_value_odfv:total_value", "price_doubled_odfv:price_doubled"], + entity_rows=[_get_standard_entity_row()], + ).to_dict() + assert resp["total_value"][0] == 105.0, "On-write ODFV failed" + assert resp["price_doubled"][0] == 200.0, "On-read ODFV failed"