|
2 | 2 | import time |
3 | 3 | from datetime import datetime |
4 | 4 |
|
| 5 | +import pandas as pd |
5 | 6 | import pytest |
| 7 | +from pandas.testing import assert_frame_equal |
6 | 8 |
|
7 | 9 | from feast import FeatureStore, RepoConfig |
8 | 10 | from feast.errors import FeatureViewNotFoundException |
@@ -234,3 +236,163 @@ def test_online() -> None: |
234 | 236 |
|
235 | 237 | # Restore registry.db so that teardown works |
236 | 238 | os.rename(store.config.registry + "_fake", store.config.registry) |
| 239 | + |
| 240 | + |
| 241 | +def test_online_to_df(): |
| 242 | + """ |
| 243 | + Test dataframe conversion. Make sure the response columns and rows are |
| 244 | + the same order as the request. |
| 245 | + """ |
| 246 | + |
| 247 | + driver_ids = [1, 2, 3] |
| 248 | + customer_ids = [4, 5, 6] |
| 249 | + name = "foo" |
| 250 | + lon_multiply = 1.0 |
| 251 | + lat_multiply = 0.1 |
| 252 | + age_multiply = 10 |
| 253 | + avg_order_day_multiply = 1.0 |
| 254 | + |
| 255 | + runner = CliRunner() |
| 256 | + with runner.local_repo( |
| 257 | + get_example_repo("example_feature_repo_1.py"), "bigquery" |
| 258 | + ) as store: |
| 259 | + # Write three tables to online store |
| 260 | + driver_locations_fv = store.get_feature_view(name="driver_locations") |
| 261 | + customer_profile_fv = store.get_feature_view(name="customer_profile") |
| 262 | + customer_driver_combined_fv = store.get_feature_view( |
| 263 | + name="customer_driver_combined" |
| 264 | + ) |
| 265 | + provider = store._get_provider() |
| 266 | + |
| 267 | + for (d, c) in zip(driver_ids, customer_ids): |
| 268 | + """ |
| 269 | + driver table: |
| 270 | + driver driver_locations__lon driver_locations__lat |
| 271 | + 1 1.0 0.1 |
| 272 | + 2 2.0 0.2 |
| 273 | + 3 3.0 0.3 |
| 274 | + """ |
| 275 | + driver_key = EntityKeyProto( |
| 276 | + join_keys=["driver"], entity_values=[ValueProto(int64_val=d)] |
| 277 | + ) |
| 278 | + provider.online_write_batch( |
| 279 | + config=store.config, |
| 280 | + table=driver_locations_fv, |
| 281 | + data=[ |
| 282 | + ( |
| 283 | + driver_key, |
| 284 | + { |
| 285 | + "lat": ValueProto(double_val=d * lat_multiply), |
| 286 | + "lon": ValueProto(string_val=str(d * lon_multiply)), |
| 287 | + }, |
| 288 | + datetime.utcnow(), |
| 289 | + datetime.utcnow(), |
| 290 | + ) |
| 291 | + ], |
| 292 | + progress=None, |
| 293 | + ) |
| 294 | + |
| 295 | + """ |
| 296 | + customer table |
| 297 | + customer customer_profile__avg_orders_day customer_profile__name customer_profile__age |
| 298 | + 4 4.0 foo4 40 |
| 299 | + 5 5.0 foo5 50 |
| 300 | + 6 6.0 foo6 60 |
| 301 | + """ |
| 302 | + customer_key = EntityKeyProto( |
| 303 | + join_keys=["customer"], entity_values=[ValueProto(int64_val=c)] |
| 304 | + ) |
| 305 | + provider.online_write_batch( |
| 306 | + config=store.config, |
| 307 | + table=customer_profile_fv, |
| 308 | + data=[ |
| 309 | + ( |
| 310 | + customer_key, |
| 311 | + { |
| 312 | + "avg_orders_day": ValueProto( |
| 313 | + float_val=c * avg_order_day_multiply |
| 314 | + ), |
| 315 | + "name": ValueProto(string_val=name + str(c)), |
| 316 | + "age": ValueProto(int64_val=c * age_multiply), |
| 317 | + }, |
| 318 | + datetime.utcnow(), |
| 319 | + datetime.utcnow(), |
| 320 | + ) |
| 321 | + ], |
| 322 | + progress=None, |
| 323 | + ) |
| 324 | + """ |
| 325 | + customer_driver_combined table |
| 326 | + customer driver customer_driver_combined__trips |
| 327 | + 4 1 4 |
| 328 | + 5 2 10 |
| 329 | + 6 3 18 |
| 330 | + """ |
| 331 | + combo_keys = EntityKeyProto( |
| 332 | + join_keys=["customer", "driver"], |
| 333 | + entity_values=[ValueProto(int64_val=c), ValueProto(int64_val=d)], |
| 334 | + ) |
| 335 | + provider.online_write_batch( |
| 336 | + config=store.config, |
| 337 | + table=customer_driver_combined_fv, |
| 338 | + data=[ |
| 339 | + ( |
| 340 | + combo_keys, |
| 341 | + {"trips": ValueProto(int64_val=c * d)}, |
| 342 | + datetime.utcnow(), |
| 343 | + datetime.utcnow(), |
| 344 | + ) |
| 345 | + ], |
| 346 | + progress=None, |
| 347 | + ) |
| 348 | + |
| 349 | + # Get online features in dataframe |
| 350 | + result_df = store.get_online_features( |
| 351 | + feature_refs=[ |
| 352 | + "driver_locations:lon", |
| 353 | + "driver_locations:lat", |
| 354 | + "customer_profile:avg_orders_day", |
| 355 | + "customer_profile:name", |
| 356 | + "customer_profile:age", |
| 357 | + "customer_driver_combined:trips", |
| 358 | + ], |
| 359 | + # Reverse the row order |
| 360 | + entity_rows=[ |
| 361 | + {"driver": d, "customer": c} |
| 362 | + for (d, c) in zip(reversed(driver_ids), reversed(customer_ids)) |
| 363 | + ], |
| 364 | + ).to_df() |
| 365 | + """ |
| 366 | + Construct the expected dataframe with reversed row order like so: |
| 367 | + driver customer driver_locations__lon driver_locations__lat customer_profile__avg_orders_day customer_profile__name customer_profile__age customer_driver_combined__trips |
| 368 | + 3 6 3.0 0.3 6.0 foo6 60 18 |
| 369 | + 2 5 2.0 0.2 5.0 foo5 50 10 |
| 370 | + 1 4 1.0 0.1 4.0 foo4 40 4 |
| 371 | + """ |
| 372 | + df_dict = { |
| 373 | + "driver": driver_ids, |
| 374 | + "customer": customer_ids, |
| 375 | + "driver_locations__lon": [str(d * lon_multiply) for d in driver_ids], |
| 376 | + "driver_locations__lat": [d * lat_multiply for d in driver_ids], |
| 377 | + "customer_profile__avg_orders_day": [ |
| 378 | + c * avg_order_day_multiply for c in customer_ids |
| 379 | + ], |
| 380 | + "customer_profile__name": [name + str(c) for c in customer_ids], |
| 381 | + "customer_profile__age": [c * age_multiply for c in customer_ids], |
| 382 | + "customer_driver_combined__trips": [ |
| 383 | + d * c for (d, c) in zip(driver_ids, customer_ids) |
| 384 | + ], |
| 385 | + } |
| 386 | + # Requested column order |
| 387 | + ordered_column = [ |
| 388 | + "driver", |
| 389 | + "customer", |
| 390 | + "driver_locations__lon", |
| 391 | + "driver_locations__lat", |
| 392 | + "customer_profile__avg_orders_day", |
| 393 | + "customer_profile__name", |
| 394 | + "customer_profile__age", |
| 395 | + "customer_driver_combined__trips", |
| 396 | + ] |
| 397 | + expected_df = pd.DataFrame({k: reversed(v) for (k, v) in df_dict.items()}) |
| 398 | + assert_frame_equal(result_df[ordered_column], expected_df) |
0 commit comments