diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index 8cc75ade445..ea09362299d 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -13,6 +13,7 @@ # limitations under the License. import json import logging +from collections import defaultdict from datetime import datetime from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple @@ -20,12 +21,16 @@ from pydantic import StrictStr from feast import Entity, FeatureView, RepoConfig +from feast.infra.online_stores.helpers import _to_naive_utc from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel from feast.rest_error_handler import rest_error_handling_decorator -from feast.type_map import python_values_to_proto_values +from feast.type_map import ( + feast_value_type_to_python_type, + python_values_to_proto_values, +) from feast.value_type import ValueType logger = logging.getLogger(__name__) @@ -60,7 +65,55 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: - raise NotImplementedError + """ + Writes a batch of feature rows to the remote online store via the remote API. + """ + assert isinstance(config.online_store, RemoteOnlineStoreConfig) + config.online_store.__class__ = RemoteOnlineStoreConfig + + columnar_data: Dict[str, List[Any]] = defaultdict(list) + + # Iterate through each row to populate columnar data directly + for entity_key_proto, feature_values_proto, event_ts, created_ts in data: + # Populate entity key values + for join_key, entity_value_proto in zip( + entity_key_proto.join_keys, entity_key_proto.entity_values + ): + columnar_data[join_key].append( + feast_value_type_to_python_type(entity_value_proto) + ) + + # Populate feature values + for feature_name, feature_value_proto in feature_values_proto.items(): + columnar_data[feature_name].append( + feast_value_type_to_python_type(feature_value_proto) + ) + + # Populate timestamps + columnar_data["event_timestamp"].append(_to_naive_utc(event_ts).isoformat()) + columnar_data["created"].append( + _to_naive_utc(created_ts).isoformat() if created_ts else None + ) + + req_body = { + "feature_view_name": table.name, + "df": columnar_data, + "allow_registry_cache": False, + } + + response = post_remote_online_write(config=config, req_body=req_body) + + if response.status_code != 200: + error_msg = f"Unable to write online store data using feature server API. Error_code={response.status_code}, error_message={response.text}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + if progress: + data_length = len(data) + logger.info( + f"Writing {data_length} rows to the remote store for feature view {table.name}." + ) + progress(data_length) def online_read( self, @@ -184,3 +237,14 @@ def get_remote_online_features( return session.post( f"{config.online_store.path}/get-online-features", data=req_body ) + + +@rest_error_handling_decorator +def post_remote_online_write( + session: requests.Session, config: RepoConfig, req_body: dict +) -> requests.Response: + url = f"{config.online_store.path}/write-to-online-store" + if config.online_store.cert: + return session.post(url, json=req_body, verify=config.online_store.cert) + else: + return session.post(url, json=req_body) diff --git a/sdk/python/tests/integration/online_store/test_remote_online_store.py b/sdk/python/tests/integration/online_store/test_remote_online_store.py index eb03fd0c3c5..3b5b707dcb7 100644 --- a/sdk/python/tests/integration/online_store/test_remote_online_store.py +++ b/sdk/python/tests/integration/online_store/test_remote_online_store.py @@ -1,15 +1,28 @@ import logging import os import tempfile +from datetime import timedelta from textwrap import dedent +import pandas as pd import pytest -from feast import FeatureView, OnDemandFeatureView, StreamFeatureView +from feast import ( + Entity, + FeatureView, + Field, + FileSource, + OnDemandFeatureView, + PushSource, + StreamFeatureView, +) +from feast.data_source import PushMode from feast.feature_store import FeatureStore from feast.permissions.action import AuthzedAction from feast.permissions.permission import Permission from feast.permissions.policy import RoleBasedPolicy +from feast.types import Float32, Int64 +from feast.utils import _utc_now from tests.utils.auth_permissions_util import ( PROJECT_NAME, default_store, @@ -235,7 +248,6 @@ def _create_remote_client_feature_store( if is_tls_mode and ca_trust_store_path: # configure trust store path only when is_tls_mode and ca_trust_store_path exists. os.environ["FEAST_CA_CERT_FILE_PATH"] = ca_trust_store_path - return FeatureStore(repo_path=repo_path) @@ -265,3 +277,139 @@ def _overwrite_remote_client_feature_store_yaml( with open(repo_config, "w") as repo_config_file: repo_config_file.write(config_content) + + +@pytest.mark.integration +@pytest.mark.rbac_remote_integration_test +@pytest.mark.parametrize( + "tls_mode", [("True", "True"), ("True", "False"), ("False", "")], indirect=True +) +def test_remote_online_store_read_write(auth_config, tls_mode): + with ( + tempfile.TemporaryDirectory() as remote_server_tmp_dir, + tempfile.TemporaryDirectory() as remote_client_tmp_dir, + ): + permissions_list = [ + Permission( + name="online_list_fv_perm", + types=FeatureView, + policy=RoleBasedPolicy(roles=["reader"]), + actions=[AuthzedAction.READ_ONLINE], + ), + Permission( + name="online_list_odfv_perm", + types=OnDemandFeatureView, + policy=RoleBasedPolicy(roles=["reader"]), + actions=[AuthzedAction.READ_ONLINE], + ), + Permission( + name="online_list_sfv_perm", + types=StreamFeatureView, + policy=RoleBasedPolicy(roles=["reader"]), + actions=[AuthzedAction.READ_ONLINE], + ), + Permission( + name="online_write_fv_perm", + types=FeatureView, + policy=RoleBasedPolicy(roles=["writer"]), + actions=[AuthzedAction.WRITE_ONLINE], + ), + Permission( + name="online_write_odfv_perm", + types=OnDemandFeatureView, + policy=RoleBasedPolicy(roles=["writer"]), + actions=[AuthzedAction.WRITE_ONLINE], + ), + Permission( + name="online_write_sfv_perm", + types=StreamFeatureView, + policy=RoleBasedPolicy(roles=["writer"]), + actions=[AuthzedAction.WRITE_ONLINE], + ), + ] + server_store, server_url, registry_path = ( + _create_server_store_spin_feature_server( + temp_dir=remote_server_tmp_dir, + auth_config=auth_config, + permissions_list=permissions_list, + tls_mode=tls_mode, + ) + ) + assert None not in (server_store, server_url, registry_path) + + client_store = _create_remote_client_feature_store( + temp_dir=remote_client_tmp_dir, + server_registry_path=str(registry_path), + feature_server_url=server_url, + auth_config=auth_config, + tls_mode=tls_mode, + ) + assert client_store is not None + + # Define a simple FeatureView for testing write operations + driver = Entity(name="driver_id", description="Drivers id") + + driver_hourly_stats_source = FileSource( + path="data/driver_stats.parquet", # Path is not used for online writes in this context + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + + PushSource( + name="driver_stats_push_source", + batch_source=driver_hourly_stats_source, + ) + + driver_hourly_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=1), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + source=driver_hourly_stats_source, + tags={}, + ) + + # Apply the feature view to the client store + client_store.apply([driver, driver_hourly_stats_fv]) + event_df = pd.DataFrame( + { + "driver_id": [1000, 1001], + "conv_rate": [0.56, 0.74], + "acc_rate": [0.95, 0.93], + "avg_daily_trips": [50, 45], + "event_timestamp": [pd.Timestamp(_utc_now()).round("ms")] * 2, + "created": [pd.Timestamp(_utc_now()).round("ms")] * 2, + } + ) + + # Perform the online write + client_store.push( + push_source_name="driver_stats_push_source", df=event_df, to=PushMode.ONLINE + ) + + # Verify the data by reading it back + # read_entity_keys = [entity_key_1, entity_key_2] + read_features = [ + "driver_hourly_stats_fresh:conv_rate", + "driver_hourly_stats_fresh:acc_rate", + "driver_hourly_stats_fresh:avg_daily_trips", + ] + online_features = client_store.get_online_features( + features=read_features, + entity_rows=[{"driver_id": 1000}, {"driver_id": 1001}], + ).to_dict() + + # Assertions for read data + assert online_features is not None + assert len(online_features["driver_id"]) == 2 + assert online_features["driver_id"] == [1000, 1001] + assert [round(val, 2) for val in online_features["conv_rate"]] == [0.56, 0.74] + assert [round(val, 2) for val in online_features["acc_rate"]] == [0.95, 0.93] + assert online_features["avg_daily_trips"] == [50, 45] + + # Clean up the applied feature view from the server store to avoid interference with other tests + server_store.teardown()