diff --git a/docs/getting-started/components/batch-materialization-engine.md b/docs/getting-started/components/batch-materialization-engine.md index 7be22fe1255..9a3f7af6c7d 100644 --- a/docs/getting-started/components/batch-materialization-engine.md +++ b/docs/getting-started/components/batch-materialization-engine.md @@ -1,9 +1,11 @@ # Batch Materialization Engine +Note: The materialization engine is not constructed via unified compute engine interface. + A batch materialization engine is a component of Feast that's responsible for moving data from the offline store into the online store. -A materialization engine abstracts over specific technologies or frameworks that are used to materialize data. It allows users to use a pure local serialized approach (which is the default LocalMaterializationEngine), or delegates the materialization to seperate components (e.g. AWS Lambda, as implemented by the the LambdaMaterializaionEngine). +A materialization engine abstracts over specific technologies or frameworks that are used to materialize data. It allows users to use a pure local serialized approach (which is the default LocalComputeEngine), or delegates the materialization to seperate components (e.g. AWS Lambda, as implemented by the the LambdaComputeEngine). -If the built-in engines are not sufficient, you can create your own custom materialization engine. Please see [this guide](../../how-to-guides/customizing-feast/creating-a-custom-materialization-engine.md) for more details. +If the built-in engines are not sufficient, you can create your own custom materialization engine. Please see [this guide](../../how-to-guides/customizing-feast/creating-a-custom-compute-engine.md) for more details. Please see [feature\_store.yaml](../../reference/feature-repository/feature-store-yaml.md#overview) for configuring engines. diff --git a/docs/how-to-guides/customizing-feast/creating-a-custom-materialization-engine.md b/docs/how-to-guides/customizing-feast/creating-a-custom-compute-engine.md similarity index 64% rename from docs/how-to-guides/customizing-feast/creating-a-custom-materialization-engine.md rename to docs/how-to-guides/customizing-feast/creating-a-custom-compute-engine.md index eb003d9db4d..115fb6945b7 100644 --- a/docs/how-to-guides/customizing-feast/creating-a-custom-materialization-engine.md +++ b/docs/how-to-guides/customizing-feast/creating-a-custom-compute-engine.md @@ -1,24 +1,24 @@ -# Adding a custom batch materialization engine +# Adding a custom compute engine ### Overview -Feast batch materialization operations (`materialize` and `materialize-incremental`) execute through a `BatchMaterializationEngine`. +Feast batch materialization operations (`materialize` and `materialize-incremental`), and get_historical_features are executed through a `ComputeEngine`. -Custom batch materialization engines allow Feast users to extend Feast to customize the materialization process. Examples include: +Custom batch compute engines allow Feast users to extend Feast to customize the materialization and get_historical_features process. Examples include: * Setting up custom materialization-specific infrastructure during `feast apply` (e.g. setting up Spark clusters or Lambda Functions) * Launching custom batch ingestion (materialization) jobs (Spark, Beam, AWS Lambda) * Tearing down custom materialization-specific infrastructure during `feast teardown` (e.g. tearing down Spark clusters, or deleting Lambda Functions) -Feast comes with built-in materialization engines, e.g, `LocalMaterializationEngine`, and an experimental `LambdaMaterializationEngine`. However, users can develop their own materialization engines by creating a class that implements the contract in the [BatchMaterializationEngine class](https://github.com/feast-dev/feast/blob/6d7b38a39024b7301c499c20cf4e7aef6137c47c/sdk/python/feast/infra/materialization/batch\_materialization\_engine.py#L72). +Feast comes with built-in materialization engines, e.g, `LocalComputeEngine`, and an experimental `LambdaComputeEngine`. However, users can develop their own compute engines by creating a class that implements the contract in the [ComputeEngine class](https://github.com/feast-dev/feast/blob/85514edbb181df083e6a0d24672c00f0624dcaa3/sdk/python/feast/infra/compute_engines/base.py#L19). ### Guide -The fastest way to add custom logic to Feast is to extend an existing materialization engine. The most generic engine is the `LocalMaterializationEngine` which contains no cloud-specific logic. The guide that follows will extend the `LocalProvider` with operations that print text to the console. It is up to you as a developer to add your custom code to the engine methods, but the guide below will provide the necessary scaffolding to get you started. +The fastest way to add custom logic to Feast is to implement the ComputeEngine. The guide that follows will extend the `LocalProvider` with operations that print text to the console. It is up to you as a developer to add your custom code to the engine methods, but the guide below will provide the necessary scaffolding to get you started. #### Step 1: Define an Engine class -The first step is to define a custom materialization engine class. We've created the `MyCustomEngine` below. This python file can be placed in your `feature_repo` directory if you're following the Quickstart guide. +The first step is to define a custom compute engine class. We've created the `MyCustomEngine` below. This python file can be placed in your `feature_repo` directory if you're following the Quickstart guide. ```python from typing import List, Sequence, Union @@ -27,14 +27,16 @@ from feast.entity import Entity from feast.feature_view import FeatureView from feast.batch_feature_view import BatchFeatureView from feast.stream_feature_view import StreamFeatureView -from feast.infra.materialization.local_engine import LocalMaterializationJob, LocalMaterializationEngine +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.local.job import LocalMaterializationJob +from feast.infra.compute_engines.base import ComputeEngine from feast.infra.common.materialization_job import MaterializationTask -from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob from feast.infra.online_stores.online_store import OnlineStore from feast.repo_config import RepoConfig -class MyCustomEngine(LocalMaterializationEngine): +class MyCustomEngine(ComputeEngine): def __init__( self, *, @@ -80,9 +82,13 @@ class MyCustomEngine(LocalMaterializationEngine): ) for task in tasks ] + + def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob: + raise NotImplementedError ``` -Notice how in the above engine we have only overwritten two of the methods on the `LocalMaterializatinEngine`, namely `update` and `materialize`. These two methods are convenient to replace if you are planning to launch custom batch jobs. +Notice how in the above engine we have only overwritten two of the methods on the `LocalComputeEngine`, namely `update` and `materialize`. These two methods are convenient to replace if you are planning to launch custom batch jobs. +If you want to use the compute to execute the get_historical_features method, you will need to implement the `get_historical_features` method as well. #### Step 2: Configuring Feast to use the engine diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 2441e4bc859..933696ced33 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -79,7 +79,7 @@ def __init__( ttl: Optional[timedelta] = None, tags: Optional[Dict[str, str]] = None, online: bool = False, - offline: bool = True, + offline: bool = False, description: str = "", owner: str = "", schema: Optional[List[Field]] = None, diff --git a/sdk/python/feast/infra/common/materialization_job.py b/sdk/python/feast/infra/common/materialization_job.py index 60ded6735a6..f4ce5b09548 100644 --- a/sdk/python/feast/infra/common/materialization_job.py +++ b/sdk/python/feast/infra/common/materialization_job.py @@ -20,7 +20,8 @@ class MaterializationTask: feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] start_time: datetime end_time: datetime - tqdm_builder: Callable[[int], tqdm] + only_latest: bool = True + tqdm_builder: Union[None, Callable[[int], tqdm]] = None class MaterializationJobStatus(enum.Enum): diff --git a/sdk/python/feast/infra/materialization/aws_lambda/Dockerfile b/sdk/python/feast/infra/compute_engines/aws_lambda/Dockerfile similarity index 100% rename from sdk/python/feast/infra/materialization/aws_lambda/Dockerfile rename to sdk/python/feast/infra/compute_engines/aws_lambda/Dockerfile diff --git a/sdk/python/feast/infra/materialization/aws_lambda/app.py b/sdk/python/feast/infra/compute_engines/aws_lambda/app.py similarity index 100% rename from sdk/python/feast/infra/materialization/aws_lambda/app.py rename to sdk/python/feast/infra/compute_engines/aws_lambda/app.py diff --git a/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py b/sdk/python/feast/infra/compute_engines/aws_lambda/lambda_engine.py similarity index 89% rename from sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py rename to sdk/python/feast/infra/compute_engines/aws_lambda/lambda_engine.py index 03eb51a2b66..cc32e5b74b3 100644 --- a/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py +++ b/sdk/python/feast/infra/compute_engines/aws_lambda/lambda_engine.py @@ -3,13 +3,12 @@ import logging from concurrent.futures import ThreadPoolExecutor, wait from dataclasses import dataclass -from datetime import datetime -from typing import Callable, List, Literal, Optional, Sequence, Union +from typing import Literal, Optional, Sequence, Union import boto3 +import pyarrow as pa from botocore.config import Config from pydantic import StrictStr -from tqdm import tqdm from feast import utils from feast.batch_feature_view import BatchFeatureView @@ -21,9 +20,8 @@ MaterializationJobStatus, MaterializationTask, ) -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, -) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry @@ -40,8 +38,8 @@ logger = logging.getLogger(__name__) -class LambdaMaterializationEngineConfig(FeastConfigBaseModel): - """Batch Materialization Engine config for lambda based engine""" +class LambdaComputeEngineConfig(FeastConfigBaseModel): + """Batch Compute Engine config for lambda based engine""" type: Literal["lambda"] = "lambda" """ Type selector""" @@ -82,11 +80,18 @@ def url(self) -> Optional[str]: return None -class LambdaMaterializationEngine(BatchMaterializationEngine): +class LambdaComputeEngine(ComputeEngine): """ WARNING: This engine should be considered "Alpha" functionality. """ + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> pa.Table: + raise NotImplementedError( + "Lambda Compute Engine does not support get_historical_features" + ) + def update( self, project: str, @@ -160,30 +165,14 @@ def __init__( config = Config(read_timeout=DEFAULT_TIMEOUT + 10) self.lambda_client = boto3.client("lambda", config=config) - def materialize( - self, registry, tasks: List[MaterializationTask] - ) -> List[MaterializationJob]: - return [ - self._materialize_one( - registry, - task.feature_view, - task.start_time, - task.end_time, - task.project, - task.tqdm_builder, - ) - for task in tasks - ] - def _materialize_one( - self, - registry: BaseRegistry, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], - start_date: datetime, - end_date: datetime, - project: str, - tqdm_builder: Callable[[int], tqdm], + self, registry: BaseRegistry, task: MaterializationTask, **kwargs ): + feature_view = task.feature_view + start_date = task.start_time + end_date = task.end_time + project = task.project + entities = [] for entity_name in feature_view.entities: entities.append(registry.get_entity(entity_name, project)) diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index 6e1a90f45b8..6acdb8d11d6 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -1,63 +1,130 @@ -from abc import ABC -from typing import Union +from abc import ABC, abstractmethod +from typing import List, Optional, Sequence, Union import pyarrow as pa from feast import RepoConfig +from feast.batch_feature_view import BatchFeatureView +from feast.entity import Entity +from feast.feature_view import FeatureView from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationTask, ) from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext -from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.registry.registry import Registry +from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.stream_feature_view import StreamFeatureView from feast.utils import _get_column_names class ComputeEngine(ABC): """ - The interface that Feast uses to control the compute system that handles materialization and get_historical_features. + The interface that Feast uses to control to compute system that handles materialization and get_historical_features. Each engine must implement: - materialize(): to generate and persist features - - get_historical_features(): to perform point-in-time correct joins + - get_historical_features(): to perform historical retrieval of features Engines should use FeatureBuilder and DAGNode abstractions to build modular, pluggable workflows. """ def __init__( self, *, - registry: Registry, repo_config: RepoConfig, offline_store: OfflineStore, online_store: OnlineStore, **kwargs, ): - self.registry = registry self.repo_config = repo_config self.offline_store = offline_store self.online_store = online_store - def materialize(self, task: MaterializationTask) -> MaterializationJob: - raise NotImplementedError + @abstractmethod + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ): + """ + Prepares cloud resources required for batch materialization for the specified set of Feast objects. + + Args: + project: Feast project to which the objects belong. + views_to_delete: Feature views whose corresponding infrastructure should be deleted. + views_to_keep: Feature views whose corresponding infrastructure should not be deleted, and + may need to be updated. + entities_to_delete: Entities whose corresponding infrastructure should be deleted. + entities_to_keep: Entities whose corresponding infrastructure should not be deleted, and + may need to be updated. + """ + pass + + @abstractmethod + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ): + """ + Tears down all cloud resources used by the materialization engine for the specified set of Feast objects. + + Args: + project: Feast project to which the objects belong. + fvs: Feature views whose corresponding infrastructure should be deleted. + entities: Entities whose corresponding infrastructure should be deleted. + """ + pass - def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table: + def materialize( + self, + registry: BaseRegistry, + tasks: Union[MaterializationTask, List[MaterializationTask]], + **kwargs, + ) -> List[MaterializationJob]: + if isinstance(tasks, MaterializationTask): + tasks = [tasks] + return [self._materialize_one(registry, task, **kwargs) for task in tasks] + + def _materialize_one( + self, + registry: BaseRegistry, + task: MaterializationTask, + **kwargs, + ) -> MaterializationJob: + raise NotImplementedError( + "Materialization is not implemented for this compute engine." + ) + + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> Union[RetrievalJob, pa.Table]: raise NotImplementedError def get_execution_context( self, + registry: BaseRegistry, task: Union[MaterializationTask, HistoricalRetrievalTask], ) -> ExecutionContext: entity_defs = [ - self.registry.get_entity(name, task.project) + registry.get_entity(name, task.project) for name in task.feature_view.entities ] entity_df = None if hasattr(task, "entity_df") and task.entity_df is not None: entity_df = task.entity_df - column_info = self.get_column_info(task) + column_info = self.get_column_info(registry, task) return ExecutionContext( project=task.project, repo_config=self.repo_config, @@ -70,14 +137,39 @@ def get_execution_context( def get_column_info( self, + registry: BaseRegistry, task: Union[MaterializationTask, HistoricalRetrievalTask], ) -> ColumnInfo: + entities = [] + for entity_name in task.feature_view.entities: + entities.append(registry.get_entity(entity_name, task.project)) + join_keys, feature_cols, ts_col, created_ts_col = _get_column_names( - task.feature_view, self.registry.list_entities(task.project) + task.feature_view, entities ) + field_mapping = self.get_field_mapping(task.feature_view) + return ColumnInfo( join_keys=join_keys, feature_cols=feature_cols, ts_col=ts_col, created_ts_col=created_ts_col, + field_mapping=field_mapping, ) + + def get_field_mapping( + self, feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] + ) -> Optional[dict]: + """ + Get the field mapping for a feature view. + Args: + feature_view: The feature view to get the field mapping for. + + Returns: + A dictionary mapping field names to column names. + """ + if feature_view.stream_source: + return feature_view.stream_source.field_mapping + if feature_view.batch_source: + return feature_view.batch_source.field_mapping + return None diff --git a/sdk/python/feast/infra/compute_engines/dag/context.py b/sdk/python/feast/infra/compute_engines/dag/context.py index 8b170b67766..6b1970d25f8 100644 --- a/sdk/python/feast/infra/compute_engines/dag/context.py +++ b/sdk/python/feast/infra/compute_engines/dag/context.py @@ -16,12 +16,39 @@ class ColumnInfo: feature_cols: List[str] ts_col: str created_ts_col: Optional[str] + field_mapping: Optional[Dict[str, str]] = None def __iter__(self): yield self.join_keys yield self.feature_cols yield self.ts_col yield self.created_ts_col + yield self.field_mapping + + @property + def timestamp_column(self) -> str: + """ + Get the event timestamp column from the context. + """ + mapped_column = self._get_mapped_column(self.ts_col) + if mapped_column is None: + raise ValueError("Timestamp column cannot be None") + return mapped_column + + @property + def created_timestamp_column(self) -> Optional[str]: + """ + Get the created timestamp column from the context. + """ + return self._get_mapped_column(self.created_ts_col) + + def _get_mapped_column(self, column: Optional[str]) -> Optional[str]: + """ + Helper method to get the mapped column name if it exists in field_mapping. + """ + if column and self.field_mapping: + return self.field_mapping.get(column, column) + return column @dataclass @@ -55,6 +82,8 @@ class ExecutionContext: node_outputs: Internal cache of DAGValue outputs keyed by DAGNode name. Automatically populated during ExecutionPlan execution to avoid redundant computation. Used by downstream nodes to access their input data. + + field_mapping: A mapping of field names to their corresponding column names in the """ project: str diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index ceed3e2d4f3..9d4e4466499 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -69,6 +69,9 @@ def _should_transform(self): def _should_validate(self): return getattr(self.feature_view, "enable_validation", False) + def _should_dedupe(self, task): + return isinstance(task, HistoricalRetrievalTask) or task.only_latest + def build(self) -> ExecutionPlan: last_node = self.build_source_node() @@ -80,7 +83,9 @@ def build(self) -> ExecutionPlan: if self._should_aggregate(): last_node = self.build_aggregation_node(last_node) - elif isinstance(self.task, HistoricalRetrievalTask): + + # Dedupe only if not aggregated + elif self._should_dedupe(self.task): last_node = self.build_dedup_node(last_node) if self._should_transform(): diff --git a/sdk/python/feast/infra/materialization/kubernetes/Dockerfile b/sdk/python/feast/infra/compute_engines/kubernetes/Dockerfile similarity index 100% rename from sdk/python/feast/infra/materialization/kubernetes/Dockerfile rename to sdk/python/feast/infra/compute_engines/kubernetes/Dockerfile diff --git a/sdk/python/feast/infra/materialization/__init__.py b/sdk/python/feast/infra/compute_engines/kubernetes/__init__.py similarity index 100% rename from sdk/python/feast/infra/materialization/__init__.py rename to sdk/python/feast/infra/compute_engines/kubernetes/__init__.py diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py b/sdk/python/feast/infra/compute_engines/kubernetes/k8s_engine.py similarity index 86% rename from sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py rename to sdk/python/feast/infra/compute_engines/kubernetes/k8s_engine.py index adf14eaf419..0dcff09f027 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py +++ b/sdk/python/feast/infra/compute_engines/kubernetes/k8s_engine.py @@ -1,34 +1,27 @@ import logging import uuid -from datetime import datetime from time import sleep -from typing import Callable, List, Literal, Sequence, Union +from typing import List, Literal +import pyarrow as pa import yaml from kubernetes import client, utils from kubernetes import config as k8s_config from kubernetes.client.exceptions import ApiException from kubernetes.utils import FailToCreateError from pydantic import StrictStr -from tqdm import tqdm -from feast import FeatureView, RepoConfig -from feast.batch_feature_view import BatchFeatureView -from feast.entity import Entity +from feast import RepoConfig from feast.infra.common.materialization_job import ( - MaterializationJob, MaterializationJobStatus, MaterializationTask, ) -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, -) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry -from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel -from feast.stream_feature_view import StreamFeatureView from feast.utils import _get_column_names from .k8s_materialization_job import KubernetesMaterializationJob @@ -36,8 +29,8 @@ logger = logging.getLogger(__name__) -class KubernetesMaterializationEngineConfig(FeastConfigBaseModel): - """Batch Materialization Engine config for Kubernetes""" +class KubernetesComputeEngineConfig(FeastConfigBaseModel): + """Batch Compute Engine config for Kubernetes""" type: Literal["k8s"] = "k8s" """ Materialization type selector""" @@ -94,7 +87,14 @@ class KubernetesMaterializationEngineConfig(FeastConfigBaseModel): """(optional) Print pod logs on job failure. Only applies to synchronous materialization""" -class KubernetesMaterializationEngine(BatchMaterializationEngine): +class KubernetesComputeEngine(ComputeEngine): + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> pa.Table: + raise NotImplementedError( + "KubernetesComputeEngine does not support get_historical_features()" + ) + def __init__( self, *, @@ -121,57 +121,17 @@ def __init__( self.batch_engine_config = repo_config.batch_engine self.namespace = self.batch_engine_config.namespace - def update( - self, - project: str, - views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] - ], - views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] - ], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - ): - """This method ensures that any necessary infrastructure or resources needed by the - engine are set up ahead of materialization.""" - pass - - def teardown_infra( - self, - project: str, - fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], - entities: Sequence[Entity], - ): - """This method ensures that any infrastructure or resources set up by ``update()``are torn down.""" - pass - - def materialize( - self, - registry: BaseRegistry, - tasks: List[MaterializationTask], - ) -> List[MaterializationJob]: - return [ - self._materialize_one( - registry, - task.feature_view, - task.start_time, - task.end_time, - task.project, - task.tqdm_builder, - ) - for task in tasks - ] - def _materialize_one( self, registry: BaseRegistry, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], - start_date: datetime, - end_date: datetime, - project: str, - tqdm_builder: Callable[[int], tqdm], + task: MaterializationTask, + **kwargs, ): + feature_view = task.feature_view + start_date = task.start_time + end_date = task.end_time + project = task.project + entities = [] for entity_name in feature_view.entities: entities.append(registry.get_entity(entity_name, project)) diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py b/sdk/python/feast/infra/compute_engines/kubernetes/k8s_materialization_job.py similarity index 100% rename from sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py rename to sdk/python/feast/infra/compute_engines/kubernetes/k8s_materialization_job.py diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py b/sdk/python/feast/infra/compute_engines/kubernetes/k8s_materialization_task.py similarity index 100% rename from sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py rename to sdk/python/feast/infra/compute_engines/kubernetes/k8s_materialization_task.py diff --git a/sdk/python/feast/infra/materialization/kubernetes/main.py b/sdk/python/feast/infra/compute_engines/kubernetes/main.py similarity index 100% rename from sdk/python/feast/infra/materialization/kubernetes/main.py rename to sdk/python/feast/infra/compute_engines/kubernetes/main.py diff --git a/sdk/python/feast/infra/compute_engines/local/backends/factory.py b/sdk/python/feast/infra/compute_engines/local/backends/factory.py index 0a5f40cccf2..6d3774f6393 100644 --- a/sdk/python/feast/infra/compute_engines/local/backends/factory.py +++ b/sdk/python/feast/infra/compute_engines/local/backends/factory.py @@ -23,7 +23,11 @@ def from_name(name: str) -> DataFrameBackend: @staticmethod def infer_from_entity_df(entity_df) -> Optional[DataFrameBackend]: - if isinstance(entity_df, pyarrow.Table) or isinstance(entity_df, pd.DataFrame): + if ( + not entity_df + or isinstance(entity_df, pyarrow.Table) + or isinstance(entity_df, pd.DataFrame) + ): return PandasBackend() if BackendFactory._is_polars(entity_df): diff --git a/sdk/python/feast/infra/compute_engines/local/compute.py b/sdk/python/feast/infra/compute_engines/local/compute.py index 5b5fa7c06ab..0b99a58c304 100644 --- a/sdk/python/feast/infra/compute_engines/local/compute.py +++ b/sdk/python/feast/infra/compute_engines/local/compute.py @@ -1,5 +1,12 @@ -from typing import Optional +from typing import Optional, Sequence, Union +from feast import ( + BatchFeatureView, + Entity, + FeatureView, + OnDemandFeatureView, + StreamFeatureView, +) from feast.infra.common.materialization_job import ( MaterializationJobStatus, MaterializationTask, @@ -10,11 +17,36 @@ from feast.infra.compute_engines.local.backends.base import DataFrameBackend from feast.infra.compute_engines.local.backends.factory import BackendFactory from feast.infra.compute_engines.local.feature_builder import LocalFeatureBuilder -from feast.infra.compute_engines.local.job import LocalRetrievalJob -from feast.infra.materialization.local_engine import LocalMaterializationJob +from feast.infra.compute_engines.local.job import ( + LocalMaterializationJob, + LocalRetrievalJob, +) +from feast.infra.registry.base_registry import BaseRegistry class LocalComputeEngine(ComputeEngine): + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ): + pass + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ): + pass + def __init__(self, backend: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.backend_name = backend @@ -28,9 +60,11 @@ def _get_backend(self, context: ExecutionContext) -> DataFrameBackend: return backend raise ValueError("Could not infer backend from context.entity_df") - def materialize(self, task: MaterializationTask) -> LocalMaterializationJob: + def _materialize_one( + self, registry: BaseRegistry, task: MaterializationTask, **kwargs + ) -> LocalMaterializationJob: job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" - context = self.get_execution_context(task) + context = self.get_execution_context(registry, task) backend = self._get_backend(context) try: @@ -50,9 +84,9 @@ def materialize(self, task: MaterializationTask) -> LocalMaterializationJob: ) def get_historical_features( - self, task: HistoricalRetrievalTask + self, registry: BaseRegistry, task: HistoricalRetrievalTask ) -> LocalRetrievalJob: - context = self.get_execution_context(task) + context = self.get_execution_context(registry, task) backend = self._get_backend(context) try: diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index 4f9dcc871d5..e3e29099360 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -26,7 +26,10 @@ def __init__( self.backend = backend def build_source_node(self): - node = LocalSourceReadNode("source", self.feature_view, self.task) + source = self.feature_view.batch_source + start_time = self.task.start_time + end_time = self.task.end_time + node = LocalSourceReadNode("source", source, start_time, end_time) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/local/job.py b/sdk/python/feast/infra/compute_engines/local/job.py index 530bee8d59b..eaf003e55ea 100644 --- a/sdk/python/feast/infra/compute_engines/local/job.py +++ b/sdk/python/feast/infra/compute_engines/local/job.py @@ -1,13 +1,18 @@ +from dataclasses import dataclass from typing import List, Optional, cast import pandas as pd import pyarrow -from feast import OnDemandFeatureView +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, +) from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.on_demand_feature_view import OnDemandFeatureView from feast.saved_dataset import SavedDatasetStorage @@ -75,3 +80,32 @@ def to_sql(self) -> str: raise NotImplementedError( "SQL generation is not supported in LocalRetrievalJob" ) + + +@dataclass +class LocalMaterializationJob(MaterializationJob): + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + error: Optional[BaseException] = None, + ) -> None: + super().__init__() + self._job_id: str = job_id + self._status: MaterializationJobStatus = status + self._error: Optional[BaseException] = error + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + return False + + def job_id(self) -> str: + return self._job_id + + def url(self) -> Optional[str]: + return None diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index 709b592f97c..a8c4405dd06 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -9,6 +9,9 @@ from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue from feast.infra.compute_engines.local.backends.base import DataFrameBackend from feast.infra.compute_engines.local.local_node import LocalNode +from feast.infra.compute_engines.utils import ( + create_offline_store_retrieval_job, +) from feast.infra.offline_stores.offline_utils import ( infer_event_timestamp_from_entity_df, ) @@ -31,26 +34,18 @@ def __init__( self.end_time = end_time def execute(self, context: ExecutionContext) -> ArrowTableValue: - offline_store = context.offline_store - ( - join_key_columns, - feature_name_columns, - timestamp_field, - created_timestamp_column, - ) = context.column_info - - # 📥 Reuse Feast's robust query resolver - retrieval_job = offline_store.pull_all_from_table_or_query( - config=context.repo_config, + retrieval_job = create_offline_store_retrieval_job( data_source=self.source, - join_key_columns=join_key_columns, - feature_name_columns=feature_name_columns, - timestamp_field=timestamp_field, - created_timestamp_column=created_timestamp_column, - start_date=self.start_time, - end_date=self.end_time, + context=context, + start_time=self.start_time, + end_time=self.end_time, ) arrow_table = retrieval_job.to_arrow() + field_mapping = context.column_info.field_mapping + if field_mapping: + arrow_table = arrow_table.rename_columns( + [field_mapping.get(col, col) for col in arrow_table.column_names] + ) return ArrowTableValue(data=arrow_table) @@ -63,8 +58,9 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: feature_table = self.get_single_table(context).data if context.entity_df is None: - context.node_outputs[self.name] = feature_table - return feature_table + output = ArrowTableValue(feature_table) + context.node_outputs[self.name] = output + return output entity_table = pa.Table.from_pandas(context.entity_df) feature_df = self.backend.from_arrow(feature_table) @@ -75,13 +71,15 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: entity_schema ) - join_keys, feature_cols, ts_col, created_ts_col = context.column_info + column_info = context.column_info entity_df = self.backend.rename_columns( entity_df, {entity_df_event_timestamp_col: ENTITY_TS_ALIAS} ) - joined_df = self.backend.join(feature_df, entity_df, on=join_keys, how="left") + joined_df = self.backend.join( + feature_df, entity_df, on=column_info.join_keys, how="left" + ) result = self.backend.to_arrow(joined_df) output = ArrowTableValue(result) context.node_outputs[self.name] = output @@ -105,18 +103,18 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: input_table = self.get_single_table(context).data df = self.backend.from_arrow(input_table) - _, _, ts_col, _ = context.column_info + timestamp_column = context.column_info.timestamp_column if ENTITY_TS_ALIAS in self.backend.columns(df): # filter where feature.ts <= entity.event_timestamp - df = df[df[ts_col] <= df[ENTITY_TS_ALIAS]] + df = df[df[timestamp_column] <= df[ENTITY_TS_ALIAS]] # TTL: feature.ts >= entity.event_timestamp - ttl if self.ttl: lower_bound = df[ENTITY_TS_ALIAS] - self.backend.to_timedelta_value( self.ttl ) - df = df[df[ts_col] >= lower_bound] + df = df[df[timestamp_column] >= lower_bound] # Optional user-defined filter expression (e.g., "value > 0") if self.filter_expr: @@ -157,17 +155,21 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: df = self.backend.from_arrow(input_table) # Extract join_keys, timestamp, and created_ts from context - join_keys, _, ts_col, created_ts_col = context.column_info + column_info = context.column_info # Dedup strategy: sort and drop_duplicates - sort_keys = [ts_col] - if created_ts_col: - sort_keys.append(created_ts_col) - - dedup_keys = join_keys + [ENTITY_TS_ALIAS] - df = self.backend.drop_duplicates( - df, keys=dedup_keys, sort_by=sort_keys, ascending=False - ) + dedup_keys = context.column_info.join_keys + if dedup_keys: + sort_keys = [column_info.timestamp_column] + if ( + column_info.created_timestamp_column + and column_info.created_timestamp_column in df.columns + ): + sort_keys.append(column_info.created_timestamp_column) + + df = self.backend.drop_duplicates( + df, keys=dedup_keys, sort_by=sort_keys, ascending=False + ) result = self.backend.to_arrow(df) output = ArrowTableValue(result) context.node_outputs[self.name] = output @@ -219,6 +221,9 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: input_table = self.get_single_table(context).data context.node_outputs[self.name] = input_table + if input_table.num_rows == 0: + return input_table + if self.feature_view.online: online_store = context.online_store diff --git a/sdk/python/feast/infra/materialization/contrib/__init__.py b/sdk/python/feast/infra/compute_engines/snowflake/__init__.py similarity index 100% rename from sdk/python/feast/infra/materialization/contrib/__init__.py rename to sdk/python/feast/infra/compute_engines/snowflake/__init__.py diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/compute_engines/snowflake/snowflake_engine.py similarity index 90% rename from sdk/python/feast/infra/materialization/snowflake_engine.py rename to sdk/python/feast/infra/compute_engines/snowflake/snowflake_engine.py index 9c535c334e3..31c420613a8 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/compute_engines/snowflake/snowflake_engine.py @@ -1,11 +1,11 @@ import os import shutil -from dataclasses import dataclass -from datetime import datetime, timezone -from typing import Callable, List, Literal, Optional, Sequence, Union +from datetime import timezone +from typing import Literal, Optional, Sequence, Union import click import pandas as pd +import pyarrow as pa from colorama import Fore, Style from pydantic import ConfigDict, Field, StrictStr from tqdm import tqdm @@ -15,12 +15,13 @@ from feast.entity import Entity from feast.feature_view import DUMMY_ENTITY_ID, FeatureView from feast.infra.common.materialization_job import ( - MaterializationJob, MaterializationJobStatus, MaterializationTask, ) -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.snowflake.snowflake_materialization_job import ( + SnowflakeMaterializationJob, ) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore @@ -42,8 +43,8 @@ from feast.utils import _coerce_datetime, _get_column_names -class SnowflakeMaterializationEngineConfig(FeastConfigBaseModel): - """Batch Materialization Engine config for Snowflake Snowpark Python UDFs""" +class SnowflakeComputeEngineConfig(FeastConfigBaseModel): + """Batch Compute Engine config for Snowflake Snowpark Python UDFs""" type: Literal["snowflake.engine"] = "snowflake.engine" """ Type selector""" @@ -89,36 +90,14 @@ class SnowflakeMaterializationEngineConfig(FeastConfigBaseModel): model_config = ConfigDict(populate_by_name=True) -@dataclass -class SnowflakeMaterializationJob(MaterializationJob): - def __init__( - self, - job_id: str, - status: MaterializationJobStatus, - error: Optional[BaseException] = None, - ) -> None: - super().__init__() - self._job_id: str = job_id - self._status: MaterializationJobStatus = status - self._error: Optional[BaseException] = error - - def status(self) -> MaterializationJobStatus: - return self._status - - def error(self) -> Optional[BaseException]: - return self._error - - def should_be_retried(self) -> bool: - return False - - def job_id(self) -> str: - return self._job_id - - def url(self) -> Optional[str]: - return None - +class SnowflakeComputeEngine(ComputeEngine): + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> pa.Table: + raise NotImplementedError( + "SnowflakeComputeEngine does not support get_historical_features" + ) -class SnowflakeMaterializationEngine(BatchMaterializationEngine): def update( self, project: str, @@ -209,7 +188,7 @@ def __init__( **kwargs, ): assert repo_config.offline_store.type == "snowflake.offline", ( - "To use SnowflakeMaterializationEngine, you must use Snowflake as an offline store." + "To use Snowflake Compute Engine, you must use Snowflake as an offline store." ) super().__init__( @@ -219,30 +198,18 @@ def __init__( **kwargs, ) - def materialize( - self, registry, tasks: List[MaterializationTask] - ) -> List[MaterializationJob]: - return [ - self._materialize_one( - registry, - task.feature_view, - task.start_time, - task.end_time, - task.project, - task.tqdm_builder, - ) - for task in tasks - ] - def _materialize_one( self, registry: BaseRegistry, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], - start_date: datetime, - end_date: datetime, - project: str, - tqdm_builder: Callable[[int], tqdm], + task: MaterializationTask, + **kwargs, ): + feature_view = task.feature_view + start_date = task.start_time + end_date = task.end_time + project = task.project + tqdm_builder = task.tqdm_builder if task.tqdm_builder else tqdm + assert isinstance(feature_view, BatchFeatureView) or isinstance( feature_view, FeatureView ), ( diff --git a/sdk/python/feast/infra/compute_engines/snowflake/snowflake_materialization_job.py b/sdk/python/feast/infra/compute_engines/snowflake/snowflake_materialization_job.py new file mode 100644 index 00000000000..fdacd0edf66 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/snowflake/snowflake_materialization_job.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import Optional + +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, +) + + +@dataclass +class SnowflakeMaterializationJob(MaterializationJob): + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + error: Optional[BaseException] = None, + ) -> None: + super().__init__() + self._job_id: str = job_id + self._status: MaterializationJobStatus = status + self._error: Optional[BaseException] = error + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + return False + + def job_id(self) -> str: + return self._job_id + + def url(self) -> Optional[str]: + return None diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 981e786cf7f..618a3b780f6 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,42 +1,117 @@ +import logging +from datetime import datetime +from typing import Dict, Literal, Optional, Sequence, Union, cast + +from pydantic import StrictStr + +from feast import ( + BatchFeatureView, + Entity, + FeatureView, + OnDemandFeatureView, + StreamFeatureView, +) from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.common.serde import SerializedArtifacts from feast.infra.compute_engines.base import ComputeEngine from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder -from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob -from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session -from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( +from feast.infra.compute_engines.spark.job import ( + SparkDAGRetrievalJob, SparkMaterializationJob, ) +from feast.infra.compute_engines.spark.utils import ( + get_or_create_new_spark_session, + map_in_pandas, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkRetrievalJob, +) from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.registry.base_registry import BaseRegistry +from feast.repo_config import FeastConfigBaseModel +from feast.utils import _get_column_names + + +class SparkComputeEngineConfig(FeastConfigBaseModel): + type: Literal["spark.engine"] = "spark.engine" + """ Spark Compute type selector""" + + spark_conf: Optional[Dict[str, str]] = None + """ Configuration overlay for the spark session """ + + staging_location: Optional[StrictStr] = None + """ Remote path for batch materialization jobs""" + + region: Optional[StrictStr] = None + """ AWS Region if applicable for s3-based staging locations""" + + partitions: int = 0 + """Number of partitions to use when writing data to online store. If 0, no repartitioning is done""" class SparkComputeEngine(ComputeEngine): + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ): + pass + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ): + pass + def __init__( self, offline_store, online_store, - registry, repo_config, **kwargs, ): super().__init__( offline_store=offline_store, online_store=online_store, - registry=registry, repo_config=repo_config, **kwargs, ) self.spark_session = get_or_create_new_spark_session() - def materialize(self, task: MaterializationTask) -> MaterializationJob: + def _materialize_one( + self, + registry: BaseRegistry, + task: MaterializationTask, + from_offline_store: bool = False, + **kwargs, + ) -> MaterializationJob: + if from_offline_store: + return self._materialize_from_offline_store( + registry=registry, + feature_view=task.feature_view, + start_date=task.start_time, + end_date=task.end_time, + project=task.project, + ) + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" # ✅ 1. Build typed execution context - context = self.get_execution_context(task) + context = self.get_execution_context(registry, task) try: # ✅ 2. Construct Feature Builder and run it @@ -58,12 +133,80 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: job_id=job_id, status=MaterializationJobStatus.ERROR, error=e ) - def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob: + def _materialize_from_offline_store( + self, + registry: BaseRegistry, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + start_date: datetime, + end_date: datetime, + project: str, + ): + """ + Legacy materialization method for Spark Compute Engine. This method is used to materialize features from the + offline store to the online store directly. + """ + logging.warning( + "Materializing from offline store will be deprecated in the future. Please use the new " + "materialization API." + ) + entities = [] + for entity_name in feature_view.entities: + entities.append(registry.get_entity(entity_name, project)) + + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = _get_column_names(feature_view, entities) + + job_id = f"{feature_view.name}-{start_date}-{end_date}" + + try: + offline_job = cast( + SparkRetrievalJob, + self.offline_store.pull_latest_from_table_or_query( + config=self.repo_config, + data_source=feature_view.batch_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ), + ) + + serialized_artifacts = SerializedArtifacts.serialize( + feature_view=feature_view, repo_config=self.repo_config + ) + + spark_df = offline_job.to_spark_df() + if self.repo_config.batch_engine.partitions != 0: + spark_df = spark_df.repartition( + self.repo_config.batch_engine.partitions + ) + + spark_df.mapInPandas( + lambda x: map_in_pandas(x, serialized_artifacts), "status int" + ).count() # dummy action to force evaluation + + return SparkMaterializationJob( + job_id=job_id, status=MaterializationJobStatus.SUCCEEDED + ) + except BaseException as e: + return SparkMaterializationJob( + job_id=job_id, status=MaterializationJobStatus.ERROR, error=e + ) + + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> RetrievalJob: if isinstance(task.entity_df, str): raise NotImplementedError("SQL-based entity_df is not yet supported in DAG") # ✅ 1. Build typed execution context - context = self.get_execution_context(task) + context = self.get_execution_context(registry, task) try: # ✅ 2. Construct Feature Builder and run it diff --git a/sdk/python/feast/infra/compute_engines/spark/config.py b/sdk/python/feast/infra/compute_engines/spark/config.py deleted file mode 100644 index 070cf204dce..00000000000 --- a/sdk/python/feast/infra/compute_engines/spark/config.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Dict, Optional - -from pydantic import StrictStr - -from feast.repo_config import FeastConfigBaseModel - - -class SparkComputeConfig(FeastConfigBaseModel): - type: StrictStr = "spark" - """ Spark Compute type selector""" - - spark_conf: Optional[Dict[str, str]] = None - """ Configuration overlay for the spark session """ - # sparksession is not serializable and we dont want to pass it around as an argument - - staging_location: Optional[StrictStr] = None - """ Remote path for batch materialization jobs""" - - region: Optional[StrictStr] = None - """ AWS Region if applicable for s3-based staging locations""" diff --git a/sdk/python/feast/infra/compute_engines/spark/job.py b/sdk/python/feast/infra/compute_engines/spark/job.py index 0f343789d96..6b39da00758 100644 --- a/sdk/python/feast/infra/compute_engines/spark/job.py +++ b/sdk/python/feast/infra/compute_engines/spark/job.py @@ -1,9 +1,14 @@ +from dataclasses import dataclass from typing import List, Optional import pyspark from pyspark.sql import SparkSession from feast import OnDemandFeatureView, RepoConfig +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, +) from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( @@ -54,3 +59,32 @@ def to_spark_df(self) -> pyspark.sql.DataFrame: def to_sql(self) -> str: assert self._plan is not None, "Execution plan is not set" return self._plan.to_sql(self._context) + + +@dataclass +class SparkMaterializationJob(MaterializationJob): + def url(self) -> Optional[str]: + pass + + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + error: Optional[BaseException] = None, + ) -> None: + super().__init__() + self._job_id: str = job_id + self._status: MaterializationJobStatus = status + self._error: Optional[BaseException] = error + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + return False + + def job_id(self) -> str: + return self._job_id diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index 7fe7fdb45e8..8d00f124439 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -13,6 +13,9 @@ from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.compute_engines.spark.utils import map_in_arrow +from feast.infra.compute_engines.utils import ( + create_offline_store_retrieval_job, +) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkRetrievalJob, _get_entity_schema, @@ -62,24 +65,12 @@ def __init__( self.end_time = end_time def execute(self, context: ExecutionContext) -> DAGValue: - offline_store = context.offline_store - ( - join_key_columns, - feature_name_columns, - timestamp_field, - created_timestamp_column, - ) = context.column_info - - # 📥 Reuse Feast's robust query resolver - retrieval_job = offline_store.pull_all_from_table_or_query( - config=context.repo_config, + column_info = context.column_info + retrieval_job = create_offline_store_retrieval_job( data_source=self.source, - join_key_columns=join_key_columns, - feature_name_columns=feature_name_columns, - timestamp_field=timestamp_field, - created_timestamp_column=created_timestamp_column, - start_date=self.start_time, - end_date=self.end_time, + context=context, + start_time=self.start_time, + end_time=self.end_time, ) spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() @@ -88,8 +79,8 @@ def execute(self, context: ExecutionContext) -> DAGValue: format=DAGFormat.SPARK, metadata={ "source": "feature_view_batch_source", - "timestamp_field": timestamp_field, - "created_timestamp_column": created_timestamp_column, + "timestamp_field": column_info.timestamp_column, + "created_timestamp_column": column_info.created_timestamp_column, "start_date": self.start_time, "end_date": self.end_time, }, @@ -171,7 +162,7 @@ def execute(self, context: ExecutionContext) -> DAGValue: ) # Get timestamp fields from feature view - join_keys, feature_cols, ts_col, created_ts_col = context.column_info + column_info = context.column_info # Rename entity_df event_timestamp_col to match feature_df entity_df = rename_entity_ts_column( @@ -180,10 +171,13 @@ def execute(self, context: ExecutionContext) -> DAGValue: ) # Perform left join on entity df - joined = feature_df.join(entity_df, on=join_keys, how="left") + # TODO: give a config option to use other join types + joined = feature_df.join(entity_df, on=column_info.join_keys, how="left") return DAGValue( - data=joined, format=DAGFormat.SPARK, metadata={"joined_on": join_keys} + data=joined, + format=DAGFormat.SPARK, + metadata={"joined_on": column_info.join_keys}, ) @@ -206,12 +200,14 @@ def execute(self, context: ExecutionContext) -> DAGValue: input_df: DataFrame = input_value.data # Get timestamp fields from feature view - _, _, ts_col, _ = context.column_info + timestamp_column = context.column_info.timestamp_column # Optional filter: feature.ts <= entity.event_timestamp filtered_df = input_df if ENTITY_TS_ALIAS in input_df.columns: - filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS)) + filtered_df = filtered_df.filter( + F.col(timestamp_column) <= F.col(ENTITY_TS_ALIAS) + ) # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl if self.ttl: @@ -219,7 +215,7 @@ def execute(self, context: ExecutionContext) -> DAGValue: lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr( f"INTERVAL {ttl_seconds} seconds" ) - filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound) + filtered_df = filtered_df.filter(F.col(timestamp_column) >= lower_bound) # Optional custom filter condition if self.filter_condition: @@ -247,21 +243,23 @@ def execute(self, context: ExecutionContext) -> DAGValue: input_df: DataFrame = input_value.data # Get timestamp fields from feature view - join_keys, _, ts_col, created_ts_col = context.column_info + colmun_info = context.column_info - # Dedup based on join keys and event timestamp + # Dedup based on join keys and event timestamp column # Dedup with row_number - partition_cols = join_keys + [ENTITY_TS_ALIAS] - ordering = [F.col(ts_col).desc()] - if created_ts_col: - ordering.append(F.col(created_ts_col).desc()) - - window = Window.partitionBy(*partition_cols).orderBy(*ordering) - deduped_df = ( - input_df.withColumn("row_num", F.row_number().over(window)) - .filter("row_num = 1") - .drop("row_num") - ) + partition_cols = context.column_info.join_keys + deduped_df = input_df + if partition_cols: + ordering = [F.col(colmun_info.timestamp_column).desc()] + if colmun_info.created_timestamp_column: + ordering.append(F.col(colmun_info.created_timestamp_column).desc()) + + window = Window.partitionBy(*partition_cols).orderBy(*ordering) + deduped_df = ( + input_df.withColumn("row_num", F.row_number().over(window)) + .filter("row_num = 1") + .drop("row_num") + ) return DAGValue( data=deduped_df, diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 7808ca0118a..4e429f8e075 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -1,11 +1,13 @@ from typing import Dict, Iterable, Literal, Optional +import pandas as pd +import pyarrow import pyarrow as pa from pyspark import SparkConf from pyspark.sql import SparkSession from feast.infra.common.serde import SerializedArtifacts -from feast.utils import _convert_arrow_to_proto +from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping def get_or_create_new_spark_session( @@ -64,3 +66,45 @@ def map_in_arrow( ) yield batch + + +def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts): + for pdf in iterator: + if pdf.shape[0] == 0: + print("Skipping") + return + + table = pyarrow.Table.from_pandas(pdf) + + ( + feature_view, + online_store, + _, + repo_config, + ) = serialized_artifacts.unserialize() + + if feature_view.batch_source.field_mapping is not None: + # Spark offline store does the field mapping in pull_latest_from_table_or_query() call + # This may be needed in future if this materialization engine supports other offline stores + table = _run_pyarrow_field_mapping( + table, feature_view.batch_source.field_mapping + ) + + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + + rows_to_write = _convert_arrow_to_proto( + table, feature_view, join_key_to_value_type + ) + online_store.online_write_batch( + repo_config, + feature_view, + rows_to_write, + lambda x: None, + ) + + yield pd.DataFrame( + [pd.Series(range(1, 2))] + ) # dummy result because mapInPandas needs to return something diff --git a/sdk/python/feast/infra/compute_engines/utils.py b/sdk/python/feast/infra/compute_engines/utils.py new file mode 100644 index 00000000000..09a13a72193 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/utils.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Optional + +from feast.data_source import DataSource +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.offline_stores.offline_store import RetrievalJob + + +def create_offline_store_retrieval_job( + data_source: DataSource, + context: ExecutionContext, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, +) -> RetrievalJob: + """ + Create a retrieval job for the offline store. + Args: + data_source: The data source to pull from. + context: + start_time: + end_time: + + Returns: + + """ + offline_store = context.offline_store + column_info = context.column_info + # 📥 Reuse Feast's robust query resolver + retrieval_job = offline_store.pull_all_from_table_or_query( + config=context.repo_config, + data_source=data_source, + join_key_columns=column_info.join_keys, + feature_name_columns=column_info.feature_cols, + timestamp_field=column_info.ts_col, + created_timestamp_column=column_info.created_ts_col, + start_date=start_time, + end_date=end_time, + ) + return retrieval_job diff --git a/sdk/python/feast/infra/materialization/batch_materialization_engine.py b/sdk/python/feast/infra/materialization/batch_materialization_engine.py deleted file mode 100644 index 17bc6134cdb..00000000000 --- a/sdk/python/feast/infra/materialization/batch_materialization_engine.py +++ /dev/null @@ -1,94 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Sequence, Union - -from feast.batch_feature_view import BatchFeatureView -from feast.entity import Entity -from feast.feature_view import FeatureView -from feast.infra.common.materialization_job import ( - MaterializationJob, - MaterializationTask, -) -from feast.infra.offline_stores.offline_store import OfflineStore -from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.registry.base_registry import BaseRegistry -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.repo_config import RepoConfig -from feast.stream_feature_view import StreamFeatureView - - -class BatchMaterializationEngine(ABC): - """ - The interface that Feast uses to control the compute system that handles batch materialization. - """ - - def __init__( - self, - *, - repo_config: RepoConfig, - offline_store: OfflineStore, - online_store: OnlineStore, - **kwargs, - ): - self.repo_config = repo_config - self.offline_store = offline_store - self.online_store = online_store - - @abstractmethod - def update( - self, - project: str, - views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView] - ], - views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] - ], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - ): - """ - Prepares cloud resources required for batch materialization for the specified set of Feast objects. - - Args: - project: Feast project to which the objects belong. - views_to_delete: Feature views whose corresponding infrastructure should be deleted. - views_to_keep: Feature views whose corresponding infrastructure should not be deleted, and - may need to be updated. - entities_to_delete: Entities whose corresponding infrastructure should be deleted. - entities_to_keep: Entities whose corresponding infrastructure should not be deleted, and - may need to be updated. - """ - pass - - @abstractmethod - def materialize( - self, registry: BaseRegistry, tasks: List[MaterializationTask] - ) -> List[MaterializationJob]: - """ - Materialize data from the offline store to the online store for this feature repo. - - Args: - registry: The registry for the current feature store. - tasks: A list of individual materialization tasks. - - Returns: - A list of materialization jobs representing each task. - """ - pass - - @abstractmethod - def teardown_infra( - self, - project: str, - fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], - entities: Sequence[Entity], - ): - """ - Tears down all cloud resources used by the materialization engine for the specified set of Feast objects. - - Args: - project: Feast project to which the objects belong. - fvs: Feature views whose corresponding infrastructure should be deleted. - entities: Entities whose corresponding infrastructure should be deleted. - """ - pass diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py deleted file mode 100644 index c4809df3678..00000000000 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ /dev/null @@ -1,234 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Callable, List, Literal, Optional, Sequence, Union, cast - -import pandas as pd -import pyarrow -from tqdm import tqdm - -from feast.batch_feature_view import BatchFeatureView -from feast.entity import Entity -from feast.feature_view import FeatureView -from feast.infra.common.materialization_job import ( - MaterializationJob, - MaterializationJobStatus, - MaterializationTask, -) -from feast.infra.common.serde import SerializedArtifacts -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, -) -from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( - SparkOfflineStore, - SparkRetrievalJob, -) -from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.registry.base_registry import BaseRegistry -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.stream_feature_view import StreamFeatureView -from feast.utils import ( - _convert_arrow_to_proto, - _get_column_names, - _run_pyarrow_field_mapping, -) - - -class SparkMaterializationEngineConfig(FeastConfigBaseModel): - """Batch Materialization Engine config for spark engine""" - - type: Literal["spark.engine"] = "spark.engine" - """ Type selector""" - - partitions: int = 0 - """Number of partitions to use when writing data to online store. If 0, no repartitioning is done""" - - -@dataclass -class SparkMaterializationJob(MaterializationJob): - def __init__( - self, - job_id: str, - status: MaterializationJobStatus, - error: Optional[BaseException] = None, - ) -> None: - super().__init__() - self._job_id: str = job_id - self._status: MaterializationJobStatus = status - self._error: Optional[BaseException] = error - - def status(self) -> MaterializationJobStatus: - return self._status - - def error(self) -> Optional[BaseException]: - return self._error - - def should_be_retried(self) -> bool: - return False - - def job_id(self) -> str: - return self._job_id - - def url(self) -> Optional[str]: - return None - - -class SparkMaterializationEngine(BatchMaterializationEngine): - def update( - self, - project: str, - views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] - ], - views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] - ], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - ): - # Nothing to set up. - pass - - def teardown_infra( - self, - project: str, - fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], - entities: Sequence[Entity], - ): - # Nothing to tear down. - pass - - def __init__( - self, - *, - repo_config: RepoConfig, - offline_store: SparkOfflineStore, - online_store: OnlineStore, - **kwargs, - ): - if not isinstance(offline_store, SparkOfflineStore): - raise TypeError( - "SparkMaterializationEngine is only compatible with the SparkOfflineStore" - ) - super().__init__( - repo_config=repo_config, - offline_store=offline_store, - online_store=online_store, - **kwargs, - ) - - def materialize( - self, registry, tasks: List[MaterializationTask] - ) -> List[MaterializationJob]: - return [ - self._materialize_one( - registry, - task.feature_view, - task.start_time, - task.end_time, - task.project, - task.tqdm_builder, - ) - for task in tasks - ] - - def _materialize_one( - self, - registry: BaseRegistry, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], - start_date: datetime, - end_date: datetime, - project: str, - tqdm_builder: Callable[[int], tqdm], - ): - entities = [] - for entity_name in feature_view.entities: - entities.append(registry.get_entity(entity_name, project)) - - ( - join_key_columns, - feature_name_columns, - timestamp_field, - created_timestamp_column, - ) = _get_column_names(feature_view, entities) - - job_id = f"{feature_view.name}-{start_date}-{end_date}" - - try: - offline_job = cast( - SparkRetrievalJob, - self.offline_store.pull_latest_from_table_or_query( - config=self.repo_config, - data_source=feature_view.batch_source, - join_key_columns=join_key_columns, - feature_name_columns=feature_name_columns, - timestamp_field=timestamp_field, - created_timestamp_column=created_timestamp_column, - start_date=start_date, - end_date=end_date, - ), - ) - - serialized_artifacts = SerializedArtifacts.serialize( - feature_view=feature_view, repo_config=self.repo_config - ) - - spark_df = offline_job.to_spark_df() - if self.repo_config.batch_engine.partitions != 0: - spark_df = spark_df.repartition( - self.repo_config.batch_engine.partitions - ) - - spark_df.mapInPandas( - lambda x: _map_by_partition(x, serialized_artifacts), "status int" - ).count() # dummy action to force evaluation - - return SparkMaterializationJob( - job_id=job_id, status=MaterializationJobStatus.SUCCEEDED - ) - except BaseException as e: - return SparkMaterializationJob( - job_id=job_id, status=MaterializationJobStatus.ERROR, error=e - ) - - -def _map_by_partition(iterator, serialized_artifacts: SerializedArtifacts): - for pdf in iterator: - if pdf.shape[0] == 0: - print("Skipping") - return - - table = pyarrow.Table.from_pandas(pdf) - - ( - feature_view, - online_store, - _, - repo_config, - ) = serialized_artifacts.unserialize() - - if feature_view.batch_source.field_mapping is not None: - # Spark offline store does the field mapping in pull_latest_from_table_or_query() call - # This may be needed in future if this materialization engine supports other offline stores - table = _run_pyarrow_field_mapping( - table, feature_view.batch_source.field_mapping - ) - - join_key_to_value_type = { - entity.name: entity.dtype.to_value_type() - for entity in feature_view.entity_columns - } - - rows_to_write = _convert_arrow_to_proto( - table, feature_view, join_key_to_value_type - ) - online_store.online_write_batch( - repo_config, - feature_view, - rows_to_write, - lambda x: None, - ) - - yield pd.DataFrame( - [pd.Series(range(1, 2))] - ) # dummy result because mapInPandas needs to return something diff --git a/sdk/python/feast/infra/materialization/kubernetes/__init__.py b/sdk/python/feast/infra/materialization/kubernetes/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/sdk/python/feast/infra/materialization/local_engine.py b/sdk/python/feast/infra/materialization/local_engine.py deleted file mode 100644 index ed71d11586d..00000000000 --- a/sdk/python/feast/infra/materialization/local_engine.py +++ /dev/null @@ -1,188 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Callable, List, Literal, Optional, Sequence, Union - -from tqdm import tqdm - -from feast.batch_feature_view import BatchFeatureView -from feast.entity import Entity -from feast.feature_view import FeatureView -from feast.infra.common.materialization_job import ( - MaterializationJob, - MaterializationJobStatus, - MaterializationTask, -) -from feast.infra.offline_stores.offline_store import OfflineStore -from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.registry.base_registry import BaseRegistry -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.stream_feature_view import StreamFeatureView -from feast.utils import ( - _convert_arrow_to_proto, - _get_column_names, - _run_pyarrow_field_mapping, -) - -from .batch_materialization_engine import ( - BatchMaterializationEngine, -) - -DEFAULT_BATCH_SIZE = 10_000 - - -class LocalMaterializationEngineConfig(FeastConfigBaseModel): - """Batch Materialization Engine config for local in-process engine""" - - type: Literal["local"] = "local" - """ Type selector""" - - -@dataclass -class LocalMaterializationJob(MaterializationJob): - def __init__( - self, - job_id: str, - status: MaterializationJobStatus, - error: Optional[BaseException] = None, - ) -> None: - super().__init__() - self._job_id: str = job_id - self._status: MaterializationJobStatus = status - self._error: Optional[BaseException] = error - - def status(self) -> MaterializationJobStatus: - return self._status - - def error(self) -> Optional[BaseException]: - return self._error - - def should_be_retried(self) -> bool: - return False - - def job_id(self) -> str: - return self._job_id - - def url(self) -> Optional[str]: - return None - - -class LocalMaterializationEngine(BatchMaterializationEngine): - def update( - self, - project: str, - views_to_delete: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] - ], - views_to_keep: Sequence[ - Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] - ], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - ): - # Nothing to set up. - pass - - def teardown_infra( - self, - project: str, - fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], - entities: Sequence[Entity], - ): - # Nothing to tear down. - pass - - def __init__( - self, - *, - repo_config: RepoConfig, - offline_store: OfflineStore, - online_store: OnlineStore, - **kwargs, - ): - super().__init__( - repo_config=repo_config, - offline_store=offline_store, - online_store=online_store, - **kwargs, - ) - - def materialize( - self, registry, tasks: List[MaterializationTask] - ) -> List[MaterializationJob]: - return [ - self._materialize_one( - registry, - task.feature_view, - task.start_time, - task.end_time, - task.project, - task.tqdm_builder, - ) - for task in tasks - ] - - def _materialize_one( - self, - registry: BaseRegistry, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], - start_date: datetime, - end_date: datetime, - project: str, - tqdm_builder: Callable[[int], tqdm], - ): - entities = [] - for entity_name in feature_view.entities: - entities.append(registry.get_entity(entity_name, project)) - - ( - join_key_columns, - feature_name_columns, - timestamp_field, - created_timestamp_column, - ) = _get_column_names(feature_view, entities) - - job_id = f"{feature_view.name}-{start_date}-{end_date}" - - try: - offline_job = self.offline_store.pull_latest_from_table_or_query( - config=self.repo_config, - data_source=feature_view.batch_source, - join_key_columns=join_key_columns, - feature_name_columns=feature_name_columns, - timestamp_field=timestamp_field, - created_timestamp_column=created_timestamp_column, - start_date=start_date, - end_date=end_date, - ) - - table = offline_job.to_arrow() - - if feature_view.batch_source.field_mapping is not None: - table = _run_pyarrow_field_mapping( - table, feature_view.batch_source.field_mapping - ) - - join_key_to_value_type = { - entity.name: entity.dtype.to_value_type() - for entity in feature_view.entity_columns - } - - with tqdm_builder(table.num_rows) as pbar: - for batch in table.to_batches(DEFAULT_BATCH_SIZE): - rows_to_write = _convert_arrow_to_proto( - batch, feature_view, join_key_to_value_type - ) - self.online_store.online_write_batch( - self.repo_config, - feature_view, - rows_to_write, - lambda x: pbar.update(x), - ) - return LocalMaterializationJob( - job_id=job_id, status=MaterializationJobStatus.SUCCEEDED - ) - except BaseException as e: - return LocalMaterializationJob( - job_id=job_id, status=MaterializationJobStatus.ERROR, error=e - ) diff --git a/sdk/python/feast/infra/offline_stores/dask.py b/sdk/python/feast/infra/offline_stores/dask.py index 87af51337dd..72359c4b793 100644 --- a/sdk/python/feast/infra/offline_stores/dask.py +++ b/sdk/python/feast/infra/offline_stores/dask.py @@ -320,81 +320,101 @@ def pull_latest_from_table_or_query( assert isinstance(config.offline_store, DaskOfflineStoreConfig) assert isinstance(data_source, FileSource) - # Create lazy function that is only called from the RetrievalJob object - def evaluate_offline_job(): - source_df = _read_datasource(data_source, config.repo_path) - - source_df = _normalize_timestamp( - source_df, timestamp_field, created_timestamp_column + def evaluate_func(): + df = DaskOfflineStore.evaluate_offline_job( + config=config, + data_source=data_source, + join_key_columns=join_key_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, ) - - source_columns = set(source_df.columns) - if not set(join_key_columns).issubset(source_columns): - raise FeastJoinKeysDuringMaterialization( - data_source.path, set(join_key_columns), source_columns - ) - ts_columns = ( [timestamp_field, created_timestamp_column] if created_timestamp_column else [timestamp_field] ) - # try-catch block is added to deal with this issue https://github.com/dask/dask/issues/8939. - # TODO(kevjumba): remove try catch when fix is merged upstream in Dask. - try: - if created_timestamp_column: - source_df = source_df.sort_values( - by=created_timestamp_column, - ) - - source_df = source_df.sort_values(by=timestamp_field) - - except ZeroDivisionError: - # Use 1 partition to get around case where everything in timestamp column is the same so the partition algorithm doesn't - # try to divide by zero. - if created_timestamp_column: - source_df = source_df.sort_values( - by=created_timestamp_column, npartitions=1 - ) - - source_df = source_df.sort_values(by=timestamp_field, npartitions=1) - - # TODO: The old implementation is inclusive of start_date and exclusive of end_date. - # Which is inconsistent with other offline stores. - if start_date or end_date: - if start_date and end_date: - source_df = source_df[ - source_df[timestamp_field].between( - start_date, end_date, inclusive="both" - ) - ] - elif start_date: - source_df = source_df[source_df[timestamp_field] >= start_date] - elif end_date: - source_df = source_df[source_df[timestamp_field] <= end_date] - - source_df = source_df.persist() - columns_to_extract = set( join_key_columns + feature_name_columns + ts_columns ) if join_key_columns: - source_df = source_df.drop_duplicates( + df = df.drop_duplicates( join_key_columns, keep="last", ignore_index=True ) else: - source_df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL + df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL columns_to_extract.add(DUMMY_ENTITY_ID) - return source_df[list(columns_to_extract)].persist() + return df[list(columns_to_extract)].persist() # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized return DaskRetrievalJob( - evaluation_function=evaluate_offline_job, + evaluation_function=evaluate_func, full_feature_names=False, repo_path=str(config.repo_path), ) + @staticmethod + def evaluate_offline_job( + config: RepoConfig, + data_source: FileSource, + join_key_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> dd.DataFrame: + # Create lazy function that is only called from the RetrievalJob object + source_df = _read_datasource(data_source, config.repo_path) + + source_df = _normalize_timestamp( + source_df, timestamp_field, created_timestamp_column + ) + + source_columns = set(source_df.columns) + if not set(join_key_columns).issubset(source_columns): + raise FeastJoinKeysDuringMaterialization( + data_source.path, set(join_key_columns), source_columns + ) + + # try-catch block is added to deal with this issue https://github.com/dask/dask/issues/8939. + # TODO(kevjumba): remove try catch when fix is merged upstream in Dask. + try: + if created_timestamp_column: + source_df = source_df.sort_values( + by=created_timestamp_column, + ) + + source_df = source_df.sort_values(by=timestamp_field) + + except ZeroDivisionError: + # Use 1 partition to get around case where everything in timestamp column is the same so the partition algorithm doesn't + # try to divide by zero. + if created_timestamp_column: + source_df = source_df.sort_values( + by=created_timestamp_column, npartitions=1 + ) + + source_df = source_df.sort_values(by=timestamp_field, npartitions=1) + + # TODO: The old implementation is inclusive of start_date and exclusive of end_date. + # Which is inconsistent with other offline stores. + if start_date or end_date: + if start_date and end_date: + source_df = source_df[ + source_df[timestamp_field].between( + start_date, end_date, inclusive="both" + ) + ] + elif start_date: + source_df = source_df[source_df[timestamp_field] >= start_date] + elif end_date: + source_df = source_df[source_df[timestamp_field] <= end_date] + + source_df = source_df.persist() + return source_df + @staticmethod def pull_all_from_table_or_query( config: RepoConfig, @@ -409,16 +429,38 @@ def pull_all_from_table_or_query( assert isinstance(config.offline_store, DaskOfflineStoreConfig) assert isinstance(data_source, FileSource) - return DaskOfflineStore.pull_latest_from_table_or_query( - config=config, - data_source=data_source, - join_key_columns=join_key_columns - + [timestamp_field], # avoid deduplication - feature_name_columns=feature_name_columns, - timestamp_field=timestamp_field, - created_timestamp_column=created_timestamp_column, - start_date=start_date, - end_date=end_date, + def evaluate_func(): + df = DaskOfflineStore.evaluate_offline_job( + config=config, + data_source=data_source, + join_key_columns=join_key_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + ts_columns = ( + [timestamp_field, created_timestamp_column] + if created_timestamp_column + else [timestamp_field] + ) + columns_to_extract = set( + join_key_columns + feature_name_columns + ts_columns + ) + if not join_key_columns: + df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL + columns_to_extract.add(DUMMY_ENTITY_ID) + # TODO: Decides if we want to field mapping for pull_latest_from_table_or_query + # This is default for other offline store. + df = df[list(columns_to_extract)] + df.persist() + return df + + # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized + return DaskRetrievalJob( + evaluation_function=evaluate_func, + full_feature_names=False, + repo_path=str(config.repo_path), ) @staticmethod @@ -642,7 +684,7 @@ def _merge( def _normalize_timestamp( df_to_join: dd.DataFrame, timestamp_field: str, - created_timestamp_column: str, + created_timestamp_column: Optional[str] = None, ) -> dd.DataFrame: df_to_join_types = df_to_join.dtypes timestamp_field_type = df_to_join_types[timestamp_field] diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index b30e695de52..b532ac563d4 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -28,10 +28,10 @@ MaterializationJobStatus, MaterializationTask, ) -from feast.infra.infra_object import Infra, InfraObject -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, +from feast.infra.compute_engines.base import ( + ComputeEngine, ) +from feast.infra.infra_object import Infra, InfraObject from feast.infra.offline_stores.offline_store import RetrievalJob from feast.infra.offline_stores.offline_utils import get_offline_store_from_config from feast.infra.online_stores.helpers import get_online_store_from_config @@ -64,7 +64,7 @@ def __init__(self, config: RepoConfig): self.repo_config = config self._offline_store = None self._online_store = None - self._batch_engine: Optional[BatchMaterializationEngine] = None + self._batch_engine: Optional[ComputeEngine] = None @property def online_store(self): @@ -89,7 +89,7 @@ def async_supported(self) -> ProviderAsyncMethods: ) @property - def batch_engine(self) -> BatchMaterializationEngine: + def batch_engine(self) -> ComputeEngine: if self._batch_engine: return self._batch_engine else: @@ -439,7 +439,7 @@ def materialize_single_feature_view( end_time=end_date, tqdm_builder=tqdm_builder, ) - jobs = self.batch_engine.materialize(registry, [task]) + jobs = self.batch_engine.materialize(registry, task) assert len(jobs) == 1 if jobs[0].status() == MaterializationJobStatus.ERROR and jobs[0].error(): e = jobs[0].error() diff --git a/sdk/python/feast/offline_server.py b/sdk/python/feast/offline_server.py index f3215ca0e47..9c7e04dfe31 100644 --- a/sdk/python/feast/offline_server.py +++ b/sdk/python/feast/offline_server.py @@ -358,6 +358,7 @@ def pull_all_from_table_or_query(self, command: dict): data_source=data_source, join_key_columns=command["join_key_columns"], feature_name_columns=command["feature_name_columns"], + created_timestamp_column=command["created_timestamp_column"], timestamp_field=command["timestamp_field"], start_date=utils.make_tzaware( datetime.fromisoformat(command["start_date"]) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 7195eb831bf..8a4e6d1896e 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -44,11 +44,11 @@ } BATCH_ENGINE_CLASS_FOR_TYPE = { - "local": "feast.infra.materialization.local_engine.LocalMaterializationEngine", - "snowflake.engine": "feast.infra.materialization.snowflake_engine.SnowflakeMaterializationEngine", - "lambda": "feast.infra.materialization.aws_lambda.lambda_engine.LambdaMaterializationEngine", - "k8s": "feast.infra.materialization.kubernetes.k8s_materialization_engine.KubernetesMaterializationEngine", - "spark.engine": "feast.infra.materialization.contrib.spark.spark_materialization_engine.SparkMaterializationEngine", + "local": "feast.infra.compute_engines.local.compute.LocalComputeEngine", + "snowflake.engine": "feast.infra.compute_engines.snowflake.snowflake_engine.SnowflakeComputeEngine", + "lambda": "feast.infra.compute_engines.aws_lambda.lambda_engine.LambdaComputeEngine", + "k8s": "feast.infra.compute_engines.kubernetes.k8s_engine.KubernetesComputeEngine", + "spark.engine": "feast.infra.compute_engines.spark.compute.SparkComputeEngine", } LEGACY_ONLINE_STORE_CLASS_FOR_TYPE = { diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 621190643a4..3062953897c 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -168,10 +168,9 @@ def transform_feature(df: DataFrame) -> DataFrame: repo_config=spark_environment.config, offline_store=SparkOfflineStore(), online_store=MagicMock(), - registry=registry, ) - spark_dag_retrieval_job = engine.get_historical_features(task) + spark_dag_retrieval_job = engine.get_historical_features(registry, task) spark_df = cast(SparkDAGRetrievalJob, spark_dag_retrieval_job).to_spark_df() df_out = spark_df.orderBy("driver_id").to_pandas_on_spark() @@ -253,9 +252,11 @@ def tqdm_builder(length): registry=registry, ) - spark_materialize_job = engine.materialize(task) + spark_materialize_jobs = engine.materialize(registry, task) + + assert len(spark_materialize_jobs) == 1 - assert spark_materialize_job.status() == MaterializationJobStatus.SUCCEEDED + assert spark_materialize_jobs[0].status() == MaterializationJobStatus.SUCCEEDED _check_online_features( fs=fs,