diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index c66af0db18e..57d5aa1b07e 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -1,8 +1,7 @@ import functools import warnings from datetime import datetime, timedelta -from types import FunctionType -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import dill @@ -61,7 +60,7 @@ class BatchFeatureView(FeatureView): owner: str timestamp_field: str materialization_intervals: List[Tuple[datetime, datetime]] - udf: Optional[FunctionType] + udf: Optional[Callable[[Any], Any]] udf_string: Optional[str] feature_transformation: Transformation @@ -78,7 +77,7 @@ def __init__( description: str = "", owner: str = "", schema: Optional[List[Field]] = None, - udf: Optional[FunctionType] = None, + udf: Optional[Callable[[Any], Any]], udf_string: Optional[str] = "", feature_transformation: Optional[Transformation] = None, ): diff --git a/sdk/python/feast/infra/compute_engines/__init__.py b/sdk/python/feast/infra/compute_engines/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/spark/__init__.py b/sdk/python/feast/infra/compute_engines/spark/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/spark/config.py b/sdk/python/feast/infra/compute_engines/spark/config.py new file mode 100644 index 00000000000..070cf204dce --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/config.py @@ -0,0 +1,20 @@ +from typing import Dict, Optional + +from pydantic import StrictStr + +from feast.repo_config import FeastConfigBaseModel + + +class SparkComputeConfig(FeastConfigBaseModel): + type: StrictStr = "spark" + """ Spark Compute type selector""" + + spark_conf: Optional[Dict[str, str]] = None + """ Configuration overlay for the spark session """ + # sparksession is not serializable and we dont want to pass it around as an argument + + staging_location: Optional[StrictStr] = None + """ Remote path for batch materialization jobs""" + + region: Optional[StrictStr] = None + """ AWS Region if applicable for s3-based staging locations""" diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py new file mode 100644 index 00000000000..262876f9dbb --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -0,0 +1,19 @@ +from typing import Dict, Optional + +from pyspark import SparkConf +from pyspark.sql import SparkSession + + +def get_or_create_new_spark_session( + spark_config: Optional[Dict[str, str]] = None, +) -> SparkSession: + spark_session = SparkSession.getActiveSession() + if not spark_session: + spark_builder = SparkSession.builder + if spark_config: + spark_builder = spark_builder.config( + conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()]) + ) + + spark_session = spark_builder.getOrCreate() + return spark_session diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 42802993226..2f134001a5a 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -151,8 +151,9 @@ def get_feature_transformation(self) -> Optional[Transformation]: if self.mode in ( TransformationMode.PANDAS, TransformationMode.PYTHON, + TransformationMode.SPARK_SQL, TransformationMode.SPARK, - ) or self.mode in ("pandas", "python", "spark"): + ) or self.mode in ("pandas", "python", "spark_sql", "spark"): return Transformation( mode=self.mode, udf=self.udf, udf_string=self.udf_string or "" ) diff --git a/sdk/python/feast/transformation/base.py b/sdk/python/feast/transformation/base.py index 7489e16be97..b02be0a6708 100644 --- a/sdk/python/feast/transformation/base.py +++ b/sdk/python/feast/transformation/base.py @@ -81,7 +81,7 @@ def __init__( description: str = "", owner: str = "", ): - self.mode = mode if isinstance(mode, str) else mode.value + self.mode = mode self.udf = udf self.udf_string = udf_string self.name = name @@ -99,7 +99,7 @@ def to_proto(self) -> Union[UserDefinedFunctionProto, SubstraitTransformationPro def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "Transformation": return Transformation(mode=self.mode, udf=self.udf, udf_string=self.udf_string) - def transform(self, inputs: Any) -> Any: + def transform(self, *inputs: Any) -> Any: raise NotImplementedError def transform_arrow(self, *args, **kwargs) -> Any: diff --git a/sdk/python/feast/transformation/factory.py b/sdk/python/feast/transformation/factory.py index 5097d71353a..50c3c665764 100644 --- a/sdk/python/feast/transformation/factory.py +++ b/sdk/python/feast/transformation/factory.py @@ -5,6 +5,7 @@ "pandas": "feast.transformation.pandas_transformation.PandasTransformation", "substrait": "feast.transformation.substrait_transformation.SubstraitTransformation", "sql": "feast.transformation.sql_transformation.SQLTransformation", + "spark_sql": "feast.transformation.spark_transformation.SparkTransformation", "spark": "feast.transformation.spark_transformation.SparkTransformation", } diff --git a/sdk/python/feast/transformation/mode.py b/sdk/python/feast/transformation/mode.py index 4bd5ddbe7a3..2b453477b3a 100644 --- a/sdk/python/feast/transformation/mode.py +++ b/sdk/python/feast/transformation/mode.py @@ -4,6 +4,7 @@ class TransformationMode(Enum): PYTHON = "python" PANDAS = "pandas" + SPARK_SQL = "spark_sql" SPARK = "spark" SQL = "sql" SUBSTRAIT = "substrait" diff --git a/sdk/python/feast/transformation/spark_transformation.py b/sdk/python/feast/transformation/spark_transformation.py index d288cf58b08..84d4c010c17 100644 --- a/sdk/python/feast/transformation/spark_transformation.py +++ b/sdk/python/feast/transformation/spark_transformation.py @@ -1,11 +1,125 @@ -from typing import Any +from typing import Any, Dict, Optional, Union, cast +import pandas as pd +import pyspark.sql + +from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session from feast.transformation.base import Transformation +from feast.transformation.mode import TransformationMode class SparkTransformation(Transformation): - def transform(self, inputs: Any) -> Any: - pass + r""" + SparkTransformation can be used to define a transformation using a Spark UDF or SQL query. + The current spark session will be used or a new one will be created if not available. + E.g.: + spark_transformation = SparkTransformation( + mode=TransformationMode.SPARK, + udf=remove_extra_spaces, + udf_string="remove extra spaces", + ) + OR + spark_transformation = Transformation( + mode=TransformationMode.SPARK_SQL, + udf=remove_extra_spaces_sql, + udf_string="remove extra spaces sql", + ) + OR + @transformation(mode=TransformationMode.SPARK) + def remove_extra_spaces_udf(df: pd.DataFrame) -> pd.DataFrame: + return df.assign(name=df['name'].str.replace('\s+', ' ')) + """ + + def __new__( + cls, + mode: Union[TransformationMode, str], + udf: Any, + udf_string: str, + spark_config: Dict[str, Any] = {}, + name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs, + ) -> "SparkTransformation": + """ + Creates a SparkTransformation + Args: + mode: (required) The mode of the transformation. Choose one from TransformationMode.SPARK or TransformationMode.SPARK_SQL. + udf: (required) The user-defined transformation function. + udf_string: (required) The string representation of the udf. The dill get source doesn't + spark_config: (optional) The spark configuration to use for the transformation. + name: (optional) The name of the transformation. + tags: (optional) Metadata tags for the transformation. + description: (optional) A description of the transformation. + owner: (optional) The owner of the transformation. + """ + instance = super(SparkTransformation, cls).__new__( + cls, + mode=mode, + spark_config=spark_config, + udf=udf, + udf_string=udf_string, + name=name, + tags=tags, + description=description, + owner=owner, + ) + return cast(SparkTransformation, instance) + + def __init__( + self, + mode: Union[TransformationMode, str], + udf: Any, + udf_string: str, + spark_config: Dict[str, Any] = {}, + name: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs, + ): + super().__init__( + mode=mode, + udf=udf, + name=name, + udf_string=udf_string, + tags=tags, + description=description, + owner=owner, + ) + self.spark_session = get_or_create_new_spark_session(spark_config) + + def transform( + self, + *inputs: Union[str, pd.DataFrame], + ) -> pd.DataFrame: + if self.mode == TransformationMode.SPARK_SQL: + return self._transform_spark_sql(*inputs) + else: + return self._transform_spark_udf(*inputs) + + @staticmethod + def _create_temp_view_for_dataframe(df: pyspark.sql.DataFrame, name: str): + df_temp_view = f"feast_transformation_temp_view_{name}" + df.createOrReplaceTempView(df_temp_view) + return df_temp_view + + def _transform_spark_sql( + self, *inputs: Union[pyspark.sql.DataFrame, str] + ) -> pd.DataFrame: + inputs_str = [ + self._create_temp_view_for_dataframe(v, f"index_{i}") + if isinstance(v, pyspark.sql.DataFrame) + else v + for i, v in enumerate(inputs) + ] + return self.spark_session.sql(self.udf(*inputs_str)) + + def _transform_spark_udf(self, *inputs: Any) -> pd.DataFrame: + return self.udf(*inputs) def infer_features(self, *args, **kwargs) -> Any: pass diff --git a/sdk/python/tests/unit/transformation/test_pandas_transformation.py b/sdk/python/tests/unit/transformation/test_pandas_transformation.py new file mode 100644 index 00000000000..d20204ceb93 --- /dev/null +++ b/sdk/python/tests/unit/transformation/test_pandas_transformation.py @@ -0,0 +1,18 @@ +import pandas as pd + +from feast.transformation.pandas_transformation import PandasTransformation + + +def pandas_udf(features_df: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["output1"] = features_df["feature1"] + df["output2"] = features_df["feature2"] + return df + + +def test_init_pandas_transformation(): + transformation = PandasTransformation(udf=pandas_udf, udf_string="udf1") + features_df = pd.DataFrame.from_dict({"feature1": [1, 2], "feature2": [2, 3]}) + transformed_df = transformation.transform(features_df) + assert transformed_df["output1"].values[0] == 1 + assert transformed_df["output2"].values[1] == 3 diff --git a/sdk/python/tests/unit/transformation/test_spark_transformation.py b/sdk/python/tests/unit/transformation/test_spark_transformation.py new file mode 100644 index 00000000000..8ee9d22bf28 --- /dev/null +++ b/sdk/python/tests/unit/transformation/test_spark_transformation.py @@ -0,0 +1,104 @@ +from unittest.mock import patch + +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.functions import col, regexp_replace +from pyspark.testing.utils import assertDataFrameEqual + +from feast.transformation.base import Transformation +from feast.transformation.mode import TransformationMode +from feast.transformation.spark_transformation import SparkTransformation + + +def get_sample_df(spark): + sample_data = [ + {"name": "John D.", "age": 30}, + {"name": "Alice G.", "age": 25}, + {"name": "Bob T.", "age": 35}, + {"name": "Eve A.", "age": 28}, + ] + df = spark.createDataFrame(sample_data) + return df + + +def get_expected_df(spark): + expected_data = [ + {"name": "John D.", "age": 30}, + {"name": "Alice G.", "age": 25}, + {"name": "Bob T.", "age": 35}, + {"name": "Eve A.", "age": 28}, + ] + + expected_df = spark.createDataFrame(expected_data) + return expected_df + + +def remove_extra_spaces(df, column_name): + df_transformed = df.withColumn( + column_name, regexp_replace(col(column_name), "\\s+", " ") + ) + return df_transformed + + +def remove_extra_spaces_sql(df, column_name): + sql = f""" + SELECT + age, + regexp_replace({column_name}, '\\\\s+', ' ') as {column_name} + FROM {df} + """ + return sql + + +@pytest.fixture +def spark_fixture(): + spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + yield spark + + +@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") +def test_spark_transformation(spark_fixture): + spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + df = get_sample_df(spark) + + spark_transformation = Transformation( + mode=TransformationMode.SPARK, + udf=remove_extra_spaces, + udf_string="remove extra spaces", + ) + + transformed_df = spark_transformation.transform(df, "name") + expected_df = get_expected_df(spark) + assertDataFrameEqual(transformed_df, expected_df) + + +@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") +def test_spark_transformation_init_transformation(spark_fixture): + spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + df = get_sample_df(spark) + + spark_transformation = SparkTransformation( + mode=TransformationMode.SPARK, + udf=remove_extra_spaces, + udf_string="remove extra spaces", + ) + + transformed_df = spark_transformation.transform(df, "name") + expected_df = get_expected_df(spark) + assertDataFrameEqual(transformed_df, expected_df) + + +@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session") +def test_spark_transformation_sql(spark_fixture): + spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate() + df = get_sample_df(spark) + + spark_transformation = SparkTransformation( + mode=TransformationMode.SPARK_SQL, + udf=remove_extra_spaces_sql, + udf_string="remove extra spaces sql", + ) + + transformed_df = spark_transformation.transform(df, "name") + expected_df = get_expected_df(spark) + assertDataFrameEqual(transformed_df, expected_df) diff --git a/sdk/python/tests/utils/ssl_certifcates_util.py b/sdk/python/tests/utils/ssl_certifcates_util.py index 53a56e04f3d..53b9df3973c 100644 --- a/sdk/python/tests/utils/ssl_certifcates_util.py +++ b/sdk/python/tests/utils/ssl_certifcates_util.py @@ -8,7 +8,7 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric import dh, dsa, ec, rsa from cryptography.x509 import load_pem_x509_certificate from cryptography.x509.oid import NameOID @@ -126,13 +126,33 @@ def create_ca_trust_store( private_key = serialization.load_pem_private_key( private_key_data, password=None, backend=default_backend() ) - # Check the public/private key match - if ( - private_key.public_key().public_numbers() - != public_cert.public_key().public_numbers() + private_pub = private_key.public_key() + cert_pub = public_cert.public_key() + + if isinstance( + private_pub, + ( + rsa.RSAPublicKey, + dsa.DSAPublicKey, + ec.EllipticCurvePublicKey, + dh.DHPublicKey, + ), + ) and isinstance( + cert_pub, + ( + rsa.RSAPublicKey, + dsa.DSAPublicKey, + ec.EllipticCurvePublicKey, + dh.DHPublicKey, + ), ): - raise ValueError( - "Public certificate does not match the private key." + if private_pub.public_numbers() != cert_pub.public_numbers(): + raise ValueError( + "Public certificate does not match the private key." + ) + else: + logger.warning( + "Key type does not support public_numbers(). Skipping strict public key match." ) # Step 4: Add the public certificate to the new trust store