From 76bbdf1af2855e08b0c104d66757c67a4e331bf9 Mon Sep 17 00:00:00 2001 From: Felix Wang Date: Mon, 15 Aug 2022 13:40:24 -0700 Subject: [PATCH 1/5] Prevent overwriting existing file for file offline store Signed-off-by: Felix Wang --- sdk/python/feast/errors.py | 5 ++ .../athena_offline_store/tests/data_source.py | 4 +- .../tests/data_source.py | 4 +- .../spark_offline_store/tests/data_source.py | 6 ++- .../trino_offline_store/tests/data_source.py | 4 +- sdk/python/feast/infra/offline_stores/file.py | 14 +++++- .../infra/offline_stores/offline_store.py | 8 ++- .../universal/data_source_creator.py | 8 ++- .../universal/data_sources/bigquery.py | 4 +- .../universal/data_sources/file.py | 16 +++++- .../universal/data_sources/redshift.py | 4 +- .../universal/data_sources/snowflake.py | 4 +- .../integration/offline_store/test_persist.py | 50 +++++++++++++++++++ 13 files changed, 117 insertions(+), 14 deletions(-) create mode 100644 sdk/python/tests/integration/offline_store/test_persist.py diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index f8a288940a5..01e8ecbacab 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -204,6 +204,11 @@ def __init__( ) +class SavedDatasetLocationAlreadyExists(Exception): + def __init__(self, location: str): + super().__init__(f"Saved dataset location {location} already exists.") + + class FeastOfflineStoreInvalidName(Exception): def __init__(self, offline_store_class_name: str): super().__init__( diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py index 92e0d6e5f60..2818be36276 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -94,7 +94,9 @@ def create_data_source( data_source=self.offline_store_config.data_source, ) - def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetAthenaStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py index c84fce03dcd..3d0b0087c6a 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py @@ -119,7 +119,9 @@ def create_online_store(self) -> Dict[str, str]: "password": POSTGRES_PASSWORD, } - def create_saved_dataset_destination(self): + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ): # FIXME: ... return None diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py index 95bedd1b409..939e4902327 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py @@ -1,5 +1,5 @@ import uuid -from typing import Dict, List +from typing import Dict, List, Optional import pandas as pd from pyspark import SparkConf @@ -96,7 +96,9 @@ def create_data_source( field_mapping=field_mapping or {"ts_1": "ts"}, ) - def create_saved_dataset_destination(self) -> SavedDatasetSparkStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetSparkStorage: table = f"persisted_{str(uuid.uuid4()).replace('-', '_')}" return SavedDatasetSparkStorage( table=table, query=None, path=None, file_format=None diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py index 67efa6a27f8..83266314683 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py @@ -105,7 +105,9 @@ def create_data_source( field_mapping=field_mapping or {"ts_1": "ts"}, ) - def create_saved_dataset_destination(self) -> SavedDatasetTrinoStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetTrinoStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index ca945c3ff3b..4ca5ce219fe 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -1,3 +1,4 @@ +import os import uuid from datetime import datetime from pathlib import Path @@ -11,13 +12,16 @@ import pytz from pydantic.typing import Literal -from feast import FileSource, OnDemandFeatureView from feast.data_source import DataSource -from feast.errors import FeastJoinKeysDuringMaterialization +from feast.errors import ( + FeastJoinKeysDuringMaterialization, + SavedDatasetLocationAlreadyExists, +) from feast.feature_logging import LoggingConfig, LoggingSource from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView from feast.infra.offline_stores.file_source import ( FileLoggingDestination, + FileSource, SavedDatasetFileStorage, ) from feast.infra.offline_stores.offline_store import ( @@ -30,6 +34,7 @@ get_pyarrow_schema_from_batch_source, ) from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.usage import log_exceptions_and_usage @@ -85,6 +90,11 @@ def _to_arrow_internal(self): def persist(self, storage: SavedDatasetStorage): assert isinstance(storage, SavedDatasetFileStorage) + + # Check if the specified location already exists. + if os.path.exists(storage.file_options.uri): + raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri) + filesystem, path = FileSource.create_filesystem_and_path( storage.file_options.uri, storage.file_options.s3_endpoint_override, diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 741b97e2fd7..95f28e8ade7 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -174,7 +174,13 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]: @abstractmethod def persist(self, storage: SavedDatasetStorage): - """Synchronously executes the underlying query and persists the result in the same offline store.""" + """ + Synchronously executes the underlying query and persists the result in the same offline store + at the specified destination. + + Currently does not prevent overwriting a pre-existing location in the offline store, although + individual implementations may do so. Eventually all implementations should prevent overwriting. + """ pass @property diff --git a/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py b/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py index b36af0db472..7b8f2869646 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py @@ -49,7 +49,13 @@ def create_offline_store_config(self) -> FeastConfigBaseModel: ... @abstractmethod - def create_saved_dataset_destination(self) -> SavedDatasetStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetStorage: + """ + Creates a saved dataset destination. If data_source is specified, uses the location of that + data source as the destination for the saved dataset. + """ ... def create_logged_features_destination(self) -> LoggingDestination: diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py index 384037eef14..f3e4a8599e9 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py @@ -95,7 +95,9 @@ def create_data_source( field_mapping=field_mapping or {"ts_1": "ts"}, ) - def create_saved_dataset_destination(self) -> SavedDatasetBigQueryStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetBigQueryStorage: table = self.get_prefixed_table_name( f"persisted_{str(uuid.uuid4()).replace('-', '_')}" ) diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py index 7b8e5e80e67..baada15a217 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py @@ -59,7 +59,17 @@ def create_data_source( field_mapping=field_mapping or {"ts_1": "ts"}, ) - def create_saved_dataset_destination(self) -> SavedDatasetFileStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetFileStorage: + if data_source: + assert isinstance(data_source, FileSource) + return SavedDatasetFileStorage( + path=data_source.path, + file_format=ParquetFormat(), + s3_endpoint_override=None, + ) + d = tempfile.mkdtemp(prefix=self.project_name) self.dirs.append(d) return SavedDatasetFileStorage( @@ -154,7 +164,9 @@ def create_data_source( s3_endpoint_override=f"http://{host}:{port}", ) - def create_saved_dataset_destination(self) -> SavedDatasetFileStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetFileStorage: port = self.minio.get_exposed_port("9000") host = self.minio.get_container_host_ip() diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py index c92a413616b..43d167b3aa8 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py @@ -78,7 +78,9 @@ def create_data_source( database=self.offline_store_config.database, ) - def create_saved_dataset_destination(self) -> SavedDatasetRedshiftStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetRedshiftStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py index b5fc2448d4f..b5acada2011 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py @@ -66,7 +66,9 @@ def create_data_source( warehouse=self.offline_store_config.warehouse, ) - def create_saved_dataset_destination(self) -> SavedDatasetSnowflakeStorage: + def create_saved_dataset_destination( + self, data_source: Optional[DataSource] = None + ) -> SavedDatasetSnowflakeStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) diff --git a/sdk/python/tests/integration/offline_store/test_persist.py b/sdk/python/tests/integration/offline_store/test_persist.py new file mode 100644 index 00000000000..d286f7722e6 --- /dev/null +++ b/sdk/python/tests/integration/offline_store/test_persist.py @@ -0,0 +1,50 @@ +import pytest + +from feast.errors import SavedDatasetLocationAlreadyExists +from tests.integration.feature_repos.repo_configuration import ( + construct_universal_feature_views, +) +from tests.integration.feature_repos.universal.entities import ( + customer, + driver, + location, +) + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores(only=["file"]) +def test_persist_does_not_overwrite(environment, universal_data_sources): + """ + Tests that the persist method does not overwrite an existing location in the offline store. + + This test currently is only run against the file offline store as it is the only implementation + that prevents overwriting. As more offline stores add this check, they should be added to this test. + """ + store = environment.feature_store + entities, datasets, data_sources = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + store.apply([driver(), customer(), location(), *feature_views.values()]) + + features = [ + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ] + + entity_df = datasets.entity_df.drop( + columns=["order_id", "origin_id", "destination_id"] + ) + job = store.get_historical_features( + entity_df=entity_df, + features=features, + ) + + with pytest.raises(SavedDatasetLocationAlreadyExists): + # This should fail since persisting to a preexisting location is not allowed. + store.create_saved_dataset( + from_=job, + name="my_training_dataset", + storage=environment.data_source_creator.create_saved_dataset_destination( + data_source=data_sources.customer + ), + ) From cc726e18b0b5c33b10aef00d4f7190fb6cc9e3fb Mon Sep 17 00:00:00 2001 From: Felix Wang Date: Wed, 17 Aug 2022 16:59:14 -0700 Subject: [PATCH 2/5] Add `from_data_source` method Signed-off-by: Felix Wang --- .../feast/infra/offline_stores/file_source.py | 42 +++++++++++++++---- sdk/python/feast/saved_dataset.py | 28 ++++++++++++- .../integration/offline_store/test_persist.py | 10 +++-- 3 files changed, 68 insertions(+), 12 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/file_source.py b/sdk/python/feast/infra/offline_stores/file_source.py index 52a687a52c9..135409ed04a 100644 --- a/sdk/python/feast/infra/offline_stores/file_source.py +++ b/sdk/python/feast/infra/offline_stores/file_source.py @@ -96,12 +96,20 @@ def __eq__(self, other): ) @property - def path(self): - """ - Returns the path of this file data source. - """ + def path(self) -> str: + """Returns the path of this file data source.""" return self.file_options.uri + @property + def file_format(self) -> Optional[FileFormat]: + """Returns the file format of this file data source.""" + return self.file_options.file_format + + @property + def s3_endpoint_override(self) -> Optional[str]: + """Returns the s3 endpoint override of this file data source.""" + return self.file_options.s3_endpoint_override + @staticmethod def from_proto(data_source: DataSourceProto): return FileSource( @@ -177,24 +185,33 @@ def get_table_query_string(self) -> str: class FileOptions: """ Configuration options for a file data source. + + Attributes: + uri: File source url, e.g. s3:// or local file. + s3_endpoint_override: Custom s3 endpoint (used only with s3 uri). + file_format: File source format, e.g. parquet. """ + uri: str + file_format: Optional[FileFormat] + s3_endpoint_override: str + def __init__( self, + uri: str, file_format: Optional[FileFormat], s3_endpoint_override: Optional[str], - uri: Optional[str], ): """ Initializes a FileOptions object. Args: + uri: File source url, e.g. s3:// or local file. file_format (optional): File source format, e.g. parquet. s3_endpoint_override (optional): Custom s3 endpoint (used only with s3 uri). - uri (optional): File source url, e.g. s3:// or local file. """ + self.uri = uri self.file_format = file_format - self.uri = uri or "" self.s3_endpoint_override = s3_endpoint_override or "" @classmethod @@ -269,6 +286,17 @@ def to_data_source(self) -> DataSource: s3_endpoint_override=self.file_options.s3_endpoint_override, ) + @staticmethod + def from_data_source(data_source: DataSource) -> "SavedDatasetStorage": + assert isinstance(data_source, FileSource) + return SavedDatasetFileStorage( + path=data_source.path, + file_format=data_source.file_format + if data_source.file_format + else ParquetFormat(), + s3_endpoint_override=data_source.s3_endpoint_override, + ) + class FileLoggingDestination(LoggingDestination): _proto_kind = "file_destination" diff --git a/sdk/python/feast/saved_dataset.py b/sdk/python/feast/saved_dataset.py index e2004d15f4c..4a3043a8731 100644 --- a/sdk/python/feast/saved_dataset.py +++ b/sdk/python/feast/saved_dataset.py @@ -8,6 +8,7 @@ from feast.data_source import DataSource from feast.dqm.profilers.profiler import Profile, Profiler +from feast.importer import import_class from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto from feast.protos.feast.core.SavedDataset_pb2 import SavedDatasetMeta, SavedDatasetSpec from feast.protos.feast.core.SavedDataset_pb2 import ( @@ -31,6 +32,16 @@ def __new__(cls, name, bases, dct): return kls +_DATA_SOURCE_TO_SAVED_DATASET_STORAGE = { + "FileSource": "feast.infra.offline_stores.file_source.SavedDatasetFileStorage", +} + + +def get_saved_dataset_storage_class_from_path(saved_dataset_storage_path: str): + module_name, class_name = saved_dataset_storage_path.rsplit(".", 1) + return import_class(module_name, class_name, "SavedDatasetStorage") + + class SavedDatasetStorage(metaclass=_StorageRegistry): _proto_attr_name: str @@ -43,11 +54,24 @@ def from_proto(storage_proto: SavedDatasetStorageProto) -> "SavedDatasetStorage" @abstractmethod def to_proto(self) -> SavedDatasetStorageProto: - ... + pass @abstractmethod def to_data_source(self) -> DataSource: - ... + pass + + @staticmethod + def from_data_source(data_source: DataSource) -> "SavedDatasetStorage": + data_source_type = type(data_source).__name__ + if data_source_type in _DATA_SOURCE_TO_SAVED_DATASET_STORAGE: + cls = get_saved_dataset_storage_class_from_path( + _DATA_SOURCE_TO_SAVED_DATASET_STORAGE[data_source_type] + ) + return cls.from_data_source(data_source) + else: + raise ValueError( + f"This method currently does not support {data_source_type}." + ) class SavedDataset: diff --git a/sdk/python/tests/integration/offline_store/test_persist.py b/sdk/python/tests/integration/offline_store/test_persist.py index d286f7722e6..8e6f1829174 100644 --- a/sdk/python/tests/integration/offline_store/test_persist.py +++ b/sdk/python/tests/integration/offline_store/test_persist.py @@ -1,6 +1,7 @@ import pytest from feast.errors import SavedDatasetLocationAlreadyExists +from feast.saved_dataset import SavedDatasetStorage from tests.integration.feature_repos.repo_configuration import ( construct_universal_feature_views, ) @@ -40,11 +41,14 @@ def test_persist_does_not_overwrite(environment, universal_data_sources): ) with pytest.raises(SavedDatasetLocationAlreadyExists): + # Copy data source destination to a saved dataset destination. + saved_dataset_destination = SavedDatasetStorage.from_data_source( + data_sources.customer + ) + # This should fail since persisting to a preexisting location is not allowed. store.create_saved_dataset( from_=job, name="my_training_dataset", - storage=environment.data_source_creator.create_saved_dataset_destination( - data_source=data_sources.customer - ), + storage=saved_dataset_destination, ) From 91f3ab7c0a6af3982d071d1b43da5dc76631bac9 Mon Sep 17 00:00:00 2001 From: Felix Wang Date: Wed, 17 Aug 2022 17:03:06 -0700 Subject: [PATCH 3/5] Remove unnecessary changes Signed-off-by: Felix Wang --- .../athena_offline_store/tests/data_source.py | 2 +- .../postgres_offline_store/tests/data_source.py | 2 +- .../spark_offline_store/tests/data_source.py | 2 +- .../trino_offline_store/tests/data_source.py | 2 +- .../feature_repos/universal/data_source_creator.py | 8 +------- .../universal/data_sources/bigquery.py | 2 +- .../feature_repos/universal/data_sources/file.py | 14 ++------------ .../universal/data_sources/redshift.py | 2 +- .../universal/data_sources/snowflake.py | 4 +--- 9 files changed, 10 insertions(+), 28 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py index 2818be36276..200f3b009e4 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -95,7 +95,7 @@ def create_data_source( ) def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None + self, ) -> SavedDatasetAthenaStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py index 3d0b0087c6a..18be61d7930 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py @@ -120,7 +120,7 @@ def create_online_store(self) -> Dict[str, str]: } def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None + self, ): # FIXME: ... return None diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py index 939e4902327..7c0a6af7536 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py @@ -97,7 +97,7 @@ def create_data_source( ) def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None + self, ) -> SavedDatasetSparkStorage: table = f"persisted_{str(uuid.uuid4()).replace('-', '_')}" return SavedDatasetSparkStorage( diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py index 83266314683..7aba5730b73 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py @@ -106,7 +106,7 @@ def create_data_source( ) def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None + self, ) -> SavedDatasetTrinoStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" diff --git a/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py b/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py index 7b8f2869646..b36af0db472 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_source_creator.py @@ -49,13 +49,7 @@ def create_offline_store_config(self) -> FeastConfigBaseModel: ... @abstractmethod - def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None - ) -> SavedDatasetStorage: - """ - Creates a saved dataset destination. If data_source is specified, uses the location of that - data source as the destination for the saved dataset. - """ + def create_saved_dataset_destination(self) -> SavedDatasetStorage: ... def create_logged_features_destination(self) -> LoggingDestination: diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py index f3e4a8599e9..abd532ea5f9 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py @@ -96,7 +96,7 @@ def create_data_source( ) def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None + self, ) -> SavedDatasetBigQueryStorage: table = self.get_prefixed_table_name( f"persisted_{str(uuid.uuid4()).replace('-', '_')}" diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py index baada15a217..6aa4e9a4535 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py @@ -59,17 +59,7 @@ def create_data_source( field_mapping=field_mapping or {"ts_1": "ts"}, ) - def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None - ) -> SavedDatasetFileStorage: - if data_source: - assert isinstance(data_source, FileSource) - return SavedDatasetFileStorage( - path=data_source.path, - file_format=ParquetFormat(), - s3_endpoint_override=None, - ) - + def create_saved_dataset_destination(self) -> SavedDatasetFileStorage: d = tempfile.mkdtemp(prefix=self.project_name) self.dirs.append(d) return SavedDatasetFileStorage( @@ -165,7 +155,7 @@ def create_data_source( ) def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None + self, ) -> SavedDatasetFileStorage: port = self.minio.get_exposed_port("9000") host = self.minio.get_container_host_ip() diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py index 43d167b3aa8..5bce2d2b133 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py @@ -79,7 +79,7 @@ def create_data_source( ) def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None + self, ) -> SavedDatasetRedshiftStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py index b5acada2011..b5fc2448d4f 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py @@ -66,9 +66,7 @@ def create_data_source( warehouse=self.offline_store_config.warehouse, ) - def create_saved_dataset_destination( - self, data_source: Optional[DataSource] = None - ) -> SavedDatasetSnowflakeStorage: + def create_saved_dataset_destination(self) -> SavedDatasetSnowflakeStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) From bf256a298cb08a22b37c8319da48996f0c326d79 Mon Sep 17 00:00:00 2001 From: Felix Wang Date: Wed, 17 Aug 2022 17:04:57 -0700 Subject: [PATCH 4/5] Format Signed-off-by: Felix Wang --- .../contrib/athena_offline_store/tests/data_source.py | 4 +--- .../contrib/postgres_offline_store/tests/data_source.py | 4 +--- .../contrib/spark_offline_store/tests/data_source.py | 6 ++---- .../contrib/trino_offline_store/tests/data_source.py | 4 +--- .../feature_repos/universal/data_sources/bigquery.py | 4 +--- .../feature_repos/universal/data_sources/file.py | 4 +--- .../feature_repos/universal/data_sources/redshift.py | 4 +--- 7 files changed, 8 insertions(+), 22 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py index 200f3b009e4..92e0d6e5f60 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py @@ -94,9 +94,7 @@ def create_data_source( data_source=self.offline_store_config.data_source, ) - def create_saved_dataset_destination( - self, - ) -> SavedDatasetAthenaStorage: + def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py index 18be61d7930..c84fce03dcd 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py @@ -119,9 +119,7 @@ def create_online_store(self) -> Dict[str, str]: "password": POSTGRES_PASSWORD, } - def create_saved_dataset_destination( - self, - ): + def create_saved_dataset_destination(self): # FIXME: ... return None diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py index 7c0a6af7536..95bedd1b409 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py @@ -1,5 +1,5 @@ import uuid -from typing import Dict, List, Optional +from typing import Dict, List import pandas as pd from pyspark import SparkConf @@ -96,9 +96,7 @@ def create_data_source( field_mapping=field_mapping or {"ts_1": "ts"}, ) - def create_saved_dataset_destination( - self, - ) -> SavedDatasetSparkStorage: + def create_saved_dataset_destination(self) -> SavedDatasetSparkStorage: table = f"persisted_{str(uuid.uuid4()).replace('-', '_')}" return SavedDatasetSparkStorage( table=table, query=None, path=None, file_format=None diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py index 7aba5730b73..67efa6a27f8 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py @@ -105,9 +105,7 @@ def create_data_source( field_mapping=field_mapping or {"ts_1": "ts"}, ) - def create_saved_dataset_destination( - self, - ) -> SavedDatasetTrinoStorage: + def create_saved_dataset_destination(self) -> SavedDatasetTrinoStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py index abd532ea5f9..384037eef14 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py @@ -95,9 +95,7 @@ def create_data_source( field_mapping=field_mapping or {"ts_1": "ts"}, ) - def create_saved_dataset_destination( - self, - ) -> SavedDatasetBigQueryStorage: + def create_saved_dataset_destination(self) -> SavedDatasetBigQueryStorage: table = self.get_prefixed_table_name( f"persisted_{str(uuid.uuid4()).replace('-', '_')}" ) diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py index 6aa4e9a4535..7b8e5e80e67 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py @@ -154,9 +154,7 @@ def create_data_source( s3_endpoint_override=f"http://{host}:{port}", ) - def create_saved_dataset_destination( - self, - ) -> SavedDatasetFileStorage: + def create_saved_dataset_destination(self) -> SavedDatasetFileStorage: port = self.minio.get_exposed_port("9000") host = self.minio.get_container_host_ip() diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py index 5bce2d2b133..c92a413616b 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py @@ -78,9 +78,7 @@ def create_data_source( database=self.offline_store_config.database, ) - def create_saved_dataset_destination( - self, - ) -> SavedDatasetRedshiftStorage: + def create_saved_dataset_destination(self) -> SavedDatasetRedshiftStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) From 8a1a6f0c649f14e799a1eaf79257f1f737e321a5 Mon Sep 17 00:00:00 2001 From: Felix Wang Date: Thu, 18 Aug 2022 09:41:10 -0700 Subject: [PATCH 5/5] Make overwriting optional Signed-off-by: Felix Wang --- sdk/python/feast/feature_store.py | 11 ++++++++++- sdk/python/feast/infra/offline_stores/bigquery.py | 2 +- .../contrib/athena_offline_store/athena.py | 2 +- .../contrib/postgres_offline_store/postgres.py | 2 +- .../contrib/spark_offline_store/spark.py | 2 +- .../contrib/trino_offline_store/trino.py | 2 +- sdk/python/feast/infra/offline_stores/file.py | 4 ++-- .../feast/infra/offline_stores/offline_store.py | 8 +++++--- sdk/python/feast/infra/offline_stores/redshift.py | 2 +- sdk/python/feast/infra/offline_stores/snowflake.py | 2 +- sdk/python/tests/integration/e2e/test_validation.py | 4 ++++ .../test_universal_historical_retrieval.py | 1 + 12 files changed, 29 insertions(+), 13 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index ea13c3a8db3..02225a7b52e 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1146,6 +1146,7 @@ def create_saved_dataset( storage: SavedDatasetStorage, tags: Optional[Dict[str, str]] = None, feature_service: Optional[FeatureService] = None, + allow_overwrite: bool = False, ) -> SavedDataset: """ Execute provided retrieval job and persist its outcome in given storage. @@ -1154,6 +1155,14 @@ def create_saved_dataset( Name for the saved dataset should be unique within project, since it's possible to overwrite previously stored dataset with the same name. + Args: + from_: The retrieval job whose result should be persisted. + name: The name of the saved dataset. + storage: The saved dataset storage object indicating where the result should be persisted. + tags (optional): A dictionary of key-value pairs to store arbitrary metadata. + feature_service (optional): The feature service that should be associated with this saved dataset. + allow_overwrite (optional): If True, the persisted result can overwrite an existing table or file. + Returns: SavedDataset object with attached RetrievalJob @@ -1186,7 +1195,7 @@ def create_saved_dataset( dataset.min_event_timestamp = from_.metadata.min_event_timestamp dataset.max_event_timestamp = from_.metadata.max_event_timestamp - from_.persist(storage) + from_.persist(storage=storage, allow_overwrite=allow_overwrite) dataset = dataset.with_retrieval_job( self._get_provider().retrieve_saved_dataset( diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index da19bff5eca..5c3535071a8 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -493,7 +493,7 @@ def _execute_query( block_until_done(client=self.client, bq_job=bq_job, timeout=timeout) return bq_job - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetBigQueryStorage) self.to_bigquery( diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py index d7f11fb39f8..92e133d02ea 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -402,7 +402,7 @@ def _to_arrow_internal(self) -> pa.Table: def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetAthenaStorage) self.to_athena(table_name=storage.athena_options.table) diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py index 1347f8b37c0..80b1e089a19 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py @@ -297,7 +297,7 @@ def _to_arrow_internal(self) -> pa.Table: def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetPostgreSQLStorage) df_to_postgres_table( diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 26d414232f2..7c19b1e4e3c 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -275,7 +275,7 @@ def _to_arrow_internal(self) -> pyarrow.Table: self.to_spark_df().write.parquet(temp_dir, mode="overwrite") return pq.read_table(temp_dir) - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): """ Run the retrieval and persist the results in the same offline store used for read. Please note the persisting is done only within the scope of the spark session. diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py index b5f0b1f9504..5a3a9737d30 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py @@ -126,7 +126,7 @@ def to_trino( self._client.execute_query(query_text=query) return destination_table - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): """ Run the retrieval and persist the results in the same offline store used for read. """ diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 4ca5ce219fe..742366d42ee 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -88,11 +88,11 @@ def _to_arrow_internal(self): df = self.evaluation_function().compute() return pyarrow.Table.from_pandas(df) - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetFileStorage) # Check if the specified location already exists. - if os.path.exists(storage.file_options.uri): + if not allow_overwrite and os.path.exists(storage.file_options.uri): raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri) filesystem, path = FileSource.create_filesystem_and_path( diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 95f28e8ade7..b3b17eaed3b 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -173,13 +173,15 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]: pass @abstractmethod - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): """ Synchronously executes the underlying query and persists the result in the same offline store at the specified destination. - Currently does not prevent overwriting a pre-existing location in the offline store, although - individual implementations may do so. Eventually all implementations should prevent overwriting. + Args: + storage: The saved dataset storage object specifying where the result should be persisted. + allow_overwrite: If True, a pre-existing location (e.g. table or file) can be overwritten. + Currently not all individual offline store implementations make use of this parameter. """ pass diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 2acf06017da..1c20ff0c5a9 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -483,7 +483,7 @@ def to_redshift(self, table_name: str) -> None: query, ) - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetRedshiftStorage) self.to_redshift(table_name=storage.redshift_options.table) diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 98db97b1799..8239aec34c2 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -460,7 +460,7 @@ def to_arrow_chunks(self, arrow_options: Optional[Dict] = None) -> Optional[List return arrow_batches - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetSnowflakeStorage) self.to_snowflake(table_name=storage.snowflake_options.table) diff --git a/sdk/python/tests/integration/e2e/test_validation.py b/sdk/python/tests/integration/e2e/test_validation.py index 26b46d96483..771061b2069 100644 --- a/sdk/python/tests/integration/e2e/test_validation.py +++ b/sdk/python/tests/integration/e2e/test_validation.py @@ -65,6 +65,7 @@ def test_historical_retrieval_with_validation(environment, universal_data_source from_=reference_job, name="my_training_dataset", storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, ) saved_dataset = store.get_saved_dataset("my_training_dataset") @@ -95,6 +96,7 @@ def test_historical_retrieval_fails_on_validation(environment, universal_data_so from_=reference_job, name="my_other_dataset", storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, ) job = store.get_historical_features( @@ -172,6 +174,7 @@ def test_logged_features_validation(environment, universal_data_sources): ), name="reference_for_validating_logged_features", storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, ) log_source_df = store.get_historical_features( @@ -245,6 +248,7 @@ def test_e2e_validation_via_cli(environment, universal_data_sources): from_=retrieval_job, name="reference_for_validating_logged_features", storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, ) reference = saved_dataset.as_reference( name="test_reference", profiler=configurable_profiler diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index cd61822e1c0..73c5152d477 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -381,6 +381,7 @@ def test_historical_features_persisting( name="saved_dataset", storage=environment.data_source_creator.create_saved_dataset_destination(), tags={"env": "test"}, + allow_overwrite=True, ) event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL