Skip to content

Commit 99ee2ce

Browse files
authored
Add validations when materializing from file sources (#1615)
* Validate join keys when materializing from file sources Signed-off-by: Achal Shah <achals@gmail.com> * Dedupe columns when extracting from the dataframe Signed-off-by: Achal Shah <achals@gmail.com>
1 parent 09fe2a6 commit 99ee2ce

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,13 @@ def pull_latest_from_table_or_query(
212212
created_timestamp_column
213213
].apply(lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc))
214214

215+
source_columns = set(source_df.columns)
216+
if not set(join_key_columns).issubset(source_columns):
217+
raise ValueError(
218+
f"The DataFrame must have at least {set(join_key_columns)} columns present, "
219+
f"but these were missing: {set(join_key_columns)- source_columns} "
220+
)
221+
215222
ts_columns = (
216223
[event_timestamp_column, created_timestamp_column]
217224
if created_timestamp_column
@@ -229,8 +236,7 @@ def pull_latest_from_table_or_query(
229236
# make driver_id a normal column again
230237
last_values_df.reset_index(inplace=True)
231238

232-
table = pyarrow.Table.from_pandas(
233-
last_values_df[join_key_columns + feature_name_columns + ts_columns]
234-
)
239+
columns_to_extract = set(join_key_columns + feature_name_columns + ts_columns)
240+
table = pyarrow.Table.from_pandas(last_values_df[columns_to_extract])
235241

236242
return table

sdk/python/feast/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Enti
132132
for entity_proto in registry_proto.entities:
133133
if entity_proto.spec.name == name and entity_proto.spec.project == project:
134134
return Entity.from_proto(entity_proto)
135-
raise EntityNotFoundException(project, name)
135+
raise EntityNotFoundException(name, project=project)
136136

137137
def apply_feature_table(self, feature_table: FeatureTable, project: str):
138138
"""

0 commit comments

Comments
 (0)