diff --git a/protos/feast/core/FeatureView.proto b/protos/feast/core/FeatureView.proto index 3e9aa17256f..481ae00403f 100644 --- a/protos/feast/core/FeatureView.proto +++ b/protos/feast/core/FeatureView.proto @@ -74,7 +74,11 @@ message FeatureViewSpec { DataSource stream_source = 9; // Whether these features should be served online or not + // This is also used to determine whether the features should be written to the online store bool online = 8; + + // Whether these features should be written to the offline store + bool offline = 13; } message FeatureViewMeta { diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 5259d5d2b90..2c2106f5a3e 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -236,6 +236,7 @@ def __copy__(self): schema=self.schema, tags=self.tags, online=self.online, + offline=self.offline, ) # This is deliberately set outside of the FV initialization as we do not have the Entity objects. @@ -258,6 +259,7 @@ def __eq__(self, other): sorted(self.entities) != sorted(other.entities) or self.ttl != other.ttl or self.online != other.online + or self.offline != other.offline or self.batch_source != other.batch_source or self.stream_source != other.stream_source or sorted(self.entity_columns) != sorted(other.entity_columns) @@ -363,6 +365,7 @@ def to_proto(self) -> FeatureViewProto: owner=self.owner, ttl=(ttl_duration if ttl_duration is not None else None), online=self.online, + offline=self.offline, batch_source=batch_source_proto, stream_source=stream_source_proto, ) @@ -412,6 +415,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto): tags=dict(feature_view_proto.spec.tags), owner=feature_view_proto.spec.owner, online=feature_view_proto.spec.online, + offline=feature_view_proto.spec.offline, ttl=( timedelta(days=0) if feature_view_proto.spec.ttl.ToNanoseconds() == 0 diff --git a/sdk/python/feast/infra/common/serde.py b/sdk/python/feast/infra/common/serde.py new file mode 100644 index 00000000000..90e1be9234e --- /dev/null +++ b/sdk/python/feast/infra/common/serde.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass + +import dill + +from feast import FeatureView +from feast.infra.passthrough_provider import PassthroughProvider +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto + + +@dataclass +class SerializedArtifacts: + """Class to assist with serializing unpicklable artifacts to be passed to the compute engine.""" + + feature_view_proto: str + repo_config_byte: str + + @classmethod + def serialize(cls, feature_view, repo_config): + # serialize to proto + feature_view_proto = feature_view.to_proto().SerializeToString() + + # serialize repo_config to disk. Will be used to instantiate the online store + repo_config_byte = dill.dumps(repo_config) + + return SerializedArtifacts( + feature_view_proto=feature_view_proto, repo_config_byte=repo_config_byte + ) + + def unserialize(self): + # unserialize + proto = FeatureViewProto() + proto.ParseFromString(self.feature_view_proto) + feature_view = FeatureView.from_proto(proto) + + # load + repo_config = dill.loads(self.repo_config_byte) + + provider = PassthroughProvider(repo_config) + online_store = provider.online_store + offline_store = provider.offline_store + return feature_view, online_store, offline_store, repo_config diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index aee245da21c..4f9dcc871d5 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -91,6 +91,6 @@ def build_validation_node(self, input_node): return node def build_output_nodes(self, input_node): - node = LocalOutputNode("output") + node = LocalOutputNode("output", self.feature_view) node.add_input(input_node) self.nodes.append(node) diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index aea83921351..709b592f97c 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -1,8 +1,9 @@ from datetime import datetime, timedelta -from typing import Optional +from typing import Optional, Union import pyarrow as pa +from feast import BatchFeatureView, StreamFeatureView from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue @@ -11,6 +12,7 @@ from feast.infra.offline_stores.offline_utils import ( infer_event_timestamp_from_entity_df, ) +from feast.utils import _convert_arrow_to_proto ENTITY_TS_ALIAS = "__entity_event_timestamp" @@ -207,11 +209,42 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: class LocalOutputNode(LocalNode): - def __init__(self, name: str): + def __init__( + self, name: str, feature_view: Union[BatchFeatureView, StreamFeatureView] + ): super().__init__(name) + self.feature_view = feature_view def execute(self, context: ExecutionContext) -> ArrowTableValue: input_table = self.get_single_table(context).data context.node_outputs[self.name] = input_table - # TODO: implement the logic to write to offline store + + if self.feature_view.online: + online_store = context.online_store + + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in self.feature_view.entity_columns + } + + rows_to_write = _convert_arrow_to_proto( + input_table, self.feature_view, join_key_to_value_type + ) + + online_store.online_write_batch( + config=context.repo_config, + table=self.feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + + if self.feature_view.offline: + offline_store = context.offline_store + offline_store.offline_write_batch( + config=context.repo_config, + feature_view=self.feature_view, + table=input_table, + progress=lambda x: None, + ) + return input_table diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index 944feccf903..df882bfc2c3 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -5,7 +5,7 @@ from feast.infra.common.materialization_job import MaterializationTask from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.feature_builder import FeatureBuilder -from feast.infra.compute_engines.spark.node import ( +from feast.infra.compute_engines.spark.nodes import ( SparkAggregationNode, SparkDedupNode, SparkFilterNode, @@ -73,7 +73,8 @@ def build_transformation_node(self, input_node): return node def build_output_nodes(self, input_node): - node = SparkWriteNode("output", input_node, self.feature_view) + node = SparkWriteNode("output", self.feature_view) + node.add_input(input_node) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py similarity index 88% rename from sdk/python/feast/infra/compute_engines/spark/node.py rename to sdk/python/feast/infra/compute_engines/spark/nodes.py index 0c1c1476613..7fe7fdb45e8 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -7,18 +7,19 @@ from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation from feast.data_source import DataSource +from feast.infra.common.serde import SerializedArtifacts from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( - _map_by_partition, - _SparkSerializedArtifacts, -) +from feast.infra.compute_engines.spark.utils import map_in_arrow from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkRetrievalJob, _get_entity_schema, ) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) from feast.infra.offline_stores.offline_utils import ( infer_event_timestamp_from_entity_df, ) @@ -273,30 +274,41 @@ class SparkWriteNode(DAGNode): def __init__( self, name: str, - input_node: DAGNode, feature_view: Union[BatchFeatureView, StreamFeatureView], ): super().__init__(name) - self.add_input(input_node) self.feature_view = feature_view def execute(self, context: ExecutionContext) -> DAGValue: spark_df: DataFrame = self.get_single_input_value(context).data - spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( + serialized_artifacts = SerializedArtifacts.serialize( feature_view=self.feature_view, repo_config=context.repo_config ) - # ✅ 1. Write to offline store (if enabled) - if self.feature_view.offline: - # TODO: Update _map_by_partition to be able to write to offline store - pass - - # ✅ 2. Write to online store (if enabled) + # ✅ 1. Write to online store if online enabled if self.feature_view.online: - spark_df.mapInPandas( - lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" + spark_df.mapInArrow( + lambda x: map_in_arrow(x, serialized_artifacts, mode="online"), + spark_df.schema, ).count() + # ✅ 2. Write to offline store if offline enabled + if self.feature_view.offline: + if not isinstance(self.feature_view.batch_source, SparkSource): + spark_df.mapInArrow( + lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"), + spark_df.schema, + ).count() + # Directly write spark df to spark offline store without using mapInArrow + else: + dest_path = self.feature_view.batch_source.path + file_format = self.feature_view.batch_source.file_format + if not dest_path or not file_format: + raise ValueError( + "Destination path and file format must be specified for SparkSource." + ) + spark_df.write.format(file_format).mode("append").save(dest_path) + return DAGValue( data=spark_df, format=DAGFormat.SPARK, diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 262876f9dbb..7808ca0118a 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -1,8 +1,12 @@ -from typing import Dict, Optional +from typing import Dict, Iterable, Literal, Optional +import pyarrow as pa from pyspark import SparkConf from pyspark.sql import SparkSession +from feast.infra.common.serde import SerializedArtifacts +from feast.utils import _convert_arrow_to_proto + def get_or_create_new_spark_session( spark_config: Optional[Dict[str, str]] = None, @@ -16,4 +20,47 @@ def get_or_create_new_spark_session( ) spark_session = spark_builder.getOrCreate() + spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") return spark_session + + +def map_in_arrow( + iterator: Iterable[pa.RecordBatch], + serialized_artifacts: "SerializedArtifacts", + mode: Literal["online", "offline"] = "online", +): + for batch in iterator: + table = pa.Table.from_batches([batch]) + + ( + feature_view, + online_store, + offline_store, + repo_config, + ) = serialized_artifacts.unserialize() + + if mode == "online": + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + + rows_to_write = _convert_arrow_to_proto( + table, feature_view, join_key_to_value_type + ) + + online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + if mode == "offline": + offline_store.offline_write_batch( + config=repo_config, + feature_view=feature_view, + table=table, + progress=lambda x: None, + ) + + yield batch diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 246297cc1d5..c4809df3678 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import Callable, List, Literal, Optional, Sequence, Union, cast -import dill import pandas as pd import pyarrow from tqdm import tqdm @@ -15,6 +14,7 @@ MaterializationJobStatus, MaterializationTask, ) +from feast.infra.common.serde import SerializedArtifacts from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, ) @@ -23,10 +23,8 @@ SparkRetrievalJob, ) from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.passthrough_provider import PassthroughProvider from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView -from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.stream_feature_view import StreamFeatureView from feast.utils import ( @@ -171,7 +169,7 @@ def _materialize_one( ), ) - spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( + serialized_artifacts = SerializedArtifacts.serialize( feature_view=feature_view, repo_config=self.repo_config ) @@ -182,7 +180,7 @@ def _materialize_one( ) spark_df.mapInPandas( - lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int" + lambda x: _map_by_partition(x, serialized_artifacts), "status int" ).count() # dummy action to force evaluation return SparkMaterializationJob( @@ -194,40 +192,7 @@ def _materialize_one( ) -@dataclass -class _SparkSerializedArtifacts: - """Class to assist with serializing unpicklable artifacts to the spark workers""" - - feature_view_proto: str - repo_config_byte: str - - @classmethod - def serialize(cls, feature_view, repo_config): - # serialize to proto - feature_view_proto = feature_view.to_proto().SerializeToString() - - # serialize repo_config to disk. Will be used to instantiate the online store - repo_config_byte = dill.dumps(repo_config) - - return _SparkSerializedArtifacts( - feature_view_proto=feature_view_proto, repo_config_byte=repo_config_byte - ) - - def unserialize(self): - # unserialize - proto = FeatureViewProto() - proto.ParseFromString(self.feature_view_proto) - feature_view = FeatureView.from_proto(proto) - - # load - repo_config = dill.loads(self.repo_config_byte) - - provider = PassthroughProvider(repo_config) - online_store = provider.online_store - return feature_view, online_store, repo_config - - -def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArtifacts): +def _map_by_partition(iterator, serialized_artifacts: SerializedArtifacts): for pdf in iterator: if pdf.shape[0] == 0: print("Skipping") @@ -238,8 +203,9 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti ( feature_view, online_store, + _, repo_config, - ) = spark_serialized_artifacts.unserialize() + ) = serialized_artifacts.unserialize() if feature_view.batch_source.field_mapping is not None: # Spark offline store does the field mapping in pull_latest_from_table_or_query() call diff --git a/sdk/python/feast/protos/feast/core/FeatureView_pb2.py b/sdk/python/feast/protos/feast/core/FeatureView_pb2.py index 80d04c1ec3f..d1456cf9faf 100644 --- a/sdk/python/feast/protos/feast/core/FeatureView_pb2.py +++ b/sdk/python/feast/protos/feast/core/FeatureView_pb2.py @@ -18,7 +18,7 @@ from feast.protos.feast.core import Feature_pb2 as feast_dot_core_dot_Feature__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66\x65\x61st/core/FeatureView.proto\x12\nfeast.core\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1b\x66\x65\x61st/core/DataSource.proto\x1a\x18\x66\x65\x61st/core/Feature.proto\"c\n\x0b\x46\x65\x61tureView\x12)\n\x04spec\x18\x01 \x01(\x0b\x32\x1b.feast.core.FeatureViewSpec\x12)\n\x04meta\x18\x02 \x01(\x0b\x32\x1b.feast.core.FeatureViewMeta\"\xbd\x03\n\x0f\x46\x65\x61tureViewSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07project\x18\x02 \x01(\t\x12\x10\n\x08\x65ntities\x18\x03 \x03(\t\x12+\n\x08\x66\x65\x61tures\x18\x04 \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x31\n\x0e\x65ntity_columns\x18\x0c \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x13\n\x0b\x64\x65scription\x18\n \x01(\t\x12\x33\n\x04tags\x18\x05 \x03(\x0b\x32%.feast.core.FeatureViewSpec.TagsEntry\x12\r\n\x05owner\x18\x0b \x01(\t\x12&\n\x03ttl\x18\x06 \x01(\x0b\x32\x19.google.protobuf.Duration\x12,\n\x0c\x62\x61tch_source\x18\x07 \x01(\x0b\x32\x16.feast.core.DataSource\x12-\n\rstream_source\x18\t \x01(\x0b\x32\x16.feast.core.DataSource\x12\x0e\n\x06online\x18\x08 \x01(\x08\x1a+\n\tTagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xcc\x01\n\x0f\x46\x65\x61tureViewMeta\x12\x35\n\x11\x63reated_timestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12:\n\x16last_updated_timestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x46\n\x19materialization_intervals\x18\x03 \x03(\x0b\x32#.feast.core.MaterializationInterval\"w\n\x17MaterializationInterval\x12.\n\nstart_time\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"@\n\x0f\x46\x65\x61tureViewList\x12-\n\x0c\x66\x65\x61tureviews\x18\x01 \x03(\x0b\x32\x17.feast.core.FeatureViewBU\n\x10\x66\x65\x61st.proto.coreB\x10\x46\x65\x61tureViewProtoZ/github.com/feast-dev/feast/go/protos/feast/coreb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66\x65\x61st/core/FeatureView.proto\x12\nfeast.core\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1b\x66\x65\x61st/core/DataSource.proto\x1a\x18\x66\x65\x61st/core/Feature.proto\"c\n\x0b\x46\x65\x61tureView\x12)\n\x04spec\x18\x01 \x01(\x0b\x32\x1b.feast.core.FeatureViewSpec\x12)\n\x04meta\x18\x02 \x01(\x0b\x32\x1b.feast.core.FeatureViewMeta\"\xce\x03\n\x0f\x46\x65\x61tureViewSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07project\x18\x02 \x01(\t\x12\x10\n\x08\x65ntities\x18\x03 \x03(\t\x12+\n\x08\x66\x65\x61tures\x18\x04 \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x31\n\x0e\x65ntity_columns\x18\x0c \x03(\x0b\x32\x19.feast.core.FeatureSpecV2\x12\x13\n\x0b\x64\x65scription\x18\n \x01(\t\x12\x33\n\x04tags\x18\x05 \x03(\x0b\x32%.feast.core.FeatureViewSpec.TagsEntry\x12\r\n\x05owner\x18\x0b \x01(\t\x12&\n\x03ttl\x18\x06 \x01(\x0b\x32\x19.google.protobuf.Duration\x12,\n\x0c\x62\x61tch_source\x18\x07 \x01(\x0b\x32\x16.feast.core.DataSource\x12-\n\rstream_source\x18\t \x01(\x0b\x32\x16.feast.core.DataSource\x12\x0e\n\x06online\x18\x08 \x01(\x08\x12\x0f\n\x07offline\x18\r \x01(\x08\x1a+\n\tTagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xcc\x01\n\x0f\x46\x65\x61tureViewMeta\x12\x35\n\x11\x63reated_timestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12:\n\x16last_updated_timestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x46\n\x19materialization_intervals\x18\x03 \x03(\x0b\x32#.feast.core.MaterializationInterval\"w\n\x17MaterializationInterval\x12.\n\nstart_time\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"@\n\x0f\x46\x65\x61tureViewList\x12-\n\x0c\x66\x65\x61tureviews\x18\x01 \x03(\x0b\x32\x17.feast.core.FeatureViewBU\n\x10\x66\x65\x61st.proto.coreB\x10\x46\x65\x61tureViewProtoZ/github.com/feast-dev/feast/go/protos/feast/coreb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,13 +31,13 @@ _globals['_FEATUREVIEW']._serialized_start=164 _globals['_FEATUREVIEW']._serialized_end=263 _globals['_FEATUREVIEWSPEC']._serialized_start=266 - _globals['_FEATUREVIEWSPEC']._serialized_end=711 - _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_start=668 - _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_end=711 - _globals['_FEATUREVIEWMETA']._serialized_start=714 - _globals['_FEATUREVIEWMETA']._serialized_end=918 - _globals['_MATERIALIZATIONINTERVAL']._serialized_start=920 - _globals['_MATERIALIZATIONINTERVAL']._serialized_end=1039 - _globals['_FEATUREVIEWLIST']._serialized_start=1041 - _globals['_FEATUREVIEWLIST']._serialized_end=1105 + _globals['_FEATUREVIEWSPEC']._serialized_end=728 + _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_start=685 + _globals['_FEATUREVIEWSPEC_TAGSENTRY']._serialized_end=728 + _globals['_FEATUREVIEWMETA']._serialized_start=731 + _globals['_FEATUREVIEWMETA']._serialized_end=935 + _globals['_MATERIALIZATIONINTERVAL']._serialized_start=937 + _globals['_MATERIALIZATIONINTERVAL']._serialized_end=1056 + _globals['_FEATUREVIEWLIST']._serialized_start=1058 + _globals['_FEATUREVIEWLIST']._serialized_end=1122 # @@protoc_insertion_point(module_scope) diff --git a/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi b/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi index 57158fc2c6c..6abeb85e263 100644 --- a/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi +++ b/sdk/python/feast/protos/feast/core/FeatureView_pb2.pyi @@ -90,6 +90,7 @@ class FeatureViewSpec(google.protobuf.message.Message): BATCH_SOURCE_FIELD_NUMBER: builtins.int STREAM_SOURCE_FIELD_NUMBER: builtins.int ONLINE_FIELD_NUMBER: builtins.int + OFFLINE_FIELD_NUMBER: builtins.int name: builtins.str """Name of the feature view. Must be unique. Not updated.""" project: builtins.str @@ -124,7 +125,11 @@ class FeatureViewSpec(google.protobuf.message.Message): def stream_source(self) -> feast.core.DataSource_pb2.DataSource: """Streaming DataSource from where this view can consume "online" feature data.""" online: builtins.bool - """Whether these features should be served online or not""" + """Whether these features should be served online or not + This is also used to determine whether the features should be written to the online store + """ + offline: builtins.bool + """Whether these features should be written to the offline store""" def __init__( self, *, @@ -140,9 +145,10 @@ class FeatureViewSpec(google.protobuf.message.Message): batch_source: feast.core.DataSource_pb2.DataSource | None = ..., stream_source: feast.core.DataSource_pb2.DataSource | None = ..., online: builtins.bool = ..., + offline: builtins.bool = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "stream_source", b"stream_source", "ttl", b"ttl"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "description", b"description", "entities", b"entities", "entity_columns", b"entity_columns", "features", b"features", "name", b"name", "online", b"online", "owner", b"owner", "project", b"project", "stream_source", b"stream_source", "tags", b"tags", "ttl", b"ttl"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["batch_source", b"batch_source", "description", b"description", "entities", b"entities", "entity_columns", b"entity_columns", "features", b"features", "name", b"name", "offline", b"offline", "online", b"online", "owner", b"owner", "project", b"project", "stream_source", b"stream_source", "tags", b"tags", "ttl", b"ttl"]) -> None: ... global___FeatureViewSpec = FeatureViewSpec diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index 15b6e850c65..621190643a4 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -188,6 +188,20 @@ def transform_feature(df: DataFrame) -> DataFrame: @pytest.mark.integration def test_spark_compute_engine_materialize(): + """ + Test the SparkComputeEngine materialize method. + For the current feature view driver_hourly_stats, The below execution plan: + 1. feature data from create_feature_dataset + 2. filter by start_time and end_time, that is, the last 2 days + for the driver_id 1001, the data left is row 0 + for the driver_id 1002, the data left is row 2 + 3. apply the transform_feature function to the data + for all features, the value is multiplied by 2 + 4. write the data to the online store and offline store + + Returns: + + """ spark_environment = create_spark_environment() fs = spark_environment.feature_store registry = fs.registry @@ -212,7 +226,7 @@ def transform_feature(df: DataFrame) -> DataFrame: Field(name="driver_id", dtype=Int32), ], online=True, - offline=False, + offline=True, source=data_source, ) @@ -242,9 +256,59 @@ def tqdm_builder(length): spark_materialize_job = engine.materialize(task) assert spark_materialize_job.status() == MaterializationJobStatus.SUCCEEDED + + _check_online_features( + fs=fs, + driver_id=1001, + feature="driver_hourly_stats:conv_rate", + expected_value=1.6, + full_feature_names=True, + ) + + entity_df = create_entity_df() + + _check_offline_features( + fs=fs, + feature="driver_hourly_stats:conv_rate", + entity_df=entity_df, + ) finally: spark_environment.teardown() +def _check_online_features( + fs, + driver_id, + feature, + expected_value, + full_feature_names: bool = True, +): + online_response = fs.get_online_features( + features=[feature], + entity_rows=[{"driver_id": driver_id}], + full_feature_names=full_feature_names, + ).to_dict() + + feature_ref = "__".join(feature.split(":")) + + assert len(online_response["driver_id"]) == 1 + assert online_response["driver_id"][0] == driver_id + assert abs(online_response[feature_ref][0] - expected_value < 1e-6), ( + "Transformed result" + ) + + +def _check_offline_features( + fs, + feature, + entity_df, +): + offline_df = fs.get_historical_features( + entity_df=entity_df, + features=[feature], + ).to_df() + assert len(offline_df) == 4 + + if __name__ == "__main__": test_spark_compute_engine_get_historical_features() diff --git a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py index d5f6085a67c..c486b4148fc 100644 --- a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py @@ -194,7 +194,7 @@ def test_local_output_node(): context = create_context( node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} ) - node = LocalOutputNode("output") + node = LocalOutputNode("output", MagicMock()) node.add_input(MagicMock()) node.inputs[0].name = "source" result = node.execute(context) diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index ae69c0a6fcd..3f681017e89 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -8,7 +8,7 @@ from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.compute_engines.spark.node import ( +from feast.infra.compute_engines.spark.nodes import ( SparkAggregationNode, SparkDedupNode, SparkJoinNode,