diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index f83500cbc9b..00e94f612ee 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -300,10 +300,23 @@ def from_proto( == "user_defined_function" and on_demand_feature_view_proto.spec.feature_transformation.user_defined_function.body_text != "" + and on_demand_feature_view_proto.spec.mode == "pandas" ): transformation = PandasTransformation.from_proto( on_demand_feature_view_proto.spec.feature_transformation.user_defined_function ) + elif ( + on_demand_feature_view_proto.spec.feature_transformation.WhichOneof( + "transformation" + ) + == "user_defined_function" + and on_demand_feature_view_proto.spec.feature_transformation.user_defined_function.body_text + != "" + and on_demand_feature_view_proto.spec.mode == "python" + ): + transformation = PythonTransformation.from_proto( + on_demand_feature_view_proto.spec.feature_transformation.user_defined_function + ) elif ( on_demand_feature_view_proto.spec.feature_transformation.WhichOneof( "transformation" diff --git a/sdk/python/tests/example_repos/example_feature_repo_1.py b/sdk/python/tests/example_repos/example_feature_repo_1.py index eca9aee57c9..fbf1fbb9b07 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_1.py +++ b/sdk/python/tests/example_repos/example_feature_repo_1.py @@ -1,6 +1,9 @@ from datetime import timedelta +import pandas as pd + from feast import Entity, FeatureService, FeatureView, Field, FileSource, PushSource +from feast.on_demand_feature_view import on_demand_feature_view from feast.types import Float32, Int64, String # Note that file source paths are not validated, so there doesn't actually need to be any data @@ -99,6 +102,17 @@ ) +@on_demand_feature_view( + sources=[customer_profile], + schema=[Field(name="on_demand_age", dtype=Int64)], + mode="pandas", +) +def customer_profile_pandas_odfv(inputs: pd.DataFrame) -> pd.DataFrame: + outputs = pd.DataFrame() + outputs["on_demand_age"] = inputs["age"] + 1 + return outputs + + all_drivers_feature_service = FeatureService( name="driver_locations_service", features=[driver_locations], diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 926c7226fc8..6b8c8b0b981 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -124,6 +124,17 @@ def test_online() -> None: assert "trips" in result + result = store.get_online_features( + features=["customer_profile_pandas_odfv:on_demand_age"], + entity_rows=[{"driver_id": 1, "customer_id": "5"}], + full_feature_names=False, + ).to_dict() + + assert "on_demand_age" in result + assert result["driver_id"] == [1] + assert result["customer_id"] == ["5"] + assert result["on_demand_age"] == [4] + # invalid table reference with pytest.raises(FeatureViewNotFoundException): store.get_online_features( diff --git a/sdk/python/tests/unit/test_on_demand_pandas_transformation.py b/sdk/python/tests/unit/test_on_demand_pandas_transformation.py new file mode 100644 index 00000000000..c5f066dd83d --- /dev/null +++ b/sdk/python/tests/unit/test_on_demand_pandas_transformation.py @@ -0,0 +1,93 @@ +import os +import tempfile +from datetime import datetime, timedelta + +import pandas as pd + +from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig +from feast.driver_test_data import create_driver_hourly_stats_df +from feast.field import Field +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.on_demand_feature_view import on_demand_feature_view +from feast.types import Float32, Float64, Int64 + + +def test_pandas_transformation(): + with tempfile.TemporaryDirectory() as data_dir: + store = FeatureStore( + config=RepoConfig( + project="test_on_demand_python_transformation", + registry=os.path.join(data_dir, "registry.db"), + provider="local", + entity_key_serialization_version=2, + online_store=SqliteOnlineStoreConfig( + path=os.path.join(data_dir, "online.db") + ), + ) + ) + + # Generate test data. + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) + start_date = end_date - timedelta(days=15) + + driver_entities = [1001, 1002, 1003, 1004, 1005] + driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date) + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") + driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True) + + driver = Entity(name="driver", join_keys=["driver_id"]) + + driver_stats_source = FileSource( + name="driver_hourly_stats_source", + path=driver_stats_path, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=0), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=driver_stats_source, + ) + + @on_demand_feature_view( + sources=[driver_stats_fv], + schema=[Field(name="conv_rate_plus_acc", dtype=Float64)], + mode="pandas", + ) + def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["conv_rate_plus_acc"] = inputs["conv_rate"] + inputs["acc_rate"] + return df + + store.apply([driver, driver_stats_source, driver_stats_fv, pandas_view]) + + entity_rows = [ + { + "driver_id": 1001, + } + ] + store.write_to_online_store( + feature_view_name="driver_hourly_stats", df=driver_df + ) + + online_response = store.get_online_features( + entity_rows=entity_rows, + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "driver_hourly_stats:avg_daily_trips", + "pandas_view:conv_rate_plus_acc", + ], + ).to_df() + + assert online_response["conv_rate_plus_acc"].equals( + online_response["conv_rate"] + online_response["acc_rate"] + ) diff --git a/sdk/python/tests/unit/test_on_demand_python_transformation.py b/sdk/python/tests/unit/test_on_demand_python_transformation.py new file mode 100644 index 00000000000..4913b6c1b1d --- /dev/null +++ b/sdk/python/tests/unit/test_on_demand_python_transformation.py @@ -0,0 +1,172 @@ +import os +import tempfile +import unittest +from datetime import datetime, timedelta +from typing import Any, Dict + +import pandas as pd +import pytest + +from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig +from feast.driver_test_data import create_driver_hourly_stats_df +from feast.field import Field +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.on_demand_feature_view import on_demand_feature_view +from feast.types import Float32, Float64, Int64 + + +class TestOnDemandPythonTransformation(unittest.TestCase): + def setUp(self): + with tempfile.TemporaryDirectory() as data_dir: + self.store = FeatureStore( + config=RepoConfig( + project="test_on_demand_python_transformation", + registry=os.path.join(data_dir, "registry.db"), + provider="local", + entity_key_serialization_version=2, + online_store=SqliteOnlineStoreConfig( + path=os.path.join(data_dir, "online.db") + ), + ) + ) + + # Generate test data. + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) + start_date = end_date - timedelta(days=15) + + driver_entities = [1001, 1002, 1003, 1004, 1005] + driver_df = create_driver_hourly_stats_df( + driver_entities, start_date, end_date + ) + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") + driver_df.to_parquet( + path=driver_stats_path, allow_truncated_timestamps=True + ) + + driver = Entity(name="driver", join_keys=["driver_id"]) + + driver_stats_source = FileSource( + name="driver_hourly_stats_source", + path=driver_stats_path, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=0), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=driver_stats_source, + ) + + @on_demand_feature_view( + sources=[driver_stats_fv], + schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)], + mode="pandas", + ) + def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["conv_rate_plus_acc_pandas"] = ( + inputs["conv_rate"] + inputs["acc_rate"] + ) + return df + + @on_demand_feature_view( + sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], + schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)], + mode="python", + ) + def python_view(inputs: Dict[str, Any]) -> Dict[str, Any]: + output: Dict[str, Any] = { + "conv_rate_plus_acc_python": [ + conv_rate + acc_rate + for conv_rate, acc_rate in zip( + inputs["conv_rate"], inputs["acc_rate"] + ) + ] + } + return output + + @on_demand_feature_view( + sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], + schema=[ + Field(name="conv_rate_plus_acc_python_singleton", dtype=Float64) + ], + mode="python", + ) + def python_singleton_view(inputs: Dict[str, Any]) -> Dict[str, Any]: + output: Dict[str, Any] = dict(conv_rate_plus_acc_python=float("-inf")) + output["conv_rate_plus_acc_python_singleton"] = ( + inputs["conv_rate"] + inputs["acc_rate"] + ) + return output + + with pytest.raises(TypeError): + # Note the singleton view will fail as the type is + # expected to be a List which can be confirmed in _infer_features_dict + self.store.apply( + [ + driver, + driver_stats_source, + driver_stats_fv, + pandas_view, + python_view, + python_singleton_view, + ] + ) + + self.store.apply( + [driver, driver_stats_source, driver_stats_fv, pandas_view, python_view] + ) + self.store.write_to_online_store( + feature_view_name="driver_hourly_stats", df=driver_df + ) + + def test_python_pandas_parity(self): + entity_rows = [ + { + "driver_id": 1001, + } + ] + + online_python_response = self.store.get_online_features( + entity_rows=entity_rows, + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "python_view:conv_rate_plus_acc_python", + ], + ).to_dict() + + online_pandas_response = self.store.get_online_features( + entity_rows=entity_rows, + features=[ + "driver_hourly_stats:conv_rate", + "driver_hourly_stats:acc_rate", + "pandas_view:conv_rate_plus_acc_pandas", + ], + ).to_df() + + assert len(online_python_response) == 4 + assert all( + key in online_python_response.keys() + for key in [ + "driver_id", + "acc_rate", + "conv_rate", + "conv_rate_plus_acc_python", + ] + ) + assert len(online_python_response["conv_rate_plus_acc_python"]) == 1 + assert ( + online_python_response["conv_rate_plus_acc_python"][0] + == online_pandas_response["conv_rate_plus_acc_pandas"][0] + == online_python_response["conv_rate"][0] + + online_python_response["acc_rate"][0] + )