diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index bfb8a59b2bb..83aaafd6863 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -2037,11 +2037,10 @@ def _augment_response_with_on_demand_transforms( proto_values = [] for selected_feature in selected_subset: - if odfv.mode in ["python", "pandas"]: - feature_vector = transformed_features[selected_feature] - proto_values.append( - python_values_to_proto_values(feature_vector, ValueType.UNKNOWN) - ) + feature_vector = transformed_features[selected_feature] + proto_values.append( + python_values_to_proto_values(feature_vector, ValueType.UNKNOWN) + ) odfv_result_names |= set(selected_subset) diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index aaed78dd459..6c16ef26439 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -128,7 +128,7 @@ def to_arrow( features_df = self._to_df_internal(timeout=timeout) if self.on_demand_feature_views: for odfv in self.on_demand_feature_views: - if odfv.mode != "pandas": + if odfv.mode not in {"pandas", "substrait"}: raise Exception( f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.' ) diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index 8d51edbe587..f83500cbc9b 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -465,7 +465,9 @@ def get_transformed_features( return self.get_transformed_features_dict( feature_dict=features, ) - elif self.mode == "pandas" and isinstance(features, pd.DataFrame): + elif self.mode in {"pandas", "substrait"} and isinstance( + features, pd.DataFrame + ): return self.get_transformed_features_df( df_with_features=features, full_feature_names=full_feature_names, diff --git a/sdk/python/tests/unit/test_on_demand_substrait_transformation.py b/sdk/python/tests/unit/test_substrait_transformation.py similarity index 73% rename from sdk/python/tests/unit/test_on_demand_substrait_transformation.py rename to sdk/python/tests/unit/test_substrait_transformation.py index 378aa7ce3bd..28ab68c70be 100644 --- a/sdk/python/tests/unit/test_on_demand_substrait_transformation.py +++ b/sdk/python/tests/unit/test_substrait_transformation.py @@ -60,6 +60,7 @@ def test_ibis_pandas_parity(): @on_demand_feature_view( sources=[driver_stats_fv], schema=[Field(name="conv_rate_plus_acc", dtype=Float64)], + mode="pandas", ) def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: df = pd.DataFrame() @@ -84,30 +85,50 @@ def substrait_view(inputs: Table) -> Table: [driver, driver_stats_source, driver_stats_fv, substrait_view, pandas_view] ) + store.materialize( + start_date=start_date, + end_date=end_date, + ) + entity_df = pd.DataFrame.from_dict( { # entity's join key -> entity values "driver_id": [1001, 1002, 1003], # "event_timestamp" (reserved key) -> timestamps "event_timestamp": [ - datetime(2021, 4, 12, 10, 59, 42), - datetime(2021, 4, 12, 8, 12, 10), - datetime(2021, 4, 12, 16, 40, 26), + start_date + timedelta(days=4), + start_date + timedelta(days=5), + start_date + timedelta(days=6), ], } ) + requested_features = [ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "driver_hourly_stats:avg_daily_trips", + "substrait_view:conv_rate_plus_acc_substrait", + "pandas_view:conv_rate_plus_acc", + ] + training_df = store.get_historical_features( - entity_df=entity_df, - features=[ - "driver_hourly_stats:conv_rate", - "driver_hourly_stats:acc_rate", - "driver_hourly_stats:avg_daily_trips", - "substrait_view:conv_rate_plus_acc_substrait", - "pandas_view:conv_rate_plus_acc", - ], - ).to_df() + entity_df=entity_df, features=requested_features + ) + + assert training_df.to_df()["conv_rate_plus_acc"].equals( + training_df.to_df()["conv_rate_plus_acc_substrait"] + ) + + assert training_df.to_arrow()["conv_rate_plus_acc"].equals( + training_df.to_arrow()["conv_rate_plus_acc_substrait"] + ) + + online_response = store.get_online_features( + features=requested_features, + entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}, {"driver_id": 1003}], + ) - assert training_df["conv_rate_plus_acc"].equals( - training_df["conv_rate_plus_acc_substrait"] + assert ( + online_response.to_dict()["conv_rate_plus_acc"] + == online_response.to_dict()["conv_rate_plus_acc_substrait"] )