diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index ffafe311257..87765e132f7 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -30,6 +30,11 @@ class FeastObjectNotFoundException(Exception): pass +class ProjectMetadataNotFoundException(FeastObjectNotFoundException): + def __init__(self, project: str): + super().__init__(f"Project Metadata does not exist in project {project}") + + class EntityNotFoundException(FeastObjectNotFoundException): def __init__(self, name, project=None): if project: diff --git a/sdk/python/feast/infra/registry/base_registry.py b/sdk/python/feast/infra/registry/base_registry.py index 33adb6b7c95..8ad164e73ae 100644 --- a/sdk/python/feast/infra/registry/base_registry.py +++ b/sdk/python/feast/infra/registry/base_registry.py @@ -545,15 +545,47 @@ def list_validation_references( """ raise NotImplementedError + @abstractmethod + def apply_project_metadata( + self, + project: str, + commit: bool = True, + ): + """ + Persist a project metadata with a new uuid + + Args: + project: Feast project name + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def delete_project_metadata( + self, + project: str, + commit: bool = True, + ): + """ + Deletes a project metadata or raises ProjectMetadataNotFoundException exception if not found. + + Args: + project: Feast project name + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + @abstractmethod def list_project_metadata( - self, project: str, allow_cache: bool = False + self, + project: Optional[str], + allow_cache: bool = False, ) -> List[ProjectMetadata]: """ - Retrieves project metadata + Retrieves project metadata if given project name otherwise all project metadata Args: - project: Filter metadata based on project name + project: Filter metadata based on project name or None to retrieve all project metadata allow_cache: Allow returning feature views from the cached registry Returns: @@ -561,6 +593,24 @@ def list_project_metadata( """ raise NotImplementedError + @abstractmethod + def get_project_metadata( + self, + project: str, + allow_cache: bool = False, + ) -> Optional[ProjectMetadata]: + """ + Retrieves project metadata if present otherwise None + + Args: + project: Filter metadata based on project name + allow_cache: Allow returning feature views from the cached registry + + Returns: + Get project metadata if project exists otherwise None + """ + raise NotImplementedError + @abstractmethod def update_infra(self, infra: Infra, project: str, commit: bool = True): """ diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index 611d67de96d..15514dc4ed0 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -302,11 +302,11 @@ def list_validation_references( return self._list_validation_references(project, tags) @abstractmethod - def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: + def _list_project_metadata(self, project: Optional[str]) -> List[ProjectMetadata]: pass def list_project_metadata( - self, project: str, allow_cache: bool = False + self, project: Optional[str], allow_cache: bool = False ) -> List[ProjectMetadata]: if allow_cache: self._refresh_cached_registry_if_necessary() @@ -315,6 +315,24 @@ def list_project_metadata( ) return self._list_project_metadata(project) + @abstractmethod + def _get_project_metadata(self, project: str) -> Optional[ProjectMetadata]: + pass + + def get_project_metadata( + self, project: str, allow_cache: bool = False + ) -> Optional[ProjectMetadata]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + project_metadata_proto = proto_registry_utils.get_project_metadata( + self.cached_registry_proto, project + ) + if project_metadata_proto is None: + return None + else: + return ProjectMetadata.from_proto(project_metadata_proto) + return self._get_project_metadata(project) + @abstractmethod def _get_infra(self, project: str) -> Infra: pass diff --git a/sdk/python/feast/infra/registry/proto_registry_utils.py b/sdk/python/feast/infra/registry/proto_registry_utils.py index f67808aab55..9955550f8c3 100644 --- a/sdk/python/feast/infra/registry/proto_registry_utils.py +++ b/sdk/python/feast/infra/registry/proto_registry_utils.py @@ -284,13 +284,19 @@ def list_validation_references( @registry_proto_cache def list_project_metadata( - registry_proto: RegistryProto, project: str + registry_proto: RegistryProto, project: Optional[str] ) -> List[ProjectMetadata]: - return [ - ProjectMetadata.from_proto(project_metadata) - for project_metadata in registry_proto.project_metadata - if project_metadata.project == project - ] + if not project: + return [ + ProjectMetadata.from_proto(project_metadata) + for project_metadata in registry_proto.project_metadata + ] + else: + return [ + ProjectMetadata.from_proto(project_metadata) + for project_metadata in registry_proto.project_metadata + if project_metadata.project == project + ] @registry_proto_cache_with_tags diff --git a/sdk/python/feast/infra/registry/registry.py b/sdk/python/feast/infra/registry/registry.py index 366f3aacaad..7f4f0be68b1 100644 --- a/sdk/python/feast/infra/registry/registry.py +++ b/sdk/python/feast/infra/registry/registry.py @@ -32,6 +32,7 @@ FeatureServiceNotFoundException, FeatureViewNotFoundException, PermissionNotFoundException, + ProjectMetadataNotFoundException, ValidationReferenceNotFound, ) from feast.feature_service import FeatureService @@ -214,9 +215,11 @@ def __init__( self._registry_store = cls(registry_config, repo_path) self.cached_registry_proto_ttl = timedelta( - seconds=registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 + seconds=( + registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) ) def clone(self) -> "Registry": @@ -787,14 +790,86 @@ def delete_validation_reference(self, name: str, project: str, commit: bool = Tr return raise ValidationReferenceNotFound(name, project=project) + def apply_project_metadata(self, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + if commit: + self.commit() + def list_project_metadata( - self, project: str, allow_cache: bool = False + self, project: Optional[str], allow_cache: bool = False ) -> List[ProjectMetadata]: registry_proto = self._get_registry_proto( project=project, allow_cache=allow_cache ) return proto_registry_utils.list_project_metadata(registry_proto, project) + def get_project_metadata( + self, project: str, allow_cache: bool = False + ) -> Optional[ProjectMetadata]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + project_metadata_proto = proto_registry_utils.get_project_metadata( + registry_proto, project + ) + if project_metadata_proto is None: + return None + else: + return ProjectMetadata.from_proto(project_metadata_proto) + + def delete_project_metadata(self, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, project_metadata_proto in enumerate( + self.cached_registry_proto.project_metadata + ): + if project_metadata_proto.project == project: + list_entities = self.list_entities(project) + list_feature_views = self.list_feature_views(project) + list_on_demand_feature_views = self.list_on_demand_feature_views( + project + ) + list_stream_feature_views = self.list_stream_feature_views(project) + list_feature_services = self.list_feature_services(project) + list_data_sources = self.list_data_sources(project) + list_saved_datasets = self.list_saved_datasets(project) + list_validation_references = self.list_validation_references(project) + list_permissions = self.list_permissions(project) + for entity in list_entities: + self.delete_entity(entity.name, project, commit=False) + for feature_view in list_feature_views: + self.delete_feature_view(feature_view.name, project, commit=False) + for on_demand_feature_view in list_on_demand_feature_views: + self.delete_feature_view( + on_demand_feature_view.name, project, commit=False + ) + for stream_feature_view in list_stream_feature_views: + self.delete_feature_view( + stream_feature_view.name, project, commit=False + ) + for feature_service in list_feature_services: + self.delete_feature_service( + feature_service.name, project, commit=False + ) + for data_source in list_data_sources: + self.delete_data_source(data_source.name, project, commit=False) + for saved_dataset in list_saved_datasets: + self.delete_saved_dataset(saved_dataset.name, project, commit=False) + for validation_reference in list_validation_references: + self.delete_validation_reference( + validation_reference.name, project, commit=False + ) + for permission in list_permissions: + self.delete_permission(permission.name, project, commit=False) + del self.cached_registry_proto.project_metadata[idx] + if commit: + self.commit() + return + + raise ProjectMetadataNotFoundException(project) + def commit(self): """Commits the state of the registry cache to the remote registry store.""" if self.cached_registry_proto: diff --git a/sdk/python/feast/infra/registry/remote.py b/sdk/python/feast/infra/registry/remote.py index 618628bc071..67fa776cb45 100644 --- a/sdk/python/feast/infra/registry/remote.py +++ b/sdk/python/feast/infra/registry/remote.py @@ -16,10 +16,7 @@ from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.auth.auth_type import AuthType -from feast.permissions.auth_model import ( - AuthConfig, - NoAuthConfig, -) +from feast.permissions.auth_model import AuthConfig, NoAuthConfig from feast.permissions.client.grpc_client_auth_interceptor import ( GrpcClientAuthHeaderInterceptor, ) @@ -173,15 +170,17 @@ def apply_feature_view( arg_name = "on_demand_feature_view" request = RegistryServer_pb2.ApplyFeatureViewRequest( - feature_view=feature_view.to_proto() - if arg_name == "feature_view" - else None, - stream_feature_view=feature_view.to_proto() - if arg_name == "stream_feature_view" - else None, - on_demand_feature_view=feature_view.to_proto() - if arg_name == "on_demand_feature_view" - else None, + feature_view=( + feature_view.to_proto() if arg_name == "feature_view" else None + ), + stream_feature_view=( + feature_view.to_proto() if arg_name == "stream_feature_view" else None + ), + on_demand_feature_view=( + feature_view.to_proto() + if arg_name == "on_demand_feature_view" + else None + ), project=project, commit=commit, ) @@ -375,14 +374,29 @@ def list_validation_references( ] def list_project_metadata( - self, project: str, allow_cache: bool = False + self, project: Optional[str], allow_cache: bool = False ) -> List[ProjectMetadata]: request = RegistryServer_pb2.ListProjectMetadataRequest( - project=project, allow_cache=allow_cache + project="" if project is None else project, + allow_cache=allow_cache, ) response = self.stub.ListProjectMetadata(request) return [ProjectMetadata.from_proto(pm) for pm in response.project_metadata] + def apply_project_metadata(self, project: StrictStr, commit: bool = True): + # TODO: Add logic for applying project metadata + pass + + def get_project_metadata( + self, project: StrictStr, allow_cache: bool = False + ) -> Optional[ProjectMetadata]: + # TODO: Add logic for getting project metadata + pass + + def delete_project_metadata(self, project: StrictStr, commit: bool = True): + # TODO: Add logic for deleting project metadata + pass + def update_infra(self, infra: Infra, project: str, commit: bool = True): request = RegistryServer_pb2.UpdateInfraRequest( infra=infra.to_proto(), project=project, commit=commit diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 801b90afe38..02df8801dfc 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -143,9 +143,11 @@ def __init__( self.cached_registry_proto_created = _utc_now() self._refresh_lock = Lock() self.cached_registry_proto_ttl = timedelta( - seconds=registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 + seconds=( + registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) ) self.project = project @@ -280,7 +282,7 @@ def _apply_object( proto_field_name: str, name: Optional[str] = None, ): - self._maybe_init_project_metadata(project) + self.apply_project_metadata(project) name = name or (obj.name if hasattr(obj, "name") else None) assert name, f"name needs to be provided for {obj}" @@ -620,7 +622,7 @@ def _get_object( proto_field_name: str, not_found_exception: Optional[Callable], ): - self._maybe_init_project_metadata(project) + self.apply_project_metadata(project) with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT @@ -821,7 +823,7 @@ def _list_objects( proto_field_name: str, tags: Optional[dict[str, str]] = None, ): - self._maybe_init_project_metadata(project) + self.apply_project_metadata(project) with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT @@ -899,8 +901,9 @@ def apply_materialization( ) def list_project_metadata( - self, project: str, allow_cache: bool = False + self, project: Optional[str], allow_cache: bool = False ) -> List[ProjectMetadata]: + # TODO: List all projects when project is None if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_project_metadata( @@ -928,6 +931,44 @@ def list_project_metadata( return [project_metadata] return [] + def get_project_metadata( + self, project: str, allow_cache: bool = False + ) -> Optional[ProjectMetadata]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + project_metadata_proto = proto_registry_utils.get_project_metadata( + self.cached_registry_proto, project + ) + if project_metadata_proto is None: + return None + else: + return ProjectMetadata.from_proto(project_metadata_proto) + + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + metadata_key, + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + project_metadata = ProjectMetadata(project_name=project) + for row in df.iterrows(): + if row[1]["METADATA_KEY"] == FeastMetadataKeys.PROJECT_UUID.value: + project_metadata.project_uuid = row[1]["METADATA_VALUE"] + break + # TODO(adchia): Add other project metadata in a structured way + return project_metadata + return None + + def delete_project_metadata(self, project: StrictStr, commit: bool = True): + pass + def apply_user_metadata( self, project: str, @@ -1091,7 +1132,7 @@ def _infer_fv_table(self, feature_view) -> str: raise ValueError(f"Unexpected feature view type: {type(feature_view)}") return table - def _maybe_init_project_metadata(self, project): + def apply_project_metadata(self, project): with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index 90c6e82e7d9..7917a7f43d2 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -31,6 +31,7 @@ FeatureServiceNotFoundException, FeatureViewNotFoundException, PermissionNotFoundException, + ProjectMetadataNotFoundException, SavedDatasetNotFound, ValidationReferenceNotFound, ) @@ -487,24 +488,60 @@ def _list_on_demand_feature_views( tags=tags, ) - def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: + def _list_project_metadata(self, project: Optional[str]) -> List[ProjectMetadata]: + with self.engine.begin() as conn: + if not project: + stmt = select(feast_metadata) + else: + stmt = select(feast_metadata).where( + feast_metadata.c.project_id == project, + ) + rows = conn.execute(stmt).all() + if rows: + project_metadata_dict: Dict[str, ProjectMetadata] = {} + for row in rows: + project_id = row._mapping["project_id"] + metadata_key = row._mapping["metadata_key"] + metadata_value = row._mapping["metadata_value"] + + if project_id not in project_metadata_dict: + project_metadata_dict[project_id] = ProjectMetadata( + project_name=project_id + ) + + project_metadata_model: ProjectMetadata = project_metadata_dict[ + project_id + ] + if metadata_key == FeastMetadataKeys.PROJECT_UUID.value: + project_metadata_model.project_uuid = metadata_value + return list(project_metadata_dict.values()) + return [] + + def _get_project_metadata( + self, + project: str, + ) -> Optional[ProjectMetadata]: + """ + Returns given project metadata. + """ with self.engine.begin() as conn: stmt = select(feast_metadata).where( feast_metadata.c.project_id == project, ) rows = conn.execute(stmt).all() if rows: - project_metadata = ProjectMetadata(project_name=project) + project_metadata: ProjectMetadata = ProjectMetadata( + project_name=project + ) for row in rows: - if ( - row._mapping["metadata_key"] - == FeastMetadataKeys.PROJECT_UUID.value - ): - project_metadata.project_uuid = row._mapping["metadata_value"] - break - # TODO(adchia): Add other project metadata in a structured way - return [project_metadata] - return [] + metadata_key = row._mapping["metadata_key"] + metadata_value = row._mapping["metadata_value"] + + if metadata_key == FeastMetadataKeys.PROJECT_UUID.value: + project_metadata.project_uuid = metadata_value + return project_metadata + else: + return None def apply_saved_dataset( self, @@ -720,7 +757,7 @@ def _apply_object( proto_field_name: str, name: Optional[str] = None, ): - self._maybe_init_project_metadata(project) + self.apply_project_metadata(project) name = name or (obj.name if hasattr(obj, "name") else None) assert name, f"name needs to be provided for {obj}" @@ -791,7 +828,7 @@ def _apply_object( self._set_last_updated_metadata(update_datetime, project) - def _maybe_init_project_metadata(self, project): + def apply_project_metadata(self, project): # Initialize project metadata if needed with self.engine.begin() as conn: update_datetime = _utc_now() @@ -812,6 +849,25 @@ def _maybe_init_project_metadata(self, project): insert_stmt = insert(feast_metadata).values(values) conn.execute(insert_stmt) + def delete_project_metadata(self, project: str, commit: bool = True): + project_metadata = self.get_project_metadata(project, allow_cache=False) + if project_metadata is None: + raise ProjectMetadataNotFoundException(project) + with self.engine.begin() as conn: + for t in { + entities, + data_sources, + feature_views, + feature_services, + on_demand_feature_views, + saved_datasets, + validation_references, + managed_infra, + feast_metadata, + }: + stmt = delete(t).where(t.c.project_id == project) + conn.execute(stmt) + def _delete_object( self, table: Table, @@ -842,7 +898,7 @@ def _get_object( proto_field_name: str, not_found_exception: Optional[Callable], ): - self._maybe_init_project_metadata(project) + self.apply_project_metadata(project) with self.engine.begin() as conn: stmt = select(table).where( @@ -866,7 +922,7 @@ def _list_objects( proto_field_name: str, tags: Optional[dict[str, str]] = None, ): - self._maybe_init_project_metadata(project) + self.apply_project_metadata(project) with self.engine.begin() as conn: stmt = select(table).where(table.c.project_id == project) rows = conn.execute(stmt).all() diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index c528cee4a84..6be268851b1 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -342,10 +342,12 @@ def test_apply_entity_success(test_registry): # Register Entity test_registry.apply_entity(entity, project) - project_metadata = test_registry.list_project_metadata(project=project) - assert len(project_metadata) == 1 - project_uuid = project_metadata[0].project_uuid - assert len(project_metadata[0].project_uuid) == 36 + project_metadata_list = test_registry.list_project_metadata(project=project) + assert len(project_metadata_list) == 1 + project_metadata_list = test_registry.list_project_metadata(project=None) + assert len(project_metadata_list) == 1 + project_uuid = project_metadata_list[0].project_uuid + assert len(project_metadata_list[0].project_uuid) == 36 assert_project_uuid(project, project_uuid, test_registry) entities = test_registry.list_entities(project, tags=entity.tags)