Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def get_online_features(

Note: This method will download the full feature registry the first time it is run. If you are using a
remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL
duration (which can be set to infinitey). If the cached registry is stale (more time than the TTL has
duration (which can be set to infinity). If the cached registry is stale (more time than the TTL has
passed), then a new registry will be downloaded synchronously by this method. This download may
introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call
refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to
Expand Down
9 changes: 9 additions & 0 deletions sdk/python/feast/online_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing import Any, Dict, List, cast

import pandas as pd

from feast.protos.feast.serving.ServingService_pb2 import (
GetOnlineFeaturesRequestV2,
GetOnlineFeaturesResponse,
Expand Down Expand Up @@ -61,6 +63,13 @@ def to_dict(self) -> Dict[str, Any]:

return features_dict

def to_df(self) -> pd.DataFrame:
"""
Converts GetOnlineFeaturesResponse features into Panda dataframe form.
"""

return pd.DataFrame(self.to_dict())


def _infer_online_entity_rows(
entity_rows: List[Dict[str, Any]]
Expand Down
162 changes: 162 additions & 0 deletions sdk/python/tests/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import time
from datetime import datetime

import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

from feast import FeatureStore, RepoConfig
from feast.errors import FeatureViewNotFoundException
Expand Down Expand Up @@ -234,3 +236,163 @@ def test_online() -> None:

# Restore registry.db so that teardown works
os.rename(store.config.registry + "_fake", store.config.registry)


def test_online_to_df():
"""
Test dataframe conversion. Make sure the response columns and rows are
the same order as the request.
"""

driver_ids = [1, 2, 3]
customer_ids = [4, 5, 6]
name = "foo"
lon_multiply = 1.0
lat_multiply = 0.1
age_multiply = 10
avg_order_day_multiply = 1.0

runner = CliRunner()
with runner.local_repo(
get_example_repo("example_feature_repo_1.py"), "bigquery"
) as store:
# Write three tables to online store
driver_locations_fv = store.get_feature_view(name="driver_locations")
customer_profile_fv = store.get_feature_view(name="customer_profile")
customer_driver_combined_fv = store.get_feature_view(
name="customer_driver_combined"
)
provider = store._get_provider()

for (d, c) in zip(driver_ids, customer_ids):
"""
driver table:
driver driver_locations__lon driver_locations__lat
1 1.0 0.1
2 2.0 0.2
3 3.0 0.3
"""
driver_key = EntityKeyProto(
join_keys=["driver"], entity_values=[ValueProto(int64_val=d)]
)
provider.online_write_batch(
config=store.config,
table=driver_locations_fv,
data=[
(
driver_key,
{
"lat": ValueProto(double_val=d * lat_multiply),
"lon": ValueProto(string_val=str(d * lon_multiply)),
},
datetime.utcnow(),
datetime.utcnow(),
)
],
progress=None,
)

"""
customer table
customer customer_profile__avg_orders_day customer_profile__name customer_profile__age
4 4.0 foo4 40
5 5.0 foo5 50
6 6.0 foo6 60
"""
customer_key = EntityKeyProto(
join_keys=["customer"], entity_values=[ValueProto(int64_val=c)]
)
provider.online_write_batch(
config=store.config,
table=customer_profile_fv,
data=[
(
customer_key,
{
"avg_orders_day": ValueProto(
float_val=c * avg_order_day_multiply
),
"name": ValueProto(string_val=name + str(c)),
"age": ValueProto(int64_val=c * age_multiply),
},
datetime.utcnow(),
datetime.utcnow(),
)
],
progress=None,
)
"""
customer_driver_combined table
customer driver customer_driver_combined__trips
4 1 4
5 2 10
6 3 18
"""
combo_keys = EntityKeyProto(
join_keys=["customer", "driver"],
entity_values=[ValueProto(int64_val=c), ValueProto(int64_val=d)],
)
provider.online_write_batch(
config=store.config,
table=customer_driver_combined_fv,
data=[
(
combo_keys,
{"trips": ValueProto(int64_val=c * d)},
datetime.utcnow(),
datetime.utcnow(),
)
],
progress=None,
)

# Get online features in dataframe
result_df = store.get_online_features(
feature_refs=[
"driver_locations:lon",
"driver_locations:lat",
"customer_profile:avg_orders_day",
"customer_profile:name",
"customer_profile:age",
"customer_driver_combined:trips",
],
# Reverse the row order
entity_rows=[
{"driver": d, "customer": c}
for (d, c) in zip(reversed(driver_ids), reversed(customer_ids))
],
).to_df()
"""
Construct the expected dataframe with reversed row order like so:
driver customer driver_locations__lon driver_locations__lat customer_profile__avg_orders_day customer_profile__name customer_profile__age customer_driver_combined__trips
3 6 3.0 0.3 6.0 foo6 60 18
2 5 2.0 0.2 5.0 foo5 50 10
1 4 1.0 0.1 4.0 foo4 40 4
"""
df_dict = {
"driver": driver_ids,
"customer": customer_ids,
"driver_locations__lon": [str(d * lon_multiply) for d in driver_ids],
"driver_locations__lat": [d * lat_multiply for d in driver_ids],
"customer_profile__avg_orders_day": [
c * avg_order_day_multiply for c in customer_ids
],
"customer_profile__name": [name + str(c) for c in customer_ids],
"customer_profile__age": [c * age_multiply for c in customer_ids],
"customer_driver_combined__trips": [
d * c for (d, c) in zip(driver_ids, customer_ids)
],
}
# Requested column order
ordered_column = [
"driver",
"customer",
"driver_locations__lon",
"driver_locations__lat",
"customer_profile__avg_orders_day",
"customer_profile__name",
"customer_profile__age",
"customer_driver_combined__trips",
]
expected_df = pd.DataFrame({k: reversed(v) for (k, v) in df_dict.items()})
assert_frame_equal(result_df[ordered_column], expected_df)