Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ def get_historical_features(
if end_date is not None:
api_parameters["end_date"] = end_date.isoformat()

if isinstance(entity_df, str):
api_parameters["entity_df_sql"] = entity_df
entity_df = None

return RemoteRetrievalJob(
client=client,
api=OfflineStore.get_historical_features.__name__,
Expand Down Expand Up @@ -470,12 +474,12 @@ def get_table_column_names_and_types_from_data_source(


def _create_retrieval_metadata(
feature_refs: List[str], entity_df: Optional[pd.DataFrame] = None
feature_refs: List[str], entity_df: Optional[Union[pd.DataFrame, str]] = None
):
if entity_df is None:
if entity_df is None or isinstance(entity_df, str):
return RetrievalMetadata(
features=feature_refs,
keys=[], # No entity keys when no entity_df provided
keys=[],
min_event_timestamp=None,
max_event_timestamp=None,
)
Expand Down
4 changes: 4 additions & 0 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,10 @@ def get_historical_features(self, command: dict, key: Optional[str] = None):
if len(entity_df.columns) == 1 and "key" in entity_df.columns:
entity_df = None

# If the client sent a SQL string, use it directly
if entity_df is None and "entity_df_sql" in command:
entity_df = command["entity_df_sql"]

feature_view_names = command["feature_view_names"]
name_aliases = command["name_aliases"]
feature_refs = command["feature_refs"]
Expand Down
84 changes: 83 additions & 1 deletion sdk/python/tests/unit/test_offline_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import tempfile
from datetime import datetime, timedelta
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import assertpy
import pandas as pd
Expand All @@ -15,6 +15,7 @@
from feast.infra.offline_stores.remote import (
RemoteOfflineStore,
RemoteOfflineStoreConfig,
_create_retrieval_metadata,
)
from feast.offline_server import OfflineServer, _init_auth_manager
from feast.repo_config import RepoConfig
Expand Down Expand Up @@ -348,6 +349,87 @@ def _test_pull_all_from_table_or_query(temp_dir, fs: FeatureStore):
).to_df()


def test_create_retrieval_metadata_with_sql_string():
"""SQL string entity_df should produce a stub with empty keys and no timestamps."""
sql = "SELECT driver_id, event_timestamp FROM driver_stats"
metadata = _create_retrieval_metadata(
feature_refs=["driver_hourly_stats:conv_rate"], entity_df=sql
)
assertpy.assert_that(metadata.features).is_equal_to(
["driver_hourly_stats:conv_rate"]
)
assertpy.assert_that(list(metadata.keys)).is_empty()
assertpy.assert_that(metadata.min_event_timestamp).is_none()
assertpy.assert_that(metadata.max_event_timestamp).is_none()


def test_remote_offline_store_sql_entity_df_routing():
"""RemoteOfflineStore.get_historical_features must move a SQL string into
api_parameters['entity_df_sql'] and pass entity_df=None to RemoteRetrievalJob."""
sql = "SELECT driver_id, event_timestamp FROM driver_stats"

mock_client = MagicMock()
with patch(
"feast.infra.offline_stores.remote.build_arrow_flight_client",
return_value=mock_client,
):
job = RemoteOfflineStore.get_historical_features(
config=MagicMock(
offline_store=RemoteOfflineStoreConfig(
type="remote", host="localhost", port=8815
),
auth_config=MagicMock(type="no_auth"),
),
feature_views=[],
feature_refs=["driver_hourly_stats:conv_rate"],
entity_df=sql,
registry=MagicMock(),
project="test",
full_feature_names=False,
)

assertpy.assert_that(job.entity_df).is_none()
assertpy.assert_that(job.api_parameters).contains_key("entity_df_sql")
assertpy.assert_that(job.api_parameters["entity_df_sql"]).is_equal_to(sql)


def test_offline_server_get_historical_features_passes_sql_to_store():
"""OfflineServer.get_historical_features must forward entity_df_sql from the
command dict as a SQL string to the backing offline store."""
sql = "SELECT driver_id, event_timestamp FROM driver_stats"

mock_job = MagicMock()
mock_offline_store = MagicMock()
mock_offline_store.get_historical_features.return_value = mock_job

mock_store = MagicMock()
mock_store.config.project = "test"

server = MagicMock(spec=OfflineServer)
server.offline_store = mock_offline_store
server.store = mock_store
server.flights = {}
server.list_feature_views_by_name.return_value = []

command = {
"api": "get_historical_features",
"command_id": "abc",
"feature_view_names": [],
"name_aliases": [],
"feature_refs": ["driver_hourly_stats:conv_rate"],
"project": "test",
"full_feature_names": False,
"entity_df_sql": sql,
}

# Call the real method with the mock server as self
result = OfflineServer.get_historical_features(server, command, key=None)

assertpy.assert_that(result).is_equal_to(mock_job)
_, kwargs = mock_offline_store.get_historical_features.call_args
assertpy.assert_that(kwargs["entity_df"]).is_equal_to(sql)


def test_get_feature_view_by_name_propagates_transient_errors():
"""Transient registry errors must not be swallowed and misreported as
FeatureViewNotFoundException."""
Expand Down
Loading