From 1183a9a314cbdebf2c0d45b8b33f5d2add4a1162 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Fri, 4 Apr 2025 14:58:10 -0700 Subject: [PATCH 01/18] add compute engine Signed-off-by: HaoXuAI --- sdk/python/feast/batch_feature_view.py | 3 + sdk/python/feast/feature_view.py | 4 + .../feast/infra/compute_engines/base.py | 49 ++++ .../infra/compute_engines/dag/builder.py | 70 ++++++ .../feast/infra/compute_engines/dag/model.py | 27 +++ .../feast/infra/compute_engines/dag/node.py | 29 +++ .../feast/infra/compute_engines/dag/plan.py | 63 +++++ .../feast/infra/compute_engines/dag/value.py | 18 ++ .../infra/compute_engines/spark/compute.py | 81 +++++++ .../feast/infra/compute_engines/spark/node.py | 220 ++++++++++++++++++ .../spark/spark_dag_builder.py | 43 ++++ 11 files changed, 607 insertions(+) create mode 100644 sdk/python/feast/infra/compute_engines/base.py create mode 100644 sdk/python/feast/infra/compute_engines/dag/builder.py create mode 100644 sdk/python/feast/infra/compute_engines/dag/model.py create mode 100644 sdk/python/feast/infra/compute_engines/dag/node.py create mode 100644 sdk/python/feast/infra/compute_engines/dag/plan.py create mode 100644 sdk/python/feast/infra/compute_engines/dag/value.py create mode 100644 sdk/python/feast/infra/compute_engines/spark/compute.py create mode 100644 sdk/python/feast/infra/compute_engines/spark/node.py create mode 100644 sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 57d5aa1b07e..c3c6784b79c 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -55,6 +55,7 @@ class BatchFeatureView(FeatureView): entity_columns: List[Field] features: List[Field] online: bool + offline: bool description: str tags: Dict[str, str] owner: str @@ -74,6 +75,7 @@ def __init__( ttl: Optional[timedelta] = None, tags: Optional[Dict[str, str]] = None, online: bool = True, + offline: bool = True, description: str = "", owner: str = "", schema: Optional[List[Field]] = None, @@ -110,6 +112,7 @@ def __init__( ttl=ttl, tags=tags, online=online, + offline=offline, description=description, owner=owner, schema=schema, diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 49b74893451..70881588a71 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -93,6 +93,7 @@ class FeatureView(BaseFeatureView): entity_columns: List[Field] features: List[Field] online: bool + offline: bool description: str tags: Dict[str, str] owner: str @@ -107,6 +108,7 @@ def __init__( entities: Optional[List[Entity]] = None, ttl: Optional[timedelta] = timedelta(days=0), online: bool = True, + offline: bool = False, description: str = "", tags: Optional[Dict[str, str]] = None, owner: str = "", @@ -127,6 +129,8 @@ def __init__( can result in extremely computationally intensive queries. online (optional): A boolean indicating whether online retrieval is enabled for this feature view. + offline (optional): A boolean indicating whether write to offline store is enabled for + this feature view. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the feature view, typically the email of the diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py new file mode 100644 index 00000000000..b5120f9f1c9 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -0,0 +1,49 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Union, List + +import pandas as pd +import pyarrow as pa + +from feast import RepoConfig, BatchFeatureView, StreamFeatureView +from feast.infra.materialization.batch_materialization_engine import MaterializationTask, MaterializationJob +from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.registry.registry import Registry + + +@dataclass +class HistoricalRetrievalTask: + entity_df: Union[pd.DataFrame, str] + feature_views: List[Union[BatchFeatureView, StreamFeatureView]] + full_feature_names: bool + registry: Registry + config: RepoConfig + + +class ComputeEngine(ABC): + """ + The interface that Feast uses to control the compute system that handles materialization and get_historical_features. + """ + + def __init__( + self, + *, + registry: Registry, + repo_config: RepoConfig, + offline_store: OfflineStore, + online_store: OnlineStore, + **kwargs, + ): + self.registry = registry + self.repo_config = repo_config + self.offline_store = offline_store + self.online_store = online_store + + def materialize(self, + task: MaterializationTask) -> MaterializationJob: + raise NotImplementedError + + def get_historical_features(self, + task: HistoricalRetrievalTask) -> pa.Table: + raise NotImplementedError diff --git a/sdk/python/feast/infra/compute_engines/dag/builder.py b/sdk/python/feast/infra/compute_engines/dag/builder.py new file mode 100644 index 00000000000..7ba4329b26e --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/builder.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractmethod +from typing import Union + +from feast import BatchFeatureView, StreamFeatureView +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import MaterializationTask + + +class DAGBuilder(ABC): + def __init__(self, + feature_view: Union[BatchFeatureView, StreamFeatureView], + task: Union[MaterializationTask, HistoricalRetrievalTask] + ): + self.feature_view = feature_view + self.task = task + self.nodes = [] + + @abstractmethod + def build_source_node(self): + raise NotImplementedError + + @abstractmethod + def build_aggregation_node(self, + input_node): + raise NotImplementedError + + @abstractmethod + def build_join_node(self, + input_node): + raise NotImplementedError + + @abstractmethod + def build_transformation_node(self, + input_node): + raise NotImplementedError + + @abstractmethod + def build_output_nodes(self, + input_node): + raise NotImplementedError + + @abstractmethod + def build_validation_node(self, + input_node): + raise + + def build(self) -> ExecutionPlan: + last_node = self.build_source_node() + + if getattr(self.feature_view.transformation, "requires_aggregation", False): + last_node = self.build_aggregation_node(last_node) + + if self._should_join(): + last_node = self.build_join_node(last_node) + + if self.feature_view.transformation: + last_node = self.build_transformation_node(last_node) + + if getattr(self.feature_view, "enable_validation", False): + last_node = self.build_validation_node(last_node) + + self.build_output_nodes(last_node) + return ExecutionPlan(self.nodes) + + def _should_join(self): + return ( + self.feature_view.compute_config.join_strategy == "engine" + or self.task.config.compute_engine.get("point_in_time_join") == "engine" + ) diff --git a/sdk/python/feast/infra/compute_engines/dag/model.py b/sdk/python/feast/infra/compute_engines/dag/model.py new file mode 100644 index 00000000000..b42ee62dc9d --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/model.py @@ -0,0 +1,27 @@ +from enum import Enum +from dataclasses import dataclass, field +import pandas as pd +from typing import Union, List, Dict + +from feast.entity import Entity +from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.online_stores.online_store import OnlineStore +from feast.repo_config import RepoConfig +from feast.infra.compute_engines.dag.value import DAGValue + + +class DAGFormat(str, Enum): + SPARK = "spark" + PANDAS = "pandas" + ARROW = "arrow" + + +@dataclass +class ExecutionContext: + project: str + repo_config: RepoConfig + offline_store: OfflineStore + online_store: OnlineStore + entity_defs: List[Entity] + entity_df: Union[pd.DataFrame, None] = None + node_outputs: Dict[str, DAGValue] = field(default_factory=dict) diff --git a/sdk/python/feast/infra/compute_engines/dag/node.py b/sdk/python/feast/infra/compute_engines/dag/node.py new file mode 100644 index 00000000000..47b31eba0c0 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/node.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import List + +from feast.infra.compute_engines.dag.model import ExecutionContext +from infra.compute_engines.dag.value import DAGValue + + +class DAGNode(ABC): + name: str + inputs: List["DAGNode"] + outputs: List["DAGNode"] + + def __init__(self, + name: str): + self.name = name + self.inputs = [] + self.outputs = [] + + def add_input(self, + node: "DAGNode"): + if node in self.inputs: + raise ValueError(f"Input node {node.name} already added to {self.name}") + self.inputs.append(node) + node.outputs.append(self) + + @abstractmethod + def execute(self, + context: ExecutionContext) -> DAGValue: + ... diff --git a/sdk/python/feast/infra/compute_engines/dag/plan.py b/sdk/python/feast/infra/compute_engines/dag/plan.py new file mode 100644 index 00000000000..ebfd1240315 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/plan.py @@ -0,0 +1,63 @@ +from typing import List + +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.dag.model import ExecutionContext + + +class ExecutionPlan: + """ + ExecutionPlan represents an ordered sequence of DAGNodes that together define + a data processing pipeline for feature materialization or historical retrieval. + + This plan is constructed as a topological sort of the DAG — meaning that each + node appears after all its input dependencies. The plan is executed in order, + caching intermediate results (`DAGValue`) so that each node can reuse outputs + from upstream nodes without recomputation. + + Key Concepts: + - DAGNode: Each node performs a specific logical step (e.g., read, aggregate, join). + - DAGValue: Output of a node, includes data (e.g., Spark DataFrame) and metadata. + - ExecutionContext: Contains runtime information (config, registry, stores, entity_df). + - node_outputs: A cache of intermediate results keyed by node name. + + Usage: + plan = ExecutionPlan(dag_nodes) + result = plan.execute(context) + + This design enables modular compute backends (e.g., Spark, Pandas, Arrow), where + each node defines its execution logic independently while benefiting from shared + execution orchestration, caching, and context injection. + + Example: + DAG: + ReadNode -> AggregateNode -> JoinNode -> TransformNode -> WriteNode + + Execution proceeds step by step, passing intermediate DAGValues through + the plan while respecting node dependencies and formats. + + This approach is inspired by execution DAGs in systems like Apache Spark, + Apache Beam, and Dask — but specialized for Feast’s feature computation domain. + """ + def __init__(self, nodes: List[DAGNode]): + self.nodes = nodes + + def execute(self, context: ExecutionContext) -> DAGValue: + node_outputs: dict[str, DAGValue] = {} + + for node in self.nodes: + # Gather input values + for input_node in node.inputs: + if input_node.name not in node_outputs: + node_outputs[input_node.name] = input_node.execute(context) + + # Execute this node + output = node.execute(context) + node_outputs[node.name] = output + + # Inject into context for downstream access + context.node_outputs = node_outputs + + # Return output of final node + return node_outputs[self.nodes[-1].name] + diff --git a/sdk/python/feast/infra/compute_engines/dag/value.py b/sdk/python/feast/infra/compute_engines/dag/value.py new file mode 100644 index 00000000000..f45b969cf77 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/value.py @@ -0,0 +1,18 @@ +from typing import Any, Optional + +from feast.infra.compute_engines.dag.model import DAGFormat + + +class DAGValue: + def __init__(self, + data: Any, + format: DAGFormat, + metadata: Optional[dict] = None): + self.data = data + self.format = format + self.metadata = metadata or {} + + def assert_format(self, + expected: DAGFormat): + if self.format != expected: + raise ValueError(f"Expected format {expected}, but got {self.format}") diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py new file mode 100644 index 00000000000..a93aadd5ee7 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -0,0 +1,81 @@ +import pyarrow as pa +from feast.infra.compute_engines.base import ComputeEngine, HistoricalRetrievalTask +from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder +from feast.infra.materialization.batch_materialization_engine import MaterializationTask, MaterializationJob, \ + MaterializationJobStatus +from feast.infra.materialization.contrib.spark.spark_materialization_engine import SparkMaterializationJob +from feast.infra.compute_engines.dag.model import ExecutionContext + + +class SparkComputeEngine(ComputeEngine): + def materialize(self, task: MaterializationTask) -> MaterializationJob: + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" + + try: + # ✅ 1. Build typed execution context + entities = [] + for entity_name in task.feature_view.entities: + entities.append(self.registry.get_entity(entity_name, task.project)) + + context = ExecutionContext( + project=task.project, + repo_config=self.repo_config, + offline_store=self.offline_store, + online_store=self.online_store, + entity_defs=entities + ) + + # ✅ 2. Construct DAG and run it + builder = SparkDAGBuilder( + feature_view=task.feature_view, + task=task, + ) + plan = builder.build() + plan.execute(context) + + # ✅ 3. Report success + return SparkMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED + ) + + except Exception as e: + # 🛑 Handle failure + return SparkMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=e + ) + + def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: + # ✅ 1. Validate input + assert len(task.feature_views) == 1, "Multi-view support not yet implemented" + feature_view = task.feature_views[0] + + if isinstance(task.entity_df, str): + raise NotImplementedError("SQL-based entity_df is not yet supported in DAG") + + # ✅ 2. Build typed execution context + entity_defs = [ + task.registry.get_entity(name, task.config.project) + for name in feature_view.entities + ] + + context = ExecutionContext( + project=task.config.project, + repo_config=task.config, + offline_store=self.offline_store, + online_store=self.online_store, + entity_defs=entity_defs, + entity_df=task.entity_df, + ) + + # ✅ 3. Construct and execute DAG + builder = SparkDAGBuilder(feature_view=feature_view, task=task) + plan = builder.build() + + result = plan.execute(context=context) + spark_df = result.data # should be a Spark DataFrame + + # ✅ 4. Return as Arrow + return spark_df.toPandas().to_arrow() diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py new file mode 100644 index 00000000000..3fccea748ee --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -0,0 +1,220 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Union, List, Optional, Dict + +from pyspark._typing import F +from pyspark.sql import DataFrame, Window + +from aggregation import Aggregation +from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.model import DAGFormat, ExecutionContext +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.materialization.batch_materialization_engine import MaterializationTask +from feast.infra.offline_stores.contrib.spark_offline_store.spark import SparkRetrievalJob +from feast.utils import _get_column_names + + +@dataclass +class SparkJoinContext: + name: str # feature view name or alias + join_keys: List[str] + feature_columns: List[str] + timestamp_field: str + created_timestamp_column: Optional[str] + ttl_seconds: Optional[int] + min_event_timestamp: Optional[datetime] + max_event_timestamp: Optional[datetime] + field_mapping: Dict[str, str] # original_column_name -> renamed_column + full_feature_names: bool = False # apply feature view name prefix + + +class SparkReadNode(DAGNode): + def __init__(self, + name: str, + task: Union[MaterializationTask, HistoricalRetrievalTask]): + super().__init__(name) + self.task = task + + def execute(self, + context: ExecutionContext) -> DAGValue: + offline_store = context.offline_store + + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = _get_column_names(self.task.feature_view, context.entity_defs) + + # 📥 Reuse Feast's robust query resolver + retrieval_job: SparkRetrievalJob = offline_store.pull_latest_from_table_or_query( + config=context.repo_config, + data_source=self.task.feature_view.batch_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=self.task.start_time, + end_date=self.task.end_time, + ) + + spark_df = retrieval_job.to_spark_df() + + return DAGValue( + data=spark_df, + format=DAGFormat.SPARK, + metadata={ + "source": "feature_view_batch_source", + "timestamp_field": timestamp_field, + "created_timestamp_column": created_timestamp_column, + "start_date": self.task.start_time, + "end_date": self.task.end_time, + }, + ) + + +class SparkAggregationNode(DAGNode): + def __init__(self, + name: str, + input_node: DAGNode, + aggregations: List[Aggregation], + group_by_keys: List[str], + timestamp_col: str): + super().__init__(name) + self.add_input(input_node) + self.aggregations = aggregations + self.group_by_keys = group_by_keys + self.timestamp_col = timestamp_col + + def execute(self, + context: ExecutionContext) -> DAGValue: + input_df: DataFrame = context.node_outputs[self.inputs[0].name].data + input_df.assert_format(DAGFormat.SPARK) + + agg_exprs = [] + for agg in self.aggregations: + func = getattr(F, agg.function) + expr = func(agg.column).alias( + f"{agg.function}_{agg.column}_{int(agg.time_window.total_seconds())}s" if agg.time_window else f"{agg.function}_{agg.column}") + agg_exprs.append(expr) + + if any(agg.time_window for agg in self.aggregations): + # 🕒 Use Spark's `window` function + time_window = self.aggregations[0].time_window # assume consistent window size for now + grouped = input_df.groupBy( + *self.group_by_keys, + F.window(F.col(self.timestamp_col), f"{int(time_window.total_seconds())} seconds") + ).agg(*agg_exprs) + else: + # Simple aggregation + grouped = input_df.groupBy(*self.group_by_keys).agg(*agg_exprs) + + return DAGValue( + data=grouped, + format=DAGFormat.SPARK, + metadata={"aggregated": True}) + + +class SparkJoinNode(DAGNode): + def __init__(self, + name: str, + feature_node: DAGNode, + join_keys: List[str], + feature_view + ): + super().__init__(name) + self.join_keys = join_keys + self.add_input(feature_node) + self.feature_view = feature_view + + def execute(self, + context: ExecutionContext) -> DAGValue: + feature_value = context.node_outputs[self.inputs[1].name] + feature_value.assert_format(DAGFormat.SPARK) + + entity_df = context.entity_df + feature_df = feature_value.data + + # Get timestamp fields from feature view + join_keys, feature_cols, ts_col, created_ts_col = _get_column_names( + self.feature_view, context.entity_defs + ) + + entity_event_ts_col = "event_timestamp" # Standardized by SparkEntityLoadNode + + # Perform left join + event timestamp filtering + joined = feature_df.join(entity_df, on=join_keys, how="left") + joined = joined.filter(F.col(ts_col) <= F.col(entity_event_ts_col)) + + # Dedup with row_number + partition_cols = join_keys + [entity_event_ts_col] + ordering = [F.col(ts_col).desc()] + if created_ts_col: + ordering.append(F.col(created_ts_col).desc()) + + window = Window.partitionBy(*partition_cols).orderBy(*ordering) + deduped = joined.withColumn("row_num", F.row_number().over(window)).filter("row_num = 1").drop("row_num") + + return DAGValue( + data=deduped, + format=DAGFormat.SPARK, + metadata={"joined_on": join_keys} + ) + + +class SparkWriteNode(DAGNode): + def __init__(self, + name: str, + input_node: DAGNode, + task): + super().__init__(name) + self.task = task + self.add_input(input_node) + + def execute(self, + context: ExecutionContext) -> DAGValue: + spark_df: DataFrame = context.node_outputs[self.inputs[0].name].data + feature_view = self.task.feature_view + + # ✅ 1. Write to offline store (if enabled) + if getattr(self.task, "offline", True): + context.offline_store.offline_write_batch( + config=context.repo_config, + feature_view=feature_view, + table=spark_df, + ) + + # ✅ 2. Write to online store (if enabled) + if getattr(self.task, "online", False): + context.online_store.online_write_batch( + config=context.repo_config, + data=spark_df, + ) + + return DAGValue(data=spark_df, + format=DAGFormat.SPARK, + metadata={"written_to": "online+offline" if self.task.online else "offline"}) + + +class SparkTransformationNode(DAGNode): + def __init__(self, + name: str, + input_node: DAGNode, + udf): + super().__init__(name) + self.add_input(input_node) + self.udf = udf + + def execute(self, + context: ExecutionContext) -> DAGValue: + input_val = context.node_outputs[self.inputs[0].name] + input_val.assert_format(DAGFormat.SPARK) + + transformed_df = self.udf(input_val.data) + + return DAGValue( + data=transformed_df, + format=DAGFormat.SPARK, + metadata={"transformed": True} + ) diff --git a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py new file mode 100644 index 00000000000..2e5c3c42db8 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py @@ -0,0 +1,43 @@ +from feast.infra.compute_engines.dag.builder import DAGBuilder +from feast.infra.compute_engines.spark.node import SparkReadNode, SparkAggregationNode, SparkJoinNode, SparkWriteNode, \ + SparkTransformationNode + + +class SparkDAGBuilder(DAGBuilder): + + def build_source_node(self): + source_path = self.feature_view.source.path + node = SparkReadNode("source", source_path) + self.nodes.append(node) + return node + + def build_aggregation_node(self, + input_node): + agg_specs = self.feature_view.aggregations + node = SparkAggregationNode("agg", input_node, agg_specs) + self.nodes.append(node) + return node + + def build_join_node(self, + input_node): + join_keys = self.feature_view.entities + node = SparkJoinNode("join", input_node, join_keys) + self.nodes.append(node) + return node + + def build_transformation_node(self, + input_node): + udf_name = self.feature_view.transformation.name + udf = self.feature_view.transformation.udf + node = SparkTransformationNode(udf_name, input_node, udf) + self.nodes.append(node) + return node + + def build_output_nodes(self, + input_node): + output_node = SparkWriteNode("output", input_node) + self.nodes.append(output_node) + + def build_validation_node(self, + input_node): + pass From 2825ee4592e392a019dcffae045084a4b01f3490 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Fri, 4 Apr 2025 16:37:11 -0700 Subject: [PATCH 02/18] fix linting Signed-off-by: HaoXuAI --- .../feast/infra/compute_engines/base.py | 36 +++--- .../infra/compute_engines/dag/builder.py | 39 +++--- .../feast/infra/compute_engines/dag/model.py | 7 +- .../feast/infra/compute_engines/dag/node.py | 13 +- .../feast/infra/compute_engines/dag/plan.py | 4 +- .../feast/infra/compute_engines/dag/value.py | 8 +- .../infra/compute_engines/spark/compute.py | 30 ++--- .../feast/infra/compute_engines/spark/node.py | 119 +++++++++++------- .../spark/spark_dag_builder.py | 25 ++-- 9 files changed, 150 insertions(+), 131 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index b5120f9f1c9..9f4fffbc2d9 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -1,12 +1,16 @@ from abc import ABC from dataclasses import dataclass -from typing import Union, List +from datetime import datetime +from typing import Union import pandas as pd import pyarrow as pa -from feast import RepoConfig, BatchFeatureView, StreamFeatureView -from feast.infra.materialization.batch_materialization_engine import MaterializationTask, MaterializationJob +from feast import BatchFeatureView, RepoConfig, StreamFeatureView +from feast.infra.materialization.batch_materialization_engine import ( + MaterializationJob, + MaterializationTask, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.registry import Registry @@ -15,10 +19,12 @@ @dataclass class HistoricalRetrievalTask: entity_df: Union[pd.DataFrame, str] - feature_views: List[Union[BatchFeatureView, StreamFeatureView]] - full_feature_names: bool + feature_view: Union[BatchFeatureView, StreamFeatureView] + full_feature_name: bool registry: Registry config: RepoConfig + start_time: datetime + end_time: datetime class ComputeEngine(ABC): @@ -27,23 +33,21 @@ class ComputeEngine(ABC): """ def __init__( - self, - *, - registry: Registry, - repo_config: RepoConfig, - offline_store: OfflineStore, - online_store: OnlineStore, - **kwargs, + self, + *, + registry: Registry, + repo_config: RepoConfig, + offline_store: OfflineStore, + online_store: OnlineStore, + **kwargs, ): self.registry = registry self.repo_config = repo_config self.offline_store = offline_store self.online_store = online_store - def materialize(self, - task: MaterializationTask) -> MaterializationJob: + def materialize(self, task: MaterializationTask) -> MaterializationJob: raise NotImplementedError - def get_historical_features(self, - task: HistoricalRetrievalTask) -> pa.Table: + def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: raise NotImplementedError diff --git a/sdk/python/feast/infra/compute_engines/dag/builder.py b/sdk/python/feast/infra/compute_engines/dag/builder.py index 7ba4329b26e..30166b0ea52 100644 --- a/sdk/python/feast/infra/compute_engines/dag/builder.py +++ b/sdk/python/feast/infra/compute_engines/dag/builder.py @@ -1,60 +1,57 @@ from abc import ABC, abstractmethod from typing import Union -from feast import BatchFeatureView, StreamFeatureView -from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast import BatchFeatureView, StreamFeatureView, FeatureView from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.materialization.batch_materialization_engine import MaterializationTask +from feast.infra.compute_engines.dag.node import DAGNode class DAGBuilder(ABC): - def __init__(self, - feature_view: Union[BatchFeatureView, StreamFeatureView], - task: Union[MaterializationTask, HistoricalRetrievalTask] - ): + def __init__( + self, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + task: Union[MaterializationTask, HistoricalRetrievalTask], + ): self.feature_view = feature_view self.task = task - self.nodes = [] + self.nodes: list[DAGNode] = [] @abstractmethod def build_source_node(self): raise NotImplementedError @abstractmethod - def build_aggregation_node(self, - input_node): + def build_aggregation_node(self, input_node): raise NotImplementedError @abstractmethod - def build_join_node(self, - input_node): + def build_join_node(self, input_node): raise NotImplementedError @abstractmethod - def build_transformation_node(self, - input_node): + def build_transformation_node(self, input_node): raise NotImplementedError @abstractmethod - def build_output_nodes(self, - input_node): + def build_output_nodes(self, input_node): raise NotImplementedError @abstractmethod - def build_validation_node(self, - input_node): + def build_validation_node(self, input_node): raise def build(self) -> ExecutionPlan: last_node = self.build_source_node() - if getattr(self.feature_view.transformation, "requires_aggregation", False): + if hasattr(self.feature_view, "aggregation") and self.feature_view.aggregation is not None: last_node = self.build_aggregation_node(last_node) if self._should_join(): last_node = self.build_join_node(last_node) - if self.feature_view.transformation: + if hasattr(self.feature_view, "feature_transformation") and self.feature_view.feature_transformation: last_node = self.build_transformation_node(last_node) if getattr(self.feature_view, "enable_validation", False): @@ -65,6 +62,6 @@ def build(self) -> ExecutionPlan: def _should_join(self): return ( - self.feature_view.compute_config.join_strategy == "engine" - or self.task.config.compute_engine.get("point_in_time_join") == "engine" + self.feature_view.compute_config.join_strategy == "engine" + or self.task.config.compute_engine.get("point_in_time_join") == "engine" ) diff --git a/sdk/python/feast/infra/compute_engines/dag/model.py b/sdk/python/feast/infra/compute_engines/dag/model.py index b42ee62dc9d..a488adabc98 100644 --- a/sdk/python/feast/infra/compute_engines/dag/model.py +++ b/sdk/python/feast/infra/compute_engines/dag/model.py @@ -1,13 +1,14 @@ -from enum import Enum from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Union + import pandas as pd -from typing import Union, List, Dict from feast.entity import Entity +from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.repo_config import RepoConfig -from feast.infra.compute_engines.dag.value import DAGValue class DAGFormat(str, Enum): diff --git a/sdk/python/feast/infra/compute_engines/dag/node.py b/sdk/python/feast/infra/compute_engines/dag/node.py index 47b31eba0c0..f3dd0a0ad3d 100644 --- a/sdk/python/feast/infra/compute_engines/dag/node.py +++ b/sdk/python/feast/infra/compute_engines/dag/node.py @@ -1,29 +1,26 @@ from abc import ABC, abstractmethod from typing import List -from feast.infra.compute_engines.dag.model import ExecutionContext from infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.dag.model import ExecutionContext + class DAGNode(ABC): name: str inputs: List["DAGNode"] outputs: List["DAGNode"] - def __init__(self, - name: str): + def __init__(self, name: str): self.name = name self.inputs = [] self.outputs = [] - def add_input(self, - node: "DAGNode"): + def add_input(self, node: "DAGNode"): if node in self.inputs: raise ValueError(f"Input node {node.name} already added to {self.name}") self.inputs.append(node) node.outputs.append(self) @abstractmethod - def execute(self, - context: ExecutionContext) -> DAGValue: - ... + def execute(self, context: ExecutionContext) -> DAGValue: ... diff --git a/sdk/python/feast/infra/compute_engines/dag/plan.py b/sdk/python/feast/infra/compute_engines/dag/plan.py index ebfd1240315..80a5cd5cd36 100644 --- a/sdk/python/feast/infra/compute_engines/dag/plan.py +++ b/sdk/python/feast/infra/compute_engines/dag/plan.py @@ -1,8 +1,8 @@ from typing import List +from feast.infra.compute_engines.dag.model import ExecutionContext from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.compute_engines.dag.model import ExecutionContext class ExecutionPlan: @@ -39,6 +39,7 @@ class ExecutionPlan: This approach is inspired by execution DAGs in systems like Apache Spark, Apache Beam, and Dask — but specialized for Feast’s feature computation domain. """ + def __init__(self, nodes: List[DAGNode]): self.nodes = nodes @@ -60,4 +61,3 @@ def execute(self, context: ExecutionContext) -> DAGValue: # Return output of final node return node_outputs[self.nodes[-1].name] - diff --git a/sdk/python/feast/infra/compute_engines/dag/value.py b/sdk/python/feast/infra/compute_engines/dag/value.py index f45b969cf77..0e2063d0dba 100644 --- a/sdk/python/feast/infra/compute_engines/dag/value.py +++ b/sdk/python/feast/infra/compute_engines/dag/value.py @@ -4,15 +4,11 @@ class DAGValue: - def __init__(self, - data: Any, - format: DAGFormat, - metadata: Optional[dict] = None): + def __init__(self, data: Any, format: DAGFormat, metadata: Optional[dict] = None): self.data = data self.format = format self.metadata = metadata or {} - def assert_format(self, - expected: DAGFormat): + def assert_format(self, expected: DAGFormat): if self.format != expected: raise ValueError(f"Expected format {expected}, but got {self.format}") diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index a93aadd5ee7..68ba5fd76a9 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,10 +1,16 @@ import pyarrow as pa + from feast.infra.compute_engines.base import ComputeEngine, HistoricalRetrievalTask -from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder -from feast.infra.materialization.batch_materialization_engine import MaterializationTask, MaterializationJob, \ - MaterializationJobStatus -from feast.infra.materialization.contrib.spark.spark_materialization_engine import SparkMaterializationJob from feast.infra.compute_engines.dag.model import ExecutionContext +from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder +from feast.infra.materialization.batch_materialization_engine import ( + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( + SparkMaterializationJob, +) class SparkComputeEngine(ComputeEngine): @@ -22,7 +28,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: repo_config=self.repo_config, offline_store=self.offline_store, online_store=self.online_store, - entity_defs=entities + entity_defs=entities, ) # ✅ 2. Construct DAG and run it @@ -35,22 +41,16 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: # ✅ 3. Report success return SparkMaterializationJob( - job_id=job_id, - status=MaterializationJobStatus.SUCCEEDED + job_id=job_id, status=MaterializationJobStatus.SUCCEEDED ) except Exception as e: # 🛑 Handle failure return SparkMaterializationJob( - job_id=job_id, - status=MaterializationJobStatus.ERROR, - error=e + job_id=job_id, status=MaterializationJobStatus.ERROR, error=e ) def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: - # ✅ 1. Validate input - assert len(task.feature_views) == 1, "Multi-view support not yet implemented" - feature_view = task.feature_views[0] if isinstance(task.entity_df, str): raise NotImplementedError("SQL-based entity_df is not yet supported in DAG") @@ -58,7 +58,7 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: # ✅ 2. Build typed execution context entity_defs = [ task.registry.get_entity(name, task.config.project) - for name in feature_view.entities + for name in task.feature_view.entities ] context = ExecutionContext( @@ -71,7 +71,7 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: ) # ✅ 3. Construct and execute DAG - builder = SparkDAGBuilder(feature_view=feature_view, task=task) + builder = SparkDAGBuilder(feature_view=task.feature_view, task=task) plan = builder.build() result = plan.execute(context=context) diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 3fccea748ee..02ff3562cca 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -1,18 +1,22 @@ from dataclasses import dataclass from datetime import datetime -from typing import Union, List, Optional, Dict +from typing import Dict, List, Optional, Union, cast -from pyspark._typing import F +from aggregation import Aggregation +from pyspark.sql import functions as F from pyspark.sql import DataFrame, Window -from aggregation import Aggregation from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.dag.model import DAGFormat, ExecutionContext from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.materialization.batch_materialization_engine import MaterializationTask -from feast.infra.offline_stores.contrib.spark_offline_store.spark import SparkRetrievalJob +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkRetrievalJob, +) from feast.utils import _get_column_names +from infra.materialization.contrib.spark.spark_materialization_engine import _map_by_partition, \ + _SparkSerializedArtifacts @dataclass @@ -30,15 +34,19 @@ class SparkJoinContext: class SparkReadNode(DAGNode): - def __init__(self, - name: str, - task: Union[MaterializationTask, HistoricalRetrievalTask]): + def __init__( + self, + name: str, + task: Union[MaterializationTask, HistoricalRetrievalTask] + ): super().__init__(name) self.task = task def execute(self, context: ExecutionContext) -> DAGValue: offline_store = context.offline_store + start_time = self.task.start_time + end_time = self.task.end_time ( join_key_columns, @@ -48,18 +56,17 @@ def execute(self, ) = _get_column_names(self.task.feature_view, context.entity_defs) # 📥 Reuse Feast's robust query resolver - retrieval_job: SparkRetrievalJob = offline_store.pull_latest_from_table_or_query( + retrieval_job = offline_store.pull_latest_from_table_or_query( config=context.repo_config, data_source=self.task.feature_view.batch_source, join_key_columns=join_key_columns, feature_name_columns=feature_name_columns, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, - start_date=self.task.start_time, - end_date=self.task.end_time, + start_date=start_time, + end_date=end_time, ) - - spark_df = retrieval_job.to_spark_df() + spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() return DAGValue( data=spark_df, @@ -68,19 +75,21 @@ def execute(self, "source": "feature_view_batch_source", "timestamp_field": timestamp_field, "created_timestamp_column": created_timestamp_column, - "start_date": self.task.start_time, - "end_date": self.task.end_time, + "start_date": start_time, + "end_date": end_time, }, ) class SparkAggregationNode(DAGNode): - def __init__(self, - name: str, - input_node: DAGNode, - aggregations: List[Aggregation], - group_by_keys: List[str], - timestamp_col: str): + def __init__( + self, + name: str, + input_node: DAGNode, + aggregations: List[Aggregation], + group_by_keys: List[str], + timestamp_col: str, + ): super().__init__(name) self.add_input(input_node) self.aggregations = aggregations @@ -89,40 +98,49 @@ def __init__(self, def execute(self, context: ExecutionContext) -> DAGValue: - input_df: DataFrame = context.node_outputs[self.inputs[0].name].data - input_df.assert_format(DAGFormat.SPARK) + input_value = context.node_outputs[self.inputs[0].name] + input_value.assert_format(DAGFormat.SPARK) + input_df: DataFrame = input_value.data agg_exprs = [] for agg in self.aggregations: func = getattr(F, agg.function) expr = func(agg.column).alias( - f"{agg.function}_{agg.column}_{int(agg.time_window.total_seconds())}s" if agg.time_window else f"{agg.function}_{agg.column}") + f"{agg.function}_{agg.column}_{int(agg.time_window.total_seconds())}s" + if agg.time_window + else f"{agg.function}_{agg.column}" + ) agg_exprs.append(expr) if any(agg.time_window for agg in self.aggregations): # 🕒 Use Spark's `window` function - time_window = self.aggregations[0].time_window # assume consistent window size for now + time_window = self.aggregations[ + 0 + ].time_window # assume consistent window size for now grouped = input_df.groupBy( *self.group_by_keys, - F.window(F.col(self.timestamp_col), f"{int(time_window.total_seconds())} seconds") + F.window( + F.col(self.timestamp_col), + f"{int(time_window.total_seconds())} seconds", + ), ).agg(*agg_exprs) else: # Simple aggregation grouped = input_df.groupBy(*self.group_by_keys).agg(*agg_exprs) return DAGValue( - data=grouped, - format=DAGFormat.SPARK, - metadata={"aggregated": True}) + data=grouped, format=DAGFormat.SPARK, metadata={"aggregated": True} + ) class SparkJoinNode(DAGNode): - def __init__(self, - name: str, - feature_node: DAGNode, - join_keys: List[str], - feature_view - ): + def __init__( + self, + name: str, + feature_node: DAGNode, + join_keys: List[str], + feature_view + ): super().__init__(name) self.join_keys = join_keys self.add_input(feature_node) @@ -154,12 +172,14 @@ def execute(self, ordering.append(F.col(created_ts_col).desc()) window = Window.partitionBy(*partition_cols).orderBy(*ordering) - deduped = joined.withColumn("row_num", F.row_number().over(window)).filter("row_num = 1").drop("row_num") + deduped = ( + joined.withColumn("row_num", F.row_number().over(window)) + .filter("row_num = 1") + .drop("row_num") + ) return DAGValue( - data=deduped, - format=DAGFormat.SPARK, - metadata={"joined_on": join_keys} + data=deduped, format=DAGFormat.SPARK, metadata={"joined_on": join_keys} ) @@ -183,18 +203,25 @@ def execute(self, config=context.repo_config, feature_view=feature_view, table=spark_df, + progress=None, ) # ✅ 2. Write to online store (if enabled) if getattr(self.task, "online", False): - context.online_store.online_write_batch( - config=context.repo_config, - data=spark_df, + spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( + feature_view=feature_view, repo_config=context.repo_config ) + spark_df.mapInPandas( + lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" + ).count() - return DAGValue(data=spark_df, - format=DAGFormat.SPARK, - metadata={"written_to": "online+offline" if self.task.online else "offline"}) + return DAGValue( + data=spark_df, + format=DAGFormat.SPARK, + metadata={ + "written_to": "online+offline" if self.task.online else "offline" + }, + ) class SparkTransformationNode(DAGNode): @@ -214,7 +241,5 @@ def execute(self, transformed_df = self.udf(input_val.data) return DAGValue( - data=transformed_df, - format=DAGFormat.SPARK, - metadata={"transformed": True} + data=transformed_df, format=DAGFormat.SPARK, metadata={"transformed": True} ) diff --git a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py index 2e5c3c42db8..a68d06f237d 100644 --- a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py @@ -1,43 +1,42 @@ from feast.infra.compute_engines.dag.builder import DAGBuilder -from feast.infra.compute_engines.spark.node import SparkReadNode, SparkAggregationNode, SparkJoinNode, SparkWriteNode, \ - SparkTransformationNode +from feast.infra.compute_engines.spark.node import ( + SparkAggregationNode, + SparkJoinNode, + SparkReadNode, + SparkTransformationNode, + SparkWriteNode, +) class SparkDAGBuilder(DAGBuilder): - def build_source_node(self): source_path = self.feature_view.source.path node = SparkReadNode("source", source_path) self.nodes.append(node) return node - def build_aggregation_node(self, - input_node): + def build_aggregation_node(self, input_node): agg_specs = self.feature_view.aggregations node = SparkAggregationNode("agg", input_node, agg_specs) self.nodes.append(node) return node - def build_join_node(self, - input_node): + def build_join_node(self, input_node): join_keys = self.feature_view.entities node = SparkJoinNode("join", input_node, join_keys) self.nodes.append(node) return node - def build_transformation_node(self, - input_node): + def build_transformation_node(self, input_node): udf_name = self.feature_view.transformation.name udf = self.feature_view.transformation.udf node = SparkTransformationNode(udf_name, input_node, udf) self.nodes.append(node) return node - def build_output_nodes(self, - input_node): + def build_output_nodes(self, input_node): output_node = SparkWriteNode("output", input_node) self.nodes.append(output_node) - def build_validation_node(self, - input_node): + def build_validation_node(self, input_node): pass From 4210f68e0d1bfe9c0fe48ac019dfbb717ee4b7c3 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Fri, 4 Apr 2025 16:37:28 -0700 Subject: [PATCH 03/18] fix linting Signed-off-by: HaoXuAI --- .../infra/compute_engines/dag/builder.py | 14 +++-- .../infra/compute_engines/spark/compute.py | 1 - .../feast/infra/compute_engines/spark/node.py | 55 +++++++------------ 3 files changed, 30 insertions(+), 40 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/dag/builder.py b/sdk/python/feast/infra/compute_engines/dag/builder.py index 30166b0ea52..3b5824de87d 100644 --- a/sdk/python/feast/infra/compute_engines/dag/builder.py +++ b/sdk/python/feast/infra/compute_engines/dag/builder.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from typing import Union -from feast import BatchFeatureView, StreamFeatureView, FeatureView +from feast import BatchFeatureView, FeatureView, StreamFeatureView from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.materialization.batch_materialization_engine import MaterializationTask -from feast.infra.compute_engines.dag.node import DAGNode class DAGBuilder(ABC): @@ -45,13 +45,19 @@ def build_validation_node(self, input_node): def build(self) -> ExecutionPlan: last_node = self.build_source_node() - if hasattr(self.feature_view, "aggregation") and self.feature_view.aggregation is not None: + if ( + hasattr(self.feature_view, "aggregation") + and self.feature_view.aggregation is not None + ): last_node = self.build_aggregation_node(last_node) if self._should_join(): last_node = self.build_join_node(last_node) - if hasattr(self.feature_view, "feature_transformation") and self.feature_view.feature_transformation: + if ( + hasattr(self.feature_view, "feature_transformation") + and self.feature_view.feature_transformation + ): last_node = self.build_transformation_node(last_node) if getattr(self.feature_view, "enable_validation", False): diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 68ba5fd76a9..9ad447293a8 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -51,7 +51,6 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: ) def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: - if isinstance(task.entity_df, str): raise NotImplementedError("SQL-based entity_df is not yet supported in DAG") diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 02ff3562cca..00e1991e3d5 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -3,8 +3,12 @@ from typing import Dict, List, Optional, Union, cast from aggregation import Aggregation -from pyspark.sql import functions as F +from infra.materialization.contrib.spark.spark_materialization_engine import ( + _map_by_partition, + _SparkSerializedArtifacts, +) from pyspark.sql import DataFrame, Window +from pyspark.sql import functions as F from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.dag.model import DAGFormat, ExecutionContext @@ -15,8 +19,6 @@ SparkRetrievalJob, ) from feast.utils import _get_column_names -from infra.materialization.contrib.spark.spark_materialization_engine import _map_by_partition, \ - _SparkSerializedArtifacts @dataclass @@ -35,15 +37,12 @@ class SparkJoinContext: class SparkReadNode(DAGNode): def __init__( - self, - name: str, - task: Union[MaterializationTask, HistoricalRetrievalTask] + self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask] ): super().__init__(name) self.task = task - def execute(self, - context: ExecutionContext) -> DAGValue: + def execute(self, context: ExecutionContext) -> DAGValue: offline_store = context.offline_store start_time = self.task.start_time end_time = self.task.end_time @@ -83,12 +82,12 @@ def execute(self, class SparkAggregationNode(DAGNode): def __init__( - self, - name: str, - input_node: DAGNode, - aggregations: List[Aggregation], - group_by_keys: List[str], - timestamp_col: str, + self, + name: str, + input_node: DAGNode, + aggregations: List[Aggregation], + group_by_keys: List[str], + timestamp_col: str, ): super().__init__(name) self.add_input(input_node) @@ -96,8 +95,7 @@ def __init__( self.group_by_keys = group_by_keys self.timestamp_col = timestamp_col - def execute(self, - context: ExecutionContext) -> DAGValue: + def execute(self, context: ExecutionContext) -> DAGValue: input_value = context.node_outputs[self.inputs[0].name] input_value.assert_format(DAGFormat.SPARK) input_df: DataFrame = input_value.data @@ -135,19 +133,14 @@ def execute(self, class SparkJoinNode(DAGNode): def __init__( - self, - name: str, - feature_node: DAGNode, - join_keys: List[str], - feature_view + self, name: str, feature_node: DAGNode, join_keys: List[str], feature_view ): super().__init__(name) self.join_keys = join_keys self.add_input(feature_node) self.feature_view = feature_view - def execute(self, - context: ExecutionContext) -> DAGValue: + def execute(self, context: ExecutionContext) -> DAGValue: feature_value = context.node_outputs[self.inputs[1].name] feature_value.assert_format(DAGFormat.SPARK) @@ -184,16 +177,12 @@ def execute(self, class SparkWriteNode(DAGNode): - def __init__(self, - name: str, - input_node: DAGNode, - task): + def __init__(self, name: str, input_node: DAGNode, task): super().__init__(name) self.task = task self.add_input(input_node) - def execute(self, - context: ExecutionContext) -> DAGValue: + def execute(self, context: ExecutionContext) -> DAGValue: spark_df: DataFrame = context.node_outputs[self.inputs[0].name].data feature_view = self.task.feature_view @@ -225,16 +214,12 @@ def execute(self, class SparkTransformationNode(DAGNode): - def __init__(self, - name: str, - input_node: DAGNode, - udf): + def __init__(self, name: str, input_node: DAGNode, udf): super().__init__(name) self.add_input(input_node) self.udf = udf - def execute(self, - context: ExecutionContext) -> DAGValue: + def execute(self, context: ExecutionContext) -> DAGValue: input_val = context.node_outputs[self.inputs[0].name] input_val.assert_format(DAGFormat.SPARK) From a398075ae38155a6eca9a42a2d5068fea8c9aca6 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Fri, 4 Apr 2025 16:47:18 -0700 Subject: [PATCH 04/18] fix linting Signed-off-by: HaoXuAI --- .../infra/compute_engines/dag/context.py | 21 +++++++++++++++++++ .../feast/infra/compute_engines/dag/model.py | 21 ------------------- .../feast/infra/compute_engines/dag/node.py | 2 +- .../feast/infra/compute_engines/dag/plan.py | 2 +- .../infra/compute_engines/spark/compute.py | 2 +- .../feast/infra/compute_engines/spark/node.py | 3 ++- 6 files changed, 26 insertions(+), 25 deletions(-) create mode 100644 sdk/python/feast/infra/compute_engines/dag/context.py diff --git a/sdk/python/feast/infra/compute_engines/dag/context.py b/sdk/python/feast/infra/compute_engines/dag/context.py new file mode 100644 index 00000000000..bebaf7e75a2 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/dag/context.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Union + +import pandas as pd + +from feast.entity import Entity +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.online_stores.online_store import OnlineStore +from feast.repo_config import RepoConfig + + +@dataclass +class ExecutionContext: + project: str + repo_config: RepoConfig + offline_store: OfflineStore + online_store: OnlineStore + entity_defs: List[Entity] + entity_df: Union[pd.DataFrame, None] = None + node_outputs: Dict[str, DAGValue] = field(default_factory=dict) diff --git a/sdk/python/feast/infra/compute_engines/dag/model.py b/sdk/python/feast/infra/compute_engines/dag/model.py index a488adabc98..f77fdd0b6c9 100644 --- a/sdk/python/feast/infra/compute_engines/dag/model.py +++ b/sdk/python/feast/infra/compute_engines/dag/model.py @@ -1,28 +1,7 @@ -from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Union - -import pandas as pd - -from feast.entity import Entity -from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.offline_stores.offline_store import OfflineStore -from feast.infra.online_stores.online_store import OnlineStore -from feast.repo_config import RepoConfig class DAGFormat(str, Enum): SPARK = "spark" PANDAS = "pandas" ARROW = "arrow" - - -@dataclass -class ExecutionContext: - project: str - repo_config: RepoConfig - offline_store: OfflineStore - online_store: OnlineStore - entity_defs: List[Entity] - entity_df: Union[pd.DataFrame, None] = None - node_outputs: Dict[str, DAGValue] = field(default_factory=dict) diff --git a/sdk/python/feast/infra/compute_engines/dag/node.py b/sdk/python/feast/infra/compute_engines/dag/node.py index f3dd0a0ad3d..e727811a1f4 100644 --- a/sdk/python/feast/infra/compute_engines/dag/node.py +++ b/sdk/python/feast/infra/compute_engines/dag/node.py @@ -3,7 +3,7 @@ from infra.compute_engines.dag.value import DAGValue -from feast.infra.compute_engines.dag.model import ExecutionContext +from feast.infra.compute_engines.dag.context import ExecutionContext class DAGNode(ABC): diff --git a/sdk/python/feast/infra/compute_engines/dag/plan.py b/sdk/python/feast/infra/compute_engines/dag/plan.py index 80a5cd5cd36..9af26e635dd 100644 --- a/sdk/python/feast/infra/compute_engines/dag/plan.py +++ b/sdk/python/feast/infra/compute_engines/dag/plan.py @@ -1,6 +1,6 @@ from typing import List -from feast.infra.compute_engines.dag.model import ExecutionContext +from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 9ad447293a8..cebe84fa0d7 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,7 +1,7 @@ import pyarrow as pa from feast.infra.compute_engines.base import ComputeEngine, HistoricalRetrievalTask -from feast.infra.compute_engines.dag.model import ExecutionContext +from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder from feast.infra.materialization.batch_materialization_engine import ( MaterializationJob, diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 00e1991e3d5..6d23dec90a0 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -11,7 +11,8 @@ from pyspark.sql import functions as F from feast.infra.compute_engines.base import HistoricalRetrievalTask -from feast.infra.compute_engines.dag.model import DAGFormat, ExecutionContext +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.materialization.batch_materialization_engine import MaterializationTask From 68c48dc7bcac849b6daefbf1736f1efd62f9455c Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Mon, 7 Apr 2025 00:03:47 -0700 Subject: [PATCH 05/18] fix linting Signed-off-by: HaoXuAI --- sdk/python/feast/feature_view.py | 1 + .../feast/infra/compute_engines/spark/node.py | 48 ++++++++++++------- .../spark/spark_dag_builder.py | 14 ++++-- sdk/python/feast/stream_feature_view.py | 3 ++ sdk/python/feast/transformation/base.py | 2 +- 5 files changed, 44 insertions(+), 24 deletions(-) diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 70881588a71..5259d5d2b90 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -222,6 +222,7 @@ def __init__( source=source, ) self.online = online + self.offline = offline self.materialization_intervals = [] def __hash__(self): diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 6d23dec90a0..02823eecb4f 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -2,20 +2,21 @@ from datetime import datetime from typing import Dict, List, Optional, Union, cast -from aggregation import Aggregation -from infra.materialization.contrib.spark.spark_materialization_engine import ( - _map_by_partition, - _SparkSerializedArtifacts, -) from pyspark.sql import DataFrame, Window from pyspark.sql import functions as F +from feast import BatchFeatureView, StreamFeatureView +from feast.aggregation import Aggregation from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.materialization.batch_materialization_engine import MaterializationTask +from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( + _map_by_partition, + _SparkSerializedArtifacts, +) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkRetrievalJob, ) @@ -116,12 +117,13 @@ def execute(self, context: ExecutionContext) -> DAGValue: time_window = self.aggregations[ 0 ].time_window # assume consistent window size for now + if time_window is None: + raise ValueError("Aggregation requires time_window but got None.") + window_duration_str = f"{int(time_window.total_seconds())} seconds" + grouped = input_df.groupBy( *self.group_by_keys, - F.window( - F.col(self.timestamp_col), - f"{int(time_window.total_seconds())} seconds", - ), + F.window(F.col(self.timestamp_col), window_duration_str), ).agg(*agg_exprs) else: # Simple aggregation @@ -134,7 +136,11 @@ def execute(self, context: ExecutionContext) -> DAGValue: class SparkJoinNode(DAGNode): def __init__( - self, name: str, feature_node: DAGNode, join_keys: List[str], feature_view + self, + name: str, + feature_node: DAGNode, + join_keys: List[str], + feature_view: Union[BatchFeatureView, StreamFeatureView], ): super().__init__(name) self.join_keys = join_keys @@ -178,28 +184,32 @@ def execute(self, context: ExecutionContext) -> DAGValue: class SparkWriteNode(DAGNode): - def __init__(self, name: str, input_node: DAGNode, task): + def __init__( + self, + name: str, + input_node: DAGNode, + feature_view: Union[BatchFeatureView, StreamFeatureView], + ): super().__init__(name) - self.task = task self.add_input(input_node) + self.feature_view = feature_view def execute(self, context: ExecutionContext) -> DAGValue: spark_df: DataFrame = context.node_outputs[self.inputs[0].name].data - feature_view = self.task.feature_view # ✅ 1. Write to offline store (if enabled) - if getattr(self.task, "offline", True): + if self.feature_view.online: context.offline_store.offline_write_batch( config=context.repo_config, - feature_view=feature_view, + feature_view=self.feature_view, table=spark_df, progress=None, ) # ✅ 2. Write to online store (if enabled) - if getattr(self.task, "online", False): + if self.feature_view.offline: spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( - feature_view=feature_view, repo_config=context.repo_config + feature_view=self.feature_view, repo_config=context.repo_config ) spark_df.mapInPandas( lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" @@ -209,7 +219,9 @@ def execute(self, context: ExecutionContext) -> DAGValue: data=spark_df, format=DAGFormat.SPARK, metadata={ - "written_to": "online+offline" if self.task.online else "offline" + "feature_view": self.feature_view.name, + "write_to_online": self.feature_view.online, + "write_to_offline": self.feature_view.offline, }, ) diff --git a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py index a68d06f237d..16abee72fe5 100644 --- a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py @@ -17,25 +17,29 @@ def build_source_node(self): def build_aggregation_node(self, input_node): agg_specs = self.feature_view.aggregations - node = SparkAggregationNode("agg", input_node, agg_specs) + group_by_keys = self.feature_view.entities + timestamp_col = self.feature_view.batch_source.timestamp_field + node = SparkAggregationNode( + "agg", input_node, agg_specs, group_by_keys, timestamp_col + ) self.nodes.append(node) return node def build_join_node(self, input_node): join_keys = self.feature_view.entities - node = SparkJoinNode("join", input_node, join_keys) + node = SparkJoinNode("join", input_node, join_keys, self.feature_view) self.nodes.append(node) return node def build_transformation_node(self, input_node): - udf_name = self.feature_view.transformation.name - udf = self.feature_view.transformation.udf + udf_name = self.feature_view.feature_transformation.name + udf = self.feature_view.feature_transformation.udf node = SparkTransformationNode(udf_name, input_node, udf) self.nodes.append(node) return node def build_output_nodes(self, input_node): - output_node = SparkWriteNode("output", input_node) + output_node = SparkWriteNode("output", input_node, self.feature_view) self.nodes.append(output_node) def build_validation_node(self, input_node): diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 2f134001a5a..67e953b4033 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -72,6 +72,7 @@ class StreamFeatureView(FeatureView): entity_columns: List[Field] features: List[Field] online: bool + offline: bool description: str tags: Dict[str, str] owner: str @@ -92,6 +93,7 @@ def __init__( ttl: timedelta = timedelta(days=0), tags: Optional[Dict[str, str]] = None, online: bool = True, + offline: bool = False, description: str = "", owner: str = "", schema: Optional[List[Field]] = None, @@ -138,6 +140,7 @@ def __init__( ttl=ttl, tags=tags, online=online, + offline=offline, description=description, owner=owner, schema=schema, diff --git a/sdk/python/feast/transformation/base.py b/sdk/python/feast/transformation/base.py index b02be0a6708..8ff1925d0e0 100644 --- a/sdk/python/feast/transformation/base.py +++ b/sdk/python/feast/transformation/base.py @@ -84,7 +84,7 @@ def __init__( self.mode = mode self.udf = udf self.udf_string = udf_string - self.name = name + self.name = name or udf.__name__ self.tags = tags or {} self.description = description self.owner = owner From 1c9ae31bf73598a57345242afc08266afaf8438a Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Mon, 7 Apr 2025 23:41:30 -0700 Subject: [PATCH 06/18] add doc Signed-off-by: HaoXuAI --- .../architecture/feature-transformation.md | 2 +- docs/reference/compute-engine/README.md | 87 +++++++++++++++++++ .../feast/infra/compute_engines/base.py | 4 + .../infra/compute_engines/dag/builder.py | 2 + .../feast/infra/compute_engines/dag/plan.py | 7 ++ .../infra/compute_engines/spark/compute.py | 43 +++++++-- .../feast/infra/compute_engines/spark/job.py | 51 +++++++++++ .../feast/infra/compute_engines/spark/node.py | 78 ++++++++++++++++- .../spark/spark_dag_builder.py | 26 +++++- 9 files changed, 285 insertions(+), 15 deletions(-) create mode 100644 docs/reference/compute-engine/README.md create mode 100644 sdk/python/feast/infra/compute_engines/spark/job.py diff --git a/docs/getting-started/architecture/feature-transformation.md b/docs/getting-started/architecture/feature-transformation.md index 1a15d4c3a51..6b09eb9f950 100644 --- a/docs/getting-started/architecture/feature-transformation.md +++ b/docs/getting-started/architecture/feature-transformation.md @@ -8,7 +8,7 @@ Feature transformations can be executed by three types of "transformation engine 1. The Feast Feature Server 2. An Offline Store (e.g., Snowflake, BigQuery, DuckDB, Spark, etc.) -3. A Stream processor (e.g., Flink or Spark Streaming) +3. An Compute Engine (see more [here](../../reference/compute-engine/README.md)) The three transformation engines are coupled with the [communication pattern used for writes](write-patterns.md). diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md new file mode 100644 index 00000000000..8edba81908c --- /dev/null +++ b/docs/reference/compute-engine/README.md @@ -0,0 +1,87 @@ +# 🧠 ComputeEngine (WIP) + +The `ComputeEngine` is Feast’s pluggable abstraction for executing feature pipelines — including transformations, aggregations, joins, and materialization/get_historical_features — on a backend of your choice (e.g., Spark, PyArrow, Pandas, Ray). + +It powers both: + +- `materialize()` – for batch and stream generation of features to offline/online stores +- `get_historical_features()` – for point-in-time correct training dataset retrieval + +This system builds and executes DAGs (Directed Acyclic Graphs) of typed operations, enabling modular and scalable workflows. + +--- + +## 🧠 Core Concepts + +| Component | Description | +|--------------------|--------------------------------------------------------------------| +| `ComputeEngine` | Interface for executing materialization and retrieval tasks | +| `DAGBuilder` | Constructs a DAG for a specific backend | +| `DAGNode` | Represents a logical operation (read, aggregate, join, etc.) | +| `ExecutionPlan` | Executes nodes in dependency order and stores intermediate outputs | +| `ExecutionContext` | Holds config, registry, stores, entity data, and node outputs | + +--- + +## ✨ Available Engines + +### 🔥 SparkComputeEngine + +- Distributed DAG execution via Apache Spark +- Supports point-in-time joins and large-scale materialization +- Integrates with `SparkOfflineStore` and `SparkMaterializationJob` + +### 🧪 LocalComputeEngine (WIP) + +- Runs on Arrow + Pandas (or optionally DuckDB) +- Designed for local dev, testing, or lightweight feature generation + +--- + +## 🛠️ Example DAG Flow +`Read → Aggregate → Join → Transform → Write` + +Each step is implemented as a `DAGNode`. An `ExecutionPlan` executes these nodes in topological order, caching `DAGValue` outputs. + +--- + +## 🧩 Implementing a Custom Compute Engine + +To create your own compute engine: + +1. **Implement the interface** + +```python +class MyComputeEngine(ComputeEngine): + def materialize(self, task: MaterializationTask) -> MaterializationJob: + ... + + def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: + ... +``` + +2. Create a DAGBuilder +```python +class MyDAGBuilder(DAGBuilder): + def build_source_node(self): ... + def build_aggregation_node(self, input_node): ... + def build_join_node(self, input_node): ... + def build_transformation_node(self, input_node): ... + def build_output_nodes(self, input_node): ... +``` + +3. Define DAGNode subclasses + * ReadNode, AggregationNode, JoinNode, WriteNode, etc. + * Each DAGNode.execute(context) -> DAGValue + +4. Return an ExecutionPlan + * ExecutionPlan stores DAG nodes in topological order + * Automatically handles intermediate value caching + +## 🚧 Roadmap +- [x] Modular, backend-agnostic DAG execution framework +- [x] Spark engine with native support for materialization + PIT joins +- [ ] PyArrow + Pandas engine for local compute +- [ ] Native multi-feature-view DAG optimization +- [ ] DAG validation, metrics, and debug output +- [ ] Scalable distributed backend via Ray or Polars diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index 9f4fffbc2d9..ec284372b11 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -30,6 +30,10 @@ class HistoricalRetrievalTask: class ComputeEngine(ABC): """ The interface that Feast uses to control the compute system that handles materialization and get_historical_features. + Each engine must implement: + - materialize(): to generate and persist features + - get_historical_features(): to perform point-in-time correct joins + Engines should use DAGBuilder and DAGNode abstractions to build modular, pluggable workflows. """ def __init__( diff --git a/sdk/python/feast/infra/compute_engines/dag/builder.py b/sdk/python/feast/infra/compute_engines/dag/builder.py index 3b5824de87d..a53a6f33bf4 100644 --- a/sdk/python/feast/infra/compute_engines/dag/builder.py +++ b/sdk/python/feast/infra/compute_engines/dag/builder.py @@ -9,6 +9,8 @@ class DAGBuilder(ABC): + """ """ + def __init__( self, feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], diff --git a/sdk/python/feast/infra/compute_engines/dag/plan.py b/sdk/python/feast/infra/compute_engines/dag/plan.py index 9af26e635dd..c6088a6b795 100644 --- a/sdk/python/feast/infra/compute_engines/dag/plan.py +++ b/sdk/python/feast/infra/compute_engines/dag/plan.py @@ -61,3 +61,10 @@ def execute(self, context: ExecutionContext) -> DAGValue: # Return output of final node return node_outputs[self.nodes[-1].name] + + def to_sql(self, context: ExecutionContext) -> str: + """ + Generate SQL query for the entire execution plan. + This is a placeholder and should be implemented in subclasses. + """ + raise NotImplementedError("SQL generation is not implemented yet.") diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index cebe84fa0d7..821fd08432b 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,8 +1,8 @@ -import pyarrow as pa - from feast.infra.compute_engines.base import ComputeEngine, HistoricalRetrievalTask from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder +from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session from feast.infra.materialization.batch_materialization_engine import ( MaterializationJob, MaterializationJobStatus, @@ -11,9 +11,27 @@ from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( SparkMaterializationJob, ) +from feast.infra.offline_stores.offline_store import RetrievalJob class SparkComputeEngine(ComputeEngine): + def __init__( + self, + offline_store, + online_store, + registry, + repo_config, + **kwargs, + ): + super().__init__( + offline_store=offline_store, + online_store=online_store, + registry=registry, + repo_config=repo_config, + **kwargs, + ) + self.spark_session = get_or_create_new_spark_session() + def materialize(self, task: MaterializationTask) -> MaterializationJob: job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" @@ -33,6 +51,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: # ✅ 2. Construct DAG and run it builder = SparkDAGBuilder( + spark_session=self.spark_session, feature_view=task.feature_view, task=task, ) @@ -50,7 +69,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: job_id=job_id, status=MaterializationJobStatus.ERROR, error=e ) - def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: + def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob: if isinstance(task.entity_df, str): raise NotImplementedError("SQL-based entity_df is not yet supported in DAG") @@ -70,11 +89,17 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: ) # ✅ 3. Construct and execute DAG - builder = SparkDAGBuilder(feature_view=task.feature_view, task=task) + builder = SparkDAGBuilder( + spark_session=self.spark_session, + feature_view=task.feature_view, + task=task, + ) plan = builder.build() - result = plan.execute(context=context) - spark_df = result.data # should be a Spark DataFrame - - # ✅ 4. Return as Arrow - return spark_df.toPandas().to_arrow() + return SparkDAGRetrievalJob( + plan=plan, + spark_session=self.spark_session, + context=context, + config=task.config, + full_feature_names=task.full_feature_name, + ) diff --git a/sdk/python/feast/infra/compute_engines/spark/job.py b/sdk/python/feast/infra/compute_engines/spark/job.py new file mode 100644 index 00000000000..3ddde5f9011 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/job.py @@ -0,0 +1,51 @@ +from typing import List, Optional + +import pyspark +from pyspark.sql import SparkSession + +from feast import OnDemandFeatureView, RepoConfig +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkRetrievalJob, +) +from feast.infra.offline_stores.offline_store import RetrievalMetadata + + +class SparkDAGRetrievalJob(SparkRetrievalJob): + def __init__( + self, + spark_session: SparkSession, + plan: ExecutionPlan, + context: ExecutionContext, + full_feature_names: bool, + config: RepoConfig, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, + ): + super().__init__( + spark_session=spark_session, + query="", + full_feature_names=full_feature_names, + config=config, + on_demand_feature_views=on_demand_feature_views, + metadata=metadata, + ) + self._plan = plan + self._context = context + self._metadata = metadata + self._spark_df = None # Will be populated on first access + + def _ensure_executed(self): + if self._spark_df is None: + result = self._plan.execute(self._context) + self._spark_df = result.data + + def to_spark_df(self) -> pyspark.sql.DataFrame: + self._ensure_executed() + assert self._spark_df is not None, "Execution plan did not produce a DataFrame" + return self._spark_df + + def to_sql(self) -> str: + self._ensure_executed() + return self._plan.to_sql(self._context) diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 02823eecb4f..5f3e6e341d8 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -2,13 +2,13 @@ from datetime import datetime from typing import Dict, List, Optional, Union, cast -from pyspark.sql import DataFrame, Window +from infra.compute_engines.dag.context import ExecutionContext +from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql import functions as F from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation from feast.infra.compute_engines.base import HistoricalRetrievalTask -from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue @@ -19,6 +19,11 @@ ) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkRetrievalJob, + _get_entity_df_event_timestamp_range, + _get_entity_schema, +) +from feast.infra.offline_stores.offline_utils import ( + infer_event_timestamp_from_entity_df, ) from feast.utils import _get_column_names @@ -37,7 +42,7 @@ class SparkJoinContext: full_feature_names: bool = False # apply feature view name prefix -class SparkReadNode(DAGNode): +class SparkMaterializationReadNode(DAGNode): def __init__( self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask] ): @@ -82,6 +87,73 @@ def execute(self, context: ExecutionContext) -> DAGValue: ) +class SparkHistoricalRetrievalReadNode(DAGNode): + def __init__( + self, name: str, task: HistoricalRetrievalTask, spark_session: SparkSession + ): + super().__init__(name) + self.task = task + self.spark_session = spark_session + + def execute(self, context: ExecutionContext) -> DAGValue: + """ + Read data from the offline store on the Spark engine. + TODO: Some functionality is duplicated with SparkMaterializationReadNode and spark get_historical_features. + Args: + context: SparkExecutionContext + Returns: DAGValue + """ + offline_store = context.offline_store + fv = self.task.feature_view + entity_df = context.entity_df + source = fv.batch_source + entities = context.entity_defs + + ( + join_key_columns, + feature_name_columns, + timestamp_field, + _, + ) = _get_column_names(fv, entities) + + entity_schema = _get_entity_schema( + spark_session=self.spark_session, + entity_df=entity_df, + ) + event_timestamp_col = infer_event_timestamp_from_entity_df( + entity_schema=entity_schema, + ) + entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( + entity_df, + event_timestamp_col, + self.spark_session, + ) + min_ts = entity_df_event_timestamp_range[0] + max_ts = entity_df_event_timestamp_range[1] + + retrieval_job = offline_store.pull_all_from_table_or_query( + config=context.repo_config, + data_source=source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + start_date=min_ts, + end_date=max_ts, + ) + spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() + + return DAGValue( + data=spark_df, + format=DAGFormat.SPARK, + metadata={ + "source": "feature_view_batch_source", + "timestamp_field": timestamp_field, + "start_date": min_ts, + "end_date": max_ts, + }, + ) + + class SparkAggregationNode(DAGNode): def __init__( self, diff --git a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py index 16abee72fe5..849b8370b67 100644 --- a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py @@ -1,17 +1,39 @@ +from typing import Union + +from pyspark.sql import SparkSession + +from feast import BatchFeatureView, FeatureView, StreamFeatureView +from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.dag.builder import DAGBuilder from feast.infra.compute_engines.spark.node import ( SparkAggregationNode, + SparkHistoricalRetrievalReadNode, SparkJoinNode, - SparkReadNode, + SparkMaterializationReadNode, SparkTransformationNode, SparkWriteNode, ) +from feast.infra.materialization.batch_materialization_engine import MaterializationTask class SparkDAGBuilder(DAGBuilder): + def __init__( + self, + spark_session: SparkSession, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + task: Union[MaterializationTask, HistoricalRetrievalTask], + ): + super().__init__(feature_view, task) + self.spark_session = spark_session + def build_source_node(self): source_path = self.feature_view.source.path - node = SparkReadNode("source", source_path) + if isinstance(self.task, MaterializationTask): + node = SparkMaterializationReadNode("source", source_path) + else: + node = SparkHistoricalRetrievalReadNode( + "source", source_path, self.spark_session + ) self.nodes.append(node) return node From 25af94e6ef099924bc60f3822475316eaa6e1f9a Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 9 Apr 2025 00:35:39 -0700 Subject: [PATCH 07/18] add test Signed-off-by: HaoXuAI --- .../infra/compute_engines/dag/context.py | 31 +++ .../feast/infra/compute_engines/dag/node.py | 25 +- .../feast/infra/compute_engines/spark/node.py | 38 +++- .../spark/spark_dag_builder.py | 4 +- .../infra/compute_engines/spark/test_nodes.py | 214 ++++++++++++++++++ 5 files changed, 300 insertions(+), 12 deletions(-) create mode 100644 sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py diff --git a/sdk/python/feast/infra/compute_engines/dag/context.py b/sdk/python/feast/infra/compute_engines/dag/context.py index bebaf7e75a2..f54a59ba36d 100644 --- a/sdk/python/feast/infra/compute_engines/dag/context.py +++ b/sdk/python/feast/infra/compute_engines/dag/context.py @@ -12,6 +12,37 @@ @dataclass class ExecutionContext: + """ + ExecutionContext holds all runtime information required to execute a DAG plan + within a ComputeEngine. It is passed into each DAGNode during execution and + contains shared context such as configuration, registry-backed entities, runtime + data (e.g. entity_df), and DAG evaluation state. + + Attributes: + project: Feast project name (namespace for features, entities, views). + + repo_config: Resolved RepoConfig containing provider and store configuration. + + offline_store: Reference to the configured OfflineStore implementation. + Used for loading raw feature data during materialization or retrieval. + + online_store: Reference to the OnlineStore implementation. + Used during materialization to write online features. + + entity_defs: List of Entity definitions fetched from the registry. + Used for resolving join keys, inferring timestamp columns, and + validating FeatureViews against schema. + + entity_df: A runtime DataFrame of entity rows used during historical + retrieval (e.g. for point-in-time join). Includes entity keys and + event timestamps. This is not part of the registry and is user-supplied + for training dataset generation. + + node_outputs: Internal cache of DAGValue outputs keyed by DAGNode name. + Automatically populated during ExecutionPlan execution to avoid redundant + computation. Used by downstream nodes to access their input data. + """ + project: str repo_config: RepoConfig offline_store: OfflineStore diff --git a/sdk/python/feast/infra/compute_engines/dag/node.py b/sdk/python/feast/infra/compute_engines/dag/node.py index e727811a1f4..033ae8f1780 100644 --- a/sdk/python/feast/infra/compute_engines/dag/node.py +++ b/sdk/python/feast/infra/compute_engines/dag/node.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod from typing import List -from infra.compute_engines.dag.value import DAGValue - from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.value import DAGValue class DAGNode(ABC): @@ -22,5 +21,27 @@ def add_input(self, node: "DAGNode"): self.inputs.append(node) node.outputs.append(self) + def get_input_values(self, context: ExecutionContext) -> List[DAGValue]: + input_values = [] + for input_node in self.inputs: + if input_node.name not in context.node_outputs: + raise KeyError( + f"Missing output for input node '{input_node.name}' in context." + ) + input_values.append(context.node_outputs[input_node.name]) + return input_values + + def get_single_input_value(self, context: ExecutionContext) -> DAGValue: + if len(self.inputs) != 1: + raise RuntimeError( + f"DAGNode '{self.name}' expected exactly 1 input, but got {len(self.inputs)}." + ) + input_node = self.inputs[0] + if input_node.name not in context.node_outputs: + raise KeyError( + f"Missing output for input node '{input_node.name}' in context." + ) + return context.node_outputs[input_node.name] + @abstractmethod def execute(self, context: ExecutionContext) -> DAGValue: ... diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 5f3e6e341d8..28e3414d3bc 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -2,13 +2,13 @@ from datetime import datetime from typing import Dict, List, Optional, Union, cast -from infra.compute_engines.dag.context import ExecutionContext from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql import functions as F from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue @@ -170,7 +170,7 @@ def __init__( self.timestamp_col = timestamp_col def execute(self, context: ExecutionContext) -> DAGValue: - input_value = context.node_outputs[self.inputs[0].name] + input_value = self.get_single_input_value(context) input_value.assert_format(DAGFormat.SPARK) input_df: DataFrame = input_value.data @@ -213,32 +213,52 @@ def __init__( feature_node: DAGNode, join_keys: List[str], feature_view: Union[BatchFeatureView, StreamFeatureView], + spark_session: SparkSession, ): super().__init__(name) self.join_keys = join_keys self.add_input(feature_node) self.feature_view = feature_view + self.spark_session = spark_session def execute(self, context: ExecutionContext) -> DAGValue: - feature_value = context.node_outputs[self.inputs[1].name] + feature_value = self.get_single_input_value(context) feature_value.assert_format(DAGFormat.SPARK) + feature_df = feature_value.data entity_df = context.entity_df - feature_df = feature_value.data + assert entity_df is not None, "entity_df must be set in ExecutionContext" # Get timestamp fields from feature view join_keys, feature_cols, ts_col, created_ts_col = _get_column_names( self.feature_view, context.entity_defs ) - entity_event_ts_col = "event_timestamp" # Standardized by SparkEntityLoadNode + # Rename entity_df event_timestamp_col to match feature_df + entity_schema = _get_entity_schema( + spark_session=self.spark_session, + entity_df=entity_df, + ) + event_timestamp_col = infer_event_timestamp_from_entity_df( + entity_schema=entity_schema, + ) + entity_ts_alias = "__entity_event_timestamp" + entity_df = entity_df.withColumnRenamed(event_timestamp_col, entity_ts_alias) # Perform left join + event timestamp filtering joined = feature_df.join(entity_df, on=join_keys, how="left") - joined = joined.filter(F.col(ts_col) <= F.col(entity_event_ts_col)) + joined = joined.filter(F.col(ts_col) <= F.col(entity_ts_alias)) + + # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl + if self.feature_view.ttl: + ttl_seconds = int(self.feature_view.ttl.total_seconds()) + lower_bound = F.col(entity_ts_alias) - F.expr( + f"INTERVAL {ttl_seconds} seconds" + ) + joined = joined.filter(F.col(ts_col) >= lower_bound) # Dedup with row_number - partition_cols = join_keys + [entity_event_ts_col] + partition_cols = join_keys + [entity_ts_alias] ordering = [F.col(ts_col).desc()] if created_ts_col: ordering.append(F.col(created_ts_col).desc()) @@ -267,7 +287,7 @@ def __init__( self.feature_view = feature_view def execute(self, context: ExecutionContext) -> DAGValue: - spark_df: DataFrame = context.node_outputs[self.inputs[0].name].data + spark_df: DataFrame = self.get_single_input_value(context).data # ✅ 1. Write to offline store (if enabled) if self.feature_view.online: @@ -305,7 +325,7 @@ def __init__(self, name: str, input_node: DAGNode, udf): self.udf = udf def execute(self, context: ExecutionContext) -> DAGValue: - input_val = context.node_outputs[self.inputs[0].name] + input_val = self.get_single_input_value(context) input_val.assert_format(DAGFormat.SPARK) transformed_df = self.udf(input_val.data) diff --git a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py index 849b8370b67..024d5e10de6 100644 --- a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py @@ -49,7 +49,9 @@ def build_aggregation_node(self, input_node): def build_join_node(self, input_node): join_keys = self.feature_view.entities - node = SparkJoinNode("join", input_node, join_keys, self.feature_view) + node = SparkJoinNode( + "join", input_node, join_keys, self.feature_view, self.spark_session + ) self.nodes.append(node) return node diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py new file mode 100644 index 00000000000..c0782095952 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -0,0 +1,214 @@ +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest +from pyspark.sql import SparkSession + +from feast.aggregation import Aggregation +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.spark.node import ( + SparkAggregationNode, + SparkJoinNode, + SparkTransformationNode, +) +from tests.example_repos.example_feature_repo_with_bfvs import ( + driver, + driver_hourly_stats_view, +) + + +@pytest.fixture(scope="session") +def spark_session(): + spark = ( + SparkSession.builder.appName("FeastSparkTests") + .master("local[*]") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + ) + + yield spark + + spark.stop() + + +def test_spark_transformation_node_executes_udf(spark_session): + # Sample Spark input + df = spark_session.createDataFrame( + [ + {"name": "John D.", "age": 30}, + {"name": "Alice G.", "age": 25}, + ] + ) + + def strip_extra_spaces(df): + from pyspark.sql.functions import col, regexp_replace + + return df.withColumn("name", regexp_replace(col("name"), "\\s+", " ")) + + # Wrap DAGValue + input_value = DAGValue(data=df, format=DAGFormat.SPARK) + + # Setup context + context = ExecutionContext( + project="test_proj", + repo_config=MagicMock(), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=MagicMock(), + entity_df=None, + node_outputs={"source": input_value}, + ) + + # Create and run the node + node = SparkTransformationNode( + "transform", input_node=MagicMock(), udf=strip_extra_spaces + ) + + node.inputs[0].name = "source" + result = node.execute(context) + + # Assert output + out_df = result.data + rows = out_df.orderBy("age").collect() + assert rows[0]["name"] == "Alice G." + assert rows[1]["name"] == "John D." + + +def test_spark_aggregation_node_executes_correctly(spark_session): + # Sample input DataFrame + input_df = spark_session.createDataFrame( + [ + {"user_id": 1, "value": 10}, + {"user_id": 1, "value": 20}, + {"user_id": 2, "value": 5}, + ] + ) + + # Define Aggregation spec (e.g. COUNT on value) + agg_specs = [Aggregation(column="value", function="count")] + + # Wrap as DAGValue + input_value = DAGValue(data=input_df, format=DAGFormat.SPARK) + + # Setup context + context = ExecutionContext( + project="test_project", + repo_config=MagicMock(), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=[], + entity_df=None, + node_outputs={"source": input_value}, + ) + + # Create and configure node + node = SparkAggregationNode( + name="agg", + input_node=MagicMock(), + aggregations=agg_specs, + group_by_keys=["user_id"], + timestamp_col="", + ) + node.inputs[0].name = "source" + + # Execute + result = node.execute(context) + result_df = result.data.orderBy("user_id").collect() + + # Validate output + assert result.format == DAGFormat.SPARK + assert result_df[0]["user_id"] == 1 + assert result_df[0]["count_value"] == 2 + assert result_df[1]["user_id"] == 2 + assert result_df[1]["count_value"] == 1 + + +def test_spark_join_node_executes_point_in_time_join(spark_session): + now = datetime.utcnow() + + # Entity DataFrame (point-in-time join targets) + entity_df = spark_session.createDataFrame( + [ + {"driver_id": 1001, "event_timestamp": now}, + {"driver_id": 1002, "event_timestamp": now}, + ] + ) + + # Feature DataFrame (raw features with timestamp) + feature_df = spark_session.createDataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": now - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.95, + "avg_daily_trips": 15, + }, + { + "driver_id": 1001, + "event_timestamp": now - timedelta(days=2), + "created": now - timedelta(hours=4), + "conv_rate": 0.75, + "acc_rate": 0.90, + "avg_daily_trips": 14, + }, + { + "driver_id": 1002, + "event_timestamp": now - timedelta(days=1), + "created": now - timedelta(hours=3), + "conv_rate": 0.7, + "acc_rate": 0.88, + "avg_daily_trips": 12, + }, + ] + ) + + # Wrap as DAGValues + feature_val = DAGValue(data=feature_df, format=DAGFormat.SPARK) + + # Setup FeatureView mock with batch_source metadata + feature_view = driver_hourly_stats_view + + # Set up context + context = ExecutionContext( + project="test_project", + repo_config=MagicMock(), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=[driver], + entity_df=entity_df, + node_outputs={ + "feature_node": feature_val, + }, + ) + + # Create the node and add input + node = SparkJoinNode( + name="join", + feature_node=MagicMock(name="feature_node"), + join_keys=["user_id"], + feature_view=feature_view, + spark_session=spark_session, + ) + node.inputs[0].name = "feature_node" # must match key in node_outputs + + # Execute the node + output = node.execute(context) + result_df = output.data.orderBy("driver_id").collect() + + # Assertions + assert output.format == DAGFormat.SPARK + assert len(result_df) == 2 + + # Validate result for driver_id = 1001 + assert result_df[0]["driver_id"] == 1001 + assert abs(result_df[0]["conv_rate"] - 0.8) < 1e-6 + assert result_df[0]["avg_daily_trips"] == 15 + + # Validate result for driver_id = 1002 + assert result_df[1]["driver_id"] == 1002 + assert abs(result_df[1]["conv_rate"] - 0.7) < 1e-6 + assert result_df[1]["avg_daily_trips"] == 12 From ed0cdf47996157acfc67c333db695f65cb3cdf50 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 9 Apr 2025 23:06:44 -0700 Subject: [PATCH 08/18] add integration test Signed-off-by: HaoXuAI --- sdk/python/feast/batch_feature_view.py | 7 + .../infra/compute_engines/dag/builder.py | 15 ++- .../spark/spark_dag_builder.py | 5 +- sdk/python/feast/stream_feature_view.py | 3 + .../example_feature_repo_with_bfvs_compute.py | 67 ++++++++++ .../compute_engines/spark/test_compute.py | 122 ++++++++++++++++++ 6 files changed, 212 insertions(+), 7 deletions(-) create mode 100644 sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py create mode 100644 sdk/python/tests/integration/compute_engines/spark/test_compute.py diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index c3c6784b79c..7b10b92e2f8 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -64,6 +64,7 @@ class BatchFeatureView(FeatureView): udf: Optional[Callable[[Any], Any]] udf_string: Optional[str] feature_transformation: Transformation + batch_engine: Optional[Field] def __init__( self, @@ -82,6 +83,7 @@ def __init__( udf: Optional[Callable[[Any], Any]], udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, + batch_engine: Optional[Field] = None, ): if not flags_helper.is_test(): warnings.warn( @@ -105,6 +107,7 @@ def __init__( self.feature_transformation = ( feature_transformation or self.get_feature_transformation() ) + self.batch_engine = batch_engine super().__init__( name=name, @@ -147,6 +150,7 @@ def batch_feature_view( source: Optional[DataSource] = None, tags: Optional[Dict[str, str]] = None, online: bool = True, + offline: bool = True, description: str = "", owner: str = "", schema: Optional[List[Field]] = None, @@ -154,11 +158,13 @@ def batch_feature_view( """ Args: name: + mode: entities: ttl: source: tags: online: + offline: description: owner: schema: @@ -184,6 +190,7 @@ def decorator(user_function): source=source, tags=tags, online=online, + offline=offline, description=description, owner=owner, schema=schema, diff --git a/sdk/python/feast/infra/compute_engines/dag/builder.py b/sdk/python/feast/infra/compute_engines/dag/builder.py index a53a6f33bf4..7fb311115f7 100644 --- a/sdk/python/feast/infra/compute_engines/dag/builder.py +++ b/sdk/python/feast/infra/compute_engines/dag/builder.py @@ -69,7 +69,14 @@ def build(self) -> ExecutionPlan: return ExecutionPlan(self.nodes) def _should_join(self): - return ( - self.feature_view.compute_config.join_strategy == "engine" - or self.task.config.compute_engine.get("point_in_time_join") == "engine" - ) + if hasattr(self.feature_view, "batch_engine"): + return hasattr(self.feature_view.batch_engine, "join_strategy") and ( + self.feature_view.batch_engine.join_strategy == "engine" + or self.task.config.batch_engine.get("point_in_time_join") == "engine" + ) + if hasattr(self.feature_view, "batch_engine_config"): + return hasattr(self.feature_view.stream_engine, "join_strategy") and ( + self.feature_view.stream_engine.join_strategy == "engine" + or self.task.config.stream_engine.get("point_in_time_join") == "engine" + ) + return False diff --git a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py index 024d5e10de6..545f51a59e7 100644 --- a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py @@ -27,12 +27,11 @@ def __init__( self.spark_session = spark_session def build_source_node(self): - source_path = self.feature_view.source.path if isinstance(self.task, MaterializationTask): - node = SparkMaterializationReadNode("source", source_path) + node = SparkMaterializationReadNode("source", self.task) else: node = SparkHistoricalRetrievalReadNode( - "source", source_path, self.spark_session + "source", self.task, self.spark_session ) self.nodes.append(node) return node diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 67e953b4033..083bbbe2771 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -83,6 +83,7 @@ class StreamFeatureView(FeatureView): udf: Optional[FunctionType] udf_string: Optional[str] feature_transformation: Optional[Transformation] + stream_engine: Optional[Field] def __init__( self, @@ -103,6 +104,7 @@ def __init__( udf: Optional[FunctionType] = None, udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, + stream_engine: Optional[Field] = None, ): if not flags_helper.is_test(): warnings.warn( @@ -133,6 +135,7 @@ def __init__( self.feature_transformation = ( feature_transformation or self.get_feature_transformation() ) + self.stream_engine = stream_engine super().__init__( name=name, diff --git a/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py new file mode 100644 index 00000000000..b6b26a58d7d --- /dev/null +++ b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py @@ -0,0 +1,67 @@ +from datetime import timedelta + +from pyspark.sql import DataFrame + +from feast import BatchFeatureView, Entity, Field, FileSource +from feast.types import Float32, Int32, Int64 + +driver_hourly_stats = FileSource( + path="%PARQUET_PATH%", # placeholder to be replaced by the test + timestamp_field="event_timestamp", + created_timestamp_column="created", +) + +driver = Entity( + name="driver_id", + description="driver id", +) + + +def transform_feature(df: DataFrame) -> DataFrame: + df = df.withColumn("conv_rate", df["conv_rate"] * 2) + df = df.withColumn("acc_rate", df["acc_rate"] * 2) + return df + + +driver_hourly_stats_view = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="python", + udf=transform_feature, + udf_string="transform_feature", + ttl=timedelta(days=1), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=True, + source=driver_hourly_stats, + tags={}, +) + + +global_daily_stats = FileSource( + path="%PARQUET_PATH_GLOBAL%", # placeholder to be replaced by the test + timestamp_field="event_timestamp", + created_timestamp_column="created", +) + + +global_stats_feature_view = BatchFeatureView( + name="global_daily_stats", + entities=None, + mode="python", + udf=lambda x: x, + ttl=timedelta(days=1), + schema=[ + Field(name="num_rides", dtype=Int32), + Field(name="avg_ride_length", dtype=Float32), + ], + online=True, + offline=True, + source=global_daily_stats, + tags={}, +) diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py new file mode 100644 index 00000000000..81d081d3152 --- /dev/null +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -0,0 +1,122 @@ +from datetime import datetime, timedelta +from typing import cast +from unittest.mock import MagicMock + +import pandas as pd +import pytest + +from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.compute_engines.spark.compute import SparkComputeEngine +from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, +) +from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import ( + SparkDataSourceCreator, +) +from tests.example_repos.example_feature_repo_with_bfvs_compute import ( + global_stats_feature_view, +) +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) +from tests.integration.feature_repos.universal.online_store.redis import ( + RedisOnlineStoreCreator, +) + + +@pytest.mark.integration +def test_spark_compute_engine_get_historical_features(): + now = datetime.utcnow() + + spark_config = IntegrationTestRepoConfig( + provider="local", + online_store_creator=RedisOnlineStoreCreator, + offline_store_creator=SparkDataSourceCreator, + batch_engine={"type": "spark.engine", "partitions": 10}, + ) + spark_environment = construct_test_environment( + spark_config, None, entity_key_serialization_version=2 + ) + + spark_environment.setup() + + # 👷 Prepare test parquet feature file + df = pd.DataFrame( + [ + { + "driver_id": 1001, + "event_timestamp": now - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.8, + "acc_rate": 0.95, + "avg_daily_trips": 15, + }, + { + "driver_id": 1001, + "event_timestamp": now - timedelta(days=2), + "created": now - timedelta(hours=3), + "conv_rate": 0.75, + "acc_rate": 0.9, + "avg_daily_trips": 14, + }, + { + "driver_id": 1002, + "event_timestamp": now - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.7, + "acc_rate": 0.88, + "avg_daily_trips": 12, + }, + ] + ) + + ds = spark_environment.data_source_creator.create_data_source( + df, + spark_environment.feature_store.project, + field_mapping={"ts_1": "ts"}, + ) + global_stats_feature_view.source = ds + + # 📥 Entity DataFrame to join with + entity_df = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": now}, + {"driver_id": 1002, "event_timestamp": now}, + ] + ) + + # 🛠 Build retrieval task + task = HistoricalRetrievalTask( + entity_df=entity_df, + feature_view=global_stats_feature_view, + full_feature_name=False, + registry=MagicMock(), + config=spark_environment.config, + start_time=now - timedelta(days=1), + end_time=now, + ) + + # 🧪 Run SparkComputeEngine + engine = SparkComputeEngine( + repo_config=task.config, + offline_store=SparkOfflineStore(), + online_store=MagicMock(), + registry=MagicMock(), + ) + + spark_dag_retrieval_job = engine.get_historical_features(task) + spark_df = cast(SparkDAGRetrievalJob, spark_dag_retrieval_job).to_spark_df() + df_out = spark_df.to_pandas().sort_values("driver_id").reset_index(drop=True) + + # ✅ Assert output + assert list(df_out.driver_id) == [1001, 1002] + assert abs(df_out.loc[0]["conv_rate"] - 0.8) < 1e-6 + assert abs(df_out.loc[1]["conv_rate"] - 0.7) < 1e-6 + + +if __name__ == "__main__": + test_spark_compute_engine_get_historical_features() From 6b57e94b4a47b68a6522d826edc0880208a92736 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sat, 12 Apr 2025 00:33:17 -0700 Subject: [PATCH 09/18] update API Signed-off-by: HaoXuAI --- docs/reference/compute-engine/README.md | 6 +- .../feast/infra/compute_engines/base.py | 2 +- .../feast/infra/compute_engines/dag/plan.py | 16 +--- .../{dag/builder.py => feature_builder.py} | 23 ++--- .../infra/compute_engines/spark/compute.py | 6 +- ...park_dag_builder.py => feature_builder.py} | 9 +- .../feast/infra/compute_engines/spark/node.py | 81 ++++++++-------- .../example_feature_repo_with_bfvs_compute.py | 9 +- .../compute_engines/spark/test_compute.py | 93 ++++++++++++------- 9 files changed, 127 insertions(+), 118 deletions(-) rename sdk/python/feast/infra/compute_engines/{dag/builder.py => feature_builder.py} (71%) rename sdk/python/feast/infra/compute_engines/spark/{spark_dag_builder.py => feature_builder.py} (90%) diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index 8edba81908c..2b475325201 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -16,7 +16,7 @@ This system builds and executes DAGs (Directed Acyclic Graphs) of typed operatio | Component | Description | |--------------------|--------------------------------------------------------------------| | `ComputeEngine` | Interface for executing materialization and retrieval tasks | -| `DAGBuilder` | Constructs a DAG for a specific backend | +| `FeatureBuilder` | Constructs a DAG for a specific backend | | `DAGNode` | Represents a logical operation (read, aggregate, join, etc.) | | `ExecutionPlan` | Executes nodes in dependency order and stores intermediate outputs | | `ExecutionContext` | Holds config, registry, stores, entity data, and node outputs | @@ -60,9 +60,9 @@ class MyComputeEngine(ComputeEngine): ... ``` -2. Create a DAGBuilder +2. Create a FeatureBuilder ```python -class MyDAGBuilder(DAGBuilder): +class MyDAGBuilder(FeatureBuilder): def build_source_node(self): ... def build_aggregation_node(self, input_node): ... def build_join_node(self, input_node): ... diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index ec284372b11..49c4fabb913 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -33,7 +33,7 @@ class ComputeEngine(ABC): Each engine must implement: - materialize(): to generate and persist features - get_historical_features(): to perform point-in-time correct joins - Engines should use DAGBuilder and DAGNode abstractions to build modular, pluggable workflows. + Engines should use FeatureBuilder and DAGNode abstractions to build modular, pluggable workflows. """ def __init__( diff --git a/sdk/python/feast/infra/compute_engines/dag/plan.py b/sdk/python/feast/infra/compute_engines/dag/plan.py index c6088a6b795..130a894bda8 100644 --- a/sdk/python/feast/infra/compute_engines/dag/plan.py +++ b/sdk/python/feast/infra/compute_engines/dag/plan.py @@ -44,23 +44,17 @@ def __init__(self, nodes: List[DAGNode]): self.nodes = nodes def execute(self, context: ExecutionContext) -> DAGValue: - node_outputs: dict[str, DAGValue] = {} + context.node_outputs = {} for node in self.nodes: - # Gather input values for input_node in node.inputs: - if input_node.name not in node_outputs: - node_outputs[input_node.name] = input_node.execute(context) + if input_node.name not in context.node_outputs: + context.node_outputs[input_node.name] = input_node.execute(context) - # Execute this node output = node.execute(context) - node_outputs[node.name] = output + context.node_outputs[node.name] = output - # Inject into context for downstream access - context.node_outputs = node_outputs - - # Return output of final node - return node_outputs[self.nodes[-1].name] + return context.node_outputs[self.nodes[-1].name] def to_sql(self, context: ExecutionContext) -> str: """ diff --git a/sdk/python/feast/infra/compute_engines/dag/builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py similarity index 71% rename from sdk/python/feast/infra/compute_engines/dag/builder.py rename to sdk/python/feast/infra/compute_engines/feature_builder.py index 7fb311115f7..a4f7425b3f9 100644 --- a/sdk/python/feast/infra/compute_engines/dag/builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -8,8 +8,11 @@ from feast.infra.materialization.batch_materialization_engine import MaterializationTask -class DAGBuilder(ABC): - """ """ +class FeatureBuilder(ABC): + """ + Translates a FeatureView definition and execution task into an execution DAG. + This builder is engine-specific and returns an ExecutionPlan that ComputeEngine can run. + """ def __init__( self, @@ -53,8 +56,7 @@ def build(self) -> ExecutionPlan: ): last_node = self.build_aggregation_node(last_node) - if self._should_join(): - last_node = self.build_join_node(last_node) + last_node = self.build_join_node(last_node) if ( hasattr(self.feature_view, "feature_transformation") @@ -67,16 +69,3 @@ def build(self) -> ExecutionPlan: self.build_output_nodes(last_node) return ExecutionPlan(self.nodes) - - def _should_join(self): - if hasattr(self.feature_view, "batch_engine"): - return hasattr(self.feature_view.batch_engine, "join_strategy") and ( - self.feature_view.batch_engine.join_strategy == "engine" - or self.task.config.batch_engine.get("point_in_time_join") == "engine" - ) - if hasattr(self.feature_view, "batch_engine_config"): - return hasattr(self.feature_view.stream_engine, "join_strategy") and ( - self.feature_view.stream_engine.join_strategy == "engine" - or self.task.config.stream_engine.get("point_in_time_join") == "engine" - ) - return False diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 821fd08432b..5b0540e5bd2 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,7 +1,7 @@ from feast.infra.compute_engines.base import ComputeEngine, HistoricalRetrievalTask from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob -from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session from feast.infra.materialization.batch_materialization_engine import ( MaterializationJob, @@ -50,7 +50,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: ) # ✅ 2. Construct DAG and run it - builder = SparkDAGBuilder( + builder = SparkFeatureBuilder( spark_session=self.spark_session, feature_view=task.feature_view, task=task, @@ -89,7 +89,7 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob ) # ✅ 3. Construct and execute DAG - builder = SparkDAGBuilder( + builder = SparkFeatureBuilder( spark_session=self.spark_session, feature_view=task.feature_view, task=task, diff --git a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py similarity index 90% rename from sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py rename to sdk/python/feast/infra/compute_engines/spark/feature_builder.py index 545f51a59e7..d1b540b35e3 100644 --- a/sdk/python/feast/infra/compute_engines/spark/spark_dag_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -4,7 +4,7 @@ from feast import BatchFeatureView, FeatureView, StreamFeatureView from feast.infra.compute_engines.base import HistoricalRetrievalTask -from feast.infra.compute_engines.dag.builder import DAGBuilder +from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.spark.node import ( SparkAggregationNode, SparkHistoricalRetrievalReadNode, @@ -16,7 +16,7 @@ from feast.infra.materialization.batch_materialization_engine import MaterializationTask -class SparkDAGBuilder(DAGBuilder): +class SparkFeatureBuilder(FeatureBuilder): def __init__( self, spark_session: SparkSession, @@ -62,8 +62,9 @@ def build_transformation_node(self, input_node): return node def build_output_nodes(self, input_node): - output_node = SparkWriteNode("output", input_node, self.feature_view) - self.nodes.append(output_node) + node = SparkWriteNode("output", input_node, self.feature_view) + self.nodes.append(node) + return node def build_validation_node(self, input_node): pass diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 28e3414d3bc..f636e4cbd1c 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -19,13 +19,12 @@ ) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkRetrievalJob, - _get_entity_df_event_timestamp_range, _get_entity_schema, ) from feast.infra.offline_stores.offline_utils import ( infer_event_timestamp_from_entity_df, ) -from feast.utils import _get_column_names +from feast.utils import _get_column_names, _get_fields_with_aliases @dataclass @@ -103,9 +102,7 @@ def execute(self, context: ExecutionContext) -> DAGValue: context: SparkExecutionContext Returns: DAGValue """ - offline_store = context.offline_store fv = self.task.feature_view - entity_df = context.entity_df source = fv.batch_source entities = context.entity_defs @@ -113,34 +110,38 @@ def execute(self, context: ExecutionContext) -> DAGValue: join_key_columns, feature_name_columns, timestamp_field, - _, + created_timestamp_column, ) = _get_column_names(fv, entities) - entity_schema = _get_entity_schema( - spark_session=self.spark_session, - entity_df=entity_df, + # TODO: Use pull_all_from_table_or_query when it supports not filtering by timestamp + # retrieval_job = offline_store.pull_all_from_table_or_query( + # config=context.repo_config, + # data_source=source, + # join_key_columns=join_key_columns, + # feature_name_columns=feature_name_columns, + # timestamp_field=timestamp_field, + # start_date=min_ts, + # end_date=max_ts, + # ) + # spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() + + columns = join_key_columns + feature_name_columns + [timestamp_field] + if created_timestamp_column: + columns.append(created_timestamp_column) + + (fields_with_aliases, aliases) = _get_fields_with_aliases( + fields=columns, + field_mappings=source.field_mapping, ) - event_timestamp_col = infer_event_timestamp_from_entity_df( - entity_schema=entity_schema, - ) - entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, - event_timestamp_col, - self.spark_session, - ) - min_ts = entity_df_event_timestamp_range[0] - max_ts = entity_df_event_timestamp_range[1] + fields_with_alias_string = ", ".join(fields_with_aliases) - retrieval_job = offline_store.pull_all_from_table_or_query( - config=context.repo_config, - data_source=source, - join_key_columns=join_key_columns, - feature_name_columns=feature_name_columns, - timestamp_field=timestamp_field, - start_date=min_ts, - end_date=max_ts, - ) - spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() + from_expression = source.get_table_query_string() + + query = f""" + SELECT {fields_with_alias_string} + FROM {from_expression} + """ + spark_df = self.spark_session.sql(query) return DAGValue( data=spark_df, @@ -148,8 +149,6 @@ def execute(self, context: ExecutionContext) -> DAGValue: metadata={ "source": "feature_view_batch_source", "timestamp_field": timestamp_field, - "start_date": min_ts, - "end_date": max_ts, }, ) @@ -243,7 +242,9 @@ def execute(self, context: ExecutionContext) -> DAGValue: entity_schema=entity_schema, ) entity_ts_alias = "__entity_event_timestamp" - entity_df = entity_df.withColumnRenamed(event_timestamp_col, entity_ts_alias) + entity_df = self.spark_session.createDataFrame(entity_df).withColumnRenamed( + event_timestamp_col, entity_ts_alias + ) # Perform left join + event timestamp filtering joined = feature_df.join(entity_df, on=join_keys, how="left") @@ -288,21 +289,17 @@ def __init__( def execute(self, context: ExecutionContext) -> DAGValue: spark_df: DataFrame = self.get_single_input_value(context).data + spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( + feature_view=self.feature_view, repo_config=context.repo_config + ) # ✅ 1. Write to offline store (if enabled) - if self.feature_view.online: - context.offline_store.offline_write_batch( - config=context.repo_config, - feature_view=self.feature_view, - table=spark_df, - progress=None, - ) + if self.feature_view.offline: + # TODO: Update _map_by_partition to be able to write to offline store + pass # ✅ 2. Write to online store (if enabled) - if self.feature_view.offline: - spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( - feature_view=self.feature_view, repo_config=context.repo_config - ) + if self.feature_view.online: spark_df.mapInPandas( lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" ).count() diff --git a/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py index b6b26a58d7d..31fda14a66c 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py +++ b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py @@ -16,6 +16,13 @@ description="driver id", ) +schema = [ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), +] + def transform_feature(df: DataFrame) -> DataFrame: df = df.withColumn("conv_rate", df["conv_rate"] * 2) @@ -42,14 +49,12 @@ def transform_feature(df: DataFrame) -> DataFrame: tags={}, ) - global_daily_stats = FileSource( path="%PARQUET_PATH_GLOBAL%", # placeholder to be replaced by the test timestamp_field="event_timestamp", created_timestamp_column="created", ) - global_stats_feature_view = BatchFeatureView( name="global_daily_stats", entities=None, diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 81d081d3152..7fbdd0c4300 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -5,6 +5,7 @@ import pandas as pd import pytest +from feast import BatchFeatureView from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.spark.compute import SparkComputeEngine from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob @@ -15,7 +16,9 @@ SparkDataSourceCreator, ) from tests.example_repos.example_feature_repo_with_bfvs_compute import ( - global_stats_feature_view, + driver, + schema, + transform_feature, ) from tests.integration.feature_repos.integration_test_repo_config import ( IntegrationTestRepoConfig, @@ -30,7 +33,10 @@ @pytest.mark.integration def test_spark_compute_engine_get_historical_features(): - now = datetime.utcnow() + now = datetime.now() + today = datetime.today() + yesterday = today - timedelta(days=1) + last_week = today - timedelta(days=7) spark_config = IntegrationTestRepoConfig( provider="local", @@ -41,15 +47,16 @@ def test_spark_compute_engine_get_historical_features(): spark_environment = construct_test_environment( spark_config, None, entity_key_serialization_version=2 ) - spark_environment.setup() + fs = spark_environment.feature_store + registry = fs.registry # 👷 Prepare test parquet feature file df = pd.DataFrame( [ { "driver_id": 1001, - "event_timestamp": now - timedelta(days=1), + "event_timestamp": yesterday, "created": now - timedelta(hours=2), "conv_rate": 0.8, "acc_rate": 0.95, @@ -57,7 +64,7 @@ def test_spark_compute_engine_get_historical_features(): }, { "driver_id": 1001, - "event_timestamp": now - timedelta(days=2), + "event_timestamp": last_week, "created": now - timedelta(hours=3), "conv_rate": 0.75, "acc_rate": 0.9, @@ -65,7 +72,7 @@ def test_spark_compute_engine_get_historical_features(): }, { "driver_id": 1002, - "event_timestamp": now - timedelta(days=1), + "event_timestamp": yesterday, "created": now - timedelta(hours=2), "conv_rate": 0.7, "acc_rate": 0.88, @@ -73,49 +80,65 @@ def test_spark_compute_engine_get_historical_features(): }, ] ) - ds = spark_environment.data_source_creator.create_data_source( df, spark_environment.feature_store.project, - field_mapping={"ts_1": "ts"}, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + driver_stats_fv = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="python", + udf=transform_feature, + udf_string="transform_feature", + ttl=timedelta(days=2), + schema=schema, + online=False, + offline=False, + source=ds, ) - global_stats_feature_view.source = ds # 📥 Entity DataFrame to join with entity_df = pd.DataFrame( [ - {"driver_id": 1001, "event_timestamp": now}, - {"driver_id": 1002, "event_timestamp": now}, + {"driver_id": 1001, "event_timestamp": today}, + {"driver_id": 1002, "event_timestamp": today}, ] ) - # 🛠 Build retrieval task - task = HistoricalRetrievalTask( - entity_df=entity_df, - feature_view=global_stats_feature_view, - full_feature_name=False, - registry=MagicMock(), - config=spark_environment.config, - start_time=now - timedelta(days=1), - end_time=now, - ) + try: + fs.apply([driver, driver_stats_fv]) - # 🧪 Run SparkComputeEngine - engine = SparkComputeEngine( - repo_config=task.config, - offline_store=SparkOfflineStore(), - online_store=MagicMock(), - registry=MagicMock(), - ) + # 🛠 Build retrieval task + task = HistoricalRetrievalTask( + entity_df=entity_df, + feature_view=driver_stats_fv, + full_feature_name=False, + registry=registry, + config=spark_environment.config, + start_time=now - timedelta(days=1), + end_time=now, + ) + + # 🧪 Run SparkComputeEngine + engine = SparkComputeEngine( + repo_config=task.config, + offline_store=SparkOfflineStore(), + online_store=MagicMock(), + registry=registry, + ) - spark_dag_retrieval_job = engine.get_historical_features(task) - spark_df = cast(SparkDAGRetrievalJob, spark_dag_retrieval_job).to_spark_df() - df_out = spark_df.to_pandas().sort_values("driver_id").reset_index(drop=True) + spark_dag_retrieval_job = engine.get_historical_features(task) + spark_df = cast(SparkDAGRetrievalJob, spark_dag_retrieval_job).to_spark_df() + df_out = spark_df.to_pandas_on_spark() - # ✅ Assert output - assert list(df_out.driver_id) == [1001, 1002] - assert abs(df_out.loc[0]["conv_rate"] - 0.8) < 1e-6 - assert abs(df_out.loc[1]["conv_rate"] - 0.7) < 1e-6 + # ✅ Assert output + assert df_out.driver_id.to_list() == [1001, 1002] + assert abs(df_out["conv_rate"].to_list()[0] - 1.6) < 1e-6 + assert abs(df_out["conv_rate"].to_list()[1] - 1.4) < 1e-6 + finally: + spark_environment.teardown() if __name__ == "__main__": From 227f8f493e38d64a21f76cc758f6bb0431b7b92c Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sat, 12 Apr 2025 00:45:15 -0700 Subject: [PATCH 10/18] update API Signed-off-by: HaoXuAI --- docs/reference/compute-engine/README.md | 21 ++++++++++++------- .../feast/infra/compute_engines/base.py | 18 ++-------------- .../infra/compute_engines/feature_builder.py | 2 +- .../infra/compute_engines/spark/compute.py | 3 ++- .../feast/infra/compute_engines/spark/node.py | 2 +- .../feast/infra/compute_engines/tasks.py | 19 +++++++++++++++++ .../compute_engines/spark/test_compute.py | 2 +- 7 files changed, 39 insertions(+), 28 deletions(-) create mode 100644 sdk/python/feast/infra/compute_engines/tasks.py diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index 2b475325201..4324164b66c 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -13,13 +13,13 @@ This system builds and executes DAGs (Directed Acyclic Graphs) of typed operatio ## 🧠 Core Concepts -| Component | Description | -|--------------------|--------------------------------------------------------------------| -| `ComputeEngine` | Interface for executing materialization and retrieval tasks | -| `FeatureBuilder` | Constructs a DAG for a specific backend | -| `DAGNode` | Represents a logical operation (read, aggregate, join, etc.) | -| `ExecutionPlan` | Executes nodes in dependency order and stores intermediate outputs | -| `ExecutionContext` | Holds config, registry, stores, entity data, and node outputs | +| Component | Description | +|--------------------|----------------------------------------------------------------------| +| `ComputeEngine` | Interface for executing materialization and retrieval tasks | +| `FeatureBuilder` | Constructs a DAG from Feature View definition for a specific backend | +| `DAGNode` | Represents a logical operation (read, aggregate, join, etc.) | +| `ExecutionPlan` | Executes nodes in dependency order and stores intermediate outputs | +| `ExecutionContext` | Holds config, registry, stores, entity data, and node outputs | --- @@ -52,6 +52,9 @@ To create your own compute engine: 1. **Implement the interface** ```python +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.materialization.batch_materialization_engine import MaterializationTask, MaterializationJob +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask class MyComputeEngine(ComputeEngine): def materialize(self, task: MaterializationTask) -> MaterializationJob: ... @@ -62,7 +65,9 @@ class MyComputeEngine(ComputeEngine): 2. Create a FeatureBuilder ```python -class MyDAGBuilder(FeatureBuilder): +from feast.infra.compute_engines.feature_builder import FeatureBuilder + +class CustomFeatureBuilder(FeatureBuilder): def build_source_node(self): ... def build_aggregation_node(self, input_node): ... def build_join_node(self, input_node): ... diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index 49c4fabb913..f05d952c801 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -1,12 +1,9 @@ from abc import ABC -from dataclasses import dataclass -from datetime import datetime -from typing import Union -import pandas as pd import pyarrow as pa -from feast import BatchFeatureView, RepoConfig, StreamFeatureView +from feast import RepoConfig +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask from feast.infra.materialization.batch_materialization_engine import ( MaterializationJob, MaterializationTask, @@ -16,17 +13,6 @@ from feast.infra.registry.registry import Registry -@dataclass -class HistoricalRetrievalTask: - entity_df: Union[pd.DataFrame, str] - feature_view: Union[BatchFeatureView, StreamFeatureView] - full_feature_name: bool - registry: Registry - config: RepoConfig - start_time: datetime - end_time: datetime - - class ComputeEngine(ABC): """ The interface that Feast uses to control the compute system that handles materialization and get_historical_features. diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index a4f7425b3f9..9e584c86ca1 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -2,9 +2,9 @@ from typing import Union from feast import BatchFeatureView, FeatureView, StreamFeatureView -from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask from feast.infra.materialization.batch_materialization_engine import MaterializationTask diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 5b0540e5bd2..144fc002131 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,8 +1,9 @@ -from feast.infra.compute_engines.base import ComputeEngine, HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask from feast.infra.materialization.batch_materialization_engine import ( MaterializationJob, MaterializationJobStatus, diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index f636e4cbd1c..eb38335dc33 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -7,11 +7,11 @@ from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation -from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask from feast.infra.materialization.batch_materialization_engine import MaterializationTask from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( _map_by_partition, diff --git a/sdk/python/feast/infra/compute_engines/tasks.py b/sdk/python/feast/infra/compute_engines/tasks.py new file mode 100644 index 00000000000..a1fdd4940dc --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/tasks.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Union + +import pandas as pd + +from feast import BatchFeatureView, RepoConfig, StreamFeatureView +from feast.infra.registry.registry import Registry + + +@dataclass +class HistoricalRetrievalTask: + entity_df: Union[pd.DataFrame, str] + feature_view: Union[BatchFeatureView, StreamFeatureView] + full_feature_name: bool + registry: Registry + config: RepoConfig + start_time: datetime + end_time: datetime diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 7fbdd0c4300..7bb4f52855c 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -6,9 +6,9 @@ import pytest from feast import BatchFeatureView -from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.spark.compute import SparkComputeEngine from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, ) From e9362de5623ca333e6db8cbc61003853be50dcc4 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sat, 12 Apr 2025 21:14:48 -0700 Subject: [PATCH 11/18] update API Signed-off-by: HaoXuAI --- sdk/python/feast/batch_feature_view.py | 4 + .../feast/infra/compute_engines/base.py | 40 ++++++++ .../infra/compute_engines/dag/context.py | 17 +++- .../infra/compute_engines/feature_builder.py | 20 +++- .../infra/compute_engines/spark/compute.py | 73 ++++++-------- .../compute_engines/spark/feature_builder.py | 17 ++++ .../feast/infra/compute_engines/spark/job.py | 9 +- .../feast/infra/compute_engines/spark/node.py | 99 +++++++++++++++---- .../feast/infra/compute_engines/tasks.py | 4 +- .../example_feature_repo_with_bfvs_compute.py | 4 +- .../compute_engines/spark/test_compute.py | 32 ++++-- .../infra/compute_engines/spark/test_nodes.py | 27 ++++- 12 files changed, 263 insertions(+), 83 deletions(-) diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 7b10b92e2f8..1616761e9ca 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -6,6 +6,7 @@ import dill from feast import flags_helper +from feast.aggregation import Aggregation from feast.data_source import DataSource from feast.entity import Entity from feast.feature_view import FeatureView @@ -65,6 +66,7 @@ class BatchFeatureView(FeatureView): udf_string: Optional[str] feature_transformation: Transformation batch_engine: Optional[Field] + aggregations: Optional[List[Aggregation]] def __init__( self, @@ -84,6 +86,7 @@ def __init__( udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, batch_engine: Optional[Field] = None, + aggregations: Optional[List[Aggregation]] = None, ): if not flags_helper.is_test(): warnings.warn( @@ -108,6 +111,7 @@ def __init__( feature_transformation or self.get_feature_transformation() ) self.batch_engine = batch_engine + self.aggregations = aggregations or [] super().__init__( name=name, diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index f05d952c801..bac1005cb3a 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -1,8 +1,10 @@ from abc import ABC +from typing import Union import pyarrow as pa from feast import RepoConfig +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.compute_engines.tasks import HistoricalRetrievalTask from feast.infra.materialization.batch_materialization_engine import ( MaterializationJob, @@ -11,6 +13,7 @@ from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.registry import Registry +from feast.utils import _get_column_names class ComputeEngine(ABC): @@ -41,3 +44,40 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: raise NotImplementedError + + def get_execution_context( + self, + task: Union[MaterializationTask, HistoricalRetrievalTask], + ) -> ExecutionContext: + entity_defs = [ + self.registry.get_entity(name, task.project) + for name in task.feature_view.entities + ] + entity_df = None + if task.entity_df is not None: + entity_df = task.entity_df + + column_info = self.get_column_info(task) + return ExecutionContext( + project=task.project, + repo_config=self.repo_config, + offline_store=self.offline_store, + online_store=self.online_store, + entity_defs=entity_defs, + column_info=column_info, + entity_df=entity_df, + ) + + def get_column_info( + self, + task: Union[MaterializationTask, HistoricalRetrievalTask], + ) -> ColumnInfo: + join_keys, feature_cols, ts_col, created_ts_col = _get_column_names( + task.feature_view, self.registry.list_entities(task.project) + ) + return ColumnInfo( + join_keys=join_keys, + feature_cols=feature_cols, + ts_col=ts_col, + created_ts_col=created_ts_col, + ) diff --git a/sdk/python/feast/infra/compute_engines/dag/context.py b/sdk/python/feast/infra/compute_engines/dag/context.py index f54a59ba36d..8b170b67766 100644 --- a/sdk/python/feast/infra/compute_engines/dag/context.py +++ b/sdk/python/feast/infra/compute_engines/dag/context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import pandas as pd @@ -10,6 +10,20 @@ from feast.repo_config import RepoConfig +@dataclass +class ColumnInfo: + join_keys: List[str] + feature_cols: List[str] + ts_col: str + created_ts_col: Optional[str] + + def __iter__(self): + yield self.join_keys + yield self.feature_cols + yield self.ts_col + yield self.created_ts_col + + @dataclass class ExecutionContext: """ @@ -47,6 +61,7 @@ class ExecutionContext: repo_config: RepoConfig offline_store: OfflineStore online_store: OnlineStore + column_info: ColumnInfo entity_defs: List[Entity] entity_df: Union[pd.DataFrame, None] = None node_outputs: Dict[str, DAGValue] = field(default_factory=dict) diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index 9e584c86ca1..54f7587bfdd 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -35,6 +35,14 @@ def build_aggregation_node(self, input_node): def build_join_node(self, input_node): raise NotImplementedError + @abstractmethod + def build_filter_node(self, input_node): + raise NotImplementedError + + @abstractmethod + def build_dedup_node(self, input_node): + raise NotImplementedError + @abstractmethod def build_transformation_node(self, input_node): raise NotImplementedError @@ -50,13 +58,17 @@ def build_validation_node(self, input_node): def build(self) -> ExecutionPlan: last_node = self.build_source_node() + # PIT join entities to the feature data, and perform filtering + last_node = self.build_join_node(last_node) + last_node = self.build_filter_node(last_node) + if ( - hasattr(self.feature_view, "aggregation") - and self.feature_view.aggregation is not None + hasattr(self.feature_view, "aggregations") + and self.feature_view.aggregations is not None ): last_node = self.build_aggregation_node(last_node) - - last_node = self.build_join_node(last_node) + else: + last_node = self.build_dedup_node(last_node) if ( hasattr(self.feature_view, "feature_transformation") diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 144fc002131..13f402df7a7 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -36,21 +36,11 @@ def __init__( def materialize(self, task: MaterializationTask) -> MaterializationJob: job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" - try: - # ✅ 1. Build typed execution context - entities = [] - for entity_name in task.feature_view.entities: - entities.append(self.registry.get_entity(entity_name, task.project)) - - context = ExecutionContext( - project=task.project, - repo_config=self.repo_config, - offline_store=self.offline_store, - online_store=self.online_store, - entity_defs=entities, - ) + # ✅ 1. Build typed execution context + context = self.get_execution_context(task) - # ✅ 2. Construct DAG and run it + try: + # ✅ 2. Construct Feature Builder and run it builder = SparkFeatureBuilder( spark_session=self.spark_session, feature_view=task.feature_view, @@ -74,33 +64,32 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob if isinstance(task.entity_df, str): raise NotImplementedError("SQL-based entity_df is not yet supported in DAG") - # ✅ 2. Build typed execution context - entity_defs = [ - task.registry.get_entity(name, task.config.project) - for name in task.feature_view.entities - ] - - context = ExecutionContext( - project=task.config.project, - repo_config=task.config, - offline_store=self.offline_store, - online_store=self.online_store, - entity_defs=entity_defs, - entity_df=task.entity_df, - ) + # ✅ 1. Build typed execution context + context = self.get_execution_context(task) - # ✅ 3. Construct and execute DAG - builder = SparkFeatureBuilder( - spark_session=self.spark_session, - feature_view=task.feature_view, - task=task, - ) - plan = builder.build() + try: + # ✅ 2. Construct Feature Builder and run it + builder = SparkFeatureBuilder( + spark_session=self.spark_session, + feature_view=task.feature_view, + task=task, + ) + plan = builder.build() - return SparkDAGRetrievalJob( - plan=plan, - spark_session=self.spark_session, - context=context, - config=task.config, - full_feature_names=task.full_feature_name, - ) + return SparkDAGRetrievalJob( + plan=plan, + spark_session=self.spark_session, + context=context, + config=self.repo_config, + full_feature_names=task.full_feature_name, + ) + except Exception as e: + # 🛑 Handle failure + return SparkDAGRetrievalJob( + plan=None, + spark_session=self.spark_session, + context=context, + config=self.repo_config, + full_feature_names=task.full_feature_name, + error=e, + ) diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index d1b540b35e3..7dc933a0529 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -7,6 +7,8 @@ from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.spark.node import ( SparkAggregationNode, + SparkDedupNode, + SparkFilterNode, SparkHistoricalRetrievalReadNode, SparkJoinNode, SparkMaterializationReadNode, @@ -54,6 +56,21 @@ def build_join_node(self, input_node): self.nodes.append(node) return node + def build_filter_node(self, input_node): + filter_expr = None + if hasattr(self.feature_view, "filter"): + filter_expr = self.feature_view.filter + node = SparkFilterNode("filter", input_node, self.feature_view, filter_expr) + self.nodes.append(node) + return node + + def build_dedup_node(self, input_node): + node = SparkDedupNode( + "dedup", input_node, self.feature_view, self.spark_session + ) + self.nodes.append(node) + return node + def build_transformation_node(self, input_node): udf_name = self.feature_view.feature_transformation.name udf = self.feature_view.feature_transformation.udf diff --git a/sdk/python/feast/infra/compute_engines/spark/job.py b/sdk/python/feast/infra/compute_engines/spark/job.py index 3ddde5f9011..07ae85e1178 100644 --- a/sdk/python/feast/infra/compute_engines/spark/job.py +++ b/sdk/python/feast/infra/compute_engines/spark/job.py @@ -16,12 +16,13 @@ class SparkDAGRetrievalJob(SparkRetrievalJob): def __init__( self, spark_session: SparkSession, - plan: ExecutionPlan, context: ExecutionContext, full_feature_names: bool, config: RepoConfig, + plan: Optional[ExecutionPlan] = None, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, metadata: Optional[RetrievalMetadata] = None, + error: Optional[BaseException] = None, ): super().__init__( spark_session=spark_session, @@ -34,7 +35,11 @@ def __init__( self._plan = plan self._context = context self._metadata = metadata - self._spark_df = None # Will be populated on first access + self._spark_df = None + self._error = error + + def error(self) -> Optional[BaseException]: + return self._error def _ensure_executed(self): if self._spark_df is None: diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index eb38335dc33..ec45737b8db 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -24,7 +24,7 @@ from feast.infra.offline_stores.offline_utils import ( infer_event_timestamp_from_entity_df, ) -from feast.utils import _get_column_names, _get_fields_with_aliases +from feast.utils import _get_fields_with_aliases @dataclass @@ -58,7 +58,7 @@ def execute(self, context: ExecutionContext) -> DAGValue: feature_name_columns, timestamp_field, created_timestamp_column, - ) = _get_column_names(self.task.feature_view, context.entity_defs) + ) = context.column_info # 📥 Reuse Feast's robust query resolver retrieval_job = offline_store.pull_latest_from_table_or_query( @@ -104,14 +104,13 @@ def execute(self, context: ExecutionContext) -> DAGValue: """ fv = self.task.feature_view source = fv.batch_source - entities = context.entity_defs ( join_key_columns, feature_name_columns, timestamp_field, created_timestamp_column, - ) = _get_column_names(fv, entities) + ) = context.column_info # TODO: Use pull_all_from_table_or_query when it supports not filtering by timestamp # retrieval_job = offline_store.pull_all_from_table_or_query( @@ -198,7 +197,9 @@ def execute(self, context: ExecutionContext) -> DAGValue: ).agg(*agg_exprs) else: # Simple aggregation - grouped = input_df.groupBy(*self.group_by_keys).agg(*agg_exprs) + grouped = input_df.groupBy( + *self.group_by_keys, + ).agg(*agg_exprs) return DAGValue( data=grouped, format=DAGFormat.SPARK, metadata={"aggregated": True} @@ -223,15 +224,13 @@ def __init__( def execute(self, context: ExecutionContext) -> DAGValue: feature_value = self.get_single_input_value(context) feature_value.assert_format(DAGFormat.SPARK) - feature_df = feature_value.data + feature_df: DataFrame = feature_value.data entity_df = context.entity_df assert entity_df is not None, "entity_df must be set in ExecutionContext" # Get timestamp fields from feature view - join_keys, feature_cols, ts_col, created_ts_col = _get_column_names( - self.feature_view, context.entity_defs - ) + join_keys, feature_cols, ts_col, created_ts_col = context.column_info # Rename entity_df event_timestamp_col to match feature_df entity_schema = _get_entity_schema( @@ -242,13 +241,43 @@ def execute(self, context: ExecutionContext) -> DAGValue: entity_schema=entity_schema, ) entity_ts_alias = "__entity_event_timestamp" - entity_df = self.spark_session.createDataFrame(entity_df).withColumnRenamed( - event_timestamp_col, entity_ts_alias - ) + if not isinstance(entity_df, DataFrame): + entity_df = self.spark_session.createDataFrame(entity_df) + entity_df = entity_df.withColumnRenamed(event_timestamp_col, entity_ts_alias) # Perform left join + event timestamp filtering joined = feature_df.join(entity_df, on=join_keys, how="left") - joined = joined.filter(F.col(ts_col) <= F.col(entity_ts_alias)) + + return DAGValue( + data=joined, format=DAGFormat.SPARK, metadata={"joined_on": join_keys} + ) + + +class SparkFilterNode(DAGNode): + def __init__( + self, + name: str, + input_node: DAGNode, + feature_view: Union[BatchFeatureView, StreamFeatureView], + filter_condition: Optional[str] = None, + ): + super().__init__(name) + self.feature_view = feature_view + self.add_input(input_node) + self.filter_condition = filter_condition + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.SPARK) + input_df: DataFrame = input_value.data + + # Get timestamp fields from feature view + _, _, ts_col, _ = context.column_info + + # Apply filter condition + entity_ts_alias = "__entity_event_timestamp" + filtered_df = input_df + filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(entity_ts_alias)) # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl if self.feature_view.ttl: @@ -256,23 +285,59 @@ def execute(self, context: ExecutionContext) -> DAGValue: lower_bound = F.col(entity_ts_alias) - F.expr( f"INTERVAL {ttl_seconds} seconds" ) - joined = joined.filter(F.col(ts_col) >= lower_bound) + filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound) + # Optional custom filter condition + if self.filter_condition: + filtered_df = input_df.filter(self.filter_condition) + + return DAGValue( + data=filtered_df, + format=DAGFormat.SPARK, + metadata={"filter_applied": True}, + ) + + +class SparkDedupNode(DAGNode): + def __init__( + self, + name: str, + input_node: DAGNode, + feature_view: Union[BatchFeatureView, StreamFeatureView], + spark_session: SparkSession, + ): + super().__init__(name) + self.add_input(input_node) + self.feature_view = feature_view + self.spark_session = spark_session + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.SPARK) + input_df: DataFrame = input_value.data + + # Get timestamp fields from feature view + join_keys, _, ts_col, created_ts_col = context.column_info + + # Dedup based on join keys and event timestamp # Dedup with row_number + entity_ts_alias = "__entity_event_timestamp" partition_cols = join_keys + [entity_ts_alias] ordering = [F.col(ts_col).desc()] if created_ts_col: ordering.append(F.col(created_ts_col).desc()) window = Window.partitionBy(*partition_cols).orderBy(*ordering) - deduped = ( - joined.withColumn("row_num", F.row_number().over(window)) + deduped_df = ( + input_df.withColumn("row_num", F.row_number().over(window)) .filter("row_num = 1") .drop("row_num") ) return DAGValue( - data=deduped, format=DAGFormat.SPARK, metadata={"joined_on": join_keys} + data=deduped_df, + format=DAGFormat.SPARK, + metadata={"deduped": True}, ) diff --git a/sdk/python/feast/infra/compute_engines/tasks.py b/sdk/python/feast/infra/compute_engines/tasks.py index a1fdd4940dc..a5b5583b3ce 100644 --- a/sdk/python/feast/infra/compute_engines/tasks.py +++ b/sdk/python/feast/infra/compute_engines/tasks.py @@ -4,16 +4,16 @@ import pandas as pd -from feast import BatchFeatureView, RepoConfig, StreamFeatureView +from feast import BatchFeatureView, StreamFeatureView from feast.infra.registry.registry import Registry @dataclass class HistoricalRetrievalTask: + project: str entity_df: Union[pd.DataFrame, str] feature_view: Union[BatchFeatureView, StreamFeatureView] full_feature_name: bool registry: Registry - config: RepoConfig start_time: datetime end_time: datetime diff --git a/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py index 31fda14a66c..71eee882bbb 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py +++ b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py @@ -25,8 +25,8 @@ def transform_feature(df: DataFrame) -> DataFrame: - df = df.withColumn("conv_rate", df["conv_rate"] * 2) - df = df.withColumn("acc_rate", df["acc_rate"] * 2) + df = df.withColumn("sum_conv_rate", df["sum_conv_rate"] * 2) + df = df.withColumn("avg_acc_rate", df["avg_acc_rate"] * 2) return df diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 7bb4f52855c..fabc0f6a4c0 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -6,6 +6,7 @@ import pytest from feast import BatchFeatureView +from feast.aggregation import Aggregation from feast.infra.compute_engines.spark.compute import SparkComputeEngine from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob from feast.infra.compute_engines.tasks import HistoricalRetrievalTask @@ -59,7 +60,7 @@ def test_spark_compute_engine_get_historical_features(): "event_timestamp": yesterday, "created": now - timedelta(hours=2), "conv_rate": 0.8, - "acc_rate": 0.95, + "acc_rate": 0.5, "avg_daily_trips": 15, }, { @@ -75,7 +76,15 @@ def test_spark_compute_engine_get_historical_features(): "event_timestamp": yesterday, "created": now - timedelta(hours=2), "conv_rate": 0.7, - "acc_rate": 0.88, + "acc_rate": 0.4, + "avg_daily_trips": 12, + }, + { + "driver_id": 1002, + "event_timestamp": yesterday - timedelta(days=1), + "created": now - timedelta(hours=2), + "conv_rate": 0.3, + "acc_rate": 0.6, "avg_daily_trips": 12, }, ] @@ -90,9 +99,13 @@ def test_spark_compute_engine_get_historical_features(): name="driver_hourly_stats", entities=[driver], mode="python", + aggregations=[ + Aggregation(column="conv_rate", function="sum"), + Aggregation(column="acc_rate", function="avg"), + ], udf=transform_feature, udf_string="transform_feature", - ttl=timedelta(days=2), + ttl=timedelta(days=3), schema=schema, online=False, offline=False, @@ -112,18 +125,18 @@ def test_spark_compute_engine_get_historical_features(): # 🛠 Build retrieval task task = HistoricalRetrievalTask( + project=spark_environment.project, entity_df=entity_df, feature_view=driver_stats_fv, full_feature_name=False, registry=registry, - config=spark_environment.config, start_time=now - timedelta(days=1), end_time=now, ) # 🧪 Run SparkComputeEngine engine = SparkComputeEngine( - repo_config=task.config, + repo_config=spark_environment.config, offline_store=SparkOfflineStore(), online_store=MagicMock(), registry=registry, @@ -131,12 +144,15 @@ def test_spark_compute_engine_get_historical_features(): spark_dag_retrieval_job = engine.get_historical_features(task) spark_df = cast(SparkDAGRetrievalJob, spark_dag_retrieval_job).to_spark_df() - df_out = spark_df.to_pandas_on_spark() + df_out = spark_df.orderBy("driver_id").to_pandas_on_spark() # ✅ Assert output assert df_out.driver_id.to_list() == [1001, 1002] - assert abs(df_out["conv_rate"].to_list()[0] - 1.6) < 1e-6 - assert abs(df_out["conv_rate"].to_list()[1] - 1.4) < 1e-6 + assert abs(df_out["sum_conv_rate"].to_list()[0] - 1.6) < 1e-6 + assert abs(df_out["sum_conv_rate"].to_list()[1] - 2.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[0] - 1.0) < 1e-6 + assert abs(df_out["avg_acc_rate"].to_list()[1] - 1.0) < 1e-6 + finally: spark_environment.teardown() diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index c0782095952..9ffa8062533 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -5,11 +5,12 @@ from pyspark.sql import SparkSession from feast.aggregation import Aggregation -from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.compute_engines.spark.node import ( SparkAggregationNode, + SparkDedupNode, SparkJoinNode, SparkTransformationNode, ) @@ -183,21 +184,37 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): node_outputs={ "feature_node": feature_val, }, + column_info=ColumnInfo( + join_keys=["driver_id"], + feature_cols=["conv_rate", "acc_rate", "avg_daily_trips"], + ts_col="event_timestamp", + created_ts_col="created", + ), ) # Create the node and add input - node = SparkJoinNode( + join_node = SparkJoinNode( name="join", feature_node=MagicMock(name="feature_node"), join_keys=["user_id"], feature_view=feature_view, spark_session=spark_session, ) - node.inputs[0].name = "feature_node" # must match key in node_outputs + join_node.inputs[0].name = "feature_node" # must match key in node_outputs # Execute the node - output = node.execute(context) - result_df = output.data.orderBy("driver_id").collect() + output = join_node.execute(context) + context.node_outputs["join"] = output + + dedup_node = SparkDedupNode( + name="dedup", + input_node=join_node, + feature_view=feature_view, + spark_session=spark_session, + ) + dedup_node.inputs[0].name = "join" # must match key in node_outputs + dedup_output = dedup_node.execute(context) + result_df = dedup_output.data.orderBy("driver_id").collect() # Assertions assert output.format == DAGFormat.SPARK From 68cf242565b4aaf135a012f6801f9856821ce481 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sat, 12 Apr 2025 21:23:19 -0700 Subject: [PATCH 12/18] update API Signed-off-by: HaoXuAI --- sdk/python/feast/infra/compute_engines/base.py | 2 +- sdk/python/feast/infra/compute_engines/spark/compute.py | 1 - sdk/python/feast/infra/compute_engines/spark/job.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index bac1005cb3a..d5372d246aa 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -54,7 +54,7 @@ def get_execution_context( for name in task.feature_view.entities ] entity_df = None - if task.entity_df is not None: + if hasattr(task, "entity_df") and task.entity_df is not None: entity_df = task.entity_df column_info = self.get_column_info(task) diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 13f402df7a7..e6e6cc52971 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,5 +1,4 @@ from feast.infra.compute_engines.base import ComputeEngine -from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session diff --git a/sdk/python/feast/infra/compute_engines/spark/job.py b/sdk/python/feast/infra/compute_engines/spark/job.py index 07ae85e1178..0f343789d96 100644 --- a/sdk/python/feast/infra/compute_engines/spark/job.py +++ b/sdk/python/feast/infra/compute_engines/spark/job.py @@ -16,10 +16,10 @@ class SparkDAGRetrievalJob(SparkRetrievalJob): def __init__( self, spark_session: SparkSession, + plan: Optional[ExecutionPlan], context: ExecutionContext, full_feature_names: bool, config: RepoConfig, - plan: Optional[ExecutionPlan] = None, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, metadata: Optional[RetrievalMetadata] = None, error: Optional[BaseException] = None, @@ -52,5 +52,5 @@ def to_spark_df(self) -> pyspark.sql.DataFrame: return self._spark_df def to_sql(self) -> str: - self._ensure_executed() + assert self._plan is not None, "Execution plan is not set" return self._plan.to_sql(self._context) From 3a5cf921005c5a69e6611d6d38e5304f6ec5d4e4 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sat, 12 Apr 2025 22:47:38 -0700 Subject: [PATCH 13/18] update API Signed-off-by: HaoXuAI --- .../infra/compute_engines/feature_builder.py | 34 ++-- .../infra/compute_engines/spark/compute.py | 1 + .../compute_engines/spark/feature_builder.py | 4 +- .../feast/infra/compute_engines/spark/node.py | 49 ++++-- sdk/python/feast/utils.py | 1 + .../example_feature_repo_with_bfvs_compute.py | 72 -------- .../compute_engines/spark/test_compute.py | 158 ++++++++++++++---- 7 files changed, 185 insertions(+), 134 deletions(-) delete mode 100644 sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index 54f7587bfdd..cab32d47d26 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -55,28 +55,40 @@ def build_output_nodes(self, input_node): def build_validation_node(self, input_node): raise + def _should_aggregate(self): + return ( + hasattr(self.feature_view, "aggregations") + and self.feature_view.aggregations is not None + and len(self.feature_view.aggregations) > 0 + ) + + def _should_transform(self): + return ( + hasattr(self.feature_view, "feature_transformation") + and self.feature_view.feature_transformation + ) + + def _should_validate(self): + return getattr(self.feature_view, "enable_validation", False) + def build(self) -> ExecutionPlan: last_node = self.build_source_node() # PIT join entities to the feature data, and perform filtering - last_node = self.build_join_node(last_node) + if isinstance(self.task, HistoricalRetrievalTask): + last_node = self.build_join_node(last_node) + last_node = self.build_filter_node(last_node) - if ( - hasattr(self.feature_view, "aggregations") - and self.feature_view.aggregations is not None - ): + if self._should_aggregate(): last_node = self.build_aggregation_node(last_node) - else: + elif isinstance(self.task, HistoricalRetrievalTask): last_node = self.build_dedup_node(last_node) - if ( - hasattr(self.feature_view, "feature_transformation") - and self.feature_view.feature_transformation - ): + if self._should_transform(): last_node = self.build_transformation_node(last_node) - if getattr(self.feature_view, "enable_validation", False): + if self._should_validate(): last_node = self.build_validation_node(last_node) self.build_output_nodes(last_node) diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index e6e6cc52971..7ea2eb69596 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -54,6 +54,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: ) except Exception as e: + raise e # 🛑 Handle failure return SparkMaterializationJob( job_id=job_id, status=MaterializationJobStatus.ERROR, error=e diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index 7dc933a0529..e7efbfe1195 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -60,7 +60,9 @@ def build_filter_node(self, input_node): filter_expr = None if hasattr(self.feature_view, "filter"): filter_expr = self.feature_view.filter - node = SparkFilterNode("filter", input_node, self.feature_view, filter_expr) + node = SparkFilterNode( + "filter", self.spark_session, input_node, self.feature_view, filter_expr + ) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index ec45737b8db..e3f737a4fa6 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -26,6 +26,29 @@ ) from feast.utils import _get_fields_with_aliases +ENTITY_TS_ALIAS = "__entity_event_timestamp" + + +# Rename entity_df event_timestamp_col to match feature_df +def rename_entity_ts_column( + spark_session: SparkSession, entity_df: DataFrame +) -> DataFrame: + # check if entity_ts_alias already exists + if ENTITY_TS_ALIAS in entity_df.columns: + return entity_df + + entity_schema = _get_entity_schema( + spark_session=spark_session, + entity_df=entity_df, + ) + event_timestamp_col = infer_event_timestamp_from_entity_df( + entity_schema=entity_schema, + ) + if not isinstance(entity_df, DataFrame): + entity_df = spark_session.createDataFrame(entity_df) + entity_df = entity_df.withColumnRenamed(event_timestamp_col, ENTITY_TS_ALIAS) + return entity_df + @dataclass class SparkJoinContext: @@ -233,19 +256,12 @@ def execute(self, context: ExecutionContext) -> DAGValue: join_keys, feature_cols, ts_col, created_ts_col = context.column_info # Rename entity_df event_timestamp_col to match feature_df - entity_schema = _get_entity_schema( + entity_df = rename_entity_ts_column( spark_session=self.spark_session, entity_df=entity_df, ) - event_timestamp_col = infer_event_timestamp_from_entity_df( - entity_schema=entity_schema, - ) - entity_ts_alias = "__entity_event_timestamp" - if not isinstance(entity_df, DataFrame): - entity_df = self.spark_session.createDataFrame(entity_df) - entity_df = entity_df.withColumnRenamed(event_timestamp_col, entity_ts_alias) - # Perform left join + event timestamp filtering + # Perform left join on entity df joined = feature_df.join(entity_df, on=join_keys, how="left") return DAGValue( @@ -257,11 +273,13 @@ class SparkFilterNode(DAGNode): def __init__( self, name: str, + spark_session: SparkSession, input_node: DAGNode, feature_view: Union[BatchFeatureView, StreamFeatureView], filter_condition: Optional[str] = None, ): super().__init__(name) + self.spark_session = spark_session self.feature_view = feature_view self.add_input(input_node) self.filter_condition = filter_condition @@ -274,22 +292,22 @@ def execute(self, context: ExecutionContext) -> DAGValue: # Get timestamp fields from feature view _, _, ts_col, _ = context.column_info - # Apply filter condition - entity_ts_alias = "__entity_event_timestamp" + # Optional filter: feature.ts <= entity.event_timestamp filtered_df = input_df - filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(entity_ts_alias)) + if ENTITY_TS_ALIAS in input_df.columns: + filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS)) # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl if self.feature_view.ttl: ttl_seconds = int(self.feature_view.ttl.total_seconds()) - lower_bound = F.col(entity_ts_alias) - F.expr( + lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr( f"INTERVAL {ttl_seconds} seconds" ) filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound) # Optional custom filter condition if self.filter_condition: - filtered_df = input_df.filter(self.filter_condition) + filtered_df = filtered_df.filter(self.filter_condition) return DAGValue( data=filtered_df, @@ -321,8 +339,7 @@ def execute(self, context: ExecutionContext) -> DAGValue: # Dedup based on join keys and event timestamp # Dedup with row_number - entity_ts_alias = "__entity_event_timestamp" - partition_cols = join_keys + [entity_ts_alias] + partition_cols = join_keys + [ENTITY_TS_ALIAS] ordering = [F.col(ts_col).desc()] if created_ts_col: ordering.append(F.col(created_ts_col).desc()) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 4cca1379ed3..1520c2f7dd5 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -271,6 +271,7 @@ def _convert_arrow_fv_to_proto( if isinstance(table, pyarrow.Table): table = table.to_batches()[0] + # TODO: This will break if the feature view has aggregations or transformations columns = [ (field.name, field.dtype.to_value_type()) for field in feature_view.features ] + list(join_keys.items()) diff --git a/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py b/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py deleted file mode 100644 index 71eee882bbb..00000000000 --- a/sdk/python/tests/example_repos/example_feature_repo_with_bfvs_compute.py +++ /dev/null @@ -1,72 +0,0 @@ -from datetime import timedelta - -from pyspark.sql import DataFrame - -from feast import BatchFeatureView, Entity, Field, FileSource -from feast.types import Float32, Int32, Int64 - -driver_hourly_stats = FileSource( - path="%PARQUET_PATH%", # placeholder to be replaced by the test - timestamp_field="event_timestamp", - created_timestamp_column="created", -) - -driver = Entity( - name="driver_id", - description="driver id", -) - -schema = [ - Field(name="conv_rate", dtype=Float32), - Field(name="acc_rate", dtype=Float32), - Field(name="avg_daily_trips", dtype=Int64), - Field(name="driver_id", dtype=Int32), -] - - -def transform_feature(df: DataFrame) -> DataFrame: - df = df.withColumn("sum_conv_rate", df["sum_conv_rate"] * 2) - df = df.withColumn("avg_acc_rate", df["avg_acc_rate"] * 2) - return df - - -driver_hourly_stats_view = BatchFeatureView( - name="driver_hourly_stats", - entities=[driver], - mode="python", - udf=transform_feature, - udf_string="transform_feature", - ttl=timedelta(days=1), - schema=[ - Field(name="conv_rate", dtype=Float32), - Field(name="acc_rate", dtype=Float32), - Field(name="avg_daily_trips", dtype=Int64), - Field(name="driver_id", dtype=Int32), - ], - online=True, - offline=True, - source=driver_hourly_stats, - tags={}, -) - -global_daily_stats = FileSource( - path="%PARQUET_PATH_GLOBAL%", # placeholder to be replaced by the test - timestamp_field="event_timestamp", - created_timestamp_column="created", -) - -global_stats_feature_view = BatchFeatureView( - name="global_daily_stats", - entities=None, - mode="python", - udf=lambda x: x, - ttl=timedelta(days=1), - schema=[ - Field(name="num_rides", dtype=Int32), - Field(name="avg_ride_length", dtype=Float32), - ], - online=True, - offline=True, - source=global_daily_stats, - tags={}, -) diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index fabc0f6a4c0..0a5fc8c97b7 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -4,23 +4,26 @@ import pandas as pd import pytest +from pyspark.sql import DataFrame +from tqdm import tqdm -from feast import BatchFeatureView +from feast import BatchFeatureView, Entity, Field from feast.aggregation import Aggregation +from feast.data_source import DataSource from feast.infra.compute_engines.spark.compute import SparkComputeEngine from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import ( + MaterializationJobStatus, + MaterializationTask, +) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, ) from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import ( SparkDataSourceCreator, ) -from tests.example_repos.example_feature_repo_with_bfvs_compute import ( - driver, - schema, - transform_feature, -) +from feast.types import Float32, Int32, Int64 from tests.integration.feature_repos.integration_test_repo_config import ( IntegrationTestRepoConfig, ) @@ -31,28 +34,18 @@ RedisOnlineStoreCreator, ) +now = datetime.now() +today = datetime.today() -@pytest.mark.integration -def test_spark_compute_engine_get_historical_features(): - now = datetime.now() - today = datetime.today() - yesterday = today - timedelta(days=1) - last_week = today - timedelta(days=7) +driver = Entity( + name="driver_id", + description="driver id", +) - spark_config = IntegrationTestRepoConfig( - provider="local", - online_store_creator=RedisOnlineStoreCreator, - offline_store_creator=SparkDataSourceCreator, - batch_engine={"type": "spark.engine", "partitions": 10}, - ) - spark_environment = construct_test_environment( - spark_config, None, entity_key_serialization_version=2 - ) - spark_environment.setup() - fs = spark_environment.feature_store - registry = fs.registry - # 👷 Prepare test parquet feature file +def create_feature_dataset(spark_environment) -> DataSource: + yesterday = today - timedelta(days=1) + last_week = today - timedelta(days=7) df = pd.DataFrame( [ { @@ -95,6 +88,45 @@ def test_spark_compute_engine_get_historical_features(): timestamp_field="event_timestamp", created_timestamp_column="created", ) + return ds + + +def create_entity_df() -> pd.DataFrame: + entity_df = pd.DataFrame( + [ + {"driver_id": 1001, "event_timestamp": today}, + {"driver_id": 1002, "event_timestamp": today}, + ] + ) + return entity_df + + +def create_spark_environment(): + spark_config = IntegrationTestRepoConfig( + provider="local", + online_store_creator=RedisOnlineStoreCreator, + offline_store_creator=SparkDataSourceCreator, + batch_engine={"type": "spark.engine", "partitions": 10}, + ) + spark_environment = construct_test_environment( + spark_config, None, entity_key_serialization_version=2 + ) + spark_environment.setup() + return spark_environment + + +@pytest.mark.integration +def test_spark_compute_engine_get_historical_features(): + spark_environment = create_spark_environment() + fs = spark_environment.feature_store + registry = fs.registry + data_source = create_feature_dataset(spark_environment) + + def transform_feature(df: DataFrame) -> DataFrame: + df = df.withColumn("sum_conv_rate", df["sum_conv_rate"] * 2) + df = df.withColumn("avg_acc_rate", df["avg_acc_rate"] * 2) + return df + driver_stats_fv = BatchFeatureView( name="driver_hourly_stats", entities=[driver], @@ -106,19 +138,18 @@ def test_spark_compute_engine_get_historical_features(): udf=transform_feature, udf_string="transform_feature", ttl=timedelta(days=3), - schema=schema, + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], online=False, offline=False, - source=ds, + source=data_source, ) - # 📥 Entity DataFrame to join with - entity_df = pd.DataFrame( - [ - {"driver_id": 1001, "event_timestamp": today}, - {"driver_id": 1002, "event_timestamp": today}, - ] - ) + entity_df = create_entity_df() try: fs.apply([driver, driver_stats_fv]) @@ -157,5 +188,64 @@ def test_spark_compute_engine_get_historical_features(): spark_environment.teardown() +def test_spark_compute_engine_materialize(): + spark_environment = create_spark_environment() + fs = spark_environment.feature_store + registry = fs.registry + + data_source = create_feature_dataset(spark_environment) + + def transform_feature(df: DataFrame) -> DataFrame: + df = df.withColumn("conv_rate", df["conv_rate"] * 2) + df = df.withColumn("acc_rate", df["acc_rate"] * 2) + return df + + driver_stats_fv = BatchFeatureView( + name="driver_hourly_stats", + entities=[driver], + mode="python", + udf=transform_feature, + udf_string="transform_feature", + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + Field(name="driver_id", dtype=Int32), + ], + online=True, + offline=False, + source=data_source, + ) + + def tqdm_builder(length): + return tqdm(total=length, ncols=100) + + try: + fs.apply([driver, driver_stats_fv]) + + # 🛠 Build retrieval task + task = MaterializationTask( + project=spark_environment.project, + feature_view=driver_stats_fv, + start_time=now - timedelta(days=1), + end_time=now, + tqdm_builder=tqdm_builder, + ) + + # 🧪 Run SparkComputeEngine + engine = SparkComputeEngine( + repo_config=spark_environment.config, + offline_store=SparkOfflineStore(), + online_store=MagicMock(), + registry=registry, + ) + + spark_materialize_job = engine.materialize(task) + + assert spark_materialize_job.status() == MaterializationJobStatus.SUCCEEDED + finally: + spark_environment.teardown() + + if __name__ == "__main__": test_spark_compute_engine_get_historical_features() From c1ba3d6e36d8a2e6b1519ff92fc2b0ea3f730b45 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sat, 12 Apr 2025 22:55:31 -0700 Subject: [PATCH 14/18] fix linting Signed-off-by: HaoXuAI --- .../unit/infra/compute_engines/spark/test_nodes.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index 9ffa8062533..afeea82008a 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -59,6 +59,12 @@ def strip_extra_spaces(df): online_store=MagicMock(), entity_defs=MagicMock(), entity_df=None, + column_info=ColumnInfo( + join_keys=["name"], + feature_cols=["age"], + ts_col="", + created_ts_col="", + ), node_outputs={"source": input_value}, ) @@ -101,6 +107,12 @@ def test_spark_aggregation_node_executes_correctly(spark_session): online_store=MagicMock(), entity_defs=[], entity_df=None, + column_info=ColumnInfo( + join_keys=["user_id"], + feature_cols=["value"], + ts_col="", + created_ts_col="", + ), node_outputs={"source": input_value}, ) From 95f757d5444867b52f81f06e597320c317e02b02 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sat, 12 Apr 2025 23:07:18 -0700 Subject: [PATCH 15/18] update doc Signed-off-by: HaoXuAI --- docs/reference/compute-engine/README.md | 33 +++++++++++++++++-- .../infra/compute_engines/spark/compute.py | 1 - 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index 4324164b66c..b9fde74234a 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -38,8 +38,33 @@ This system builds and executes DAGs (Directed Acyclic Graphs) of typed operatio --- -## 🛠️ Example DAG Flow -`Read → Aggregate → Join → Transform → Write` +## 🛠️ Feature Builder Flow +```markdown +SourceReadNode + | + v +JoinNode (Only for get_historical_features with entity df) + | + v +FilterNode (Always included; applies TTL or user-defined filters) + | + v +AggregationNode (If aggregations are defined in FeatureView) + | + v +DeduplicationNode (If no aggregation is defined for get_historical_features) + | + v +TransformationNode (If feature_transformation is defined) + | + v +ValidationNode (If enable_validation = True) + | + v +Output + ├──> RetrievalOutput (For get_historical_features) + └──> OnlineStoreWrite / OfflineStoreWrite (For materialize) +``` Each step is implemented as a `DAGNode`. An `ExecutionPlan` executes these nodes in topological order, caching `DAGValue` outputs. @@ -59,7 +84,7 @@ class MyComputeEngine(ComputeEngine): def materialize(self, task: MaterializationTask) -> MaterializationJob: ... - def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: + def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob: ... ``` @@ -71,6 +96,8 @@ class CustomFeatureBuilder(FeatureBuilder): def build_source_node(self): ... def build_aggregation_node(self, input_node): ... def build_join_node(self, input_node): ... + def build_filter_node(self, input_node): + def build_dedup_node(self, input_node): def build_transformation_node(self, input_node): ... def build_output_nodes(self, input_node): ... ``` diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 7ea2eb69596..e6e6cc52971 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -54,7 +54,6 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: ) except Exception as e: - raise e # 🛑 Handle failure return SparkMaterializationJob( job_id=job_id, status=MaterializationJobStatus.ERROR, error=e From e5445a2d559273ea2854fbcdc7690146643ae61c Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sat, 12 Apr 2025 23:19:40 -0700 Subject: [PATCH 16/18] update doc Signed-off-by: HaoXuAI --- sdk/python/feast/batch_feature_view.py | 5 +++-- sdk/python/feast/stream_feature_view.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 1616761e9ca..2441e4bc859 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -41,7 +41,8 @@ class BatchFeatureView(FeatureView): schema: The schema of the feature view, including feature, timestamp, and entity columns. If not specified, can be inferred from the underlying data source. source: The batch source of data where this group of features is stored. - online: A boolean indicating whether online retrieval is enabled for this feature view. + online: A boolean indicating whether online retrieval and write to online store is enabled for this feature view. + offline: A boolean indicating whether offline retrieval and write to offline store is enabled for this feature view. description: A human-readable description. tags: A dictionary of key-value pairs to store arbitrary metadata. owner: The owner of the batch feature view, typically the email of the primary maintainer. @@ -77,7 +78,7 @@ def __init__( entities: Optional[List[Entity]] = None, ttl: Optional[timedelta] = None, tags: Optional[Dict[str, str]] = None, - online: bool = True, + online: bool = False, offline: bool = True, description: str = "", owner: str = "", diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 083bbbe2771..e3608b10354 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -57,7 +57,8 @@ class StreamFeatureView(FeatureView): aggregations: List of aggregations registered with the stream feature view. mode: The mode of execution. timestamp_field: Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows. - online: A boolean indicating whether online retrieval is enabled for this feature view. + online: A boolean indicating whether online retrieval, and write to online store is enabled for this feature view. + offline: A boolean indicating whether offline retrieval, and write to offline store is enabled for this feature view. description: A human-readable description. tags: A dictionary of key-value pairs to store arbitrary metadata. owner: The owner of the stream feature view, typically the email of the primary maintainer. From 24330644aa24f45706aaa4581de929a756b86ede Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 13 Apr 2025 00:41:33 -0700 Subject: [PATCH 17/18] update test Signed-off-by: HaoXuAI --- .../tests/integration/compute_engines/spark/test_compute.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 0a5fc8c97b7..b8046c12296 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -188,6 +188,7 @@ def transform_feature(df: DataFrame) -> DataFrame: spark_environment.teardown() +@pytest.mark.integration def test_spark_compute_engine_materialize(): spark_environment = create_spark_environment() fs = spark_environment.feature_store From 87e51c7a3e5c56535699295180e35491ad42d667 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Tue, 15 Apr 2025 15:31:10 -0700 Subject: [PATCH 18/18] update doc Signed-off-by: HaoXuAI --- docs/getting-started/architecture/feature-transformation.md | 2 +- docs/reference/compute-engine/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/getting-started/architecture/feature-transformation.md b/docs/getting-started/architecture/feature-transformation.md index 6b09eb9f950..562a733fef0 100644 --- a/docs/getting-started/architecture/feature-transformation.md +++ b/docs/getting-started/architecture/feature-transformation.md @@ -8,7 +8,7 @@ Feature transformations can be executed by three types of "transformation engine 1. The Feast Feature Server 2. An Offline Store (e.g., Snowflake, BigQuery, DuckDB, Spark, etc.) -3. An Compute Engine (see more [here](../../reference/compute-engine/README.md)) +3. [A Compute Engine](../../reference/compute-engine/README.md) The three transformation engines are coupled with the [communication pattern used for writes](write-patterns.md). diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index b9fde74234a..50aaa5befab 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -1,6 +1,6 @@ # 🧠 ComputeEngine (WIP) -The `ComputeEngine` is Feast’s pluggable abstraction for executing feature pipelines — including transformations, aggregations, joins, and materialization/get_historical_features — on a backend of your choice (e.g., Spark, PyArrow, Pandas, Ray). +The `ComputeEngine` is Feast’s pluggable abstraction for executing feature pipelines — including transformations, aggregations, joins, and materializations/get_historical_features — on a backend of your choice (e.g., Spark, PyArrow, Pandas, Ray). It powers both: