diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 6112279a027..f2fa33e53a8 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -706,7 +706,7 @@ def plan( >>> fs = FeatureStore(repo_path="project/feature_repo") >>> driver = Entity(name="driver_id", description="driver id") >>> driver_hourly_stats = FileSource( - ... path="project/feature_repo/data/driver_stats.parquet", + ... path="data/driver_stats.parquet", ... timestamp_field="event_timestamp", ... created_timestamp_column="created", ... ) @@ -820,7 +820,7 @@ def apply( >>> fs = FeatureStore(repo_path="project/feature_repo") >>> driver = Entity(name="driver_id", description="driver id") >>> driver_hourly_stats = FileSource( - ... path="project/feature_repo/data/driver_stats.parquet", + ... path="data/driver_stats.parquet", ... timestamp_field="event_timestamp", ... created_timestamp_column="created", ... ) diff --git a/sdk/python/feast/infra/offline_stores/dask.py b/sdk/python/feast/infra/offline_stores/dask.py index 52ad88d2997..d26e8609bae 100644 --- a/sdk/python/feast/infra/offline_stores/dask.py +++ b/sdk/python/feast/infra/offline_stores/dask.py @@ -57,6 +57,7 @@ def __init__( self, evaluation_function: Callable, full_feature_names: bool, + repo_path: str, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, metadata: Optional[RetrievalMetadata] = None, ): @@ -67,6 +68,7 @@ def __init__( self._full_feature_names = full_feature_names self._on_demand_feature_views = on_demand_feature_views or [] self._metadata = metadata + self.repo_path = repo_path @property def full_feature_names(self) -> bool: @@ -99,8 +101,13 @@ def persist( if not allow_overwrite and os.path.exists(storage.file_options.uri): raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri) + if not Path(storage.file_options.uri).is_absolute(): + absolute_path = Path(self.repo_path) / storage.file_options.uri + else: + absolute_path = Path(storage.file_options.uri) + filesystem, path = FileSource.create_filesystem_and_path( - storage.file_options.uri, + str(absolute_path), storage.file_options.s3_endpoint_override, ) @@ -243,7 +250,9 @@ def evaluate_historical_retrieval(): all_join_keys = list(set(all_join_keys + join_keys)) - df_to_join = _read_datasource(feature_view.batch_source) + df_to_join = _read_datasource( + feature_view.batch_source, config.repo_path + ) df_to_join, timestamp_field = _field_mapping( df_to_join, @@ -297,6 +306,7 @@ def evaluate_historical_retrieval(): min_event_timestamp=entity_df_event_timestamp_range[0], max_event_timestamp=entity_df_event_timestamp_range[1], ), + repo_path=str(config.repo_path), ) return job @@ -316,7 +326,7 @@ def pull_latest_from_table_or_query( # Create lazy function that is only called from the RetrievalJob object def evaluate_offline_job(): - source_df = _read_datasource(data_source) + source_df = _read_datasource(data_source, config.repo_path) source_df = _normalize_timestamp( source_df, timestamp_field, created_timestamp_column @@ -377,6 +387,7 @@ def evaluate_offline_job(): return DaskRetrievalJob( evaluation_function=evaluate_offline_job, full_feature_names=False, + repo_path=str(config.repo_path), ) @staticmethod @@ -420,8 +431,13 @@ def write_logged_features( # Since this code will be mostly used from Go-created thread, it's better to avoid producing new threads data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False) + if config.repo_path is not None and not Path(destination.path).is_absolute(): + absolute_path = config.repo_path / destination.path + else: + absolute_path = Path(destination.path) + filesystem, path = FileSource.create_filesystem_and_path( - destination.path, + str(absolute_path), destination.s3_endpoint_override, ) @@ -456,8 +472,14 @@ def offline_write_batch( ) file_options = feature_view.batch_source.file_options + + if config.repo_path is not None and not Path(file_options.uri).is_absolute(): + absolute_path = config.repo_path / file_options.uri + else: + absolute_path = Path(file_options.uri) + filesystem, path = FileSource.create_filesystem_and_path( - file_options.uri, file_options.s3_endpoint_override + str(absolute_path), file_options.s3_endpoint_override ) prev_table = pyarrow.parquet.read_table( path, filesystem=filesystem, memory_map=True @@ -493,7 +515,7 @@ def _get_entity_df_event_timestamp_range( ) -def _read_datasource(data_source) -> dd.DataFrame: +def _read_datasource(data_source, repo_path) -> dd.DataFrame: storage_options = ( { "client_kwargs": { @@ -504,8 +526,12 @@ def _read_datasource(data_source) -> dd.DataFrame: else None ) + if not Path(data_source.path).is_absolute(): + path = repo_path / data_source.path + else: + path = data_source.path return dd.read_parquet( - data_source.path, + path, storage_options=storage_options, ) diff --git a/sdk/python/feast/infra/offline_stores/duckdb.py b/sdk/python/feast/infra/offline_stores/duckdb.py index a639d54add5..e64da029a6a 100644 --- a/sdk/python/feast/infra/offline_stores/duckdb.py +++ b/sdk/python/feast/infra/offline_stores/duckdb.py @@ -27,7 +27,7 @@ from feast.repo_config import FeastConfigBaseModel, RepoConfig -def _read_data_source(data_source: DataSource) -> Table: +def _read_data_source(data_source: DataSource, repo_path: str) -> Table: assert isinstance(data_source, FileSource) if isinstance(data_source.file_format, ParquetFormat): @@ -43,6 +43,7 @@ def _read_data_source(data_source: DataSource) -> Table: def _write_data_source( table: Table, data_source: DataSource, + repo_path: str, mode: str = "append", allow_overwrite: bool = False, ): @@ -50,14 +51,24 @@ def _write_data_source( file_options = data_source.file_options - if mode == "overwrite" and not allow_overwrite and os.path.exists(file_options.uri): + if not Path(file_options.uri).is_absolute(): + absolute_path = Path(repo_path) / file_options.uri + else: + absolute_path = Path(file_options.uri) + + if ( + mode == "overwrite" + and not allow_overwrite + and os.path.exists(str(absolute_path)) + ): raise SavedDatasetLocationAlreadyExists(location=file_options.uri) if isinstance(data_source.file_format, ParquetFormat): if mode == "overwrite": table = table.to_pyarrow() + filesystem, path = FileSource.create_filesystem_and_path( - file_options.uri, + str(absolute_path), file_options.s3_endpoint_override, ) diff --git a/sdk/python/feast/infra/offline_stores/file_source.py b/sdk/python/feast/infra/offline_stores/file_source.py index 3fdc6cba31a..9557b8077d0 100644 --- a/sdk/python/feast/infra/offline_stores/file_source.py +++ b/sdk/python/feast/infra/offline_stores/file_source.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Callable, Dict, Iterable, List, Optional, Tuple import pyarrow @@ -154,8 +155,16 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: + if ( + config.repo_path is not None + and not Path(self.file_options.uri).is_absolute() + ): + absolute_path = config.repo_path / self.file_options.uri + else: + absolute_path = Path(self.file_options.uri) + filesystem, path = FileSource.create_filesystem_and_path( - self.path, self.file_options.s3_endpoint_override + str(absolute_path), self.file_options.s3_endpoint_override ) # TODO why None check necessary diff --git a/sdk/python/feast/infra/offline_stores/ibis.py b/sdk/python/feast/infra/offline_stores/ibis.py index 61c477baec6..66d00ca6292 100644 --- a/sdk/python/feast/infra/offline_stores/ibis.py +++ b/sdk/python/feast/infra/offline_stores/ibis.py @@ -46,8 +46,8 @@ def pull_latest_from_table_or_query_ibis( created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, - data_source_reader: Callable[[DataSource], Table], - data_source_writer: Callable[[pyarrow.Table, DataSource], None], + data_source_reader: Callable[[DataSource, str], Table], + data_source_writer: Callable[[pyarrow.Table, DataSource, str], None], staging_location: Optional[str] = None, staging_location_endpoint_override: Optional[str] = None, ) -> RetrievalJob: @@ -57,7 +57,7 @@ def pull_latest_from_table_or_query_ibis( start_date = start_date.astimezone(tz=timezone.utc) end_date = end_date.astimezone(tz=timezone.utc) - table = data_source_reader(data_source) + table = data_source_reader(data_source, str(config.repo_path)) table = table.select(*fields) @@ -87,6 +87,7 @@ def pull_latest_from_table_or_query_ibis( data_source_writer=data_source_writer, staging_location=staging_location, staging_location_endpoint_override=staging_location_endpoint_override, + repo_path=str(config.repo_path), ) @@ -147,8 +148,8 @@ def get_historical_features_ibis( entity_df: Union[pd.DataFrame, str], registry: BaseRegistry, project: str, - data_source_reader: Callable[[DataSource], Table], - data_source_writer: Callable[[pyarrow.Table, DataSource], None], + data_source_reader: Callable[[DataSource, str], Table], + data_source_writer: Callable[[pyarrow.Table, DataSource, str], None], full_feature_names: bool = False, staging_location: Optional[str] = None, staging_location_endpoint_override: Optional[str] = None, @@ -174,7 +175,9 @@ def get_historical_features_ibis( def read_fv( feature_view: FeatureView, feature_refs: List[str], full_feature_names: bool ) -> Tuple: - fv_table: Table = data_source_reader(feature_view.batch_source) + fv_table: Table = data_source_reader( + feature_view.batch_source, str(config.repo_path) + ) for old_name, new_name in feature_view.batch_source.field_mapping.items(): if old_name in fv_table.columns: @@ -247,6 +250,7 @@ def read_fv( data_source_writer=data_source_writer, staging_location=staging_location, staging_location_endpoint_override=staging_location_endpoint_override, + repo_path=str(config.repo_path), ) @@ -258,8 +262,8 @@ def pull_all_from_table_or_query_ibis( timestamp_field: str, start_date: datetime, end_date: datetime, - data_source_reader: Callable[[DataSource], Table], - data_source_writer: Callable[[pyarrow.Table, DataSource], None], + data_source_reader: Callable[[DataSource, str], Table], + data_source_writer: Callable[[pyarrow.Table, DataSource, str], None], staging_location: Optional[str] = None, staging_location_endpoint_override: Optional[str] = None, ) -> RetrievalJob: @@ -267,7 +271,7 @@ def pull_all_from_table_or_query_ibis( start_date = start_date.astimezone(tz=timezone.utc) end_date = end_date.astimezone(tz=timezone.utc) - table = data_source_reader(data_source) + table = data_source_reader(data_source, str(config.repo_path)) table = table.select(*fields) @@ -290,6 +294,7 @@ def pull_all_from_table_or_query_ibis( data_source_writer=data_source_writer, staging_location=staging_location, staging_location_endpoint_override=staging_location_endpoint_override, + repo_path=str(config.repo_path), ) @@ -319,7 +324,7 @@ def offline_write_batch_ibis( feature_view: FeatureView, table: pyarrow.Table, progress: Optional[Callable[[int], Any]], - data_source_writer: Callable[[pyarrow.Table, DataSource], None], + data_source_writer: Callable[[pyarrow.Table, DataSource, str], None], ): pa_schema, column_names = get_pyarrow_schema_from_batch_source( config, feature_view.batch_source @@ -330,7 +335,9 @@ def offline_write_batch_ibis( f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." ) - data_source_writer(ibis.memtable(table), feature_view.batch_source) + data_source_writer( + ibis.memtable(table), feature_view.batch_source, str(config.repo_path) + ) def deduplicate( @@ -469,6 +476,7 @@ def __init__( data_source_writer, staging_location, staging_location_endpoint_override, + repo_path, ) -> None: super().__init__() self.table = table @@ -480,6 +488,7 @@ def __init__( self.data_source_writer = data_source_writer self.staging_location = staging_location self.staging_location_endpoint_override = staging_location_endpoint_override + self.repo_path = repo_path def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: return self.table.execute() @@ -502,7 +511,11 @@ def persist( timeout: Optional[int] = None, ): self.data_source_writer( - self.table, storage.to_data_source(), "overwrite", allow_overwrite + self.table, + storage.to_data_source(), + self.repo_path, + "overwrite", + allow_overwrite, ) @property diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index bf0bde6fcbf..845b5505c9f 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -193,6 +193,7 @@ class RepoConfig(FeastBaseModel): """ Flags (deprecated field): Feature flags for experimental features """ repo_path: Optional[Path] = None + """When using relative path in FileSource path, this parameter is mandatory""" entity_key_serialization_version: StrictInt = 1 """ Entity key serialization version: This version is used to control what serialization scheme is diff --git a/sdk/python/feast/templates/cassandra/bootstrap.py b/sdk/python/feast/templates/cassandra/bootstrap.py index fa70917914f..33385141145 100644 --- a/sdk/python/feast/templates/cassandra/bootstrap.py +++ b/sdk/python/feast/templates/cassandra/bootstrap.py @@ -275,7 +275,9 @@ def bootstrap(): # example_repo.py example_py_file = repo_path / "example_repo.py" - replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path)) + replace_str_in_file( + example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path)) + ) # store config yaml, interact with user and then customize file: settings = collect_cassandra_store_settings() diff --git a/sdk/python/feast/templates/hazelcast/bootstrap.py b/sdk/python/feast/templates/hazelcast/bootstrap.py index e5018e4fe02..7a2b49d2493 100644 --- a/sdk/python/feast/templates/hazelcast/bootstrap.py +++ b/sdk/python/feast/templates/hazelcast/bootstrap.py @@ -165,7 +165,9 @@ def bootstrap(): # example_repo.py example_py_file = repo_path / "example_repo.py" - replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path)) + replace_str_in_file( + example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path)) + ) # store config yaml, interact with user and then customize file: settings = collect_hazelcast_online_store_settings() diff --git a/sdk/python/feast/templates/hbase/bootstrap.py b/sdk/python/feast/templates/hbase/bootstrap.py index 125eb7c2e72..94be8e441da 100644 --- a/sdk/python/feast/templates/hbase/bootstrap.py +++ b/sdk/python/feast/templates/hbase/bootstrap.py @@ -23,7 +23,9 @@ def bootstrap(): driver_df.to_parquet(path=str(driver_stats_path), allow_truncated_timestamps=True) example_py_file = repo_path / "example_repo.py" - replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path)) + replace_str_in_file( + example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path)) + ) if __name__ == "__main__": diff --git a/sdk/python/feast/templates/local/bootstrap.py b/sdk/python/feast/templates/local/bootstrap.py index e2c1efdbc49..9f6a5a6c969 100644 --- a/sdk/python/feast/templates/local/bootstrap.py +++ b/sdk/python/feast/templates/local/bootstrap.py @@ -25,8 +25,12 @@ def bootstrap(): example_py_file = repo_path / "example_repo.py" replace_str_in_file(example_py_file, "%PROJECT_NAME%", str(project_name)) - replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path)) - replace_str_in_file(example_py_file, "%LOGGING_PATH%", str(data_path)) + replace_str_in_file( + example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path)) + ) + replace_str_in_file( + example_py_file, "%LOGGING_PATH%", str(data_path.relative_to(repo_path)) + ) if __name__ == "__main__": diff --git a/sdk/python/tests/doctest/test_all.py b/sdk/python/tests/doctest/test_all.py index 52348e7da4e..d1b2161252f 100644 --- a/sdk/python/tests/doctest/test_all.py +++ b/sdk/python/tests/doctest/test_all.py @@ -26,7 +26,7 @@ def setup_feature_store(): description="driver id", ) driver_hourly_stats = FileSource( - path="project/feature_repo/data/driver_stats.parquet", + path="data/driver_stats.parquet", timestamp_field="event_timestamp", created_timestamp_column="created", ) diff --git a/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py b/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py index 6d5eeb90c71..afc0e4e5c8f 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py @@ -109,7 +109,7 @@ def metadata(self) -> Optional[RetrievalMetadata]: ) def retrieval_job(request, environment): if request.param is DaskRetrievalJob: - return DaskRetrievalJob(lambda: 1, full_feature_names=False) + return DaskRetrievalJob(lambda: 1, full_feature_names=False, repo_path="") elif request.param is RedshiftRetrievalJob: offline_store_config = RedshiftOfflineStoreConfig( cluster_id="feast-int-bucket", diff --git a/sdk/python/tests/unit/test_offline_server.py b/sdk/python/tests/unit/test_offline_server.py index 7c38d9bfca4..e82e2fa6872 100644 --- a/sdk/python/tests/unit/test_offline_server.py +++ b/sdk/python/tests/unit/test_offline_server.py @@ -95,6 +95,7 @@ def remote_feature_store(offline_server): provider="local", offline_store=offline_config, entity_key_serialization_version=2, + # repo_config = ) ) return store diff --git a/sdk/python/tests/utils/auth_permissions_util.py b/sdk/python/tests/utils/auth_permissions_util.py index 3b5e589812a..b8ca7355e98 100644 --- a/sdk/python/tests/utils/auth_permissions_util.py +++ b/sdk/python/tests/utils/auth_permissions_util.py @@ -119,6 +119,7 @@ def get_remote_registry_store(server_port, feature_store): registry=registry_config, provider="local", entity_key_serialization_version=2, + repo_path=feature_store.repo_path, ) ) return store