diff --git a/python/feast_spark/pyspark/historical_feature_retrieval_job.py b/python/feast_spark/pyspark/historical_feature_retrieval_job.py index 3dc642a5..82cb9579 100644 --- a/python/feast_spark/pyspark/historical_feature_retrieval_job.py +++ b/python/feast_spark/pyspark/historical_feature_retrieval_job.py @@ -579,16 +579,18 @@ def filter_feature_table_by_time_range( time_range_filtered_df = feature_table_df.filter(feature_table_timestamp_filter) + entities_projected = ( + entity_df.withColumnRenamed( + entity_event_timestamp_column, ENTITY_EVENT_TIMESTAMP_ALIAS + ) + .select(feature_table.entity_names + [ENTITY_EVENT_TIMESTAMP_ALIAS]) + .distinct() + ) + time_range_filtered_df = ( time_range_filtered_df.repartition(200) .join( - broadcast( - entity_df.withColumnRenamed( - entity_event_timestamp_column, ENTITY_EVENT_TIMESTAMP_ALIAS - ) - ), - on=feature_table.entity_names, - how="inner", + broadcast(entities_projected), on=feature_table.entity_names, how="inner", ) .withColumn( "distance", @@ -605,7 +607,6 @@ def filter_feature_table_by_time_range( ), ) .where(col("distance") == col("min_distance")) - .select(time_range_filtered_df.columns + [ENTITY_EVENT_TIMESTAMP_ALIAS]) ) if SparkContext._active_spark_context._jsc.sc().getCheckpointDir().nonEmpty(): time_range_filtered_df = time_range_filtered_df.checkpoint()