From f652bd4fd31f02d07d2506b837bf74d7a0fe1cbb Mon Sep 17 00:00:00 2001 From: tokoko Date: Thu, 21 Mar 2024 04:15:07 +0000 Subject: [PATCH 1/2] feat: refactor ibis point-in-time-join Signed-off-by: tokoko --- .../contrib/ibis_offline_store/ibis.py | 231 +++++++++--------- .../requirements/py3.10-ci-requirements.txt | 29 ++- .../requirements/py3.10-requirements.txt | 4 +- .../requirements/py3.9-ci-requirements.txt | 29 ++- .../requirements/py3.9-requirements.txt | 4 +- .../unit/infra/offline_stores/test_ibis.py | 88 +++++++ setup.py | 1 + 7 files changed, 248 insertions(+), 138 deletions(-) create mode 100644 sdk/python/tests/unit/infra/offline_stores/test_ibis.py diff --git a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py index 72e0d970c60..cb35cc083e0 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py @@ -72,112 +72,6 @@ def _get_entity_df_event_timestamp_range( return entity_df_event_timestamp_range - @staticmethod - def _get_historical_features_one( - feature_view: FeatureView, - entity_table: Table, - feature_refs: List[str], - full_feature_names: bool, - timestamp_range: Tuple, - acc_table: Table, - event_timestamp_col: str, - ) -> Table: - fv_table: Table = ibis.read_parquet(feature_view.batch_source.name) - - for old_name, new_name in feature_view.batch_source.field_mapping.items(): - if old_name in fv_table.columns: - fv_table = fv_table.rename({new_name: old_name}) - - timestamp_field = feature_view.batch_source.timestamp_field - - # TODO mutate only if tz-naive - fv_table = fv_table.mutate( - **{ - timestamp_field: fv_table[timestamp_field].cast( - dt.Timestamp(timezone="UTC") - ) - } - ) - - full_name_prefix = feature_view.projection.name_alias or feature_view.name - - feature_refs = [ - fr.split(":")[1] - for fr in feature_refs - if fr.startswith(f"{full_name_prefix}:") - ] - - timestamp_range_start_minus_ttl = ( - timestamp_range[0] - feature_view.ttl - if feature_view.ttl and feature_view.ttl > timedelta(0, 0, 0, 0, 0, 0, 0) - else timestamp_range[0] - ) - - timestamp_range_start_minus_ttl = ibis.literal( - timestamp_range_start_minus_ttl.strftime("%Y-%m-%d %H:%M:%S.%f") - ).cast(dt.Timestamp(timezone="UTC")) - - timestamp_range_end = ibis.literal( - timestamp_range[1].strftime("%Y-%m-%d %H:%M:%S.%f") - ).cast(dt.Timestamp(timezone="UTC")) - - fv_table = fv_table.filter( - ibis.and_( - fv_table[timestamp_field] <= timestamp_range_end, - fv_table[timestamp_field] >= timestamp_range_start_minus_ttl, - ) - ) - - # join_key_map = feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns} - # predicates = [fv_table[k] == entity_table[v] for k, v in join_key_map.items()] - - if feature_view.projection.join_key_map: - predicates = [ - fv_table[k] == entity_table[v] - for k, v in feature_view.projection.join_key_map.items() - ] - else: - predicates = [ - fv_table[e.name] == entity_table[e.name] - for e in feature_view.entity_columns - ] - - predicates.append( - fv_table[timestamp_field] <= entity_table[event_timestamp_col] - ) - - fv_table = fv_table.inner_join( - entity_table, predicates, lname="", rname="{name}_y" - ) - - fv_table = ( - fv_table.group_by(by="entity_row_id") - .order_by(ibis.desc(fv_table[timestamp_field])) - .mutate(rn=ibis.row_number()) - ) - - fv_table = fv_table.filter(fv_table["rn"] == ibis.literal(0)) - - select_cols = ["entity_row_id"] - select_cols.extend(feature_refs) - fv_table = fv_table.select(select_cols) - - if full_feature_names: - fv_table = fv_table.rename( - {f"{full_name_prefix}__{feature}": feature for feature in feature_refs} - ) - - acc_table = acc_table.left_join( - fv_table, - predicates=[fv_table.entity_row_id == acc_table.entity_row_id], - lname="", - rname="{name}_yyyy", - ) - - acc_table = acc_table.drop(s.endswith("_yyyy")) - - return acc_table - @staticmethod def _to_utc(entity_df: pd.DataFrame, event_timestamp_col): entity_df_event_timestamp = entity_df.loc[ @@ -228,9 +122,11 @@ def get_historical_features( entity_schema=entity_schema, ) + # TODO get range with ibis timestamp_range = IbisOfflineStore._get_entity_df_event_timestamp_range( entity_df, event_timestamp_col ) + entity_df = IbisOfflineStore._to_utc(entity_df, event_timestamp_col) entity_table = ibis.memtable(entity_df) @@ -238,20 +134,55 @@ def get_historical_features( entity_table, feature_views, event_timestamp_col ) - res: Table = entity_table + def read_fv(feature_view, feature_refs, full_feature_names): + fv_table: Table = ibis.read_parquet(feature_view.batch_source.name) - for fv in feature_views: - res = IbisOfflineStore._get_historical_features_one( - fv, - entity_table, + for old_name, new_name in feature_view.batch_source.field_mapping.items(): + if old_name in fv_table.columns: + fv_table = fv_table.rename({new_name: old_name}) + + timestamp_field = feature_view.batch_source.timestamp_field + + # TODO mutate only if tz-naive + fv_table = fv_table.mutate( + **{ + timestamp_field: fv_table[timestamp_field].cast( + dt.Timestamp(timezone="UTC") + ) + } + ) + + full_name_prefix = feature_view.projection.name_alias or feature_view.name + + feature_refs = [ + fr.split(":")[1] + for fr in feature_refs + if fr.startswith(f"{full_name_prefix}:") + ] + + if full_feature_names: + fv_table = fv_table.rename( + {f"{full_name_prefix}__{feature}": feature for feature in feature_refs} + ) + + feature_refs = [f"{full_name_prefix}__{feature}" for feature in feature_refs] + + return ( + fv_table, + feature_view.batch_source.timestamp_field, + feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns}, feature_refs, - full_feature_names, - timestamp_range, - res, - event_timestamp_col, + feature_view.ttl ) - res = res.drop("entity_row_id") + res = point_in_time_join( + entity_table=entity_table, + feature_tables=[ + read_fv(feature_view, feature_refs, full_feature_names) + for feature_view in feature_views + ], + event_timestamp_col=event_timestamp_col + ) return IbisRetrievalJob( res, @@ -285,6 +216,10 @@ def pull_all_from_table_or_query( table = table.select(*fields) + # TODO get rid of this fix + if '__log_date' in table.columns: + table = table.drop('__log_date') + table = table.filter( ibis.and_( table[timestamp_field] >= ibis.literal(start_date), @@ -320,6 +255,7 @@ def write_logged_features( else: kwargs = {} + #TODO always write to directory table.to_parquet( f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs ) @@ -405,3 +341,66 @@ def persist( @property def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata + + +def point_in_time_join( + entity_table: Table, + feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]], + event_timestamp_col = 'event_timestamp' +): + #TODO handle ttl + all_entities = [event_timestamp_col] + for feature_table, timestamp_field, join_key_map, _, _ in feature_tables: + all_entities.extend(join_key_map.values()) + + r = ibis.literal("") + + for e in set(all_entities): + r = r.concat(entity_table[e].cast("string")) # type: ignore + + entity_table = entity_table.mutate(entity_row_id=r) + + acc_table = entity_table + + for feature_table, timestamp_field, join_key_map, feature_refs, ttl in feature_tables: + predicates = [feature_table[k] == entity_table[v] for k, v in join_key_map.items()] + + predicates.append( + feature_table[timestamp_field] <= entity_table[event_timestamp_col], + ) + + if ttl: + predicates.append( + feature_table[timestamp_field] >= entity_table[event_timestamp_col] - ibis.literal(ttl) + ) + + feature_table = feature_table.inner_join( + entity_table, predicates, lname="", rname="{name}_y" + ) + + feature_table = feature_table.drop(s.endswith("_y")) + + feature_table = ( + feature_table.group_by(by="entity_row_id") + .order_by(ibis.desc(feature_table[timestamp_field])) + .mutate(rn=ibis.row_number()) + ) + + feature_table = feature_table.filter(feature_table["rn"] == ibis.literal(0)).drop("rn") + + select_cols = ["entity_row_id"] + select_cols.extend(feature_refs) + feature_table = feature_table.select(select_cols) + + acc_table = acc_table.left_join( + feature_table, + predicates=[feature_table.entity_row_id == acc_table.entity_row_id], + lname="", + rname="{name}_yyyy", + ) + + acc_table = acc_table.drop(s.endswith("_yyyy")) + + acc_table = acc_table.drop('entity_row_id') + + return acc_table \ No newline at end of file diff --git a/sdk/python/requirements/py3.10-ci-requirements.txt b/sdk/python/requirements/py3.10-ci-requirements.txt index 8f0ef90d77e..737271eee16 100644 --- a/sdk/python/requirements/py3.10-ci-requirements.txt +++ b/sdk/python/requirements/py3.10-ci-requirements.txt @@ -61,11 +61,11 @@ black==22.12.0 # via feast (setup.py) bleach==6.1.0 # via nbconvert -boto3==1.34.65 +boto3==1.34.67 # via # feast (setup.py) # moto -botocore==1.34.65 +botocore==1.34.67 # via # boto3 # moto @@ -82,7 +82,7 @@ cachecontrol==0.14.0 # via firebase-admin cachetools==5.3.3 # via google-auth -cassandra-driver==3.29.0 +cassandra-driver==3.29.1 # via feast (setup.py) certifi==2024.2.2 # via @@ -164,6 +164,12 @@ docker==7.0.0 # testcontainers docutils==0.19 # via sphinx +duckdb==0.10.1 + # via + # duckdb-engine + # ibis-framework +duckdb-engine==0.11.2 + # via ibis-framework entrypoints==0.4 # via altair exceptiongroup==1.2.0 @@ -213,7 +219,7 @@ google-api-core[grpc]==2.17.1 # google-cloud-storage google-api-python-client==2.122.0 # via firebase-admin -google-auth==2.28.2 +google-auth==2.29.0 # via # google-api-core # google-api-python-client @@ -258,7 +264,7 @@ googleapis-common-protos[grpc]==1.63.0 # google-api-core # grpc-google-iam-v1 # grpcio-status -great-expectations==0.18.11 +great-expectations==0.18.12 # via feast (setup.py) greenlet==3.0.3 # via sqlalchemy @@ -310,7 +316,7 @@ httpx==0.27.0 # via # feast (setup.py) # jupyterlab -ibis-framework==8.0.0 +ibis-framework[duckdb]==8.0.0 # via # feast (setup.py) # ibis-substrait @@ -331,7 +337,7 @@ importlib-metadata==6.11.0 # via # dask # feast (setup.py) -importlib-resources==6.3.1 +importlib-resources==6.3.2 # via feast (setup.py) iniconfig==2.0.0 # via pytest @@ -459,7 +465,7 @@ moreorless==0.4.0 # via bowler moto==4.2.14 # via feast (setup.py) -msal==1.27.0 +msal==1.28.0 # via # azure-identity # msal-extensions @@ -844,8 +850,13 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx sqlalchemy[mypy]==1.4.52 # via + # duckdb-engine # feast (setup.py) + # ibis-framework # sqlalchemy + # sqlalchemy-views +sqlalchemy-views==0.3.2 + # via ibis-framework sqlalchemy2-stubs==0.0.2a38 # via sqlalchemy sqlglot==20.11.0 @@ -984,7 +995,7 @@ urllib3==1.26.18 # requests # responses # rockset -uvicorn[standard]==0.28.0 +uvicorn[standard]==0.29.0 # via feast (setup.py) uvloop==0.19.0 # via uvicorn diff --git a/sdk/python/requirements/py3.10-requirements.txt b/sdk/python/requirements/py3.10-requirements.txt index e17a588538e..240f43b57e4 100644 --- a/sdk/python/requirements/py3.10-requirements.txt +++ b/sdk/python/requirements/py3.10-requirements.txt @@ -62,7 +62,7 @@ importlib-metadata==6.11.0 # via # dask # feast (setup.py) -importlib-resources==6.3.1 +importlib-resources==6.3.2 # via feast (setup.py) jinja2==3.1.3 # via feast (setup.py) @@ -176,7 +176,7 @@ tzdata==2024.1 # via pandas urllib3==2.2.1 # via requests -uvicorn[standard]==0.28.0 +uvicorn[standard]==0.29.0 # via feast (setup.py) uvloop==0.19.0 # via uvicorn diff --git a/sdk/python/requirements/py3.9-ci-requirements.txt b/sdk/python/requirements/py3.9-ci-requirements.txt index dc96554431b..f2585a7978d 100644 --- a/sdk/python/requirements/py3.9-ci-requirements.txt +++ b/sdk/python/requirements/py3.9-ci-requirements.txt @@ -61,11 +61,11 @@ black==22.12.0 # via feast (setup.py) bleach==6.1.0 # via nbconvert -boto3==1.34.65 +boto3==1.34.67 # via # feast (setup.py) # moto -botocore==1.34.65 +botocore==1.34.67 # via # boto3 # moto @@ -82,7 +82,7 @@ cachecontrol==0.14.0 # via firebase-admin cachetools==5.3.3 # via google-auth -cassandra-driver==3.29.0 +cassandra-driver==3.29.1 # via feast (setup.py) certifi==2024.2.2 # via @@ -164,6 +164,12 @@ docker==7.0.0 # testcontainers docutils==0.19 # via sphinx +duckdb==0.10.1 + # via + # duckdb-engine + # ibis-framework +duckdb-engine==0.11.2 + # via ibis-framework entrypoints==0.4 # via altair exceptiongroup==1.2.0 @@ -213,7 +219,7 @@ google-api-core[grpc]==2.17.1 # google-cloud-storage google-api-python-client==2.122.0 # via firebase-admin -google-auth==2.28.2 +google-auth==2.29.0 # via # google-api-core # google-api-python-client @@ -258,7 +264,7 @@ googleapis-common-protos[grpc]==1.63.0 # google-api-core # grpc-google-iam-v1 # grpcio-status -great-expectations==0.18.11 +great-expectations==0.18.12 # via feast (setup.py) greenlet==3.0.3 # via sqlalchemy @@ -310,7 +316,7 @@ httpx==0.27.0 # via # feast (setup.py) # jupyterlab -ibis-framework==8.0.0 +ibis-framework[duckdb]==8.0.0 # via # feast (setup.py) # ibis-substrait @@ -339,7 +345,7 @@ importlib-metadata==6.11.0 # nbconvert # sphinx # typeguard -importlib-resources==6.3.1 +importlib-resources==6.3.2 # via feast (setup.py) iniconfig==2.0.0 # via pytest @@ -467,7 +473,7 @@ moreorless==0.4.0 # via bowler moto==4.2.14 # via feast (setup.py) -msal==1.27.0 +msal==1.28.0 # via # azure-identity # msal-extensions @@ -854,8 +860,13 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx sqlalchemy[mypy]==1.4.52 # via + # duckdb-engine # feast (setup.py) + # ibis-framework # sqlalchemy + # sqlalchemy-views +sqlalchemy-views==0.3.2 + # via ibis-framework sqlalchemy2-stubs==0.0.2a38 # via sqlalchemy sqlglot==20.11.0 @@ -998,7 +1009,7 @@ urllib3==1.26.18 # responses # rockset # snowflake-connector-python -uvicorn[standard]==0.28.0 +uvicorn[standard]==0.29.0 # via feast (setup.py) uvloop==0.19.0 # via uvicorn diff --git a/sdk/python/requirements/py3.9-requirements.txt b/sdk/python/requirements/py3.9-requirements.txt index f2228ade027..43b0191ed40 100644 --- a/sdk/python/requirements/py3.9-requirements.txt +++ b/sdk/python/requirements/py3.9-requirements.txt @@ -63,7 +63,7 @@ importlib-metadata==6.11.0 # dask # feast (setup.py) # typeguard -importlib-resources==6.3.1 +importlib-resources==6.3.2 # via feast (setup.py) jinja2==3.1.3 # via feast (setup.py) @@ -178,7 +178,7 @@ tzdata==2024.1 # via pandas urllib3==2.2.1 # via requests -uvicorn[standard]==0.28.0 +uvicorn[standard]==0.29.0 # via feast (setup.py) uvloop==0.19.0 # via uvicorn diff --git a/sdk/python/tests/unit/infra/offline_stores/test_ibis.py b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py new file mode 100644 index 00000000000..a73d4451a5a --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py @@ -0,0 +1,88 @@ +from datetime import datetime, timedelta +import ibis +import pyarrow as pa +from typing import List, Tuple, Dict +from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import point_in_time_join +from pprint import pprint + +def pa_datetime(year, month, day): + return pa.scalar(datetime(year, month, day), type=pa.timestamp('s', tz='UTC')) + +def customer_table(): + return pa.Table.from_arrays( + arrays=[ + pa.array([1, 1, 2]), + pa.array([pa_datetime(2024, 1, 1),pa_datetime(2024, 1, 2),pa_datetime(2024, 1, 1)]) + ], + names=['customer_id', 'event_timestamp'] + ) + +def features_table_1(): + return pa.Table.from_arrays( + arrays=[ + pa.array([1, 1, 1, 2]), + pa.array([pa_datetime(2023, 12, 31), pa_datetime(2024, 1, 2), pa_datetime(2024, 1, 3), pa_datetime(2023, 1, 3)]), + pa.array([11, 22, 33, 22]) + ], + names=['customer_id', 'event_timestamp', 'feature1'] + ) + +def point_in_time_join_brute( + entity_table: pa.Table, + feature_tables: List[Tuple[pa.Table, str, Dict[str, str], List[str], timedelta]], + event_timestamp_col = 'event_timestamp' +): + ret_fields = [entity_table.schema.field(n) for n in entity_table.schema.names] + + from operator import itemgetter + ret = entity_table.to_pydict() + batch_dict = entity_table.to_pydict() + + for i, row_timestmap in enumerate(batch_dict[event_timestamp_col]): + for feature_table, timestamp_key, join_key_map, feature_refs, ttl in feature_tables: + if i == 0: + ret_fields.extend([feature_table.schema.field(f) for f in feature_table.schema.names if f not in join_key_map.values() and f != timestamp_key]) + + def check_equality(ft_dict, batch_dict, x, y): + return all([ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()]) + + ft_dict = feature_table.to_pydict() + found_matches = [ + (j, ft_dict[timestamp_key][j]) for j in range(entity_table.num_rows) + if check_equality(ft_dict, batch_dict, j, i) and + ft_dict[timestamp_key][j] <= row_timestmap and + ft_dict[timestamp_key][j] >= row_timestmap - ttl + ] + + index_found = max(found_matches, key=itemgetter(1))[0] if found_matches else None + for col in ft_dict.keys(): + if col not in feature_refs: + continue + + if col not in ret: + ret[col] = [] + + if index_found is not None: + ret[col].append(ft_dict[col][index_found]) + else: + ret[col].append(None) + + return pa.Table.from_pydict(ret, schema=pa.schema(ret_fields)) + + +def test_point_in_time_join(): + expected = point_in_time_join_brute( + customer_table(), + feature_tables=[ + (features_table_1(), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10)) + ] + ) + + actual = point_in_time_join( + ibis.memtable(customer_table()), + feature_tables=[ + (ibis.memtable(features_table_1()), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10)) + ] + ).to_pyarrow() + + assert actual.equals(expected) diff --git a/setup.py b/setup.py index b32d03ed77c..ca89b09bf66 100644 --- a/setup.py +++ b/setup.py @@ -211,6 +211,7 @@ + HAZELCAST_REQUIRED + IBIS_REQUIRED + GRPCIO_REQUIRED + + DUCKDB_REQUIRED ) DOCS_REQUIRED = CI_REQUIRED From fcfe3053b4cb45f1e2d61f4122ac36e84b2a3832 Mon Sep 17 00:00:00 2001 From: tokoko Date: Thu, 21 Mar 2024 04:39:45 +0000 Subject: [PATCH 2/2] fix formatting, linting Signed-off-by: tokoko --- .../contrib/ibis_offline_store/ibis.py | 51 ++++++---- .../unit/infra/offline_stores/test_ibis.py | 98 ++++++++++++++----- 2 files changed, 108 insertions(+), 41 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py index cb35cc083e0..8787d701581 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py @@ -162,26 +162,32 @@ def read_fv(feature_view, feature_refs, full_feature_names): if full_feature_names: fv_table = fv_table.rename( - {f"{full_name_prefix}__{feature}": feature for feature in feature_refs} + { + f"{full_name_prefix}__{feature}": feature + for feature in feature_refs + } ) - feature_refs = [f"{full_name_prefix}__{feature}" for feature in feature_refs] + feature_refs = [ + f"{full_name_prefix}__{feature}" for feature in feature_refs + ] return ( fv_table, feature_view.batch_source.timestamp_field, - feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns}, + feature_view.projection.join_key_map + or {e.name: e.name for e in feature_view.entity_columns}, feature_refs, - feature_view.ttl + feature_view.ttl, ) res = point_in_time_join( entity_table=entity_table, - feature_tables=[ + feature_tables=[ read_fv(feature_view, feature_refs, full_feature_names) for feature_view in feature_views ], - event_timestamp_col=event_timestamp_col + event_timestamp_col=event_timestamp_col, ) return IbisRetrievalJob( @@ -217,8 +223,8 @@ def pull_all_from_table_or_query( table = table.select(*fields) # TODO get rid of this fix - if '__log_date' in table.columns: - table = table.drop('__log_date') + if "__log_date" in table.columns: + table = table.drop("__log_date") table = table.filter( ibis.and_( @@ -255,7 +261,7 @@ def write_logged_features( else: kwargs = {} - #TODO always write to directory + # TODO always write to directory table.to_parquet( f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs ) @@ -346,9 +352,9 @@ def metadata(self) -> Optional[RetrievalMetadata]: def point_in_time_join( entity_table: Table, feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]], - event_timestamp_col = 'event_timestamp' + event_timestamp_col="event_timestamp", ): - #TODO handle ttl + # TODO handle ttl all_entities = [event_timestamp_col] for feature_table, timestamp_field, join_key_map, _, _ in feature_tables: all_entities.extend(join_key_map.values()) @@ -362,8 +368,16 @@ def point_in_time_join( acc_table = entity_table - for feature_table, timestamp_field, join_key_map, feature_refs, ttl in feature_tables: - predicates = [feature_table[k] == entity_table[v] for k, v in join_key_map.items()] + for ( + feature_table, + timestamp_field, + join_key_map, + feature_refs, + ttl, + ) in feature_tables: + predicates = [ + feature_table[k] == entity_table[v] for k, v in join_key_map.items() + ] predicates.append( feature_table[timestamp_field] <= entity_table[event_timestamp_col], @@ -371,7 +385,8 @@ def point_in_time_join( if ttl: predicates.append( - feature_table[timestamp_field] >= entity_table[event_timestamp_col] - ibis.literal(ttl) + feature_table[timestamp_field] + >= entity_table[event_timestamp_col] - ibis.literal(ttl) ) feature_table = feature_table.inner_join( @@ -386,7 +401,9 @@ def point_in_time_join( .mutate(rn=ibis.row_number()) ) - feature_table = feature_table.filter(feature_table["rn"] == ibis.literal(0)).drop("rn") + feature_table = feature_table.filter( + feature_table["rn"] == ibis.literal(0) + ).drop("rn") select_cols = ["entity_row_id"] select_cols.extend(feature_refs) @@ -401,6 +418,6 @@ def point_in_time_join( acc_table = acc_table.drop(s.endswith("_yyyy")) - acc_table = acc_table.drop('entity_row_id') + acc_table = acc_table.drop("entity_row_id") - return acc_table \ No newline at end of file + return acc_table diff --git a/sdk/python/tests/unit/infra/offline_stores/test_ibis.py b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py index a73d4451a5a..5f105e2af70 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_ibis.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py @@ -1,67 +1,105 @@ from datetime import datetime, timedelta +from typing import Dict, List, Tuple + import ibis import pyarrow as pa -from typing import List, Tuple, Dict -from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import point_in_time_join -from pprint import pprint + +from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import ( + point_in_time_join, +) + def pa_datetime(year, month, day): - return pa.scalar(datetime(year, month, day), type=pa.timestamp('s', tz='UTC')) + return pa.scalar(datetime(year, month, day), type=pa.timestamp("s", tz="UTC")) + def customer_table(): return pa.Table.from_arrays( arrays=[ pa.array([1, 1, 2]), - pa.array([pa_datetime(2024, 1, 1),pa_datetime(2024, 1, 2),pa_datetime(2024, 1, 1)]) + pa.array( + [ + pa_datetime(2024, 1, 1), + pa_datetime(2024, 1, 2), + pa_datetime(2024, 1, 1), + ] + ), ], - names=['customer_id', 'event_timestamp'] + names=["customer_id", "event_timestamp"], ) + def features_table_1(): return pa.Table.from_arrays( arrays=[ pa.array([1, 1, 1, 2]), - pa.array([pa_datetime(2023, 12, 31), pa_datetime(2024, 1, 2), pa_datetime(2024, 1, 3), pa_datetime(2023, 1, 3)]), - pa.array([11, 22, 33, 22]) - ], - names=['customer_id', 'event_timestamp', 'feature1'] + pa.array( + [ + pa_datetime(2023, 12, 31), + pa_datetime(2024, 1, 2), + pa_datetime(2024, 1, 3), + pa_datetime(2023, 1, 3), + ] + ), + pa.array([11, 22, 33, 22]), + ], + names=["customer_id", "event_timestamp", "feature1"], ) + def point_in_time_join_brute( entity_table: pa.Table, feature_tables: List[Tuple[pa.Table, str, Dict[str, str], List[str], timedelta]], - event_timestamp_col = 'event_timestamp' + event_timestamp_col="event_timestamp", ): ret_fields = [entity_table.schema.field(n) for n in entity_table.schema.names] from operator import itemgetter + ret = entity_table.to_pydict() batch_dict = entity_table.to_pydict() for i, row_timestmap in enumerate(batch_dict[event_timestamp_col]): - for feature_table, timestamp_key, join_key_map, feature_refs, ttl in feature_tables: + for ( + feature_table, + timestamp_key, + join_key_map, + feature_refs, + ttl, + ) in feature_tables: if i == 0: - ret_fields.extend([feature_table.schema.field(f) for f in feature_table.schema.names if f not in join_key_map.values() and f != timestamp_key]) + ret_fields.extend( + [ + feature_table.schema.field(f) + for f in feature_table.schema.names + if f not in join_key_map.values() and f != timestamp_key + ] + ) def check_equality(ft_dict, batch_dict, x, y): - return all([ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()]) + return all( + [ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()] + ) ft_dict = feature_table.to_pydict() found_matches = [ - (j, ft_dict[timestamp_key][j]) for j in range(entity_table.num_rows) - if check_equality(ft_dict, batch_dict, j, i) and - ft_dict[timestamp_key][j] <= row_timestmap and - ft_dict[timestamp_key][j] >= row_timestmap - ttl + (j, ft_dict[timestamp_key][j]) + for j in range(entity_table.num_rows) + if check_equality(ft_dict, batch_dict, j, i) + and ft_dict[timestamp_key][j] <= row_timestmap + and ft_dict[timestamp_key][j] >= row_timestmap - ttl ] - index_found = max(found_matches, key=itemgetter(1))[0] if found_matches else None + index_found = ( + max(found_matches, key=itemgetter(1))[0] if found_matches else None + ) for col in ft_dict.keys(): if col not in feature_refs: continue if col not in ret: ret[col] = [] - + if index_found is not None: ret[col].append(ft_dict[col][index_found]) else: @@ -74,15 +112,27 @@ def test_point_in_time_join(): expected = point_in_time_join_brute( customer_table(), feature_tables=[ - (features_table_1(), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10)) - ] + ( + features_table_1(), + "event_timestamp", + {"customer_id": "customer_id"}, + ["feature1"], + timedelta(days=10), + ) + ], ) actual = point_in_time_join( ibis.memtable(customer_table()), feature_tables=[ - (ibis.memtable(features_table_1()), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10)) - ] + ( + ibis.memtable(features_table_1()), + "event_timestamp", + {"customer_id": "customer_id"}, + ["feature1"], + timedelta(days=10), + ) + ], ).to_pyarrow() assert actual.equals(expected)