From 64c26ccb15398b57bc57f78d5ff6f96e4f495799 Mon Sep 17 00:00:00 2001 From: Miray Yuce Date: Mon, 28 Mar 2022 17:14:34 -0700 Subject: [PATCH 1/3] add new value-types Signed-off-by: Miray Yuce --- ui/src/parsers/types.ts | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ui/src/parsers/types.ts b/ui/src/parsers/types.ts index e32d5b102ea..2f88eea4f06 100644 --- a/ui/src/parsers/types.ts +++ b/ui/src/parsers/types.ts @@ -13,7 +13,17 @@ enum FEAST_FEATURE_VALUE_TYPES { BYTES = "BYTES", INT32 = "INT32", DOUBLE = "DOUBLE", - UNIX_TIMESTAMP = "UNIX_TIMESTAMP" + UNIX_TIMESTAMP = "UNIX_TIMESTAMP", + INVALID = "INVALID", + BYTES_LIST = "BYTES_LIST", + STRING_LIST = "STRING_LIST", + INT32_LIST = "INT32_LIST", + INT64_LIST = "INT64_LIST", + DOUBLE_LIST = "DOUBLE_LIST", + FLOAT_LIST = "FLOAT_LIST", + BOOL_LIST = "BOOL_LIST", + UNIX_TIMESTAMP_LIST = "UNIX_TIMESTAMP_LIST", + NULL = "NULL" } export { FEAST_FCO_TYPES, FEAST_FEATURE_VALUE_TYPES }; From 293a2e37c7394f93b6a3e6cde1a1e9f0bbbb50d3 Mon Sep 17 00:00:00 2001 From: Miray Yuce Date: Tue, 29 Mar 2022 09:51:44 -0700 Subject: [PATCH 2/3] auto formatted files Signed-off-by: Miray Yuce --- sdk/python/feast/cli.py | 14 +- sdk/python/feast/data_source.py | 14 +- sdk/python/feast/diff/infra_diff.py | 10 +- sdk/python/feast/diff/registry_diff.py | 12 +- sdk/python/feast/driver_test_data.py | 7 +- sdk/python/feast/feature.py | 9 +- sdk/python/feast/feature_store.py | 135 ++++++++++-------- sdk/python/feast/feature_view.py | 4 +- sdk/python/feast/go_server.py | 6 +- sdk/python/feast/inference.py | 11 +- sdk/python/feast/infra/aws.py | 5 +- .../feast/infra/offline_stores/bigquery.py | 33 +++-- .../infra/offline_stores/bigquery_source.py | 35 ++--- .../contrib/spark_offline_store/spark.py | 4 +- .../spark_offline_store/spark_source.py | 5 +- sdk/python/feast/infra/offline_stores/file.py | 23 ++- .../infra/offline_stores/offline_store.py | 6 +- .../feast/infra/offline_stores/redshift.py | 15 +- .../infra/offline_stores/redshift_source.py | 4 +- .../feast/infra/offline_stores/snowflake.py | 12 +- .../feast/infra/online_stores/datastore.py | 17 ++- .../feast/infra/online_stores/dynamodb.py | 6 +- .../feast/infra/online_stores/sqlite.py | 10 +- .../feast/infra/passthrough_provider.py | 10 +- sdk/python/feast/infra/provider.py | 18 ++- sdk/python/feast/infra/utils/aws_utils.py | 35 +++-- sdk/python/feast/on_demand_feature_view.py | 7 +- sdk/python/feast/registry.py | 5 +- sdk/python/feast/repo_config.py | 24 ++-- sdk/python/feast/repo_operations.py | 4 +- sdk/python/feast/templates/aws/test.py | 3 +- sdk/python/feast/templates/gcp/test.py | 3 +- sdk/python/feast/templates/local/example.py | 6 +- sdk/python/feast/templates/snowflake/test.py | 3 +- sdk/python/feast/templates/spark/example.py | 10 +- sdk/python/feast/type_map.py | 8 +- sdk/python/feast/usage.py | 42 +++--- sdk/python/feast/utils.py | 2 +- ...st_benchmark_universal_online_retrieval.py | 4 +- sdk/python/tests/conftest.py | 14 +- sdk/python/tests/doctest/test_all.py | 7 +- .../example_repos/example_feature_repo_2.py | 6 +- sdk/python/tests/foo_provider.py | 5 +- .../tests/integration/e2e/test_usage_e2e.py | 2 +- .../tests/integration/e2e/test_validation.py | 16 ++- .../feature_repos/repo_configuration.py | 11 +- .../universal/data_sources/file.py | 4 +- .../offline_store/test_s3_custom_endpoint.py | 4 +- .../test_universal_historical_retrieval.py | 9 +- .../online_store/test_e2e_local.py | 2 +- .../integration/registration/test_cli.py | 3 +- .../registration/test_feature_store.py | 17 ++- .../registration/test_inference.py | 13 +- .../integration/registration/test_registry.py | 18 ++- .../registration/test_universal_types.py | 20 ++- .../tests/unit/diff/test_registry_diff.py | 15 +- sdk/python/tests/unit/test_usage.py | 2 +- sdk/python/tests/utils/data_source_utils.py | 3 +- .../tests/utils/online_read_write_test.py | 2 +- 59 files changed, 511 insertions(+), 243 deletions(-) diff --git a/sdk/python/feast/cli.py b/sdk/python/feast/cli.py index d2a71bc561b..8da6bfd7ced 100644 --- a/sdk/python/feast/cli.py +++ b/sdk/python/feast/cli.py @@ -471,7 +471,10 @@ def registry_dump_command(ctx: click.Context): @click.argument("start_ts") @click.argument("end_ts") @click.option( - "--views", "-v", help="Feature views to materialize", multiple=True, + "--views", + "-v", + help="Feature views to materialize", + multiple=True, ) @click.pass_context def materialize_command( @@ -498,7 +501,10 @@ def materialize_command( @cli.command("materialize-incremental") @click.argument("end_ts") @click.option( - "--views", "-v", help="Feature views to incrementally materialize", multiple=True, + "--views", + "-v", + help="Feature views to incrementally materialize", + multiple=True, ) @click.pass_context def materialize_incremental_command(ctx: click.Context, end_ts: str, views: List[str]): @@ -560,7 +566,9 @@ def init_command(project_directory, minimal: bool, template: str): help="Specify a port for the server [default: 6566]", ) @click.option( - "--no-access-log", is_flag=True, help="Disable the Uvicorn access log.", + "--no-access-log", + is_flag=True, + help="Disable the Uvicorn access log.", ) @click.pass_context def serve_command(ctx: click.Context, host: str, port: int, no_access_log: bool): diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index 2f66f846bcb..b887fdb3b50 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -42,7 +42,10 @@ class KafkaOptions: """ def __init__( - self, bootstrap_servers: str, message_format: StreamFormat, topic: str, + self, + bootstrap_servers: str, + message_format: StreamFormat, + topic: str, ): self.bootstrap_servers = bootstrap_servers self.message_format = message_format @@ -91,7 +94,10 @@ class KinesisOptions: """ def __init__( - self, record_format: StreamFormat, region: str, stream_name: str, + self, + record_format: StreamFormat, + region: str, + stream_name: str, ): self.record_format = record_format self.region = region @@ -378,7 +384,9 @@ class RequestDataSource(DataSource): schema: Dict[str, ValueType] def __init__( - self, name: str, schema: Dict[str, ValueType], + self, + name: str, + schema: Dict[str, ValueType], ): """Creates a RequestDataSource object.""" super().__init__(name) diff --git a/sdk/python/feast/diff/infra_diff.py b/sdk/python/feast/diff/infra_diff.py index a09eaf39ebe..51bece33dd6 100644 --- a/sdk/python/feast/diff/infra_diff.py +++ b/sdk/python/feast/diff/infra_diff.py @@ -126,7 +126,8 @@ def diff_infra_protos( infra_objects_to_delete, infra_objects_to_add, ) = tag_infra_proto_objects_for_keep_delete_add( - current_infra_objects, new_infra_objects, + current_infra_objects, + new_infra_objects, ) for e in infra_objects_to_add: @@ -199,5 +200,10 @@ def diff_between( ) ) return InfraObjectDiff( - new.name, infra_object_type, current, new, property_diffs, transition, + new.name, + infra_object_type, + current, + new, + property_diffs, + transition, ) diff --git a/sdk/python/feast/diff/registry_diff.py b/sdk/python/feast/diff/registry_diff.py index 4558a149a5c..21cab0425b0 100644 --- a/sdk/python/feast/diff/registry_diff.py +++ b/sdk/python/feast/diff/registry_diff.py @@ -147,7 +147,9 @@ def diff_registry_objects( def extract_objects_for_keep_delete_update_add( - registry: Registry, current_project: str, desired_repo_contents: RepoContents, + registry: Registry, + current_project: str, + desired_repo_contents: RepoContents, ) -> Tuple[ Dict[FeastObjectType, Set[FeastObject]], Dict[FeastObjectType, Set[FeastObject]], @@ -194,7 +196,9 @@ def extract_objects_for_keep_delete_update_add( def diff_between( - registry: Registry, current_project: str, desired_repo_contents: RepoContents, + registry: Registry, + current_project: str, + desired_repo_contents: RepoContents, ) -> RegistryDiff: """ Returns the difference between the current and desired repo states. @@ -287,7 +291,9 @@ def apply_diff_to_registry( BaseFeatureView, feast_object_diff.current_feast_object ) registry.delete_feature_view( - feature_view_obj.name, project, commit=False, + feature_view_obj.name, + project, + commit=False, ) if feast_object_diff.transition_type in [ diff --git a/sdk/python/feast/driver_test_data.py b/sdk/python/feast/driver_test_data.py index 117bfcbd9cb..07018c9004d 100644 --- a/sdk/python/feast/driver_test_data.py +++ b/sdk/python/feast/driver_test_data.py @@ -30,7 +30,12 @@ def _convert_event_timestamp(event_timestamp: pd.Timestamp, t: EventTimestampTyp def create_orders_df( - customers, drivers, start_date, end_date, order_count, locations=None, + customers, + drivers, + start_date, + end_date, + order_count, + locations=None, ) -> pd.DataFrame: """ Example df generated by this function (if locations): diff --git a/sdk/python/feast/feature.py b/sdk/python/feast/feature.py index b37e0f562b2..81f99e8cb30 100644 --- a/sdk/python/feast/feature.py +++ b/sdk/python/feast/feature.py @@ -30,7 +30,10 @@ class Feature: """ def __init__( - self, name: str, dtype: ValueType, labels: Optional[Dict[str, str]] = None, + self, + name: str, + dtype: ValueType, + labels: Optional[Dict[str, str]] = None, ): """Creates a Feature object.""" self._name = name @@ -91,7 +94,9 @@ def to_proto(self) -> FeatureSpecProto: value_type = ValueTypeProto.ValueType.Enum.Value(self.dtype.name) return FeatureSpecProto( - name=self.name, value_type=value_type, labels=self.labels, + name=self.name, + value_type=value_type, + labels=self.labels, ) @classmethod diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 73c2f14a638..ddc68419602 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -112,7 +112,9 @@ class FeatureStore: @log_exceptions def __init__( - self, repo_path: Optional[str] = None, config: Optional[RepoConfig] = None, + self, + repo_path: Optional[str] = None, + config: Optional[RepoConfig] = None, ): """ Creates a FeatureStore object. @@ -243,7 +245,9 @@ def list_request_feature_views( ) def _list_feature_views( - self, allow_cache: bool = False, hide_dummy_entity: bool = True, + self, + allow_cache: bool = False, + hide_dummy_entity: bool = True, ) -> List[FeatureView]: feature_views = [] for fv in self._registry.list_feature_views( @@ -399,18 +403,20 @@ def delete_feature_view(self, name: str): @log_exceptions_and_usage def delete_feature_service(self, name: str): """ - Deletes a feature service. + Deletes a feature service. - Args: - name: Name of feature service. + Args: + name: Name of feature service. - Raises: - FeatureServiceNotFoundException: The feature view could not be found. - """ + Raises: + FeatureServiceNotFoundException: The feature view could not be found. + """ return self._registry.delete_feature_service(name, self.project) def _get_features( - self, features: Union[List[str], FeatureService], allow_cache: bool = False, + self, + features: Union[List[str], FeatureService], + allow_cache: bool = False, ) -> List[str]: _features = features @@ -872,7 +878,8 @@ def get_historical_features( for feature_name in odfv_request_data_schema.keys(): if feature_name not in entity_pd_df.columns: raise RequestDataNotFoundInEntityDfException( - feature_name=feature_name, feature_view_name=odfv.name, + feature_name=feature_name, + feature_view_name=odfv.name, ) _validate_feature_refs(_feature_refs, full_feature_names) @@ -903,17 +910,17 @@ def create_saved_dataset( feature_service: Optional[FeatureService] = None, ) -> SavedDataset: """ - Execute provided retrieval job and persist its outcome in given storage. - Storage type (eg, BigQuery or Redshift) must be the same as globally configured offline store. - After data successfully persisted saved dataset object with dataset metadata is committed to the registry. - Name for the saved dataset should be unique within project, since it's possible to overwrite previously stored dataset - with the same name. + Execute provided retrieval job and persist its outcome in given storage. + Storage type (eg, BigQuery or Redshift) must be the same as globally configured offline store. + After data successfully persisted saved dataset object with dataset metadata is committed to the registry. + Name for the saved dataset should be unique within project, since it's possible to overwrite previously stored dataset + with the same name. - Returns: - SavedDataset object with attached RetrievalJob + Returns: + SavedDataset object with attached RetrievalJob - Raises: - ValueError if given retrieval job doesn't have metadata + Raises: + ValueError if given retrieval job doesn't have metadata """ warnings.warn( "Saving dataset is an experimental feature. " @@ -985,7 +992,9 @@ def get_saved_dataset(self, name: str) -> SavedDataset: @log_exceptions_and_usage def materialize_incremental( - self, end_date: datetime, feature_views: Optional[List[str]] = None, + self, + end_date: datetime, + feature_views: Optional[List[str]] = None, ) -> None: """ Materialize incremental new data from the offline store into the online store. @@ -1072,7 +1081,10 @@ def tqdm_builder(length): ) self._registry.apply_materialization( - feature_view, self.project, start_date, end_date, + feature_view, + self.project, + start_date, + end_date, ) @log_exceptions_and_usage @@ -1159,7 +1171,10 @@ def tqdm_builder(length): ) self._registry.apply_materialization( - feature_view, self.project, start_date, end_date, + feature_view, + self.project, + start_date, + end_date, ) @log_exceptions_and_usage @@ -1427,12 +1442,17 @@ def _get_online_features( for table, requested_features in grouped_refs: # Get the correct set of entity values with the correct join keys. table_entity_values, idxs = self._get_unique_entities( - table, join_key_values, entity_name_to_join_key_map, + table, + join_key_values, + entity_name_to_join_key_map, ) # Fetch feature data for the minimum set of Entities. feature_data = self._read_from_online_store( - table_entity_values, provider, requested_features, table, + table_entity_values, + provider, + requested_features, + table, ) # Populate the result_rows with the Features from the OnlineStore inplace. @@ -1589,15 +1609,17 @@ def _get_unique_entities( join_key_values: Dict[str, List[Value]], entity_name_to_join_key_map: Dict[str, str], ) -> Tuple[Tuple[Dict[str, Value], ...], Tuple[List[int], ...]]: - """ Return the set of unique composite Entities for a Feature View and the indexes at which they appear. + """Return the set of unique composite Entities for a Feature View and the indexes at which they appear. - This method allows us to query the OnlineStore for data we need only once - rather than requesting and processing data for the same combination of - Entities multiple times. + This method allows us to query the OnlineStore for data we need only once + rather than requesting and processing data for the same combination of + Entities multiple times. """ # Get the correct set of entity values with the correct join keys. table_entity_values = self._get_table_entity_values( - table, entity_name_to_join_key_map, join_key_values, + table, + entity_name_to_join_key_map, + join_key_values, ) # Convert back to rowise. @@ -1629,14 +1651,14 @@ def _read_from_online_store( requested_features: List[str], table: FeatureView, ) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[Value]]]: - """ Read and process data from the OnlineStore for a given FeatureView. + """Read and process data from the OnlineStore for a given FeatureView. - This method guarantees that the order of the data in each element of the - List returned is the same as the order of `requested_features`. + This method guarantees that the order of the data in each element of the + List returned is the same as the order of `requested_features`. - This method assumes that `provider.online_read` returns data for each - combination of Entities in `entity_rows` in the same order as they - are provided. + This method assumes that `provider.online_read` returns data for each + combination of Entities in `entity_rows` in the same order as they + are provided. """ # Instantiate one EntityKeyProto per Entity. entity_key_protos = [ @@ -1693,23 +1715,23 @@ def _populate_response_from_feature_data( requested_features: Iterable[str], table: FeatureView, ): - """ Populate the GetOnlineFeaturesResponse with feature data. - - This method assumes that `_read_from_online_store` returns data for each - combination of Entities in `entity_rows` in the same order as they - are provided. - - Args: - feature_data: A list of data in Protobuf form which was retrieved from the OnlineStore. - indexes: A list of indexes which should be the same length as `feature_data`. Each list - of indexes corresponds to a set of result rows in `online_features_response`. - online_features_response: The object to populate. - full_feature_names: A boolean that provides the option to add the feature view prefixes to the feature names, - changing them from the format "feature" to "feature_view__feature" (e.g., "daily_transactions" changes to - "customer_fv__daily_transactions"). - requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the - data in `feature_data`. - table: The FeatureView that `feature_data` was retrieved from. + """Populate the GetOnlineFeaturesResponse with feature data. + + This method assumes that `_read_from_online_store` returns data for each + combination of Entities in `entity_rows` in the same order as they + are provided. + + Args: + feature_data: A list of data in Protobuf form which was retrieved from the OnlineStore. + indexes: A list of indexes which should be the same length as `feature_data`. Each list + of indexes corresponds to a set of result rows in `online_features_response`. + online_features_response: The object to populate. + full_feature_names: A boolean that provides the option to add the feature view prefixes to the feature names, + changing them from the format "feature" to "feature_view__feature" (e.g., "daily_transactions" changes to + "customer_fv__daily_transactions"). + requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the + data in `feature_data`. + table: The FeatureView that `feature_data` was retrieved from. """ # Add the feature names to the response. requested_feature_refs = [ @@ -1782,7 +1804,8 @@ def _augment_response_with_on_demand_transforms( for odfv_name, _feature_refs in odfv_feature_refs.items(): odfv = requested_odfv_map[odfv_name] transformed_features_df = odfv.get_transformed_features_df( - initial_response_df, full_feature_names, + initial_response_df, + full_feature_names, ) selected_subset = [ f for f in transformed_features_df.columns if f in _feature_refs @@ -1973,7 +1996,7 @@ def _group_feature_refs( List[Tuple[RequestFeatureView, List[str]]], Set[str], ]: - """ Get list of feature views and corresponding feature names based on feature references""" + """Get list of feature views and corresponding feature names based on feature references""" # view name to view proto view_index = {view.projection.name_to_use(): view for view in all_feature_views} @@ -2046,7 +2069,7 @@ def _print_materialization_log( def _validate_feature_views(feature_views: List[BaseFeatureView]): - """ Verify feature views have case-insensitively unique names""" + """Verify feature views have case-insensitively unique names""" fv_names = set() for fv in feature_views: case_insensitive_fv_name = fv.name.lower() @@ -2061,7 +2084,7 @@ def _validate_feature_views(feature_views: List[BaseFeatureView]): def _validate_data_sources(data_sources: List[DataSource]): - """ Verify data sources have case-insensitively unique names""" + """Verify data sources have case-insensitively unique names""" ds_names = set() for fv in data_sources: case_insensitive_ds_name = fv.name.lower() diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index b81c0afa510..8a5019530fa 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -44,7 +44,9 @@ DUMMY_ENTITY_NAME = "__dummy" DUMMY_ENTITY_VAL = "" DUMMY_ENTITY = Entity( - name=DUMMY_ENTITY_NAME, join_key=DUMMY_ENTITY_ID, value_type=ValueType.STRING, + name=DUMMY_ENTITY_NAME, + join_key=DUMMY_ENTITY_ID, + value_type=ValueType.STRING, ) diff --git a/sdk/python/feast/go_server.py b/sdk/python/feast/go_server.py index 1fcbab61f08..379e52a64c6 100644 --- a/sdk/python/feast/go_server.py +++ b/sdk/python/feast/go_server.py @@ -84,7 +84,11 @@ def connect(self) -> bool: else feast.__path__[0] + "/binaries/server" ) # Automatically reconnect with go subprocess exits - self._process = Popen([executable], cwd=cwd, env=env,) + self._process = Popen( + [executable], + cwd=cwd, + env=env, + ) channel = grpc.insecure_channel(f"unix:{self.sock_file}") self._client = ServingServiceStub(channel) diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index 51c4e9d78ec..bca46cee5ee 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -45,7 +45,10 @@ def update_entities_with_inferred_types_from_feature_views( # get entity information from information extracted from the view batch source extracted_entity_name_type_pairs = list( - filter(lambda tup: tup[0] == entity.join_key, col_names_and_types,) + filter( + lambda tup: tup[0] == entity.join_key, + col_names_and_types, + ) ) if len(extracted_entity_name_type_pairs) == 0: # Doesn't mention inference error because would also be an error without inferencing @@ -54,8 +57,10 @@ def update_entities_with_inferred_types_from_feature_views( its entity's name.""" ) - inferred_value_type = view.batch_source.source_datatype_to_feast_value_type()( - extracted_entity_name_type_pairs[0][1] + inferred_value_type = ( + view.batch_source.source_datatype_to_feast_value_type()( + extracted_entity_name_type_pairs[0][1] + ) ) if ( diff --git a/sdk/python/feast/infra/aws.py b/sdk/python/feast/infra/aws.py index b7cc61de0e5..e1ec507bca7 100644 --- a/sdk/python/feast/infra/aws.py +++ b/sdk/python/feast/infra/aws.py @@ -196,7 +196,10 @@ def _deploy_feature_server(self, project: str, image_uri: str): @log_exceptions_and_usage(provider="AwsProvider") def teardown_infra( - self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], + self, + project: str, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ) -> None: if self.online_store: self.online_store.teardown(self.repo_config, tables, entities) diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 44e62d6ad1a..fe703916740 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -56,7 +56,7 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel): - """ Offline store config for GCP BigQuery """ + """Offline store config for GCP BigQuery""" type: Literal["bigquery"] = "bigquery" """ Offline store type selector""" @@ -121,7 +121,10 @@ def pull_latest_from_table_or_query( # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized return BigQueryRetrievalJob( - query=query, client=client, config=config, full_feature_names=False, + query=query, + client=client, + config=config, + full_feature_names=False, ) @staticmethod @@ -151,7 +154,10 @@ def pull_all_from_table_or_query( WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') """ return BigQueryRetrievalJob( - query=query, client=client, config=config, full_feature_names=False, + query=query, + client=client, + config=config, + full_feature_names=False, ) @staticmethod @@ -182,20 +188,27 @@ def get_historical_features( config.offline_store.location, ) - entity_schema = _get_entity_schema(client=client, entity_df=entity_df,) + entity_schema = _get_entity_schema( + client=client, + entity_df=entity_df, + ) - entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema + entity_df_event_timestamp_col = ( + offline_utils.infer_event_timestamp_from_entity_df(entity_schema) ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, entity_df_event_timestamp_col, client, + entity_df, + entity_df_event_timestamp_col, + client, ) @contextlib.contextmanager def query_generator() -> Iterator[str]: _upload_entity_df( - client=client, table_name=table_reference, entity_df=entity_df, + client=client, + table_name=table_reference, + entity_df=entity_df, ) expected_join_keys = offline_utils.get_expected_join_keys( @@ -437,7 +450,9 @@ def _get_table_reference_for_new_entity( def _upload_entity_df( - client: Client, table_name: str, entity_df: Union[pd.DataFrame, str], + client: Client, + table_name: str, + entity_df: Union[pd.DataFrame, str], ) -> Table: """Uploads a Pandas entity dataframe into a BigQuery table and returns the resulting table""" diff --git a/sdk/python/feast/infra/offline_stores/bigquery_source.py b/sdk/python/feast/infra/offline_stores/bigquery_source.py index 92b6939fc3a..0025949527c 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery_source.py +++ b/sdk/python/feast/infra/offline_stores/bigquery_source.py @@ -27,20 +27,20 @@ def __init__( ): """Create a BigQuerySource from an existing table or query. - Args: - table (optional): The BigQuery table where features can be found. - table_ref (optional): (Deprecated) The BigQuery table where features can be found. - event_timestamp_column: Event timestamp column used for point in time joins of feature values. - created_timestamp_column (optional): Timestamp column when row was created, used for deduplicating rows. - field_mapping: A dictionary mapping of column names in this data source to feature names in a feature table - or view. Only used for feature columns, not entities or timestamp columns. - date_partition_column (optional): Timestamp column used for partitioning. - query (optional): SQL query to execute to generate data for this data source. - name (optional): Name for the source. Defaults to the table_ref if not specified. - Example: - >>> from feast import BigQuerySource - >>> my_bigquery_source = BigQuerySource(table="gcp_project:bq_dataset.bq_table") - """ + Args: + table (optional): The BigQuery table where features can be found. + table_ref (optional): (Deprecated) The BigQuery table where features can be found. + event_timestamp_column: Event timestamp column used for point in time joins of feature values. + created_timestamp_column (optional): Timestamp column when row was created, used for deduplicating rows. + field_mapping: A dictionary mapping of column names in this data source to feature names in a feature table + or view. Only used for feature columns, not entities or timestamp columns. + date_partition_column (optional): Timestamp column used for partitioning. + query (optional): SQL query to execute to generate data for this data source. + name (optional): Name for the source. Defaults to the table_ref if not specified. + Example: + >>> from feast import BigQuerySource + >>> my_bigquery_source = BigQuerySource(table="gcp_project:bq_dataset.bq_table") + """ if table is None and table_ref is None and query is None: raise ValueError('No "table" or "query" argument provided.') if not table and table_ref: @@ -186,7 +186,9 @@ class BigQueryOptions: """ def __init__( - self, table_ref: Optional[str], query: Optional[str], + self, + table_ref: Optional[str], + query: Optional[str], ): self._table_ref = table_ref self._query = query @@ -247,7 +249,8 @@ def to_proto(self) -> DataSourceProto.BigQueryOptions: """ bigquery_options_proto = DataSourceProto.BigQueryOptions( - table_ref=self.table_ref, query=self.query, + table_ref=self.table_ref, + query=self.query, ) return bigquery_options_proto diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index fb40edd16d7..2bbe198d308 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -134,7 +134,9 @@ def get_historical_features( entity_schema=entity_schema, ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, event_timestamp_col, spark_session, + entity_df, + event_timestamp_col, + spark_session, ) expected_join_keys = offline_utils.get_expected_join_keys( diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 3ffdf6eda0c..1db9d0887e5 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -79,7 +79,10 @@ def __init__( ) self.spark_options = SparkOptions( - table=table, query=query, path=path, file_format=file_format, + table=table, + query=query, + path=path, + file_format=file_format, ) @property diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index c71f0c3ff74..6a4b5f72011 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -31,7 +31,7 @@ class FileOfflineStoreConfig(FeastConfigBaseModel): - """ Offline store config for local (file-based) store """ + """Offline store config for local (file-based) store""" type: Literal["file"] = "file" """ Offline store type selector""" @@ -79,7 +79,8 @@ def persist(self, storage: SavedDatasetStorage): assert isinstance(storage, SavedDatasetFileStorage) filesystem, path = FileSource.create_filesystem_and_path( - storage.file_options.file_url, storage.file_options.s3_endpoint_override, + storage.file_options.file_url, + storage.file_options.s3_endpoint_override, ) if path.endswith(".parquet"): @@ -332,7 +333,8 @@ def evaluate_offline_job(): # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized return FileRetrievalJob( - evaluation_function=evaluate_offline_job, full_feature_names=False, + evaluation_function=evaluate_offline_job, + full_feature_names=False, ) @staticmethod @@ -360,7 +362,8 @@ def pull_all_from_table_or_query( def _get_entity_df_event_timestamp_range( - entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, + entity_df: Union[pd.DataFrame, str], + entity_df_event_timestamp_col: str, ) -> Tuple[datetime, datetime]: if not isinstance(entity_df, pd.DataFrame): raise ValueError( @@ -390,7 +393,10 @@ def _read_datasource(data_source) -> dd.DataFrame: else None ) - return dd.read_parquet(data_source.path, storage_options=storage_options,) + return dd.read_parquet( + data_source.path, + storage_options=storage_options, + ) def _field_mapping( @@ -440,7 +446,8 @@ def _field_mapping( # Make sure to not have duplicated columns if entity_df_event_timestamp_col == event_timestamp_column: df_to_join = _run_dask_field_mapping( - df_to_join, {event_timestamp_column: f"__{event_timestamp_column}"}, + df_to_join, + {event_timestamp_column: f"__{event_timestamp_column}"}, ) event_timestamp_column = f"__{event_timestamp_column}" @@ -553,7 +560,9 @@ def _drop_duplicates( df_to_join = df_to_join.persist() df_to_join = df_to_join.drop_duplicates( - all_join_keys + [entity_df_event_timestamp_col], keep="last", ignore_index=True, + all_join_keys + [entity_df_event_timestamp_col], + keep="last", + ignore_index=True, ) return df_to_join.persist() diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index e5937712f69..9b00773c34c 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -81,7 +81,8 @@ def to_df( for odfv in self.on_demand_feature_views: features_df = features_df.join( odfv.get_transformed_features_df( - features_df, self.full_feature_names, + features_df, + self.full_feature_names, ) ) @@ -125,7 +126,8 @@ def to_arrow( for odfv in self.on_demand_feature_views: features_df = features_df.join( odfv.get_transformed_features_df( - features_df, self.full_feature_names, + features_df, + self.full_feature_names, ) ) diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 3efd45bc741..828ffba46ab 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -39,7 +39,7 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel): - """ Offline store config for AWS Redshift """ + """Offline store config for AWS Redshift""" type: Literal["redshift"] = "redshift" """ Offline store type selector""" @@ -185,12 +185,15 @@ def get_historical_features( entity_df, redshift_client, config, s3_resource ) - entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema + entity_df_event_timestamp_col = ( + offline_utils.infer_event_timestamp_from_entity_df(entity_schema) ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, entity_df_event_timestamp_col, redshift_client, config, + entity_df, + entity_df_event_timestamp_col, + redshift_client, + config, ) @contextlib.contextmanager @@ -341,7 +344,7 @@ def _to_arrow_internal(self) -> pa.Table: @log_exceptions_and_usage def to_s3(self) -> str: - """ Export dataset to S3 in Parquet format and return path """ + """Export dataset to S3 in Parquet format and return path""" if self.on_demand_feature_views: transformed_df = self.to_df() aws_utils.upload_df_to_s3(self._s3_resource, self._s3_path, transformed_df) @@ -361,7 +364,7 @@ def to_s3(self) -> str: @log_exceptions_and_usage def to_redshift(self, table_name: str) -> None: - """ Save dataset as a new Redshift table """ + """Save dataset as a new Redshift table""" if self.on_demand_feature_views: transformed_df = self.to_df() aws_utils.upload_df_to_redshift( diff --git a/sdk/python/feast/infra/offline_stores/redshift_source.py b/sdk/python/feast/infra/offline_stores/redshift_source.py index 8573396aca4..d3afb532d9a 100644 --- a/sdk/python/feast/infra/offline_stores/redshift_source.py +++ b/sdk/python/feast/infra/offline_stores/redshift_source.py @@ -282,7 +282,9 @@ def to_proto(self) -> DataSourceProto.RedshiftOptions: A RedshiftOptionsProto protobuf. """ redshift_options_proto = DataSourceProto.RedshiftOptions( - table=self.table, schema=self.schema, query=self.query, + table=self.table, + schema=self.schema, + query=self.query, ) return redshift_options_proto diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index ee8cd71ce05..b1bc2b1d136 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -54,7 +54,7 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel): - """ Offline store config for Snowflake """ + """Offline store config for Snowflake""" type: Literal["snowflake.offline"] = "snowflake.offline" """ Offline store type selector""" @@ -208,12 +208,14 @@ def get_historical_features( entity_schema = _get_entity_schema(entity_df, snowflake_conn, config) - entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema + entity_df_event_timestamp_col = ( + offline_utils.infer_event_timestamp_from_entity_df(entity_schema) ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, entity_df_event_timestamp_col, snowflake_conn, + entity_df, + entity_df_event_timestamp_col, + snowflake_conn, ) @contextlib.contextmanager @@ -336,7 +338,7 @@ def _to_arrow_internal(self) -> pa.Table: ) def to_snowflake(self, table_name: str) -> None: - """ Save dataset as a new Snowflake table """ + """Save dataset as a new Snowflake table""" if self.on_demand_feature_views is not None: transformed_df = self.to_df() diff --git a/sdk/python/feast/infra/online_stores/datastore.py b/sdk/python/feast/infra/online_stores/datastore.py index e7621ab88f8..c495d324f50 100644 --- a/sdk/python/feast/infra/online_stores/datastore.py +++ b/sdk/python/feast/infra/online_stores/datastore.py @@ -55,7 +55,7 @@ class DatastoreOnlineStoreConfig(FeastConfigBaseModel): - """ Online store config for GCP Datastore """ + """Online store config for GCP Datastore""" type: Literal["datastore"] = "datastore" """ Online store type selector""" @@ -197,7 +197,12 @@ def _write_minibatch( document_id = compute_entity_id(entity_key) key = client.key( - "Project", project, "Table", table.name, "Row", document_id, + "Project", + project, + "Table", + table.name, + "Row", + document_id, ) entity = datastore.Entity( @@ -316,7 +321,10 @@ def _initialize_client( project_id: Optional[str], namespace: Optional[str] ) -> datastore.Client: try: - client = datastore.Client(project=project_id, namespace=namespace,) + client = datastore.Client( + project=project_id, + namespace=namespace, + ) return client except DefaultCredentialsError as e: raise FeastProviderLoginError( @@ -392,7 +400,8 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any: @staticmethod def from_proto(datastore_table_proto: DatastoreTableProto) -> Any: datastore_table = DatastoreTable( - project=datastore_table_proto.project, name=datastore_table_proto.name, + project=datastore_table_proto.project, + name=datastore_table_proto.name, ) # Distinguish between null and empty string, since project_id and namespace are StringValues. diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index c161a5b9554..6e212ef55f1 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -304,7 +304,8 @@ def _get_table_name( def _delete_table_idempotent( - dynamodb_resource, table_name: str, + dynamodb_resource, + table_name: str, ): try: table = dynamodb_resource.Table(table_name) @@ -358,7 +359,8 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any: @staticmethod def from_proto(dynamodb_table_proto: DynamoDBTableProto) -> Any: return DynamoDBTable( - name=dynamodb_table_proto.name, region=dynamodb_table_proto.region, + name=dynamodb_table_proto.name, + region=dynamodb_table_proto.region, ) def update(self): diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index e65aab4e7be..23af3c13e8e 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -37,7 +37,7 @@ class SqliteOnlineStoreConfig(FeastConfigBaseModel): - """ Online store config for local (SQLite-based) store """ + """Online store config for local (SQLite-based) store""" type: Literal[ "sqlite", "feast.infra.online_stores.sqlite.SqliteOnlineStore" @@ -230,7 +230,8 @@ def teardown( def _initialize_conn(db_path: str): Path(db_path).parent.mkdir(exist_ok=True) return sqlite3.connect( - db_path, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES, + db_path, + detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES, ) @@ -278,7 +279,10 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any: @staticmethod def from_proto(sqlite_table_proto: SqliteTableProto) -> Any: - return SqliteTable(path=sqlite_table_proto.path, name=sqlite_table_proto.name,) + return SqliteTable( + path=sqlite_table_proto.path, + name=sqlite_table_proto.name, + ) def update(self): self.conn.execute( diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 3468b9dc927..b630aca0457 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -66,7 +66,10 @@ def update_infra( ) def teardown_infra( - self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], + self, + project: str, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ) -> None: set_usage_attribute("provider", self.__class__.__name__) if self.online_store: @@ -102,7 +105,10 @@ def online_read( return result def ingest_df( - self, feature_view: FeatureView, entities: List[Entity], df: pandas.DataFrame, + self, + feature_view: FeatureView, + entities: List[Entity], + df: pandas.DataFrame, ): set_usage_attribute("provider", self.__class__.__name__) table = pa.Table.from_pandas(df) diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index b3f10292423..952e6be4549 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -79,7 +79,10 @@ def plan_infra( @abc.abstractmethod def teardown_infra( - self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], + self, + project: str, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): """ Tear down all cloud resources for a repo. @@ -119,7 +122,10 @@ def online_write_batch( ... def ingest_df( - self, feature_view: FeatureView, entities: List[Entity], df: pandas.DataFrame, + self, + feature_view: FeatureView, + entities: List[Entity], + df: pandas.DataFrame, ): """ Ingests a DataFrame directly into the online store @@ -183,7 +189,7 @@ def retrieve_saved_dataset( Returns: RetrievalJob object, which is lazy wrapper for actual query performed under the hood. - """ + """ ... def get_feature_server_endpoint(self) -> Optional[str]: @@ -302,7 +308,8 @@ def _get_column_names( def _run_field_mapping( - table: pyarrow.Table, field_mapping: Dict[str, str], + table: pyarrow.Table, + field_mapping: Dict[str, str], ) -> pyarrow.Table: # run field mapping in the forward direction cols = table.column_names @@ -314,7 +321,8 @@ def _run_field_mapping( def _run_dask_field_mapping( - table: dd.DataFrame, field_mapping: Dict[str, str], + table: dd.DataFrame, + field_mapping: Dict[str, str], ): if field_mapping: # run field mapping in the forward direction diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index e7f628795d8..cfe50fbc2ec 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -87,7 +87,10 @@ def execute_redshift_statement_async( """ try: return redshift_data_client.execute_statement( - ClusterIdentifier=cluster_id, Database=database, DbUser=user, Sql=query, + ClusterIdentifier=cluster_id, + Database=database, + DbUser=user, + Sql=query, ) except ClientError as e: if e.response["Error"]["Code"] == "ValidationException": @@ -151,11 +154,15 @@ def execute_redshift_statement( def get_redshift_statement_result(redshift_data_client, statement_id: str) -> dict: - """ Get the Redshift statement result """ + """Get the Redshift statement result""" return redshift_data_client.get_statement_result(Id=statement_id) -def upload_df_to_s3(s3_resource, s3_path: str, df: pd.DataFrame,) -> None: +def upload_df_to_s3( + s3_resource, + s3_path: str, + df: pd.DataFrame, +) -> None: """Uploads a Pandas DataFrame to S3 as a parquet file Args: @@ -301,12 +308,16 @@ def temporarily_upload_df_to_redshift( # Clean up the uploaded Redshift table execute_redshift_statement( - redshift_data_client, cluster_id, database, user, f"DROP TABLE {table_name}", + redshift_data_client, + cluster_id, + database, + user, + f"DROP TABLE {table_name}", ) def download_s3_directory(s3_resource, bucket: str, key: str, local_dir: str): - """ Download the S3 directory to a local disk """ + """Download the S3 directory to a local disk""" bucket_obj = s3_resource.Bucket(bucket) if key != "" and not key.endswith("/"): key = key + "/" @@ -318,7 +329,7 @@ def download_s3_directory(s3_resource, bucket: str, key: str, local_dir: str): def delete_s3_directory(s3_resource, bucket: str, key: str): - """ Delete S3 directory recursively """ + """Delete S3 directory recursively""" bucket_obj = s3_resource.Bucket(bucket) if key != "" and not key.endswith("/"): key = key + "/" @@ -365,11 +376,17 @@ def unload_redshift_query_to_pa( iam_role: str, query: str, ) -> pa.Table: - """ Unload Redshift Query results to S3 and get the results in PyArrow Table format """ + """Unload Redshift Query results to S3 and get the results in PyArrow Table format""" bucket, key = get_bucket_and_key(s3_path) execute_redshift_query_and_unload_to_s3( - redshift_data_client, cluster_id, database, user, s3_path, iam_role, query, + redshift_data_client, + cluster_id, + database, + user, + s3_path, + iam_role, + query, ) with tempfile.TemporaryDirectory() as temp_dir: @@ -388,7 +405,7 @@ def unload_redshift_query_to_df( iam_role: str, query: str, ) -> pd.DataFrame: - """ Unload Redshift Query results to S3 and get the results in Pandas DataFrame format """ + """Unload Redshift Query results to S3 and get the results in Pandas DataFrame format""" table = unload_redshift_query_to_pa( redshift_data_client, cluster_id, diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index f0eaf987ef3..d2130480646 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -162,7 +162,8 @@ def to_proto(self) -> OnDemandFeatureViewProto: features=[feature.to_proto() for feature in self.features], inputs=inputs, user_defined_function=UserDefinedFunctionProto( - name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True), + name=self.udf.__name__, + body=dill.dumps(self.udf, recurse=True), ), description=self.description, tags=self.tags, @@ -242,7 +243,9 @@ def get_request_data_schema(self) -> Dict[str, ValueType]: return schema def get_transformed_features_df( - self, df_with_features: pd.DataFrame, full_feature_names: bool = False, + self, + df_with_features: pd.DataFrame, + full_feature_names: bool = False, ) -> pd.DataFrame: # Apply on demand transformations columns_to_cleanup = [] diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index 0bf73fcd24c..36cc243d670 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -789,7 +789,10 @@ def delete_entity(self, name: str, project: str, commit: bool = True): raise EntityNotFoundException(name, project) def apply_saved_dataset( - self, saved_dataset: SavedDataset, project: str, commit: bool = True, + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, ): """ Registers a single entity with Feast diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 4fc46b91faf..7ab4a0e60e1 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -57,7 +57,7 @@ class FeastBaseModel(BaseModel): - """ Feast Pydantic Configuration Class """ + """Feast Pydantic Configuration Class""" class Config: arbitrary_types_allowed = True @@ -65,7 +65,7 @@ class Config: class FeastConfigBaseModel(BaseModel): - """ Feast Pydantic Configuration Class """ + """Feast Pydantic Configuration Class""" class Config: arbitrary_types_allowed = True @@ -73,7 +73,7 @@ class Config: class RegistryConfig(FeastBaseModel): - """ Metadata Store Configuration. Configuration that relates to reading from and writing to the Feast registry.""" + """Metadata Store Configuration. Configuration that relates to reading from and writing to the Feast registry.""" registry_store_type: Optional[StrictStr] """ str: Provider name or a class name that implements RegistryStore. """ @@ -89,7 +89,7 @@ class RegistryConfig(FeastBaseModel): class RepoConfig(FeastBaseModel): - """ Repo config. Typically loaded from `feature_store.yaml` """ + """Repo config. Typically loaded from `feature_store.yaml`""" registry: Union[StrictStr, RegistryConfig] = "data/registry.db" """ str: Path to metadata store. Can be a local path, or remote object storage path, e.g. a GCS URI """ @@ -190,7 +190,8 @@ def _validate_online_store_config(cls, values): online_config_class(**values["online_store"]) except ValidationError as e: raise ValidationError( - [ErrorWrapper(e, loc="online_store")], model=RepoConfig, + [ErrorWrapper(e, loc="online_store")], + model=RepoConfig, ) return values @@ -224,7 +225,8 @@ def _validate_offline_store_config(cls, values): offline_config_class(**values["offline_store"]) except ValidationError as e: raise ValidationError( - [ErrorWrapper(e, loc="offline_store")], model=RepoConfig, + [ErrorWrapper(e, loc="offline_store")], + model=RepoConfig, ) return values @@ -258,7 +260,8 @@ def _validate_feature_server_config(cls, values): feature_server_config_class(**values["feature_server"]) except ValidationError as e: raise ValidationError( - [ErrorWrapper(e, loc="feature_server")], model=RepoConfig, + [ErrorWrapper(e, loc="feature_server")], + model=RepoConfig, ) return values @@ -295,7 +298,12 @@ def write_to_path(self, repo_path: Path): config_path = repo_path / "feature_store.yaml" with open(config_path, mode="w") as f: yaml.dump( - yaml.safe_load(self.json(exclude={"repo_path"}, exclude_unset=True,)), + yaml.safe_load( + self.json( + exclude={"repo_path"}, + exclude_unset=True, + ) + ), f, sort_keys=False, ) diff --git a/sdk/python/feast/repo_operations.py b/sdk/python/feast/repo_operations.py index 3457aa48866..5800e6c62b4 100644 --- a/sdk/python/feast/repo_operations.py +++ b/sdk/python/feast/repo_operations.py @@ -94,7 +94,7 @@ def get_repo_files(repo_root: Path) -> List[Path]: def parse_repo(repo_root: Path) -> RepoContents: - """ Collect feature table definitions from feature repo """ + """Collect feature table definitions from feature repo""" res = RepoContents( data_sources=set(), entities=set(), @@ -264,7 +264,7 @@ def teardown(repo_config: RepoConfig, repo_path: Path): @log_exceptions_and_usage def registry_dump(repo_config: RepoConfig, repo_path: Path): - """ For debugging only: output contents of the metadata registry """ + """For debugging only: output contents of the metadata registry""" registry_config = repo_config.get_registry_config() project = repo_config.project registry = Registry(registry_config=registry_config, repo_path=repo_path) diff --git a/sdk/python/feast/templates/aws/test.py b/sdk/python/feast/templates/aws/test.py index 07410954f7b..3d223e8f266 100644 --- a/sdk/python/feast/templates/aws/test.py +++ b/sdk/python/feast/templates/aws/test.py @@ -54,7 +54,8 @@ def main(): # Retrieve features from the online store (Firestore) online_features = fs.get_online_features( - features=features, entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], + features=features, + entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], ).to_dict() print() diff --git a/sdk/python/feast/templates/gcp/test.py b/sdk/python/feast/templates/gcp/test.py index 538334044bf..8ff11bda5c7 100644 --- a/sdk/python/feast/templates/gcp/test.py +++ b/sdk/python/feast/templates/gcp/test.py @@ -54,7 +54,8 @@ def main(): # Retrieve features from the online store (Firestore) online_features = fs.get_online_features( - features=features, entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], + features=features, + entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], ).to_dict() print() diff --git a/sdk/python/feast/templates/local/example.py b/sdk/python/feast/templates/local/example.py index 6ab549b8c5b..455b8b780c9 100644 --- a/sdk/python/feast/templates/local/example.py +++ b/sdk/python/feast/templates/local/example.py @@ -15,7 +15,11 @@ # Define an entity for the driver. You can think of entity as a primary key used to # fetch features. -driver = Entity(name="driver_id", value_type=ValueType.INT64, description="driver id",) +driver = Entity( + name="driver_id", + value_type=ValueType.INT64, + description="driver id", +) # Our parquet files contain sample data that includes a driver_id column, timestamps and # three feature column. Here we define a Feature View that will allow us to serve this diff --git a/sdk/python/feast/templates/snowflake/test.py b/sdk/python/feast/templates/snowflake/test.py index 32aa6380d51..3c33f6aefda 100644 --- a/sdk/python/feast/templates/snowflake/test.py +++ b/sdk/python/feast/templates/snowflake/test.py @@ -54,7 +54,8 @@ def main(): # Retrieve features from the online store online_features = fs.get_online_features( - features=features, entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], + features=features, + entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], ).to_dict() print() diff --git a/sdk/python/feast/templates/spark/example.py b/sdk/python/feast/templates/spark/example.py index b1741fd6883..1cdffaf9097 100644 --- a/sdk/python/feast/templates/spark/example.py +++ b/sdk/python/feast/templates/spark/example.py @@ -15,9 +15,15 @@ # Entity definitions -driver = Entity(name="driver_id", value_type=ValueType.INT64, description="driver id",) +driver = Entity( + name="driver_id", + value_type=ValueType.INT64, + description="driver id", +) customer = Entity( - name="customer_id", value_type=ValueType.INT64, description="customer id", + name="customer_id", + value_type=ValueType.INT64, + description="customer id", ) # Sources diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index 713b952d092..9798faf508b 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -580,10 +580,10 @@ def pa_to_redshift_value_type(pa_type: pyarrow.DataType) -> str: def _non_empty_value(value: Any) -> bool: """ - Check that there's enough data we can use for type inference. - If primitive type - just checking that it's not None - If iterable - checking that there's some elements (len > 0) - String is special case: "" - empty string is considered non empty + Check that there's enough data we can use for type inference. + If primitive type - just checking that it's not None + If iterable - checking that there's some elements (len > 0) + String is special case: "" - empty string is considered non empty """ return value is not None and ( not isinstance(value, Sized) or len(value) > 0 or isinstance(value, str) diff --git a/sdk/python/feast/usage.py b/sdk/python/feast/usage.py index 6a6a7146ce7..90b659479d1 100644 --- a/sdk/python/feast/usage.py +++ b/sdk/python/feast/usage.py @@ -224,27 +224,27 @@ def tracing_span(name): def log_exceptions_and_usage(*args, **attrs): """ - This function decorator enables three components: - 1. Error tracking - 2. Usage statistic collection - 3. Time profiling - - This data is being collected, anonymized and sent to Feast Developers. - All events from nested decorated functions are being grouped into single event - to build comprehensive context useful for profiling and error tracking. - - Usage example (will result in one output event): - @log_exceptions_and_usage - def fn(...): - nested() - - @log_exceptions_and_usage(attr='value') - def nested(...): - deeply_nested() - - @log_exceptions_and_usage(attr2='value2', sample=RateSampler(rate=0.1)) - def deeply_nested(...): - ... + This function decorator enables three components: + 1. Error tracking + 2. Usage statistic collection + 3. Time profiling + + This data is being collected, anonymized and sent to Feast Developers. + All events from nested decorated functions are being grouped into single event + to build comprehensive context useful for profiling and error tracking. + + Usage example (will result in one output event): + @log_exceptions_and_usage + def fn(...): + nested() + + @log_exceptions_and_usage(attr='value') + def nested(...): + deeply_nested() + + @log_exceptions_and_usage(attr2='value2', sample=RateSampler(rate=0.1)) + def deeply_nested(...): + ... """ sampler = attrs.pop("sampler", AlwaysSampler()) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 890c48fcbe2..e521338680c 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -4,7 +4,7 @@ def make_tzaware(t: datetime) -> datetime: - """ We assume tz-naive datetimes are UTC """ + """We assume tz-naive datetimes are UTC""" if t.tzinfo is None: return t.replace(tzinfo=utc) else: diff --git a/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py b/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py index 3d2a253dc0e..5aec2ac1305 100644 --- a/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py +++ b/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py @@ -59,5 +59,7 @@ def test_online_retrieval(environment, universal_data_sources, benchmark): unprefixed_feature_refs.remove("conv_rate_plus_val_to_add") benchmark( - fs.get_online_features, features=feature_refs, entity_rows=entity_rows, + fs.get_online_features, + features=feature_refs, + entity_rows=entity_rows, ) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 1254604a0be..928364cac50 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -70,10 +70,16 @@ def pytest_addoption(parser): help="Run tests with external dependencies", ) parser.addoption( - "--benchmark", action="store_true", default=False, help="Run benchmark tests", + "--benchmark", + action="store_true", + default=False, + help="Run benchmark tests", ) parser.addoption( - "--universal", action="store_true", default=False, help="Run universal tests", + "--universal", + action="store_true", + default=False, + help="Run universal tests", ) parser.addoption( "--goserver", @@ -254,7 +260,9 @@ def cleanup(): def e2e_data_sources(environment: Environment, request): df = create_dataset() data_source = environment.data_source_creator.create_data_source( - df, environment.feature_store.project, field_mapping={"ts_1": "ts"}, + df, + environment.feature_store.project, + field_mapping={"ts_1": "ts"}, ) def cleanup(): diff --git a/sdk/python/tests/doctest/test_all.py b/sdk/python/tests/doctest/test_all.py index bf3a09db1e2..c2d3db351b0 100644 --- a/sdk/python/tests/doctest/test_all.py +++ b/sdk/python/tests/doctest/test_all.py @@ -17,7 +17,9 @@ def setup_feature_store(): init_repo("feature_repo", "local") fs = FeatureStore(repo_path="feature_repo") driver = Entity( - name="driver_id", value_type=ValueType.INT64, description="driver id", + name="driver_id", + value_type=ValueType.INT64, + description="driver id", ) driver_hourly_stats = FileSource( path="feature_repo/data/driver_stats.parquet", @@ -89,7 +91,8 @@ def test_docstrings(): setup_function() test_suite = doctest.DocTestSuite( - temp_module, optionflags=doctest.ELLIPSIS, + temp_module, + optionflags=doctest.ELLIPSIS, ) if test_suite.countTestCases() > 0: result = unittest.TextTestRunner(sys.stdout).run(test_suite) diff --git a/sdk/python/tests/example_repos/example_feature_repo_2.py b/sdk/python/tests/example_repos/example_feature_repo_2.py index 96da67ac92d..714bd8c8b50 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_2.py +++ b/sdk/python/tests/example_repos/example_feature_repo_2.py @@ -8,7 +8,11 @@ created_timestamp_column="created", ) -driver = Entity(name="driver_id", value_type=ValueType.INT64, description="driver id",) +driver = Entity( + name="driver_id", + value_type=ValueType.INT64, + description="driver id", +) driver_hourly_stats_view = FeatureView( diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 1d4ce7d6cb6..574bbd12b11 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -29,7 +29,10 @@ def update_infra( pass def teardown_infra( - self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], + self, + project: str, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): pass diff --git a/sdk/python/tests/integration/e2e/test_usage_e2e.py b/sdk/python/tests/integration/e2e/test_usage_e2e.py index c7b62b3a5da..12c1eb86281 100644 --- a/sdk/python/tests/integration/e2e/test_usage_e2e.py +++ b/sdk/python/tests/integration/e2e/test_usage_e2e.py @@ -136,7 +136,7 @@ def test_exception_usage_off(dummy_exporter, enabling_toggle): def _reload_feast(): - """ After changing environment need to reload modules and rerun usage decorators """ + """After changing environment need to reload modules and rerun usage decorators""" modules = ( "feast.infra.local", "feast.infra.online_stores.sqlite", diff --git a/sdk/python/tests/integration/e2e/test_validation.py b/sdk/python/tests/integration/e2e/test_validation.py index 77706b74d10..d7ea69904be 100644 --- a/sdk/python/tests/integration/e2e/test_validation.py +++ b/sdk/python/tests/integration/e2e/test_validation.py @@ -72,7 +72,8 @@ def test_historical_retrieval_with_validation(environment, universal_data_source ) reference_job = store.get_historical_features( - entity_df=entity_df, features=_features, + entity_df=entity_df, + features=_features, ) store.create_saved_dataset( @@ -81,7 +82,10 @@ def test_historical_retrieval_with_validation(environment, universal_data_source storage=environment.data_source_creator.create_saved_dataset_destination(), ) - job = store.get_historical_features(entity_df=entity_df, features=_features,) + job = store.get_historical_features( + entity_df=entity_df, + features=_features, + ) # if validation pass there will be no exceptions on this point job.to_df( @@ -106,7 +110,8 @@ def test_historical_retrieval_fails_on_validation(environment, universal_data_so ) reference_job = store.get_historical_features( - entity_df=entity_df, features=_features, + entity_df=entity_df, + features=_features, ) store.create_saved_dataset( @@ -115,7 +120,10 @@ def test_historical_retrieval_fails_on_validation(environment, universal_data_so storage=environment.data_source_creator.create_saved_dataset_destination(), ) - job = store.get_historical_features(entity_df=entity_df, features=_features,) + job = store.get_historical_features( + entity_df=entity_df, + features=_features, + ) with pytest.raises(ValidationFailed) as exc_info: job.to_df( diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 9cbe11f86b9..eb60ac3852e 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -117,7 +117,10 @@ FULL_REPO_CONFIGS = DEFAULT_FULL_REPO_CONFIGS GO_REPO_CONFIGS = [ - IntegrationTestRepoConfig(online_store=REDIS_CONFIG, go_feature_server=True,), + IntegrationTestRepoConfig( + online_store=REDIS_CONFIG, + go_feature_server=True, + ), ] @@ -271,7 +274,8 @@ def values(self): def construct_universal_feature_views( - data_sources: UniversalDataSources, with_odfv: bool = True, + data_sources: UniversalDataSources, + with_odfv: bool = True, ) -> UniversalFeatureViews: driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver) return UniversalFeatureViews( @@ -367,7 +371,8 @@ def construct_test_environment( # Note: even if it's a local feature server, the repo config does not have this configured feature_server = None registry = RegistryConfig( - path=str(Path(repo_dir_name) / "registry.db"), cache_ttl_seconds=1, + path=str(Path(repo_dir_name) / "registry.db"), + cache_ttl_seconds=1, ) config = RepoConfig( registry=registry, diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py index baa3db6afc1..6740015a910 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py @@ -106,7 +106,9 @@ def _upload_parquet_file(self, df, file_name, minio_endpoint): if not client.bucket_exists(self.bucket): client.make_bucket(self.bucket) client.fput_object( - self.bucket, file_name, self.f.name, + self.bucket, + file_name, + self.f.name, ) def create_data_source( diff --git a/sdk/python/tests/integration/offline_store/test_s3_custom_endpoint.py b/sdk/python/tests/integration/offline_store/test_s3_custom_endpoint.py index 6066009a6d6..b9f8cefa780 100644 --- a/sdk/python/tests/integration/offline_store/test_s3_custom_endpoint.py +++ b/sdk/python/tests/integration/offline_store/test_s3_custom_endpoint.py @@ -17,7 +17,9 @@ @pytest.mark.skip( reason="No way to run this test today. Credentials conflict with real AWS credentials in CI" ) -def test_registration_and_retrieval_from_custom_s3_endpoint(universal_data_sources,): +def test_registration_and_retrieval_from_custom_s3_endpoint( + universal_data_sources, +): config = IntegrationTestRepoConfig( offline_store_creator="tests.integration.feature_repos.universal.data_sources.file.S3FileDataSourceCreator" ) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 541033433f6..390f8e40c0e 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -209,7 +209,10 @@ def get_expected_training_df( (f"global_stats__{k}" if full_feature_names else k): global_record.get( k, None ) - for k in ("num_rides", "avg_ride_length",) + for k in ( + "num_rides", + "avg_ride_length", + ) } ) @@ -855,5 +858,7 @@ def assert_frame_equal(expected_df, actual_df, keys): ) pd_assert_frame_equal( - expected_df, actual_df, check_dtype=False, + expected_df, + actual_df, + check_dtype=False, ) diff --git a/sdk/python/tests/integration/online_store/test_e2e_local.py b/sdk/python/tests/integration/online_store/test_e2e_local.py index d14bc5ab1cc..c1aa10900ae 100644 --- a/sdk/python/tests/integration/online_store/test_e2e_local.py +++ b/sdk/python/tests/integration/online_store/test_e2e_local.py @@ -12,7 +12,7 @@ def _get_last_feature_row(df: pd.DataFrame, driver_id, max_date: datetime): - """ Manually extract last feature value from a dataframe for a given driver_id with up to `max_date` date """ + """Manually extract last feature value from a dataframe for a given driver_id with up to `max_date` date""" filtered = df[ (df["driver_id"] == driver_id) & (df["event_timestamp"] < max_date.replace(tzinfo=utc)) diff --git a/sdk/python/tests/integration/registration/test_cli.py b/sdk/python/tests/integration/registration/test_cli.py index 655e53e7593..e721bb28c97 100644 --- a/sdk/python/tests/integration/registration/test_cli.py +++ b/sdk/python/tests/integration/registration/test_cli.py @@ -86,7 +86,8 @@ def test_universal_cli(environment: Environment): assertpy.assert_that(result.returncode).is_equal_to(0) assertpy.assert_that(fs.list_feature_views()).is_length(4) result = runner.run( - ["data-sources", "describe", "customer_profile_source"], cwd=repo_path, + ["data-sources", "describe", "customer_profile_source"], + cwd=repo_path, ) assertpy.assert_that(result.returncode).is_equal_to(0) assertpy.assert_that(fs.list_data_sources()).is_length(4) diff --git a/sdk/python/tests/integration/registration/test_feature_store.py b/sdk/python/tests/integration/registration/test_feature_store.py index d5496a6de75..8234c904829 100644 --- a/sdk/python/tests/integration/registration/test_feature_store.py +++ b/sdk/python/tests/integration/registration/test_feature_store.py @@ -88,7 +88,8 @@ def feature_store_with_s3_registry(): @pytest.mark.parametrize( - "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", + [lazy_fixture("feature_store_with_local_registry")], ) def test_apply_entity_success(test_feature_store): entity = Entity( @@ -160,7 +161,8 @@ def test_apply_entity_integration(test_feature_store): @pytest.mark.parametrize( - "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", + [lazy_fixture("feature_store_with_local_registry")], ) def test_apply_feature_view_success(test_feature_store): # Create Feature Views @@ -211,7 +213,8 @@ def test_apply_feature_view_success(test_feature_store): @pytest.mark.integration @pytest.mark.parametrize( - "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", + [lazy_fixture("feature_store_with_local_registry")], ) @pytest.mark.parametrize("dataframe_source", [lazy_fixture("simple_dataset_1")]) def test_feature_view_inference_success(test_feature_store, dataframe_source): @@ -351,7 +354,8 @@ def test_apply_feature_view_integration(test_feature_store): @pytest.mark.parametrize( - "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", + [lazy_fixture("feature_store_with_local_registry")], ) def test_apply_object_and_read(test_feature_store): assert isinstance(test_feature_store, FeatureStore) @@ -427,7 +431,8 @@ def test_apply_remote_repo(): @pytest.mark.parametrize( - "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", + [lazy_fixture("feature_store_with_local_registry")], ) @pytest.mark.parametrize("dataframe_source", [lazy_fixture("simple_dataset_1")]) def test_reapply_feature_view_success(test_feature_store, dataframe_source): @@ -485,7 +490,7 @@ def test_reapply_feature_view_success(test_feature_store, dataframe_source): def test_apply_conflicting_featureview_names(feature_store_with_local_registry): - """ Test applying feature views with non-case-insensitively unique names""" + """Test applying feature views with non-case-insensitively unique names""" driver_stats = FeatureView( name="driver_hourly_stats", diff --git a/sdk/python/tests/integration/registration/test_inference.py b/sdk/python/tests/integration/registration/test_inference.py index 0ea6276669d..2f8b932acb2 100644 --- a/sdk/python/tests/integration/registration/test_inference.py +++ b/sdk/python/tests/integration/registration/test_inference.py @@ -45,10 +45,16 @@ def test_update_entities_with_inferred_types_from_feature_views( ) as file_source_2: fv1 = FeatureView( - name="fv1", entities=["id"], batch_source=file_source, ttl=None, + name="fv1", + entities=["id"], + batch_source=file_source, + ttl=None, ) fv2 = FeatureView( - name="fv2", entities=["id"], batch_source=file_source_2, ttl=None, + name="fv2", + entities=["id"], + batch_source=file_source_2, + ttl=None, ) actual_1 = Entity(name="id", join_key="id_join_key") @@ -151,7 +157,8 @@ def test_update_data_sources_with_inferred_event_timestamp_col(universal_data_so data_source.event_timestamp_column = None update_data_sources_with_inferred_event_timestamp_col( - data_sources_copy.values(), RepoConfig(provider="local", project="test"), + data_sources_copy.values(), + RepoConfig(provider="local", project="test"), ) actual_event_timestamp_cols = [ source.event_timestamp_column for source in data_sources_copy.values() diff --git a/sdk/python/tests/integration/registration/test_registry.py b/sdk/python/tests/integration/registration/test_registry.py index 535497634d4..3a510b235a0 100644 --- a/sdk/python/tests/integration/registration/test_registry.py +++ b/sdk/python/tests/integration/registration/test_registry.py @@ -67,7 +67,8 @@ def s3_registry() -> Registry: @pytest.mark.parametrize( - "test_registry", [lazy_fixture("local_registry")], + "test_registry", + [lazy_fixture("local_registry")], ) def test_apply_entity_success(test_registry): entity = Entity( @@ -116,7 +117,8 @@ def test_apply_entity_success(test_registry): @pytest.mark.integration @pytest.mark.parametrize( - "test_registry", [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], + "test_registry", + [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], ) def test_apply_entity_integration(test_registry): entity = Entity( @@ -160,7 +162,8 @@ def test_apply_entity_integration(test_registry): @pytest.mark.parametrize( - "test_registry", [lazy_fixture("local_registry")], + "test_registry", + [lazy_fixture("local_registry")], ) def test_apply_feature_view_success(test_registry): # Create Feature Views @@ -234,7 +237,8 @@ def test_apply_feature_view_success(test_registry): @pytest.mark.parametrize( - "test_registry", [lazy_fixture("local_registry")], + "test_registry", + [lazy_fixture("local_registry")], ) def test_modify_feature_views_success(test_registry): # Create Feature Views @@ -355,7 +359,8 @@ def odfv1(feature_df: pd.DataFrame) -> pd.DataFrame: @pytest.mark.integration @pytest.mark.parametrize( - "test_registry", [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], + "test_registry", + [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], ) def test_apply_feature_view_integration(test_registry): # Create Feature Views @@ -430,7 +435,8 @@ def test_apply_feature_view_integration(test_registry): @pytest.mark.integration @pytest.mark.parametrize( - "test_registry", [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], + "test_registry", + [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], ) def test_apply_data_source(test_registry: Registry): # Create Feature Views diff --git a/sdk/python/tests/integration/registration/test_universal_types.py b/sdk/python/tests/integration/registration/test_universal_types.py index 59ca119f98a..7ec2e183e1e 100644 --- a/sdk/python/tests/integration/registration/test_universal_types.py +++ b/sdk/python/tests/integration/registration/test_universal_types.py @@ -91,9 +91,11 @@ def online_types_test_fixtures(request): def get_fixtures(request): config: TypeTestConfig = request.param # Lower case needed because Redshift lower-cases all table names - test_project_id = f"{config.entity_type}{config.feature_dtype}{config.feature_is_list}".replace( - ".", "" - ).lower() + test_project_id = ( + f"{config.entity_type}{config.feature_dtype}{config.feature_is_list}".replace( + ".", "" + ).lower() + ) type_test_environment = construct_test_environment( test_repo_config=config.test_repo_config, test_suite_name=f"test_{test_project_id}", @@ -182,7 +184,8 @@ def test_feature_get_historical_features_types_match(offline_types_test_fixtures ts - timedelta(hours=2), ] historical_features = fs.get_historical_features( - entity_df=entity_df, features=features, + entity_df=entity_df, + features=features, ) # Note: Pandas doesn't play well with nan values in ints. BQ will also coerce to floats if there are NaNs historical_features_df = historical_features.to_df() @@ -230,7 +233,8 @@ def test_feature_get_online_features_types_match(online_types_test_fixtures): driver_id_value = "1" if config.entity_type == ValueType.STRING else 1 online_features = fs.get_online_features( - features=features, entity_rows=[{"driver": driver_id_value}], + features=features, + entity_rows=[{"driver": driver_id_value}], ).to_dict() feature_list_dtype_to_expected_online_response_value_type = { @@ -286,7 +290,11 @@ def create_feature_view( elif feature_dtype == "datetime": value_type = ValueType.UNIX_TIMESTAMP - return driver_feature_view(data_source, name=name, value_type=value_type,) + return driver_feature_view( + data_source, + name=name, + value_type=value_type, + ) def assert_expected_historical_feature_types( diff --git a/sdk/python/tests/unit/diff/test_registry_diff.py b/sdk/python/tests/unit/diff/test_registry_diff.py index 0322ab47abf..58b7a4c00b6 100644 --- a/sdk/python/tests/unit/diff/test_registry_diff.py +++ b/sdk/python/tests/unit/diff/test_registry_diff.py @@ -11,10 +11,16 @@ def test_tag_objects_for_keep_delete_update_add(simple_dataset_1): df=simple_dataset_1, event_timestamp_column="ts_1" ) as file_source: to_delete = FeatureView( - name="to_delete", entities=["id"], batch_source=file_source, ttl=None, + name="to_delete", + entities=["id"], + batch_source=file_source, + ttl=None, ) unchanged_fv = FeatureView( - name="fv1", entities=["id"], batch_source=file_source, ttl=None, + name="fv1", + entities=["id"], + batch_source=file_source, + ttl=None, ) pre_changed = FeatureView( name="fv2", @@ -31,7 +37,10 @@ def test_tag_objects_for_keep_delete_update_add(simple_dataset_1): tags={"when": "after"}, ) to_add = FeatureView( - name="to_add", entities=["id"], batch_source=file_source, ttl=None, + name="to_add", + entities=["id"], + batch_source=file_source, + ttl=None, ) keep, delete, update, add = tag_objects_for_keep_delete_update_add( diff --git a/sdk/python/tests/unit/test_usage.py b/sdk/python/tests/unit/test_usage.py index 13988d32642..ca842474307 100644 --- a/sdk/python/tests/unit/test_usage.py +++ b/sdk/python/tests/unit/test_usage.py @@ -234,4 +234,4 @@ def call_length_ms(call): return ( datetime.datetime.fromisoformat(call["end"]) - datetime.datetime.fromisoformat(call["start"]) - ).total_seconds() * 10 ** 3 + ).total_seconds() * 10**3 diff --git a/sdk/python/tests/utils/data_source_utils.py b/sdk/python/tests/utils/data_source_utils.py index 5a5baceef07..d2ad337e21a 100644 --- a/sdk/python/tests/utils/data_source_utils.py +++ b/sdk/python/tests/utils/data_source_utils.py @@ -43,7 +43,8 @@ def simple_bq_source_using_table_ref_arg( job.result() return BigQuerySource( - table_ref=table_ref, event_timestamp_column=event_timestamp_column, + table_ref=table_ref, + event_timestamp_column=event_timestamp_column, ) diff --git a/sdk/python/tests/utils/online_read_write_test.py b/sdk/python/tests/utils/online_read_write_test.py index fe03217dabe..39846cd2ad4 100644 --- a/sdk/python/tests/utils/online_read_write_test.py +++ b/sdk/python/tests/utils/online_read_write_test.py @@ -22,7 +22,7 @@ def basic_rw_test( ) def _driver_rw_test(event_ts, created_ts, write, expect_read): - """ A helper function to write values and read them back """ + """A helper function to write values and read them back""" write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( From 9afe66d21dc98529206fa89050f08eaa0e1ae366 Mon Sep 17 00:00:00 2001 From: Achal Shah Date: Tue, 29 Mar 2022 10:16:35 -0700 Subject: [PATCH 3/3] make format-python Signed-off-by: Achal Shah --- sdk/python/feast/cli.py | 14 ++---- sdk/python/feast/data_source.py | 14 ++---- sdk/python/feast/diff/infra_diff.py | 10 +---- sdk/python/feast/diff/registry_diff.py | 12 ++--- sdk/python/feast/driver_test_data.py | 7 +-- sdk/python/feast/feature.py | 9 +--- sdk/python/feast/feature_store.py | 45 +++++-------------- sdk/python/feast/feature_view.py | 4 +- sdk/python/feast/go_server.py | 6 +-- sdk/python/feast/inference.py | 11 ++--- sdk/python/feast/infra/aws.py | 5 +-- .../feast/infra/offline_stores/bigquery.py | 31 ++++--------- .../infra/offline_stores/bigquery_source.py | 7 +-- .../contrib/spark_offline_store/spark.py | 4 +- .../spark_offline_store/spark_source.py | 5 +-- sdk/python/feast/infra/offline_stores/file.py | 21 +++------ .../infra/offline_stores/offline_store.py | 6 +-- .../feast/infra/offline_stores/redshift.py | 9 ++-- .../infra/offline_stores/redshift_source.py | 4 +- .../feast/infra/offline_stores/snowflake.py | 8 ++-- .../feast/infra/online_stores/datastore.py | 15 ++----- .../feast/infra/online_stores/dynamodb.py | 6 +-- .../feast/infra/online_stores/sqlite.py | 8 +--- .../feast/infra/passthrough_provider.py | 10 +---- sdk/python/feast/infra/provider.py | 16 ++----- sdk/python/feast/infra/utils/aws_utils.py | 25 ++--------- sdk/python/feast/on_demand_feature_view.py | 7 +-- sdk/python/feast/registry.py | 5 +-- sdk/python/feast/repo_config.py | 16 ++----- sdk/python/feast/templates/aws/test.py | 3 +- sdk/python/feast/templates/gcp/test.py | 3 +- sdk/python/feast/templates/local/example.py | 6 +-- sdk/python/feast/templates/snowflake/test.py | 3 +- sdk/python/feast/templates/spark/example.py | 10 +---- ...st_benchmark_universal_online_retrieval.py | 4 +- sdk/python/tests/conftest.py | 14 ++---- sdk/python/tests/doctest/test_all.py | 7 +-- .../example_repos/example_feature_repo_2.py | 6 +-- sdk/python/tests/foo_provider.py | 5 +-- .../tests/integration/e2e/test_validation.py | 16 ++----- .../feature_repos/repo_configuration.py | 11 ++--- .../universal/data_sources/file.py | 4 +- .../offline_store/test_s3_custom_endpoint.py | 4 +- .../test_universal_historical_retrieval.py | 9 +--- .../integration/registration/test_cli.py | 3 +- .../registration/test_feature_store.py | 15 +++---- .../registration/test_inference.py | 13 ++---- .../integration/registration/test_registry.py | 18 +++----- .../registration/test_universal_types.py | 20 +++------ .../tests/unit/diff/test_registry_diff.py | 15 ++----- sdk/python/tests/unit/test_usage.py | 2 +- sdk/python/tests/utils/data_source_utils.py | 3 +- 52 files changed, 133 insertions(+), 401 deletions(-) diff --git a/sdk/python/feast/cli.py b/sdk/python/feast/cli.py index 8da6bfd7ced..d2a71bc561b 100644 --- a/sdk/python/feast/cli.py +++ b/sdk/python/feast/cli.py @@ -471,10 +471,7 @@ def registry_dump_command(ctx: click.Context): @click.argument("start_ts") @click.argument("end_ts") @click.option( - "--views", - "-v", - help="Feature views to materialize", - multiple=True, + "--views", "-v", help="Feature views to materialize", multiple=True, ) @click.pass_context def materialize_command( @@ -501,10 +498,7 @@ def materialize_command( @cli.command("materialize-incremental") @click.argument("end_ts") @click.option( - "--views", - "-v", - help="Feature views to incrementally materialize", - multiple=True, + "--views", "-v", help="Feature views to incrementally materialize", multiple=True, ) @click.pass_context def materialize_incremental_command(ctx: click.Context, end_ts: str, views: List[str]): @@ -566,9 +560,7 @@ def init_command(project_directory, minimal: bool, template: str): help="Specify a port for the server [default: 6566]", ) @click.option( - "--no-access-log", - is_flag=True, - help="Disable the Uvicorn access log.", + "--no-access-log", is_flag=True, help="Disable the Uvicorn access log.", ) @click.pass_context def serve_command(ctx: click.Context, host: str, port: int, no_access_log: bool): diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index b887fdb3b50..2f66f846bcb 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -42,10 +42,7 @@ class KafkaOptions: """ def __init__( - self, - bootstrap_servers: str, - message_format: StreamFormat, - topic: str, + self, bootstrap_servers: str, message_format: StreamFormat, topic: str, ): self.bootstrap_servers = bootstrap_servers self.message_format = message_format @@ -94,10 +91,7 @@ class KinesisOptions: """ def __init__( - self, - record_format: StreamFormat, - region: str, - stream_name: str, + self, record_format: StreamFormat, region: str, stream_name: str, ): self.record_format = record_format self.region = region @@ -384,9 +378,7 @@ class RequestDataSource(DataSource): schema: Dict[str, ValueType] def __init__( - self, - name: str, - schema: Dict[str, ValueType], + self, name: str, schema: Dict[str, ValueType], ): """Creates a RequestDataSource object.""" super().__init__(name) diff --git a/sdk/python/feast/diff/infra_diff.py b/sdk/python/feast/diff/infra_diff.py index 51bece33dd6..a09eaf39ebe 100644 --- a/sdk/python/feast/diff/infra_diff.py +++ b/sdk/python/feast/diff/infra_diff.py @@ -126,8 +126,7 @@ def diff_infra_protos( infra_objects_to_delete, infra_objects_to_add, ) = tag_infra_proto_objects_for_keep_delete_add( - current_infra_objects, - new_infra_objects, + current_infra_objects, new_infra_objects, ) for e in infra_objects_to_add: @@ -200,10 +199,5 @@ def diff_between( ) ) return InfraObjectDiff( - new.name, - infra_object_type, - current, - new, - property_diffs, - transition, + new.name, infra_object_type, current, new, property_diffs, transition, ) diff --git a/sdk/python/feast/diff/registry_diff.py b/sdk/python/feast/diff/registry_diff.py index 21cab0425b0..4558a149a5c 100644 --- a/sdk/python/feast/diff/registry_diff.py +++ b/sdk/python/feast/diff/registry_diff.py @@ -147,9 +147,7 @@ def diff_registry_objects( def extract_objects_for_keep_delete_update_add( - registry: Registry, - current_project: str, - desired_repo_contents: RepoContents, + registry: Registry, current_project: str, desired_repo_contents: RepoContents, ) -> Tuple[ Dict[FeastObjectType, Set[FeastObject]], Dict[FeastObjectType, Set[FeastObject]], @@ -196,9 +194,7 @@ def extract_objects_for_keep_delete_update_add( def diff_between( - registry: Registry, - current_project: str, - desired_repo_contents: RepoContents, + registry: Registry, current_project: str, desired_repo_contents: RepoContents, ) -> RegistryDiff: """ Returns the difference between the current and desired repo states. @@ -291,9 +287,7 @@ def apply_diff_to_registry( BaseFeatureView, feast_object_diff.current_feast_object ) registry.delete_feature_view( - feature_view_obj.name, - project, - commit=False, + feature_view_obj.name, project, commit=False, ) if feast_object_diff.transition_type in [ diff --git a/sdk/python/feast/driver_test_data.py b/sdk/python/feast/driver_test_data.py index 07018c9004d..117bfcbd9cb 100644 --- a/sdk/python/feast/driver_test_data.py +++ b/sdk/python/feast/driver_test_data.py @@ -30,12 +30,7 @@ def _convert_event_timestamp(event_timestamp: pd.Timestamp, t: EventTimestampTyp def create_orders_df( - customers, - drivers, - start_date, - end_date, - order_count, - locations=None, + customers, drivers, start_date, end_date, order_count, locations=None, ) -> pd.DataFrame: """ Example df generated by this function (if locations): diff --git a/sdk/python/feast/feature.py b/sdk/python/feast/feature.py index 81f99e8cb30..b37e0f562b2 100644 --- a/sdk/python/feast/feature.py +++ b/sdk/python/feast/feature.py @@ -30,10 +30,7 @@ class Feature: """ def __init__( - self, - name: str, - dtype: ValueType, - labels: Optional[Dict[str, str]] = None, + self, name: str, dtype: ValueType, labels: Optional[Dict[str, str]] = None, ): """Creates a Feature object.""" self._name = name @@ -94,9 +91,7 @@ def to_proto(self) -> FeatureSpecProto: value_type = ValueTypeProto.ValueType.Enum.Value(self.dtype.name) return FeatureSpecProto( - name=self.name, - value_type=value_type, - labels=self.labels, + name=self.name, value_type=value_type, labels=self.labels, ) @classmethod diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index ddc68419602..b2a00e4a736 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -112,9 +112,7 @@ class FeatureStore: @log_exceptions def __init__( - self, - repo_path: Optional[str] = None, - config: Optional[RepoConfig] = None, + self, repo_path: Optional[str] = None, config: Optional[RepoConfig] = None, ): """ Creates a FeatureStore object. @@ -245,9 +243,7 @@ def list_request_feature_views( ) def _list_feature_views( - self, - allow_cache: bool = False, - hide_dummy_entity: bool = True, + self, allow_cache: bool = False, hide_dummy_entity: bool = True, ) -> List[FeatureView]: feature_views = [] for fv in self._registry.list_feature_views( @@ -414,9 +410,7 @@ def delete_feature_service(self, name: str): return self._registry.delete_feature_service(name, self.project) def _get_features( - self, - features: Union[List[str], FeatureService], - allow_cache: bool = False, + self, features: Union[List[str], FeatureService], allow_cache: bool = False, ) -> List[str]: _features = features @@ -878,8 +872,7 @@ def get_historical_features( for feature_name in odfv_request_data_schema.keys(): if feature_name not in entity_pd_df.columns: raise RequestDataNotFoundInEntityDfException( - feature_name=feature_name, - feature_view_name=odfv.name, + feature_name=feature_name, feature_view_name=odfv.name, ) _validate_feature_refs(_feature_refs, full_feature_names) @@ -992,9 +985,7 @@ def get_saved_dataset(self, name: str) -> SavedDataset: @log_exceptions_and_usage def materialize_incremental( - self, - end_date: datetime, - feature_views: Optional[List[str]] = None, + self, end_date: datetime, feature_views: Optional[List[str]] = None, ) -> None: """ Materialize incremental new data from the offline store into the online store. @@ -1081,10 +1072,7 @@ def tqdm_builder(length): ) self._registry.apply_materialization( - feature_view, - self.project, - start_date, - end_date, + feature_view, self.project, start_date, end_date, ) @log_exceptions_and_usage @@ -1171,10 +1159,7 @@ def tqdm_builder(length): ) self._registry.apply_materialization( - feature_view, - self.project, - start_date, - end_date, + feature_view, self.project, start_date, end_date, ) @log_exceptions_and_usage @@ -1442,17 +1427,12 @@ def _get_online_features( for table, requested_features in grouped_refs: # Get the correct set of entity values with the correct join keys. table_entity_values, idxs = self._get_unique_entities( - table, - join_key_values, - entity_name_to_join_key_map, + table, join_key_values, entity_name_to_join_key_map, ) # Fetch feature data for the minimum set of Entities. feature_data = self._read_from_online_store( - table_entity_values, - provider, - requested_features, - table, + table_entity_values, provider, requested_features, table, ) # Populate the result_rows with the Features from the OnlineStore inplace. @@ -1617,9 +1597,7 @@ def _get_unique_entities( """ # Get the correct set of entity values with the correct join keys. table_entity_values = self._get_table_entity_values( - table, - entity_name_to_join_key_map, - join_key_values, + table, entity_name_to_join_key_map, join_key_values, ) # Convert back to rowise. @@ -1804,8 +1782,7 @@ def _augment_response_with_on_demand_transforms( for odfv_name, _feature_refs in odfv_feature_refs.items(): odfv = requested_odfv_map[odfv_name] transformed_features_df = odfv.get_transformed_features_df( - initial_response_df, - full_feature_names, + initial_response_df, full_feature_names, ) selected_subset = [ f for f in transformed_features_df.columns if f in _feature_refs diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 8a5019530fa..b81c0afa510 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -44,9 +44,7 @@ DUMMY_ENTITY_NAME = "__dummy" DUMMY_ENTITY_VAL = "" DUMMY_ENTITY = Entity( - name=DUMMY_ENTITY_NAME, - join_key=DUMMY_ENTITY_ID, - value_type=ValueType.STRING, + name=DUMMY_ENTITY_NAME, join_key=DUMMY_ENTITY_ID, value_type=ValueType.STRING, ) diff --git a/sdk/python/feast/go_server.py b/sdk/python/feast/go_server.py index 379e52a64c6..1fcbab61f08 100644 --- a/sdk/python/feast/go_server.py +++ b/sdk/python/feast/go_server.py @@ -84,11 +84,7 @@ def connect(self) -> bool: else feast.__path__[0] + "/binaries/server" ) # Automatically reconnect with go subprocess exits - self._process = Popen( - [executable], - cwd=cwd, - env=env, - ) + self._process = Popen([executable], cwd=cwd, env=env,) channel = grpc.insecure_channel(f"unix:{self.sock_file}") self._client = ServingServiceStub(channel) diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index bca46cee5ee..51c4e9d78ec 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -45,10 +45,7 @@ def update_entities_with_inferred_types_from_feature_views( # get entity information from information extracted from the view batch source extracted_entity_name_type_pairs = list( - filter( - lambda tup: tup[0] == entity.join_key, - col_names_and_types, - ) + filter(lambda tup: tup[0] == entity.join_key, col_names_and_types,) ) if len(extracted_entity_name_type_pairs) == 0: # Doesn't mention inference error because would also be an error without inferencing @@ -57,10 +54,8 @@ def update_entities_with_inferred_types_from_feature_views( its entity's name.""" ) - inferred_value_type = ( - view.batch_source.source_datatype_to_feast_value_type()( - extracted_entity_name_type_pairs[0][1] - ) + inferred_value_type = view.batch_source.source_datatype_to_feast_value_type()( + extracted_entity_name_type_pairs[0][1] ) if ( diff --git a/sdk/python/feast/infra/aws.py b/sdk/python/feast/infra/aws.py index e1ec507bca7..b7cc61de0e5 100644 --- a/sdk/python/feast/infra/aws.py +++ b/sdk/python/feast/infra/aws.py @@ -196,10 +196,7 @@ def _deploy_feature_server(self, project: str, image_uri: str): @log_exceptions_and_usage(provider="AwsProvider") def teardown_infra( - self, - project: str, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], ) -> None: if self.online_store: self.online_store.teardown(self.repo_config, tables, entities) diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index fe703916740..6c0d56562ca 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -121,10 +121,7 @@ def pull_latest_from_table_or_query( # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized return BigQueryRetrievalJob( - query=query, - client=client, - config=config, - full_feature_names=False, + query=query, client=client, config=config, full_feature_names=False, ) @staticmethod @@ -154,10 +151,7 @@ def pull_all_from_table_or_query( WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') """ return BigQueryRetrievalJob( - query=query, - client=client, - config=config, - full_feature_names=False, + query=query, client=client, config=config, full_feature_names=False, ) @staticmethod @@ -188,27 +182,20 @@ def get_historical_features( config.offline_store.location, ) - entity_schema = _get_entity_schema( - client=client, - entity_df=entity_df, - ) + entity_schema = _get_entity_schema(client=client, entity_df=entity_df,) - entity_df_event_timestamp_col = ( - offline_utils.infer_event_timestamp_from_entity_df(entity_schema) + entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, - entity_df_event_timestamp_col, - client, + entity_df, entity_df_event_timestamp_col, client, ) @contextlib.contextmanager def query_generator() -> Iterator[str]: _upload_entity_df( - client=client, - table_name=table_reference, - entity_df=entity_df, + client=client, table_name=table_reference, entity_df=entity_df, ) expected_join_keys = offline_utils.get_expected_join_keys( @@ -450,9 +437,7 @@ def _get_table_reference_for_new_entity( def _upload_entity_df( - client: Client, - table_name: str, - entity_df: Union[pd.DataFrame, str], + client: Client, table_name: str, entity_df: Union[pd.DataFrame, str], ) -> Table: """Uploads a Pandas entity dataframe into a BigQuery table and returns the resulting table""" diff --git a/sdk/python/feast/infra/offline_stores/bigquery_source.py b/sdk/python/feast/infra/offline_stores/bigquery_source.py index 0025949527c..24593581c7c 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery_source.py +++ b/sdk/python/feast/infra/offline_stores/bigquery_source.py @@ -186,9 +186,7 @@ class BigQueryOptions: """ def __init__( - self, - table_ref: Optional[str], - query: Optional[str], + self, table_ref: Optional[str], query: Optional[str], ): self._table_ref = table_ref self._query = query @@ -249,8 +247,7 @@ def to_proto(self) -> DataSourceProto.BigQueryOptions: """ bigquery_options_proto = DataSourceProto.BigQueryOptions( - table_ref=self.table_ref, - query=self.query, + table_ref=self.table_ref, query=self.query, ) return bigquery_options_proto diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 2bbe198d308..fb40edd16d7 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -134,9 +134,7 @@ def get_historical_features( entity_schema=entity_schema, ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, - event_timestamp_col, - spark_session, + entity_df, event_timestamp_col, spark_session, ) expected_join_keys = offline_utils.get_expected_join_keys( diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 1db9d0887e5..3ffdf6eda0c 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -79,10 +79,7 @@ def __init__( ) self.spark_options = SparkOptions( - table=table, - query=query, - path=path, - file_format=file_format, + table=table, query=query, path=path, file_format=file_format, ) @property diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 6a4b5f72011..b39e8f5c2de 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -79,8 +79,7 @@ def persist(self, storage: SavedDatasetStorage): assert isinstance(storage, SavedDatasetFileStorage) filesystem, path = FileSource.create_filesystem_and_path( - storage.file_options.file_url, - storage.file_options.s3_endpoint_override, + storage.file_options.file_url, storage.file_options.s3_endpoint_override, ) if path.endswith(".parquet"): @@ -333,8 +332,7 @@ def evaluate_offline_job(): # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized return FileRetrievalJob( - evaluation_function=evaluate_offline_job, - full_feature_names=False, + evaluation_function=evaluate_offline_job, full_feature_names=False, ) @staticmethod @@ -362,8 +360,7 @@ def pull_all_from_table_or_query( def _get_entity_df_event_timestamp_range( - entity_df: Union[pd.DataFrame, str], - entity_df_event_timestamp_col: str, + entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, ) -> Tuple[datetime, datetime]: if not isinstance(entity_df, pd.DataFrame): raise ValueError( @@ -393,10 +390,7 @@ def _read_datasource(data_source) -> dd.DataFrame: else None ) - return dd.read_parquet( - data_source.path, - storage_options=storage_options, - ) + return dd.read_parquet(data_source.path, storage_options=storage_options,) def _field_mapping( @@ -446,8 +440,7 @@ def _field_mapping( # Make sure to not have duplicated columns if entity_df_event_timestamp_col == event_timestamp_column: df_to_join = _run_dask_field_mapping( - df_to_join, - {event_timestamp_column: f"__{event_timestamp_column}"}, + df_to_join, {event_timestamp_column: f"__{event_timestamp_column}"}, ) event_timestamp_column = f"__{event_timestamp_column}" @@ -560,9 +553,7 @@ def _drop_duplicates( df_to_join = df_to_join.persist() df_to_join = df_to_join.drop_duplicates( - all_join_keys + [entity_df_event_timestamp_col], - keep="last", - ignore_index=True, + all_join_keys + [entity_df_event_timestamp_col], keep="last", ignore_index=True, ) return df_to_join.persist() diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 9b00773c34c..e5937712f69 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -81,8 +81,7 @@ def to_df( for odfv in self.on_demand_feature_views: features_df = features_df.join( odfv.get_transformed_features_df( - features_df, - self.full_feature_names, + features_df, self.full_feature_names, ) ) @@ -126,8 +125,7 @@ def to_arrow( for odfv in self.on_demand_feature_views: features_df = features_df.join( odfv.get_transformed_features_df( - features_df, - self.full_feature_names, + features_df, self.full_feature_names, ) ) diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 828ffba46ab..e67cf13f5c4 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -185,15 +185,12 @@ def get_historical_features( entity_df, redshift_client, config, s3_resource ) - entity_df_event_timestamp_col = ( - offline_utils.infer_event_timestamp_from_entity_df(entity_schema) + entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, - entity_df_event_timestamp_col, - redshift_client, - config, + entity_df, entity_df_event_timestamp_col, redshift_client, config, ) @contextlib.contextmanager diff --git a/sdk/python/feast/infra/offline_stores/redshift_source.py b/sdk/python/feast/infra/offline_stores/redshift_source.py index d3afb532d9a..8573396aca4 100644 --- a/sdk/python/feast/infra/offline_stores/redshift_source.py +++ b/sdk/python/feast/infra/offline_stores/redshift_source.py @@ -282,9 +282,7 @@ def to_proto(self) -> DataSourceProto.RedshiftOptions: A RedshiftOptionsProto protobuf. """ redshift_options_proto = DataSourceProto.RedshiftOptions( - table=self.table, - schema=self.schema, - query=self.query, + table=self.table, schema=self.schema, query=self.query, ) return redshift_options_proto diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index b1bc2b1d136..cc346251a82 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -208,14 +208,12 @@ def get_historical_features( entity_schema = _get_entity_schema(entity_df, snowflake_conn, config) - entity_df_event_timestamp_col = ( - offline_utils.infer_event_timestamp_from_entity_df(entity_schema) + entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( - entity_df, - entity_df_event_timestamp_col, - snowflake_conn, + entity_df, entity_df_event_timestamp_col, snowflake_conn, ) @contextlib.contextmanager diff --git a/sdk/python/feast/infra/online_stores/datastore.py b/sdk/python/feast/infra/online_stores/datastore.py index c495d324f50..e975ce138ca 100644 --- a/sdk/python/feast/infra/online_stores/datastore.py +++ b/sdk/python/feast/infra/online_stores/datastore.py @@ -197,12 +197,7 @@ def _write_minibatch( document_id = compute_entity_id(entity_key) key = client.key( - "Project", - project, - "Table", - table.name, - "Row", - document_id, + "Project", project, "Table", table.name, "Row", document_id, ) entity = datastore.Entity( @@ -321,10 +316,7 @@ def _initialize_client( project_id: Optional[str], namespace: Optional[str] ) -> datastore.Client: try: - client = datastore.Client( - project=project_id, - namespace=namespace, - ) + client = datastore.Client(project=project_id, namespace=namespace,) return client except DefaultCredentialsError as e: raise FeastProviderLoginError( @@ -400,8 +392,7 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any: @staticmethod def from_proto(datastore_table_proto: DatastoreTableProto) -> Any: datastore_table = DatastoreTable( - project=datastore_table_proto.project, - name=datastore_table_proto.name, + project=datastore_table_proto.project, name=datastore_table_proto.name, ) # Distinguish between null and empty string, since project_id and namespace are StringValues. diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index 6e212ef55f1..c161a5b9554 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -304,8 +304,7 @@ def _get_table_name( def _delete_table_idempotent( - dynamodb_resource, - table_name: str, + dynamodb_resource, table_name: str, ): try: table = dynamodb_resource.Table(table_name) @@ -359,8 +358,7 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any: @staticmethod def from_proto(dynamodb_table_proto: DynamoDBTableProto) -> Any: return DynamoDBTable( - name=dynamodb_table_proto.name, - region=dynamodb_table_proto.region, + name=dynamodb_table_proto.name, region=dynamodb_table_proto.region, ) def update(self): diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 23af3c13e8e..710f4c386a6 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -230,8 +230,7 @@ def teardown( def _initialize_conn(db_path: str): Path(db_path).parent.mkdir(exist_ok=True) return sqlite3.connect( - db_path, - detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES, + db_path, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES, ) @@ -279,10 +278,7 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any: @staticmethod def from_proto(sqlite_table_proto: SqliteTableProto) -> Any: - return SqliteTable( - path=sqlite_table_proto.path, - name=sqlite_table_proto.name, - ) + return SqliteTable(path=sqlite_table_proto.path, name=sqlite_table_proto.name,) def update(self): self.conn.execute( diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index b630aca0457..3468b9dc927 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -66,10 +66,7 @@ def update_infra( ) def teardown_infra( - self, - project: str, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], ) -> None: set_usage_attribute("provider", self.__class__.__name__) if self.online_store: @@ -105,10 +102,7 @@ def online_read( return result def ingest_df( - self, - feature_view: FeatureView, - entities: List[Entity], - df: pandas.DataFrame, + self, feature_view: FeatureView, entities: List[Entity], df: pandas.DataFrame, ): set_usage_attribute("provider", self.__class__.__name__) table = pa.Table.from_pandas(df) diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 952e6be4549..4441b77c644 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -79,10 +79,7 @@ def plan_infra( @abc.abstractmethod def teardown_infra( - self, - project: str, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], ): """ Tear down all cloud resources for a repo. @@ -122,10 +119,7 @@ def online_write_batch( ... def ingest_df( - self, - feature_view: FeatureView, - entities: List[Entity], - df: pandas.DataFrame, + self, feature_view: FeatureView, entities: List[Entity], df: pandas.DataFrame, ): """ Ingests a DataFrame directly into the online store @@ -308,8 +302,7 @@ def _get_column_names( def _run_field_mapping( - table: pyarrow.Table, - field_mapping: Dict[str, str], + table: pyarrow.Table, field_mapping: Dict[str, str], ) -> pyarrow.Table: # run field mapping in the forward direction cols = table.column_names @@ -321,8 +314,7 @@ def _run_field_mapping( def _run_dask_field_mapping( - table: dd.DataFrame, - field_mapping: Dict[str, str], + table: dd.DataFrame, field_mapping: Dict[str, str], ): if field_mapping: # run field mapping in the forward direction diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index cfe50fbc2ec..fe5eed774ec 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -87,10 +87,7 @@ def execute_redshift_statement_async( """ try: return redshift_data_client.execute_statement( - ClusterIdentifier=cluster_id, - Database=database, - DbUser=user, - Sql=query, + ClusterIdentifier=cluster_id, Database=database, DbUser=user, Sql=query, ) except ClientError as e: if e.response["Error"]["Code"] == "ValidationException": @@ -158,11 +155,7 @@ def get_redshift_statement_result(redshift_data_client, statement_id: str) -> di return redshift_data_client.get_statement_result(Id=statement_id) -def upload_df_to_s3( - s3_resource, - s3_path: str, - df: pd.DataFrame, -) -> None: +def upload_df_to_s3(s3_resource, s3_path: str, df: pd.DataFrame,) -> None: """Uploads a Pandas DataFrame to S3 as a parquet file Args: @@ -308,11 +301,7 @@ def temporarily_upload_df_to_redshift( # Clean up the uploaded Redshift table execute_redshift_statement( - redshift_data_client, - cluster_id, - database, - user, - f"DROP TABLE {table_name}", + redshift_data_client, cluster_id, database, user, f"DROP TABLE {table_name}", ) @@ -380,13 +369,7 @@ def unload_redshift_query_to_pa( bucket, key = get_bucket_and_key(s3_path) execute_redshift_query_and_unload_to_s3( - redshift_data_client, - cluster_id, - database, - user, - s3_path, - iam_role, - query, + redshift_data_client, cluster_id, database, user, s3_path, iam_role, query, ) with tempfile.TemporaryDirectory() as temp_dir: diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index d2130480646..f0eaf987ef3 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -162,8 +162,7 @@ def to_proto(self) -> OnDemandFeatureViewProto: features=[feature.to_proto() for feature in self.features], inputs=inputs, user_defined_function=UserDefinedFunctionProto( - name=self.udf.__name__, - body=dill.dumps(self.udf, recurse=True), + name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True), ), description=self.description, tags=self.tags, @@ -243,9 +242,7 @@ def get_request_data_schema(self) -> Dict[str, ValueType]: return schema def get_transformed_features_df( - self, - df_with_features: pd.DataFrame, - full_feature_names: bool = False, + self, df_with_features: pd.DataFrame, full_feature_names: bool = False, ) -> pd.DataFrame: # Apply on demand transformations columns_to_cleanup = [] diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index 36cc243d670..0bf73fcd24c 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -789,10 +789,7 @@ def delete_entity(self, name: str, project: str, commit: bool = True): raise EntityNotFoundException(name, project) def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, + self, saved_dataset: SavedDataset, project: str, commit: bool = True, ): """ Registers a single entity with Feast diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 7ab4a0e60e1..fce13d8f61e 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -190,8 +190,7 @@ def _validate_online_store_config(cls, values): online_config_class(**values["online_store"]) except ValidationError as e: raise ValidationError( - [ErrorWrapper(e, loc="online_store")], - model=RepoConfig, + [ErrorWrapper(e, loc="online_store")], model=RepoConfig, ) return values @@ -225,8 +224,7 @@ def _validate_offline_store_config(cls, values): offline_config_class(**values["offline_store"]) except ValidationError as e: raise ValidationError( - [ErrorWrapper(e, loc="offline_store")], - model=RepoConfig, + [ErrorWrapper(e, loc="offline_store")], model=RepoConfig, ) return values @@ -260,8 +258,7 @@ def _validate_feature_server_config(cls, values): feature_server_config_class(**values["feature_server"]) except ValidationError as e: raise ValidationError( - [ErrorWrapper(e, loc="feature_server")], - model=RepoConfig, + [ErrorWrapper(e, loc="feature_server")], model=RepoConfig, ) return values @@ -298,12 +295,7 @@ def write_to_path(self, repo_path: Path): config_path = repo_path / "feature_store.yaml" with open(config_path, mode="w") as f: yaml.dump( - yaml.safe_load( - self.json( - exclude={"repo_path"}, - exclude_unset=True, - ) - ), + yaml.safe_load(self.json(exclude={"repo_path"}, exclude_unset=True,)), f, sort_keys=False, ) diff --git a/sdk/python/feast/templates/aws/test.py b/sdk/python/feast/templates/aws/test.py index 3d223e8f266..07410954f7b 100644 --- a/sdk/python/feast/templates/aws/test.py +++ b/sdk/python/feast/templates/aws/test.py @@ -54,8 +54,7 @@ def main(): # Retrieve features from the online store (Firestore) online_features = fs.get_online_features( - features=features, - entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], + features=features, entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], ).to_dict() print() diff --git a/sdk/python/feast/templates/gcp/test.py b/sdk/python/feast/templates/gcp/test.py index 8ff11bda5c7..538334044bf 100644 --- a/sdk/python/feast/templates/gcp/test.py +++ b/sdk/python/feast/templates/gcp/test.py @@ -54,8 +54,7 @@ def main(): # Retrieve features from the online store (Firestore) online_features = fs.get_online_features( - features=features, - entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], + features=features, entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], ).to_dict() print() diff --git a/sdk/python/feast/templates/local/example.py b/sdk/python/feast/templates/local/example.py index 455b8b780c9..6ab549b8c5b 100644 --- a/sdk/python/feast/templates/local/example.py +++ b/sdk/python/feast/templates/local/example.py @@ -15,11 +15,7 @@ # Define an entity for the driver. You can think of entity as a primary key used to # fetch features. -driver = Entity( - name="driver_id", - value_type=ValueType.INT64, - description="driver id", -) +driver = Entity(name="driver_id", value_type=ValueType.INT64, description="driver id",) # Our parquet files contain sample data that includes a driver_id column, timestamps and # three feature column. Here we define a Feature View that will allow us to serve this diff --git a/sdk/python/feast/templates/snowflake/test.py b/sdk/python/feast/templates/snowflake/test.py index 3c33f6aefda..32aa6380d51 100644 --- a/sdk/python/feast/templates/snowflake/test.py +++ b/sdk/python/feast/templates/snowflake/test.py @@ -54,8 +54,7 @@ def main(): # Retrieve features from the online store online_features = fs.get_online_features( - features=features, - entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], + features=features, entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}], ).to_dict() print() diff --git a/sdk/python/feast/templates/spark/example.py b/sdk/python/feast/templates/spark/example.py index 1cdffaf9097..b1741fd6883 100644 --- a/sdk/python/feast/templates/spark/example.py +++ b/sdk/python/feast/templates/spark/example.py @@ -15,15 +15,9 @@ # Entity definitions -driver = Entity( - name="driver_id", - value_type=ValueType.INT64, - description="driver id", -) +driver = Entity(name="driver_id", value_type=ValueType.INT64, description="driver id",) customer = Entity( - name="customer_id", - value_type=ValueType.INT64, - description="customer id", + name="customer_id", value_type=ValueType.INT64, description="customer id", ) # Sources diff --git a/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py b/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py index 5aec2ac1305..3d2a253dc0e 100644 --- a/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py +++ b/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py @@ -59,7 +59,5 @@ def test_online_retrieval(environment, universal_data_sources, benchmark): unprefixed_feature_refs.remove("conv_rate_plus_val_to_add") benchmark( - fs.get_online_features, - features=feature_refs, - entity_rows=entity_rows, + fs.get_online_features, features=feature_refs, entity_rows=entity_rows, ) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 928364cac50..1254604a0be 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -70,16 +70,10 @@ def pytest_addoption(parser): help="Run tests with external dependencies", ) parser.addoption( - "--benchmark", - action="store_true", - default=False, - help="Run benchmark tests", + "--benchmark", action="store_true", default=False, help="Run benchmark tests", ) parser.addoption( - "--universal", - action="store_true", - default=False, - help="Run universal tests", + "--universal", action="store_true", default=False, help="Run universal tests", ) parser.addoption( "--goserver", @@ -260,9 +254,7 @@ def cleanup(): def e2e_data_sources(environment: Environment, request): df = create_dataset() data_source = environment.data_source_creator.create_data_source( - df, - environment.feature_store.project, - field_mapping={"ts_1": "ts"}, + df, environment.feature_store.project, field_mapping={"ts_1": "ts"}, ) def cleanup(): diff --git a/sdk/python/tests/doctest/test_all.py b/sdk/python/tests/doctest/test_all.py index c2d3db351b0..bf3a09db1e2 100644 --- a/sdk/python/tests/doctest/test_all.py +++ b/sdk/python/tests/doctest/test_all.py @@ -17,9 +17,7 @@ def setup_feature_store(): init_repo("feature_repo", "local") fs = FeatureStore(repo_path="feature_repo") driver = Entity( - name="driver_id", - value_type=ValueType.INT64, - description="driver id", + name="driver_id", value_type=ValueType.INT64, description="driver id", ) driver_hourly_stats = FileSource( path="feature_repo/data/driver_stats.parquet", @@ -91,8 +89,7 @@ def test_docstrings(): setup_function() test_suite = doctest.DocTestSuite( - temp_module, - optionflags=doctest.ELLIPSIS, + temp_module, optionflags=doctest.ELLIPSIS, ) if test_suite.countTestCases() > 0: result = unittest.TextTestRunner(sys.stdout).run(test_suite) diff --git a/sdk/python/tests/example_repos/example_feature_repo_2.py b/sdk/python/tests/example_repos/example_feature_repo_2.py index 714bd8c8b50..96da67ac92d 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_2.py +++ b/sdk/python/tests/example_repos/example_feature_repo_2.py @@ -8,11 +8,7 @@ created_timestamp_column="created", ) -driver = Entity( - name="driver_id", - value_type=ValueType.INT64, - description="driver id", -) +driver = Entity(name="driver_id", value_type=ValueType.INT64, description="driver id",) driver_hourly_stats_view = FeatureView( diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 574bbd12b11..1d4ce7d6cb6 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -29,10 +29,7 @@ def update_infra( pass def teardown_infra( - self, - project: str, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], ): pass diff --git a/sdk/python/tests/integration/e2e/test_validation.py b/sdk/python/tests/integration/e2e/test_validation.py index d7ea69904be..77706b74d10 100644 --- a/sdk/python/tests/integration/e2e/test_validation.py +++ b/sdk/python/tests/integration/e2e/test_validation.py @@ -72,8 +72,7 @@ def test_historical_retrieval_with_validation(environment, universal_data_source ) reference_job = store.get_historical_features( - entity_df=entity_df, - features=_features, + entity_df=entity_df, features=_features, ) store.create_saved_dataset( @@ -82,10 +81,7 @@ def test_historical_retrieval_with_validation(environment, universal_data_source storage=environment.data_source_creator.create_saved_dataset_destination(), ) - job = store.get_historical_features( - entity_df=entity_df, - features=_features, - ) + job = store.get_historical_features(entity_df=entity_df, features=_features,) # if validation pass there will be no exceptions on this point job.to_df( @@ -110,8 +106,7 @@ def test_historical_retrieval_fails_on_validation(environment, universal_data_so ) reference_job = store.get_historical_features( - entity_df=entity_df, - features=_features, + entity_df=entity_df, features=_features, ) store.create_saved_dataset( @@ -120,10 +115,7 @@ def test_historical_retrieval_fails_on_validation(environment, universal_data_so storage=environment.data_source_creator.create_saved_dataset_destination(), ) - job = store.get_historical_features( - entity_df=entity_df, - features=_features, - ) + job = store.get_historical_features(entity_df=entity_df, features=_features,) with pytest.raises(ValidationFailed) as exc_info: job.to_df( diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index eb60ac3852e..9cbe11f86b9 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -117,10 +117,7 @@ FULL_REPO_CONFIGS = DEFAULT_FULL_REPO_CONFIGS GO_REPO_CONFIGS = [ - IntegrationTestRepoConfig( - online_store=REDIS_CONFIG, - go_feature_server=True, - ), + IntegrationTestRepoConfig(online_store=REDIS_CONFIG, go_feature_server=True,), ] @@ -274,8 +271,7 @@ def values(self): def construct_universal_feature_views( - data_sources: UniversalDataSources, - with_odfv: bool = True, + data_sources: UniversalDataSources, with_odfv: bool = True, ) -> UniversalFeatureViews: driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver) return UniversalFeatureViews( @@ -371,8 +367,7 @@ def construct_test_environment( # Note: even if it's a local feature server, the repo config does not have this configured feature_server = None registry = RegistryConfig( - path=str(Path(repo_dir_name) / "registry.db"), - cache_ttl_seconds=1, + path=str(Path(repo_dir_name) / "registry.db"), cache_ttl_seconds=1, ) config = RepoConfig( registry=registry, diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py index 6740015a910..baa3db6afc1 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py @@ -106,9 +106,7 @@ def _upload_parquet_file(self, df, file_name, minio_endpoint): if not client.bucket_exists(self.bucket): client.make_bucket(self.bucket) client.fput_object( - self.bucket, - file_name, - self.f.name, + self.bucket, file_name, self.f.name, ) def create_data_source( diff --git a/sdk/python/tests/integration/offline_store/test_s3_custom_endpoint.py b/sdk/python/tests/integration/offline_store/test_s3_custom_endpoint.py index b9f8cefa780..6066009a6d6 100644 --- a/sdk/python/tests/integration/offline_store/test_s3_custom_endpoint.py +++ b/sdk/python/tests/integration/offline_store/test_s3_custom_endpoint.py @@ -17,9 +17,7 @@ @pytest.mark.skip( reason="No way to run this test today. Credentials conflict with real AWS credentials in CI" ) -def test_registration_and_retrieval_from_custom_s3_endpoint( - universal_data_sources, -): +def test_registration_and_retrieval_from_custom_s3_endpoint(universal_data_sources,): config = IntegrationTestRepoConfig( offline_store_creator="tests.integration.feature_repos.universal.data_sources.file.S3FileDataSourceCreator" ) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 390f8e40c0e..541033433f6 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -209,10 +209,7 @@ def get_expected_training_df( (f"global_stats__{k}" if full_feature_names else k): global_record.get( k, None ) - for k in ( - "num_rides", - "avg_ride_length", - ) + for k in ("num_rides", "avg_ride_length",) } ) @@ -858,7 +855,5 @@ def assert_frame_equal(expected_df, actual_df, keys): ) pd_assert_frame_equal( - expected_df, - actual_df, - check_dtype=False, + expected_df, actual_df, check_dtype=False, ) diff --git a/sdk/python/tests/integration/registration/test_cli.py b/sdk/python/tests/integration/registration/test_cli.py index e721bb28c97..655e53e7593 100644 --- a/sdk/python/tests/integration/registration/test_cli.py +++ b/sdk/python/tests/integration/registration/test_cli.py @@ -86,8 +86,7 @@ def test_universal_cli(environment: Environment): assertpy.assert_that(result.returncode).is_equal_to(0) assertpy.assert_that(fs.list_feature_views()).is_length(4) result = runner.run( - ["data-sources", "describe", "customer_profile_source"], - cwd=repo_path, + ["data-sources", "describe", "customer_profile_source"], cwd=repo_path, ) assertpy.assert_that(result.returncode).is_equal_to(0) assertpy.assert_that(fs.list_data_sources()).is_length(4) diff --git a/sdk/python/tests/integration/registration/test_feature_store.py b/sdk/python/tests/integration/registration/test_feature_store.py index 8234c904829..c7345d3f4d6 100644 --- a/sdk/python/tests/integration/registration/test_feature_store.py +++ b/sdk/python/tests/integration/registration/test_feature_store.py @@ -88,8 +88,7 @@ def feature_store_with_s3_registry(): @pytest.mark.parametrize( - "test_feature_store", - [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], ) def test_apply_entity_success(test_feature_store): entity = Entity( @@ -161,8 +160,7 @@ def test_apply_entity_integration(test_feature_store): @pytest.mark.parametrize( - "test_feature_store", - [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], ) def test_apply_feature_view_success(test_feature_store): # Create Feature Views @@ -213,8 +211,7 @@ def test_apply_feature_view_success(test_feature_store): @pytest.mark.integration @pytest.mark.parametrize( - "test_feature_store", - [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], ) @pytest.mark.parametrize("dataframe_source", [lazy_fixture("simple_dataset_1")]) def test_feature_view_inference_success(test_feature_store, dataframe_source): @@ -354,8 +351,7 @@ def test_apply_feature_view_integration(test_feature_store): @pytest.mark.parametrize( - "test_feature_store", - [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], ) def test_apply_object_and_read(test_feature_store): assert isinstance(test_feature_store, FeatureStore) @@ -431,8 +427,7 @@ def test_apply_remote_repo(): @pytest.mark.parametrize( - "test_feature_store", - [lazy_fixture("feature_store_with_local_registry")], + "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], ) @pytest.mark.parametrize("dataframe_source", [lazy_fixture("simple_dataset_1")]) def test_reapply_feature_view_success(test_feature_store, dataframe_source): diff --git a/sdk/python/tests/integration/registration/test_inference.py b/sdk/python/tests/integration/registration/test_inference.py index 2f8b932acb2..0ea6276669d 100644 --- a/sdk/python/tests/integration/registration/test_inference.py +++ b/sdk/python/tests/integration/registration/test_inference.py @@ -45,16 +45,10 @@ def test_update_entities_with_inferred_types_from_feature_views( ) as file_source_2: fv1 = FeatureView( - name="fv1", - entities=["id"], - batch_source=file_source, - ttl=None, + name="fv1", entities=["id"], batch_source=file_source, ttl=None, ) fv2 = FeatureView( - name="fv2", - entities=["id"], - batch_source=file_source_2, - ttl=None, + name="fv2", entities=["id"], batch_source=file_source_2, ttl=None, ) actual_1 = Entity(name="id", join_key="id_join_key") @@ -157,8 +151,7 @@ def test_update_data_sources_with_inferred_event_timestamp_col(universal_data_so data_source.event_timestamp_column = None update_data_sources_with_inferred_event_timestamp_col( - data_sources_copy.values(), - RepoConfig(provider="local", project="test"), + data_sources_copy.values(), RepoConfig(provider="local", project="test"), ) actual_event_timestamp_cols = [ source.event_timestamp_column for source in data_sources_copy.values() diff --git a/sdk/python/tests/integration/registration/test_registry.py b/sdk/python/tests/integration/registration/test_registry.py index 3a510b235a0..535497634d4 100644 --- a/sdk/python/tests/integration/registration/test_registry.py +++ b/sdk/python/tests/integration/registration/test_registry.py @@ -67,8 +67,7 @@ def s3_registry() -> Registry: @pytest.mark.parametrize( - "test_registry", - [lazy_fixture("local_registry")], + "test_registry", [lazy_fixture("local_registry")], ) def test_apply_entity_success(test_registry): entity = Entity( @@ -117,8 +116,7 @@ def test_apply_entity_success(test_registry): @pytest.mark.integration @pytest.mark.parametrize( - "test_registry", - [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], + "test_registry", [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], ) def test_apply_entity_integration(test_registry): entity = Entity( @@ -162,8 +160,7 @@ def test_apply_entity_integration(test_registry): @pytest.mark.parametrize( - "test_registry", - [lazy_fixture("local_registry")], + "test_registry", [lazy_fixture("local_registry")], ) def test_apply_feature_view_success(test_registry): # Create Feature Views @@ -237,8 +234,7 @@ def test_apply_feature_view_success(test_registry): @pytest.mark.parametrize( - "test_registry", - [lazy_fixture("local_registry")], + "test_registry", [lazy_fixture("local_registry")], ) def test_modify_feature_views_success(test_registry): # Create Feature Views @@ -359,8 +355,7 @@ def odfv1(feature_df: pd.DataFrame) -> pd.DataFrame: @pytest.mark.integration @pytest.mark.parametrize( - "test_registry", - [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], + "test_registry", [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], ) def test_apply_feature_view_integration(test_registry): # Create Feature Views @@ -435,8 +430,7 @@ def test_apply_feature_view_integration(test_registry): @pytest.mark.integration @pytest.mark.parametrize( - "test_registry", - [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], + "test_registry", [lazy_fixture("gcs_registry"), lazy_fixture("s3_registry")], ) def test_apply_data_source(test_registry: Registry): # Create Feature Views diff --git a/sdk/python/tests/integration/registration/test_universal_types.py b/sdk/python/tests/integration/registration/test_universal_types.py index 7ec2e183e1e..59ca119f98a 100644 --- a/sdk/python/tests/integration/registration/test_universal_types.py +++ b/sdk/python/tests/integration/registration/test_universal_types.py @@ -91,11 +91,9 @@ def online_types_test_fixtures(request): def get_fixtures(request): config: TypeTestConfig = request.param # Lower case needed because Redshift lower-cases all table names - test_project_id = ( - f"{config.entity_type}{config.feature_dtype}{config.feature_is_list}".replace( - ".", "" - ).lower() - ) + test_project_id = f"{config.entity_type}{config.feature_dtype}{config.feature_is_list}".replace( + ".", "" + ).lower() type_test_environment = construct_test_environment( test_repo_config=config.test_repo_config, test_suite_name=f"test_{test_project_id}", @@ -184,8 +182,7 @@ def test_feature_get_historical_features_types_match(offline_types_test_fixtures ts - timedelta(hours=2), ] historical_features = fs.get_historical_features( - entity_df=entity_df, - features=features, + entity_df=entity_df, features=features, ) # Note: Pandas doesn't play well with nan values in ints. BQ will also coerce to floats if there are NaNs historical_features_df = historical_features.to_df() @@ -233,8 +230,7 @@ def test_feature_get_online_features_types_match(online_types_test_fixtures): driver_id_value = "1" if config.entity_type == ValueType.STRING else 1 online_features = fs.get_online_features( - features=features, - entity_rows=[{"driver": driver_id_value}], + features=features, entity_rows=[{"driver": driver_id_value}], ).to_dict() feature_list_dtype_to_expected_online_response_value_type = { @@ -290,11 +286,7 @@ def create_feature_view( elif feature_dtype == "datetime": value_type = ValueType.UNIX_TIMESTAMP - return driver_feature_view( - data_source, - name=name, - value_type=value_type, - ) + return driver_feature_view(data_source, name=name, value_type=value_type,) def assert_expected_historical_feature_types( diff --git a/sdk/python/tests/unit/diff/test_registry_diff.py b/sdk/python/tests/unit/diff/test_registry_diff.py index 58b7a4c00b6..0322ab47abf 100644 --- a/sdk/python/tests/unit/diff/test_registry_diff.py +++ b/sdk/python/tests/unit/diff/test_registry_diff.py @@ -11,16 +11,10 @@ def test_tag_objects_for_keep_delete_update_add(simple_dataset_1): df=simple_dataset_1, event_timestamp_column="ts_1" ) as file_source: to_delete = FeatureView( - name="to_delete", - entities=["id"], - batch_source=file_source, - ttl=None, + name="to_delete", entities=["id"], batch_source=file_source, ttl=None, ) unchanged_fv = FeatureView( - name="fv1", - entities=["id"], - batch_source=file_source, - ttl=None, + name="fv1", entities=["id"], batch_source=file_source, ttl=None, ) pre_changed = FeatureView( name="fv2", @@ -37,10 +31,7 @@ def test_tag_objects_for_keep_delete_update_add(simple_dataset_1): tags={"when": "after"}, ) to_add = FeatureView( - name="to_add", - entities=["id"], - batch_source=file_source, - ttl=None, + name="to_add", entities=["id"], batch_source=file_source, ttl=None, ) keep, delete, update, add = tag_objects_for_keep_delete_update_add( diff --git a/sdk/python/tests/unit/test_usage.py b/sdk/python/tests/unit/test_usage.py index ca842474307..13988d32642 100644 --- a/sdk/python/tests/unit/test_usage.py +++ b/sdk/python/tests/unit/test_usage.py @@ -234,4 +234,4 @@ def call_length_ms(call): return ( datetime.datetime.fromisoformat(call["end"]) - datetime.datetime.fromisoformat(call["start"]) - ).total_seconds() * 10**3 + ).total_seconds() * 10 ** 3 diff --git a/sdk/python/tests/utils/data_source_utils.py b/sdk/python/tests/utils/data_source_utils.py index d2ad337e21a..5a5baceef07 100644 --- a/sdk/python/tests/utils/data_source_utils.py +++ b/sdk/python/tests/utils/data_source_utils.py @@ -43,8 +43,7 @@ def simple_bq_source_using_table_ref_arg( job.result() return BigQuerySource( - table_ref=table_ref, - event_timestamp_column=event_timestamp_column, + table_ref=table_ref, event_timestamp_column=event_timestamp_column, )