diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index df882bfc2c3..a3059105950 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -29,7 +29,7 @@ def build_source_node(self): source = self.feature_view.batch_source start_time = self.task.start_time end_time = self.task.end_time - node = SparkReadNode("source", source, start_time, end_time) + node = SparkReadNode("source", source, self.spark_session, start_time, end_time) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index 8d00f124439..1ab454daa52 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -56,11 +56,13 @@ def __init__( self, name: str, source: DataSource, + spark_session: SparkSession, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ): super().__init__(name) self.source = source + self.spark_session = spark_session self.start_time = start_time self.end_time = end_time @@ -72,7 +74,10 @@ def execute(self, context: ExecutionContext) -> DAGValue: start_time=self.start_time, end_time=self.end_time, ) - spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() + if isinstance(retrieval_job, SparkRetrievalJob): + spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() + else: + spark_df = self.spark_session.createDataFrame(retrieval_job.to_arrow()) return DAGValue( data=spark_df,