Skip to content

Commit d71b452

Browse files
authored
Add to_df() to convert get_online_feature response to pandas dataframe (#1623)
* Add to_df() to convert get_online_feature response to pandas dataframe Signed-off-by: ted chang <htchang@us.ibm.com> * Add dataframe column and row consistency test Signed-off-by: ted chang <htchang@us.ibm.com>
1 parent b32e766 commit d71b452

File tree

3 files changed

+172
-1
lines changed

3 files changed

+172
-1
lines changed

sdk/python/feast/feature_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def get_online_features(
474474
475475
Note: This method will download the full feature registry the first time it is run. If you are using a
476476
remote registry like GCS or S3 then that may take a few seconds. The registry remains cached up to a TTL
477-
duration (which can be set to infinitey). If the cached registry is stale (more time than the TTL has
477+
duration (which can be set to infinity). If the cached registry is stale (more time than the TTL has
478478
passed), then a new registry will be downloaded synchronously by this method. This download may
479479
introduce latency to online feature retrieval. In order to avoid synchronous downloads, please call
480480
refresh_registry() prior to the TTL being reached. Remember it is possible to set the cache TTL to

sdk/python/feast/online_response.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from typing import Any, Dict, List, cast
1616

17+
import pandas as pd
18+
1719
from feast.protos.feast.serving.ServingService_pb2 import (
1820
GetOnlineFeaturesRequestV2,
1921
GetOnlineFeaturesResponse,
@@ -61,6 +63,13 @@ def to_dict(self) -> Dict[str, Any]:
6163

6264
return features_dict
6365

66+
def to_df(self) -> pd.DataFrame:
67+
"""
68+
Converts GetOnlineFeaturesResponse features into Panda dataframe form.
69+
"""
70+
71+
return pd.DataFrame(self.to_dict())
72+
6473

6574
def _infer_online_entity_rows(
6675
entity_rows: List[Dict[str, Any]]

sdk/python/tests/test_online_retrieval.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import time
33
from datetime import datetime
44

5+
import pandas as pd
56
import pytest
7+
from pandas.testing import assert_frame_equal
68

79
from feast import FeatureStore, RepoConfig
810
from feast.errors import FeatureViewNotFoundException
@@ -234,3 +236,163 @@ def test_online() -> None:
234236

235237
# Restore registry.db so that teardown works
236238
os.rename(store.config.registry + "_fake", store.config.registry)
239+
240+
241+
def test_online_to_df():
242+
"""
243+
Test dataframe conversion. Make sure the response columns and rows are
244+
the same order as the request.
245+
"""
246+
247+
driver_ids = [1, 2, 3]
248+
customer_ids = [4, 5, 6]
249+
name = "foo"
250+
lon_multiply = 1.0
251+
lat_multiply = 0.1
252+
age_multiply = 10
253+
avg_order_day_multiply = 1.0
254+
255+
runner = CliRunner()
256+
with runner.local_repo(
257+
get_example_repo("example_feature_repo_1.py"), "bigquery"
258+
) as store:
259+
# Write three tables to online store
260+
driver_locations_fv = store.get_feature_view(name="driver_locations")
261+
customer_profile_fv = store.get_feature_view(name="customer_profile")
262+
customer_driver_combined_fv = store.get_feature_view(
263+
name="customer_driver_combined"
264+
)
265+
provider = store._get_provider()
266+
267+
for (d, c) in zip(driver_ids, customer_ids):
268+
"""
269+
driver table:
270+
driver driver_locations__lon driver_locations__lat
271+
1 1.0 0.1
272+
2 2.0 0.2
273+
3 3.0 0.3
274+
"""
275+
driver_key = EntityKeyProto(
276+
join_keys=["driver"], entity_values=[ValueProto(int64_val=d)]
277+
)
278+
provider.online_write_batch(
279+
config=store.config,
280+
table=driver_locations_fv,
281+
data=[
282+
(
283+
driver_key,
284+
{
285+
"lat": ValueProto(double_val=d * lat_multiply),
286+
"lon": ValueProto(string_val=str(d * lon_multiply)),
287+
},
288+
datetime.utcnow(),
289+
datetime.utcnow(),
290+
)
291+
],
292+
progress=None,
293+
)
294+
295+
"""
296+
customer table
297+
customer customer_profile__avg_orders_day customer_profile__name customer_profile__age
298+
4 4.0 foo4 40
299+
5 5.0 foo5 50
300+
6 6.0 foo6 60
301+
"""
302+
customer_key = EntityKeyProto(
303+
join_keys=["customer"], entity_values=[ValueProto(int64_val=c)]
304+
)
305+
provider.online_write_batch(
306+
config=store.config,
307+
table=customer_profile_fv,
308+
data=[
309+
(
310+
customer_key,
311+
{
312+
"avg_orders_day": ValueProto(
313+
float_val=c * avg_order_day_multiply
314+
),
315+
"name": ValueProto(string_val=name + str(c)),
316+
"age": ValueProto(int64_val=c * age_multiply),
317+
},
318+
datetime.utcnow(),
319+
datetime.utcnow(),
320+
)
321+
],
322+
progress=None,
323+
)
324+
"""
325+
customer_driver_combined table
326+
customer driver customer_driver_combined__trips
327+
4 1 4
328+
5 2 10
329+
6 3 18
330+
"""
331+
combo_keys = EntityKeyProto(
332+
join_keys=["customer", "driver"],
333+
entity_values=[ValueProto(int64_val=c), ValueProto(int64_val=d)],
334+
)
335+
provider.online_write_batch(
336+
config=store.config,
337+
table=customer_driver_combined_fv,
338+
data=[
339+
(
340+
combo_keys,
341+
{"trips": ValueProto(int64_val=c * d)},
342+
datetime.utcnow(),
343+
datetime.utcnow(),
344+
)
345+
],
346+
progress=None,
347+
)
348+
349+
# Get online features in dataframe
350+
result_df = store.get_online_features(
351+
feature_refs=[
352+
"driver_locations:lon",
353+
"driver_locations:lat",
354+
"customer_profile:avg_orders_day",
355+
"customer_profile:name",
356+
"customer_profile:age",
357+
"customer_driver_combined:trips",
358+
],
359+
# Reverse the row order
360+
entity_rows=[
361+
{"driver": d, "customer": c}
362+
for (d, c) in zip(reversed(driver_ids), reversed(customer_ids))
363+
],
364+
).to_df()
365+
"""
366+
Construct the expected dataframe with reversed row order like so:
367+
driver customer driver_locations__lon driver_locations__lat customer_profile__avg_orders_day customer_profile__name customer_profile__age customer_driver_combined__trips
368+
3 6 3.0 0.3 6.0 foo6 60 18
369+
2 5 2.0 0.2 5.0 foo5 50 10
370+
1 4 1.0 0.1 4.0 foo4 40 4
371+
"""
372+
df_dict = {
373+
"driver": driver_ids,
374+
"customer": customer_ids,
375+
"driver_locations__lon": [str(d * lon_multiply) for d in driver_ids],
376+
"driver_locations__lat": [d * lat_multiply for d in driver_ids],
377+
"customer_profile__avg_orders_day": [
378+
c * avg_order_day_multiply for c in customer_ids
379+
],
380+
"customer_profile__name": [name + str(c) for c in customer_ids],
381+
"customer_profile__age": [c * age_multiply for c in customer_ids],
382+
"customer_driver_combined__trips": [
383+
d * c for (d, c) in zip(driver_ids, customer_ids)
384+
],
385+
}
386+
# Requested column order
387+
ordered_column = [
388+
"driver",
389+
"customer",
390+
"driver_locations__lon",
391+
"driver_locations__lat",
392+
"customer_profile__avg_orders_day",
393+
"customer_profile__name",
394+
"customer_profile__age",
395+
"customer_driver_combined__trips",
396+
]
397+
expected_df = pd.DataFrame({k: reversed(v) for (k, v) in df_dict.items()})
398+
assert_frame_equal(result_df[ordered_column], expected_df)

0 commit comments

Comments
 (0)