diff --git a/docs/reference/offline-stores/spark.md b/docs/reference/offline-stores/spark.md index 2e2facba64a..7f2d4094785 100644 --- a/docs/reference/offline-stores/spark.md +++ b/docs/reference/offline-stores/spark.md @@ -32,9 +32,13 @@ offline_store: spark.sql.session.timeZone: "UTC" spark.sql.execution.arrow.fallback.enabled: "true" spark.sql.execution.arrow.pyspark.enabled: "true" + # Optional: spill large materializations to the staging location instead of collecting in the driver + staging_location: "s3://my-bucket/tmp/feast" online_store: path: data/online_store.db ``` + +> The `staging_location` can point to object storage (like S3, GCS, or Azure blobs) or a local filesystem directory (e.g., `/tmp/feast/staging`) to spill large materialization outputs before reading them back into Feast. {% endcode %} The full set of configuration options is available in [SparkOfflineStoreConfig](https://rtd.feast.dev/en/master/#feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStoreConfig). @@ -60,7 +64,7 @@ Below is a matrix indicating which functionality is supported by `SparkRetrieval | export to arrow table | yes | | export to arrow batches | no | | export to SQL | no | -| export to data lake (S3, GCS, etc.) | no | +| export to data lake (S3, GCS, etc.) | yes | | export to data warehouse | no | | export as Spark dataframe | yes | | local execution of Python-based on-demand transforms | no | diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 47e76a014f0..af83d303504 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -16,6 +16,7 @@ Union, cast, ) +from urllib.parse import urlparse if TYPE_CHECKING: from feast.saved_dataset import ValidationReference @@ -24,6 +25,7 @@ import pandas import pandas as pd import pyarrow +import pyarrow.dataset as ds import pyarrow.parquet as pq import pyspark from pydantic import StrictStr @@ -445,8 +447,43 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table: """Return dataset as pyarrow Table synchronously""" + if self._should_use_staging_for_arrow(): + return self._to_arrow_via_staging() + return pyarrow.Table.from_pandas(self._to_df_internal(timeout=timeout)) + def _should_use_staging_for_arrow(self) -> bool: + offline_store = getattr(self._config, "offline_store", None) + return bool( + isinstance(offline_store, SparkOfflineStoreConfig) + and getattr(offline_store, "staging_location", None) + ) + + def _to_arrow_via_staging(self) -> pyarrow.Table: + paths = self.to_remote_storage() + if not paths: + return pyarrow.table({}) + + parquet_paths = _filter_parquet_files(paths) + if not parquet_paths: + return pyarrow.table({}) + + normalized_paths = self._normalize_staging_paths(parquet_paths) + dataset = ds.dataset(normalized_paths, format="parquet") + return dataset.to_table() + + def _normalize_staging_paths(self, paths: List[str]) -> List[str]: + """Normalize staging paths for PyArrow datasets.""" + normalized = [] + for path in paths: + if path.startswith("file://"): + normalized.append(path[len("file://") :]) + elif "://" in path: + normalized.append(path) + else: + normalized.append(path) + return normalized + def to_feast_df( self, validation_reference: Optional["ValidationReference"] = None, @@ -508,55 +545,53 @@ def supports_remote_storage_export(self) -> bool: def to_remote_storage(self) -> List[str]: """Currently only works for local and s3-based staging locations""" - if self.supports_remote_storage_export(): - sdf: pyspark.sql.DataFrame = self.to_spark_df() - - if self._config.offline_store.staging_location.startswith("/"): - local_file_staging_location = os.path.abspath( - self._config.offline_store.staging_location - ) - - # write to staging location - output_uri = os.path.join( - str(local_file_staging_location), str(uuid.uuid4()) - ) - sdf.write.parquet(output_uri) - - return _list_files_in_folder(output_uri) - elif self._config.offline_store.staging_location.startswith("s3://"): - from feast.infra.utils import aws_utils - - spark_compatible_s3_staging_location = ( - self._config.offline_store.staging_location.replace( - "s3://", "s3a://" - ) - ) - - # write to staging location - output_uri = os.path.join( - str(spark_compatible_s3_staging_location), str(uuid.uuid4()) - ) - sdf.write.parquet(output_uri) - - return aws_utils.list_s3_files( - self._config.offline_store.region, output_uri - ) - elif self._config.offline_store.staging_location.startswith("hdfs://"): - output_uri = os.path.join( - self._config.offline_store.staging_location, str(uuid.uuid4()) - ) - sdf.write.parquet(output_uri) - spark_session = get_spark_session_or_start_new_with_repoconfig( - store_config=self._config.offline_store - ) - return _list_hdfs_files(spark_session, output_uri) - else: - raise NotImplementedError( - "to_remote_storage is only implemented for file://, s3:// and hdfs:// uri schemes" - ) + if not self.supports_remote_storage_export(): + raise NotImplementedError() + + sdf: pyspark.sql.DataFrame = self.to_spark_df() + staging_location = self._config.offline_store.staging_location + + if staging_location.startswith("/"): + local_file_staging_location = os.path.abspath(staging_location) + output_uri = os.path.join(local_file_staging_location, str(uuid.uuid4())) + sdf.write.parquet(output_uri) + return _list_files_in_folder(output_uri) + elif staging_location.startswith("s3://"): + from feast.infra.utils import aws_utils + spark_compatible_s3_staging_location = staging_location.replace( + "s3://", "s3a://" + ) + output_uri = os.path.join( + spark_compatible_s3_staging_location, str(uuid.uuid4()) + ) + sdf.write.parquet(output_uri) + s3_uri_for_listing = output_uri.replace("s3a://", "s3://", 1) + return aws_utils.list_s3_files( + self._config.offline_store.region, s3_uri_for_listing + ) + elif staging_location.startswith("gs://"): + output_uri = os.path.join(staging_location, str(uuid.uuid4())) + sdf.write.parquet(output_uri) + return _list_gcs_files(output_uri) + elif staging_location.startswith(("wasbs://", "abfs://", "abfss://")) or ( + staging_location.startswith("https://") + and ".blob.core.windows.net" in staging_location + ): + output_uri = os.path.join(staging_location, str(uuid.uuid4())) + sdf.write.parquet(output_uri) + return _list_azure_files(output_uri) + elif staging_location.startswith("hdfs://"): + output_uri = os.path.join(staging_location, str(uuid.uuid4())) + sdf.write.parquet(output_uri) + spark_session = get_spark_session_or_start_new_with_repoconfig( + store_config=self._config.offline_store + ) + return _list_hdfs_files(spark_session, output_uri) else: - raise NotImplementedError() + raise NotImplementedError( + "to_remote_storage is only implemented for file://, s3://, gs://, azure, and hdfs uri schemes" + ) @property def metadata(self) -> Optional[RetrievalMetadata]: @@ -789,6 +824,10 @@ def _list_files_in_folder(folder): return files +def _filter_parquet_files(paths: List[str]) -> List[str]: + return [path for path in paths if path.endswith(".parquet")] + + def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]: jvm = spark_session._jvm jsc = spark_session._jsc @@ -805,6 +844,81 @@ def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]: return files +def _list_gcs_files(path: str) -> List[str]: + try: + from google.cloud import storage + except ImportError as e: + from feast.errors import FeastExtrasDependencyImportError + + raise FeastExtrasDependencyImportError("gcp", str(e)) + + assert path.startswith("gs://"), "GCS path must start with gs://" + bucket_path = path[len("gs://") :] + if "/" in bucket_path: + bucket, prefix = bucket_path.split("/", 1) + else: + bucket, prefix = bucket_path, "" + + client = storage.Client() + bucket_obj = client.bucket(bucket) + blobs = bucket_obj.list_blobs(prefix=prefix) + + files = [] + for blob in blobs: + if not blob.name.endswith("/"): + files.append(f"gs://{bucket}/{blob.name}") + return files + + +def _list_azure_files(path: str) -> List[str]: + try: + from azure.identity import DefaultAzureCredential + from azure.storage.blob import BlobServiceClient + except ImportError as e: + from feast.errors import FeastExtrasDependencyImportError + + raise FeastExtrasDependencyImportError("azure", str(e)) + + parsed = urlparse(path) + scheme = parsed.scheme + + if scheme in ("wasbs", "abfs", "abfss"): + if "@" not in parsed.netloc: + raise ValueError("Azure staging URI must include container@account host") + container, account_host = parsed.netloc.split("@", 1) + account_url = f"https://{account_host}" + base = f"{scheme}://{container}@{account_host}" + prefix = parsed.path.lstrip("/") + else: + account_url = f"{parsed.scheme}://{parsed.netloc}" + container_and_prefix = parsed.path.lstrip("/").split("/", 1) + container = container_and_prefix[0] + base = f"{account_url}/{container}" + prefix = container_and_prefix[1] if len(container_and_prefix) > 1 else "" + + credential = os.environ.get("AZURE_STORAGE_KEY") or os.environ.get( + "AZURE_STORAGE_ACCOUNT_KEY" + ) + if credential: + client = BlobServiceClient(account_url=account_url, credential=credential) + else: + default_credential = DefaultAzureCredential( + exclude_shared_token_cache_credential=True + ) + client = BlobServiceClient( + account_url=account_url, credential=default_credential + ) + + container_client = client.get_container_client(container) + blobs = container_client.list_blobs(name_starts_with=prefix if prefix else None) + + files = [] + for blob in blobs: + if not blob.name.endswith("/"): + files.append(f"{base}/{blob.name}") + return files + + def _cast_data_frame( df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame ) -> pyspark.sql.DataFrame: diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py index 22c75ebf387..fbbba055e56 100644 --- a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark.py @@ -1,13 +1,17 @@ +import os from datetime import datetime from unittest.mock import MagicMock, patch import pandas as pd +import pyarrow as pa from feast.entity import Entity from feast.feature_view import FeatureView, Field +from feast.infra.offline_stores.contrib.spark_offline_store import spark as spark_module from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, SparkOfflineStoreConfig, + SparkRetrievalJob, ) from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( SparkSource, @@ -504,3 +508,207 @@ def test_get_historical_features_non_entity_with_only_end_date(mock_get_spark_se # Verify data: mocked DataFrame flows to Pandas pdf = retrieval_job._to_df_internal() assert pdf.equals(expected_pdf) + + +def test_to_arrow_uses_staging_when_enabled(monkeypatch, tmp_path): + repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig( + type="spark", + staging_location=str(tmp_path), + ), + ) + + job = SparkRetrievalJob( + spark_session=MagicMock(), + query="select 1", + full_feature_names=False, + config=repo_config, + ) + + expected_table = pa.table({"a": [1]}) + dataset_mock = MagicMock() + + monkeypatch.setattr( + job, "to_remote_storage", MagicMock(return_value=["file:///test.parquet"]) + ) + monkeypatch.setattr( + spark_module.ds, "dataset", MagicMock(return_value=dataset_mock) + ) + dataset_mock.to_table.return_value = expected_table + + result = job._to_arrow_internal() + + job.to_remote_storage.assert_called_once() + spark_module.ds.dataset.assert_called_once_with(["/test.parquet"], format="parquet") + assert result.equals(expected_table) + + +def test_to_arrow_normalizes_local_staging_paths(monkeypatch, tmp_path): + repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig( + type="spark", + staging_location=str(tmp_path / "local"), + ), + ) + + job = SparkRetrievalJob( + spark_session=MagicMock(), + query="select 1", + full_feature_names=False, + config=repo_config, + ) + + expected_table = pa.table({"a": [1]}) + dataset_mock = MagicMock() + + raw_path = os.path.join(str(tmp_path), "staged.parquet") + monkeypatch.setattr(job, "to_remote_storage", MagicMock(return_value=[raw_path])) + monkeypatch.setattr( + spark_module.ds, "dataset", MagicMock(return_value=dataset_mock) + ) + dataset_mock.to_table.return_value = expected_table + + result = job._to_arrow_internal() + + job.to_remote_storage.assert_called_once() + spark_module.ds.dataset.assert_called_once_with([raw_path], format="parquet") + assert result.equals(expected_table) + + +def test_to_arrow_falls_back_to_pandas_when_staging_disabled(monkeypatch): + repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig( + type="spark", + staging_location=None, + ), + ) + + job = SparkRetrievalJob( + spark_session=MagicMock(), + query="select 1", + full_feature_names=False, + config=repo_config, + ) + + pdf = pd.DataFrame({"a": [1]}) + monkeypatch.setattr(job, "_to_df_internal", MagicMock(return_value=pdf)) + monkeypatch.setattr( + job, "to_remote_storage", MagicMock(side_effect=AssertionError("not called")) + ) + + result = job._to_arrow_internal() + + assert result.equals(pa.Table.from_pandas(pdf)) + + +@patch("feast.infra.utils.aws_utils.list_s3_files") +def test_to_remote_storage_lists_with_s3_scheme(mock_list_s3_files): + spark_df = MagicMock() + spark_df.write.parquet = MagicMock() + spark_session = MagicMock() + spark_session.sql.return_value = spark_df + + mock_list_s3_files.return_value = ["s3://bucket/prefix/file.parquet"] + + repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig( + type="spark", + staging_location="s3://bucket/prefix", + region="us-west-2", + ), + ) + + job = SparkRetrievalJob( + spark_session=spark_session, + query="select 1", + full_feature_names=False, + config=repo_config, + ) + + result = job.to_remote_storage() + + assert spark_df.write.parquet.call_args[0][0].startswith("s3a://bucket/prefix") + assert mock_list_s3_files.call_args[0][1].startswith("s3://bucket/prefix") + assert result == mock_list_s3_files.return_value + + +@patch("feast.infra.offline_stores.contrib.spark_offline_store.spark._list_gcs_files") +def test_to_remote_storage_lists_with_gcs_scheme(mock_list_gcs_files): + spark_df = MagicMock() + spark_df.write.parquet = MagicMock() + spark_session = MagicMock() + spark_session.sql.return_value = spark_df + + mock_list_gcs_files.return_value = ["gs://bucket/prefix/file.parquet"] + + repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig( + type="spark", + staging_location="gs://bucket/prefix", + ), + ) + + job = SparkRetrievalJob( + spark_session=spark_session, + query="select 1", + full_feature_names=False, + config=repo_config, + ) + + result = job.to_remote_storage() + + assert spark_df.write.parquet.call_args[0][0].startswith("gs://bucket/prefix") + mock_list_gcs_files.assert_called_once() + assert result == mock_list_gcs_files.return_value + + +@patch("feast.infra.offline_stores.contrib.spark_offline_store.spark._list_azure_files") +def test_to_remote_storage_lists_with_azure_scheme(mock_list_azure_files): + spark_df = MagicMock() + spark_df.write.parquet = MagicMock() + spark_session = MagicMock() + spark_session.sql.return_value = spark_df + + mock_list_azure_files.return_value = [ + "wasbs://container@account.blob.core.windows.net/path/file.parquet" + ] + + repo_config = RepoConfig( + project="test_project", + registry="test_registry", + provider="local", + offline_store=SparkOfflineStoreConfig( + type="spark", + staging_location="wasbs://container@account.blob.core.windows.net/path", + ), + ) + + job = SparkRetrievalJob( + spark_session=spark_session, + query="select 1", + full_feature_names=False, + config=repo_config, + ) + + result = job.to_remote_storage() + + assert spark_df.write.parquet.call_args[0][0].startswith( + "wasbs://container@account.blob.core.windows.net/path" + ) + mock_list_azure_files.assert_called_once() + assert result == mock_list_azure_files.return_value