From 7f6d527bf2d49ac08cfc4c25cc9b0ec5ffc58018 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 00:34:50 -0700 Subject: [PATCH 01/11] enable write node Signed-off-by: HaoXuAI --- sdk/python/feast/infra/common/serializer.py | 41 +++++++++++++++++ .../compute_engines/local/feature_builder.py | 2 +- .../infra/compute_engines/local/nodes.py | 34 +++++++++++++- .../compute_engines/spark/feature_builder.py | 2 +- .../spark/{node.py => nodes.py} | 21 +++------ .../infra/compute_engines/spark/utils.py | 46 ++++++++++++++++++- .../spark/spark_materialization_engine.py | 46 +++---------------- .../infra/compute_engines/spark/test_nodes.py | 2 +- 8 files changed, 134 insertions(+), 60 deletions(-) create mode 100644 sdk/python/feast/infra/common/serializer.py rename sdk/python/feast/infra/compute_engines/spark/{node.py => nodes.py} (94%) diff --git a/sdk/python/feast/infra/common/serializer.py b/sdk/python/feast/infra/common/serializer.py new file mode 100644 index 00000000000..44e39c4d250 --- /dev/null +++ b/sdk/python/feast/infra/common/serializer.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass + +import dill +from infra.passthrough_provider import PassthroughProvider + +from feast import FeatureView +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto + + +@dataclass +class SerializedArtifacts: + """Class to assist with serializing unpicklable artifacts to be passed to the compute engine.""" + + feature_view_proto: str + repo_config_byte: str + + @classmethod + def serialize(cls, feature_view, repo_config): + # serialize to proto + feature_view_proto = feature_view.to_proto().SerializeToString() + + # serialize repo_config to disk. Will be used to instantiate the online store + repo_config_byte = dill.dumps(repo_config) + + return SerializedArtifacts( + feature_view_proto=feature_view_proto, repo_config_byte=repo_config_byte + ) + + def unserialize(self): + # unserialize + proto = FeatureViewProto() + proto.ParseFromString(self.feature_view_proto) + feature_view = FeatureView.from_proto(proto) + + # load + repo_config = dill.loads(self.repo_config_byte) + + provider = PassthroughProvider(repo_config) + online_store = provider.online_store + offline_store = provider.offline_store + return feature_view, online_store, offline_store, repo_config diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index aee245da21c..4f9dcc871d5 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -91,6 +91,6 @@ def build_validation_node(self, input_node): return node def build_output_nodes(self, input_node): - node = LocalOutputNode("output") + node = LocalOutputNode("output", self.feature_view) node.add_input(input_node) self.nodes.append(node) diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index aea83921351..c4e51075a07 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -2,6 +2,7 @@ from typing import Optional import pyarrow as pa +from utils import _convert_arrow_to_proto from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ExecutionContext @@ -207,11 +208,40 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: class LocalOutputNode(LocalNode): - def __init__(self, name: str): + def __init__(self, name: str, feature_view): super().__init__(name) + self.feature_view = feature_view def execute(self, context: ExecutionContext) -> ArrowTableValue: input_table = self.get_single_table(context).data context.node_outputs[self.name] = input_table - # TODO: implement the logic to write to offline store + + if self.feature_view.online: + online_store = context.online_store + + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in self.feature_view.entity_columns + } + + rows_to_write = _convert_arrow_to_proto( + input_table, self.feature_view, join_key_to_value_type + ) + + online_store.online_write_batch( + config=context.repo_config, + table=self.feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + + if self.feature_view.offline: + offline_store = context.offline_store + offline_store.offline_write_batch( + config=context.repo_config, + feature_view=self.feature_view, + table=input_table, + progress=lambda x: None, + ) + return input_table diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index 944feccf903..f40f35a8360 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -5,7 +5,7 @@ from feast.infra.common.materialization_job import MaterializationTask from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.feature_builder import FeatureBuilder -from feast.infra.compute_engines.spark.node import ( +from feast.infra.compute_engines.spark.nodes import ( SparkAggregationNode, SparkDedupNode, SparkFilterNode, diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py similarity index 94% rename from sdk/python/feast/infra/compute_engines/spark/node.py rename to sdk/python/feast/infra/compute_engines/spark/nodes.py index 0c1c1476613..ad7e2b576fc 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -7,14 +7,12 @@ from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation from feast.data_source import DataSource +from feast.infra.common.serializer import SerializedArtifacts from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( - _map_by_partition, - _SparkSerializedArtifacts, -) +from feast.infra.compute_engines.spark.utils import map_in_arrow from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkRetrievalJob, _get_entity_schema, @@ -282,19 +280,14 @@ def __init__( def execute(self, context: ExecutionContext) -> DAGValue: spark_df: DataFrame = self.get_single_input_value(context).data - spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( + serialized_artifacts = SerializedArtifacts.serialize( feature_view=self.feature_view, repo_config=context.repo_config ) - # ✅ 1. Write to offline store (if enabled) - if self.feature_view.offline: - # TODO: Update _map_by_partition to be able to write to offline store - pass - - # ✅ 2. Write to online store (if enabled) - if self.feature_view.online: - spark_df.mapInPandas( - lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" + # ✅ 1. Write to online or offline store (if enabled) + if self.feature_view.online or self.feature_view.offline: + spark_df.mapInArrow( + lambda x: map_in_arrow(x, serialized_artifacts), "status int" ).count() return DAGValue( diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 262876f9dbb..6b1840d54ea 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -1,7 +1,11 @@ -from typing import Dict, Optional +from typing import Dict, Iterable, Optional +import pyarrow as pa from pyspark import SparkConf from pyspark.sql import SparkSession +from utils import _convert_arrow_to_proto + +from feast.infra.common.serializer import SerializedArtifacts def get_or_create_new_spark_session( @@ -17,3 +21,43 @@ def get_or_create_new_spark_session( spark_session = spark_builder.getOrCreate() return spark_session + + +def map_in_arrow( + iterator: Iterable[pa.RecordBatch], + serialized_artifacts: "SerializedArtifacts", +): + table = pa.Table.from_batches(iterator) + + ( + feature_view, + online_store, + offline_store, + repo_config, + ) = serialized_artifacts.unserialize() + + if feature_view.online: + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + + rows_to_write = _convert_arrow_to_proto( + table, feature_view, join_key_to_value_type + ) + + online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + if feature_view.offline: + offline_store.offline_write_batch( + config=repo_config, + feature_view=feature_view, + table=table, + progress=lambda x: None, + ) + + yield table diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 246297cc1d5..744d3da5139 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import Callable, List, Literal, Optional, Sequence, Union, cast -import dill import pandas as pd import pyarrow from tqdm import tqdm @@ -15,6 +14,7 @@ MaterializationJobStatus, MaterializationTask, ) +from feast.infra.common.serializer import SerializedArtifacts from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, ) @@ -23,10 +23,8 @@ SparkRetrievalJob, ) from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.passthrough_provider import PassthroughProvider from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView -from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.stream_feature_view import StreamFeatureView from feast.utils import ( @@ -171,7 +169,7 @@ def _materialize_one( ), ) - spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( + serialized_artifacts = SerializedArtifacts.serialize( feature_view=feature_view, repo_config=self.repo_config ) @@ -182,7 +180,7 @@ def _materialize_one( ) spark_df.mapInPandas( - lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" + lambda x: _map_by_partition(x, serialized_artifacts), "status int" ).count() # dummy action to force evaluation return SparkMaterializationJob( @@ -194,40 +192,7 @@ def _materialize_one( ) -@dataclass -class _SparkSerializedArtifacts: - """Class to assist with serializing unpicklable artifacts to the spark workers""" - - feature_view_proto: str - repo_config_byte: str - - @classmethod - def serialize(cls, feature_view, repo_config): - # serialize to proto - feature_view_proto = feature_view.to_proto().SerializeToString() - - # serialize repo_config to disk. Will be used to instantiate the online store - repo_config_byte = dill.dumps(repo_config) - - return _SparkSerializedArtifacts( - feature_view_proto=feature_view_proto, repo_config_byte=repo_config_byte - ) - - def unserialize(self): - # unserialize - proto = FeatureViewProto() - proto.ParseFromString(self.feature_view_proto) - feature_view = FeatureView.from_proto(proto) - - # load - repo_config = dill.loads(self.repo_config_byte) - - provider = PassthroughProvider(repo_config) - online_store = provider.online_store - return feature_view, online_store, repo_config - - -def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArtifacts): +def _map_by_partition(iterator, serialized_artifacts: SerializedArtifacts): for pdf in iterator: if pdf.shape[0] == 0: print("Skipping") @@ -238,8 +203,9 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti ( feature_view, online_store, + _, repo_config, - ) = spark_serialized_artifacts.unserialize() + ) = serialized_artifacts.unserialize() if feature_view.batch_source.field_mapping is not None: # Spark offline store does the field mapping in pull_latest_from_table_or_query() call diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index ae69c0a6fcd..3f681017e89 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -8,7 +8,7 @@ from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.compute_engines.spark.node import ( +from feast.infra.compute_engines.spark.nodes import ( SparkAggregationNode, SparkDedupNode, SparkJoinNode, From ff1f400b089645b5c28a011adbd934d31517c6d4 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 15:24:12 -0700 Subject: [PATCH 02/11] fix linting Signed-off-by: HaoXuAI --- sdk/python/feast/infra/common/serializer.py | 2 +- .../infra/compute_engines/spark/compute.py | 1 + .../infra/compute_engines/spark/nodes.py | 4 +- .../infra/compute_engines/spark/utils.py | 72 ++++++++++--------- 4 files changed, 42 insertions(+), 37 deletions(-) diff --git a/sdk/python/feast/infra/common/serializer.py b/sdk/python/feast/infra/common/serializer.py index 44e39c4d250..90e1be9234e 100644 --- a/sdk/python/feast/infra/common/serializer.py +++ b/sdk/python/feast/infra/common/serializer.py @@ -1,9 +1,9 @@ from dataclasses import dataclass import dill -from infra.passthrough_provider import PassthroughProvider from feast import FeatureView +from feast.infra.passthrough_provider import PassthroughProvider from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 981e786cf7f..cd4f39ee4ed 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -53,6 +53,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: ) except Exception as e: + raise e # 🛑 Handle failure return SparkMaterializationJob( job_id=job_id, status=MaterializationJobStatus.ERROR, error=e diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index ad7e2b576fc..f5b83d34df6 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -286,8 +286,10 @@ def execute(self, context: ExecutionContext) -> DAGValue: # ✅ 1. Write to online or offline store (if enabled) if self.feature_view.online or self.feature_view.offline: + print("Spark DF count:", spark_df.count()) + print("Num partitions:", spark_df.rdd.getNumPartitions()) spark_df.mapInArrow( - lambda x: map_in_arrow(x, serialized_artifacts), "status int" + lambda x: map_in_arrow(x, serialized_artifacts), spark_df.schema ).count() return DAGValue( diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 6b1840d54ea..76ba50b6ac1 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -3,9 +3,9 @@ import pyarrow as pa from pyspark import SparkConf from pyspark.sql import SparkSession -from utils import _convert_arrow_to_proto from feast.infra.common.serializer import SerializedArtifacts +from feast.utils import _convert_arrow_to_proto def get_or_create_new_spark_session( @@ -20,6 +20,7 @@ def get_or_create_new_spark_session( ) spark_session = spark_builder.getOrCreate() + spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") return spark_session @@ -27,37 +28,38 @@ def map_in_arrow( iterator: Iterable[pa.RecordBatch], serialized_artifacts: "SerializedArtifacts", ): - table = pa.Table.from_batches(iterator) - - ( - feature_view, - online_store, - offline_store, - repo_config, - ) = serialized_artifacts.unserialize() - - if feature_view.online: - join_key_to_value_type = { - entity.name: entity.dtype.to_value_type() - for entity in feature_view.entity_columns - } - - rows_to_write = _convert_arrow_to_proto( - table, feature_view, join_key_to_value_type - ) - - online_store.online_write_batch( - config=repo_config, - table=feature_view, - data=rows_to_write, - progress=lambda x: None, - ) - if feature_view.offline: - offline_store.offline_write_batch( - config=repo_config, - feature_view=feature_view, - table=table, - progress=lambda x: None, - ) - - yield table + for batch in iterator: + table = pa.Table.from_batches([batch]) + + ( + feature_view, + online_store, + offline_store, + repo_config, + ) = serialized_artifacts.unserialize() + + if feature_view.online: + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + + rows_to_write = _convert_arrow_to_proto( + table, feature_view, join_key_to_value_type + ) + + online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + if feature_view.offline: + offline_store.offline_write_batch( + config=repo_config, + feature_view=feature_view, + table=table, + progress=lambda x: None, + ) + + yield batch From 91f41cbec4f03c137cbbc0a9355e1de10fe12d50 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 15:24:38 -0700 Subject: [PATCH 03/11] remove debug Signed-off-by: HaoXuAI --- sdk/python/feast/infra/compute_engines/spark/compute.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index cd4f39ee4ed..981e786cf7f 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -53,7 +53,6 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: ) except Exception as e: - raise e # 🛑 Handle failure return SparkMaterializationJob( job_id=job_id, status=MaterializationJobStatus.ERROR, error=e From 1f54560e86ea518e9f73c1118c414b9bea5db0be Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 15:26:36 -0700 Subject: [PATCH 04/11] rename module Signed-off-by: HaoXuAI --- sdk/python/feast/infra/common/{serializer.py => serde.py} | 0 sdk/python/feast/infra/compute_engines/spark/nodes.py | 2 +- sdk/python/feast/infra/compute_engines/spark/utils.py | 2 +- .../contrib/spark/spark_materialization_engine.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename sdk/python/feast/infra/common/{serializer.py => serde.py} (100%) diff --git a/sdk/python/feast/infra/common/serializer.py b/sdk/python/feast/infra/common/serde.py similarity index 100% rename from sdk/python/feast/infra/common/serializer.py rename to sdk/python/feast/infra/common/serde.py diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index f5b83d34df6..d9d68b39854 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -7,7 +7,7 @@ from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation from feast.data_source import DataSource -from feast.infra.common.serializer import SerializedArtifacts +from feast.infra.common.serde import SerializedArtifacts from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 76ba50b6ac1..2998207f51d 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -4,7 +4,7 @@ from pyspark import SparkConf from pyspark.sql import SparkSession -from feast.infra.common.serializer import SerializedArtifacts +from feast.infra.common.serde import SerializedArtifacts from feast.utils import _convert_arrow_to_proto diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 744d3da5139..c4809df3678 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -14,7 +14,7 @@ MaterializationJobStatus, MaterializationTask, ) -from feast.infra.common.serializer import SerializedArtifacts +from feast.infra.common.serde import SerializedArtifacts from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, ) From 3d97656c5803ec6505dc703e2fd2190f58b5c62b Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 15:33:20 -0700 Subject: [PATCH 05/11] fix linting Signed-off-by: HaoXuAI --- sdk/python/feast/infra/compute_engines/local/nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index c4e51075a07..3cd66da0cf2 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -2,8 +2,8 @@ from typing import Optional import pyarrow as pa -from utils import _convert_arrow_to_proto +from feast.utils import _convert_arrow_to_proto from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue From 145c58e2ec35a5c2c4c2458c86327c249f307736 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 15:34:52 -0700 Subject: [PATCH 06/11] fix linting Signed-off-by: HaoXuAI --- sdk/python/feast/infra/compute_engines/local/nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index 3cd66da0cf2..c2965c00817 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -3,7 +3,6 @@ import pyarrow as pa -from feast.utils import _convert_arrow_to_proto from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue @@ -12,6 +11,7 @@ from feast.infra.offline_stores.offline_utils import ( infer_event_timestamp_from_entity_df, ) +from feast.utils import _convert_arrow_to_proto ENTITY_TS_ALIAS = "__entity_event_timestamp" From 62e9b3d51c59eddf615d191366e0ca1b92389f15 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 20:14:01 -0700 Subject: [PATCH 07/11] fix linting Signed-off-by: HaoXuAI --- .../integration/compute_engines/spark/test_compute.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 15b6e850c65..e54dc42ccc8 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -33,6 +33,7 @@ from tests.integration.feature_repos.universal.online_store.redis import ( RedisOnlineStoreCreator, ) +from tests.utils.e2e_test_validation import _check_offline_and_online_features now = datetime.now() today = datetime.today() @@ -242,6 +243,16 @@ def tqdm_builder(length): spark_materialize_job = engine.materialize(task) assert spark_materialize_job.status() == MaterializationJobStatus.SUCCEEDED + + _check_offline_and_online_features( + fs=fs, + fv=driver_stats_fv, + driver_id=1, + event_timestamp=now, + expected_value=0.3, + full_feature_names=True, + check_offline_store=True, + ) finally: spark_environment.teardown() From a09b7e5df6d89087865524779116bb12b10deba9 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 23:02:30 -0700 Subject: [PATCH 08/11] fix fv offline Signed-off-by: HaoXuAI --- protos/feast/core/FeatureView.proto | 4 + sdk/python/feast/feature_view.py | 2 + .../infra/compute_engines/local/nodes.py | 7 +- .../compute_engines/spark/feature_builder.py | 3 +- .../infra/compute_engines/spark/nodes.py | 4 - .../infra/compute_engines/spark/utils.py | 2 + .../compute_engines/spark/test_compute.py | 73 +++++++++++++++++-- .../infra/compute_engines/local/test_nodes.py | 2 +- 8 files changed, 81 insertions(+), 16 deletions(-) diff --git a/protos/feast/core/FeatureView.proto b/protos/feast/core/FeatureView.proto index 3e9aa17256f..481ae00403f 100644 --- a/protos/feast/core/FeatureView.proto +++ b/protos/feast/core/FeatureView.proto @@ -74,7 +74,11 @@ message FeatureViewSpec { DataSource stream_source = 9; // Whether these features should be served online or not + // This is also used to determine whether the features should be written to the online store bool online = 8; + + // Whether these features should be written to the offline store + bool offline = 13; } message FeatureViewMeta { diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 5259d5d2b90..fea0c51a125 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -236,6 +236,7 @@ def __copy__(self): schema=self.schema, tags=self.tags, online=self.online, + offline=self.offline, ) # This is deliberately set outside of the FV initialization as we do not have the Entity objects. @@ -258,6 +259,7 @@ def __eq__(self, other): sorted(self.entities) != sorted(other.entities) or self.ttl != other.ttl or self.online != other.online + or self.offline != other.offline or self.batch_source != other.batch_source or self.stream_source != other.stream_source or sorted(self.entity_columns) != sorted(other.entity_columns) diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index c2965c00817..709b592f97c 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -1,8 +1,9 @@ from datetime import datetime, timedelta -from typing import Optional +from typing import Optional, Union import pyarrow as pa +from feast import BatchFeatureView, StreamFeatureView from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue @@ -208,7 +209,9 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: class LocalOutputNode(LocalNode): - def __init__(self, name: str, feature_view): + def __init__( + self, name: str, feature_view: Union[BatchFeatureView, StreamFeatureView] + ): super().__init__(name) self.feature_view = feature_view diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index f40f35a8360..df882bfc2c3 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -73,7 +73,8 @@ def build_transformation_node(self, input_node): return node def build_output_nodes(self, input_node): - node = SparkWriteNode("output", input_node, self.feature_view) + node = SparkWriteNode("output", self.feature_view) + node.add_input(input_node) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index d9d68b39854..7ec663b4b2f 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -271,11 +271,9 @@ class SparkWriteNode(DAGNode): def __init__( self, name: str, - input_node: DAGNode, feature_view: Union[BatchFeatureView, StreamFeatureView], ): super().__init__(name) - self.add_input(input_node) self.feature_view = feature_view def execute(self, context: ExecutionContext) -> DAGValue: @@ -286,8 +284,6 @@ def execute(self, context: ExecutionContext) -> DAGValue: # ✅ 1. Write to online or offline store (if enabled) if self.feature_view.online or self.feature_view.offline: - print("Spark DF count:", spark_df.count()) - print("Num partitions:", spark_df.rdd.getNumPartitions()) spark_df.mapInArrow( lambda x: map_in_arrow(x, serialized_artifacts), spark_df.schema ).count() diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 2998207f51d..700e1a7428d 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -37,6 +37,7 @@ def map_in_arrow( offline_store, repo_config, ) = serialized_artifacts.unserialize() + print("write_feature_view", feature_view) if feature_view.online: join_key_to_value_type = { @@ -55,6 +56,7 @@ def map_in_arrow( progress=lambda x: None, ) if feature_view.offline: + print("offline_to_write", table) offline_store.offline_write_batch( config=repo_config, feature_view=feature_view, diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index e54dc42ccc8..5fb8112c91e 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -33,7 +33,6 @@ from tests.integration.feature_repos.universal.online_store.redis import ( RedisOnlineStoreCreator, ) -from tests.utils.e2e_test_validation import _check_offline_and_online_features now = datetime.now() today = datetime.today() @@ -189,6 +188,20 @@ def transform_feature(df: DataFrame) -> DataFrame: @pytest.mark.integration def test_spark_compute_engine_materialize(): + """ + Test the SparkComputeEngine materialize method. + For the current feature view driver_hourly_stats, The below execution plan: + 1. feature data from create_feature_dataset + 2. filter by start_time and end_time, that is, the last 2 days + for the driver_id 1001, the data left is row 0 + for the driver_id 1002, the data left is row 2 + 3. apply the transform_feature function to the data + for all features, the value is multiplied by 2 + 4. write the data to the online store and offline store + + Returns: + + """ spark_environment = create_spark_environment() fs = spark_environment.feature_store registry = fs.registry @@ -213,7 +226,7 @@ def transform_feature(df: DataFrame) -> DataFrame: Field(name="driver_id", dtype=Int32), ], online=True, - offline=False, + offline=True, source=data_source, ) @@ -244,18 +257,62 @@ def tqdm_builder(length): assert spark_materialize_job.status() == MaterializationJobStatus.SUCCEEDED - _check_offline_and_online_features( + _check_online_features( fs=fs, - fv=driver_stats_fv, - driver_id=1, - event_timestamp=now, - expected_value=0.3, + driver_id=1001, + feature="driver_hourly_stats:conv_rate", + expected_value=1.6, full_feature_names=True, - check_offline_store=True, + ) + + entity_df = create_entity_df() + + _check_offline_features( + fs=fs, + feature="driver_hourly_stats:conv_rate", + entity_df=entity_df, + expected_value=1.6, ) finally: spark_environment.teardown() +def _check_online_features( + fs, + driver_id, + feature, + expected_value, + full_feature_names: bool = True, +): + online_response = fs.get_online_features( + features=[feature], + entity_rows=[{"driver_id": driver_id}], + full_feature_names=full_feature_names, + ).to_dict() + + feature_ref = "__".join(feature.split(":")) + + assert len(online_response["driver_id"]) == 1 + assert online_response["driver_id"][0] == driver_id + assert abs(online_response[feature_ref][0] - expected_value < 1e-6), ( + "Transformed result" + ) + + +def _check_offline_features( + fs, + feature, + entity_df, + expected_value, +): + offline_df = fs.get_historical_features( + entity_df=entity_df, + features=[feature], + ).to_df() + + assert len(offline_df) == 2 + assert offline_df["driver_id"].to_list() == [1001, 1002] + + if __name__ == "__main__": test_spark_compute_engine_get_historical_features() diff --git a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py index d5f6085a67c..c486b4148fc 100644 --- a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py @@ -194,7 +194,7 @@ def test_local_output_node(): context = create_context( node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} ) - node = LocalOutputNode("output") + node = LocalOutputNode("output", MagicMock()) node.add_input(MagicMock()) node.inputs[0].name = "source" result = node.execute(context) From 155d1ed1e0cf92b6da1ed4096ac3ebba24881495 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 20 Apr 2025 23:26:12 -0700 Subject: [PATCH 09/11] fix feature view proto Signed-off-by: HaoXuAI --- sdk/python/feast/feature_view.py | 2 ++ .../protos/feast/core/FeatureView_pb2.py | 20 +++++++++---------- .../protos/feast/core/FeatureView_pb2.pyi | 10 ++++++++-- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index fea0c51a125..2c2106f5a3e 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -365,6 +365,7 @@ def to_proto(self) -> FeatureViewProto: owner=self.owner, ttl=(ttl_duration if ttl_duration is not None else None), online=self.online, + offline=self.offline, batch_source=batch_source_proto, stream_source=stream_source_proto, ) @@ -414,6 +415,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto): tags=dict(feature_view_proto.spec.tags), owner=feature_view_proto.spec.owner, online=feature_view_proto.spec.online, + offline=feature_view_proto.spec.offline, ttl=( timedelta(days=0) if feature_view_proto.spec.ttl.ToNanoseconds() == 0 diff --git a/sdk/python/feast/protos/feast/core/FeatureView_pb2.py b/sdk/python/feast/protos/feast/core/FeatureView_pb2.py index 80d04c1ec3f..d1456cf9faf 100644 --- a/sdk/python/feast/protos/feast/core/FeatureView_pb2.py +++ b/sdk/python/feast/protos/feast/core/FeatureView_pb2.py @@ -18,7 +18,7 @@ from feast.protos.feast.core import Feature_pb2 as feast_dot_core_dot_Feature__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66\x65\x61st/core/FeatureView.proto\x12\nfeast.core\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1b\x66\x65\x61st/core/DataSource.proto\x1a\x18\x66\x65\x61st/core/Feature.proto\"c\n\x0b\x46\x65\x61tureView\x12)\n\x04spec\x18\x01 \x01(\x0b\x32\x1b.feast.core.FeatureViewSpec\x12)\n\x04meta\x18\x02 \x01(\x0b\x32\x1b.feast.core.FeatureViewMeta\"\xbd\x03\n\x0f\x46\x65\x61tureViewSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07project\x18\x02 \x01(\t\x12\x10\n\x08\x65ntities\x18\x03 \x03(\t\x12+\n\x08\x66\x65\x61tures\x18\x04 \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x31\n\x0e\x65ntity_columns\x18\x0c \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x13\n\x0b\x64\x65scription\x18\n \x01(\t\x12\x33\n\x04tags\x18\x05 \x03(\x0b\x32%.feast.core.FeatureViewSpec.TagsEntry\x12\r\n\x05owner\x18\x0b \x01(\t\x12&\n\x03ttl\x18\x06 \x01(\x0b\x32\x19.google.protobuf.Duration\x12,\n\x0c\x62\x61tch_source\x18\x07 \x01(\x0b\x32\x16.feast.core.DataSource\x12-\n\rstream_source\x18\t \x01(\x0b\x32\x16.feast.core.DataSource\x12\x0e\n\x06online\x18\x08 \x01(\x08\x1a+\n\tTagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xcc\x01\n\x0f\x46\x65\x61tureViewMeta\x12\x35\n\x11\x63reated_timestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12:\n\x16last_updated_timestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x46\n\x19materialization_intervals\x18\x03 \x03(\x0b\x32#.feast.core.MaterializationInterval\"w\n\x17MaterializationInterval\x12.\n\nstart_time\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"@\n\x0f\x46\x65\x61tureViewList\x12-\n\x0c\x66\x65\x61tureviews\x18\x01 \x03(\x0b\x32\x17.feast.core.FeatureViewBU\n\x10\x66\x65\x61st.proto.coreB\x10\x46\x65\x61tureViewProtoZ/github.com/feast-dev/feast/go/protos/feast/coreb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66\x65\x61st/core/FeatureView.proto\x12\nfeast.core\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1b\x66\x65\x61st/core/DataSource.proto\x1a\x18\x66\x65\x61st/core/Feature.proto\"c\n\x0b\x46\x65\x61tureView\x12)\n\x04spec\x18\x01 \x01(\x0b\x32\x1b.feast.core.FeatureViewSpec\x12)\n\x04meta\x18\x02 \x01(\x0b\x32\x1b.feast.core.FeatureViewMeta\"\xce\x03\n\x0f\x46\x65\x61tureViewSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07project\x18\x02 \x01(\t\x12\x10\n\x08\x65ntities\x18\x03 \x03(\t\x12+\n\x08\x66\x65\x61tures\x18\x04 \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x31\n\x0e\x65ntity_columns\x18\x0c \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x13\n\x0b\x64\x65scription\x18\n \x01(\t\x12\x33\n\x04tags\x18\x05 \x03(\x0b\x32%.feast.core.FeatureViewSpec.TagsEntry\x12\r\n\x05owner\x18\x0b \x01(\t\x12&\n\x03ttl\x18\x06 \x01(\x0b\x32\x19.google.protobuf.Duration\x12,\n\x0c\x62\x61tch_source\x18\x07 \x01(\x0b\x32\x16.feast.core.DataSource\x12-\n\rstream_source\x18\t \x01(\x0b\x32\x16.feast.core.DataSource\x12\x0e\n\x06online\x18\x08 \x01(\x08\x12\x0f\n\x07offline\x18\r \x01(\x08\x1a+\n\tTagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xcc\x01\n\x0f\x46\x65\x61tureViewMeta\x12\x35\n\x11\x63reated_timestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12:\n\x16last_updated_timestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x46\n\x19materialization_intervals\x18\x03 \x03(\x0b\x32#.feast.core.MaterializationInterval\"w\n\x17MaterializationInterval\x12.\n\nstart_time\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"@\n\x0f\x46\x65\x61tureViewList\x12-\n\x0c\x66\x65\x61tureviews\x18\x01 \x03(\x0b\x32\x17.feast.core.FeatureViewBU\n\x10\x66\x65\x61st.proto.coreB\x10\x46\x65\x61tureViewProtoZ/github.com/feast-dev/feast/go/protos/feast/coreb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,13 +31,13 @@ _globals['_FEATUREVIEW']._serialized_start=164 _globals['_FEATUREVIEW']._serialized_end=263 _globals['_FEATUREVIEWSPEC']._serialized_start=266 - _globals['_FEATUREVIEWSPEC']._serialized_end=711 - _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_start=668 - _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_end=711 - _globals['_FEATUREVIEWMETA']._serialized_start=714 - _globals['_FEATUREVIEWMETA']._serialized_end=918 - _globals['_MATERIALIZATIONINTERVAL']._serialized_start=920 - _globals['_MATERIALIZATIONINTERVAL']._serialized_end=1039 - _globals['_FEATUREVIEWLIST']._serialized_start=1041 - _globals['_FEATUREVIEWLIST']._serialized_end=1105 + _globals['_FEATUREVIEWSPEC']._serialized_end=728 + _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_start=685 + _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_end=728 + _globals['_FEATUREVIEWMETA']._serialized_start=731 + _globals['_FEATUREVIEWMETA']._serialized_end=935 + _globals['_MATERIALIZATIONINTERVAL']._serialized_start=937 + _globals['_MATERIALIZATIONINTERVAL']._serialized_end=1056 + _globals['_FEATUREVIEWLIST']._serialized_start=1058 + _globals['_FEATUREVIEWLIST']._serialized_end=1122 # @@protoc_insertion_point(module_scope) diff --git a/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi b/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi index 57158fc2c6c..6abeb85e263 100644 --- a/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi +++ b/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi @@ -90,6 +90,7 @@ class FeatureViewSpec(google.protobuf.message.Message): BATCH_SOURCE_FIELD_NUMBER: builtins.int STREAM_SOURCE_FIELD_NUMBER: builtins.int ONLINE_FIELD_NUMBER: builtins.int + OFFLINE_FIELD_NUMBER: builtins.int name: builtins.str """Name of the feature view. Must be unique. Not updated.""" project: builtins.str @@ -124,7 +125,11 @@ class FeatureViewSpec(google.protobuf.message.Message): def stream_source(self) -> feast.core.DataSource_pb2.DataSource: """Streaming DataSource from where this view can consume "online" feature data.""" online: builtins.bool - """Whether these features should be served online or not""" + """Whether these features should be served online or not + This is also used to determine whether the features should be written to the online store + """ + offline: builtins.bool + """Whether these features should be written to the offline store""" def __init__( self, *, @@ -140,9 +145,10 @@ class FeatureViewSpec(google.protobuf.message.Message): batch_source: feast.core.DataSource_pb2.DataSource | None = ..., stream_source: feast.core.DataSource_pb2.DataSource | None = ..., online: builtins.bool = ..., + offline: builtins.bool = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "stream_source", b"stream_source", "ttl", b"ttl"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "description", b"description", "entities", b"entities", "entity_columns", b"entity_columns", "features", b"features", "name", b"name", "online", b"online", "owner", b"owner", "project", b"project", "stream_source", b"stream_source", "tags", b"tags", "ttl", b"ttl"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "description", b"description", "entities", b"entities", "entity_columns", b"entity_columns", "features", b"features", "name", b"name", "offline", b"offline", "online", b"online", "owner", b"owner", "project", b"project", "stream_source", b"stream_source", "tags", b"tags", "ttl", b"ttl"]) -> None: ... global___FeatureViewSpec = FeatureViewSpec From 3af968bb0e9c8574677500a254455b3c45bde34d Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Tue, 22 Apr 2025 00:08:47 -0700 Subject: [PATCH 10/11] fix write node Signed-off-by: HaoXuAI --- .../infra/compute_engines/spark/nodes.py | 27 +++++- .../infra/compute_engines/spark/utils.py | 9 +- .../contrib/spark_offline_store/spark.py | 95 +++++++++++++------ .../compute_engines/spark/test_compute.py | 6 +- 4 files changed, 94 insertions(+), 43 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index 7ec663b4b2f..7fe7fdb45e8 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -17,6 +17,9 @@ SparkRetrievalJob, _get_entity_schema, ) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) from feast.infra.offline_stores.offline_utils import ( infer_event_timestamp_from_entity_df, ) @@ -282,12 +285,30 @@ def execute(self, context: ExecutionContext) -> DAGValue: feature_view=self.feature_view, repo_config=context.repo_config ) - # ✅ 1. Write to online or offline store (if enabled) - if self.feature_view.online or self.feature_view.offline: + # ✅ 1. Write to online store if online enabled + if self.feature_view.online: spark_df.mapInArrow( - lambda x: map_in_arrow(x, serialized_artifacts), spark_df.schema + lambda x: map_in_arrow(x, serialized_artifacts, mode="online"), + spark_df.schema, ).count() + # ✅ 2. Write to offline store if offline enabled + if self.feature_view.offline: + if not isinstance(self.feature_view.batch_source, SparkSource): + spark_df.mapInArrow( + lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"), + spark_df.schema, + ).count() + # Directly write spark df to spark offline store without using mapInArrow + else: + dest_path = self.feature_view.batch_source.path + file_format = self.feature_view.batch_source.file_format + if not dest_path or not file_format: + raise ValueError( + "Destination path and file format must be specified for SparkSource." + ) + spark_df.write.format(file_format).mode("append").save(dest_path) + return DAGValue( data=spark_df, format=DAGFormat.SPARK, diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 700e1a7428d..7808ca0118a 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Literal, Optional import pyarrow as pa from pyspark import SparkConf @@ -27,6 +27,7 @@ def get_or_create_new_spark_session( def map_in_arrow( iterator: Iterable[pa.RecordBatch], serialized_artifacts: "SerializedArtifacts", + mode: Literal["online", "offline"] = "online", ): for batch in iterator: table = pa.Table.from_batches([batch]) @@ -37,9 +38,8 @@ def map_in_arrow( offline_store, repo_config, ) = serialized_artifacts.unserialize() - print("write_feature_view", feature_view) - if feature_view.online: + if mode == "online": join_key_to_value_type = { entity.name: entity.dtype.to_value_type() for entity in feature_view.entity_columns @@ -55,8 +55,7 @@ def map_in_arrow( data=rows_to_write, progress=lambda x: None, ) - if feature_view.offline: - print("offline_to_write", table) + if mode == "offline": offline_store.offline_write_batch( config=repo_config, feature_view=feature_view, diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 806610cae7e..ff0e7b3410e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -3,7 +3,7 @@ import uuid import warnings from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import pandas @@ -54,6 +54,8 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel): region: Optional[StrictStr] = None """ AWS Region if applicable for s3-based staging locations""" + mode: Optional[Literal["driver", "worker"]] = "driver" + class SparkOfflineStore(OfflineStore): @staticmethod @@ -218,6 +220,22 @@ def offline_write_batch( table: pyarrow.Table, progress: Optional[Callable[[int], Any]], ): + """ + Write pyarrow table to offline store. + This method supports two execution modes: + - "driver": Uses Spark to perform schema validation, type casting, and appending the data to the offline store. + This mode must run on the Spark driver and supports advanced functionality like schema enforcement. + - "worker": A simplified, worker-safe implementation that writes Arrow tables directly to storage. + This mode is designed for distributed execution within mapInArrow or other parallel contexts. + + Args: + config: RepoConfig + feature_view: FeatureView + table: pyarrow.Table + progress: Callable[[int], Any] + mode: Literal["driver", "worker"], default is "driver" + + """ assert isinstance(config.offline_store, SparkOfflineStoreConfig) assert isinstance(feature_view.batch_source, SparkSource) @@ -230,38 +248,55 @@ def offline_write_batch( f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." ) - spark_session = get_spark_session_or_start_new_with_repoconfig( - store_config=config.offline_store - ) + mode = config.offline_store.mode - if feature_view.batch_source.path: - # write data to disk so that it can be loaded into spark (for preserving column types) - with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file: - print(tmp_file.name) - pq.write_table(table, tmp_file.name) - - # load data - df_batch = spark_session.read.parquet(tmp_file.name) - - # load existing data to get spark table schema - df_existing = spark_session.read.format( - feature_view.batch_source.file_format - ).load(feature_view.batch_source.path) - - # cast columns if applicable - df_batch = _cast_data_frame(df_batch, df_existing) - - df_batch.write.format(feature_view.batch_source.file_format).mode( - "append" - ).save(feature_view.batch_source.path) - elif feature_view.batch_source.query: - raise NotImplementedError( - "offline_write_batch not implemented for batch sources specified by query" + if mode == "driver": + spark_session = get_spark_session_or_start_new_with_repoconfig( + store_config=config.offline_store ) + + if feature_view.batch_source.path: + # write data to disk so that it can be loaded into spark (for preserving column types) + with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file: + print(tmp_file.name) + pq.write_table(table, tmp_file.name) + + # load data + df_batch = spark_session.read.parquet(tmp_file.name) + + # load existing data to get spark table schema + df_existing = spark_session.read.format( + feature_view.batch_source.file_format + ).load(feature_view.batch_source.path) + + # cast columns if applicable + df_batch = _cast_data_frame(df_batch, df_existing) + + df_batch.write.format(feature_view.batch_source.file_format).mode( + "append" + ).save(feature_view.batch_source.path) + elif feature_view.batch_source.query: + raise NotImplementedError( + "offline_write_batch not implemented for batch sources specified by query" + ) + else: + raise NotImplementedError( + "offline_write_batch not implemented for batch sources specified by a table" + ) + elif mode == "worker": + # Safe worker-side Arrow write + if not feature_view.batch_source.path: + raise ValueError("Path is required for worker mode.") + + unique_name = f"batch_{uuid.uuid4().hex}.parquet" + output_path = os.path.join(feature_view.batch_source.path, unique_name) + + pq.write_table(table, output_path) + + if progress: + progress(table.num_rows) else: - raise NotImplementedError( - "offline_write_batch not implemented for batch sources specified by a table" - ) + raise ValueError(f"Unsupported mode: {mode}") @staticmethod def pull_all_from_table_or_query( diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 5fb8112c91e..621190643a4 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -271,7 +271,6 @@ def tqdm_builder(length): fs=fs, feature="driver_hourly_stats:conv_rate", entity_df=entity_df, - expected_value=1.6, ) finally: spark_environment.teardown() @@ -303,15 +302,12 @@ def _check_offline_features( fs, feature, entity_df, - expected_value, ): offline_df = fs.get_historical_features( entity_df=entity_df, features=[feature], ).to_df() - - assert len(offline_df) == 2 - assert offline_df["driver_id"].to_list() == [1001, 1002] + assert len(offline_df) == 4 if __name__ == "__main__": From 7d234380ad21e8bc772b077761216e7e1d1aa48a Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Tue, 22 Apr 2025 00:10:19 -0700 Subject: [PATCH 11/11] fix write node Signed-off-by: HaoXuAI --- .../contrib/spark_offline_store/spark.py | 95 ++++++------------- 1 file changed, 30 insertions(+), 65 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index ff0e7b3410e..806610cae7e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -3,7 +3,7 @@ import uuid import warnings from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import pandas @@ -54,8 +54,6 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel): region: Optional[StrictStr] = None """ AWS Region if applicable for s3-based staging locations""" - mode: Optional[Literal["driver", "worker"]] = "driver" - class SparkOfflineStore(OfflineStore): @staticmethod @@ -220,22 +218,6 @@ def offline_write_batch( table: pyarrow.Table, progress: Optional[Callable[[int], Any]], ): - """ - Write pyarrow table to offline store. - This method supports two execution modes: - - "driver": Uses Spark to perform schema validation, type casting, and appending the data to the offline store. - This mode must run on the Spark driver and supports advanced functionality like schema enforcement. - - "worker": A simplified, worker-safe implementation that writes Arrow tables directly to storage. - This mode is designed for distributed execution within mapInArrow or other parallel contexts. - - Args: - config: RepoConfig - feature_view: FeatureView - table: pyarrow.Table - progress: Callable[[int], Any] - mode: Literal["driver", "worker"], default is "driver" - - """ assert isinstance(config.offline_store, SparkOfflineStoreConfig) assert isinstance(feature_view.batch_source, SparkSource) @@ -248,55 +230,38 @@ def offline_write_batch( f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." ) - mode = config.offline_store.mode + spark_session = get_spark_session_or_start_new_with_repoconfig( + store_config=config.offline_store + ) - if mode == "driver": - spark_session = get_spark_session_or_start_new_with_repoconfig( - store_config=config.offline_store + if feature_view.batch_source.path: + # write data to disk so that it can be loaded into spark (for preserving column types) + with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file: + print(tmp_file.name) + pq.write_table(table, tmp_file.name) + + # load data + df_batch = spark_session.read.parquet(tmp_file.name) + + # load existing data to get spark table schema + df_existing = spark_session.read.format( + feature_view.batch_source.file_format + ).load(feature_view.batch_source.path) + + # cast columns if applicable + df_batch = _cast_data_frame(df_batch, df_existing) + + df_batch.write.format(feature_view.batch_source.file_format).mode( + "append" + ).save(feature_view.batch_source.path) + elif feature_view.batch_source.query: + raise NotImplementedError( + "offline_write_batch not implemented for batch sources specified by query" ) - - if feature_view.batch_source.path: - # write data to disk so that it can be loaded into spark (for preserving column types) - with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file: - print(tmp_file.name) - pq.write_table(table, tmp_file.name) - - # load data - df_batch = spark_session.read.parquet(tmp_file.name) - - # load existing data to get spark table schema - df_existing = spark_session.read.format( - feature_view.batch_source.file_format - ).load(feature_view.batch_source.path) - - # cast columns if applicable - df_batch = _cast_data_frame(df_batch, df_existing) - - df_batch.write.format(feature_view.batch_source.file_format).mode( - "append" - ).save(feature_view.batch_source.path) - elif feature_view.batch_source.query: - raise NotImplementedError( - "offline_write_batch not implemented for batch sources specified by query" - ) - else: - raise NotImplementedError( - "offline_write_batch not implemented for batch sources specified by a table" - ) - elif mode == "worker": - # Safe worker-side Arrow write - if not feature_view.batch_source.path: - raise ValueError("Path is required for worker mode.") - - unique_name = f"batch_{uuid.uuid4().hex}.parquet" - output_path = os.path.join(feature_view.batch_source.path, unique_name) - - pq.write_table(table, output_path) - - if progress: - progress(table.num_rows) else: - raise ValueError(f"Unsupported mode: {mode}") + raise NotImplementedError( + "offline_write_batch not implemented for batch sources specified by a table" + ) @staticmethod def pull_all_from_table_or_query(