diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index fc4517281d..7cec0b6747 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1267,6 +1267,7 @@ def get_historical_features( job = provider.get_historical_features( self.config, feature_views, + on_demand_feature_views, _feature_refs, entity_df, self._registry, diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 2f960a0282..5e0fc413cd 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -459,6 +459,7 @@ def get_historical_features( self, config: RepoConfig, feature_views: List[Union[FeatureView, OnDemandFeatureView]], + on_demand_feature_views: List[OnDemandFeatureView], feature_refs: List[str], entity_df: Optional[Union[pd.DataFrame, str]], registry: BaseRegistry, @@ -466,6 +467,7 @@ def get_historical_features( full_feature_names: bool, **kwargs, ) -> RetrievalJob: + del on_demand_feature_views job = self.offline_store.get_historical_features( config=config, feature_views=feature_views, diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index c2879c1e2d..58d44ccaeb 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -250,6 +250,7 @@ def get_historical_features( self, config: RepoConfig, feature_views: List[Union[FeatureView, OnDemandFeatureView]], + on_demand_feature_views: List[OnDemandFeatureView], feature_refs: List[str], entity_df: Optional[Union[pd.DataFrame, str]], registry: BaseRegistry, @@ -263,6 +264,7 @@ def get_historical_features( Args: config: The config for the current feature store. feature_views: A list containing all feature views that are referenced in the entity rows. + on_demand_feature_views: The on demand feature views requested as part of the retrieval. feature_refs: The features to be retrieved. entity_df: A collection of rows containing all entity columns on which features need to be joined, as well as the timestamp column used for point-in-time joins. Either a pandas dataframe can be diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index a04ff3cc45..a528013056 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -26,6 +26,7 @@ ProviderAsyncMethods, SupportedAsyncMethods, ) +from feast.on_demand_feature_view import OnDemandFeatureView from feast.online_response import OnlineResponse from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import RepeatedValue @@ -98,12 +99,14 @@ def get_historical_features( self, config: RepoConfig, feature_views: List[FeatureView], + on_demand_feature_views: List[OnDemandFeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], registry: BaseRegistry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: + del on_demand_feature_views return RetrievalJob() def online_read( diff --git a/sdk/python/tests/unit/test_feature_store_passes_odfvs_to_provider.py b/sdk/python/tests/unit/test_feature_store_passes_odfvs_to_provider.py new file mode 100644 index 0000000000..bdcf404c5b --- /dev/null +++ b/sdk/python/tests/unit/test_feature_store_passes_odfvs_to_provider.py @@ -0,0 +1,63 @@ +from types import SimpleNamespace + +from feast.feature_store import FeatureStore +from feast.infra.offline_stores.offline_store import RetrievalJob + + +class _CapturingProvider: + def __init__(self): + self.captured_on_demand_feature_views = None + + def get_historical_features( + self, + config, + feature_views, + on_demand_feature_views, + feature_refs, + entity_df, + registry, + project, + full_feature_names, + **kwargs, + ): + self.captured_on_demand_feature_views = on_demand_feature_views + return RetrievalJob() + + +def test_feature_store_passes_on_demand_feature_views_to_provider(monkeypatch): + provider = _CapturingProvider() + + store = FeatureStore.__new__(FeatureStore) + store.config = SimpleNamespace(project="test_project", coerce_tz_aware=False) + store._registry = object() + store._get_provider = lambda: provider + + fv = object() + odfv = object() + + monkeypatch.setattr( + "feast.utils._get_features", + lambda registry, project, features, allow_cache=False: ["odfv:feat"], + ) + monkeypatch.setattr( + "feast.utils._get_feature_views_to_use", + lambda registry, project, features, allow_cache=False, hide_dummy_entity=True: ( + [fv], + [odfv], + ), + ) + monkeypatch.setattr( + "feast.utils._group_feature_refs", + lambda feature_refs, feature_views, on_demand_feature_views: ( + [(fv, [])], + [(odfv, [])], + ), + ) + monkeypatch.setattr( + "feast.utils._validate_feature_refs", + lambda *args, **kwargs: None, + ) + + store.get_historical_features(entity_df="SELECT 1", features=["odfv:feat"]) + + assert provider.captured_on_demand_feature_views == [odfv]