From c1e869e555c456e696951bdb459f3b92b8fa0886 Mon Sep 17 00:00:00 2001 From: Calvin Hopkins Date: Wed, 25 Oct 2023 15:22:54 -0700 Subject: [PATCH 1/3] Update SparkSource to have proper comparable Signed-off-by: Calvin Hopkins --- .../spark_offline_store/spark_source.py | 23 ++++++++ sdk/python/tests/unit/test_data_sources.py | 53 +++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index a27065fb5ed..294c8f59e4d 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -185,6 +185,18 @@ def get_table_query_string(self) -> str: return f"`{tmp_table_name}`" + # Note: Python requires redefining hash in child classes that override __eq__ + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + if not isinstance(other, SparkSource): + raise TypeError("Comparisons should only involve SparkSource class objects.") + return ( + super().__eq__(other) + and self.spark_options == other.spark_options + ) + class SparkOptions: allowed_formats = [format.value for format in SparkSourceFormat] @@ -282,6 +294,17 @@ def to_proto(self) -> DataSourceProto.SparkOptions: return spark_options_proto + def __eq__(self, other: object) -> bool: + if not isinstance(other, SparkOptions): + raise TypeError("Comparisons should only involve SparkOptions class objects.") + + return ( + self.table == other.table + and self.query == other.query + and self.path == other.path + and self.file_format == other.file_format + ) + class SavedDatasetSparkStorage(SavedDatasetStorage): _proto_attr_name = "spark_storage" diff --git a/sdk/python/tests/unit/test_data_sources.py b/sdk/python/tests/unit/test_data_sources.py index 990c5d3b698..c72d4899dcf 100644 --- a/sdk/python/tests/unit/test_data_sources.py +++ b/sdk/python/tests/unit/test_data_sources.py @@ -13,6 +13,7 @@ from feast.infra.offline_stores.file_source import FileSource from feast.infra.offline_stores.redshift_source import RedshiftSource from feast.infra.offline_stores.snowflake_source import SnowflakeSource +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import SparkSource from feast.types import Bool, Float32, Int64 @@ -233,3 +234,55 @@ def test_redshift_fully_qualified_table_name(source_kwargs, expected_name): ) assert redshift_source.redshift_options.fully_qualified_table_name == expected_name + +@pytest.mark.parameterize( + "test_data,are_equal", + [ + ( + SparkSource( + name='name', + table='table', + query='query', + file_format='file_format' + ), + True + ), + ( + SparkSource( + table='table', + query='query', + file_format='file_format' + ), + False + ), + ( + SparkSource( + name='name', + table='table', + query='query', + file_format='file_format1' + ), + False + ), + ( + SparkSource( + name='name', + table='table', + query='query1', + file_format='file_format' + ), + True + ), + ] +) +def test_spark_source_equality(test_data, are_equal): + default = SparkSource( + name='name', + table='table1', + query='query', + file_format='file_format' + ) + if are_equal: + assert default == test_data + else: + assert default != test_data From 63a22d9aea08fe036fc2c99424b7520638cdb4be Mon Sep 17 00:00:00 2001 From: Calvin Hopkins Date: Wed, 25 Oct 2023 15:41:23 -0700 Subject: [PATCH 2/3] lint fixes Signed-off-by: Calvin Hopkins --- .../spark_offline_store/spark_source.py | 13 ++-- sdk/python/tests/unit/test_data_sources.py | 64 +++++++------------ 2 files changed, 31 insertions(+), 46 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 294c8f59e4d..65cc5041063 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -191,11 +191,10 @@ def __hash__(self): def __eq__(self, other): if not isinstance(other, SparkSource): - raise TypeError("Comparisons should only involve SparkSource class objects.") - return ( - super().__eq__(other) - and self.spark_options == other.spark_options - ) + raise TypeError( + "Comparisons should only involve SparkSource class objects." + ) + return super().__eq__(other) and self.spark_options == other.spark_options class SparkOptions: @@ -296,7 +295,9 @@ def to_proto(self) -> DataSourceProto.SparkOptions: def __eq__(self, other: object) -> bool: if not isinstance(other, SparkOptions): - raise TypeError("Comparisons should only involve SparkOptions class objects.") + raise TypeError( + "Comparisons should only involve SparkOptions class objects." + ) return ( self.table == other.table diff --git a/sdk/python/tests/unit/test_data_sources.py b/sdk/python/tests/unit/test_data_sources.py index c72d4899dcf..6015d937cee 100644 --- a/sdk/python/tests/unit/test_data_sources.py +++ b/sdk/python/tests/unit/test_data_sources.py @@ -10,10 +10,12 @@ ) from feast.field import Field from feast.infra.offline_stores.bigquery_source import BigQuerySource +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) from feast.infra.offline_stores.file_source import FileSource from feast.infra.offline_stores.redshift_source import RedshiftSource from feast.infra.offline_stores.snowflake_source import SnowflakeSource -from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import SparkSource from feast.types import Bool, Float32, Int64 @@ -235,52 +237,34 @@ def test_redshift_fully_qualified_table_name(source_kwargs, expected_name): assert redshift_source.redshift_options.fully_qualified_table_name == expected_name + @pytest.mark.parameterize( - "test_data,are_equal", - [ - ( - SparkSource( - name='name', - table='table', - query='query', - file_format='file_format' - ), - True - ), - ( - SparkSource( - table='table', - query='query', - file_format='file_format' - ), - False + "test_data,are_equal", + [ + ( + SparkSource( + name="name", table="table", query="query", file_format="file_format" ), - ( - SparkSource( - name='name', - table='table', - query='query', - file_format='file_format1' - ), - False + True, + ), + (SparkSource(table="table", query="query", file_format="file_format"), False), + ( + SparkSource( + name="name", table="table", query="query", file_format="file_format1" ), - ( - SparkSource( - name='name', - table='table', - query='query1', - file_format='file_format' - ), - True + False, + ), + ( + SparkSource( + name="name", table="table", query="query1", file_format="file_format" ), - ] + True, + ), + ], ) def test_spark_source_equality(test_data, are_equal): default = SparkSource( - name='name', - table='table1', - query='query', - file_format='file_format' + name="name", table="table1", query="query", file_format="file_format" ) if are_equal: assert default == test_data From 61f604a0987a2521d6106124ece7f2253feb503f Mon Sep 17 00:00:00 2001 From: Calvin Hopkins Date: Wed, 25 Oct 2023 15:43:12 -0700 Subject: [PATCH 3/3] force precommit Signed-off-by: Calvin Hopkins --- .../offline_stores/contrib/spark_offline_store/spark_source.py | 1 + sdk/python/tests/unit/test_data_sources.py | 1 + 2 files changed, 2 insertions(+) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 65cc5041063..1083cc56278 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -187,6 +187,7 @@ def get_table_query_string(self) -> str: # Note: Python requires redefining hash in child classes that override __eq__ def __hash__(self): + return super().__hash__() def __eq__(self, other): diff --git a/sdk/python/tests/unit/test_data_sources.py b/sdk/python/tests/unit/test_data_sources.py index 6015d937cee..990752621b0 100644 --- a/sdk/python/tests/unit/test_data_sources.py +++ b/sdk/python/tests/unit/test_data_sources.py @@ -270,3 +270,4 @@ def test_spark_source_equality(test_data, are_equal): assert default == test_data else: assert default != test_data +