From 0d8cb664bd18c0b9730aa62063c85ec7f6194e51 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Wed, 18 Jun 2025 22:39:24 -0400 Subject: [PATCH 1/7] feat: Enable materialization for ODFV Transform on Write Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/feature_store.py | 63 +- sdk/python/feast/infra/provider.py | 1076 +++---- .../feast/infra/registry/base_registry.py | 1848 +++++------ sdk/python/feast/infra/registry/registry.py | 2180 ++++++------- sdk/python/feast/infra/registry/remote.py | 1188 +++---- sdk/python/feast/infra/registry/snowflake.py | 2750 ++++++++--------- sdk/python/feast/infra/registry/sql.py | 2516 +++++++-------- .../test_on_demand_python_transformation.py | 175 ++ 8 files changed, 5996 insertions(+), 5800 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 5cc232d5fca..8632d619fec 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -656,7 +656,7 @@ def _make_inferences( def _get_feature_views_to_materialize( self, feature_views: Optional[List[str]], - ) -> List[FeatureView]: + ) -> List[Union[FeatureView, OnDemandFeatureView]]: """ Returns the list of feature views that should be materialized. @@ -669,34 +669,53 @@ def _get_feature_views_to_materialize( FeatureViewNotFoundException: One of the specified feature views could not be found. ValueError: One of the specified feature views is not configured for materialization. """ - feature_views_to_materialize: List[FeatureView] = [] + feature_views_to_materialize: List[Union[FeatureView, OnDemandFeatureView]] = [] if feature_views is None: - feature_views_to_materialize = utils._list_feature_views( + regular_feature_views = utils._list_feature_views( self._registry, self.project, hide_dummy_entity=False ) - feature_views_to_materialize = [ - fv for fv in feature_views_to_materialize if fv.online - ] + feature_views_to_materialize.extend( + [fv for fv in regular_feature_views if fv.online] + ) stream_feature_views_to_materialize = self._list_stream_feature_views( hide_dummy_entity=False ) - feature_views_to_materialize += [ - sfv for sfv in stream_feature_views_to_materialize if sfv.online - ] + feature_views_to_materialize.extend( + [sfv for sfv in stream_feature_views_to_materialize if sfv.online] + ) + on_demand_feature_views_to_materialize = self.list_on_demand_feature_views() + feature_views_to_materialize.extend( + [ + odfv + for odfv in on_demand_feature_views_to_materialize + if odfv.write_to_online_store + ] + ) else: for name in feature_views: + feature_view: Union[FeatureView, OnDemandFeatureView] try: feature_view = self._get_feature_view(name, hide_dummy_entity=False) except FeatureViewNotFoundException: - feature_view = self._get_stream_feature_view( - name, hide_dummy_entity=False - ) + try: + feature_view = self._get_stream_feature_view( + name, hide_dummy_entity=False + ) + except FeatureViewNotFoundException: + feature_view = self.get_on_demand_feature_view(name) - if not feature_view.online: + if hasattr(feature_view, "online") and not feature_view.online: raise ValueError( f"FeatureView {feature_view.name} is not configured to be served online." ) + elif ( + hasattr(feature_view, "write_to_online_store") + and not feature_view.write_to_online_store + ): + raise ValueError( + f"OnDemandFeatureView {feature_view.name} is not configured for write_to_online_store." + ) feature_views_to_materialize.append(feature_view) return feature_views_to_materialize @@ -1312,6 +1331,8 @@ def materialize_incremental( ) # TODO paging large loads for feature_view in feature_views_to_materialize: + if isinstance(feature_view, OnDemandFeatureView): + continue start_date = feature_view.most_recent_end_time if start_date is None: if feature_view.ttl is None: @@ -1340,7 +1361,7 @@ def tqdm_builder(length): return tqdm(total=length, ncols=100) start_date = utils.make_tzaware(start_date) - end_date = utils.make_tzaware(end_date) + end_date = utils.make_tzaware(end_date) or _utc_now() provider.materialize_single_feature_view( config=self.config, @@ -1351,13 +1372,13 @@ def tqdm_builder(length): project=self.project, tqdm_builder=tqdm_builder, ) - - self._registry.apply_materialization( - feature_view, - self.project, - start_date, - end_date, - ) + if not isinstance(feature_view, OnDemandFeatureView): + self._registry.apply_materialization( + feature_view, + self.project, + start_date, + end_date, + ) def materialize( self, diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 15917420af0..39686895a87 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -1,538 +1,538 @@ -from abc import ABC, abstractmethod -from datetime import datetime -from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) - -import pandas as pd -import pyarrow -from tqdm import tqdm - -from feast import FeatureService, errors -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.feature_view import FeatureView -from feast.importer import import_class -from feast.infra.infra_object import Infra -from feast.infra.offline_stores.offline_store import RetrievalJob -from feast.infra.registry.base_registry import BaseRegistry -from feast.infra.supported_async_methods import ProviderAsyncMethods -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.online_response import OnlineResponse -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import RepeatedValue -from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.repo_config import RepoConfig -from feast.saved_dataset import SavedDataset - -PROVIDERS_CLASS_FOR_TYPE = { - "gcp": "feast.infra.passthrough_provider.PassthroughProvider", - "aws": "feast.infra.passthrough_provider.PassthroughProvider", - "local": "feast.infra.passthrough_provider.PassthroughProvider", - "azure": "feast.infra.passthrough_provider.PassthroughProvider", -} - - -class Provider(ABC): - """ - A provider defines an implementation of a feature store object. It orchestrates the various - components of a feature store, such as the offline store, online store, and materialization - engine. It is configured through a RepoConfig object. - """ - - @abstractmethod - def __init__(self, config: RepoConfig): - pass - - @property - def async_supported(self) -> ProviderAsyncMethods: - return ProviderAsyncMethods() - - @abstractmethod - def update_infra( - self, - project: str, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[Union[FeatureView, OnDemandFeatureView]], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, - ): - """ - Reconciles cloud resources with the specified set of Feast objects. - - Args: - project: Feast project to which the objects belong. - tables_to_delete: Feature views whose corresponding infrastructure should be deleted. - tables_to_keep: Feature views whose corresponding infrastructure should not be deleted, and - may need to be updated. - entities_to_delete: Entities whose corresponding infrastructure should be deleted. - entities_to_keep: Entities whose corresponding infrastructure should not be deleted, and - may need to be updated. - partial: If true, tables_to_delete and tables_to_keep are not exhaustive lists, so - infrastructure corresponding to other feature views should be not be touched. - """ - pass - - def plan_infra( - self, config: RepoConfig, desired_registry_proto: RegistryProto - ) -> Infra: - """ - Returns the Infra required to support the desired registry. - - Args: - config: The RepoConfig for the current FeatureStore. - desired_registry_proto: The desired registry, in proto form. - """ - return Infra() - - @abstractmethod - def teardown_infra( - self, - project: str, - tables: Sequence[FeatureView], - entities: Sequence[Entity], - ): - """ - Tears down all cloud resources for the specified set of Feast objects. - - Args: - project: Feast project to which the objects belong. - tables: Feature views whose corresponding infrastructure should be deleted. - entities: Entities whose corresponding infrastructure should be deleted. - """ - pass - - @abstractmethod - def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], - ) -> None: - """ - Writes a batch of feature rows to the online store. - - If a tz-naive timestamp is passed to this method, it is assumed to be UTC. - - Args: - config: The config for the current feature store. - table: Feature view to which these feature rows correspond. - data: A list of quadruplets containing feature data. Each quadruplet contains an entity - key, a dict containing feature values, an event timestamp for the row, and the created - timestamp for the row if it exists. - progress: Function to be called once a batch of rows is written to the online store, used - to show progress. - """ - pass - - @abstractmethod - async def online_write_batch_async( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], - ) -> None: - """ - Writes a batch of feature rows to the online store asynchronously. - - If a tz-naive timestamp is passed to this method, it is assumed to be UTC. - - Args: - config: The config for the current feature store. - table: Feature view to which these feature rows correspond. - data: A list of quadruplets containing feature data. Each quadruplet contains an entity - key, a dict containing feature values, an event timestamp for the row, and the created - timestamp for the row if it exists. - progress: Function to be called once a batch of rows is written to the online store, used - to show progress. - """ - pass - - def ingest_df( - self, - feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], - df: pd.DataFrame, - field_mapping: Optional[Dict] = None, - ): - """ - Persists a dataframe to the online store. - - Args: - feature_view: The feature view to which the dataframe corresponds. - df: The dataframe to be persisted. - field_mapping: A dictionary mapping dataframe column names to feature names. - """ - pass - - async def ingest_df_async( - self, - feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], - df: pd.DataFrame, - field_mapping: Optional[Dict] = None, - ): - """ - Persists a dataframe to the online store asynchronously. - - Args: - feature_view: The feature view to which the dataframe corresponds. - df: The dataframe to be persisted. - field_mapping: A dictionary mapping dataframe column names to feature names. - """ - pass - - def ingest_df_to_offline_store( - self, - feature_view: FeatureView, - df: pyarrow.Table, - ): - """ - Persists a dataframe to the offline store. - - Args: - feature_view: The feature view to which the dataframe corresponds. - df: The dataframe to be persisted. - """ - pass - - @abstractmethod - def materialize_single_feature_view( - self, - config: RepoConfig, - feature_view: FeatureView, - start_date: datetime, - end_date: datetime, - registry: BaseRegistry, - project: str, - tqdm_builder: Callable[[int], tqdm], - ) -> None: - """ - Writes latest feature values in the specified time range to the online store. - - Args: - config: The config for the current feature store. - feature_view: The feature view to materialize. - start_date: The start of the time range. - end_date: The end of the time range. - registry: The registry for the current feature store. - project: Feast project to which the objects belong. - tqdm_builder: A function to monitor the progress of materialization. - """ - pass - - @abstractmethod - def get_historical_features( - self, - config: RepoConfig, - feature_views: List[Union[FeatureView, OnDemandFeatureView]], - feature_refs: List[str], - entity_df: Union[pd.DataFrame, str], - registry: BaseRegistry, - project: str, - full_feature_names: bool, - ) -> RetrievalJob: - """ - Retrieves the point-in-time correct historical feature values for the specified entity rows. - - Args: - config: The config for the current feature store. - feature_views: A list containing all feature views that are referenced in the entity rows. - feature_refs: The features to be retrieved. - entity_df: A collection of rows containing all entity columns on which features need to be joined, - as well as the timestamp column used for point-in-time joins. Either a pandas dataframe can be - provided or a SQL query. - registry: The registry for the current feature store. - project: Feast project to which the feature views belong. - full_feature_names: If True, feature names will be prefixed with the corresponding feature view name, - changing them from the format "feature" to "feature_view__feature" (e.g. "daily_transactions" - changes to "customer_fv__daily_transactions"). - - Returns: - A RetrievalJob that can be executed to get the features. - """ - pass - - @abstractmethod - def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, - ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - """ - Reads features values for the given entity keys. - - Args: - config: The config for the current feature store. - table: The feature view whose feature values should be read. - entity_keys: The list of entity keys for which feature values should be read. - requested_features: The list of features that should be read. - - Returns: - A list of the same length as entity_keys. Each item in the list is a tuple where the first - item is the event timestamp for the row, and the second item is a dict mapping feature names - to values, which are returned in proto format. - """ - pass - - @abstractmethod - def get_online_features( - self, - config: RepoConfig, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], - ], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, - ) -> OnlineResponse: - pass - - @abstractmethod - async def get_online_features_async( - self, - config: RepoConfig, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], - ], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, - ) -> OnlineResponse: - pass - - @abstractmethod - async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, - ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - """ - Reads features values for the given entity keys asynchronously. - - Args: - config: The config for the current feature store. - table: The feature view whose feature values should be read. - entity_keys: The list of entity keys for which feature values should be read. - requested_features: The list of features that should be read. - - Returns: - A list of the same length as entity_keys. Each item in the list is a tuple where the first - item is the event timestamp for the row, and the second item is a dict mapping feature names - to values, which are returned in proto format. - """ - pass - - @abstractmethod - def retrieve_saved_dataset( - self, config: RepoConfig, dataset: SavedDataset - ) -> RetrievalJob: - """ - Reads a saved dataset. - - Args: - config: The config for the current feature store. - dataset: A SavedDataset object containing all parameters necessary for retrieving the dataset. - - Returns: - A RetrievalJob that can be executed to get the saved dataset. - """ - pass - - @abstractmethod - def write_feature_service_logs( - self, - feature_service: FeatureService, - logs: Union[pyarrow.Table, Path], - config: RepoConfig, - registry: BaseRegistry, - ): - """ - Writes features and entities logged by a feature server to the offline store. - - The schema of the logs table is inferred from the specified feature service. Only feature - services with configured logging are accepted. - - Args: - feature_service: The feature service to be logged. - logs: The logs, either as an arrow table or as a path to a parquet directory. - config: The config for the current feature store. - registry: The registry for the current feature store. - """ - pass - - @abstractmethod - def retrieve_feature_service_logs( - self, - feature_service: FeatureService, - start_date: datetime, - end_date: datetime, - config: RepoConfig, - registry: BaseRegistry, - ) -> RetrievalJob: - """ - Reads logged features for the specified time window. - - Args: - feature_service: The feature service whose logs should be retrieved. - start_date: The start of the window. - end_date: The end of the window. - config: The config for the current feature store. - registry: The registry for the current feature store. - - Returns: - A RetrievalJob that can be executed to get the feature service logs. - """ - pass - - def get_feature_server_endpoint(self) -> Optional[str]: - """Returns endpoint for the feature server, if it exists.""" - return None - - @abstractmethod - def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_features: Optional[List[str]], - query: List[float], - top_k: int, - distance_metric: Optional[str] = None, - ) -> List[ - Tuple[ - Optional[datetime], - Optional[EntityKeyProto], - Optional[ValueProto], - Optional[ValueProto], - Optional[ValueProto], - ], - ]: - """ - Searches for the top-k most similar documents in the online document store. - - Args: - distance_metric: distance metric to use for the search. - config: The config for the current feature store. - table: The feature view whose embeddings should be searched. - requested_features: the requested document feature names. - query: The query embedding to search for. - top_k: The number of documents to return. - - Returns: - A list of dictionaries, where each dictionary contains the document feature. - """ - pass - - @abstractmethod - def retrieve_online_documents_v2( - self, - config: RepoConfig, - table: FeatureView, - requested_features: List[str], - query: Optional[List[float]], - top_k: int, - distance_metric: Optional[str] = None, - query_string: Optional[str] = None, - ) -> List[ - Tuple[ - Optional[datetime], - Optional[EntityKeyProto], - Optional[Dict[str, ValueProto]], - ] - ]: - """ - Searches for the top-k most similar documents in the online document store. - - Args: - distance_metric: distance metric to use for the search. - config: The config for the current feature store. - table: The feature view whose embeddings should be searched. - requested_features: the requested document feature names. - query: The query embedding to search for (optional). - top_k: The number of documents to return. - query_string: The query string to search for using keyword search (bm25) (optional) - - Returns: - A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary - of feature key value pairs - """ - pass - - @abstractmethod - def validate_data_source( - self, - config: RepoConfig, - data_source: DataSource, - ): - """ - Validates the underlying data source. - - Args: - config: Configuration object used to configure a feature store. - data_source: DataSource object that needs to be validated - """ - pass - - @abstractmethod - def get_table_column_names_and_types_from_data_source( - self, config: RepoConfig, data_source: DataSource - ) -> Iterable[Tuple[str, str]]: - """ - Returns the list of column names and raw column types for a DataSource. - - Args: - config: Configuration object used to configure a feature store. - data_source: DataSource object - """ - pass - - @abstractmethod - async def initialize(self, config: RepoConfig) -> None: - pass - - @abstractmethod - async def close(self) -> None: - pass - - -def get_provider(config: RepoConfig) -> Provider: - if "." not in config.provider: - if config.provider not in PROVIDERS_CLASS_FOR_TYPE: - raise errors.FeastProviderNotImplementedError(config.provider) - - provider = PROVIDERS_CLASS_FOR_TYPE[config.provider] - else: - provider = config.provider - - # Split provider into module and class names by finding the right-most dot. - # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' - module_name, class_name = provider.rsplit(".", 1) - - cls = import_class(module_name, class_name, "Provider") - - return cls(config) +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +import pandas as pd +import pyarrow +from tqdm import tqdm + +from feast import FeatureService, errors +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.importer import import_class +from feast.infra.infra_object import Infra +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.supported_async_methods import ProviderAsyncMethods +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.online_response import OnlineResponse +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import RepeatedValue +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDataset + +PROVIDERS_CLASS_FOR_TYPE = { + "gcp": "feast.infra.passthrough_provider.PassthroughProvider", + "aws": "feast.infra.passthrough_provider.PassthroughProvider", + "local": "feast.infra.passthrough_provider.PassthroughProvider", + "azure": "feast.infra.passthrough_provider.PassthroughProvider", +} + + +class Provider(ABC): + """ + A provider defines an implementation of a feature store object. It orchestrates the various + components of a feature store, such as the offline store, online store, and materialization + engine. It is configured through a RepoConfig object. + """ + + @abstractmethod + def __init__(self, config: RepoConfig): + pass + + @property + def async_supported(self) -> ProviderAsyncMethods: + return ProviderAsyncMethods() + + @abstractmethod + def update_infra( + self, + project: str, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[Union[FeatureView, OnDemandFeatureView]], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, + ): + """ + Reconciles cloud resources with the specified set of Feast objects. + + Args: + project: Feast project to which the objects belong. + tables_to_delete: Feature views whose corresponding infrastructure should be deleted. + tables_to_keep: Feature views whose corresponding infrastructure should not be deleted, and + may need to be updated. + entities_to_delete: Entities whose corresponding infrastructure should be deleted. + entities_to_keep: Entities whose corresponding infrastructure should not be deleted, and + may need to be updated. + partial: If true, tables_to_delete and tables_to_keep are not exhaustive lists, so + infrastructure corresponding to other feature views should be not be touched. + """ + pass + + def plan_infra( + self, config: RepoConfig, desired_registry_proto: RegistryProto + ) -> Infra: + """ + Returns the Infra required to support the desired registry. + + Args: + config: The RepoConfig for the current FeatureStore. + desired_registry_proto: The desired registry, in proto form. + """ + return Infra() + + @abstractmethod + def teardown_infra( + self, + project: str, + tables: Sequence[FeatureView], + entities: Sequence[Entity], + ): + """ + Tears down all cloud resources for the specified set of Feast objects. + + Args: + project: Feast project to which the objects belong. + tables: Feature views whose corresponding infrastructure should be deleted. + entities: Entities whose corresponding infrastructure should be deleted. + """ + pass + + @abstractmethod + def online_write_batch( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + """ + Writes a batch of feature rows to the online store. + + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. + + Args: + config: The config for the current feature store. + table: Feature view to which these feature rows correspond. + data: A list of quadruplets containing feature data. Each quadruplet contains an entity + key, a dict containing feature values, an event timestamp for the row, and the created + timestamp for the row if it exists. + progress: Function to be called once a batch of rows is written to the online store, used + to show progress. + """ + pass + + @abstractmethod + async def online_write_batch_async( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + """ + Writes a batch of feature rows to the online store asynchronously. + + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. + + Args: + config: The config for the current feature store. + table: Feature view to which these feature rows correspond. + data: A list of quadruplets containing feature data. Each quadruplet contains an entity + key, a dict containing feature values, an event timestamp for the row, and the created + timestamp for the row if it exists. + progress: Function to be called once a batch of rows is written to the online store, used + to show progress. + """ + pass + + def ingest_df( + self, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], + df: pd.DataFrame, + field_mapping: Optional[Dict] = None, + ): + """ + Persists a dataframe to the online store. + + Args: + feature_view: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + field_mapping: A dictionary mapping dataframe column names to feature names. + """ + pass + + async def ingest_df_async( + self, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], + df: pd.DataFrame, + field_mapping: Optional[Dict] = None, + ): + """ + Persists a dataframe to the online store asynchronously. + + Args: + feature_view: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + field_mapping: A dictionary mapping dataframe column names to feature names. + """ + pass + + def ingest_df_to_offline_store( + self, + feature_view: FeatureView, + df: pyarrow.Table, + ): + """ + Persists a dataframe to the offline store. + + Args: + feature_view: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + """ + pass + + @abstractmethod + def materialize_single_feature_view( + self, + config: RepoConfig, + feature_view: Union[FeatureView, OnDemandFeatureView], + start_date: datetime, + end_date: datetime, + registry: BaseRegistry, + project: str, + tqdm_builder: Callable[[int], tqdm], + ) -> None: + """ + Writes latest feature values in the specified time range to the online store. + + Args: + config: The config for the current feature store. + feature_view: The feature view to materialize. + start_date: The start of the time range. + end_date: The end of the time range. + registry: The registry for the current feature store. + project: Feast project to which the objects belong. + tqdm_builder: A function to monitor the progress of materialization. + """ + pass + + @abstractmethod + def get_historical_features( + self, + config: RepoConfig, + feature_views: List[Union[FeatureView, OnDemandFeatureView]], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: BaseRegistry, + project: str, + full_feature_names: bool, + ) -> RetrievalJob: + """ + Retrieves the point-in-time correct historical feature values for the specified entity rows. + + Args: + config: The config for the current feature store. + feature_views: A list containing all feature views that are referenced in the entity rows. + feature_refs: The features to be retrieved. + entity_df: A collection of rows containing all entity columns on which features need to be joined, + as well as the timestamp column used for point-in-time joins. Either a pandas dataframe can be + provided or a SQL query. + registry: The registry for the current feature store. + project: Feast project to which the feature views belong. + full_feature_names: If True, feature names will be prefixed with the corresponding feature view name, + changing them from the format "feature" to "feature_view__feature" (e.g. "daily_transactions" + changes to "customer_fv__daily_transactions"). + + Returns: + A RetrievalJob that can be executed to get the features. + """ + pass + + @abstractmethod + def online_read( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + """ + Reads features values for the given entity keys. + + Args: + config: The config for the current feature store. + table: The feature view whose feature values should be read. + entity_keys: The list of entity keys for which feature values should be read. + requested_features: The list of features that should be read. + + Returns: + A list of the same length as entity_keys. Each item in the list is a tuple where the first + item is the event timestamp for the row, and the second item is a dict mapping feature names + to values, which are returned in proto format. + """ + pass + + @abstractmethod + def get_online_features( + self, + config: RepoConfig, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], + ], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + ) -> OnlineResponse: + pass + + @abstractmethod + async def get_online_features_async( + self, + config: RepoConfig, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], + ], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + ) -> OnlineResponse: + pass + + @abstractmethod + async def online_read_async( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + """ + Reads features values for the given entity keys asynchronously. + + Args: + config: The config for the current feature store. + table: The feature view whose feature values should be read. + entity_keys: The list of entity keys for which feature values should be read. + requested_features: The list of features that should be read. + + Returns: + A list of the same length as entity_keys. Each item in the list is a tuple where the first + item is the event timestamp for the row, and the second item is a dict mapping feature names + to values, which are returned in proto format. + """ + pass + + @abstractmethod + def retrieve_saved_dataset( + self, config: RepoConfig, dataset: SavedDataset + ) -> RetrievalJob: + """ + Reads a saved dataset. + + Args: + config: The config for the current feature store. + dataset: A SavedDataset object containing all parameters necessary for retrieving the dataset. + + Returns: + A RetrievalJob that can be executed to get the saved dataset. + """ + pass + + @abstractmethod + def write_feature_service_logs( + self, + feature_service: FeatureService, + logs: Union[pyarrow.Table, Path], + config: RepoConfig, + registry: BaseRegistry, + ): + """ + Writes features and entities logged by a feature server to the offline store. + + The schema of the logs table is inferred from the specified feature service. Only feature + services with configured logging are accepted. + + Args: + feature_service: The feature service to be logged. + logs: The logs, either as an arrow table or as a path to a parquet directory. + config: The config for the current feature store. + registry: The registry for the current feature store. + """ + pass + + @abstractmethod + def retrieve_feature_service_logs( + self, + feature_service: FeatureService, + start_date: datetime, + end_date: datetime, + config: RepoConfig, + registry: BaseRegistry, + ) -> RetrievalJob: + """ + Reads logged features for the specified time window. + + Args: + feature_service: The feature service whose logs should be retrieved. + start_date: The start of the window. + end_date: The end of the window. + config: The config for the current feature store. + registry: The registry for the current feature store. + + Returns: + A RetrievalJob that can be executed to get the feature service logs. + """ + pass + + def get_feature_server_endpoint(self) -> Optional[str]: + """Returns endpoint for the feature server, if it exists.""" + return None + + @abstractmethod + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_features: Optional[List[str]], + query: List[float], + top_k: int, + distance_metric: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ], + ]: + """ + Searches for the top-k most similar documents in the online document store. + + Args: + distance_metric: distance metric to use for the search. + config: The config for the current feature store. + table: The feature view whose embeddings should be searched. + requested_features: the requested document feature names. + query: The query embedding to search for. + top_k: The number of documents to return. + + Returns: + A list of dictionaries, where each dictionary contains the document feature. + """ + pass + + @abstractmethod + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + query: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Searches for the top-k most similar documents in the online document store. + + Args: + distance_metric: distance metric to use for the search. + config: The config for the current feature store. + table: The feature view whose embeddings should be searched. + requested_features: the requested document feature names. + query: The query embedding to search for (optional). + top_k: The number of documents to return. + query_string: The query string to search for using keyword search (bm25) (optional) + + Returns: + A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary + of feature key value pairs + """ + pass + + @abstractmethod + def validate_data_source( + self, + config: RepoConfig, + data_source: DataSource, + ): + """ + Validates the underlying data source. + + Args: + config: Configuration object used to configure a feature store. + data_source: DataSource object that needs to be validated + """ + pass + + @abstractmethod + def get_table_column_names_and_types_from_data_source( + self, config: RepoConfig, data_source: DataSource + ) -> Iterable[Tuple[str, str]]: + """ + Returns the list of column names and raw column types for a DataSource. + + Args: + config: Configuration object used to configure a feature store. + data_source: DataSource object + """ + pass + + @abstractmethod + async def initialize(self, config: RepoConfig) -> None: + pass + + @abstractmethod + async def close(self) -> None: + pass + + +def get_provider(config: RepoConfig) -> Provider: + if "." not in config.provider: + if config.provider not in PROVIDERS_CLASS_FOR_TYPE: + raise errors.FeastProviderNotImplementedError(config.provider) + + provider = PROVIDERS_CLASS_FOR_TYPE[config.provider] + else: + provider = config.provider + + # Split provider into module and class names by finding the right-most dot. + # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' + module_name, class_name = provider.rsplit(".", 1) + + cls = import_class(module_name, class_name, "Provider") + + return cls(config) diff --git a/sdk/python/feast/infra/registry/base_registry.py b/sdk/python/feast/infra/registry/base_registry.py index f2374edf1b2..c6780ef546d 100644 --- a/sdk/python/feast/infra/registry/base_registry.py +++ b/sdk/python/feast/infra/registry/base_registry.py @@ -1,924 +1,924 @@ -# Copyright 2019 The Feast Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import warnings -from abc import ABC, abstractmethod -from collections import defaultdict -from datetime import datetime -from typing import Any, Dict, List, Optional - -from google.protobuf.json_format import MessageToJson -from google.protobuf.message import Message - -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.infra.infra_object import Infra -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto -from feast.protos.feast.core.FeatureService_pb2 import ( - FeatureService as FeatureServiceProto, -) -from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto -from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( - OnDemandFeatureView as OnDemandFeatureViewProto, -) -from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto -from feast.protos.feast.core.Project_pb2 import Project as ProjectProto -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto -from feast.protos.feast.core.StreamFeatureView_pb2 import ( - StreamFeatureView as StreamFeatureViewProto, -) -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView -from feast.transformation.pandas_transformation import PandasTransformation -from feast.transformation.substrait_transformation import SubstraitTransformation - - -class BaseRegistry(ABC): - """ - The interface that Feast uses to apply, list, retrieve, and delete Feast objects (e.g. entities, - feature views, and data sources). - """ - - # Entity operations - @abstractmethod - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - """ - Registers a single entity with Feast - - Args: - entity: Entity that will be registered - project: Feast project that this entity belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def delete_entity(self, name: str, project: str, commit: bool = True): - """ - Deletes an entity or raises an exception if not found. - - Args: - name: Name of entity - project: Feast project that this entity belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - """ - Retrieves an entity. - - Args: - name: Name of entity - project: Feast project that this entity belongs to - allow_cache: Whether to allow returning this entity from a cached registry - - Returns: - Returns either the specified entity, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_entities( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Entity]: - """ - Retrieve a list of entities from the registry - - Args: - allow_cache: Whether to allow returning entities from a cached registry - project: Filter entities based on project name - tags: Filter by tags - - Returns: - List of entities - """ - raise NotImplementedError - - # Data source operations - @abstractmethod - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - """ - Registers a single data source with Feast - - Args: - data_source: A data source that will be registered - project: Feast project that this data source belongs to - commit: Whether to immediately commit to the registry - """ - raise NotImplementedError - - @abstractmethod - def delete_data_source(self, name: str, project: str, commit: bool = True): - """ - Deletes a data source or raises an exception if not found. - - Args: - name: Name of data source - project: Feast project that this data source belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - """ - Retrieves a data source. - - Args: - name: Name of data source - project: Feast project that this data source belongs to - allow_cache: Whether to allow returning this data source from a cached registry - - Returns: - Returns either the specified data source, or raises an exception if none is found - """ - raise NotImplementedError - - @abstractmethod - def list_data_sources( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[DataSource]: - """ - Retrieve a list of data sources from the registry - - Args: - project: Filter data source based on project name - allow_cache: Whether to allow returning data sources from a cached registry - tags: Filter by tags - - Returns: - List of data sources - """ - raise NotImplementedError - - # Feature service operations - @abstractmethod - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - """ - Registers a single feature service with Feast - - Args: - feature_service: A feature service that will be registered - project: Feast project that this entity belongs to - """ - raise NotImplementedError - - @abstractmethod - def delete_feature_service(self, name: str, project: str, commit: bool = True): - """ - Deletes a feature service or raises an exception if not found. - - Args: - name: Name of feature service - project: Feast project that this feature service belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - """ - Retrieves a feature service. - - Args: - name: Name of feature service - project: Feast project that this feature service belongs to - allow_cache: Whether to allow returning this feature service from a cached registry - - Returns: - Returns either the specified feature service, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_feature_services( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureService]: - """ - Retrieve a list of feature services from the registry - - Args: - allow_cache: Whether to allow returning entities from a cached registry - project: Filter entities based on project name - tags: Filter by tags - - Returns: - List of feature services - """ - raise NotImplementedError - - # Feature view operations - @abstractmethod - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - """ - Registers a single feature view with Feast - - Args: - feature_view: Feature view that will be registered - project: Feast project that this feature view belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def delete_feature_view(self, name: str, project: str, commit: bool = True): - """ - Deletes a feature view or raises an exception if not found. - - Args: - name: Name of feature view - project: Feast project that this feature view belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - # stream feature view operations - @abstractmethod - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> StreamFeatureView: - """ - Retrieves a stream feature view. - - Args: - name: Name of stream feature view - project: Feast project that this feature view belongs to - allow_cache: Allow returning feature view from the cached registry - - Returns: - Returns either the specified feature view, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_stream_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[StreamFeatureView]: - """ - Retrieve a list of stream feature views from the registry - - Args: - project: Filter stream feature views based on project name - allow_cache: Whether to allow returning stream feature views from a cached registry - tags: Filter by tags - - Returns: - List of stream feature views - """ - raise NotImplementedError - - # on demand feature view operations - @abstractmethod - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> OnDemandFeatureView: - """ - Retrieves an on demand feature view. - - Args: - name: Name of on demand feature view - project: Feast project that this on demand feature view belongs to - allow_cache: Whether to allow returning this on demand feature view from a cached registry - - Returns: - Returns either the specified on demand feature view, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_on_demand_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[OnDemandFeatureView]: - """ - Retrieve a list of on demand feature views from the registry - - Args: - project: Filter on demand feature views based on project name - allow_cache: Whether to allow returning on demand feature views from a cached registry - tags: Filter by tags - - Returns: - List of on demand feature views - """ - raise NotImplementedError - - # regular feature view operations - @abstractmethod - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - """ - Retrieves a feature view. - - Args: - name: Name of feature view - project: Feast project that this feature view belongs to - allow_cache: Allow returning feature view from the cached registry - - Returns: - Returns either the specified feature view, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureView]: - """ - Retrieve a list of feature views from the registry - - Args: - allow_cache: Allow returning feature views from the cached registry - project: Filter feature views based on project name - tags: Filter by tags - - Returns: - List of feature views - """ - raise NotImplementedError - - @abstractmethod - def get_any_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> BaseFeatureView: - """ - Retrieves a feature view of any type. - - Args: - name: Name of feature view - project: Feast project that this feature view belongs to - allow_cache: Allow returning feature view from the cached registry - - Returns: - Returns either the specified feature view, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_all_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[BaseFeatureView]: - """ - Retrieve a list of feature views of all types from the registry - - Args: - allow_cache: Allow returning feature views from the cached registry - project: Filter feature views based on project name - tags: Filter by tags - - Returns: - List of feature views - """ - raise NotImplementedError - - @abstractmethod - def apply_materialization( - self, - feature_view: FeatureView, - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - """ - Updates materialization intervals tracked for a single feature view in Feast - - Args: - feature_view: Feature view that will be updated with an additional materialization interval tracked - project: Feast project that this feature view belongs to - start_date (datetime): Start date of the materialization interval to track - end_date (datetime): End date of the materialization interval to track - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - # Saved dataset operations - @abstractmethod - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - """ - Stores a saved dataset metadata with Feast - - Args: - saved_dataset: SavedDataset that will be added / updated to registry - project: Feast project that this dataset belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - """ - Retrieves a saved dataset. - - Args: - name: Name of dataset - project: Feast project that this dataset belongs to - allow_cache: Whether to allow returning this dataset from a cached registry - - Returns: - Returns either the specified SavedDataset, or raises an exception if - none is found - """ - raise NotImplementedError - - def delete_saved_dataset(self, name: str, project: str, commit: bool = True): - """ - Delete a saved dataset. - - Args: - name: Name of dataset - project: Feast project that this dataset belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def list_saved_datasets( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[SavedDataset]: - """ - Retrieves a list of all saved datasets in specified project - - Args: - project: Feast project - allow_cache: Whether to allow returning this dataset from a cached registry - tags: Filter by tags - - Returns: - Returns the list of SavedDatasets - """ - raise NotImplementedError - - # Validation reference operations - @abstractmethod - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - """ - Persist a validation reference - - Args: - validation_reference: ValidationReference that will be added / updated to registry - project: Feast project that this dataset belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - """ - Deletes a validation reference or raises an exception if not found. - - Args: - name: Name of validation reference - project: Feast project that this object belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - """ - Retrieves a validation reference. - - Args: - name: Name of dataset - project: Feast project that this dataset belongs to - allow_cache: Whether to allow returning this dataset from a cached registry - - Returns: - Returns either the specified ValidationReference, or raises an exception if - none is found - """ - raise NotImplementedError - - # TODO: Needs to be implemented. - def list_validation_references( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[ValidationReference]: - """ - Retrieve a list of validation references from the registry - - Args: - project: Filter validation references based on project name - allow_cache: Allow returning validation references from the cached registry - tags: Filter by tags - - Returns: - List of request validation references - """ - raise NotImplementedError - - @abstractmethod - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - """ - Retrieves project metadata - - Args: - project: Filter metadata based on project name - allow_cache: Allow returning feature views from the cached registry - - Returns: - List of project metadata - """ - raise NotImplementedError - - @abstractmethod - def update_infra(self, infra: Infra, project: str, commit: bool = True): - """ - Updates the stored Infra object. - - Args: - infra: The new Infra object to be stored. - project: Feast project that the Infra object refers to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - """ - Retrieves the stored Infra object. - - Args: - project: Feast project that the Infra object refers to - allow_cache: Whether to allow returning this entity from a cached registry - - Returns: - The stored Infra object. - """ - raise NotImplementedError - - @abstractmethod - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): ... - - @abstractmethod - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: ... - - # Permission operations - @abstractmethod - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - """ - Registers a single permission with Feast - - Args: - permission: A permission that will be registered - project: Feast project that this permission belongs to - commit: Whether to immediately commit to the registry - """ - raise NotImplementedError - - @abstractmethod - def delete_permission(self, name: str, project: str, commit: bool = True): - """ - Deletes a permission or raises an exception if not found. - - Args: - name: Name of permission - project: Feast project that this permission belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_permission( - self, name: str, project: str, allow_cache: bool = False - ) -> Permission: - """ - Retrieves a permission. - - Args: - name: Name of permission - project: Feast project that this permission belongs to - allow_cache: Whether to allow returning this permission from a cached registry - - Returns: - Returns either the specified permission, or raises an exception if none is found - """ - raise NotImplementedError - - @abstractmethod - def list_permissions( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Permission]: - """ - Retrieve a list of permissions from the registry - - Args: - project: Filter permission based on project name - allow_cache: Whether to allow returning permissions from a cached registry - - Returns: - List of permissions - """ - raise NotImplementedError - - @abstractmethod - def apply_project( - self, - project: Project, - commit: bool = True, - ): - """ - Registers a project with Feast - - Args: - project: A project that will be registered - commit: Whether to immediately commit to the registry - """ - raise NotImplementedError - - @abstractmethod - def delete_project( - self, - name: str, - commit: bool = True, - ): - """ - Deletes a project or raises an ProjectNotFoundException exception if not found. - - Args: - project: Feast project name that needs to be deleted - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_project( - self, - name: str, - allow_cache: bool = False, - ) -> Project: - """ - Retrieves a project. - - Args: - name: Feast project name - allow_cache: Whether to allow returning this permission from a cached registry - - Returns: - Returns either the specified project, or raises ProjectObjectNotFoundException exception if none is found - """ - raise NotImplementedError - - @abstractmethod - def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Project]: - """ - Retrieve a list of projects from the registry - - Args: - allow_cache: Whether to allow returning permissions from a cached registry - - Returns: - List of project - """ - raise NotImplementedError - - @abstractmethod - def proto(self) -> RegistryProto: - """ - Retrieves a proto version of the registry. - - Returns: - The registry proto object. - """ - raise NotImplementedError - - @abstractmethod - def commit(self): - """Commits the state of the registry cache to the remote registry store.""" - raise NotImplementedError - - @abstractmethod - def refresh(self, project: Optional[str] = None): - """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" - raise NotImplementedError - - @staticmethod - def _message_to_sorted_dict(message: Message) -> Dict[str, Any]: - return json.loads(MessageToJson(message, sort_keys=True)) - - def to_dict(self, project: str) -> Dict[str, List[Any]]: - """Returns a dictionary representation of the registry contents for the specified project. - - For each list in the dictionary, the elements are sorted by name, so this - method can be used to compare two registries. - - Args: - project: Feast project to convert to a dict - """ - registry_dict: Dict[str, Any] = defaultdict(list) - registry_dict["project"] = project - for project_metadata in sorted(self.list_project_metadata(project=project)): - registry_dict["projectMetadata"].append( - self._message_to_sorted_dict(project_metadata.to_proto()) - ) - for data_source in sorted( - self.list_data_sources(project=project), key=lambda ds: ds.name - ): - registry_dict["dataSources"].append( - self._message_to_sorted_dict(data_source.to_proto()) - ) - for entity in sorted( - self.list_entities(project=project), - key=lambda entity: entity.name, - ): - registry_dict["entities"].append( - self._message_to_sorted_dict(entity.to_proto()) - ) - for feature_view in sorted( - self.list_feature_views(project=project), - key=lambda feature_view: feature_view.name, - ): - registry_dict["featureViews"].append( - self._message_to_sorted_dict(feature_view.to_proto()) - ) - for feature_service in sorted( - self.list_feature_services(project=project), - key=lambda feature_service: feature_service.name, - ): - registry_dict["featureServices"].append( - self._message_to_sorted_dict(feature_service.to_proto()) - ) - for on_demand_feature_view in sorted( - self.list_on_demand_feature_views(project=project), - key=lambda on_demand_feature_view: on_demand_feature_view.name, - ): - odfv_dict = self._message_to_sorted_dict(on_demand_feature_view.to_proto()) - # We are logging a warning because the registry object may be read from a proto that is not updated - # i.e., we have to submit dual writes but in order to ensure the read behavior succeeds we have to load - # both objects to compare any changes in the registry - warnings.warn( - "We will be deprecating the usage of spec.userDefinedFunction in a future release please upgrade cautiously.", - DeprecationWarning, - ) - if on_demand_feature_view.feature_transformation: - if isinstance( - on_demand_feature_view.feature_transformation, PandasTransformation - ): - if "userDefinedFunction" not in odfv_dict["spec"]: - odfv_dict["spec"]["userDefinedFunction"] = {} - odfv_dict["spec"]["userDefinedFunction"]["body"] = ( - on_demand_feature_view.feature_transformation.udf_string - ) - odfv_dict["spec"]["featureTransformation"]["userDefinedFunction"][ - "body" - ] = on_demand_feature_view.feature_transformation.udf_string - elif isinstance( - on_demand_feature_view.feature_transformation, - SubstraitTransformation, - ): - odfv_dict["spec"]["featureTransformation"]["substraitPlan"][ - "body" - ] = on_demand_feature_view.feature_transformation.substrait_plan - else: - odfv_dict["spec"]["featureTransformation"]["userDefinedFunction"][ - "body" - ] = None - odfv_dict["spec"]["featureTransformation"]["substraitPlan"][ - "body" - ] = None - registry_dict["onDemandFeatureViews"].append(odfv_dict) - for stream_feature_view in sorted( - self.list_stream_feature_views(project=project), - key=lambda stream_feature_view: stream_feature_view.name, - ): - sfv_dict = self._message_to_sorted_dict(stream_feature_view.to_proto()) - - sfv_dict["spec"]["userDefinedFunction"]["body"] = ( - stream_feature_view.udf_string - ) - registry_dict["streamFeatureViews"].append(sfv_dict) - - for saved_dataset in sorted( - self.list_saved_datasets(project=project), key=lambda item: item.name - ): - registry_dict["savedDatasets"].append( - self._message_to_sorted_dict(saved_dataset.to_proto()) - ) - for infra_object in sorted(self.get_infra(project=project).infra_objects): - registry_dict["infra"].append( - self._message_to_sorted_dict(infra_object.to_proto()) - ) - for permission in sorted( - self.list_permissions(project=project), key=lambda ds: ds.name - ): - registry_dict["permissions"].append( - self._message_to_sorted_dict(permission.to_proto()) - ) - - return registry_dict - - @staticmethod - def deserialize_registry_values(serialized_proto, feast_obj_type) -> Any: - if feast_obj_type == Entity: - return EntityProto.FromString(serialized_proto) - if feast_obj_type == SavedDataset: - return SavedDatasetProto.FromString(serialized_proto) - if feast_obj_type == FeatureView: - return FeatureViewProto.FromString(serialized_proto) - if feast_obj_type == StreamFeatureView: - return StreamFeatureViewProto.FromString(serialized_proto) - if feast_obj_type == OnDemandFeatureView: - return OnDemandFeatureViewProto.FromString(serialized_proto) - if feast_obj_type == FeatureService: - return FeatureServiceProto.FromString(serialized_proto) - if feast_obj_type == Permission: - return PermissionProto.FromString(serialized_proto) - if feast_obj_type == Project: - return ProjectProto.FromString(serialized_proto) - return None +# Copyright 2019 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +from google.protobuf.json_format import MessageToJson +from google.protobuf.message import Message + +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto +from feast.protos.feast.core.FeatureService_pb2 import ( + FeatureService as FeatureServiceProto, +) +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( + OnDemandFeatureView as OnDemandFeatureViewProto, +) +from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto +from feast.protos.feast.core.StreamFeatureView_pb2 import ( + StreamFeatureView as StreamFeatureViewProto, +) +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView +from feast.transformation.pandas_transformation import PandasTransformation +from feast.transformation.substrait_transformation import SubstraitTransformation + + +class BaseRegistry(ABC): + """ + The interface that Feast uses to apply, list, retrieve, and delete Feast objects (e.g. entities, + feature views, and data sources). + """ + + # Entity operations + @abstractmethod + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + """ + Registers a single entity with Feast + + Args: + entity: Entity that will be registered + project: Feast project that this entity belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def delete_entity(self, name: str, project: str, commit: bool = True): + """ + Deletes an entity or raises an exception if not found. + + Args: + name: Name of entity + project: Feast project that this entity belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + """ + Retrieves an entity. + + Args: + name: Name of entity + project: Feast project that this entity belongs to + allow_cache: Whether to allow returning this entity from a cached registry + + Returns: + Returns either the specified entity, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: + """ + Retrieve a list of entities from the registry + + Args: + allow_cache: Whether to allow returning entities from a cached registry + project: Filter entities based on project name + tags: Filter by tags + + Returns: + List of entities + """ + raise NotImplementedError + + # Data source operations + @abstractmethod + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + """ + Registers a single data source with Feast + + Args: + data_source: A data source that will be registered + project: Feast project that this data source belongs to + commit: Whether to immediately commit to the registry + """ + raise NotImplementedError + + @abstractmethod + def delete_data_source(self, name: str, project: str, commit: bool = True): + """ + Deletes a data source or raises an exception if not found. + + Args: + name: Name of data source + project: Feast project that this data source belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + """ + Retrieves a data source. + + Args: + name: Name of data source + project: Feast project that this data source belongs to + allow_cache: Whether to allow returning this data source from a cached registry + + Returns: + Returns either the specified data source, or raises an exception if none is found + """ + raise NotImplementedError + + @abstractmethod + def list_data_sources( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[DataSource]: + """ + Retrieve a list of data sources from the registry + + Args: + project: Filter data source based on project name + allow_cache: Whether to allow returning data sources from a cached registry + tags: Filter by tags + + Returns: + List of data sources + """ + raise NotImplementedError + + # Feature service operations + @abstractmethod + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + """ + Registers a single feature service with Feast + + Args: + feature_service: A feature service that will be registered + project: Feast project that this entity belongs to + """ + raise NotImplementedError + + @abstractmethod + def delete_feature_service(self, name: str, project: str, commit: bool = True): + """ + Deletes a feature service or raises an exception if not found. + + Args: + name: Name of feature service + project: Feast project that this feature service belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + """ + Retrieves a feature service. + + Args: + name: Name of feature service + project: Feast project that this feature service belongs to + allow_cache: Whether to allow returning this feature service from a cached registry + + Returns: + Returns either the specified feature service, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_feature_services( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureService]: + """ + Retrieve a list of feature services from the registry + + Args: + allow_cache: Whether to allow returning entities from a cached registry + project: Filter entities based on project name + tags: Filter by tags + + Returns: + List of feature services + """ + raise NotImplementedError + + # Feature view operations + @abstractmethod + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + """ + Registers a single feature view with Feast + + Args: + feature_view: Feature view that will be registered + project: Feast project that this feature view belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def delete_feature_view(self, name: str, project: str, commit: bool = True): + """ + Deletes a feature view or raises an exception if not found. + + Args: + name: Name of feature view + project: Feast project that this feature view belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + # stream feature view operations + @abstractmethod + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> StreamFeatureView: + """ + Retrieves a stream feature view. + + Args: + name: Name of stream feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_stream_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[StreamFeatureView]: + """ + Retrieve a list of stream feature views from the registry + + Args: + project: Filter stream feature views based on project name + allow_cache: Whether to allow returning stream feature views from a cached registry + tags: Filter by tags + + Returns: + List of stream feature views + """ + raise NotImplementedError + + # on demand feature view operations + @abstractmethod + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + """ + Retrieves an on demand feature view. + + Args: + name: Name of on demand feature view + project: Feast project that this on demand feature view belongs to + allow_cache: Whether to allow returning this on demand feature view from a cached registry + + Returns: + Returns either the specified on demand feature view, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_on_demand_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[OnDemandFeatureView]: + """ + Retrieve a list of on demand feature views from the registry + + Args: + project: Filter on demand feature views based on project name + allow_cache: Whether to allow returning on demand feature views from a cached registry + tags: Filter by tags + + Returns: + List of on demand feature views + """ + raise NotImplementedError + + # regular feature view operations + @abstractmethod + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + """ + Retrieves a feature view. + + Args: + name: Name of feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureView]: + """ + Retrieve a list of feature views from the registry + + Args: + allow_cache: Allow returning feature views from the cached registry + project: Filter feature views based on project name + tags: Filter by tags + + Returns: + List of feature views + """ + raise NotImplementedError + + @abstractmethod + def get_any_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> BaseFeatureView: + """ + Retrieves a feature view of any type. + + Args: + name: Name of feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_all_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[BaseFeatureView]: + """ + Retrieve a list of feature views of all types from the registry + + Args: + allow_cache: Allow returning feature views from the cached registry + project: Filter feature views based on project name + tags: Filter by tags + + Returns: + List of feature views + """ + raise NotImplementedError + + @abstractmethod + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + """ + Updates materialization intervals tracked for a single feature view in Feast + + Args: + feature_view: Feature view that will be updated with an additional materialization interval tracked + project: Feast project that this feature view belongs to + start_date (datetime): Start date of the materialization interval to track + end_date (datetime): End date of the materialization interval to track + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + # Saved dataset operations + @abstractmethod + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + """ + Stores a saved dataset metadata with Feast + + Args: + saved_dataset: SavedDataset that will be added / updated to registry + project: Feast project that this dataset belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + """ + Retrieves a saved dataset. + + Args: + name: Name of dataset + project: Feast project that this dataset belongs to + allow_cache: Whether to allow returning this dataset from a cached registry + + Returns: + Returns either the specified SavedDataset, or raises an exception if + none is found + """ + raise NotImplementedError + + def delete_saved_dataset(self, name: str, project: str, commit: bool = True): + """ + Delete a saved dataset. + + Args: + name: Name of dataset + project: Feast project that this dataset belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def list_saved_datasets( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[SavedDataset]: + """ + Retrieves a list of all saved datasets in specified project + + Args: + project: Feast project + allow_cache: Whether to allow returning this dataset from a cached registry + tags: Filter by tags + + Returns: + Returns the list of SavedDatasets + """ + raise NotImplementedError + + # Validation reference operations + @abstractmethod + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + """ + Persist a validation reference + + Args: + validation_reference: ValidationReference that will be added / updated to registry + project: Feast project that this dataset belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + """ + Deletes a validation reference or raises an exception if not found. + + Args: + name: Name of validation reference + project: Feast project that this object belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + """ + Retrieves a validation reference. + + Args: + name: Name of dataset + project: Feast project that this dataset belongs to + allow_cache: Whether to allow returning this dataset from a cached registry + + Returns: + Returns either the specified ValidationReference, or raises an exception if + none is found + """ + raise NotImplementedError + + # TODO: Needs to be implemented. + def list_validation_references( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[ValidationReference]: + """ + Retrieve a list of validation references from the registry + + Args: + project: Filter validation references based on project name + allow_cache: Allow returning validation references from the cached registry + tags: Filter by tags + + Returns: + List of request validation references + """ + raise NotImplementedError + + @abstractmethod + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + """ + Retrieves project metadata + + Args: + project: Filter metadata based on project name + allow_cache: Allow returning feature views from the cached registry + + Returns: + List of project metadata + """ + raise NotImplementedError + + @abstractmethod + def update_infra(self, infra: Infra, project: str, commit: bool = True): + """ + Updates the stored Infra object. + + Args: + infra: The new Infra object to be stored. + project: Feast project that the Infra object refers to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + """ + Retrieves the stored Infra object. + + Args: + project: Feast project that the Infra object refers to + allow_cache: Whether to allow returning this entity from a cached registry + + Returns: + The stored Infra object. + """ + raise NotImplementedError + + @abstractmethod + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): ... + + @abstractmethod + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: ... + + # Permission operations + @abstractmethod + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + """ + Registers a single permission with Feast + + Args: + permission: A permission that will be registered + project: Feast project that this permission belongs to + commit: Whether to immediately commit to the registry + """ + raise NotImplementedError + + @abstractmethod + def delete_permission(self, name: str, project: str, commit: bool = True): + """ + Deletes a permission or raises an exception if not found. + + Args: + name: Name of permission + project: Feast project that this permission belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_permission( + self, name: str, project: str, allow_cache: bool = False + ) -> Permission: + """ + Retrieves a permission. + + Args: + name: Name of permission + project: Feast project that this permission belongs to + allow_cache: Whether to allow returning this permission from a cached registry + + Returns: + Returns either the specified permission, or raises an exception if none is found + """ + raise NotImplementedError + + @abstractmethod + def list_permissions( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Permission]: + """ + Retrieve a list of permissions from the registry + + Args: + project: Filter permission based on project name + allow_cache: Whether to allow returning permissions from a cached registry + + Returns: + List of permissions + """ + raise NotImplementedError + + @abstractmethod + def apply_project( + self, + project: Project, + commit: bool = True, + ): + """ + Registers a project with Feast + + Args: + project: A project that will be registered + commit: Whether to immediately commit to the registry + """ + raise NotImplementedError + + @abstractmethod + def delete_project( + self, + name: str, + commit: bool = True, + ): + """ + Deletes a project or raises an ProjectNotFoundException exception if not found. + + Args: + project: Feast project name that needs to be deleted + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + """ + Retrieves a project. + + Args: + name: Feast project name + allow_cache: Whether to allow returning this permission from a cached registry + + Returns: + Returns either the specified project, or raises ProjectObjectNotFoundException exception if none is found + """ + raise NotImplementedError + + @abstractmethod + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + """ + Retrieve a list of projects from the registry + + Args: + allow_cache: Whether to allow returning permissions from a cached registry + + Returns: + List of project + """ + raise NotImplementedError + + @abstractmethod + def proto(self) -> RegistryProto: + """ + Retrieves a proto version of the registry. + + Returns: + The registry proto object. + """ + raise NotImplementedError + + @abstractmethod + def commit(self): + """Commits the state of the registry cache to the remote registry store.""" + raise NotImplementedError + + @abstractmethod + def refresh(self, project: Optional[str] = None): + """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" + raise NotImplementedError + + @staticmethod + def _message_to_sorted_dict(message: Message) -> Dict[str, Any]: + return json.loads(MessageToJson(message, sort_keys=True)) + + def to_dict(self, project: str) -> Dict[str, List[Any]]: + """Returns a dictionary representation of the registry contents for the specified project. + + For each list in the dictionary, the elements are sorted by name, so this + method can be used to compare two registries. + + Args: + project: Feast project to convert to a dict + """ + registry_dict: Dict[str, Any] = defaultdict(list) + registry_dict["project"] = project + for project_metadata in sorted(self.list_project_metadata(project=project)): + registry_dict["projectMetadata"].append( + self._message_to_sorted_dict(project_metadata.to_proto()) + ) + for data_source in sorted( + self.list_data_sources(project=project), key=lambda ds: ds.name + ): + registry_dict["dataSources"].append( + self._message_to_sorted_dict(data_source.to_proto()) + ) + for entity in sorted( + self.list_entities(project=project), + key=lambda entity: entity.name, + ): + registry_dict["entities"].append( + self._message_to_sorted_dict(entity.to_proto()) + ) + for feature_view in sorted( + self.list_feature_views(project=project), + key=lambda feature_view: feature_view.name, + ): + registry_dict["featureViews"].append( + self._message_to_sorted_dict(feature_view.to_proto()) + ) + for feature_service in sorted( + self.list_feature_services(project=project), + key=lambda feature_service: feature_service.name, + ): + registry_dict["featureServices"].append( + self._message_to_sorted_dict(feature_service.to_proto()) + ) + for on_demand_feature_view in sorted( + self.list_on_demand_feature_views(project=project), + key=lambda on_demand_feature_view: on_demand_feature_view.name, + ): + odfv_dict = self._message_to_sorted_dict(on_demand_feature_view.to_proto()) + # We are logging a warning because the registry object may be read from a proto that is not updated + # i.e., we have to submit dual writes but in order to ensure the read behavior succeeds we have to load + # both objects to compare any changes in the registry + warnings.warn( + "We will be deprecating the usage of spec.userDefinedFunction in a future release please upgrade cautiously.", + DeprecationWarning, + ) + if on_demand_feature_view.feature_transformation: + if isinstance( + on_demand_feature_view.feature_transformation, PandasTransformation + ): + if "userDefinedFunction" not in odfv_dict["spec"]: + odfv_dict["spec"]["userDefinedFunction"] = {} + odfv_dict["spec"]["userDefinedFunction"]["body"] = ( + on_demand_feature_view.feature_transformation.udf_string + ) + odfv_dict["spec"]["featureTransformation"]["userDefinedFunction"][ + "body" + ] = on_demand_feature_view.feature_transformation.udf_string + elif isinstance( + on_demand_feature_view.feature_transformation, + SubstraitTransformation, + ): + odfv_dict["spec"]["featureTransformation"]["substraitPlan"][ + "body" + ] = on_demand_feature_view.feature_transformation.substrait_plan + else: + odfv_dict["spec"]["featureTransformation"]["userDefinedFunction"][ + "body" + ] = None + odfv_dict["spec"]["featureTransformation"]["substraitPlan"][ + "body" + ] = None + registry_dict["onDemandFeatureViews"].append(odfv_dict) + for stream_feature_view in sorted( + self.list_stream_feature_views(project=project), + key=lambda stream_feature_view: stream_feature_view.name, + ): + sfv_dict = self._message_to_sorted_dict(stream_feature_view.to_proto()) + + sfv_dict["spec"]["userDefinedFunction"]["body"] = ( + stream_feature_view.udf_string + ) + registry_dict["streamFeatureViews"].append(sfv_dict) + + for saved_dataset in sorted( + self.list_saved_datasets(project=project), key=lambda item: item.name + ): + registry_dict["savedDatasets"].append( + self._message_to_sorted_dict(saved_dataset.to_proto()) + ) + for infra_object in sorted(self.get_infra(project=project).infra_objects): + registry_dict["infra"].append( + self._message_to_sorted_dict(infra_object.to_proto()) + ) + for permission in sorted( + self.list_permissions(project=project), key=lambda ds: ds.name + ): + registry_dict["permissions"].append( + self._message_to_sorted_dict(permission.to_proto()) + ) + + return registry_dict + + @staticmethod + def deserialize_registry_values(serialized_proto, feast_obj_type) -> Any: + if feast_obj_type == Entity: + return EntityProto.FromString(serialized_proto) + if feast_obj_type == SavedDataset: + return SavedDatasetProto.FromString(serialized_proto) + if feast_obj_type == FeatureView: + return FeatureViewProto.FromString(serialized_proto) + if feast_obj_type == StreamFeatureView: + return StreamFeatureViewProto.FromString(serialized_proto) + if feast_obj_type == OnDemandFeatureView: + return OnDemandFeatureViewProto.FromString(serialized_proto) + if feast_obj_type == FeatureService: + return FeatureServiceProto.FromString(serialized_proto) + if feast_obj_type == Permission: + return PermissionProto.FromString(serialized_proto) + if feast_obj_type == Project: + return ProjectProto.FromString(serialized_proto) + return None diff --git a/sdk/python/feast/infra/registry/registry.py b/sdk/python/feast/infra/registry/registry.py index 62a21d5c433..acb82546b4f 100644 --- a/sdk/python/feast/infra/registry/registry.py +++ b/sdk/python/feast/infra/registry/registry.py @@ -1,1090 +1,1090 @@ -# Copyright 2019 The Feast Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from datetime import datetime, timedelta, timezone -from enum import Enum -from pathlib import Path -from threading import Lock -from typing import Any, Dict, List, Optional -from urllib.parse import urlparse - -from google.protobuf.internal.containers import RepeatedCompositeFieldContainer -from google.protobuf.message import Message - -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.errors import ( - ConflictingFeatureViewNames, - DataSourceNotFoundException, - EntityNotFoundException, - FeatureServiceNotFoundException, - FeatureViewNotFoundException, - PermissionNotFoundException, - ProjectNotFoundException, - ProjectObjectNotFoundException, - ValidationReferenceNotFound, -) -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.importer import import_class -from feast.infra.infra_object import Infra -from feast.infra.registry import proto_registry_utils -from feast.infra.registry.base_registry import BaseRegistry -from feast.infra.registry.registry_store import NoopRegistryStore -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.auth_model import AuthConfig, NoAuthConfig -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.repo_config import RegistryConfig -from feast.repo_contents import RepoContents -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView -from feast.utils import _utc_now - -REGISTRY_SCHEMA_VERSION = "1" - -REGISTRY_STORE_CLASS_FOR_TYPE = { - "GCSRegistryStore": "feast.infra.registry.gcs.GCSRegistryStore", - "S3RegistryStore": "feast.infra.registry.s3.S3RegistryStore", - "FileRegistryStore": "feast.infra.registry.file.FileRegistryStore", - "AzureRegistryStore": "feast.infra.registry.contrib.azure.azure_registry_store.AzBlobRegistryStore", -} - -REGISTRY_STORE_CLASS_FOR_SCHEME = { - "gs": "GCSRegistryStore", - "s3": "S3RegistryStore", - "file": "FileRegistryStore", - "": "FileRegistryStore", -} - - -class FeastObjectType(Enum): - PROJECT = "project" - DATA_SOURCE = "data source" - ENTITY = "entity" - FEATURE_VIEW = "feature view" - ON_DEMAND_FEATURE_VIEW = "on demand feature view" - STREAM_FEATURE_VIEW = "stream feature view" - FEATURE_SERVICE = "feature service" - PERMISSION = "permission" - - @staticmethod - def get_objects_from_registry( - registry: "BaseRegistry", project: str - ) -> Dict["FeastObjectType", List[Any]]: - return { - FeastObjectType.PROJECT: [ - project_obj - for project_obj in registry.list_projects() - if project_obj.name == project - ], - FeastObjectType.DATA_SOURCE: registry.list_data_sources(project=project), - FeastObjectType.ENTITY: registry.list_entities(project=project), - FeastObjectType.FEATURE_VIEW: registry.list_feature_views(project=project), - FeastObjectType.ON_DEMAND_FEATURE_VIEW: registry.list_on_demand_feature_views( - project=project - ), - FeastObjectType.STREAM_FEATURE_VIEW: registry.list_stream_feature_views( - project=project, - ), - FeastObjectType.FEATURE_SERVICE: registry.list_feature_services( - project=project - ), - FeastObjectType.PERMISSION: registry.list_permissions(project=project), - } - - @staticmethod - def get_objects_from_repo_contents( - repo_contents: RepoContents, - ) -> Dict["FeastObjectType", List[Any]]: - return { - FeastObjectType.PROJECT: repo_contents.projects, - FeastObjectType.DATA_SOURCE: repo_contents.data_sources, - FeastObjectType.ENTITY: repo_contents.entities, - FeastObjectType.FEATURE_VIEW: repo_contents.feature_views, - FeastObjectType.ON_DEMAND_FEATURE_VIEW: repo_contents.on_demand_feature_views, - FeastObjectType.STREAM_FEATURE_VIEW: repo_contents.stream_feature_views, - FeastObjectType.FEATURE_SERVICE: repo_contents.feature_services, - FeastObjectType.PERMISSION: repo_contents.permissions, - } - - -FEAST_OBJECT_TYPES = [feast_object_type for feast_object_type in FeastObjectType] - -logger = logging.getLogger(__name__) - - -def get_registry_store_class_from_type(registry_store_type: str): - if not registry_store_type.endswith("RegistryStore"): - raise Exception('Registry store class name should end with "RegistryStore"') - if registry_store_type in REGISTRY_STORE_CLASS_FOR_TYPE: - registry_store_type = REGISTRY_STORE_CLASS_FOR_TYPE[registry_store_type] - module_name, registry_store_class_name = registry_store_type.rsplit(".", 1) - - return import_class(module_name, registry_store_class_name, "RegistryStore") - - -def get_registry_store_class_from_scheme(registry_path: str): - uri = urlparse(registry_path) - if uri.scheme not in REGISTRY_STORE_CLASS_FOR_SCHEME: - raise Exception( - f"Registry path {registry_path} has unsupported scheme {uri.scheme}. " - f"Supported schemes are file, s3 and gs." - ) - else: - registry_store_type = REGISTRY_STORE_CLASS_FOR_SCHEME[uri.scheme] - return get_registry_store_class_from_type(registry_store_type) - - -class Registry(BaseRegistry): - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): - pass - - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: - pass - - # The cached_registry_proto object is used for both reads and writes. In particular, - # all write operations refresh the cache and modify it in memory; the write must - # then be persisted to the underlying RegistryStore with a call to commit(). - cached_registry_proto: RegistryProto - cached_registry_proto_created: datetime - cached_registry_proto_ttl: timedelta - - def __init__( - self, - project: str, - registry_config: Optional[RegistryConfig], - repo_path: Optional[Path], - auth_config: AuthConfig = NoAuthConfig(), - ): - """ - Create the Registry object. - - Args: - registry_config: RegistryConfig object containing the destination path and cache ttl, - repo_path: Path to the base of the Feast repository - or where it will be created if it does not exist yet. - """ - - self._refresh_lock = Lock() - self._auth_config = auth_config - - registry_proto = RegistryProto() - registry_proto.registry_schema_version = REGISTRY_SCHEMA_VERSION - self.cached_registry_proto = registry_proto - self.cached_registry_proto_created = _utc_now() - - self.purge_feast_metadata = ( - registry_config.purge_feast_metadata - if registry_config is not None - else False - ) - - if registry_config: - registry_store_type = registry_config.registry_store_type - registry_path = registry_config.path - if registry_store_type is None: - cls = get_registry_store_class_from_scheme(registry_path) - else: - cls = get_registry_store_class_from_type(str(registry_store_type)) - - self._registry_store = cls(registry_config, repo_path) - self.cached_registry_proto_ttl = timedelta( - seconds=( - registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 - ) - ) - - try: - registry_proto = self._registry_store.get_registry_proto() - self.cached_registry_proto = registry_proto - self.cached_registry_proto_created = _utc_now() - # Sync feast_metadata to projects table - # when purge_feast_metadata is set to True, Delete data from - # feast_metadata table and list_project_metadata will not return any data - self._sync_feast_metadata_to_projects_table() - except FileNotFoundError: - logger.info("Registry file not found. Creating new registry.") - self.commit() - - def _sync_feast_metadata_to_projects_table(self): - """ - Sync feast_metadata to projects table - """ - feast_metadata_projects = [] - projects_set = [] - # List of project in project_metadata - for project_metadata in self.cached_registry_proto.project_metadata: - project = ProjectMetadata.from_proto(project_metadata) - feast_metadata_projects.append(project.project_name) - if len(feast_metadata_projects) > 0: - # List of project in projects - for project_metadata in self.cached_registry_proto.projects: - project = Project.from_proto(project_metadata) - projects_set.append(project.name) - - # Find object in feast_metadata_projects but not in projects - projects_to_sync = set(feast_metadata_projects) - set(projects_set) - # Sync feast_metadata to projects table - for project_name in projects_to_sync: - project = Project(name=project_name) - self.cached_registry_proto.projects.append(project.to_proto()) - - if self.purge_feast_metadata: - self.cached_registry_proto.project_metadata = [] - - def clone(self) -> "Registry": - new_registry = Registry("project", None, None, self._auth_config) - new_registry.cached_registry_proto_ttl = timedelta(seconds=0) - new_registry.cached_registry_proto = ( - self.cached_registry_proto.__deepcopy__() - if self.cached_registry_proto - else RegistryProto() - ) - new_registry.cached_registry_proto_created = _utc_now() - new_registry._registry_store = NoopRegistryStore() - return new_registry - - def update_infra(self, infra: Infra, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - self.cached_registry_proto.infra.CopyFrom(infra.to_proto()) - if commit: - self.commit() - - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return Infra.from_proto(registry_proto.infra) - - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - entity.is_valid() - - now = _utc_now() - if not entity.created_timestamp: - entity.created_timestamp = now - entity.last_updated_timestamp = now - - entity_proto = entity.to_proto() - entity_proto.spec.project = project - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_entity_proto in enumerate( - self.cached_registry_proto.entities - ): - if ( - existing_entity_proto.spec.name == entity_proto.spec.name - and existing_entity_proto.spec.project == project - ): - entity.created_timestamp = ( - existing_entity_proto.meta.created_timestamp.ToDatetime() - ) - entity_proto = entity.to_proto() - entity_proto.spec.project = project - del self.cached_registry_proto.entities[idx] - break - self.cached_registry_proto.entities.append(entity_proto) - if commit: - self.commit() - - def list_entities( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Entity]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_entities(registry_proto, project, tags) - - def list_data_sources( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[DataSource]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_data_sources(registry_proto, project, tags) - - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - registry = self._prepare_registry_for_changes(project) - for idx, existing_data_source_proto in enumerate(registry.data_sources): - if existing_data_source_proto.name == data_source.name: - del registry.data_sources[idx] - data_source_proto = data_source.to_proto() - data_source_proto.project = project - data_source_proto.data_source_class_type = ( - f"{data_source.__class__.__module__}.{data_source.__class__.__name__}" - ) - self.cached_registry_proto.data_sources.append(data_source_proto) - if commit: - self.commit() - - def delete_data_source(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, data_source_proto in enumerate( - self.cached_registry_proto.data_sources - ): - if data_source_proto.name == name: - del self.cached_registry_proto.data_sources[idx] - if commit: - self.commit() - return - raise DataSourceNotFoundException(name) - - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - now = _utc_now() - if not feature_service.created_timestamp: - feature_service.created_timestamp = now - feature_service.last_updated_timestamp = now - - feature_service_proto = feature_service.to_proto() - feature_service_proto.spec.project = project - - registry = self._prepare_registry_for_changes(project) - - for idx, existing_feature_service_proto in enumerate(registry.feature_services): - if ( - existing_feature_service_proto.spec.name - == feature_service_proto.spec.name - and existing_feature_service_proto.spec.project == project - ): - feature_service.created_timestamp = ( - existing_feature_service_proto.meta.created_timestamp.ToDatetime() - ) - feature_service_proto = feature_service.to_proto() - feature_service_proto.spec.project = project - del registry.feature_services[idx] - self.cached_registry_proto.feature_services.append(feature_service_proto) - if commit: - self.commit() - - def list_feature_services( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureService]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_feature_services(registry_proto, project, tags) - - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_feature_service(registry_proto, name, project) - - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_entity(registry_proto, name, project) - - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - feature_view.ensure_valid() - - now = _utc_now() - if not feature_view.created_timestamp: - feature_view.created_timestamp = now - feature_view.last_updated_timestamp = now - - feature_view_proto = feature_view.to_proto() - feature_view_proto.spec.project = project - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - self._check_conflicting_feature_view_names(feature_view) - existing_feature_views_of_same_type: RepeatedCompositeFieldContainer - if isinstance(feature_view, StreamFeatureView): - existing_feature_views_of_same_type = ( - self.cached_registry_proto.stream_feature_views - ) - elif isinstance(feature_view, FeatureView): - existing_feature_views_of_same_type = ( - self.cached_registry_proto.feature_views - ) - elif isinstance(feature_view, OnDemandFeatureView): - existing_feature_views_of_same_type = ( - self.cached_registry_proto.on_demand_feature_views - ) - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - - for idx, existing_feature_view_proto in enumerate( - existing_feature_views_of_same_type - ): - if ( - existing_feature_view_proto.spec.name == feature_view_proto.spec.name - and existing_feature_view_proto.spec.project == project - ): - if ( - feature_view.__class__.from_proto(existing_feature_view_proto) - == feature_view - ): - return - else: - existing_feature_view = type(feature_view).from_proto( - existing_feature_view_proto - ) - feature_view.created_timestamp = ( - existing_feature_view.created_timestamp - ) - if isinstance(feature_view, (FeatureView, StreamFeatureView)): - feature_view.update_materialization_intervals( - existing_feature_view.materialization_intervals - ) - feature_view_proto = feature_view.to_proto() - feature_view_proto.spec.project = project - del existing_feature_views_of_same_type[idx] - break - - existing_feature_views_of_same_type.append(feature_view_proto) - if commit: - self.commit() - - def list_stream_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[StreamFeatureView]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_stream_feature_views( - registry_proto, project, tags - ) - - def list_on_demand_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[OnDemandFeatureView]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_on_demand_feature_views( - registry_proto, project, tags - ) - - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> OnDemandFeatureView: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_on_demand_feature_view( - registry_proto, name, project - ) - - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_data_source(registry_proto, name, project) - - def apply_materialization( - self, - feature_view: FeatureView, - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_feature_view_proto in enumerate( - self.cached_registry_proto.feature_views - ): - if ( - existing_feature_view_proto.spec.name == feature_view.name - and existing_feature_view_proto.spec.project == project - ): - existing_feature_view = FeatureView.from_proto( - existing_feature_view_proto - ) - existing_feature_view.materialization_intervals.append( - (start_date, end_date) - ) - existing_feature_view.last_updated_timestamp = _utc_now() - feature_view_proto = existing_feature_view.to_proto() - feature_view_proto.spec.project = project - del self.cached_registry_proto.feature_views[idx] - self.cached_registry_proto.feature_views.append(feature_view_proto) - if commit: - self.commit() - return - - for idx, existing_stream_feature_view_proto in enumerate( - self.cached_registry_proto.stream_feature_views - ): - if ( - existing_stream_feature_view_proto.spec.name == feature_view.name - and existing_stream_feature_view_proto.spec.project == project - ): - existing_stream_feature_view = StreamFeatureView.from_proto( - existing_stream_feature_view_proto - ) - existing_stream_feature_view.materialization_intervals.append( - (start_date, end_date) - ) - existing_stream_feature_view.last_updated_timestamp = _utc_now() - stream_feature_view_proto = existing_stream_feature_view.to_proto() - stream_feature_view_proto.spec.project = project - del self.cached_registry_proto.stream_feature_views[idx] - self.cached_registry_proto.stream_feature_views.append( - stream_feature_view_proto - ) - if commit: - self.commit() - return - - def list_all_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[BaseFeatureView]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_all_feature_views( - registry_proto, project, tags - ) - - def get_any_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> BaseFeatureView: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_any_feature_view(registry_proto, name, project) - - def list_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureView]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_feature_views(registry_proto, project, tags) - - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_feature_view(registry_proto, name, project) - - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> StreamFeatureView: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_stream_feature_view( - registry_proto, name, project - ) - - def delete_feature_service(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, feature_service_proto in enumerate( - self.cached_registry_proto.feature_services - ): - if ( - feature_service_proto.spec.name == name - and feature_service_proto.spec.project == project - ): - del self.cached_registry_proto.feature_services[idx] - if commit: - self.commit() - return - raise FeatureServiceNotFoundException(name, project) - - def delete_feature_view(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_feature_view_proto in enumerate( - self.cached_registry_proto.feature_views - ): - if ( - existing_feature_view_proto.spec.name == name - and existing_feature_view_proto.spec.project == project - ): - del self.cached_registry_proto.feature_views[idx] - if commit: - self.commit() - return - - for idx, existing_on_demand_feature_view_proto in enumerate( - self.cached_registry_proto.on_demand_feature_views - ): - if ( - existing_on_demand_feature_view_proto.spec.name == name - and existing_on_demand_feature_view_proto.spec.project == project - ): - del self.cached_registry_proto.on_demand_feature_views[idx] - if commit: - self.commit() - return - - for idx, existing_stream_feature_view_proto in enumerate( - self.cached_registry_proto.stream_feature_views - ): - if ( - existing_stream_feature_view_proto.spec.name == name - and existing_stream_feature_view_proto.spec.project == project - ): - del self.cached_registry_proto.stream_feature_views[idx] - if commit: - self.commit() - return - - raise FeatureViewNotFoundException(name, project) - - def delete_entity(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_entity_proto in enumerate( - self.cached_registry_proto.entities - ): - if ( - existing_entity_proto.spec.name == name - and existing_entity_proto.spec.project == project - ): - del self.cached_registry_proto.entities[idx] - if commit: - self.commit() - return - - raise EntityNotFoundException(name, project) - - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - now = _utc_now() - if not saved_dataset.created_timestamp: - saved_dataset.created_timestamp = now - saved_dataset.last_updated_timestamp = now - - saved_dataset_proto = saved_dataset.to_proto() - saved_dataset_proto.spec.project = project - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_saved_dataset_proto in enumerate( - self.cached_registry_proto.saved_datasets - ): - if ( - existing_saved_dataset_proto.spec.name == saved_dataset_proto.spec.name - and existing_saved_dataset_proto.spec.project == project - ): - saved_dataset.created_timestamp = ( - existing_saved_dataset_proto.meta.created_timestamp.ToDatetime() - ) - saved_dataset.min_event_timestamp = ( - existing_saved_dataset_proto.meta.min_event_timestamp.ToDatetime() - ) - saved_dataset.max_event_timestamp = ( - existing_saved_dataset_proto.meta.max_event_timestamp.ToDatetime() - ) - saved_dataset_proto = saved_dataset.to_proto() - saved_dataset_proto.spec.project = project - del self.cached_registry_proto.saved_datasets[idx] - break - - self.cached_registry_proto.saved_datasets.append(saved_dataset_proto) - if commit: - self.commit() - - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_saved_dataset(registry_proto, name, project) - - def list_saved_datasets( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[SavedDataset]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_saved_datasets(registry_proto, project, tags) - - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - validation_reference_proto = validation_reference.to_proto() - validation_reference_proto.project = project - - registry_proto = self._prepare_registry_for_changes(project) - for idx, existing_validation_reference in enumerate( - registry_proto.validation_references - ): - if ( - existing_validation_reference.name == validation_reference_proto.name - and existing_validation_reference.project == project - ): - del registry_proto.validation_references[idx] - break - - registry_proto.validation_references.append(validation_reference_proto) - if commit: - self.commit() - - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_validation_reference( - registry_proto, name, project - ) - - def list_validation_references( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[ValidationReference]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_validation_references( - registry_proto, project, tags - ) - - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - for idx, existing_validation_reference in enumerate( - self.cached_registry_proto.validation_references - ): - if ( - existing_validation_reference.name == name - and existing_validation_reference.project == project - ): - del self.cached_registry_proto.validation_references[idx] - if commit: - self.commit() - return - raise ValidationReferenceNotFound(name, project=project) - - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_project_metadata(registry_proto, project) - - def commit(self): - """Commits the state of the registry cache to the remote registry store.""" - if self.cached_registry_proto: - self._registry_store.update_registry_proto(self.cached_registry_proto) - - def refresh(self, project: Optional[str] = None): - """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" - self._get_registry_proto(project=project, allow_cache=False) - - def teardown(self): - """Tears down (removes) the registry.""" - self._registry_store.teardown() - - def proto(self) -> RegistryProto: - return self.cached_registry_proto or RegistryProto() - - def _prepare_registry_for_changes(self, project_name: str): - """Prepares the Registry for changes by refreshing the cache if necessary.""" - - assert self.cached_registry_proto is not None - - try: - # Check if the project exists in the registry cache - self.get_project(name=project_name, allow_cache=True) - return self.cached_registry_proto - except ProjectObjectNotFoundException: - # If the project does not exist in cache, refresh cache from store - registry_proto = self._registry_store.get_registry_proto() - self.cached_registry_proto = registry_proto - self.cached_registry_proto_created = _utc_now() - - try: - # Check if the project exists in the registry cache after refresh from store - self.get_project(name=project_name) - except ProjectObjectNotFoundException: - # If the project still does not exist, create it - project_proto = Project(name=project_name).to_proto() - self.cached_registry_proto.projects.append(project_proto) - if not self.purge_feast_metadata: - project_metadata_proto = ProjectMetadata( - project_name=project_name - ).to_proto() - self.cached_registry_proto.project_metadata.append( - project_metadata_proto - ) - self.commit() - return self.cached_registry_proto - - def _get_registry_proto( - self, project: Optional[str], allow_cache: bool = False - ) -> RegistryProto: - """Returns the cached or remote registry state - - Args: - project: Name of the Feast project (optional) - allow_cache: Whether to allow the use of the registry cache when fetching the RegistryProto - - Returns: Returns a RegistryProto object which represents the state of the registry - """ - with self._refresh_lock: - expired = (self.cached_registry_proto_created is None) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - _utc_now() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl - ) - ) - ) - - if allow_cache and not expired: - return self.cached_registry_proto - logger.info("Registry cache expired, so refreshing") - registry_proto = self._registry_store.get_registry_proto() - self.cached_registry_proto = registry_proto - self.cached_registry_proto_created = _utc_now() - return registry_proto - - def _check_conflicting_feature_view_names(self, feature_view: BaseFeatureView): - name_to_fv_protos = self._existing_feature_view_names_to_fvs() - if feature_view.name in name_to_fv_protos: - if not isinstance( - name_to_fv_protos.get(feature_view.name), feature_view.proto_class - ): - raise ConflictingFeatureViewNames(feature_view.name) - - def _existing_feature_view_names_to_fvs(self) -> Dict[str, Message]: - assert self.cached_registry_proto - odfvs = { - fv.spec.name: fv - for fv in self.cached_registry_proto.on_demand_feature_views - } - fvs = {fv.spec.name: fv for fv in self.cached_registry_proto.feature_views} - sfv = { - fv.spec.name: fv for fv in self.cached_registry_proto.stream_feature_views - } - return {**odfvs, **fvs, **sfv} - - def get_permission( - self, name: str, project: str, allow_cache: bool = False - ) -> Permission: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_permission(registry_proto, name, project) - - def list_permissions( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Permission]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_permissions(registry_proto, project, tags) - - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - now = _utc_now() - if not permission.created_timestamp: - permission.created_timestamp = now - permission.last_updated_timestamp = now - - registry = self._prepare_registry_for_changes(project) - for idx, existing_permission_proto in enumerate(registry.permissions): - if ( - existing_permission_proto.spec.name == permission.name - and existing_permission_proto.spec.project == project - ): - permission.created_timestamp = ( - existing_permission_proto.meta.created_timestamp.ToDatetime() - ) - del registry.permissions[idx] - - permission_proto = permission.to_proto() - permission_proto.spec.project = project - self.cached_registry_proto.permissions.append(permission_proto) - if commit: - self.commit() - - def delete_permission(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, permission_proto in enumerate(self.cached_registry_proto.permissions): - if ( - permission_proto.spec.name == name - and permission_proto.spec.project == project - ): - del self.cached_registry_proto.permissions[idx] - if commit: - self.commit() - return - raise PermissionNotFoundException(name, project) - - def apply_project( - self, - project: Project, - commit: bool = True, - ): - registry = self.cached_registry_proto - - for idx, existing_project_proto in enumerate(registry.projects): - if existing_project_proto.spec.name == project.name: - project.created_timestamp = ( - existing_project_proto.meta.created_timestamp.ToDatetime().replace( - tzinfo=timezone.utc - ) - ) - del registry.projects[idx] - - project_proto = project.to_proto() - self.cached_registry_proto.projects.append(project_proto) - if commit: - self.commit() - - def get_project( - self, - name: str, - allow_cache: bool = False, - ) -> Project: - registry_proto = self._get_registry_proto(project=name, allow_cache=allow_cache) - return proto_registry_utils.get_project(registry_proto, name) - - def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Project]: - registry_proto = self._get_registry_proto(project=None, allow_cache=allow_cache) - return proto_registry_utils.list_projects( - registry_proto=registry_proto, tags=tags - ) - - def delete_project( - self, - name: str, - commit: bool = True, - ): - assert self.cached_registry_proto - - for idx, project_proto in enumerate(self.cached_registry_proto.projects): - if project_proto.spec.name == name: - list_validation_references = self.list_validation_references(name) - for validation_reference in list_validation_references: - self.delete_validation_reference(validation_reference.name, name) - - list_saved_datasets = self.list_saved_datasets(name) - for saved_dataset in list_saved_datasets: - self.delete_saved_dataset(saved_dataset.name, name) - - list_feature_services = self.list_feature_services(name) - for feature_service in list_feature_services: - self.delete_feature_service(feature_service.name, name) - - list_on_demand_feature_views = self.list_on_demand_feature_views(name) - for on_demand_feature_view in list_on_demand_feature_views: - self.delete_feature_view(on_demand_feature_view.name, name) - - list_stream_feature_views = self.list_stream_feature_views(name) - for stream_feature_view in list_stream_feature_views: - self.delete_feature_view(stream_feature_view.name, name) - - list_feature_views = self.list_feature_views(name) - for feature_view in list_feature_views: - self.delete_feature_view(feature_view.name, name) - - list_data_sources = self.list_data_sources(name) - for data_source in list_data_sources: - self.delete_data_source(data_source.name, name) - - list_entities = self.list_entities(name) - for entity in list_entities: - self.delete_entity(entity.name, name) - list_permissions = self.list_permissions(name) - for permission in list_permissions: - self.delete_permission(permission.name, name) - del self.cached_registry_proto.projects[idx] - if commit: - self.commit() - return - raise ProjectNotFoundException(name) +# Copyright 2019 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from datetime import datetime, timedelta, timezone +from enum import Enum +from pathlib import Path +from threading import Lock +from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse + +from google.protobuf.internal.containers import RepeatedCompositeFieldContainer +from google.protobuf.message import Message + +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.errors import ( + ConflictingFeatureViewNames, + DataSourceNotFoundException, + EntityNotFoundException, + FeatureServiceNotFoundException, + FeatureViewNotFoundException, + PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, + ValidationReferenceNotFound, +) +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.importer import import_class +from feast.infra.infra_object import Infra +from feast.infra.registry import proto_registry_utils +from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.registry.registry_store import NoopRegistryStore +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.auth_model import AuthConfig, NoAuthConfig +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.repo_config import RegistryConfig +from feast.repo_contents import RepoContents +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView +from feast.utils import _utc_now + +REGISTRY_SCHEMA_VERSION = "1" + +REGISTRY_STORE_CLASS_FOR_TYPE = { + "GCSRegistryStore": "feast.infra.registry.gcs.GCSRegistryStore", + "S3RegistryStore": "feast.infra.registry.s3.S3RegistryStore", + "FileRegistryStore": "feast.infra.registry.file.FileRegistryStore", + "AzureRegistryStore": "feast.infra.registry.contrib.azure.azure_registry_store.AzBlobRegistryStore", +} + +REGISTRY_STORE_CLASS_FOR_SCHEME = { + "gs": "GCSRegistryStore", + "s3": "S3RegistryStore", + "file": "FileRegistryStore", + "": "FileRegistryStore", +} + + +class FeastObjectType(Enum): + PROJECT = "project" + DATA_SOURCE = "data source" + ENTITY = "entity" + FEATURE_VIEW = "feature view" + ON_DEMAND_FEATURE_VIEW = "on demand feature view" + STREAM_FEATURE_VIEW = "stream feature view" + FEATURE_SERVICE = "feature service" + PERMISSION = "permission" + + @staticmethod + def get_objects_from_registry( + registry: "BaseRegistry", project: str + ) -> Dict["FeastObjectType", List[Any]]: + return { + FeastObjectType.PROJECT: [ + project_obj + for project_obj in registry.list_projects() + if project_obj.name == project + ], + FeastObjectType.DATA_SOURCE: registry.list_data_sources(project=project), + FeastObjectType.ENTITY: registry.list_entities(project=project), + FeastObjectType.FEATURE_VIEW: registry.list_feature_views(project=project), + FeastObjectType.ON_DEMAND_FEATURE_VIEW: registry.list_on_demand_feature_views( + project=project + ), + FeastObjectType.STREAM_FEATURE_VIEW: registry.list_stream_feature_views( + project=project, + ), + FeastObjectType.FEATURE_SERVICE: registry.list_feature_services( + project=project + ), + FeastObjectType.PERMISSION: registry.list_permissions(project=project), + } + + @staticmethod + def get_objects_from_repo_contents( + repo_contents: RepoContents, + ) -> Dict["FeastObjectType", List[Any]]: + return { + FeastObjectType.PROJECT: repo_contents.projects, + FeastObjectType.DATA_SOURCE: repo_contents.data_sources, + FeastObjectType.ENTITY: repo_contents.entities, + FeastObjectType.FEATURE_VIEW: repo_contents.feature_views, + FeastObjectType.ON_DEMAND_FEATURE_VIEW: repo_contents.on_demand_feature_views, + FeastObjectType.STREAM_FEATURE_VIEW: repo_contents.stream_feature_views, + FeastObjectType.FEATURE_SERVICE: repo_contents.feature_services, + FeastObjectType.PERMISSION: repo_contents.permissions, + } + + +FEAST_OBJECT_TYPES = [feast_object_type for feast_object_type in FeastObjectType] + +logger = logging.getLogger(__name__) + + +def get_registry_store_class_from_type(registry_store_type: str): + if not registry_store_type.endswith("RegistryStore"): + raise Exception('Registry store class name should end with "RegistryStore"') + if registry_store_type in REGISTRY_STORE_CLASS_FOR_TYPE: + registry_store_type = REGISTRY_STORE_CLASS_FOR_TYPE[registry_store_type] + module_name, registry_store_class_name = registry_store_type.rsplit(".", 1) + + return import_class(module_name, registry_store_class_name, "RegistryStore") + + +def get_registry_store_class_from_scheme(registry_path: str): + uri = urlparse(registry_path) + if uri.scheme not in REGISTRY_STORE_CLASS_FOR_SCHEME: + raise Exception( + f"Registry path {registry_path} has unsupported scheme {uri.scheme}. " + f"Supported schemes are file, s3 and gs." + ) + else: + registry_store_type = REGISTRY_STORE_CLASS_FOR_SCHEME[uri.scheme] + return get_registry_store_class_from_type(registry_store_type) + + +class Registry(BaseRegistry): + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + pass + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + pass + + # The cached_registry_proto object is used for both reads and writes. In particular, + # all write operations refresh the cache and modify it in memory; the write must + # then be persisted to the underlying RegistryStore with a call to commit(). + cached_registry_proto: RegistryProto + cached_registry_proto_created: datetime + cached_registry_proto_ttl: timedelta + + def __init__( + self, + project: str, + registry_config: Optional[RegistryConfig], + repo_path: Optional[Path], + auth_config: AuthConfig = NoAuthConfig(), + ): + """ + Create the Registry object. + + Args: + registry_config: RegistryConfig object containing the destination path and cache ttl, + repo_path: Path to the base of the Feast repository + or where it will be created if it does not exist yet. + """ + + self._refresh_lock = Lock() + self._auth_config = auth_config + + registry_proto = RegistryProto() + registry_proto.registry_schema_version = REGISTRY_SCHEMA_VERSION + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + + self.purge_feast_metadata = ( + registry_config.purge_feast_metadata + if registry_config is not None + else False + ) + + if registry_config: + registry_store_type = registry_config.registry_store_type + registry_path = registry_config.path + if registry_store_type is None: + cls = get_registry_store_class_from_scheme(registry_path) + else: + cls = get_registry_store_class_from_type(str(registry_store_type)) + + self._registry_store = cls(registry_config, repo_path) + self.cached_registry_proto_ttl = timedelta( + seconds=( + registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) + ) + + try: + registry_proto = self._registry_store.get_registry_proto() + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + # Sync feast_metadata to projects table + # when purge_feast_metadata is set to True, Delete data from + # feast_metadata table and list_project_metadata will not return any data + self._sync_feast_metadata_to_projects_table() + except FileNotFoundError: + logger.info("Registry file not found. Creating new registry.") + self.commit() + + def _sync_feast_metadata_to_projects_table(self): + """ + Sync feast_metadata to projects table + """ + feast_metadata_projects = [] + projects_set = [] + # List of project in project_metadata + for project_metadata in self.cached_registry_proto.project_metadata: + project = ProjectMetadata.from_proto(project_metadata) + feast_metadata_projects.append(project.project_name) + if len(feast_metadata_projects) > 0: + # List of project in projects + for project_metadata in self.cached_registry_proto.projects: + project = Project.from_proto(project_metadata) + projects_set.append(project.name) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects) - set(projects_set) + # Sync feast_metadata to projects table + for project_name in projects_to_sync: + project = Project(name=project_name) + self.cached_registry_proto.projects.append(project.to_proto()) + + if self.purge_feast_metadata: + self.cached_registry_proto.project_metadata = [] + + def clone(self) -> "Registry": + new_registry = Registry("project", None, None, self._auth_config) + new_registry.cached_registry_proto_ttl = timedelta(seconds=0) + new_registry.cached_registry_proto = ( + self.cached_registry_proto.__deepcopy__() + if self.cached_registry_proto + else RegistryProto() + ) + new_registry.cached_registry_proto_created = _utc_now() + new_registry._registry_store = NoopRegistryStore() + return new_registry + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + self.cached_registry_proto.infra.CopyFrom(infra.to_proto()) + if commit: + self.commit() + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return Infra.from_proto(registry_proto.infra) + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + entity.is_valid() + + now = _utc_now() + if not entity.created_timestamp: + entity.created_timestamp = now + entity.last_updated_timestamp = now + + entity_proto = entity.to_proto() + entity_proto.spec.project = project + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_entity_proto in enumerate( + self.cached_registry_proto.entities + ): + if ( + existing_entity_proto.spec.name == entity_proto.spec.name + and existing_entity_proto.spec.project == project + ): + entity.created_timestamp = ( + existing_entity_proto.meta.created_timestamp.ToDatetime() + ) + entity_proto = entity.to_proto() + entity_proto.spec.project = project + del self.cached_registry_proto.entities[idx] + break + self.cached_registry_proto.entities.append(entity_proto) + if commit: + self.commit() + + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_entities(registry_proto, project, tags) + + def list_data_sources( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[DataSource]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_data_sources(registry_proto, project, tags) + + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + registry = self._prepare_registry_for_changes(project) + for idx, existing_data_source_proto in enumerate(registry.data_sources): + if existing_data_source_proto.name == data_source.name: + del registry.data_sources[idx] + data_source_proto = data_source.to_proto() + data_source_proto.project = project + data_source_proto.data_source_class_type = ( + f"{data_source.__class__.__module__}.{data_source.__class__.__name__}" + ) + self.cached_registry_proto.data_sources.append(data_source_proto) + if commit: + self.commit() + + def delete_data_source(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, data_source_proto in enumerate( + self.cached_registry_proto.data_sources + ): + if data_source_proto.name == name: + del self.cached_registry_proto.data_sources[idx] + if commit: + self.commit() + return + raise DataSourceNotFoundException(name) + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + now = _utc_now() + if not feature_service.created_timestamp: + feature_service.created_timestamp = now + feature_service.last_updated_timestamp = now + + feature_service_proto = feature_service.to_proto() + feature_service_proto.spec.project = project + + registry = self._prepare_registry_for_changes(project) + + for idx, existing_feature_service_proto in enumerate(registry.feature_services): + if ( + existing_feature_service_proto.spec.name + == feature_service_proto.spec.name + and existing_feature_service_proto.spec.project == project + ): + feature_service.created_timestamp = ( + existing_feature_service_proto.meta.created_timestamp.ToDatetime() + ) + feature_service_proto = feature_service.to_proto() + feature_service_proto.spec.project = project + del registry.feature_services[idx] + self.cached_registry_proto.feature_services.append(feature_service_proto) + if commit: + self.commit() + + def list_feature_services( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureService]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_feature_services(registry_proto, project, tags) + + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_feature_service(registry_proto, name, project) + + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_entity(registry_proto, name, project) + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + feature_view.ensure_valid() + + now = _utc_now() + if not feature_view.created_timestamp: + feature_view.created_timestamp = now + feature_view.last_updated_timestamp = now + + feature_view_proto = feature_view.to_proto() + feature_view_proto.spec.project = project + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + self._check_conflicting_feature_view_names(feature_view) + existing_feature_views_of_same_type: RepeatedCompositeFieldContainer + if isinstance(feature_view, StreamFeatureView): + existing_feature_views_of_same_type = ( + self.cached_registry_proto.stream_feature_views + ) + elif isinstance(feature_view, FeatureView): + existing_feature_views_of_same_type = ( + self.cached_registry_proto.feature_views + ) + elif isinstance(feature_view, OnDemandFeatureView): + existing_feature_views_of_same_type = ( + self.cached_registry_proto.on_demand_feature_views + ) + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + + for idx, existing_feature_view_proto in enumerate( + existing_feature_views_of_same_type + ): + if ( + existing_feature_view_proto.spec.name == feature_view_proto.spec.name + and existing_feature_view_proto.spec.project == project + ): + if ( + feature_view.__class__.from_proto(existing_feature_view_proto) + == feature_view + ): + return + else: + existing_feature_view = type(feature_view).from_proto( + existing_feature_view_proto + ) + feature_view.created_timestamp = ( + existing_feature_view.created_timestamp + ) + if isinstance(feature_view, (FeatureView, StreamFeatureView)): + feature_view.update_materialization_intervals( + existing_feature_view.materialization_intervals + ) + feature_view_proto = feature_view.to_proto() + feature_view_proto.spec.project = project + del existing_feature_views_of_same_type[idx] + break + + existing_feature_views_of_same_type.append(feature_view_proto) + if commit: + self.commit() + + def list_stream_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[StreamFeatureView]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_stream_feature_views( + registry_proto, project, tags + ) + + def list_on_demand_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[OnDemandFeatureView]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_on_demand_feature_views( + registry_proto, project, tags + ) + + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_on_demand_feature_view( + registry_proto, name, project + ) + + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_data_source(registry_proto, name, project) + + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_feature_view_proto in enumerate( + self.cached_registry_proto.feature_views + ): + if ( + existing_feature_view_proto.spec.name == feature_view.name + and existing_feature_view_proto.spec.project == project + ): + existing_feature_view = FeatureView.from_proto( + existing_feature_view_proto + ) + existing_feature_view.materialization_intervals.append( + (start_date, end_date) + ) + existing_feature_view.last_updated_timestamp = _utc_now() + feature_view_proto = existing_feature_view.to_proto() + feature_view_proto.spec.project = project + del self.cached_registry_proto.feature_views[idx] + self.cached_registry_proto.feature_views.append(feature_view_proto) + if commit: + self.commit() + return + + for idx, existing_stream_feature_view_proto in enumerate( + self.cached_registry_proto.stream_feature_views + ): + if ( + existing_stream_feature_view_proto.spec.name == feature_view.name + and existing_stream_feature_view_proto.spec.project == project + ): + existing_stream_feature_view = StreamFeatureView.from_proto( + existing_stream_feature_view_proto + ) + existing_stream_feature_view.materialization_intervals.append( + (start_date, end_date) + ) + existing_stream_feature_view.last_updated_timestamp = _utc_now() + stream_feature_view_proto = existing_stream_feature_view.to_proto() + stream_feature_view_proto.spec.project = project + del self.cached_registry_proto.stream_feature_views[idx] + self.cached_registry_proto.stream_feature_views.append( + stream_feature_view_proto + ) + if commit: + self.commit() + return + + def list_all_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[BaseFeatureView]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_all_feature_views( + registry_proto, project, tags + ) + + def get_any_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> BaseFeatureView: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_any_feature_view(registry_proto, name, project) + + def list_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureView]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_feature_views(registry_proto, project, tags) + + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_feature_view(registry_proto, name, project) + + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> StreamFeatureView: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_stream_feature_view( + registry_proto, name, project + ) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, feature_service_proto in enumerate( + self.cached_registry_proto.feature_services + ): + if ( + feature_service_proto.spec.name == name + and feature_service_proto.spec.project == project + ): + del self.cached_registry_proto.feature_services[idx] + if commit: + self.commit() + return + raise FeatureServiceNotFoundException(name, project) + + def delete_feature_view(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_feature_view_proto in enumerate( + self.cached_registry_proto.feature_views + ): + if ( + existing_feature_view_proto.spec.name == name + and existing_feature_view_proto.spec.project == project + ): + del self.cached_registry_proto.feature_views[idx] + if commit: + self.commit() + return + + for idx, existing_on_demand_feature_view_proto in enumerate( + self.cached_registry_proto.on_demand_feature_views + ): + if ( + existing_on_demand_feature_view_proto.spec.name == name + and existing_on_demand_feature_view_proto.spec.project == project + ): + del self.cached_registry_proto.on_demand_feature_views[idx] + if commit: + self.commit() + return + + for idx, existing_stream_feature_view_proto in enumerate( + self.cached_registry_proto.stream_feature_views + ): + if ( + existing_stream_feature_view_proto.spec.name == name + and existing_stream_feature_view_proto.spec.project == project + ): + del self.cached_registry_proto.stream_feature_views[idx] + if commit: + self.commit() + return + + raise FeatureViewNotFoundException(name, project) + + def delete_entity(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_entity_proto in enumerate( + self.cached_registry_proto.entities + ): + if ( + existing_entity_proto.spec.name == name + and existing_entity_proto.spec.project == project + ): + del self.cached_registry_proto.entities[idx] + if commit: + self.commit() + return + + raise EntityNotFoundException(name, project) + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + now = _utc_now() + if not saved_dataset.created_timestamp: + saved_dataset.created_timestamp = now + saved_dataset.last_updated_timestamp = now + + saved_dataset_proto = saved_dataset.to_proto() + saved_dataset_proto.spec.project = project + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_saved_dataset_proto in enumerate( + self.cached_registry_proto.saved_datasets + ): + if ( + existing_saved_dataset_proto.spec.name == saved_dataset_proto.spec.name + and existing_saved_dataset_proto.spec.project == project + ): + saved_dataset.created_timestamp = ( + existing_saved_dataset_proto.meta.created_timestamp.ToDatetime() + ) + saved_dataset.min_event_timestamp = ( + existing_saved_dataset_proto.meta.min_event_timestamp.ToDatetime() + ) + saved_dataset.max_event_timestamp = ( + existing_saved_dataset_proto.meta.max_event_timestamp.ToDatetime() + ) + saved_dataset_proto = saved_dataset.to_proto() + saved_dataset_proto.spec.project = project + del self.cached_registry_proto.saved_datasets[idx] + break + + self.cached_registry_proto.saved_datasets.append(saved_dataset_proto) + if commit: + self.commit() + + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_saved_dataset(registry_proto, name, project) + + def list_saved_datasets( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[SavedDataset]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_saved_datasets(registry_proto, project, tags) + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + validation_reference_proto = validation_reference.to_proto() + validation_reference_proto.project = project + + registry_proto = self._prepare_registry_for_changes(project) + for idx, existing_validation_reference in enumerate( + registry_proto.validation_references + ): + if ( + existing_validation_reference.name == validation_reference_proto.name + and existing_validation_reference.project == project + ): + del registry_proto.validation_references[idx] + break + + registry_proto.validation_references.append(validation_reference_proto) + if commit: + self.commit() + + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_validation_reference( + registry_proto, name, project + ) + + def list_validation_references( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[ValidationReference]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_validation_references( + registry_proto, project, tags + ) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + for idx, existing_validation_reference in enumerate( + self.cached_registry_proto.validation_references + ): + if ( + existing_validation_reference.name == name + and existing_validation_reference.project == project + ): + del self.cached_registry_proto.validation_references[idx] + if commit: + self.commit() + return + raise ValidationReferenceNotFound(name, project=project) + + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_project_metadata(registry_proto, project) + + def commit(self): + """Commits the state of the registry cache to the remote registry store.""" + if self.cached_registry_proto: + self._registry_store.update_registry_proto(self.cached_registry_proto) + + def refresh(self, project: Optional[str] = None): + """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" + self._get_registry_proto(project=project, allow_cache=False) + + def teardown(self): + """Tears down (removes) the registry.""" + self._registry_store.teardown() + + def proto(self) -> RegistryProto: + return self.cached_registry_proto or RegistryProto() + + def _prepare_registry_for_changes(self, project_name: str): + """Prepares the Registry for changes by refreshing the cache if necessary.""" + + assert self.cached_registry_proto is not None + + try: + # Check if the project exists in the registry cache + self.get_project(name=project_name, allow_cache=True) + return self.cached_registry_proto + except ProjectObjectNotFoundException: + # If the project does not exist in cache, refresh cache from store + registry_proto = self._registry_store.get_registry_proto() + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + + try: + # Check if the project exists in the registry cache after refresh from store + self.get_project(name=project_name) + except ProjectObjectNotFoundException: + # If the project still does not exist, create it + project_proto = Project(name=project_name).to_proto() + self.cached_registry_proto.projects.append(project_proto) + if not self.purge_feast_metadata: + project_metadata_proto = ProjectMetadata( + project_name=project_name + ).to_proto() + self.cached_registry_proto.project_metadata.append( + project_metadata_proto + ) + self.commit() + return self.cached_registry_proto + + def _get_registry_proto( + self, project: Optional[str], allow_cache: bool = False + ) -> RegistryProto: + """Returns the cached or remote registry state + + Args: + project: Name of the Feast project (optional) + allow_cache: Whether to allow the use of the registry cache when fetching the RegistryProto + + Returns: Returns a RegistryProto object which represents the state of the registry + """ + with self._refresh_lock: + expired = (self.cached_registry_proto_created is None) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + _utc_now() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) + ) + ) + + if allow_cache and not expired: + return self.cached_registry_proto + logger.info("Registry cache expired, so refreshing") + registry_proto = self._registry_store.get_registry_proto() + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + return registry_proto + + def _check_conflicting_feature_view_names(self, feature_view: BaseFeatureView): + name_to_fv_protos = self._existing_feature_view_names_to_fvs() + if feature_view.name in name_to_fv_protos: + if not isinstance( + name_to_fv_protos.get(feature_view.name), feature_view.proto_class + ): + raise ConflictingFeatureViewNames(feature_view.name) + + def _existing_feature_view_names_to_fvs(self) -> Dict[str, Message]: + assert self.cached_registry_proto + odfvs = { + fv.spec.name: fv + for fv in self.cached_registry_proto.on_demand_feature_views + } + fvs = {fv.spec.name: fv for fv in self.cached_registry_proto.feature_views} + sfv = { + fv.spec.name: fv for fv in self.cached_registry_proto.stream_feature_views + } + return {**odfvs, **fvs, **sfv} + + def get_permission( + self, name: str, project: str, allow_cache: bool = False + ) -> Permission: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_permission(registry_proto, name, project) + + def list_permissions( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Permission]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_permissions(registry_proto, project, tags) + + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + now = _utc_now() + if not permission.created_timestamp: + permission.created_timestamp = now + permission.last_updated_timestamp = now + + registry = self._prepare_registry_for_changes(project) + for idx, existing_permission_proto in enumerate(registry.permissions): + if ( + existing_permission_proto.spec.name == permission.name + and existing_permission_proto.spec.project == project + ): + permission.created_timestamp = ( + existing_permission_proto.meta.created_timestamp.ToDatetime() + ) + del registry.permissions[idx] + + permission_proto = permission.to_proto() + permission_proto.spec.project = project + self.cached_registry_proto.permissions.append(permission_proto) + if commit: + self.commit() + + def delete_permission(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, permission_proto in enumerate(self.cached_registry_proto.permissions): + if ( + permission_proto.spec.name == name + and permission_proto.spec.project == project + ): + del self.cached_registry_proto.permissions[idx] + if commit: + self.commit() + return + raise PermissionNotFoundException(name, project) + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + registry = self.cached_registry_proto + + for idx, existing_project_proto in enumerate(registry.projects): + if existing_project_proto.spec.name == project.name: + project.created_timestamp = ( + existing_project_proto.meta.created_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ) + ) + del registry.projects[idx] + + project_proto = project.to_proto() + self.cached_registry_proto.projects.append(project_proto) + if commit: + self.commit() + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + registry_proto = self._get_registry_proto(project=name, allow_cache=allow_cache) + return proto_registry_utils.get_project(registry_proto, name) + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + registry_proto = self._get_registry_proto(project=None, allow_cache=allow_cache) + return proto_registry_utils.list_projects( + registry_proto=registry_proto, tags=tags + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + assert self.cached_registry_proto + + for idx, project_proto in enumerate(self.cached_registry_proto.projects): + if project_proto.spec.name == name: + list_validation_references = self.list_validation_references(name) + for validation_reference in list_validation_references: + self.delete_validation_reference(validation_reference.name, name) + + list_saved_datasets = self.list_saved_datasets(name) + for saved_dataset in list_saved_datasets: + self.delete_saved_dataset(saved_dataset.name, name) + + list_feature_services = self.list_feature_services(name) + for feature_service in list_feature_services: + self.delete_feature_service(feature_service.name, name) + + list_on_demand_feature_views = self.list_on_demand_feature_views(name) + for on_demand_feature_view in list_on_demand_feature_views: + self.delete_feature_view(on_demand_feature_view.name, name) + + list_stream_feature_views = self.list_stream_feature_views(name) + for stream_feature_view in list_stream_feature_views: + self.delete_feature_view(stream_feature_view.name, name) + + list_feature_views = self.list_feature_views(name) + for feature_view in list_feature_views: + self.delete_feature_view(feature_view.name, name) + + list_data_sources = self.list_data_sources(name) + for data_source in list_data_sources: + self.delete_data_source(data_source.name, name) + + list_entities = self.list_entities(name) + for entity in list_entities: + self.delete_entity(entity.name, name) + list_permissions = self.list_permissions(name) + for permission in list_permissions: + self.delete_permission(permission.name, name) + del self.cached_registry_proto.projects[idx] + if commit: + self.commit() + return + raise ProjectNotFoundException(name) diff --git a/sdk/python/feast/infra/registry/remote.py b/sdk/python/feast/infra/registry/remote.py index 590c0454b73..78f901ca202 100644 --- a/sdk/python/feast/infra/registry/remote.py +++ b/sdk/python/feast/infra/registry/remote.py @@ -1,594 +1,594 @@ -import os -from datetime import datetime -from pathlib import Path -from typing import List, Optional, Union - -import grpc -from google.protobuf.empty_pb2 import Empty -from google.protobuf.timestamp_pb2 import Timestamp -from pydantic import StrictStr - -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.infra.infra_object import Infra -from feast.infra.registry.base_registry import BaseRegistry -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.auth_model import AuthConfig, NoAuthConfig -from feast.permissions.client.grpc_client_auth_interceptor import ( - GrpcClientAuthHeaderInterceptor, -) -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc -from feast.repo_config import RegistryConfig -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView - - -def extract_base_feature_view( - any_feature_view: RegistryServer_pb2.AnyFeatureView, -) -> BaseFeatureView: - feature_view_type = any_feature_view.WhichOneof("any_feature_view") - if feature_view_type == "feature_view": - feature_view = FeatureView.from_proto(any_feature_view.feature_view) - elif feature_view_type == "on_demand_feature_view": - feature_view = OnDemandFeatureView.from_proto( - any_feature_view.on_demand_feature_view - ) - elif feature_view_type == "stream_feature_view": - feature_view = StreamFeatureView.from_proto( - any_feature_view.stream_feature_view - ) - - return feature_view - - -class RemoteRegistryConfig(RegistryConfig): - registry_type: StrictStr = "remote" - """ str: Provider name or a class name that implements Registry.""" - - path: StrictStr = "" - """ str: Path to metadata store. - If registry_type is 'remote', then this is a URL for registry server """ - - cert: StrictStr = "" - """ str: Path to the public certificate when the registry server starts in TLS(SSL) mode. This may be needed if the registry server started with a self-signed certificate, typically this file ends with `*.crt`, `*.cer`, or `*.pem`. - If registry_type is 'remote', then this configuration is needed to connect to remote registry server in TLS mode. If the remote registry started in non-tls mode then this configuration is not needed.""" - - is_tls: bool = False - """ bool: Set to `True` if you plan to connect to a registry server running in TLS (SSL) mode. - If you intend to add the public certificate to the trust store instead of passing it via the `cert` parameter, this field must be set to `True`. - If you are planning to add the public certificate as part of the trust store instead of passing it as a `cert` parameters then setting this field to `true` is mandatory. - """ - - -class RemoteRegistry(BaseRegistry): - def __init__( - self, - registry_config: Union[RegistryConfig, RemoteRegistryConfig], - project: str, - repo_path: Optional[Path], - auth_config: AuthConfig = NoAuthConfig(), - ): - self.auth_config = auth_config - assert isinstance(registry_config, RemoteRegistryConfig) - self.channel = self._create_grpc_channel(registry_config) - - auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) - self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor) - self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel) - - def _create_grpc_channel(self, registry_config): - assert isinstance(registry_config, RemoteRegistryConfig) - if registry_config.cert or registry_config.is_tls: - cafile = os.getenv("SSL_CERT_FILE") or os.getenv("REQUESTS_CA_BUNDLE") - if not cafile and not registry_config.cert: - raise EnvironmentError( - "SSL_CERT_FILE or REQUESTS_CA_BUNDLE environment variable must be set to use secure TLS or set the cert parameter in feature_Store.yaml file under remote registry configuration." - ) - with open( - registry_config.cert if registry_config.cert else cafile, "rb" - ) as cert_file: - trusted_certs = cert_file.read() - tls_credentials = grpc.ssl_channel_credentials( - root_certificates=trusted_certs - ) - return grpc.secure_channel(registry_config.path, tls_credentials) - else: - # Create an insecure gRPC channel - return grpc.insecure_channel(registry_config.path) - - def close(self): - if self.channel: - self.channel.close() - - def __del__(self): - self.close() - - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - request = RegistryServer_pb2.ApplyEntityRequest( - entity=entity.to_proto(), project=project, commit=commit - ) - self.stub.ApplyEntity(request) - - def delete_entity(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteEntityRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteEntity(request) - - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - request = RegistryServer_pb2.GetEntityRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetEntity(request) - return Entity.from_proto(response) - - def list_entities( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Entity]: - request = RegistryServer_pb2.ListEntitiesRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListEntities(request) - return [Entity.from_proto(entity) for entity in response.entities] - - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - request = RegistryServer_pb2.ApplyDataSourceRequest( - data_source=data_source.to_proto(), project=project, commit=commit - ) - self.stub.ApplyDataSource(request) - - def delete_data_source(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteDataSourceRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteDataSource(request) - - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - request = RegistryServer_pb2.GetDataSourceRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetDataSource(request) - return DataSource.from_proto(response) - - def list_data_sources( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[DataSource]: - request = RegistryServer_pb2.ListDataSourcesRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListDataSources(request) - return [ - DataSource.from_proto(data_source) for data_source in response.data_sources - ] - - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - request = RegistryServer_pb2.ApplyFeatureServiceRequest( - feature_service=feature_service.to_proto(), project=project, commit=commit - ) - self.stub.ApplyFeatureService(request) - - def delete_feature_service(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteFeatureServiceRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteFeatureService(request) - - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - request = RegistryServer_pb2.GetFeatureServiceRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetFeatureService(request) - return FeatureService.from_proto(response) - - def list_feature_services( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureService]: - request = RegistryServer_pb2.ListFeatureServicesRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListFeatureServices(request) - return [ - FeatureService.from_proto(feature_service) - for feature_service in response.feature_services - ] - - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - if isinstance(feature_view, StreamFeatureView): - arg_name = "stream_feature_view" - elif isinstance(feature_view, FeatureView): - arg_name = "feature_view" - elif isinstance(feature_view, OnDemandFeatureView): - arg_name = "on_demand_feature_view" - - request = RegistryServer_pb2.ApplyFeatureViewRequest( - feature_view=( - feature_view.to_proto() if arg_name == "feature_view" else None - ), - stream_feature_view=( - feature_view.to_proto() if arg_name == "stream_feature_view" else None - ), - on_demand_feature_view=( - feature_view.to_proto() - if arg_name == "on_demand_feature_view" - else None - ), - project=project, - commit=commit, - ) - - self.stub.ApplyFeatureView(request) - - def delete_feature_view(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteFeatureViewRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteFeatureView(request) - - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> StreamFeatureView: - request = RegistryServer_pb2.GetStreamFeatureViewRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetStreamFeatureView(request) - return StreamFeatureView.from_proto(response) - - def list_stream_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[StreamFeatureView]: - request = RegistryServer_pb2.ListStreamFeatureViewsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListStreamFeatureViews(request) - return [ - StreamFeatureView.from_proto(stream_feature_view) - for stream_feature_view in response.stream_feature_views - ] - - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> OnDemandFeatureView: - request = RegistryServer_pb2.GetOnDemandFeatureViewRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetOnDemandFeatureView(request) - return OnDemandFeatureView.from_proto(response) - - def list_on_demand_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[OnDemandFeatureView]: - request = RegistryServer_pb2.ListOnDemandFeatureViewsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListOnDemandFeatureViews(request) - return [ - OnDemandFeatureView.from_proto(on_demand_feature_view) - for on_demand_feature_view in response.on_demand_feature_views - ] - - def get_any_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> BaseFeatureView: - request = RegistryServer_pb2.GetAnyFeatureViewRequest( - name=name, project=project, allow_cache=allow_cache - ) - - response: RegistryServer_pb2.GetAnyFeatureViewResponse = ( - self.stub.GetAnyFeatureView(request) - ) - any_feature_view = response.any_feature_view - return extract_base_feature_view(any_feature_view) - - def list_all_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[BaseFeatureView]: - request = RegistryServer_pb2.ListAllFeatureViewsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - - response: RegistryServer_pb2.ListAllFeatureViewsResponse = ( - self.stub.ListAllFeatureViews(request) - ) - return [ - extract_base_feature_view(any_feature_view) - for any_feature_view in response.feature_views - ] - - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - request = RegistryServer_pb2.GetFeatureViewRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetFeatureView(request) - return FeatureView.from_proto(response) - - def list_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureView]: - request = RegistryServer_pb2.ListFeatureViewsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListFeatureViews(request) - - return [ - FeatureView.from_proto(feature_view) - for feature_view in response.feature_views - ] - - def apply_materialization( - self, - feature_view: FeatureView, - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - start_date_timestamp = Timestamp() - end_date_timestamp = Timestamp() - start_date_timestamp.FromDatetime(start_date) - end_date_timestamp.FromDatetime(end_date) - - # TODO: for this to work for stream feature views, ApplyMaterializationRequest needs to be updated - request = RegistryServer_pb2.ApplyMaterializationRequest( - feature_view=feature_view.to_proto(), - project=project, - start_date=start_date_timestamp, - end_date=end_date_timestamp, - commit=commit, - ) - self.stub.ApplyMaterialization(request) - - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - request = RegistryServer_pb2.ApplySavedDatasetRequest( - saved_dataset=saved_dataset.to_proto(), project=project, commit=commit - ) - self.stub.ApplyFeatureService(request) - - def delete_saved_dataset(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteSavedDatasetRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteSavedDataset(request) - - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - request = RegistryServer_pb2.GetSavedDatasetRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetSavedDataset(request) - return SavedDataset.from_proto(response) - - def list_saved_datasets( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[SavedDataset]: - request = RegistryServer_pb2.ListSavedDatasetsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListSavedDatasets(request) - return [ - SavedDataset.from_proto(saved_dataset) - for saved_dataset in response.saved_datasets - ] - - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - request = RegistryServer_pb2.ApplyValidationReferenceRequest( - validation_reference=validation_reference.to_proto(), - project=project, - commit=commit, - ) - self.stub.ApplyValidationReference(request) - - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteValidationReferenceRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteValidationReference(request) - - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - request = RegistryServer_pb2.GetValidationReferenceRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetValidationReference(request) - return ValidationReference.from_proto(response) - - def list_validation_references( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[ValidationReference]: - request = RegistryServer_pb2.ListValidationReferencesRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListValidationReferences(request) - return [ - ValidationReference.from_proto(validation_reference) - for validation_reference in response.validation_references - ] - - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - request = RegistryServer_pb2.ListProjectMetadataRequest( - project=project, allow_cache=allow_cache - ) - response = self.stub.ListProjectMetadata(request) - return [ProjectMetadata.from_proto(pm) for pm in response.project_metadata] - - def update_infra(self, infra: Infra, project: str, commit: bool = True): - request = RegistryServer_pb2.UpdateInfraRequest( - infra=infra.to_proto(), project=project, commit=commit - ) - self.stub.UpdateInfra(request) - - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - request = RegistryServer_pb2.GetInfraRequest( - project=project, allow_cache=allow_cache - ) - response = self.stub.GetInfra(request) - return Infra.from_proto(response) - - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): - pass - - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: - pass - - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - permission_proto = permission.to_proto() - permission_proto.spec.project = project - - request = RegistryServer_pb2.ApplyPermissionRequest( - permission=permission_proto, project=project, commit=commit - ) - self.stub.ApplyPermission(request) - - def delete_permission(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeletePermissionRequest( - name=name, project=project, commit=commit - ) - self.stub.DeletePermission(request) - - def get_permission( - self, name: str, project: str, allow_cache: bool = False - ) -> Permission: - request = RegistryServer_pb2.GetPermissionRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetPermission(request) - - return Permission.from_proto(response) - - def list_permissions( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Permission]: - request = RegistryServer_pb2.ListPermissionsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListPermissions(request) - return [ - Permission.from_proto(permission) for permission in response.permissions - ] - - def apply_project( - self, - project: Project, - commit: bool = True, - ): - project_proto = project.to_proto() - - request = RegistryServer_pb2.ApplyProjectRequest( - project=project_proto, commit=commit - ) - self.stub.ApplyProject(request) - - def delete_project( - self, - name: str, - commit: bool = True, - ): - request = RegistryServer_pb2.DeleteProjectRequest(name=name, commit=commit) - self.stub.DeleteProject(request) - - def get_project( - self, - name: str, - allow_cache: bool = False, - ) -> Project: - request = RegistryServer_pb2.GetProjectRequest( - name=name, allow_cache=allow_cache - ) - response = self.stub.GetProject(request) - - return Project.from_proto(response) - - def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Project]: - request = RegistryServer_pb2.ListProjectsRequest( - allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListProjects(request) - return [Project.from_proto(project) for project in response.projects] - - def proto(self) -> RegistryProto: - return self.stub.Proto(Empty()) - - def commit(self): - self.stub.Commit(Empty()) - - def refresh(self, project: Optional[str] = None): - request = RegistryServer_pb2.RefreshRequest(project=str(project)) - self.stub.Refresh(request) - - def teardown(self): - pass +import os +from datetime import datetime +from pathlib import Path +from typing import List, Optional, Union + +import grpc +from google.protobuf.empty_pb2 import Empty +from google.protobuf.timestamp_pb2 import Timestamp +from pydantic import StrictStr + +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.auth_model import AuthConfig, NoAuthConfig +from feast.permissions.client.grpc_client_auth_interceptor import ( + GrpcClientAuthHeaderInterceptor, +) +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc +from feast.repo_config import RegistryConfig +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView + + +def extract_base_feature_view( + any_feature_view: RegistryServer_pb2.AnyFeatureView, +) -> BaseFeatureView: + feature_view_type = any_feature_view.WhichOneof("any_feature_view") + if feature_view_type == "feature_view": + feature_view = FeatureView.from_proto(any_feature_view.feature_view) + elif feature_view_type == "on_demand_feature_view": + feature_view = OnDemandFeatureView.from_proto( + any_feature_view.on_demand_feature_view + ) + elif feature_view_type == "stream_feature_view": + feature_view = StreamFeatureView.from_proto( + any_feature_view.stream_feature_view + ) + + return feature_view + + +class RemoteRegistryConfig(RegistryConfig): + registry_type: StrictStr = "remote" + """ str: Provider name or a class name that implements Registry.""" + + path: StrictStr = "" + """ str: Path to metadata store. + If registry_type is 'remote', then this is a URL for registry server """ + + cert: StrictStr = "" + """ str: Path to the public certificate when the registry server starts in TLS(SSL) mode. This may be needed if the registry server started with a self-signed certificate, typically this file ends with `*.crt`, `*.cer`, or `*.pem`. + If registry_type is 'remote', then this configuration is needed to connect to remote registry server in TLS mode. If the remote registry started in non-tls mode then this configuration is not needed.""" + + is_tls: bool = False + """ bool: Set to `True` if you plan to connect to a registry server running in TLS (SSL) mode. + If you intend to add the public certificate to the trust store instead of passing it via the `cert` parameter, this field must be set to `True`. + If you are planning to add the public certificate as part of the trust store instead of passing it as a `cert` parameters then setting this field to `true` is mandatory. + """ + + +class RemoteRegistry(BaseRegistry): + def __init__( + self, + registry_config: Union[RegistryConfig, RemoteRegistryConfig], + project: str, + repo_path: Optional[Path], + auth_config: AuthConfig = NoAuthConfig(), + ): + self.auth_config = auth_config + assert isinstance(registry_config, RemoteRegistryConfig) + self.channel = self._create_grpc_channel(registry_config) + + auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) + self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor) + self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel) + + def _create_grpc_channel(self, registry_config): + assert isinstance(registry_config, RemoteRegistryConfig) + if registry_config.cert or registry_config.is_tls: + cafile = os.getenv("SSL_CERT_FILE") or os.getenv("REQUESTS_CA_BUNDLE") + if not cafile and not registry_config.cert: + raise EnvironmentError( + "SSL_CERT_FILE or REQUESTS_CA_BUNDLE environment variable must be set to use secure TLS or set the cert parameter in feature_Store.yaml file under remote registry configuration." + ) + with open( + registry_config.cert if registry_config.cert else cafile, "rb" + ) as cert_file: + trusted_certs = cert_file.read() + tls_credentials = grpc.ssl_channel_credentials( + root_certificates=trusted_certs + ) + return grpc.secure_channel(registry_config.path, tls_credentials) + else: + # Create an insecure gRPC channel + return grpc.insecure_channel(registry_config.path) + + def close(self): + if self.channel: + self.channel.close() + + def __del__(self): + self.close() + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + request = RegistryServer_pb2.ApplyEntityRequest( + entity=entity.to_proto(), project=project, commit=commit + ) + self.stub.ApplyEntity(request) + + def delete_entity(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteEntityRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteEntity(request) + + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + request = RegistryServer_pb2.GetEntityRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetEntity(request) + return Entity.from_proto(response) + + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: + request = RegistryServer_pb2.ListEntitiesRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListEntities(request) + return [Entity.from_proto(entity) for entity in response.entities] + + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + request = RegistryServer_pb2.ApplyDataSourceRequest( + data_source=data_source.to_proto(), project=project, commit=commit + ) + self.stub.ApplyDataSource(request) + + def delete_data_source(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteDataSourceRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteDataSource(request) + + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + request = RegistryServer_pb2.GetDataSourceRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetDataSource(request) + return DataSource.from_proto(response) + + def list_data_sources( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[DataSource]: + request = RegistryServer_pb2.ListDataSourcesRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListDataSources(request) + return [ + DataSource.from_proto(data_source) for data_source in response.data_sources + ] + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + request = RegistryServer_pb2.ApplyFeatureServiceRequest( + feature_service=feature_service.to_proto(), project=project, commit=commit + ) + self.stub.ApplyFeatureService(request) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteFeatureServiceRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteFeatureService(request) + + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + request = RegistryServer_pb2.GetFeatureServiceRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetFeatureService(request) + return FeatureService.from_proto(response) + + def list_feature_services( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureService]: + request = RegistryServer_pb2.ListFeatureServicesRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListFeatureServices(request) + return [ + FeatureService.from_proto(feature_service) + for feature_service in response.feature_services + ] + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + if isinstance(feature_view, StreamFeatureView): + arg_name = "stream_feature_view" + elif isinstance(feature_view, FeatureView): + arg_name = "feature_view" + elif isinstance(feature_view, OnDemandFeatureView): + arg_name = "on_demand_feature_view" + + request = RegistryServer_pb2.ApplyFeatureViewRequest( + feature_view=( + feature_view.to_proto() if arg_name == "feature_view" else None + ), + stream_feature_view=( + feature_view.to_proto() if arg_name == "stream_feature_view" else None + ), + on_demand_feature_view=( + feature_view.to_proto() + if arg_name == "on_demand_feature_view" + else None + ), + project=project, + commit=commit, + ) + + self.stub.ApplyFeatureView(request) + + def delete_feature_view(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteFeatureViewRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteFeatureView(request) + + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> StreamFeatureView: + request = RegistryServer_pb2.GetStreamFeatureViewRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetStreamFeatureView(request) + return StreamFeatureView.from_proto(response) + + def list_stream_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[StreamFeatureView]: + request = RegistryServer_pb2.ListStreamFeatureViewsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListStreamFeatureViews(request) + return [ + StreamFeatureView.from_proto(stream_feature_view) + for stream_feature_view in response.stream_feature_views + ] + + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + request = RegistryServer_pb2.GetOnDemandFeatureViewRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetOnDemandFeatureView(request) + return OnDemandFeatureView.from_proto(response) + + def list_on_demand_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[OnDemandFeatureView]: + request = RegistryServer_pb2.ListOnDemandFeatureViewsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListOnDemandFeatureViews(request) + return [ + OnDemandFeatureView.from_proto(on_demand_feature_view) + for on_demand_feature_view in response.on_demand_feature_views + ] + + def get_any_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> BaseFeatureView: + request = RegistryServer_pb2.GetAnyFeatureViewRequest( + name=name, project=project, allow_cache=allow_cache + ) + + response: RegistryServer_pb2.GetAnyFeatureViewResponse = ( + self.stub.GetAnyFeatureView(request) + ) + any_feature_view = response.any_feature_view + return extract_base_feature_view(any_feature_view) + + def list_all_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[BaseFeatureView]: + request = RegistryServer_pb2.ListAllFeatureViewsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + + response: RegistryServer_pb2.ListAllFeatureViewsResponse = ( + self.stub.ListAllFeatureViews(request) + ) + return [ + extract_base_feature_view(any_feature_view) + for any_feature_view in response.feature_views + ] + + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + request = RegistryServer_pb2.GetFeatureViewRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetFeatureView(request) + return FeatureView.from_proto(response) + + def list_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureView]: + request = RegistryServer_pb2.ListFeatureViewsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListFeatureViews(request) + + return [ + FeatureView.from_proto(feature_view) + for feature_view in response.feature_views + ] + + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + start_date_timestamp = Timestamp() + end_date_timestamp = Timestamp() + start_date_timestamp.FromDatetime(start_date) + end_date_timestamp.FromDatetime(end_date) + + # TODO: for this to work for stream feature views, ApplyMaterializationRequest needs to be updated + request = RegistryServer_pb2.ApplyMaterializationRequest( + feature_view=feature_view.to_proto(), + project=project, + start_date=start_date_timestamp, + end_date=end_date_timestamp, + commit=commit, + ) + self.stub.ApplyMaterialization(request) + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + request = RegistryServer_pb2.ApplySavedDatasetRequest( + saved_dataset=saved_dataset.to_proto(), project=project, commit=commit + ) + self.stub.ApplyFeatureService(request) + + def delete_saved_dataset(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteSavedDatasetRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteSavedDataset(request) + + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + request = RegistryServer_pb2.GetSavedDatasetRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetSavedDataset(request) + return SavedDataset.from_proto(response) + + def list_saved_datasets( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[SavedDataset]: + request = RegistryServer_pb2.ListSavedDatasetsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListSavedDatasets(request) + return [ + SavedDataset.from_proto(saved_dataset) + for saved_dataset in response.saved_datasets + ] + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + request = RegistryServer_pb2.ApplyValidationReferenceRequest( + validation_reference=validation_reference.to_proto(), + project=project, + commit=commit, + ) + self.stub.ApplyValidationReference(request) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteValidationReferenceRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteValidationReference(request) + + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + request = RegistryServer_pb2.GetValidationReferenceRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetValidationReference(request) + return ValidationReference.from_proto(response) + + def list_validation_references( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[ValidationReference]: + request = RegistryServer_pb2.ListValidationReferencesRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListValidationReferences(request) + return [ + ValidationReference.from_proto(validation_reference) + for validation_reference in response.validation_references + ] + + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + request = RegistryServer_pb2.ListProjectMetadataRequest( + project=project, allow_cache=allow_cache + ) + response = self.stub.ListProjectMetadata(request) + return [ProjectMetadata.from_proto(pm) for pm in response.project_metadata] + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + request = RegistryServer_pb2.UpdateInfraRequest( + infra=infra.to_proto(), project=project, commit=commit + ) + self.stub.UpdateInfra(request) + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + request = RegistryServer_pb2.GetInfraRequest( + project=project, allow_cache=allow_cache + ) + response = self.stub.GetInfra(request) + return Infra.from_proto(response) + + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + pass + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + pass + + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + permission_proto = permission.to_proto() + permission_proto.spec.project = project + + request = RegistryServer_pb2.ApplyPermissionRequest( + permission=permission_proto, project=project, commit=commit + ) + self.stub.ApplyPermission(request) + + def delete_permission(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeletePermissionRequest( + name=name, project=project, commit=commit + ) + self.stub.DeletePermission(request) + + def get_permission( + self, name: str, project: str, allow_cache: bool = False + ) -> Permission: + request = RegistryServer_pb2.GetPermissionRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetPermission(request) + + return Permission.from_proto(response) + + def list_permissions( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Permission]: + request = RegistryServer_pb2.ListPermissionsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListPermissions(request) + return [ + Permission.from_proto(permission) for permission in response.permissions + ] + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + project_proto = project.to_proto() + + request = RegistryServer_pb2.ApplyProjectRequest( + project=project_proto, commit=commit + ) + self.stub.ApplyProject(request) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + request = RegistryServer_pb2.DeleteProjectRequest(name=name, commit=commit) + self.stub.DeleteProject(request) + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + request = RegistryServer_pb2.GetProjectRequest( + name=name, allow_cache=allow_cache + ) + response = self.stub.GetProject(request) + + return Project.from_proto(response) + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + request = RegistryServer_pb2.ListProjectsRequest( + allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListProjects(request) + return [Project.from_proto(project) for project in response.projects] + + def proto(self) -> RegistryProto: + return self.stub.Proto(Empty()) + + def commit(self): + self.stub.Commit(Empty()) + + def refresh(self, project: Optional[str] = None): + request = RegistryServer_pb2.RefreshRequest(project=str(project)) + self.stub.Refresh(request) + + def teardown(self): + pass diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 06403fe9aee..71a8aa067ec 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -1,1375 +1,1375 @@ -import logging -import os -import uuid -from binascii import hexlify -from datetime import datetime, timedelta, timezone -from enum import Enum -from threading import Lock -from typing import Any, Callable, List, Literal, Optional, Union, cast - -from pydantic import ConfigDict, Field, StrictStr - -import feast -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.errors import ( - DataSourceObjectNotFoundException, - EntityNotFoundException, - FeatureServiceNotFoundException, - FeatureViewNotFoundException, - PermissionNotFoundException, - ProjectNotFoundException, - ProjectObjectNotFoundException, - SavedDatasetNotFound, - ValidationReferenceNotFound, -) -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.infra.infra_object import Infra -from feast.infra.registry import proto_registry_utils -from feast.infra.registry.base_registry import BaseRegistry -from feast.infra.utils.snowflake.snowflake_utils import ( - GetSnowflakeConnection, - execute_snowflake_statement, -) -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto -from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto -from feast.protos.feast.core.FeatureService_pb2 import ( - FeatureService as FeatureServiceProto, -) -from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto -from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto -from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( - OnDemandFeatureView as OnDemandFeatureViewProto, -) -from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto -from feast.protos.feast.core.Project_pb2 import Project as ProjectProto -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto -from feast.protos.feast.core.StreamFeatureView_pb2 import ( - StreamFeatureView as StreamFeatureViewProto, -) -from feast.protos.feast.core.ValidationProfile_pb2 import ( - ValidationReference as ValidationReferenceProto, -) -from feast.repo_config import RegistryConfig -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView -from feast.utils import _utc_now, has_all_tags - -logger = logging.getLogger(__name__) - - -class FeastMetadataKeys(Enum): - LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" - PROJECT_UUID = "project_uuid" - - -class SnowflakeRegistryConfig(RegistryConfig): - """Registry config for Snowflake""" - - registry_type: Literal["snowflake.registry"] = "snowflake.registry" - """ Registry type selector """ - - type: Literal["snowflake.registry"] = "snowflake.registry" - """ Registry type selector """ - - config_path: Optional[str] = os.path.expanduser("~/.snowsql/config") - """ Snowflake snowsql config path -- absolute path required (Cant use ~)""" - - connection_name: Optional[str] = None - """ Snowflake connector connection name -- typically defined in ~/.snowflake/connections.toml """ - - account: Optional[str] = None - """ Snowflake deployment identifier -- drop .snowflakecomputing.com """ - - user: Optional[str] = None - """ Snowflake user name """ - - password: Optional[str] = None - """ Snowflake password """ - - role: Optional[str] = None - """ Snowflake role name """ - - warehouse: Optional[str] = None - """ Snowflake warehouse name """ - - authenticator: Optional[str] = None - """ Snowflake authenticator name """ - - private_key: Optional[str] = None - """ Snowflake private key file path""" - - private_key_content: Optional[bytes] = None - """ Snowflake private key stored as bytes""" - - private_key_passphrase: Optional[str] = None - """ Snowflake private key file passphrase""" - - database: StrictStr - """ Snowflake database name """ - - schema_: Optional[str] = Field("PUBLIC", alias="schema") - """ Snowflake schema name """ - model_config = ConfigDict(populate_by_name=True) - - -class SnowflakeRegistry(BaseRegistry): - def __init__( - self, - registry_config, - project: str, - repo_path, - ): - assert registry_config is not None and isinstance( - registry_config, SnowflakeRegistryConfig - ), "SnowflakeRegistry needs a valid registry_config, a path does not work" - - self.registry_config = registry_config - self.registry_path = ( - f'"{self.registry_config.database}"."{self.registry_config.schema_}"' - ) - - with GetSnowflakeConnection(self.registry_config) as conn: - sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql" - with open(sql_function_file, "r") as file: - sql_file = file.read() - sql_cmds = sql_file.split(";") - for command in sql_cmds: - query = command.replace("REGISTRY_PATH", f"{self.registry_path}") - execute_snowflake_statement(conn, query) - - self.purge_feast_metadata = registry_config.purge_feast_metadata - self._sync_feast_metadata_to_projects_table() - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) - - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = _utc_now() - self._refresh_lock = Lock() - self.cached_registry_proto_ttl = timedelta( - seconds=( - registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 - ) - ) - self.project = project - - def _sync_feast_metadata_to_projects_table(self): - feast_metadata_projects: set = [] - projects_set: set = [] - - with GetSnowflakeConnection(self.registry_config) as conn: - query = ( - f'SELECT DISTINCT project_id FROM {self.registry_path}."FEAST_METADATA"' - ) - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - for row in df.iterrows(): - feast_metadata_projects.add(row[1]["PROJECT_ID"]) - - if len(feast_metadata_projects) > 0: - with GetSnowflakeConnection(self.registry_config) as conn: - query = f'SELECT project_id FROM {self.registry_path}."PROJECTS"' - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - for row in df.iterrows(): - projects_set.add(row[1]["PROJECT_ID"]) - - # Find object in feast_metadata_projects but not in projects - projects_to_sync = set(feast_metadata_projects) - set(projects_set) - for project_name in projects_to_sync: - self.apply_project(Project(name=project_name), commit=True) - - if self.purge_feast_metadata: - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - DELETE FROM {self.registry_path}."FEAST_METADATA" - """ - execute_snowflake_statement(conn, query) - - def refresh(self, project: Optional[str] = None): - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = _utc_now() - - def _refresh_cached_registry_if_necessary(self): - with self._refresh_lock: - expired = ( - self.cached_registry_proto is None - or self.cached_registry_proto_created is None - ) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - _utc_now() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl - ) - ) - ) - - if expired: - logger.info("Registry cache expired, so refreshing") - self.refresh() - - def teardown(self): - with GetSnowflakeConnection(self.registry_config) as conn: - sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql" - with open(sql_function_file, "r") as file: - sqlFile = file.read() - sqlCommands = sqlFile.split(";") - for command in sqlCommands: - query = command.replace("REGISTRY_PATH", f"{self.registry_path}") - execute_snowflake_statement(conn, query) - - # apply operations - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - return self._apply_object( - "DATA_SOURCES", - project, - "DATA_SOURCE_NAME", - data_source, - "DATA_SOURCE_PROTO", - ) - - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - return self._apply_object( - "ENTITIES", project, "ENTITY_NAME", entity, "ENTITY_PROTO" - ) - - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - return self._apply_object( - "FEATURE_SERVICES", - project, - "FEATURE_SERVICE_NAME", - feature_service, - "FEATURE_SERVICE_PROTO", - ) - - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - fv_table_str = self._infer_fv_table(feature_view) - fv_column_name = fv_table_str[:-1] - return self._apply_object( - fv_table_str, - project, - f"{fv_column_name}_NAME", - feature_view, - f"{fv_column_name}_PROTO", - ) - - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - return self._apply_object( - "SAVED_DATASETS", - project, - "SAVED_DATASET_NAME", - saved_dataset, - "SAVED_DATASET_PROTO", - ) - - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - return self._apply_object( - "VALIDATION_REFERENCES", - project, - "VALIDATION_REFERENCE_NAME", - validation_reference, - "VALIDATION_REFERENCE_PROTO", - ) - - def update_infra(self, infra: Infra, project: str, commit: bool = True): - self._apply_object( - "MANAGED_INFRA", - project, - "INFRA_NAME", - infra, - "INFRA_PROTO", - name="infra_obj", - ) - - def _initialize_project_if_not_exists(self, project_name: str): - try: - self.get_project(project_name, allow_cache=True) - return - except ProjectObjectNotFoundException: - try: - self.get_project(project_name, allow_cache=False) - return - except ProjectObjectNotFoundException: - self.apply_project(Project(name=project_name), commit=True) - - def _apply_object( - self, - table: str, - project: str, - id_field_name: str, - obj: Any, - proto_field_name: str, - name: Optional[str] = None, - ): - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) - # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option - if not isinstance(obj, Project): - self._initialize_project_if_not_exists(project_name=project) - - name = name or (obj.name if hasattr(obj, "name") else None) - assert name, f"name needs to be provided for {obj}" - - update_datetime = _utc_now() - if hasattr(obj, "last_updated_timestamp"): - obj.last_updated_timestamp = update_datetime - - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - project_id - FROM - {self.registry_path}."{table}" - WHERE - project_id = '{project}' - AND {id_field_name.lower()} = '{name}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - proto = hexlify(obj.to_proto().SerializeToString()).__str__()[1:] - query = f""" - UPDATE {self.registry_path}."{table}" - SET - {proto_field_name} = TO_BINARY({proto}), - last_updated_timestamp = CURRENT_TIMESTAMP() - WHERE - {id_field_name.lower()} = '{name}' - """ - execute_snowflake_statement(conn, query) - - else: - obj_proto = obj.to_proto() - - if hasattr(obj_proto, "meta") and hasattr( - obj_proto.meta, "created_timestamp" - ): - obj_proto.meta.created_timestamp.FromDatetime(update_datetime) - - proto = hexlify(obj_proto.SerializeToString()).__str__()[1:] - if table == "FEATURE_VIEWS": - query = f""" - INSERT INTO {self.registry_path}."{table}" - VALUES - ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '', '') - """ - elif "_FEATURE_VIEWS" in table: - query = f""" - INSERT INTO {self.registry_path}."{table}" - VALUES - ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '') - """ - else: - query = f""" - INSERT INTO {self.registry_path}."{table}" - VALUES - ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto})) - """ - execute_snowflake_statement(conn, query) - - if not isinstance(obj, Project): - self.apply_project( - self.get_project(name=project, allow_cache=False), commit=True - ) - - if not self.purge_feast_metadata: - self._set_last_updated_metadata(update_datetime, project) - - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - return self._apply_object( - "PERMISSIONS", - project, - "PERMISSION_NAME", - permission, - "PERMISSION_PROTO", - ) - - # delete operations - def delete_data_source(self, name: str, project: str, commit: bool = True): - return self._delete_object( - "DATA_SOURCES", - name, - project, - "DATA_SOURCE_NAME", - DataSourceObjectNotFoundException, - ) - - def delete_entity(self, name: str, project: str, commit: bool = True): - return self._delete_object( - "ENTITIES", name, project, "ENTITY_NAME", EntityNotFoundException - ) - - def delete_feature_service(self, name: str, project: str, commit: bool = True): - return self._delete_object( - "FEATURE_SERVICES", - name, - project, - "FEATURE_SERVICE_NAME", - FeatureServiceNotFoundException, - ) - - # can you have featureviews with the same name - def delete_feature_view(self, name: str, project: str, commit: bool = True): - deleted_count = 0 - for table in { - "FEATURE_VIEWS", - "ON_DEMAND_FEATURE_VIEWS", - "STREAM_FEATURE_VIEWS", - }: - deleted_count += self._delete_object( - table, name, project, "FEATURE_VIEW_NAME", None - ) - if deleted_count == 0: - raise FeatureViewNotFoundException(name, project) - - def delete_saved_dataset(self, name: str, project: str, allow_cache: bool = False): - self._delete_object( - "SAVED_DATASETS", - name, - project, - "SAVED_DATASET_NAME", - SavedDatasetNotFound, - ) - - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - self._delete_object( - "VALIDATION_REFERENCES", - name, - project, - "VALIDATION_REFERENCE_NAME", - ValidationReferenceNotFound, - ) - - def _delete_object( - self, - table: str, - name: str, - project: str, - id_field_name: str, - not_found_exception: Optional[Callable], - ): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - DELETE FROM {self.registry_path}."{table}" - WHERE - project_id = '{project}' - AND {id_field_name.lower()} = '{name}' - """ - cursor = execute_snowflake_statement(conn, query) - - if cursor.rowcount < 1 and not_found_exception: # type: ignore - raise not_found_exception(name, project) - self._set_last_updated_metadata(_utc_now(), project) - - return cursor.rowcount - - def delete_permission(self, name: str, project: str, commit: bool = True): - return self._delete_object( - "PERMISSIONS", - name, - project, - "PERMISSION_NAME", - PermissionNotFoundException, - ) - - # get operations - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_data_source( - self.cached_registry_proto, name, project - ) - return self._get_object( - "DATA_SOURCES", - name, - project, - DataSourceProto, - DataSource, - "DATA_SOURCE_NAME", - "DATA_SOURCE_PROTO", - DataSourceObjectNotFoundException, - ) - - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_entity( - self.cached_registry_proto, name, project - ) - return self._get_object( - "ENTITIES", - name, - project, - EntityProto, - Entity, - "ENTITY_NAME", - "ENTITY_PROTO", - EntityNotFoundException, - ) - - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_feature_service( - self.cached_registry_proto, name, project - ) - return self._get_object( - "FEATURE_SERVICES", - name, - project, - FeatureServiceProto, - FeatureService, - "FEATURE_SERVICE_NAME", - "FEATURE_SERVICE_PROTO", - FeatureServiceNotFoundException, - ) - - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_feature_view( - self.cached_registry_proto, name, project - ) - return self._get_object( - "FEATURE_VIEWS", - name, - project, - FeatureViewProto, - FeatureView, - "FEATURE_VIEW_NAME", - "FEATURE_VIEW_PROTO", - FeatureViewNotFoundException, - ) - - def get_any_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> BaseFeatureView: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_any_feature_view( - self.cached_registry_proto, name, project - ) - fv = self._get_object( - "FEATURE_VIEWS", - name, - project, - FeatureViewProto, - FeatureView, - "FEATURE_VIEW_NAME", - "FEATURE_VIEW_PROTO", - None, - ) - - if not fv: - fv = self._get_object( - "STREAM_FEATURE_VIEWS", - name, - project, - StreamFeatureViewProto, - StreamFeatureView, - "STREAM_FEATURE_VIEW_NAME", - "STREAM_FEATURE_VIEW_PROTO", - None, - ) - if not fv: - fv = self._get_object( - "ON_DEMAND_FEATURE_VIEWS", - name, - project, - OnDemandFeatureViewProto, - OnDemandFeatureView, - "ON_DEMAND_FEATURE_VIEW_NAME", - "ON_DEMAND_FEATURE_VIEW_PROTO", - FeatureViewNotFoundException, - ) - return fv - - def list_all_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[BaseFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_all_feature_views( - self.cached_registry_proto, project, tags - ) - - return ( - cast( - list[BaseFeatureView], - self.list_feature_views(project, allow_cache, tags), - ) - + cast( - list[BaseFeatureView], - self.list_stream_feature_views(project, allow_cache, tags), - ) - + cast( - list[BaseFeatureView], - self.list_on_demand_feature_views(project, allow_cache, tags), - ) - ) - - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - infra_object = self._get_object( - "MANAGED_INFRA", - "infra_obj", - project, - InfraProto, - Infra, - "INFRA_NAME", - "INFRA_PROTO", - None, - ) - infra_object = infra_object or InfraProto() - return Infra.from_proto(infra_object) - - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> OnDemandFeatureView: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_on_demand_feature_view( - self.cached_registry_proto, name, project - ) - return self._get_object( - "ON_DEMAND_FEATURE_VIEWS", - name, - project, - OnDemandFeatureViewProto, - OnDemandFeatureView, - "ON_DEMAND_FEATURE_VIEW_NAME", - "ON_DEMAND_FEATURE_VIEW_PROTO", - FeatureViewNotFoundException, - ) - - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_saved_dataset( - self.cached_registry_proto, name, project - ) - return self._get_object( - "SAVED_DATASETS", - name, - project, - SavedDatasetProto, - SavedDataset, - "SAVED_DATASET_NAME", - "SAVED_DATASET_PROTO", - SavedDatasetNotFound, - ) - - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ): - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_stream_feature_view( - self.cached_registry_proto, name, project - ) - return self._get_object( - "STREAM_FEATURE_VIEWS", - name, - project, - StreamFeatureViewProto, - StreamFeatureView, - "STREAM_FEATURE_VIEW_NAME", - "STREAM_FEATURE_VIEW_PROTO", - FeatureViewNotFoundException, - ) - - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_validation_reference( - self.cached_registry_proto, name, project - ) - return self._get_object( - "VALIDATION_REFERENCES", - name, - project, - ValidationReferenceProto, - ValidationReference, - "VALIDATION_REFERENCE_NAME", - "VALIDATION_REFERENCE_PROTO", - ValidationReferenceNotFound, - ) - - def _get_object( - self, - table: str, - name: str, - project: str, - proto_class: Any, - python_class: Any, - id_field_name: str, - proto_field_name: str, - not_found_exception: Optional[Callable], - ): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - {proto_field_name} - FROM - {self.registry_path}."{table}" - WHERE - project_id = '{project}' - AND {id_field_name.lower()} = '{name}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - _proto = proto_class.FromString(df.squeeze()) - return python_class.from_proto(_proto) - elif not_found_exception: - raise not_found_exception(name, project) - else: - return None - - def get_permission( - self, name: str, project: str, allow_cache: bool = False - ) -> Permission: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_permission( - self.cached_registry_proto, name, project - ) - return self._get_object( - "PERMISSIONS", - name, - project, - PermissionProto, - Permission, - "PERMISSION_NAME", - "PERMISSION_PROTO", - PermissionNotFoundException, - ) - - # list operations - def list_data_sources( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[DataSource]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_data_sources( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "DATA_SOURCES", - project, - DataSourceProto, - DataSource, - "DATA_SOURCE_PROTO", - tags=tags, - ) - - def list_entities( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Entity]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_entities( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "ENTITIES", project, EntityProto, Entity, "ENTITY_PROTO", tags=tags - ) - - def list_feature_services( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureService]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_feature_services( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "FEATURE_SERVICES", - project, - FeatureServiceProto, - FeatureService, - "FEATURE_SERVICE_PROTO", - tags=tags, - ) - - def list_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_feature_views( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "FEATURE_VIEWS", - project, - FeatureViewProto, - FeatureView, - "FEATURE_VIEW_PROTO", - tags=tags, - ) - - def list_on_demand_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[OnDemandFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_on_demand_feature_views( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "ON_DEMAND_FEATURE_VIEWS", - project, - OnDemandFeatureViewProto, - OnDemandFeatureView, - "ON_DEMAND_FEATURE_VIEW_PROTO", - tags=tags, - ) - - def list_saved_datasets( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[SavedDataset]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_saved_datasets( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "SAVED_DATASETS", - project, - SavedDatasetProto, - SavedDataset, - "SAVED_DATASET_PROTO", - tags=tags, - ) - - def list_stream_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[StreamFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_stream_feature_views( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "STREAM_FEATURE_VIEWS", - project, - StreamFeatureViewProto, - StreamFeatureView, - "STREAM_FEATURE_VIEW_PROTO", - tags=tags, - ) - - def list_validation_references( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[ValidationReference]: - return self._list_objects( - "VALIDATION_REFERENCES", - project, - ValidationReferenceProto, - ValidationReference, - "VALIDATION_REFERENCE_PROTO", - tags=tags, - ) - - def _list_objects( - self, - table: str, - project: str, - proto_class: Any, - python_class: Any, - proto_field_name: str, - tags: Optional[dict[str, str]] = None, - ): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - {proto_field_name} - FROM - {self.registry_path}."{table}" - WHERE - project_id = '{project}' - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - if not df.empty: - objects = [] - for row in df.iterrows(): - obj = python_class.from_proto( - proto_class.FromString(row[1][proto_field_name]) - ) - if has_all_tags(obj.tags, tags): - objects.append(obj) - return objects - return [] - - def list_permissions( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Permission]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_permissions( - self.cached_registry_proto, project - ) - return self._list_objects( - "PERMISSIONS", - project, - PermissionProto, - Permission, - "PERMISSION_PROTO", - tags, - ) - - def apply_materialization( - self, - feature_view: FeatureView, - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - fv_table_str = self._infer_fv_table(feature_view) - fv_column_name = fv_table_str[:-1] - python_class, proto_class = self._infer_fv_classes(feature_view) - - if python_class in {OnDemandFeatureView}: - raise ValueError( - f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" - ) - fv: Union[FeatureView, StreamFeatureView] = self._get_object( - fv_table_str, - feature_view.name, - project, - proto_class, - python_class, - f"{fv_column_name}_NAME", - f"{fv_column_name}_PROTO", - FeatureViewNotFoundException, - ) - fv.materialization_intervals.append((start_date, end_date)) - self._apply_object( - fv_table_str, - project, - f"{fv_column_name}_NAME", - fv, - f"{fv_column_name}_PROTO", - ) - - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_project_metadata( - self.cached_registry_proto, project - ) - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - metadata_key, - metadata_value - FROM - {self.registry_path}."FEAST_METADATA" - WHERE - project_id = '{project}' - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - project_metadata = ProjectMetadata(project_name=project) - for row in df.iterrows(): - if row[1]["METADATA_KEY"] == FeastMetadataKeys.PROJECT_UUID.value: - project_metadata.project_uuid = row[1]["METADATA_VALUE"] - break - # TODO(adchia): Add other project metadata in a structured way - return [project_metadata] - return [] - - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): - fv_table_str = self._infer_fv_table(feature_view) - fv_column_name = fv_table_str[:-1].lower() - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - project_id - FROM - {self.registry_path}."{fv_table_str}" - WHERE - project_id = '{project}' - AND {fv_column_name}_name = '{feature_view.name}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - if metadata_bytes: - metadata_hex = hexlify(metadata_bytes).__str__()[1:] - else: - metadata_hex = "''" - query = f""" - UPDATE {self.registry_path}."{fv_table_str}" - SET - user_metadata = TO_BINARY({metadata_hex}), - last_updated_timestamp = CURRENT_TIMESTAMP() - WHERE - project_id = '{project}' - AND {fv_column_name}_name = '{feature_view.name}' - """ - execute_snowflake_statement(conn, query) - else: - raise FeatureViewNotFoundException(feature_view.name, project=project) - - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: - fv_table_str = self._infer_fv_table(feature_view) - fv_column_name = fv_table_str[:-1].lower() - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - user_metadata - FROM - {self.registry_path}."{fv_table_str}" - WHERE - {fv_column_name}_name = '{feature_view.name}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - return df.squeeze() - else: - raise FeatureViewNotFoundException(feature_view.name, project=project) - - def proto(self) -> RegistryProto: - r = RegistryProto() - last_updated_timestamps = [] - - def process_project(project: Project): - nonlocal r, last_updated_timestamps - project_name = project.name - last_updated_timestamp = project.last_updated_timestamp - - try: - cached_project = self.get_project(project_name, True) - except ProjectObjectNotFoundException: - cached_project = None - - allow_cache = False - - if cached_project is not None: - allow_cache = ( - last_updated_timestamp <= cached_project.last_updated_timestamp - ) - - r.projects.extend([project.to_proto()]) - last_updated_timestamps.append(last_updated_timestamp) - - for lister, registry_proto_field in [ - (self.list_entities, r.entities), - (self.list_feature_views, r.feature_views), - (self.list_data_sources, r.data_sources), - (self.list_on_demand_feature_views, r.on_demand_feature_views), - (self.list_stream_feature_views, r.stream_feature_views), - (self.list_feature_services, r.feature_services), - (self.list_saved_datasets, r.saved_datasets), - (self.list_validation_references, r.validation_references), - (self.list_permissions, r.permissions), - ]: - objs: List[Any] = lister(project_name, allow_cache) # type: ignore - if objs: - obj_protos = [obj.to_proto() for obj in objs] - for obj_proto in obj_protos: - if "spec" in obj_proto.DESCRIPTOR.fields_by_name: - obj_proto.spec.project = project_name - else: - obj_proto.project = project_name - registry_proto_field.extend(obj_protos) - - # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, - # the registry proto only has a single infra field, which we're currently setting as the "last" project. - r.infra.CopyFrom(self.get_infra(project_name).to_proto()) - - projects_list = self.list_projects(allow_cache=False) - for project in projects_list: - process_project(project) - - if last_updated_timestamps: - r.last_updated.FromDatetime(max(last_updated_timestamps)) - - return r - - def _get_last_updated_metadata(self, project: str): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - metadata_value - FROM - {self.registry_path}."FEAST_METADATA" - WHERE - project_id = '{project}' - AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if df.empty: - return None - - return datetime.fromtimestamp(int(df.squeeze()), tz=timezone.utc) - - def _infer_fv_classes(self, feature_view): - if isinstance(feature_view, StreamFeatureView): - python_class, proto_class = StreamFeatureView, StreamFeatureViewProto - elif isinstance(feature_view, FeatureView): - python_class, proto_class = FeatureView, FeatureViewProto - elif isinstance(feature_view, OnDemandFeatureView): - python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - return python_class, proto_class - - def _infer_fv_table(self, feature_view) -> str: - if isinstance(feature_view, StreamFeatureView): - table = "STREAM_FEATURE_VIEWS" - elif isinstance(feature_view, FeatureView): - table = "FEATURE_VIEWS" - elif isinstance(feature_view, OnDemandFeatureView): - table = "ON_DEMAND_FEATURE_VIEWS" - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - return table - - def _maybe_init_project_metadata(self, project): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - metadata_value - FROM - {self.registry_path}."FEAST_METADATA" - WHERE - project_id = '{project}' - AND metadata_key = '{FeastMetadataKeys.PROJECT_UUID.value}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if df.empty: - new_project_uuid = f"{uuid.uuid4()}" - query = f""" - INSERT INTO {self.registry_path}."FEAST_METADATA" - VALUES - ('{project}', '{FeastMetadataKeys.PROJECT_UUID.value}', '{new_project_uuid}', CURRENT_TIMESTAMP()) - """ - execute_snowflake_statement(conn, query) - - def _set_last_updated_metadata(self, last_updated: datetime, project: str): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - project_id - FROM - {self.registry_path}."FEAST_METADATA" - WHERE - project_id = '{project}' - AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - update_time = int(last_updated.timestamp()) - if not df.empty: - query = f""" - UPDATE {self.registry_path}."FEAST_METADATA" - SET - project_id = '{project}', - metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', - metadata_value = '{update_time}', - last_updated_timestamp = CURRENT_TIMESTAMP() - WHERE - project_id = '{project}' - AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' - """ - execute_snowflake_statement(conn, query) - - else: - query = f""" - INSERT INTO {self.registry_path}."FEAST_METADATA" - VALUES - ('{project}', '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', '{update_time}', CURRENT_TIMESTAMP()) - """ - execute_snowflake_statement(conn, query) - - def commit(self): - pass - - def apply_project( - self, - project: Project, - commit: bool = True, - ): - return self._apply_object( - "PROJECTS", project.name, "project_name", project, "project_proto" - ) - - def delete_project( - self, - name: str, - commit: bool = True, - ): - project = self.get_project(name, allow_cache=False) - if project: - with GetSnowflakeConnection(self.registry_config) as conn: - for table in { - "MANAGED_INFRA", - "SAVED_DATASETS", - "VALIDATION_REFERENCES", - "FEATURE_SERVICES", - "FEATURE_VIEWS", - "ON_DEMAND_FEATURE_VIEWS", - "STREAM_FEATURE_VIEWS", - "DATA_SOURCES", - "ENTITIES", - "PERMISSIONS", - "FEAST_METADATA", - "PROJECTS", - }: - query = f""" - DELETE FROM {self.registry_path}."{table}" - WHERE - project_id = '{project}' - """ - execute_snowflake_statement(conn, query) - return - - raise ProjectNotFoundException(name) - - def _get_project( - self, - name: str, - ) -> Project: - return self._get_object( - table="PROJECTS", - name=name, - project=name, - proto_class=ProjectProto, - python_class=Project, - id_field_name="project_name", - proto_field_name="project_proto", - not_found_exception=ProjectObjectNotFoundException, - ) - - def get_project( - self, - name: str, - allow_cache: bool = False, - ) -> Project: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_project(self.cached_registry_proto, name) - return self._get_project(name) - - def _list_projects( - self, - tags: Optional[dict[str, str]], - ) -> List[Project]: - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT project_proto FROM {self.registry_path}."PROJECTS" - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - if not df.empty: - objects = [] - for row in df.iterrows(): - obj = Project.from_proto( - ProjectProto.FromString(row[1]["project_proto"]) - ) - if has_all_tags(obj.tags, tags): - objects.append(obj) - return objects - return [] - - def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Project]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_projects(self.cached_registry_proto, tags) - return self._list_projects(tags) +import logging +import os +import uuid +from binascii import hexlify +from datetime import datetime, timedelta, timezone +from enum import Enum +from threading import Lock +from typing import Any, Callable, List, Literal, Optional, Union, cast + +from pydantic import ConfigDict, Field, StrictStr + +import feast +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.errors import ( + DataSourceObjectNotFoundException, + EntityNotFoundException, + FeatureServiceNotFoundException, + FeatureViewNotFoundException, + PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, + SavedDatasetNotFound, + ValidationReferenceNotFound, +) +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry import proto_registry_utils +from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, + execute_snowflake_statement, +) +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto +from feast.protos.feast.core.FeatureService_pb2 import ( + FeatureService as FeatureServiceProto, +) +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto +from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( + OnDemandFeatureView as OnDemandFeatureViewProto, +) +from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto +from feast.protos.feast.core.StreamFeatureView_pb2 import ( + StreamFeatureView as StreamFeatureViewProto, +) +from feast.protos.feast.core.ValidationProfile_pb2 import ( + ValidationReference as ValidationReferenceProto, +) +from feast.repo_config import RegistryConfig +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView +from feast.utils import _utc_now, has_all_tags + +logger = logging.getLogger(__name__) + + +class FeastMetadataKeys(Enum): + LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" + PROJECT_UUID = "project_uuid" + + +class SnowflakeRegistryConfig(RegistryConfig): + """Registry config for Snowflake""" + + registry_type: Literal["snowflake.registry"] = "snowflake.registry" + """ Registry type selector """ + + type: Literal["snowflake.registry"] = "snowflake.registry" + """ Registry type selector """ + + config_path: Optional[str] = os.path.expanduser("~/.snowsql/config") + """ Snowflake snowsql config path -- absolute path required (Cant use ~)""" + + connection_name: Optional[str] = None + """ Snowflake connector connection name -- typically defined in ~/.snowflake/connections.toml """ + + account: Optional[str] = None + """ Snowflake deployment identifier -- drop .snowflakecomputing.com """ + + user: Optional[str] = None + """ Snowflake user name """ + + password: Optional[str] = None + """ Snowflake password """ + + role: Optional[str] = None + """ Snowflake role name """ + + warehouse: Optional[str] = None + """ Snowflake warehouse name """ + + authenticator: Optional[str] = None + """ Snowflake authenticator name """ + + private_key: Optional[str] = None + """ Snowflake private key file path""" + + private_key_content: Optional[bytes] = None + """ Snowflake private key stored as bytes""" + + private_key_passphrase: Optional[str] = None + """ Snowflake private key file passphrase""" + + database: StrictStr + """ Snowflake database name """ + + schema_: Optional[str] = Field("PUBLIC", alias="schema") + """ Snowflake schema name """ + model_config = ConfigDict(populate_by_name=True) + + +class SnowflakeRegistry(BaseRegistry): + def __init__( + self, + registry_config, + project: str, + repo_path, + ): + assert registry_config is not None and isinstance( + registry_config, SnowflakeRegistryConfig + ), "SnowflakeRegistry needs a valid registry_config, a path does not work" + + self.registry_config = registry_config + self.registry_path = ( + f'"{self.registry_config.database}"."{self.registry_config.schema_}"' + ) + + with GetSnowflakeConnection(self.registry_config) as conn: + sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql" + with open(sql_function_file, "r") as file: + sql_file = file.read() + sql_cmds = sql_file.split(";") + for command in sql_cmds: + query = command.replace("REGISTRY_PATH", f"{self.registry_path}") + execute_snowflake_statement(conn, query) + + self.purge_feast_metadata = registry_config.purge_feast_metadata + self._sync_feast_metadata_to_projects_table() + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = _utc_now() + self._refresh_lock = Lock() + self.cached_registry_proto_ttl = timedelta( + seconds=( + registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) + ) + self.project = project + + def _sync_feast_metadata_to_projects_table(self): + feast_metadata_projects: set = [] + projects_set: set = [] + + with GetSnowflakeConnection(self.registry_config) as conn: + query = ( + f'SELECT DISTINCT project_id FROM {self.registry_path}."FEAST_METADATA"' + ) + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + for row in df.iterrows(): + feast_metadata_projects.add(row[1]["PROJECT_ID"]) + + if len(feast_metadata_projects) > 0: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f'SELECT project_id FROM {self.registry_path}."PROJECTS"' + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + for row in df.iterrows(): + projects_set.add(row[1]["PROJECT_ID"]) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects) - set(projects_set) + for project_name in projects_to_sync: + self.apply_project(Project(name=project_name), commit=True) + + if self.purge_feast_metadata: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + DELETE FROM {self.registry_path}."FEAST_METADATA" + """ + execute_snowflake_statement(conn, query) + + def refresh(self, project: Optional[str] = None): + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = _utc_now() + + def _refresh_cached_registry_if_necessary(self): + with self._refresh_lock: + expired = ( + self.cached_registry_proto is None + or self.cached_registry_proto_created is None + ) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + _utc_now() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) + ) + ) + + if expired: + logger.info("Registry cache expired, so refreshing") + self.refresh() + + def teardown(self): + with GetSnowflakeConnection(self.registry_config) as conn: + sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql" + with open(sql_function_file, "r") as file: + sqlFile = file.read() + sqlCommands = sqlFile.split(";") + for command in sqlCommands: + query = command.replace("REGISTRY_PATH", f"{self.registry_path}") + execute_snowflake_statement(conn, query) + + # apply operations + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + return self._apply_object( + "DATA_SOURCES", + project, + "DATA_SOURCE_NAME", + data_source, + "DATA_SOURCE_PROTO", + ) + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + return self._apply_object( + "ENTITIES", project, "ENTITY_NAME", entity, "ENTITY_PROTO" + ) + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + return self._apply_object( + "FEATURE_SERVICES", + project, + "FEATURE_SERVICE_NAME", + feature_service, + "FEATURE_SERVICE_PROTO", + ) + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1] + return self._apply_object( + fv_table_str, + project, + f"{fv_column_name}_NAME", + feature_view, + f"{fv_column_name}_PROTO", + ) + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + return self._apply_object( + "SAVED_DATASETS", + project, + "SAVED_DATASET_NAME", + saved_dataset, + "SAVED_DATASET_PROTO", + ) + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + return self._apply_object( + "VALIDATION_REFERENCES", + project, + "VALIDATION_REFERENCE_NAME", + validation_reference, + "VALIDATION_REFERENCE_PROTO", + ) + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + self._apply_object( + "MANAGED_INFRA", + project, + "INFRA_NAME", + infra, + "INFRA_PROTO", + name="infra_obj", + ) + + def _initialize_project_if_not_exists(self, project_name: str): + try: + self.get_project(project_name, allow_cache=True) + return + except ProjectObjectNotFoundException: + try: + self.get_project(project_name, allow_cache=False) + return + except ProjectObjectNotFoundException: + self.apply_project(Project(name=project_name), commit=True) + + def _apply_object( + self, + table: str, + project: str, + id_field_name: str, + obj: Any, + proto_field_name: str, + name: Optional[str] = None, + ): + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option + if not isinstance(obj, Project): + self._initialize_project_if_not_exists(project_name=project) + + name = name or (obj.name if hasattr(obj, "name") else None) + assert name, f"name needs to be provided for {obj}" + + update_datetime = _utc_now() + if hasattr(obj, "last_updated_timestamp"): + obj.last_updated_timestamp = update_datetime + + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + proto = hexlify(obj.to_proto().SerializeToString()).__str__()[1:] + query = f""" + UPDATE {self.registry_path}."{table}" + SET + {proto_field_name} = TO_BINARY({proto}), + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + {id_field_name.lower()} = '{name}' + """ + execute_snowflake_statement(conn, query) + + else: + obj_proto = obj.to_proto() + + if hasattr(obj_proto, "meta") and hasattr( + obj_proto.meta, "created_timestamp" + ): + obj_proto.meta.created_timestamp.FromDatetime(update_datetime) + + proto = hexlify(obj_proto.SerializeToString()).__str__()[1:] + if table == "FEATURE_VIEWS": + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '', '') + """ + elif "_FEATURE_VIEWS" in table: + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '') + """ + else: + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto})) + """ + execute_snowflake_statement(conn, query) + + if not isinstance(obj, Project): + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + + if not self.purge_feast_metadata: + self._set_last_updated_metadata(update_datetime, project) + + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + return self._apply_object( + "PERMISSIONS", + project, + "PERMISSION_NAME", + permission, + "PERMISSION_PROTO", + ) + + # delete operations + def delete_data_source(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "DATA_SOURCES", + name, + project, + "DATA_SOURCE_NAME", + DataSourceObjectNotFoundException, + ) + + def delete_entity(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "ENTITIES", name, project, "ENTITY_NAME", EntityNotFoundException + ) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "FEATURE_SERVICES", + name, + project, + "FEATURE_SERVICE_NAME", + FeatureServiceNotFoundException, + ) + + # can you have featureviews with the same name + def delete_feature_view(self, name: str, project: str, commit: bool = True): + deleted_count = 0 + for table in { + "FEATURE_VIEWS", + "ON_DEMAND_FEATURE_VIEWS", + "STREAM_FEATURE_VIEWS", + }: + deleted_count += self._delete_object( + table, name, project, "FEATURE_VIEW_NAME", None + ) + if deleted_count == 0: + raise FeatureViewNotFoundException(name, project) + + def delete_saved_dataset(self, name: str, project: str, allow_cache: bool = False): + self._delete_object( + "SAVED_DATASETS", + name, + project, + "SAVED_DATASET_NAME", + SavedDatasetNotFound, + ) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + self._delete_object( + "VALIDATION_REFERENCES", + name, + project, + "VALIDATION_REFERENCE_NAME", + ValidationReferenceNotFound, + ) + + def _delete_object( + self, + table: str, + name: str, + project: str, + id_field_name: str, + not_found_exception: Optional[Callable], + ): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + DELETE FROM {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + """ + cursor = execute_snowflake_statement(conn, query) + + if cursor.rowcount < 1 and not_found_exception: # type: ignore + raise not_found_exception(name, project) + self._set_last_updated_metadata(_utc_now(), project) + + return cursor.rowcount + + def delete_permission(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "PERMISSIONS", + name, + project, + "PERMISSION_NAME", + PermissionNotFoundException, + ) + + # get operations + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_data_source( + self.cached_registry_proto, name, project + ) + return self._get_object( + "DATA_SOURCES", + name, + project, + DataSourceProto, + DataSource, + "DATA_SOURCE_NAME", + "DATA_SOURCE_PROTO", + DataSourceObjectNotFoundException, + ) + + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_entity( + self.cached_registry_proto, name, project + ) + return self._get_object( + "ENTITIES", + name, + project, + EntityProto, + Entity, + "ENTITY_NAME", + "ENTITY_PROTO", + EntityNotFoundException, + ) + + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_service( + self.cached_registry_proto, name, project + ) + return self._get_object( + "FEATURE_SERVICES", + name, + project, + FeatureServiceProto, + FeatureService, + "FEATURE_SERVICE_NAME", + "FEATURE_SERVICE_PROTO", + FeatureServiceNotFoundException, + ) + + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "FEATURE_VIEWS", + name, + project, + FeatureViewProto, + FeatureView, + "FEATURE_VIEW_NAME", + "FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_any_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> BaseFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_any_feature_view( + self.cached_registry_proto, name, project + ) + fv = self._get_object( + "FEATURE_VIEWS", + name, + project, + FeatureViewProto, + FeatureView, + "FEATURE_VIEW_NAME", + "FEATURE_VIEW_PROTO", + None, + ) + + if not fv: + fv = self._get_object( + "STREAM_FEATURE_VIEWS", + name, + project, + StreamFeatureViewProto, + StreamFeatureView, + "STREAM_FEATURE_VIEW_NAME", + "STREAM_FEATURE_VIEW_PROTO", + None, + ) + if not fv: + fv = self._get_object( + "ON_DEMAND_FEATURE_VIEWS", + name, + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "ON_DEMAND_FEATURE_VIEW_NAME", + "ON_DEMAND_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + return fv + + def list_all_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[BaseFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_all_feature_views( + self.cached_registry_proto, project, tags + ) + + return ( + cast( + list[BaseFeatureView], + self.list_feature_views(project, allow_cache, tags), + ) + + cast( + list[BaseFeatureView], + self.list_stream_feature_views(project, allow_cache, tags), + ) + + cast( + list[BaseFeatureView], + self.list_on_demand_feature_views(project, allow_cache, tags), + ) + ) + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + infra_object = self._get_object( + "MANAGED_INFRA", + "infra_obj", + project, + InfraProto, + Infra, + "INFRA_NAME", + "INFRA_PROTO", + None, + ) + infra_object = infra_object or InfraProto() + return Infra.from_proto(infra_object) + + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_on_demand_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "ON_DEMAND_FEATURE_VIEWS", + name, + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "ON_DEMAND_FEATURE_VIEW_NAME", + "ON_DEMAND_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_saved_dataset( + self.cached_registry_proto, name, project + ) + return self._get_object( + "SAVED_DATASETS", + name, + project, + SavedDatasetProto, + SavedDataset, + "SAVED_DATASET_NAME", + "SAVED_DATASET_PROTO", + SavedDatasetNotFound, + ) + + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ): + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_stream_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "STREAM_FEATURE_VIEWS", + name, + project, + StreamFeatureViewProto, + StreamFeatureView, + "STREAM_FEATURE_VIEW_NAME", + "STREAM_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_validation_reference( + self.cached_registry_proto, name, project + ) + return self._get_object( + "VALIDATION_REFERENCES", + name, + project, + ValidationReferenceProto, + ValidationReference, + "VALIDATION_REFERENCE_NAME", + "VALIDATION_REFERENCE_PROTO", + ValidationReferenceNotFound, + ) + + def _get_object( + self, + table: str, + name: str, + project: str, + proto_class: Any, + python_class: Any, + id_field_name: str, + proto_field_name: str, + not_found_exception: Optional[Callable], + ): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + {proto_field_name} + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + _proto = proto_class.FromString(df.squeeze()) + return python_class.from_proto(_proto) + elif not_found_exception: + raise not_found_exception(name, project) + else: + return None + + def get_permission( + self, name: str, project: str, allow_cache: bool = False + ) -> Permission: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_permission( + self.cached_registry_proto, name, project + ) + return self._get_object( + "PERMISSIONS", + name, + project, + PermissionProto, + Permission, + "PERMISSION_NAME", + "PERMISSION_PROTO", + PermissionNotFoundException, + ) + + # list operations + def list_data_sources( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[DataSource]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_data_sources( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "DATA_SOURCES", + project, + DataSourceProto, + DataSource, + "DATA_SOURCE_PROTO", + tags=tags, + ) + + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_entities( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "ENTITIES", project, EntityProto, Entity, "ENTITY_PROTO", tags=tags + ) + + def list_feature_services( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureService]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_feature_services( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "FEATURE_SERVICES", + project, + FeatureServiceProto, + FeatureService, + "FEATURE_SERVICE_PROTO", + tags=tags, + ) + + def list_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_feature_views( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "FEATURE_VIEWS", + project, + FeatureViewProto, + FeatureView, + "FEATURE_VIEW_PROTO", + tags=tags, + ) + + def list_on_demand_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[OnDemandFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_on_demand_feature_views( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "ON_DEMAND_FEATURE_VIEWS", + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "ON_DEMAND_FEATURE_VIEW_PROTO", + tags=tags, + ) + + def list_saved_datasets( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[SavedDataset]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_saved_datasets( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "SAVED_DATASETS", + project, + SavedDatasetProto, + SavedDataset, + "SAVED_DATASET_PROTO", + tags=tags, + ) + + def list_stream_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[StreamFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_stream_feature_views( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "STREAM_FEATURE_VIEWS", + project, + StreamFeatureViewProto, + StreamFeatureView, + "STREAM_FEATURE_VIEW_PROTO", + tags=tags, + ) + + def list_validation_references( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[ValidationReference]: + return self._list_objects( + "VALIDATION_REFERENCES", + project, + ValidationReferenceProto, + ValidationReference, + "VALIDATION_REFERENCE_PROTO", + tags=tags, + ) + + def _list_objects( + self, + table: str, + project: str, + proto_class: Any, + python_class: Any, + proto_field_name: str, + tags: Optional[dict[str, str]] = None, + ): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + {proto_field_name} + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + if not df.empty: + objects = [] + for row in df.iterrows(): + obj = python_class.from_proto( + proto_class.FromString(row[1][proto_field_name]) + ) + if has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def list_permissions( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Permission]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_permissions( + self.cached_registry_proto, project + ) + return self._list_objects( + "PERMISSIONS", + project, + PermissionProto, + Permission, + "PERMISSION_PROTO", + tags, + ) + + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1] + python_class, proto_class = self._infer_fv_classes(feature_view) + + if python_class in {OnDemandFeatureView}: + raise ValueError( + f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" + ) + fv: Union[FeatureView, StreamFeatureView] = self._get_object( + fv_table_str, + feature_view.name, + project, + proto_class, + python_class, + f"{fv_column_name}_NAME", + f"{fv_column_name}_PROTO", + FeatureViewNotFoundException, + ) + fv.materialization_intervals.append((start_date, end_date)) + self._apply_object( + fv_table_str, + project, + f"{fv_column_name}_NAME", + fv, + f"{fv_column_name}_PROTO", + ) + + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_project_metadata( + self.cached_registry_proto, project + ) + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + metadata_key, + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + project_metadata = ProjectMetadata(project_name=project) + for row in df.iterrows(): + if row[1]["METADATA_KEY"] == FeastMetadataKeys.PROJECT_UUID.value: + project_metadata.project_uuid = row[1]["METADATA_VALUE"] + break + # TODO(adchia): Add other project metadata in a structured way + return [project_metadata] + return [] + + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1].lower() + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."{fv_table_str}" + WHERE + project_id = '{project}' + AND {fv_column_name}_name = '{feature_view.name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + if metadata_bytes: + metadata_hex = hexlify(metadata_bytes).__str__()[1:] + else: + metadata_hex = "''" + query = f""" + UPDATE {self.registry_path}."{fv_table_str}" + SET + user_metadata = TO_BINARY({metadata_hex}), + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + project_id = '{project}' + AND {fv_column_name}_name = '{feature_view.name}' + """ + execute_snowflake_statement(conn, query) + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1].lower() + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + user_metadata + FROM + {self.registry_path}."{fv_table_str}" + WHERE + {fv_column_name}_name = '{feature_view.name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + return df.squeeze() + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def proto(self) -> RegistryProto: + r = RegistryProto() + last_updated_timestamps = [] + + def process_project(project: Project): + nonlocal r, last_updated_timestamps + project_name = project.name + last_updated_timestamp = project.last_updated_timestamp + + try: + cached_project = self.get_project(project_name, True) + except ProjectObjectNotFoundException: + cached_project = None + + allow_cache = False + + if cached_project is not None: + allow_cache = ( + last_updated_timestamp <= cached_project.last_updated_timestamp + ) + + r.projects.extend([project.to_proto()]) + last_updated_timestamps.append(last_updated_timestamp) + + for lister, registry_proto_field in [ + (self.list_entities, r.entities), + (self.list_feature_views, r.feature_views), + (self.list_data_sources, r.data_sources), + (self.list_on_demand_feature_views, r.on_demand_feature_views), + (self.list_stream_feature_views, r.stream_feature_views), + (self.list_feature_services, r.feature_services), + (self.list_saved_datasets, r.saved_datasets), + (self.list_validation_references, r.validation_references), + (self.list_permissions, r.permissions), + ]: + objs: List[Any] = lister(project_name, allow_cache) # type: ignore + if objs: + obj_protos = [obj.to_proto() for obj in objs] + for obj_proto in obj_protos: + if "spec" in obj_proto.DESCRIPTOR.fields_by_name: + obj_proto.spec.project = project_name + else: + obj_proto.project = project_name + registry_proto_field.extend(obj_protos) + + # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, + # the registry proto only has a single infra field, which we're currently setting as the "last" project. + r.infra.CopyFrom(self.get_infra(project_name).to_proto()) + + projects_list = self.list_projects(allow_cache=False) + for project in projects_list: + process_project(project) + + if last_updated_timestamps: + r.last_updated.FromDatetime(max(last_updated_timestamps)) + + return r + + def _get_last_updated_metadata(self, project: str): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if df.empty: + return None + + return datetime.fromtimestamp(int(df.squeeze()), tz=timezone.utc) + + def _infer_fv_classes(self, feature_view): + if isinstance(feature_view, StreamFeatureView): + python_class, proto_class = StreamFeatureView, StreamFeatureViewProto + elif isinstance(feature_view, FeatureView): + python_class, proto_class = FeatureView, FeatureViewProto + elif isinstance(feature_view, OnDemandFeatureView): + python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return python_class, proto_class + + def _infer_fv_table(self, feature_view) -> str: + if isinstance(feature_view, StreamFeatureView): + table = "STREAM_FEATURE_VIEWS" + elif isinstance(feature_view, FeatureView): + table = "FEATURE_VIEWS" + elif isinstance(feature_view, OnDemandFeatureView): + table = "ON_DEMAND_FEATURE_VIEWS" + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return table + + def _maybe_init_project_metadata(self, project): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.PROJECT_UUID.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if df.empty: + new_project_uuid = f"{uuid.uuid4()}" + query = f""" + INSERT INTO {self.registry_path}."FEAST_METADATA" + VALUES + ('{project}', '{FeastMetadataKeys.PROJECT_UUID.value}', '{new_project_uuid}', CURRENT_TIMESTAMP()) + """ + execute_snowflake_statement(conn, query) + + def _set_last_updated_metadata(self, last_updated: datetime, project: str): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + update_time = int(last_updated.timestamp()) + if not df.empty: + query = f""" + UPDATE {self.registry_path}."FEAST_METADATA" + SET + project_id = '{project}', + metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', + metadata_value = '{update_time}', + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + """ + execute_snowflake_statement(conn, query) + + else: + query = f""" + INSERT INTO {self.registry_path}."FEAST_METADATA" + VALUES + ('{project}', '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', '{update_time}', CURRENT_TIMESTAMP()) + """ + execute_snowflake_statement(conn, query) + + def commit(self): + pass + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + return self._apply_object( + "PROJECTS", project.name, "project_name", project, "project_proto" + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + project = self.get_project(name, allow_cache=False) + if project: + with GetSnowflakeConnection(self.registry_config) as conn: + for table in { + "MANAGED_INFRA", + "SAVED_DATASETS", + "VALIDATION_REFERENCES", + "FEATURE_SERVICES", + "FEATURE_VIEWS", + "ON_DEMAND_FEATURE_VIEWS", + "STREAM_FEATURE_VIEWS", + "DATA_SOURCES", + "ENTITIES", + "PERMISSIONS", + "FEAST_METADATA", + "PROJECTS", + }: + query = f""" + DELETE FROM {self.registry_path}."{table}" + WHERE + project_id = '{project}' + """ + execute_snowflake_statement(conn, query) + return + + raise ProjectNotFoundException(name) + + def _get_project( + self, + name: str, + ) -> Project: + return self._get_object( + table="PROJECTS", + name=name, + project=name, + proto_class=ProjectProto, + python_class=Project, + id_field_name="project_name", + proto_field_name="project_proto", + not_found_exception=ProjectObjectNotFoundException, + ) + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_project(self.cached_registry_proto, name) + return self._get_project(name) + + def _list_projects( + self, + tags: Optional[dict[str, str]], + ) -> List[Project]: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT project_proto FROM {self.registry_path}."PROJECTS" + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + if not df.empty: + objects = [] + for row in df.iterrows(): + obj = Project.from_proto( + ProjectProto.FromString(row[1]["project_proto"]) + ) + if has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_projects(self.cached_registry_proto, tags) + return self._list_projects(tags) diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index c42e6e8b82b..36b8174d2b6 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -1,1258 +1,1258 @@ -import logging -import uuid -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timezone -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, cast - -from pydantic import StrictInt, StrictStr -from sqlalchemy import ( # type: ignore - BigInteger, - Column, - Index, - LargeBinary, - MetaData, - String, - Table, - create_engine, - delete, - insert, - select, - update, -) -from sqlalchemy.engine import Engine - -from feast import utils -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.errors import ( - DataSourceObjectNotFoundException, - EntityNotFoundException, - FeatureServiceNotFoundException, - FeatureViewNotFoundException, - PermissionNotFoundException, - ProjectNotFoundException, - ProjectObjectNotFoundException, - SavedDatasetNotFound, - ValidationReferenceNotFound, -) -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.infra.infra_object import Infra -from feast.infra.registry.caching_registry import CachingRegistry -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto -from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto -from feast.protos.feast.core.FeatureService_pb2 import ( - FeatureService as FeatureServiceProto, -) -from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto -from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto -from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( - OnDemandFeatureView as OnDemandFeatureViewProto, -) -from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto -from feast.protos.feast.core.Project_pb2 import Project as ProjectProto -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto -from feast.protos.feast.core.StreamFeatureView_pb2 import ( - StreamFeatureView as StreamFeatureViewProto, -) -from feast.protos.feast.core.ValidationProfile_pb2 import ( - ValidationReference as ValidationReferenceProto, -) -from feast.repo_config import RegistryConfig -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView -from feast.utils import _utc_now - -metadata = MetaData() - - -projects = Table( - "projects", - metadata, - Column("project_id", String(255), primary_key=True), - Column("project_name", String(255), nullable=False), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("project_proto", LargeBinary, nullable=False), -) - -Index("idx_projects_project_id", projects.c.project_id) - -entities = Table( - "entities", - metadata, - Column("entity_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("entity_proto", LargeBinary, nullable=False), -) - -Index("idx_entities_project_id", entities.c.project_id) - -data_sources = Table( - "data_sources", - metadata, - Column("data_source_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("data_source_proto", LargeBinary, nullable=False), -) - -Index("idx_data_sources_project_id", data_sources.c.project_id) - -feature_views = Table( - "feature_views", - metadata, - Column("feature_view_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("materialized_intervals", LargeBinary, nullable=True), - Column("feature_view_proto", LargeBinary, nullable=False), - Column("user_metadata", LargeBinary, nullable=True), -) - -Index("idx_feature_views_project_id", feature_views.c.project_id) - -stream_feature_views = Table( - "stream_feature_views", - metadata, - Column("feature_view_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("feature_view_proto", LargeBinary, nullable=False), - Column("user_metadata", LargeBinary, nullable=True), -) - -Index("idx_stream_feature_views_project_id", stream_feature_views.c.project_id) - -on_demand_feature_views = Table( - "on_demand_feature_views", - metadata, - Column("feature_view_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("feature_view_proto", LargeBinary, nullable=False), - Column("user_metadata", LargeBinary, nullable=True), -) - -Index("idx_on_demand_feature_views_project_id", on_demand_feature_views.c.project_id) - -feature_services = Table( - "feature_services", - metadata, - Column("feature_service_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("feature_service_proto", LargeBinary, nullable=False), -) - -Index("idx_feature_services_project_id", feature_services.c.project_id) - -saved_datasets = Table( - "saved_datasets", - metadata, - Column("saved_dataset_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("saved_dataset_proto", LargeBinary, nullable=False), -) - -Index("idx_saved_datasets_project_id", saved_datasets.c.project_id) - -validation_references = Table( - "validation_references", - metadata, - Column("validation_reference_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("validation_reference_proto", LargeBinary, nullable=False), -) -Index("idx_validation_references_project_id", validation_references.c.project_id) - -managed_infra = Table( - "managed_infra", - metadata, - Column("infra_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("infra_proto", LargeBinary, nullable=False), -) - -Index("idx_managed_infra_project_id", managed_infra.c.project_id) - -permissions = Table( - "permissions", - metadata, - Column("permission_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("permission_proto", LargeBinary, nullable=False), -) - -Index("idx_permissions_project_id", permissions.c.project_id) - - -class FeastMetadataKeys(Enum): - LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" - PROJECT_UUID = "project_uuid" - - -feast_metadata = Table( - "feast_metadata", - metadata, - Column("project_id", String(255), primary_key=True), - Column("metadata_key", String(50), primary_key=True), - Column("metadata_value", String(50), nullable=False), - Column("last_updated_timestamp", BigInteger, nullable=False), -) - -Index("idx_feast_metadata_project_id", feast_metadata.c.project_id) - -logger = logging.getLogger(__name__) - - -class SqlRegistryConfig(RegistryConfig): - registry_type: StrictStr = "sql" - """ str: Provider name or a class name that implements Registry.""" - - path: StrictStr = "" - """ str: Path to metadata store. - If registry_type is 'sql', then this is a database URL as expected by SQLAlchemy """ - - read_path: Optional[StrictStr] = None - """ str: Read Path to metadata store if different from path. - If registry_type is 'sql', then this is a Read Endpoint for database URL. If not set, path will be used for read and write. """ - - sqlalchemy_config_kwargs: Dict[str, Any] = {"echo": False} - """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ - - cache_mode: StrictStr = "sync" - """ str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)""" - - thread_pool_executor_worker_count: StrictInt = 0 - """ int: Number of worker threads to use for asynchronous caching in SQL Registry. If set to 0, it doesn't use ThreadPoolExecutor. """ - - -class SqlRegistry(CachingRegistry): - def __init__( - self, - registry_config, - project: str, - repo_path: Optional[Path], - ): - assert registry_config is not None and isinstance( - registry_config, SqlRegistryConfig - ), "SqlRegistry needs a valid registry_config" - - self.registry_config = registry_config - - self.write_engine: Engine = create_engine( - registry_config.path, **registry_config.sqlalchemy_config_kwargs - ) - if registry_config.read_path: - self.read_engine: Engine = create_engine( - registry_config.read_path, - **registry_config.sqlalchemy_config_kwargs, - ) - else: - self.read_engine = self.write_engine - metadata.create_all(self.write_engine) - self.thread_pool_executor_worker_count = ( - registry_config.thread_pool_executor_worker_count - ) - self.purge_feast_metadata = registry_config.purge_feast_metadata - # Sync feast_metadata to projects table - # when purge_feast_metadata is set to True, Delete data from - # feast_metadata table and list_project_metadata will not return any data - self._sync_feast_metadata_to_projects_table() - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) - super().__init__( - project=project, - cache_ttl_seconds=registry_config.cache_ttl_seconds, - cache_mode=registry_config.cache_mode, - ) - - def _sync_feast_metadata_to_projects_table(self): - feast_metadata_projects: dict = {} - projects_set: set = [] - with self.read_engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value - ) - rows = conn.execute(stmt).all() - for row in rows: - feast_metadata_projects[row._mapping["project_id"]] = int( - row._mapping["last_updated_timestamp"] - ) - - if len(feast_metadata_projects) > 0: - with self.read_engine.begin() as conn: - stmt = select(projects) - rows = conn.execute(stmt).all() - for row in rows: - projects_set.append(row._mapping["project_id"]) - - # Find object in feast_metadata_projects but not in projects - projects_to_sync = set(feast_metadata_projects.keys()) - set(projects_set) - for project_name in projects_to_sync: - self.apply_project( - Project( - name=project_name, - created_timestamp=datetime.fromtimestamp( - feast_metadata_projects[project_name], tz=timezone.utc - ), - ), - commit=True, - ) - - if self.purge_feast_metadata: - with self.write_engine.begin() as conn: - for project_name in feast_metadata_projects: - stmt = delete(feast_metadata).where( - feast_metadata.c.project_id == project_name - ) - conn.execute(stmt) - - def teardown(self): - for t in { - entities, - data_sources, - feature_views, - feature_services, - on_demand_feature_views, - saved_datasets, - validation_references, - permissions, - }: - with self.write_engine.begin() as conn: - stmt = delete(t) - conn.execute(stmt) - - def _get_stream_feature_view(self, name: str, project: str): - return self._get_object( - table=stream_feature_views, - name=name, - project=project, - proto_class=StreamFeatureViewProto, - python_class=StreamFeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=FeatureViewNotFoundException, - ) - - def _list_stream_feature_views( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[StreamFeatureView]: - return self._list_objects( - stream_feature_views, - project, - StreamFeatureViewProto, - StreamFeatureView, - "feature_view_proto", - tags=tags, - ) - - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - return self._apply_object( - table=entities, - project=project, - id_field_name="entity_name", - obj=entity, - proto_field_name="entity_proto", - ) - - def _get_entity(self, name: str, project: str) -> Entity: - return self._get_object( - table=entities, - name=name, - project=project, - proto_class=EntityProto, - python_class=Entity, - id_field_name="entity_name", - proto_field_name="entity_proto", - not_found_exception=EntityNotFoundException, - ) - - def _get_any_feature_view(self, name: str, project: str) -> BaseFeatureView: - fv = self._get_object( - table=feature_views, - name=name, - project=project, - proto_class=FeatureViewProto, - python_class=FeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=None, - ) - - if not fv: - fv = self._get_object( - table=on_demand_feature_views, - name=name, - project=project, - proto_class=OnDemandFeatureViewProto, - python_class=OnDemandFeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=None, - ) - - if not fv: - fv = self._get_object( - table=stream_feature_views, - name=name, - project=project, - proto_class=StreamFeatureViewProto, - python_class=StreamFeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=FeatureViewNotFoundException, - ) - return fv - - def _list_all_feature_views( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[BaseFeatureView]: - return ( - cast( - list[BaseFeatureView], - self._list_feature_views(project=project, tags=tags), - ) - + cast( - list[BaseFeatureView], - self._list_stream_feature_views(project=project, tags=tags), - ) - + cast( - list[BaseFeatureView], - self._list_on_demand_feature_views(project=project, tags=tags), - ) - ) - - def _get_feature_view(self, name: str, project: str) -> FeatureView: - return self._get_object( - table=feature_views, - name=name, - project=project, - proto_class=FeatureViewProto, - python_class=FeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=FeatureViewNotFoundException, - ) - - def _get_on_demand_feature_view( - self, name: str, project: str - ) -> OnDemandFeatureView: - return self._get_object( - table=on_demand_feature_views, - name=name, - project=project, - proto_class=OnDemandFeatureViewProto, - python_class=OnDemandFeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=FeatureViewNotFoundException, - ) - - def _get_feature_service(self, name: str, project: str) -> FeatureService: - return self._get_object( - table=feature_services, - name=name, - project=project, - proto_class=FeatureServiceProto, - python_class=FeatureService, - id_field_name="feature_service_name", - proto_field_name="feature_service_proto", - not_found_exception=FeatureServiceNotFoundException, - ) - - def _get_saved_dataset(self, name: str, project: str) -> SavedDataset: - return self._get_object( - table=saved_datasets, - name=name, - project=project, - proto_class=SavedDatasetProto, - python_class=SavedDataset, - id_field_name="saved_dataset_name", - proto_field_name="saved_dataset_proto", - not_found_exception=SavedDatasetNotFound, - ) - - def _get_validation_reference(self, name: str, project: str) -> ValidationReference: - return self._get_object( - table=validation_references, - name=name, - project=project, - proto_class=ValidationReferenceProto, - python_class=ValidationReference, - id_field_name="validation_reference_name", - proto_field_name="validation_reference_proto", - not_found_exception=ValidationReferenceNotFound, - ) - - def _list_validation_references( - self, project: str, tags: Optional[dict[str, str]] = None - ) -> List[ValidationReference]: - return self._list_objects( - table=validation_references, - project=project, - proto_class=ValidationReferenceProto, - python_class=ValidationReference, - proto_field_name="validation_reference_proto", - tags=tags, - ) - - def _list_entities( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[Entity]: - return self._list_objects( - entities, project, EntityProto, Entity, "entity_proto", tags=tags - ) - - def delete_entity(self, name: str, project: str, commit: bool = True): - return self._delete_object( - entities, name, project, "entity_name", EntityNotFoundException - ) - - def delete_feature_view(self, name: str, project: str, commit: bool = True): - deleted_count = 0 - for table in { - feature_views, - on_demand_feature_views, - stream_feature_views, - }: - deleted_count += self._delete_object( - table, name, project, "feature_view_name", None - ) - if deleted_count == 0: - raise FeatureViewNotFoundException(name, project) - - def delete_feature_service(self, name: str, project: str, commit: bool = True): - return self._delete_object( - feature_services, - name, - project, - "feature_service_name", - FeatureServiceNotFoundException, - ) - - def _get_data_source(self, name: str, project: str) -> DataSource: - return self._get_object( - table=data_sources, - name=name, - project=project, - proto_class=DataSourceProto, - python_class=DataSource, - id_field_name="data_source_name", - proto_field_name="data_source_proto", - not_found_exception=DataSourceObjectNotFoundException, - ) - - def _list_data_sources( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[DataSource]: - return self._list_objects( - data_sources, - project, - DataSourceProto, - DataSource, - "data_source_proto", - tags=tags, - ) - - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - return self._apply_object( - data_sources, project, "data_source_name", data_source, "data_source_proto" - ) - - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - fv_table = self._infer_fv_table(feature_view) - - return self._apply_object( - fv_table, project, "feature_view_name", feature_view, "feature_view_proto" - ) - - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - return self._apply_object( - feature_services, - project, - "feature_service_name", - feature_service, - "feature_service_proto", - ) - - def delete_data_source(self, name: str, project: str, commit: bool = True): - with self.write_engine.begin() as conn: - stmt = delete(data_sources).where( - data_sources.c.data_source_name == name, - data_sources.c.project_id == project, - ) - rows = conn.execute(stmt) - if rows.rowcount < 1: - raise DataSourceObjectNotFoundException(name, project) - - def _list_feature_services( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[FeatureService]: - return self._list_objects( - feature_services, - project, - FeatureServiceProto, - FeatureService, - "feature_service_proto", - tags=tags, - ) - - def _list_feature_views( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[FeatureView]: - return self._list_objects( - feature_views, - project, - FeatureViewProto, - FeatureView, - "feature_view_proto", - tags=tags, - ) - - def _list_saved_datasets( - self, project: str, tags: Optional[dict[str, str]] = None - ) -> List[SavedDataset]: - return self._list_objects( - saved_datasets, - project, - SavedDatasetProto, - SavedDataset, - "saved_dataset_proto", - tags=tags, - ) - - def _list_on_demand_feature_views( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[OnDemandFeatureView]: - return self._list_objects( - on_demand_feature_views, - project, - OnDemandFeatureViewProto, - OnDemandFeatureView, - "feature_view_proto", - tags=tags, - ) - - def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: - with self.read_engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.project_id == project, - ) - rows = conn.execute(stmt).all() - if rows: - project_metadata = ProjectMetadata(project_name=project) - for row in rows: - if ( - row._mapping["metadata_key"] - == FeastMetadataKeys.PROJECT_UUID.value - ): - project_metadata.project_uuid = row._mapping["metadata_value"] - break - # TODO(adchia): Add other project metadata in a structured way - return [project_metadata] - return [] - - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - return self._apply_object( - saved_datasets, - project, - "saved_dataset_name", - saved_dataset, - "saved_dataset_proto", - ) - - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - return self._apply_object( - validation_references, - project, - "validation_reference_name", - validation_reference, - "validation_reference_proto", - ) - - def apply_materialization( - self, - feature_view: FeatureView, - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - table = self._infer_fv_table(feature_view) - python_class, proto_class = self._infer_fv_classes(feature_view) - - if python_class in {OnDemandFeatureView}: - raise ValueError( - f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" - ) - fv: Union[FeatureView, StreamFeatureView] = self._get_object( - table, - feature_view.name, - project, - proto_class, - python_class, - "feature_view_name", - "feature_view_proto", - FeatureViewNotFoundException, - ) - fv.materialization_intervals.append((start_date, end_date)) - self._apply_object( - table, project, "feature_view_name", fv, "feature_view_proto" - ) - - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - self._delete_object( - validation_references, - name, - project, - "validation_reference_name", - ValidationReferenceNotFound, - ) - - def update_infra(self, infra: Infra, project: str, commit: bool = True): - self._apply_object( - table=managed_infra, - project=project, - id_field_name="infra_name", - obj=infra, - proto_field_name="infra_proto", - name="infra_obj", - ) - - def _get_infra(self, project: str) -> Infra: - infra_object = self._get_object( - table=managed_infra, - name="infra_obj", - project=project, - proto_class=InfraProto, - python_class=Infra, - id_field_name="infra_name", - proto_field_name="infra_proto", - not_found_exception=None, - ) - if infra_object: - return infra_object - return Infra() - - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): - table = self._infer_fv_table(feature_view) - - name = feature_view.name - with self.write_engine.begin() as conn: - stmt = select(table).where( - getattr(table.c, "feature_view_name") == name, - table.c.project_id == project, - ) - row = conn.execute(stmt).first() - update_datetime = _utc_now() - update_time = int(update_datetime.timestamp()) - if row: - values = { - "user_metadata": metadata_bytes, - "last_updated_timestamp": update_time, - } - update_stmt = ( - update(table) - .where( - getattr(table.c, "feature_view_name") == name, - table.c.project_id == project, - ) - .values( - values, - ) - ) - conn.execute(update_stmt) - else: - raise FeatureViewNotFoundException(feature_view.name, project=project) - - def _infer_fv_table(self, feature_view): - if isinstance(feature_view, StreamFeatureView): - table = stream_feature_views - elif isinstance(feature_view, FeatureView): - table = feature_views - elif isinstance(feature_view, OnDemandFeatureView): - table = on_demand_feature_views - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - return table - - def _infer_fv_classes(self, feature_view): - if isinstance(feature_view, StreamFeatureView): - python_class, proto_class = StreamFeatureView, StreamFeatureViewProto - elif isinstance(feature_view, FeatureView): - python_class, proto_class = FeatureView, FeatureViewProto - elif isinstance(feature_view, OnDemandFeatureView): - python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - return python_class, proto_class - - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: - table = self._infer_fv_table(feature_view) - - name = feature_view.name - with self.read_engine.begin() as conn: - stmt = select(table).where(getattr(table.c, "feature_view_name") == name) - row = conn.execute(stmt).first() - if row: - return row._mapping["user_metadata"] - else: - raise FeatureViewNotFoundException(feature_view.name, project=project) - - def proto(self) -> RegistryProto: - r = RegistryProto() - last_updated_timestamps = [] - - def process_project(project: Project): - nonlocal r, last_updated_timestamps - project_name = project.name - last_updated_timestamp = project.last_updated_timestamp - - try: - cached_project = self.get_project(project_name, True) - except ProjectObjectNotFoundException: - cached_project = None - - allow_cache = False - - if cached_project is not None: - allow_cache = ( - last_updated_timestamp <= cached_project.last_updated_timestamp - ) - - r.projects.extend([project.to_proto()]) - last_updated_timestamps.append(last_updated_timestamp) - - for lister, registry_proto_field in [ - (self.list_entities, r.entities), - (self.list_feature_views, r.feature_views), - (self.list_data_sources, r.data_sources), - (self.list_on_demand_feature_views, r.on_demand_feature_views), - (self.list_stream_feature_views, r.stream_feature_views), - (self.list_feature_services, r.feature_services), - (self.list_saved_datasets, r.saved_datasets), - (self.list_validation_references, r.validation_references), - (self.list_permissions, r.permissions), - ]: - objs: List[Any] = lister(project_name, allow_cache) # type: ignore - if objs: - obj_protos = [obj.to_proto() for obj in objs] - for obj_proto in obj_protos: - if "spec" in obj_proto.DESCRIPTOR.fields_by_name: - obj_proto.spec.project = project_name - else: - obj_proto.project = project_name - registry_proto_field.extend(obj_protos) - - # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, - # the registry proto only has a single infra field, which we're currently setting as the "last" project. - r.infra.CopyFrom(self.get_infra(project_name).to_proto()) - - projects_list = self.list_projects(allow_cache=False) - if self.thread_pool_executor_worker_count == 0: - for project in projects_list: - process_project(project) - else: - with ThreadPoolExecutor( - max_workers=self.thread_pool_executor_worker_count - ) as executor: - executor.map(process_project, projects_list) - - if last_updated_timestamps: - r.last_updated.FromDatetime(max(last_updated_timestamps)) - - return r - - def commit(self): - # This method is a no-op since we're always writing values eagerly to the db. - pass - - def _initialize_project_if_not_exists(self, project_name: str): - try: - self.get_project(project_name, allow_cache=True) - return - except ProjectObjectNotFoundException: - try: - self.get_project(project_name, allow_cache=False) - return - except ProjectObjectNotFoundException: - self.apply_project(Project(name=project_name), commit=True) - - def _apply_object( - self, - table: Table, - project: str, - id_field_name: str, - obj: Any, - proto_field_name: str, - name: Optional[str] = None, - ): - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) - # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option - if not isinstance(obj, Project): - self._initialize_project_if_not_exists(project_name=project) - name = name or (obj.name if hasattr(obj, "name") else None) - assert name, f"name needs to be provided for {obj}" - - with self.write_engine.begin() as conn: - update_datetime = _utc_now() - update_time = int(update_datetime.timestamp()) - stmt = select(table).where( - getattr(table.c, id_field_name) == name, table.c.project_id == project - ) - row = conn.execute(stmt).first() - if hasattr(obj, "last_updated_timestamp"): - obj.last_updated_timestamp = update_datetime - - if row: - if proto_field_name in [ - "entity_proto", - "saved_dataset_proto", - "feature_view_proto", - "feature_service_proto", - "permission_proto", - "project_proto", - ]: - deserialized_proto = self.deserialize_registry_values( - row._mapping[proto_field_name], type(obj) - ) - obj.created_timestamp = ( - deserialized_proto.meta.created_timestamp.ToDatetime().replace( - tzinfo=timezone.utc - ) - ) - if isinstance(obj, (FeatureView, StreamFeatureView)): - obj.update_materialization_intervals( - type(obj) - .from_proto(deserialized_proto) - .materialization_intervals - ) - values = { - proto_field_name: obj.to_proto().SerializeToString(), - "last_updated_timestamp": update_time, - } - update_stmt = ( - update(table) - .where( - getattr(table.c, id_field_name) == name, - table.c.project_id == project, - ) - .values( - values, - ) - ) - conn.execute(update_stmt) - else: - obj_proto = obj.to_proto() - - if hasattr(obj_proto, "meta") and hasattr( - obj_proto.meta, "created_timestamp" - ): - if not obj_proto.meta.HasField("created_timestamp"): - obj_proto.meta.created_timestamp.FromDatetime(update_datetime) - - values = { - id_field_name: name, - proto_field_name: obj_proto.SerializeToString(), - "last_updated_timestamp": update_time, - "project_id": project, - } - insert_stmt = insert(table).values( - values, - ) - conn.execute(insert_stmt) - - if not isinstance(obj, Project): - self.apply_project( - self.get_project(name=project, allow_cache=False), commit=True - ) - if not self.purge_feast_metadata: - self._set_last_updated_metadata(update_datetime, project) - - def _maybe_init_project_metadata(self, project): - # Initialize project metadata if needed - with self.write_engine.begin() as conn: - update_datetime = _utc_now() - update_time = int(update_datetime.timestamp()) - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value, - feast_metadata.c.project_id == project, - ) - row = conn.execute(stmt).first() - if not row: - new_project_uuid = f"{uuid.uuid4()}" - values = { - "metadata_key": FeastMetadataKeys.PROJECT_UUID.value, - "metadata_value": new_project_uuid, - "last_updated_timestamp": update_time, - "project_id": project, - } - insert_stmt = insert(feast_metadata).values(values) - conn.execute(insert_stmt) - - def _delete_object( - self, - table: Table, - name: str, - project: str, - id_field_name: str, - not_found_exception: Optional[Callable], - ): - with self.write_engine.begin() as conn: - stmt = delete(table).where( - getattr(table.c, id_field_name) == name, table.c.project_id == project - ) - rows = conn.execute(stmt) - if rows.rowcount < 1 and not_found_exception: - raise not_found_exception(name, project) - self.apply_project( - self.get_project(name=project, allow_cache=False), commit=True - ) - if not self.purge_feast_metadata: - self._set_last_updated_metadata(_utc_now(), project) - - return rows.rowcount - - def _get_object( - self, - table: Table, - name: str, - project: str, - proto_class: Any, - python_class: Any, - id_field_name: str, - proto_field_name: str, - not_found_exception: Optional[Callable], - ): - with self.read_engine.begin() as conn: - stmt = select(table).where( - getattr(table.c, id_field_name) == name, table.c.project_id == project - ) - row = conn.execute(stmt).first() - if row: - _proto = proto_class.FromString(row._mapping[proto_field_name]) - return python_class.from_proto(_proto) - if not_found_exception: - raise not_found_exception(name, project) - else: - return None - - def _list_objects( - self, - table: Table, - project: str, - proto_class: Any, - python_class: Any, - proto_field_name: str, - tags: Optional[dict[str, str]] = None, - ): - with self.read_engine.begin() as conn: - stmt = select(table).where(table.c.project_id == project) - rows = conn.execute(stmt).all() - if rows: - objects = [] - for row in rows: - obj = python_class.from_proto( - proto_class.FromString(row._mapping[proto_field_name]) - ) - if utils.has_all_tags(obj.tags, tags): - objects.append(obj) - return objects - return [] - - def _set_last_updated_metadata(self, last_updated: datetime, project: str): - with self.write_engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key - == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - feast_metadata.c.project_id == project, - ) - row = conn.execute(stmt).first() - - update_time = int(last_updated.timestamp()) - - values = { - "metadata_key": FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - "metadata_value": f"{update_time}", - "last_updated_timestamp": update_time, - "project_id": project, - } - if row: - update_stmt = ( - update(feast_metadata) - .where( - feast_metadata.c.metadata_key - == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - feast_metadata.c.project_id == project, - ) - .values(values) - ) - conn.execute(update_stmt) - else: - insert_stmt = insert(feast_metadata).values( - values, - ) - conn.execute(insert_stmt) - - def _get_last_updated_metadata(self, project: str): - with self.read_engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key - == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - feast_metadata.c.project_id == project, - ) - row = conn.execute(stmt).first() - if not row: - return None - update_time = int(row._mapping["last_updated_timestamp"]) - - return datetime.fromtimestamp(update_time, tz=timezone.utc) - - def _get_permission(self, name: str, project: str) -> Permission: - return self._get_object( - table=permissions, - name=name, - project=project, - proto_class=PermissionProto, - python_class=Permission, - id_field_name="permission_name", - proto_field_name="permission_proto", - not_found_exception=PermissionNotFoundException, - ) - - def _list_permissions( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[Permission]: - return self._list_objects( - permissions, - project, - PermissionProto, - Permission, - "permission_proto", - tags=tags, - ) - - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - return self._apply_object( - permissions, project, "permission_name", permission, "permission_proto" - ) - - def delete_permission(self, name: str, project: str, commit: bool = True): - with self.write_engine.begin() as conn: - stmt = delete(permissions).where( - permissions.c.permission_name == name, - permissions.c.project_id == project, - ) - rows = conn.execute(stmt) - if rows.rowcount < 1: - raise PermissionNotFoundException(name, project) - - def _list_projects( - self, - tags: Optional[dict[str, str]], - ) -> List[Project]: - with self.read_engine.begin() as conn: - stmt = select(projects) - rows = conn.execute(stmt).all() - if rows: - objects = [] - for row in rows: - obj = Project.from_proto( - ProjectProto.FromString(row._mapping["project_proto"]) - ) - if utils.has_all_tags(obj.tags, tags): - objects.append(obj) - return objects - return [] - - def _get_project( - self, - name: str, - ) -> Project: - return self._get_object( - table=projects, - name=name, - project=name, - proto_class=ProjectProto, - python_class=Project, - id_field_name="project_name", - proto_field_name="project_proto", - not_found_exception=ProjectObjectNotFoundException, - ) - - def apply_project( - self, - project: Project, - commit: bool = True, - ): - return self._apply_object( - projects, project.name, "project_name", project, "project_proto" - ) - - def delete_project( - self, - name: str, - commit: bool = True, - ): - project = self.get_project(name, allow_cache=False) - if project: - with self.write_engine.begin() as conn: - for t in { - managed_infra, - saved_datasets, - validation_references, - feature_services, - feature_views, - on_demand_feature_views, - stream_feature_views, - data_sources, - entities, - permissions, - feast_metadata, - projects, - }: - stmt = delete(t).where(t.c.project_id == name) - conn.execute(stmt) - return - - raise ProjectNotFoundException(name) +import logging +import uuid +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union, cast + +from pydantic import StrictInt, StrictStr +from sqlalchemy import ( # type: ignore + BigInteger, + Column, + Index, + LargeBinary, + MetaData, + String, + Table, + create_engine, + delete, + insert, + select, + update, +) +from sqlalchemy.engine import Engine + +from feast import utils +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.errors import ( + DataSourceObjectNotFoundException, + EntityNotFoundException, + FeatureServiceNotFoundException, + FeatureViewNotFoundException, + PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, + SavedDatasetNotFound, + ValidationReferenceNotFound, +) +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry.caching_registry import CachingRegistry +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto +from feast.protos.feast.core.FeatureService_pb2 import ( + FeatureService as FeatureServiceProto, +) +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto +from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( + OnDemandFeatureView as OnDemandFeatureViewProto, +) +from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto +from feast.protos.feast.core.StreamFeatureView_pb2 import ( + StreamFeatureView as StreamFeatureViewProto, +) +from feast.protos.feast.core.ValidationProfile_pb2 import ( + ValidationReference as ValidationReferenceProto, +) +from feast.repo_config import RegistryConfig +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView +from feast.utils import _utc_now + +metadata = MetaData() + + +projects = Table( + "projects", + metadata, + Column("project_id", String(255), primary_key=True), + Column("project_name", String(255), nullable=False), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("project_proto", LargeBinary, nullable=False), +) + +Index("idx_projects_project_id", projects.c.project_id) + +entities = Table( + "entities", + metadata, + Column("entity_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("entity_proto", LargeBinary, nullable=False), +) + +Index("idx_entities_project_id", entities.c.project_id) + +data_sources = Table( + "data_sources", + metadata, + Column("data_source_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("data_source_proto", LargeBinary, nullable=False), +) + +Index("idx_data_sources_project_id", data_sources.c.project_id) + +feature_views = Table( + "feature_views", + metadata, + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("materialized_intervals", LargeBinary, nullable=True), + Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), +) + +Index("idx_feature_views_project_id", feature_views.c.project_id) + +stream_feature_views = Table( + "stream_feature_views", + metadata, + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), +) + +Index("idx_stream_feature_views_project_id", stream_feature_views.c.project_id) + +on_demand_feature_views = Table( + "on_demand_feature_views", + metadata, + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), +) + +Index("idx_on_demand_feature_views_project_id", on_demand_feature_views.c.project_id) + +feature_services = Table( + "feature_services", + metadata, + Column("feature_service_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("feature_service_proto", LargeBinary, nullable=False), +) + +Index("idx_feature_services_project_id", feature_services.c.project_id) + +saved_datasets = Table( + "saved_datasets", + metadata, + Column("saved_dataset_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("saved_dataset_proto", LargeBinary, nullable=False), +) + +Index("idx_saved_datasets_project_id", saved_datasets.c.project_id) + +validation_references = Table( + "validation_references", + metadata, + Column("validation_reference_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("validation_reference_proto", LargeBinary, nullable=False), +) +Index("idx_validation_references_project_id", validation_references.c.project_id) + +managed_infra = Table( + "managed_infra", + metadata, + Column("infra_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("infra_proto", LargeBinary, nullable=False), +) + +Index("idx_managed_infra_project_id", managed_infra.c.project_id) + +permissions = Table( + "permissions", + metadata, + Column("permission_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("permission_proto", LargeBinary, nullable=False), +) + +Index("idx_permissions_project_id", permissions.c.project_id) + + +class FeastMetadataKeys(Enum): + LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" + PROJECT_UUID = "project_uuid" + + +feast_metadata = Table( + "feast_metadata", + metadata, + Column("project_id", String(255), primary_key=True), + Column("metadata_key", String(50), primary_key=True), + Column("metadata_value", String(50), nullable=False), + Column("last_updated_timestamp", BigInteger, nullable=False), +) + +Index("idx_feast_metadata_project_id", feast_metadata.c.project_id) + +logger = logging.getLogger(__name__) + + +class SqlRegistryConfig(RegistryConfig): + registry_type: StrictStr = "sql" + """ str: Provider name or a class name that implements Registry.""" + + path: StrictStr = "" + """ str: Path to metadata store. + If registry_type is 'sql', then this is a database URL as expected by SQLAlchemy """ + + read_path: Optional[StrictStr] = None + """ str: Read Path to metadata store if different from path. + If registry_type is 'sql', then this is a Read Endpoint for database URL. If not set, path will be used for read and write. """ + + sqlalchemy_config_kwargs: Dict[str, Any] = {"echo": False} + """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ + + cache_mode: StrictStr = "sync" + """ str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)""" + + thread_pool_executor_worker_count: StrictInt = 0 + """ int: Number of worker threads to use for asynchronous caching in SQL Registry. If set to 0, it doesn't use ThreadPoolExecutor. """ + + +class SqlRegistry(CachingRegistry): + def __init__( + self, + registry_config, + project: str, + repo_path: Optional[Path], + ): + assert registry_config is not None and isinstance( + registry_config, SqlRegistryConfig + ), "SqlRegistry needs a valid registry_config" + + self.registry_config = registry_config + + self.write_engine: Engine = create_engine( + registry_config.path, **registry_config.sqlalchemy_config_kwargs + ) + if registry_config.read_path: + self.read_engine: Engine = create_engine( + registry_config.read_path, + **registry_config.sqlalchemy_config_kwargs, + ) + else: + self.read_engine = self.write_engine + metadata.create_all(self.write_engine) + self.thread_pool_executor_worker_count = ( + registry_config.thread_pool_executor_worker_count + ) + self.purge_feast_metadata = registry_config.purge_feast_metadata + # Sync feast_metadata to projects table + # when purge_feast_metadata is set to True, Delete data from + # feast_metadata table and list_project_metadata will not return any data + self._sync_feast_metadata_to_projects_table() + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + super().__init__( + project=project, + cache_ttl_seconds=registry_config.cache_ttl_seconds, + cache_mode=registry_config.cache_mode, + ) + + def _sync_feast_metadata_to_projects_table(self): + feast_metadata_projects: dict = {} + projects_set: set = [] + with self.read_engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value + ) + rows = conn.execute(stmt).all() + for row in rows: + feast_metadata_projects[row._mapping["project_id"]] = int( + row._mapping["last_updated_timestamp"] + ) + + if len(feast_metadata_projects) > 0: + with self.read_engine.begin() as conn: + stmt = select(projects) + rows = conn.execute(stmt).all() + for row in rows: + projects_set.append(row._mapping["project_id"]) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects.keys()) - set(projects_set) + for project_name in projects_to_sync: + self.apply_project( + Project( + name=project_name, + created_timestamp=datetime.fromtimestamp( + feast_metadata_projects[project_name], tz=timezone.utc + ), + ), + commit=True, + ) + + if self.purge_feast_metadata: + with self.write_engine.begin() as conn: + for project_name in feast_metadata_projects: + stmt = delete(feast_metadata).where( + feast_metadata.c.project_id == project_name + ) + conn.execute(stmt) + + def teardown(self): + for t in { + entities, + data_sources, + feature_views, + feature_services, + on_demand_feature_views, + saved_datasets, + validation_references, + permissions, + }: + with self.write_engine.begin() as conn: + stmt = delete(t) + conn.execute(stmt) + + def _get_stream_feature_view(self, name: str, project: str): + return self._get_object( + table=stream_feature_views, + name=name, + project=project, + proto_class=StreamFeatureViewProto, + python_class=StreamFeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=FeatureViewNotFoundException, + ) + + def _list_stream_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[StreamFeatureView]: + return self._list_objects( + stream_feature_views, + project, + StreamFeatureViewProto, + StreamFeatureView, + "feature_view_proto", + tags=tags, + ) + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + return self._apply_object( + table=entities, + project=project, + id_field_name="entity_name", + obj=entity, + proto_field_name="entity_proto", + ) + + def _get_entity(self, name: str, project: str) -> Entity: + return self._get_object( + table=entities, + name=name, + project=project, + proto_class=EntityProto, + python_class=Entity, + id_field_name="entity_name", + proto_field_name="entity_proto", + not_found_exception=EntityNotFoundException, + ) + + def _get_any_feature_view(self, name: str, project: str) -> BaseFeatureView: + fv = self._get_object( + table=feature_views, + name=name, + project=project, + proto_class=FeatureViewProto, + python_class=FeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=None, + ) + + if not fv: + fv = self._get_object( + table=on_demand_feature_views, + name=name, + project=project, + proto_class=OnDemandFeatureViewProto, + python_class=OnDemandFeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=None, + ) + + if not fv: + fv = self._get_object( + table=stream_feature_views, + name=name, + project=project, + proto_class=StreamFeatureViewProto, + python_class=StreamFeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=FeatureViewNotFoundException, + ) + return fv + + def _list_all_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[BaseFeatureView]: + return ( + cast( + list[BaseFeatureView], + self._list_feature_views(project=project, tags=tags), + ) + + cast( + list[BaseFeatureView], + self._list_stream_feature_views(project=project, tags=tags), + ) + + cast( + list[BaseFeatureView], + self._list_on_demand_feature_views(project=project, tags=tags), + ) + ) + + def _get_feature_view(self, name: str, project: str) -> FeatureView: + return self._get_object( + table=feature_views, + name=name, + project=project, + proto_class=FeatureViewProto, + python_class=FeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=FeatureViewNotFoundException, + ) + + def _get_on_demand_feature_view( + self, name: str, project: str + ) -> OnDemandFeatureView: + return self._get_object( + table=on_demand_feature_views, + name=name, + project=project, + proto_class=OnDemandFeatureViewProto, + python_class=OnDemandFeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=FeatureViewNotFoundException, + ) + + def _get_feature_service(self, name: str, project: str) -> FeatureService: + return self._get_object( + table=feature_services, + name=name, + project=project, + proto_class=FeatureServiceProto, + python_class=FeatureService, + id_field_name="feature_service_name", + proto_field_name="feature_service_proto", + not_found_exception=FeatureServiceNotFoundException, + ) + + def _get_saved_dataset(self, name: str, project: str) -> SavedDataset: + return self._get_object( + table=saved_datasets, + name=name, + project=project, + proto_class=SavedDatasetProto, + python_class=SavedDataset, + id_field_name="saved_dataset_name", + proto_field_name="saved_dataset_proto", + not_found_exception=SavedDatasetNotFound, + ) + + def _get_validation_reference(self, name: str, project: str) -> ValidationReference: + return self._get_object( + table=validation_references, + name=name, + project=project, + proto_class=ValidationReferenceProto, + python_class=ValidationReference, + id_field_name="validation_reference_name", + proto_field_name="validation_reference_proto", + not_found_exception=ValidationReferenceNotFound, + ) + + def _list_validation_references( + self, project: str, tags: Optional[dict[str, str]] = None + ) -> List[ValidationReference]: + return self._list_objects( + table=validation_references, + project=project, + proto_class=ValidationReferenceProto, + python_class=ValidationReference, + proto_field_name="validation_reference_proto", + tags=tags, + ) + + def _list_entities( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[Entity]: + return self._list_objects( + entities, project, EntityProto, Entity, "entity_proto", tags=tags + ) + + def delete_entity(self, name: str, project: str, commit: bool = True): + return self._delete_object( + entities, name, project, "entity_name", EntityNotFoundException + ) + + def delete_feature_view(self, name: str, project: str, commit: bool = True): + deleted_count = 0 + for table in { + feature_views, + on_demand_feature_views, + stream_feature_views, + }: + deleted_count += self._delete_object( + table, name, project, "feature_view_name", None + ) + if deleted_count == 0: + raise FeatureViewNotFoundException(name, project) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + return self._delete_object( + feature_services, + name, + project, + "feature_service_name", + FeatureServiceNotFoundException, + ) + + def _get_data_source(self, name: str, project: str) -> DataSource: + return self._get_object( + table=data_sources, + name=name, + project=project, + proto_class=DataSourceProto, + python_class=DataSource, + id_field_name="data_source_name", + proto_field_name="data_source_proto", + not_found_exception=DataSourceObjectNotFoundException, + ) + + def _list_data_sources( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[DataSource]: + return self._list_objects( + data_sources, + project, + DataSourceProto, + DataSource, + "data_source_proto", + tags=tags, + ) + + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + return self._apply_object( + data_sources, project, "data_source_name", data_source, "data_source_proto" + ) + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + fv_table = self._infer_fv_table(feature_view) + + return self._apply_object( + fv_table, project, "feature_view_name", feature_view, "feature_view_proto" + ) + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + return self._apply_object( + feature_services, + project, + "feature_service_name", + feature_service, + "feature_service_proto", + ) + + def delete_data_source(self, name: str, project: str, commit: bool = True): + with self.write_engine.begin() as conn: + stmt = delete(data_sources).where( + data_sources.c.data_source_name == name, + data_sources.c.project_id == project, + ) + rows = conn.execute(stmt) + if rows.rowcount < 1: + raise DataSourceObjectNotFoundException(name, project) + + def _list_feature_services( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[FeatureService]: + return self._list_objects( + feature_services, + project, + FeatureServiceProto, + FeatureService, + "feature_service_proto", + tags=tags, + ) + + def _list_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[FeatureView]: + return self._list_objects( + feature_views, + project, + FeatureViewProto, + FeatureView, + "feature_view_proto", + tags=tags, + ) + + def _list_saved_datasets( + self, project: str, tags: Optional[dict[str, str]] = None + ) -> List[SavedDataset]: + return self._list_objects( + saved_datasets, + project, + SavedDatasetProto, + SavedDataset, + "saved_dataset_proto", + tags=tags, + ) + + def _list_on_demand_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[OnDemandFeatureView]: + return self._list_objects( + on_demand_feature_views, + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "feature_view_proto", + tags=tags, + ) + + def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: + with self.read_engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.project_id == project, + ) + rows = conn.execute(stmt).all() + if rows: + project_metadata = ProjectMetadata(project_name=project) + for row in rows: + if ( + row._mapping["metadata_key"] + == FeastMetadataKeys.PROJECT_UUID.value + ): + project_metadata.project_uuid = row._mapping["metadata_value"] + break + # TODO(adchia): Add other project metadata in a structured way + return [project_metadata] + return [] + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + return self._apply_object( + saved_datasets, + project, + "saved_dataset_name", + saved_dataset, + "saved_dataset_proto", + ) + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + return self._apply_object( + validation_references, + project, + "validation_reference_name", + validation_reference, + "validation_reference_proto", + ) + + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + table = self._infer_fv_table(feature_view) + python_class, proto_class = self._infer_fv_classes(feature_view) + + if python_class in {OnDemandFeatureView}: + raise ValueError( + f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" + ) + fv: Union[FeatureView, StreamFeatureView] = self._get_object( + table, + feature_view.name, + project, + proto_class, + python_class, + "feature_view_name", + "feature_view_proto", + FeatureViewNotFoundException, + ) + fv.materialization_intervals.append((start_date, end_date)) + self._apply_object( + table, project, "feature_view_name", fv, "feature_view_proto" + ) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + self._delete_object( + validation_references, + name, + project, + "validation_reference_name", + ValidationReferenceNotFound, + ) + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + self._apply_object( + table=managed_infra, + project=project, + id_field_name="infra_name", + obj=infra, + proto_field_name="infra_proto", + name="infra_obj", + ) + + def _get_infra(self, project: str) -> Infra: + infra_object = self._get_object( + table=managed_infra, + name="infra_obj", + project=project, + proto_class=InfraProto, + python_class=Infra, + id_field_name="infra_name", + proto_field_name="infra_proto", + not_found_exception=None, + ) + if infra_object: + return infra_object + return Infra() + + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + table = self._infer_fv_table(feature_view) + + name = feature_view.name + with self.write_engine.begin() as conn: + stmt = select(table).where( + getattr(table.c, "feature_view_name") == name, + table.c.project_id == project, + ) + row = conn.execute(stmt).first() + update_datetime = _utc_now() + update_time = int(update_datetime.timestamp()) + if row: + values = { + "user_metadata": metadata_bytes, + "last_updated_timestamp": update_time, + } + update_stmt = ( + update(table) + .where( + getattr(table.c, "feature_view_name") == name, + table.c.project_id == project, + ) + .values( + values, + ) + ) + conn.execute(update_stmt) + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def _infer_fv_table(self, feature_view): + if isinstance(feature_view, StreamFeatureView): + table = stream_feature_views + elif isinstance(feature_view, FeatureView): + table = feature_views + elif isinstance(feature_view, OnDemandFeatureView): + table = on_demand_feature_views + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return table + + def _infer_fv_classes(self, feature_view): + if isinstance(feature_view, StreamFeatureView): + python_class, proto_class = StreamFeatureView, StreamFeatureViewProto + elif isinstance(feature_view, FeatureView): + python_class, proto_class = FeatureView, FeatureViewProto + elif isinstance(feature_view, OnDemandFeatureView): + python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return python_class, proto_class + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + table = self._infer_fv_table(feature_view) + + name = feature_view.name + with self.read_engine.begin() as conn: + stmt = select(table).where(getattr(table.c, "feature_view_name") == name) + row = conn.execute(stmt).first() + if row: + return row._mapping["user_metadata"] + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def proto(self) -> RegistryProto: + r = RegistryProto() + last_updated_timestamps = [] + + def process_project(project: Project): + nonlocal r, last_updated_timestamps + project_name = project.name + last_updated_timestamp = project.last_updated_timestamp + + try: + cached_project = self.get_project(project_name, True) + except ProjectObjectNotFoundException: + cached_project = None + + allow_cache = False + + if cached_project is not None: + allow_cache = ( + last_updated_timestamp <= cached_project.last_updated_timestamp + ) + + r.projects.extend([project.to_proto()]) + last_updated_timestamps.append(last_updated_timestamp) + + for lister, registry_proto_field in [ + (self.list_entities, r.entities), + (self.list_feature_views, r.feature_views), + (self.list_data_sources, r.data_sources), + (self.list_on_demand_feature_views, r.on_demand_feature_views), + (self.list_stream_feature_views, r.stream_feature_views), + (self.list_feature_services, r.feature_services), + (self.list_saved_datasets, r.saved_datasets), + (self.list_validation_references, r.validation_references), + (self.list_permissions, r.permissions), + ]: + objs: List[Any] = lister(project_name, allow_cache) # type: ignore + if objs: + obj_protos = [obj.to_proto() for obj in objs] + for obj_proto in obj_protos: + if "spec" in obj_proto.DESCRIPTOR.fields_by_name: + obj_proto.spec.project = project_name + else: + obj_proto.project = project_name + registry_proto_field.extend(obj_protos) + + # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, + # the registry proto only has a single infra field, which we're currently setting as the "last" project. + r.infra.CopyFrom(self.get_infra(project_name).to_proto()) + + projects_list = self.list_projects(allow_cache=False) + if self.thread_pool_executor_worker_count == 0: + for project in projects_list: + process_project(project) + else: + with ThreadPoolExecutor( + max_workers=self.thread_pool_executor_worker_count + ) as executor: + executor.map(process_project, projects_list) + + if last_updated_timestamps: + r.last_updated.FromDatetime(max(last_updated_timestamps)) + + return r + + def commit(self): + # This method is a no-op since we're always writing values eagerly to the db. + pass + + def _initialize_project_if_not_exists(self, project_name: str): + try: + self.get_project(project_name, allow_cache=True) + return + except ProjectObjectNotFoundException: + try: + self.get_project(project_name, allow_cache=False) + return + except ProjectObjectNotFoundException: + self.apply_project(Project(name=project_name), commit=True) + + def _apply_object( + self, + table: Table, + project: str, + id_field_name: str, + obj: Any, + proto_field_name: str, + name: Optional[str] = None, + ): + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option + if not isinstance(obj, Project): + self._initialize_project_if_not_exists(project_name=project) + name = name or (obj.name if hasattr(obj, "name") else None) + assert name, f"name needs to be provided for {obj}" + + with self.write_engine.begin() as conn: + update_datetime = _utc_now() + update_time = int(update_datetime.timestamp()) + stmt = select(table).where( + getattr(table.c, id_field_name) == name, table.c.project_id == project + ) + row = conn.execute(stmt).first() + if hasattr(obj, "last_updated_timestamp"): + obj.last_updated_timestamp = update_datetime + + if row: + if proto_field_name in [ + "entity_proto", + "saved_dataset_proto", + "feature_view_proto", + "feature_service_proto", + "permission_proto", + "project_proto", + ]: + deserialized_proto = self.deserialize_registry_values( + row._mapping[proto_field_name], type(obj) + ) + obj.created_timestamp = ( + deserialized_proto.meta.created_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ) + ) + if isinstance(obj, (FeatureView, StreamFeatureView)): + obj.update_materialization_intervals( + type(obj) + .from_proto(deserialized_proto) + .materialization_intervals + ) + values = { + proto_field_name: obj.to_proto().SerializeToString(), + "last_updated_timestamp": update_time, + } + update_stmt = ( + update(table) + .where( + getattr(table.c, id_field_name) == name, + table.c.project_id == project, + ) + .values( + values, + ) + ) + conn.execute(update_stmt) + else: + obj_proto = obj.to_proto() + + if hasattr(obj_proto, "meta") and hasattr( + obj_proto.meta, "created_timestamp" + ): + if not obj_proto.meta.HasField("created_timestamp"): + obj_proto.meta.created_timestamp.FromDatetime(update_datetime) + + values = { + id_field_name: name, + proto_field_name: obj_proto.SerializeToString(), + "last_updated_timestamp": update_time, + "project_id": project, + } + insert_stmt = insert(table).values( + values, + ) + conn.execute(insert_stmt) + + if not isinstance(obj, Project): + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + if not self.purge_feast_metadata: + self._set_last_updated_metadata(update_datetime, project) + + def _maybe_init_project_metadata(self, project): + # Initialize project metadata if needed + with self.write_engine.begin() as conn: + update_datetime = _utc_now() + update_time = int(update_datetime.timestamp()) + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value, + feast_metadata.c.project_id == project, + ) + row = conn.execute(stmt).first() + if not row: + new_project_uuid = f"{uuid.uuid4()}" + values = { + "metadata_key": FeastMetadataKeys.PROJECT_UUID.value, + "metadata_value": new_project_uuid, + "last_updated_timestamp": update_time, + "project_id": project, + } + insert_stmt = insert(feast_metadata).values(values) + conn.execute(insert_stmt) + + def _delete_object( + self, + table: Table, + name: str, + project: str, + id_field_name: str, + not_found_exception: Optional[Callable], + ): + with self.write_engine.begin() as conn: + stmt = delete(table).where( + getattr(table.c, id_field_name) == name, table.c.project_id == project + ) + rows = conn.execute(stmt) + if rows.rowcount < 1 and not_found_exception: + raise not_found_exception(name, project) + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + if not self.purge_feast_metadata: + self._set_last_updated_metadata(_utc_now(), project) + + return rows.rowcount + + def _get_object( + self, + table: Table, + name: str, + project: str, + proto_class: Any, + python_class: Any, + id_field_name: str, + proto_field_name: str, + not_found_exception: Optional[Callable], + ): + with self.read_engine.begin() as conn: + stmt = select(table).where( + getattr(table.c, id_field_name) == name, table.c.project_id == project + ) + row = conn.execute(stmt).first() + if row: + _proto = proto_class.FromString(row._mapping[proto_field_name]) + return python_class.from_proto(_proto) + if not_found_exception: + raise not_found_exception(name, project) + else: + return None + + def _list_objects( + self, + table: Table, + project: str, + proto_class: Any, + python_class: Any, + proto_field_name: str, + tags: Optional[dict[str, str]] = None, + ): + with self.read_engine.begin() as conn: + stmt = select(table).where(table.c.project_id == project) + rows = conn.execute(stmt).all() + if rows: + objects = [] + for row in rows: + obj = python_class.from_proto( + proto_class.FromString(row._mapping[proto_field_name]) + ) + if utils.has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def _set_last_updated_metadata(self, last_updated: datetime, project: str): + with self.write_engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key + == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, + feast_metadata.c.project_id == project, + ) + row = conn.execute(stmt).first() + + update_time = int(last_updated.timestamp()) + + values = { + "metadata_key": FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, + "metadata_value": f"{update_time}", + "last_updated_timestamp": update_time, + "project_id": project, + } + if row: + update_stmt = ( + update(feast_metadata) + .where( + feast_metadata.c.metadata_key + == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, + feast_metadata.c.project_id == project, + ) + .values(values) + ) + conn.execute(update_stmt) + else: + insert_stmt = insert(feast_metadata).values( + values, + ) + conn.execute(insert_stmt) + + def _get_last_updated_metadata(self, project: str): + with self.read_engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key + == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, + feast_metadata.c.project_id == project, + ) + row = conn.execute(stmt).first() + if not row: + return None + update_time = int(row._mapping["last_updated_timestamp"]) + + return datetime.fromtimestamp(update_time, tz=timezone.utc) + + def _get_permission(self, name: str, project: str) -> Permission: + return self._get_object( + table=permissions, + name=name, + project=project, + proto_class=PermissionProto, + python_class=Permission, + id_field_name="permission_name", + proto_field_name="permission_proto", + not_found_exception=PermissionNotFoundException, + ) + + def _list_permissions( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[Permission]: + return self._list_objects( + permissions, + project, + PermissionProto, + Permission, + "permission_proto", + tags=tags, + ) + + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + return self._apply_object( + permissions, project, "permission_name", permission, "permission_proto" + ) + + def delete_permission(self, name: str, project: str, commit: bool = True): + with self.write_engine.begin() as conn: + stmt = delete(permissions).where( + permissions.c.permission_name == name, + permissions.c.project_id == project, + ) + rows = conn.execute(stmt) + if rows.rowcount < 1: + raise PermissionNotFoundException(name, project) + + def _list_projects( + self, + tags: Optional[dict[str, str]], + ) -> List[Project]: + with self.read_engine.begin() as conn: + stmt = select(projects) + rows = conn.execute(stmt).all() + if rows: + objects = [] + for row in rows: + obj = Project.from_proto( + ProjectProto.FromString(row._mapping["project_proto"]) + ) + if utils.has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def _get_project( + self, + name: str, + ) -> Project: + return self._get_object( + table=projects, + name=name, + project=name, + proto_class=ProjectProto, + python_class=Project, + id_field_name="project_name", + proto_field_name="project_proto", + not_found_exception=ProjectObjectNotFoundException, + ) + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + return self._apply_object( + projects, project.name, "project_name", project, "project_proto" + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + project = self.get_project(name, allow_cache=False) + if project: + with self.write_engine.begin() as conn: + for t in { + managed_infra, + saved_datasets, + validation_references, + feature_services, + feature_views, + on_demand_feature_views, + stream_feature_views, + data_sources, + entities, + permissions, + feast_metadata, + projects, + }: + stmt = delete(t).where(t.c.project_id == name) + conn.execute(stmt) + return + + raise ProjectNotFoundException(name) diff --git a/sdk/python/tests/unit/test_on_demand_python_transformation.py b/sdk/python/tests/unit/test_on_demand_python_transformation.py index eb29c645e53..fa88fcbb24d 100644 --- a/sdk/python/tests/unit/test_on_demand_python_transformation.py +++ b/sdk/python/tests/unit/test_on_demand_python_transformation.py @@ -1117,6 +1117,181 @@ def python_stored_writes_feature_view( "current_datetime": [None], } + def test_materialize_with_odfv_writes(self): + with tempfile.TemporaryDirectory() as data_dir: + self.store = FeatureStore( + config=RepoConfig( + project="test_on_demand_python_transformation", + registry=os.path.join(data_dir, "registry.db"), + provider="local", + entity_key_serialization_version=3, + online_store=SqliteOnlineStoreConfig( + path=os.path.join(data_dir, "online.db") + ), + ) + ) + + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) + start_date = end_date - timedelta(days=15) + + driver_entities = [1001, 1002, 1003, 1004, 1005] + driver_df = create_driver_hourly_stats_df( + driver_entities, start_date, end_date + ) + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") + driver_df.to_parquet( + path=driver_stats_path, allow_truncated_timestamps=True + ) + + driver = Entity(name="driver", join_keys=["driver_id"]) + + driver_stats_source = FileSource( + name="driver_hourly_stats_source", + path=driver_stats_path, + timestamp_field="event_timestamp", + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=1), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=driver_stats_source, + tags={}, + ) + + input_request_source = RequestSource( + name="vals_to_add", + schema=[ + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + ) + + @on_demand_feature_view( + entities=[driver], + sources=[ + driver_stats_fv[["conv_rate", "acc_rate"]], + input_request_source, + ], + schema=[ + Field(name="conv_rate_plus_acc", dtype=Float64), + Field(name="current_datetime", dtype=UnixTimestamp), + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + Field(name="string_constant", dtype=String), + ], + mode="python", + write_to_online_store=True, + ) + def python_stored_writes_feature_view( + inputs: dict[str, Any], + ) -> dict[str, Any]: + output: dict[str, Any] = { + "conv_rate_plus_acc": [ + conv_rate + acc_rate + for conv_rate, acc_rate in zip( + inputs["conv_rate"], inputs["acc_rate"] + ) + ], + "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], + "counter": [c + 1 for c in inputs["counter"]], + "input_datetime": [d for d in inputs["input_datetime"]], + "string_constant": ["test_constant"], + } + return output + + @on_demand_feature_view( + entities=[driver], + sources=[ + driver_stats_fv[["conv_rate", "acc_rate"]], + input_request_source, + ], + schema=[ + Field(name="conv_rate_plus_acc", dtype=Float64), + Field(name="current_datetime", dtype=UnixTimestamp), + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + Field(name="string_constant", dtype=String), + ], + mode="python", + write_to_online_store=False, + ) + def python_no_writes_feature_view( + inputs: dict[str, Any], + ) -> dict[str, Any]: + output: dict[str, Any] = { + "conv_rate_plus_acc": [ + conv_rate + acc_rate + for conv_rate, acc_rate in zip( + inputs["conv_rate"], inputs["acc_rate"] + ) + ], + "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], + "counter": [c + 1 for c in inputs["counter"]], + "input_datetime": [d for d in inputs["input_datetime"]], + "string_constant": ["test_constant"], + } + return output + + self.store.apply( + [ + driver, + driver_stats_source, + driver_stats_fv, + python_stored_writes_feature_view, + python_no_writes_feature_view, + ] + ) + + feature_views_to_materialize = self.store._get_feature_views_to_materialize( + None + ) + + odfv_names = [ + fv.name + for fv in feature_views_to_materialize + if hasattr(fv, "write_to_online_store") + ] + assert "python_stored_writes_feature_view" in odfv_names + assert "python_no_writes_feature_view" not in odfv_names + + regular_fv_names = [ + fv.name + for fv in feature_views_to_materialize + if not hasattr(fv, "write_to_online_store") + ] + assert "driver_hourly_stats" in regular_fv_names + + materialize_end_date = datetime.now().replace( + microsecond=0, second=0, minute=0 + ) + materialize_start_date = materialize_end_date - timedelta(days=1) + + self.store.materialize(materialize_start_date, materialize_end_date) + + specific_feature_views_to_materialize = ( + self.store._get_feature_views_to_materialize( + ["driver_hourly_stats", "python_stored_writes_feature_view"] + ) + ) + assert len(specific_feature_views_to_materialize) == 2 + + try: + self.store._get_feature_views_to_materialize( + ["python_no_writes_feature_view"] + ) + assert False, ( + "Should have raised ValueError for ODFV without write_to_online_store" + ) + except ValueError as e: + assert "not configured for write_to_online_store" in str(e) + def test_stored_writes_with_explode(self): with tempfile.TemporaryDirectory() as data_dir: self.store = FeatureStore( From d254cd99f6c14d398f98e48fa421c35ed1b2aec2 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Wed, 18 Jun 2025 22:43:28 -0400 Subject: [PATCH 2/7] feat: Enable materialization for ODFV Transform on Write Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/infra/passthrough_provider.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index b532ac563d4..d4b586f5c93 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -420,17 +420,24 @@ def ingest_df_to_offline_store(self, feature_view: FeatureView, table: pa.Table) def materialize_single_feature_view( self, config: RepoConfig, - feature_view: FeatureView, + feature_view: Union[FeatureView, OnDemandFeatureView], start_date: datetime, end_date: datetime, registry: BaseRegistry, project: str, tqdm_builder: Callable[[int], tqdm], ) -> None: + if isinstance(feature_view, OnDemandFeatureView): + if not feature_view.write_to_online_store: + raise ValueError( + f"OnDemandFeatureView {feature_view.name} does not have write_to_online_store enabled" + ) + return assert ( isinstance(feature_view, BatchFeatureView) or isinstance(feature_view, StreamFeatureView) or isinstance(feature_view, FeatureView) + or isinstance(feature_view, OnDemandFeatureView) ), f"Unexpected type for {feature_view.name}: {type(feature_view)}" task = MaterializationTask( project=project, From b7f3f96a85c804aae91aa9bcc12769ae1763dfb9 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Wed, 18 Jun 2025 23:34:35 -0400 Subject: [PATCH 3/7] adding more explicit tests Signed-off-by: Francisco Javier Arceo --- .../test_on_demand_python_transformation.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/sdk/python/tests/unit/test_on_demand_python_transformation.py b/sdk/python/tests/unit/test_on_demand_python_transformation.py index fa88fcbb24d..7c492d04d47 100644 --- a/sdk/python/tests/unit/test_on_demand_python_transformation.py +++ b/sdk/python/tests/unit/test_on_demand_python_transformation.py @@ -1282,6 +1282,35 @@ def python_no_writes_feature_view( ) assert len(specific_feature_views_to_materialize) == 2 + # materialize some data into the online store for the python_stored_writes_feature_view + self.store.materialize( + materialize_start_date, + materialize_end_date, + ["python_stored_writes_feature_view"], + ) + # validate data is loaded to online store + online_response = self.store.get_online_features( + entity_rows=[{"driver_id": 1001}], + features=[ + "python_stored_writes_feature_view:conv_rate_plus_acc", + "python_stored_writes_feature_view:current_datetime", + "python_stored_writes_feature_view:counter", + "python_stored_writes_feature_view:input_datetime", + "python_stored_writes_feature_view:string_constant", + ], + ).to_dict() + assert sorted(list(online_response.keys())) == sorted( + [ + "driver_id", + "conv_rate_plus_acc", + "counter", + "current_datetime", + "input_datetime", + "string_constant", + ] + ) + assert online_response["driver_id"] == [1001] + try: self.store._get_feature_views_to_materialize( ["python_no_writes_feature_view"] From 83e93d153fb751593048150510d3bd7adc3837d4 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 19 Jun 2025 21:33:09 -0400 Subject: [PATCH 4/7] undoing registry Signed-off-by: Francisco Javier Arceo --- .../feast/infra/registry/base_registry.py | 1848 ++++++++--------- 1 file changed, 924 insertions(+), 924 deletions(-) diff --git a/sdk/python/feast/infra/registry/base_registry.py b/sdk/python/feast/infra/registry/base_registry.py index c6780ef546d..f2374edf1b2 100644 --- a/sdk/python/feast/infra/registry/base_registry.py +++ b/sdk/python/feast/infra/registry/base_registry.py @@ -1,924 +1,924 @@ -# Copyright 2019 The Feast Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import warnings -from abc import ABC, abstractmethod -from collections import defaultdict -from datetime import datetime -from typing import Any, Dict, List, Optional, Union - -from google.protobuf.json_format import MessageToJson -from google.protobuf.message import Message - -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.infra.infra_object import Infra -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto -from feast.protos.feast.core.FeatureService_pb2 import ( - FeatureService as FeatureServiceProto, -) -from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto -from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( - OnDemandFeatureView as OnDemandFeatureViewProto, -) -from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto -from feast.protos.feast.core.Project_pb2 import Project as ProjectProto -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto -from feast.protos.feast.core.StreamFeatureView_pb2 import ( - StreamFeatureView as StreamFeatureViewProto, -) -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView -from feast.transformation.pandas_transformation import PandasTransformation -from feast.transformation.substrait_transformation import SubstraitTransformation - - -class BaseRegistry(ABC): - """ - The interface that Feast uses to apply, list, retrieve, and delete Feast objects (e.g. entities, - feature views, and data sources). - """ - - # Entity operations - @abstractmethod - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - """ - Registers a single entity with Feast - - Args: - entity: Entity that will be registered - project: Feast project that this entity belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def delete_entity(self, name: str, project: str, commit: bool = True): - """ - Deletes an entity or raises an exception if not found. - - Args: - name: Name of entity - project: Feast project that this entity belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - """ - Retrieves an entity. - - Args: - name: Name of entity - project: Feast project that this entity belongs to - allow_cache: Whether to allow returning this entity from a cached registry - - Returns: - Returns either the specified entity, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_entities( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Entity]: - """ - Retrieve a list of entities from the registry - - Args: - allow_cache: Whether to allow returning entities from a cached registry - project: Filter entities based on project name - tags: Filter by tags - - Returns: - List of entities - """ - raise NotImplementedError - - # Data source operations - @abstractmethod - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - """ - Registers a single data source with Feast - - Args: - data_source: A data source that will be registered - project: Feast project that this data source belongs to - commit: Whether to immediately commit to the registry - """ - raise NotImplementedError - - @abstractmethod - def delete_data_source(self, name: str, project: str, commit: bool = True): - """ - Deletes a data source or raises an exception if not found. - - Args: - name: Name of data source - project: Feast project that this data source belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - """ - Retrieves a data source. - - Args: - name: Name of data source - project: Feast project that this data source belongs to - allow_cache: Whether to allow returning this data source from a cached registry - - Returns: - Returns either the specified data source, or raises an exception if none is found - """ - raise NotImplementedError - - @abstractmethod - def list_data_sources( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[DataSource]: - """ - Retrieve a list of data sources from the registry - - Args: - project: Filter data source based on project name - allow_cache: Whether to allow returning data sources from a cached registry - tags: Filter by tags - - Returns: - List of data sources - """ - raise NotImplementedError - - # Feature service operations - @abstractmethod - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - """ - Registers a single feature service with Feast - - Args: - feature_service: A feature service that will be registered - project: Feast project that this entity belongs to - """ - raise NotImplementedError - - @abstractmethod - def delete_feature_service(self, name: str, project: str, commit: bool = True): - """ - Deletes a feature service or raises an exception if not found. - - Args: - name: Name of feature service - project: Feast project that this feature service belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - """ - Retrieves a feature service. - - Args: - name: Name of feature service - project: Feast project that this feature service belongs to - allow_cache: Whether to allow returning this feature service from a cached registry - - Returns: - Returns either the specified feature service, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_feature_services( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureService]: - """ - Retrieve a list of feature services from the registry - - Args: - allow_cache: Whether to allow returning entities from a cached registry - project: Filter entities based on project name - tags: Filter by tags - - Returns: - List of feature services - """ - raise NotImplementedError - - # Feature view operations - @abstractmethod - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - """ - Registers a single feature view with Feast - - Args: - feature_view: Feature view that will be registered - project: Feast project that this feature view belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def delete_feature_view(self, name: str, project: str, commit: bool = True): - """ - Deletes a feature view or raises an exception if not found. - - Args: - name: Name of feature view - project: Feast project that this feature view belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - # stream feature view operations - @abstractmethod - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> StreamFeatureView: - """ - Retrieves a stream feature view. - - Args: - name: Name of stream feature view - project: Feast project that this feature view belongs to - allow_cache: Allow returning feature view from the cached registry - - Returns: - Returns either the specified feature view, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_stream_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[StreamFeatureView]: - """ - Retrieve a list of stream feature views from the registry - - Args: - project: Filter stream feature views based on project name - allow_cache: Whether to allow returning stream feature views from a cached registry - tags: Filter by tags - - Returns: - List of stream feature views - """ - raise NotImplementedError - - # on demand feature view operations - @abstractmethod - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> OnDemandFeatureView: - """ - Retrieves an on demand feature view. - - Args: - name: Name of on demand feature view - project: Feast project that this on demand feature view belongs to - allow_cache: Whether to allow returning this on demand feature view from a cached registry - - Returns: - Returns either the specified on demand feature view, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_on_demand_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[OnDemandFeatureView]: - """ - Retrieve a list of on demand feature views from the registry - - Args: - project: Filter on demand feature views based on project name - allow_cache: Whether to allow returning on demand feature views from a cached registry - tags: Filter by tags - - Returns: - List of on demand feature views - """ - raise NotImplementedError - - # regular feature view operations - @abstractmethod - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - """ - Retrieves a feature view. - - Args: - name: Name of feature view - project: Feast project that this feature view belongs to - allow_cache: Allow returning feature view from the cached registry - - Returns: - Returns either the specified feature view, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureView]: - """ - Retrieve a list of feature views from the registry - - Args: - allow_cache: Allow returning feature views from the cached registry - project: Filter feature views based on project name - tags: Filter by tags - - Returns: - List of feature views - """ - raise NotImplementedError - - @abstractmethod - def get_any_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> BaseFeatureView: - """ - Retrieves a feature view of any type. - - Args: - name: Name of feature view - project: Feast project that this feature view belongs to - allow_cache: Allow returning feature view from the cached registry - - Returns: - Returns either the specified feature view, or raises an exception if - none is found - """ - raise NotImplementedError - - @abstractmethod - def list_all_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[BaseFeatureView]: - """ - Retrieve a list of feature views of all types from the registry - - Args: - allow_cache: Allow returning feature views from the cached registry - project: Filter feature views based on project name - tags: Filter by tags - - Returns: - List of feature views - """ - raise NotImplementedError - - @abstractmethod - def apply_materialization( - self, - feature_view: Union[FeatureView, OnDemandFeatureView], - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - """ - Updates materialization intervals tracked for a single feature view in Feast - - Args: - feature_view: Feature view that will be updated with an additional materialization interval tracked - project: Feast project that this feature view belongs to - start_date (datetime): Start date of the materialization interval to track - end_date (datetime): End date of the materialization interval to track - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - # Saved dataset operations - @abstractmethod - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - """ - Stores a saved dataset metadata with Feast - - Args: - saved_dataset: SavedDataset that will be added / updated to registry - project: Feast project that this dataset belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - """ - Retrieves a saved dataset. - - Args: - name: Name of dataset - project: Feast project that this dataset belongs to - allow_cache: Whether to allow returning this dataset from a cached registry - - Returns: - Returns either the specified SavedDataset, or raises an exception if - none is found - """ - raise NotImplementedError - - def delete_saved_dataset(self, name: str, project: str, commit: bool = True): - """ - Delete a saved dataset. - - Args: - name: Name of dataset - project: Feast project that this dataset belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def list_saved_datasets( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[SavedDataset]: - """ - Retrieves a list of all saved datasets in specified project - - Args: - project: Feast project - allow_cache: Whether to allow returning this dataset from a cached registry - tags: Filter by tags - - Returns: - Returns the list of SavedDatasets - """ - raise NotImplementedError - - # Validation reference operations - @abstractmethod - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - """ - Persist a validation reference - - Args: - validation_reference: ValidationReference that will be added / updated to registry - project: Feast project that this dataset belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - """ - Deletes a validation reference or raises an exception if not found. - - Args: - name: Name of validation reference - project: Feast project that this object belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - """ - Retrieves a validation reference. - - Args: - name: Name of dataset - project: Feast project that this dataset belongs to - allow_cache: Whether to allow returning this dataset from a cached registry - - Returns: - Returns either the specified ValidationReference, or raises an exception if - none is found - """ - raise NotImplementedError - - # TODO: Needs to be implemented. - def list_validation_references( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[ValidationReference]: - """ - Retrieve a list of validation references from the registry - - Args: - project: Filter validation references based on project name - allow_cache: Allow returning validation references from the cached registry - tags: Filter by tags - - Returns: - List of request validation references - """ - raise NotImplementedError - - @abstractmethod - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - """ - Retrieves project metadata - - Args: - project: Filter metadata based on project name - allow_cache: Allow returning feature views from the cached registry - - Returns: - List of project metadata - """ - raise NotImplementedError - - @abstractmethod - def update_infra(self, infra: Infra, project: str, commit: bool = True): - """ - Updates the stored Infra object. - - Args: - infra: The new Infra object to be stored. - project: Feast project that the Infra object refers to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - """ - Retrieves the stored Infra object. - - Args: - project: Feast project that the Infra object refers to - allow_cache: Whether to allow returning this entity from a cached registry - - Returns: - The stored Infra object. - """ - raise NotImplementedError - - @abstractmethod - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): ... - - @abstractmethod - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: ... - - # Permission operations - @abstractmethod - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - """ - Registers a single permission with Feast - - Args: - permission: A permission that will be registered - project: Feast project that this permission belongs to - commit: Whether to immediately commit to the registry - """ - raise NotImplementedError - - @abstractmethod - def delete_permission(self, name: str, project: str, commit: bool = True): - """ - Deletes a permission or raises an exception if not found. - - Args: - name: Name of permission - project: Feast project that this permission belongs to - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_permission( - self, name: str, project: str, allow_cache: bool = False - ) -> Permission: - """ - Retrieves a permission. - - Args: - name: Name of permission - project: Feast project that this permission belongs to - allow_cache: Whether to allow returning this permission from a cached registry - - Returns: - Returns either the specified permission, or raises an exception if none is found - """ - raise NotImplementedError - - @abstractmethod - def list_permissions( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Permission]: - """ - Retrieve a list of permissions from the registry - - Args: - project: Filter permission based on project name - allow_cache: Whether to allow returning permissions from a cached registry - - Returns: - List of permissions - """ - raise NotImplementedError - - @abstractmethod - def apply_project( - self, - project: Project, - commit: bool = True, - ): - """ - Registers a project with Feast - - Args: - project: A project that will be registered - commit: Whether to immediately commit to the registry - """ - raise NotImplementedError - - @abstractmethod - def delete_project( - self, - name: str, - commit: bool = True, - ): - """ - Deletes a project or raises an ProjectNotFoundException exception if not found. - - Args: - project: Feast project name that needs to be deleted - commit: Whether the change should be persisted immediately - """ - raise NotImplementedError - - @abstractmethod - def get_project( - self, - name: str, - allow_cache: bool = False, - ) -> Project: - """ - Retrieves a project. - - Args: - name: Feast project name - allow_cache: Whether to allow returning this permission from a cached registry - - Returns: - Returns either the specified project, or raises ProjectObjectNotFoundException exception if none is found - """ - raise NotImplementedError - - @abstractmethod - def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Project]: - """ - Retrieve a list of projects from the registry - - Args: - allow_cache: Whether to allow returning permissions from a cached registry - - Returns: - List of project - """ - raise NotImplementedError - - @abstractmethod - def proto(self) -> RegistryProto: - """ - Retrieves a proto version of the registry. - - Returns: - The registry proto object. - """ - raise NotImplementedError - - @abstractmethod - def commit(self): - """Commits the state of the registry cache to the remote registry store.""" - raise NotImplementedError - - @abstractmethod - def refresh(self, project: Optional[str] = None): - """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" - raise NotImplementedError - - @staticmethod - def _message_to_sorted_dict(message: Message) -> Dict[str, Any]: - return json.loads(MessageToJson(message, sort_keys=True)) - - def to_dict(self, project: str) -> Dict[str, List[Any]]: - """Returns a dictionary representation of the registry contents for the specified project. - - For each list in the dictionary, the elements are sorted by name, so this - method can be used to compare two registries. - - Args: - project: Feast project to convert to a dict - """ - registry_dict: Dict[str, Any] = defaultdict(list) - registry_dict["project"] = project - for project_metadata in sorted(self.list_project_metadata(project=project)): - registry_dict["projectMetadata"].append( - self._message_to_sorted_dict(project_metadata.to_proto()) - ) - for data_source in sorted( - self.list_data_sources(project=project), key=lambda ds: ds.name - ): - registry_dict["dataSources"].append( - self._message_to_sorted_dict(data_source.to_proto()) - ) - for entity in sorted( - self.list_entities(project=project), - key=lambda entity: entity.name, - ): - registry_dict["entities"].append( - self._message_to_sorted_dict(entity.to_proto()) - ) - for feature_view in sorted( - self.list_feature_views(project=project), - key=lambda feature_view: feature_view.name, - ): - registry_dict["featureViews"].append( - self._message_to_sorted_dict(feature_view.to_proto()) - ) - for feature_service in sorted( - self.list_feature_services(project=project), - key=lambda feature_service: feature_service.name, - ): - registry_dict["featureServices"].append( - self._message_to_sorted_dict(feature_service.to_proto()) - ) - for on_demand_feature_view in sorted( - self.list_on_demand_feature_views(project=project), - key=lambda on_demand_feature_view: on_demand_feature_view.name, - ): - odfv_dict = self._message_to_sorted_dict(on_demand_feature_view.to_proto()) - # We are logging a warning because the registry object may be read from a proto that is not updated - # i.e., we have to submit dual writes but in order to ensure the read behavior succeeds we have to load - # both objects to compare any changes in the registry - warnings.warn( - "We will be deprecating the usage of spec.userDefinedFunction in a future release please upgrade cautiously.", - DeprecationWarning, - ) - if on_demand_feature_view.feature_transformation: - if isinstance( - on_demand_feature_view.feature_transformation, PandasTransformation - ): - if "userDefinedFunction" not in odfv_dict["spec"]: - odfv_dict["spec"]["userDefinedFunction"] = {} - odfv_dict["spec"]["userDefinedFunction"]["body"] = ( - on_demand_feature_view.feature_transformation.udf_string - ) - odfv_dict["spec"]["featureTransformation"]["userDefinedFunction"][ - "body" - ] = on_demand_feature_view.feature_transformation.udf_string - elif isinstance( - on_demand_feature_view.feature_transformation, - SubstraitTransformation, - ): - odfv_dict["spec"]["featureTransformation"]["substraitPlan"][ - "body" - ] = on_demand_feature_view.feature_transformation.substrait_plan - else: - odfv_dict["spec"]["featureTransformation"]["userDefinedFunction"][ - "body" - ] = None - odfv_dict["spec"]["featureTransformation"]["substraitPlan"][ - "body" - ] = None - registry_dict["onDemandFeatureViews"].append(odfv_dict) - for stream_feature_view in sorted( - self.list_stream_feature_views(project=project), - key=lambda stream_feature_view: stream_feature_view.name, - ): - sfv_dict = self._message_to_sorted_dict(stream_feature_view.to_proto()) - - sfv_dict["spec"]["userDefinedFunction"]["body"] = ( - stream_feature_view.udf_string - ) - registry_dict["streamFeatureViews"].append(sfv_dict) - - for saved_dataset in sorted( - self.list_saved_datasets(project=project), key=lambda item: item.name - ): - registry_dict["savedDatasets"].append( - self._message_to_sorted_dict(saved_dataset.to_proto()) - ) - for infra_object in sorted(self.get_infra(project=project).infra_objects): - registry_dict["infra"].append( - self._message_to_sorted_dict(infra_object.to_proto()) - ) - for permission in sorted( - self.list_permissions(project=project), key=lambda ds: ds.name - ): - registry_dict["permissions"].append( - self._message_to_sorted_dict(permission.to_proto()) - ) - - return registry_dict - - @staticmethod - def deserialize_registry_values(serialized_proto, feast_obj_type) -> Any: - if feast_obj_type == Entity: - return EntityProto.FromString(serialized_proto) - if feast_obj_type == SavedDataset: - return SavedDatasetProto.FromString(serialized_proto) - if feast_obj_type == FeatureView: - return FeatureViewProto.FromString(serialized_proto) - if feast_obj_type == StreamFeatureView: - return StreamFeatureViewProto.FromString(serialized_proto) - if feast_obj_type == OnDemandFeatureView: - return OnDemandFeatureViewProto.FromString(serialized_proto) - if feast_obj_type == FeatureService: - return FeatureServiceProto.FromString(serialized_proto) - if feast_obj_type == Permission: - return PermissionProto.FromString(serialized_proto) - if feast_obj_type == Project: - return ProjectProto.FromString(serialized_proto) - return None +# Copyright 2019 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +from google.protobuf.json_format import MessageToJson +from google.protobuf.message import Message + +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto +from feast.protos.feast.core.FeatureService_pb2 import ( + FeatureService as FeatureServiceProto, +) +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( + OnDemandFeatureView as OnDemandFeatureViewProto, +) +from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto +from feast.protos.feast.core.StreamFeatureView_pb2 import ( + StreamFeatureView as StreamFeatureViewProto, +) +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView +from feast.transformation.pandas_transformation import PandasTransformation +from feast.transformation.substrait_transformation import SubstraitTransformation + + +class BaseRegistry(ABC): + """ + The interface that Feast uses to apply, list, retrieve, and delete Feast objects (e.g. entities, + feature views, and data sources). + """ + + # Entity operations + @abstractmethod + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + """ + Registers a single entity with Feast + + Args: + entity: Entity that will be registered + project: Feast project that this entity belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def delete_entity(self, name: str, project: str, commit: bool = True): + """ + Deletes an entity or raises an exception if not found. + + Args: + name: Name of entity + project: Feast project that this entity belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + """ + Retrieves an entity. + + Args: + name: Name of entity + project: Feast project that this entity belongs to + allow_cache: Whether to allow returning this entity from a cached registry + + Returns: + Returns either the specified entity, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: + """ + Retrieve a list of entities from the registry + + Args: + allow_cache: Whether to allow returning entities from a cached registry + project: Filter entities based on project name + tags: Filter by tags + + Returns: + List of entities + """ + raise NotImplementedError + + # Data source operations + @abstractmethod + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + """ + Registers a single data source with Feast + + Args: + data_source: A data source that will be registered + project: Feast project that this data source belongs to + commit: Whether to immediately commit to the registry + """ + raise NotImplementedError + + @abstractmethod + def delete_data_source(self, name: str, project: str, commit: bool = True): + """ + Deletes a data source or raises an exception if not found. + + Args: + name: Name of data source + project: Feast project that this data source belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + """ + Retrieves a data source. + + Args: + name: Name of data source + project: Feast project that this data source belongs to + allow_cache: Whether to allow returning this data source from a cached registry + + Returns: + Returns either the specified data source, or raises an exception if none is found + """ + raise NotImplementedError + + @abstractmethod + def list_data_sources( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[DataSource]: + """ + Retrieve a list of data sources from the registry + + Args: + project: Filter data source based on project name + allow_cache: Whether to allow returning data sources from a cached registry + tags: Filter by tags + + Returns: + List of data sources + """ + raise NotImplementedError + + # Feature service operations + @abstractmethod + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + """ + Registers a single feature service with Feast + + Args: + feature_service: A feature service that will be registered + project: Feast project that this entity belongs to + """ + raise NotImplementedError + + @abstractmethod + def delete_feature_service(self, name: str, project: str, commit: bool = True): + """ + Deletes a feature service or raises an exception if not found. + + Args: + name: Name of feature service + project: Feast project that this feature service belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + """ + Retrieves a feature service. + + Args: + name: Name of feature service + project: Feast project that this feature service belongs to + allow_cache: Whether to allow returning this feature service from a cached registry + + Returns: + Returns either the specified feature service, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_feature_services( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureService]: + """ + Retrieve a list of feature services from the registry + + Args: + allow_cache: Whether to allow returning entities from a cached registry + project: Filter entities based on project name + tags: Filter by tags + + Returns: + List of feature services + """ + raise NotImplementedError + + # Feature view operations + @abstractmethod + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + """ + Registers a single feature view with Feast + + Args: + feature_view: Feature view that will be registered + project: Feast project that this feature view belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def delete_feature_view(self, name: str, project: str, commit: bool = True): + """ + Deletes a feature view or raises an exception if not found. + + Args: + name: Name of feature view + project: Feast project that this feature view belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + # stream feature view operations + @abstractmethod + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> StreamFeatureView: + """ + Retrieves a stream feature view. + + Args: + name: Name of stream feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_stream_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[StreamFeatureView]: + """ + Retrieve a list of stream feature views from the registry + + Args: + project: Filter stream feature views based on project name + allow_cache: Whether to allow returning stream feature views from a cached registry + tags: Filter by tags + + Returns: + List of stream feature views + """ + raise NotImplementedError + + # on demand feature view operations + @abstractmethod + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + """ + Retrieves an on demand feature view. + + Args: + name: Name of on demand feature view + project: Feast project that this on demand feature view belongs to + allow_cache: Whether to allow returning this on demand feature view from a cached registry + + Returns: + Returns either the specified on demand feature view, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_on_demand_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[OnDemandFeatureView]: + """ + Retrieve a list of on demand feature views from the registry + + Args: + project: Filter on demand feature views based on project name + allow_cache: Whether to allow returning on demand feature views from a cached registry + tags: Filter by tags + + Returns: + List of on demand feature views + """ + raise NotImplementedError + + # regular feature view operations + @abstractmethod + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + """ + Retrieves a feature view. + + Args: + name: Name of feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureView]: + """ + Retrieve a list of feature views from the registry + + Args: + allow_cache: Allow returning feature views from the cached registry + project: Filter feature views based on project name + tags: Filter by tags + + Returns: + List of feature views + """ + raise NotImplementedError + + @abstractmethod + def get_any_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> BaseFeatureView: + """ + Retrieves a feature view of any type. + + Args: + name: Name of feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ + raise NotImplementedError + + @abstractmethod + def list_all_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[BaseFeatureView]: + """ + Retrieve a list of feature views of all types from the registry + + Args: + allow_cache: Allow returning feature views from the cached registry + project: Filter feature views based on project name + tags: Filter by tags + + Returns: + List of feature views + """ + raise NotImplementedError + + @abstractmethod + def apply_materialization( + self, + feature_view: FeatureView, + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + """ + Updates materialization intervals tracked for a single feature view in Feast + + Args: + feature_view: Feature view that will be updated with an additional materialization interval tracked + project: Feast project that this feature view belongs to + start_date (datetime): Start date of the materialization interval to track + end_date (datetime): End date of the materialization interval to track + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + # Saved dataset operations + @abstractmethod + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + """ + Stores a saved dataset metadata with Feast + + Args: + saved_dataset: SavedDataset that will be added / updated to registry + project: Feast project that this dataset belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + """ + Retrieves a saved dataset. + + Args: + name: Name of dataset + project: Feast project that this dataset belongs to + allow_cache: Whether to allow returning this dataset from a cached registry + + Returns: + Returns either the specified SavedDataset, or raises an exception if + none is found + """ + raise NotImplementedError + + def delete_saved_dataset(self, name: str, project: str, commit: bool = True): + """ + Delete a saved dataset. + + Args: + name: Name of dataset + project: Feast project that this dataset belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def list_saved_datasets( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[SavedDataset]: + """ + Retrieves a list of all saved datasets in specified project + + Args: + project: Feast project + allow_cache: Whether to allow returning this dataset from a cached registry + tags: Filter by tags + + Returns: + Returns the list of SavedDatasets + """ + raise NotImplementedError + + # Validation reference operations + @abstractmethod + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + """ + Persist a validation reference + + Args: + validation_reference: ValidationReference that will be added / updated to registry + project: Feast project that this dataset belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + """ + Deletes a validation reference or raises an exception if not found. + + Args: + name: Name of validation reference + project: Feast project that this object belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + """ + Retrieves a validation reference. + + Args: + name: Name of dataset + project: Feast project that this dataset belongs to + allow_cache: Whether to allow returning this dataset from a cached registry + + Returns: + Returns either the specified ValidationReference, or raises an exception if + none is found + """ + raise NotImplementedError + + # TODO: Needs to be implemented. + def list_validation_references( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[ValidationReference]: + """ + Retrieve a list of validation references from the registry + + Args: + project: Filter validation references based on project name + allow_cache: Allow returning validation references from the cached registry + tags: Filter by tags + + Returns: + List of request validation references + """ + raise NotImplementedError + + @abstractmethod + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + """ + Retrieves project metadata + + Args: + project: Filter metadata based on project name + allow_cache: Allow returning feature views from the cached registry + + Returns: + List of project metadata + """ + raise NotImplementedError + + @abstractmethod + def update_infra(self, infra: Infra, project: str, commit: bool = True): + """ + Updates the stored Infra object. + + Args: + infra: The new Infra object to be stored. + project: Feast project that the Infra object refers to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + """ + Retrieves the stored Infra object. + + Args: + project: Feast project that the Infra object refers to + allow_cache: Whether to allow returning this entity from a cached registry + + Returns: + The stored Infra object. + """ + raise NotImplementedError + + @abstractmethod + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): ... + + @abstractmethod + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: ... + + # Permission operations + @abstractmethod + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + """ + Registers a single permission with Feast + + Args: + permission: A permission that will be registered + project: Feast project that this permission belongs to + commit: Whether to immediately commit to the registry + """ + raise NotImplementedError + + @abstractmethod + def delete_permission(self, name: str, project: str, commit: bool = True): + """ + Deletes a permission or raises an exception if not found. + + Args: + name: Name of permission + project: Feast project that this permission belongs to + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_permission( + self, name: str, project: str, allow_cache: bool = False + ) -> Permission: + """ + Retrieves a permission. + + Args: + name: Name of permission + project: Feast project that this permission belongs to + allow_cache: Whether to allow returning this permission from a cached registry + + Returns: + Returns either the specified permission, or raises an exception if none is found + """ + raise NotImplementedError + + @abstractmethod + def list_permissions( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Permission]: + """ + Retrieve a list of permissions from the registry + + Args: + project: Filter permission based on project name + allow_cache: Whether to allow returning permissions from a cached registry + + Returns: + List of permissions + """ + raise NotImplementedError + + @abstractmethod + def apply_project( + self, + project: Project, + commit: bool = True, + ): + """ + Registers a project with Feast + + Args: + project: A project that will be registered + commit: Whether to immediately commit to the registry + """ + raise NotImplementedError + + @abstractmethod + def delete_project( + self, + name: str, + commit: bool = True, + ): + """ + Deletes a project or raises an ProjectNotFoundException exception if not found. + + Args: + project: Feast project name that needs to be deleted + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + """ + Retrieves a project. + + Args: + name: Feast project name + allow_cache: Whether to allow returning this permission from a cached registry + + Returns: + Returns either the specified project, or raises ProjectObjectNotFoundException exception if none is found + """ + raise NotImplementedError + + @abstractmethod + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + """ + Retrieve a list of projects from the registry + + Args: + allow_cache: Whether to allow returning permissions from a cached registry + + Returns: + List of project + """ + raise NotImplementedError + + @abstractmethod + def proto(self) -> RegistryProto: + """ + Retrieves a proto version of the registry. + + Returns: + The registry proto object. + """ + raise NotImplementedError + + @abstractmethod + def commit(self): + """Commits the state of the registry cache to the remote registry store.""" + raise NotImplementedError + + @abstractmethod + def refresh(self, project: Optional[str] = None): + """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" + raise NotImplementedError + + @staticmethod + def _message_to_sorted_dict(message: Message) -> Dict[str, Any]: + return json.loads(MessageToJson(message, sort_keys=True)) + + def to_dict(self, project: str) -> Dict[str, List[Any]]: + """Returns a dictionary representation of the registry contents for the specified project. + + For each list in the dictionary, the elements are sorted by name, so this + method can be used to compare two registries. + + Args: + project: Feast project to convert to a dict + """ + registry_dict: Dict[str, Any] = defaultdict(list) + registry_dict["project"] = project + for project_metadata in sorted(self.list_project_metadata(project=project)): + registry_dict["projectMetadata"].append( + self._message_to_sorted_dict(project_metadata.to_proto()) + ) + for data_source in sorted( + self.list_data_sources(project=project), key=lambda ds: ds.name + ): + registry_dict["dataSources"].append( + self._message_to_sorted_dict(data_source.to_proto()) + ) + for entity in sorted( + self.list_entities(project=project), + key=lambda entity: entity.name, + ): + registry_dict["entities"].append( + self._message_to_sorted_dict(entity.to_proto()) + ) + for feature_view in sorted( + self.list_feature_views(project=project), + key=lambda feature_view: feature_view.name, + ): + registry_dict["featureViews"].append( + self._message_to_sorted_dict(feature_view.to_proto()) + ) + for feature_service in sorted( + self.list_feature_services(project=project), + key=lambda feature_service: feature_service.name, + ): + registry_dict["featureServices"].append( + self._message_to_sorted_dict(feature_service.to_proto()) + ) + for on_demand_feature_view in sorted( + self.list_on_demand_feature_views(project=project), + key=lambda on_demand_feature_view: on_demand_feature_view.name, + ): + odfv_dict = self._message_to_sorted_dict(on_demand_feature_view.to_proto()) + # We are logging a warning because the registry object may be read from a proto that is not updated + # i.e., we have to submit dual writes but in order to ensure the read behavior succeeds we have to load + # both objects to compare any changes in the registry + warnings.warn( + "We will be deprecating the usage of spec.userDefinedFunction in a future release please upgrade cautiously.", + DeprecationWarning, + ) + if on_demand_feature_view.feature_transformation: + if isinstance( + on_demand_feature_view.feature_transformation, PandasTransformation + ): + if "userDefinedFunction" not in odfv_dict["spec"]: + odfv_dict["spec"]["userDefinedFunction"] = {} + odfv_dict["spec"]["userDefinedFunction"]["body"] = ( + on_demand_feature_view.feature_transformation.udf_string + ) + odfv_dict["spec"]["featureTransformation"]["userDefinedFunction"][ + "body" + ] = on_demand_feature_view.feature_transformation.udf_string + elif isinstance( + on_demand_feature_view.feature_transformation, + SubstraitTransformation, + ): + odfv_dict["spec"]["featureTransformation"]["substraitPlan"][ + "body" + ] = on_demand_feature_view.feature_transformation.substrait_plan + else: + odfv_dict["spec"]["featureTransformation"]["userDefinedFunction"][ + "body" + ] = None + odfv_dict["spec"]["featureTransformation"]["substraitPlan"][ + "body" + ] = None + registry_dict["onDemandFeatureViews"].append(odfv_dict) + for stream_feature_view in sorted( + self.list_stream_feature_views(project=project), + key=lambda stream_feature_view: stream_feature_view.name, + ): + sfv_dict = self._message_to_sorted_dict(stream_feature_view.to_proto()) + + sfv_dict["spec"]["userDefinedFunction"]["body"] = ( + stream_feature_view.udf_string + ) + registry_dict["streamFeatureViews"].append(sfv_dict) + + for saved_dataset in sorted( + self.list_saved_datasets(project=project), key=lambda item: item.name + ): + registry_dict["savedDatasets"].append( + self._message_to_sorted_dict(saved_dataset.to_proto()) + ) + for infra_object in sorted(self.get_infra(project=project).infra_objects): + registry_dict["infra"].append( + self._message_to_sorted_dict(infra_object.to_proto()) + ) + for permission in sorted( + self.list_permissions(project=project), key=lambda ds: ds.name + ): + registry_dict["permissions"].append( + self._message_to_sorted_dict(permission.to_proto()) + ) + + return registry_dict + + @staticmethod + def deserialize_registry_values(serialized_proto, feast_obj_type) -> Any: + if feast_obj_type == Entity: + return EntityProto.FromString(serialized_proto) + if feast_obj_type == SavedDataset: + return SavedDatasetProto.FromString(serialized_proto) + if feast_obj_type == FeatureView: + return FeatureViewProto.FromString(serialized_proto) + if feast_obj_type == StreamFeatureView: + return StreamFeatureViewProto.FromString(serialized_proto) + if feast_obj_type == OnDemandFeatureView: + return OnDemandFeatureViewProto.FromString(serialized_proto) + if feast_obj_type == FeatureService: + return FeatureServiceProto.FromString(serialized_proto) + if feast_obj_type == Permission: + return PermissionProto.FromString(serialized_proto) + if feast_obj_type == Project: + return ProjectProto.FromString(serialized_proto) + return None From e78494e936c8402e7b04a1641e19d03b0d3e03ee Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 19 Jun 2025 21:34:12 -0400 Subject: [PATCH 5/7] reducing changes Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/infra/registry/base_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/python/feast/infra/registry/base_registry.py b/sdk/python/feast/infra/registry/base_registry.py index f2374edf1b2..85810f1fbc1 100644 --- a/sdk/python/feast/infra/registry/base_registry.py +++ b/sdk/python/feast/infra/registry/base_registry.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from google.protobuf.json_format import MessageToJson from google.protobuf.message import Message @@ -432,7 +432,7 @@ def list_all_feature_views( @abstractmethod def apply_materialization( self, - feature_view: FeatureView, + feature_view: Union[FeatureView, OnDemandFeatureView], project: str, start_date: datetime, end_date: datetime, From 05d7804494feeb2583c8334f7f8368e6c6d3f2a3 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 19 Jun 2025 21:37:38 -0400 Subject: [PATCH 6/7] removing line changes Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/infra/registry/registry.py | 2180 +++++++++---------- 1 file changed, 1090 insertions(+), 1090 deletions(-) diff --git a/sdk/python/feast/infra/registry/registry.py b/sdk/python/feast/infra/registry/registry.py index acb82546b4f..0cfbc77b24e 100644 --- a/sdk/python/feast/infra/registry/registry.py +++ b/sdk/python/feast/infra/registry/registry.py @@ -1,1090 +1,1090 @@ -# Copyright 2019 The Feast Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from datetime import datetime, timedelta, timezone -from enum import Enum -from pathlib import Path -from threading import Lock -from typing import Any, Dict, List, Optional, Union -from urllib.parse import urlparse - -from google.protobuf.internal.containers import RepeatedCompositeFieldContainer -from google.protobuf.message import Message - -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.errors import ( - ConflictingFeatureViewNames, - DataSourceNotFoundException, - EntityNotFoundException, - FeatureServiceNotFoundException, - FeatureViewNotFoundException, - PermissionNotFoundException, - ProjectNotFoundException, - ProjectObjectNotFoundException, - ValidationReferenceNotFound, -) -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.importer import import_class -from feast.infra.infra_object import Infra -from feast.infra.registry import proto_registry_utils -from feast.infra.registry.base_registry import BaseRegistry -from feast.infra.registry.registry_store import NoopRegistryStore -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.auth_model import AuthConfig, NoAuthConfig -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.repo_config import RegistryConfig -from feast.repo_contents import RepoContents -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView -from feast.utils import _utc_now - -REGISTRY_SCHEMA_VERSION = "1" - -REGISTRY_STORE_CLASS_FOR_TYPE = { - "GCSRegistryStore": "feast.infra.registry.gcs.GCSRegistryStore", - "S3RegistryStore": "feast.infra.registry.s3.S3RegistryStore", - "FileRegistryStore": "feast.infra.registry.file.FileRegistryStore", - "AzureRegistryStore": "feast.infra.registry.contrib.azure.azure_registry_store.AzBlobRegistryStore", -} - -REGISTRY_STORE_CLASS_FOR_SCHEME = { - "gs": "GCSRegistryStore", - "s3": "S3RegistryStore", - "file": "FileRegistryStore", - "": "FileRegistryStore", -} - - -class FeastObjectType(Enum): - PROJECT = "project" - DATA_SOURCE = "data source" - ENTITY = "entity" - FEATURE_VIEW = "feature view" - ON_DEMAND_FEATURE_VIEW = "on demand feature view" - STREAM_FEATURE_VIEW = "stream feature view" - FEATURE_SERVICE = "feature service" - PERMISSION = "permission" - - @staticmethod - def get_objects_from_registry( - registry: "BaseRegistry", project: str - ) -> Dict["FeastObjectType", List[Any]]: - return { - FeastObjectType.PROJECT: [ - project_obj - for project_obj in registry.list_projects() - if project_obj.name == project - ], - FeastObjectType.DATA_SOURCE: registry.list_data_sources(project=project), - FeastObjectType.ENTITY: registry.list_entities(project=project), - FeastObjectType.FEATURE_VIEW: registry.list_feature_views(project=project), - FeastObjectType.ON_DEMAND_FEATURE_VIEW: registry.list_on_demand_feature_views( - project=project - ), - FeastObjectType.STREAM_FEATURE_VIEW: registry.list_stream_feature_views( - project=project, - ), - FeastObjectType.FEATURE_SERVICE: registry.list_feature_services( - project=project - ), - FeastObjectType.PERMISSION: registry.list_permissions(project=project), - } - - @staticmethod - def get_objects_from_repo_contents( - repo_contents: RepoContents, - ) -> Dict["FeastObjectType", List[Any]]: - return { - FeastObjectType.PROJECT: repo_contents.projects, - FeastObjectType.DATA_SOURCE: repo_contents.data_sources, - FeastObjectType.ENTITY: repo_contents.entities, - FeastObjectType.FEATURE_VIEW: repo_contents.feature_views, - FeastObjectType.ON_DEMAND_FEATURE_VIEW: repo_contents.on_demand_feature_views, - FeastObjectType.STREAM_FEATURE_VIEW: repo_contents.stream_feature_views, - FeastObjectType.FEATURE_SERVICE: repo_contents.feature_services, - FeastObjectType.PERMISSION: repo_contents.permissions, - } - - -FEAST_OBJECT_TYPES = [feast_object_type for feast_object_type in FeastObjectType] - -logger = logging.getLogger(__name__) - - -def get_registry_store_class_from_type(registry_store_type: str): - if not registry_store_type.endswith("RegistryStore"): - raise Exception('Registry store class name should end with "RegistryStore"') - if registry_store_type in REGISTRY_STORE_CLASS_FOR_TYPE: - registry_store_type = REGISTRY_STORE_CLASS_FOR_TYPE[registry_store_type] - module_name, registry_store_class_name = registry_store_type.rsplit(".", 1) - - return import_class(module_name, registry_store_class_name, "RegistryStore") - - -def get_registry_store_class_from_scheme(registry_path: str): - uri = urlparse(registry_path) - if uri.scheme not in REGISTRY_STORE_CLASS_FOR_SCHEME: - raise Exception( - f"Registry path {registry_path} has unsupported scheme {uri.scheme}. " - f"Supported schemes are file, s3 and gs." - ) - else: - registry_store_type = REGISTRY_STORE_CLASS_FOR_SCHEME[uri.scheme] - return get_registry_store_class_from_type(registry_store_type) - - -class Registry(BaseRegistry): - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): - pass - - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: - pass - - # The cached_registry_proto object is used for both reads and writes. In particular, - # all write operations refresh the cache and modify it in memory; the write must - # then be persisted to the underlying RegistryStore with a call to commit(). - cached_registry_proto: RegistryProto - cached_registry_proto_created: datetime - cached_registry_proto_ttl: timedelta - - def __init__( - self, - project: str, - registry_config: Optional[RegistryConfig], - repo_path: Optional[Path], - auth_config: AuthConfig = NoAuthConfig(), - ): - """ - Create the Registry object. - - Args: - registry_config: RegistryConfig object containing the destination path and cache ttl, - repo_path: Path to the base of the Feast repository - or where it will be created if it does not exist yet. - """ - - self._refresh_lock = Lock() - self._auth_config = auth_config - - registry_proto = RegistryProto() - registry_proto.registry_schema_version = REGISTRY_SCHEMA_VERSION - self.cached_registry_proto = registry_proto - self.cached_registry_proto_created = _utc_now() - - self.purge_feast_metadata = ( - registry_config.purge_feast_metadata - if registry_config is not None - else False - ) - - if registry_config: - registry_store_type = registry_config.registry_store_type - registry_path = registry_config.path - if registry_store_type is None: - cls = get_registry_store_class_from_scheme(registry_path) - else: - cls = get_registry_store_class_from_type(str(registry_store_type)) - - self._registry_store = cls(registry_config, repo_path) - self.cached_registry_proto_ttl = timedelta( - seconds=( - registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 - ) - ) - - try: - registry_proto = self._registry_store.get_registry_proto() - self.cached_registry_proto = registry_proto - self.cached_registry_proto_created = _utc_now() - # Sync feast_metadata to projects table - # when purge_feast_metadata is set to True, Delete data from - # feast_metadata table and list_project_metadata will not return any data - self._sync_feast_metadata_to_projects_table() - except FileNotFoundError: - logger.info("Registry file not found. Creating new registry.") - self.commit() - - def _sync_feast_metadata_to_projects_table(self): - """ - Sync feast_metadata to projects table - """ - feast_metadata_projects = [] - projects_set = [] - # List of project in project_metadata - for project_metadata in self.cached_registry_proto.project_metadata: - project = ProjectMetadata.from_proto(project_metadata) - feast_metadata_projects.append(project.project_name) - if len(feast_metadata_projects) > 0: - # List of project in projects - for project_metadata in self.cached_registry_proto.projects: - project = Project.from_proto(project_metadata) - projects_set.append(project.name) - - # Find object in feast_metadata_projects but not in projects - projects_to_sync = set(feast_metadata_projects) - set(projects_set) - # Sync feast_metadata to projects table - for project_name in projects_to_sync: - project = Project(name=project_name) - self.cached_registry_proto.projects.append(project.to_proto()) - - if self.purge_feast_metadata: - self.cached_registry_proto.project_metadata = [] - - def clone(self) -> "Registry": - new_registry = Registry("project", None, None, self._auth_config) - new_registry.cached_registry_proto_ttl = timedelta(seconds=0) - new_registry.cached_registry_proto = ( - self.cached_registry_proto.__deepcopy__() - if self.cached_registry_proto - else RegistryProto() - ) - new_registry.cached_registry_proto_created = _utc_now() - new_registry._registry_store = NoopRegistryStore() - return new_registry - - def update_infra(self, infra: Infra, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - self.cached_registry_proto.infra.CopyFrom(infra.to_proto()) - if commit: - self.commit() - - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return Infra.from_proto(registry_proto.infra) - - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - entity.is_valid() - - now = _utc_now() - if not entity.created_timestamp: - entity.created_timestamp = now - entity.last_updated_timestamp = now - - entity_proto = entity.to_proto() - entity_proto.spec.project = project - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_entity_proto in enumerate( - self.cached_registry_proto.entities - ): - if ( - existing_entity_proto.spec.name == entity_proto.spec.name - and existing_entity_proto.spec.project == project - ): - entity.created_timestamp = ( - existing_entity_proto.meta.created_timestamp.ToDatetime() - ) - entity_proto = entity.to_proto() - entity_proto.spec.project = project - del self.cached_registry_proto.entities[idx] - break - self.cached_registry_proto.entities.append(entity_proto) - if commit: - self.commit() - - def list_entities( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Entity]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_entities(registry_proto, project, tags) - - def list_data_sources( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[DataSource]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_data_sources(registry_proto, project, tags) - - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - registry = self._prepare_registry_for_changes(project) - for idx, existing_data_source_proto in enumerate(registry.data_sources): - if existing_data_source_proto.name == data_source.name: - del registry.data_sources[idx] - data_source_proto = data_source.to_proto() - data_source_proto.project = project - data_source_proto.data_source_class_type = ( - f"{data_source.__class__.__module__}.{data_source.__class__.__name__}" - ) - self.cached_registry_proto.data_sources.append(data_source_proto) - if commit: - self.commit() - - def delete_data_source(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, data_source_proto in enumerate( - self.cached_registry_proto.data_sources - ): - if data_source_proto.name == name: - del self.cached_registry_proto.data_sources[idx] - if commit: - self.commit() - return - raise DataSourceNotFoundException(name) - - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - now = _utc_now() - if not feature_service.created_timestamp: - feature_service.created_timestamp = now - feature_service.last_updated_timestamp = now - - feature_service_proto = feature_service.to_proto() - feature_service_proto.spec.project = project - - registry = self._prepare_registry_for_changes(project) - - for idx, existing_feature_service_proto in enumerate(registry.feature_services): - if ( - existing_feature_service_proto.spec.name - == feature_service_proto.spec.name - and existing_feature_service_proto.spec.project == project - ): - feature_service.created_timestamp = ( - existing_feature_service_proto.meta.created_timestamp.ToDatetime() - ) - feature_service_proto = feature_service.to_proto() - feature_service_proto.spec.project = project - del registry.feature_services[idx] - self.cached_registry_proto.feature_services.append(feature_service_proto) - if commit: - self.commit() - - def list_feature_services( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureService]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_feature_services(registry_proto, project, tags) - - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_feature_service(registry_proto, name, project) - - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_entity(registry_proto, name, project) - - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - feature_view.ensure_valid() - - now = _utc_now() - if not feature_view.created_timestamp: - feature_view.created_timestamp = now - feature_view.last_updated_timestamp = now - - feature_view_proto = feature_view.to_proto() - feature_view_proto.spec.project = project - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - self._check_conflicting_feature_view_names(feature_view) - existing_feature_views_of_same_type: RepeatedCompositeFieldContainer - if isinstance(feature_view, StreamFeatureView): - existing_feature_views_of_same_type = ( - self.cached_registry_proto.stream_feature_views - ) - elif isinstance(feature_view, FeatureView): - existing_feature_views_of_same_type = ( - self.cached_registry_proto.feature_views - ) - elif isinstance(feature_view, OnDemandFeatureView): - existing_feature_views_of_same_type = ( - self.cached_registry_proto.on_demand_feature_views - ) - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - - for idx, existing_feature_view_proto in enumerate( - existing_feature_views_of_same_type - ): - if ( - existing_feature_view_proto.spec.name == feature_view_proto.spec.name - and existing_feature_view_proto.spec.project == project - ): - if ( - feature_view.__class__.from_proto(existing_feature_view_proto) - == feature_view - ): - return - else: - existing_feature_view = type(feature_view).from_proto( - existing_feature_view_proto - ) - feature_view.created_timestamp = ( - existing_feature_view.created_timestamp - ) - if isinstance(feature_view, (FeatureView, StreamFeatureView)): - feature_view.update_materialization_intervals( - existing_feature_view.materialization_intervals - ) - feature_view_proto = feature_view.to_proto() - feature_view_proto.spec.project = project - del existing_feature_views_of_same_type[idx] - break - - existing_feature_views_of_same_type.append(feature_view_proto) - if commit: - self.commit() - - def list_stream_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[StreamFeatureView]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_stream_feature_views( - registry_proto, project, tags - ) - - def list_on_demand_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[OnDemandFeatureView]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_on_demand_feature_views( - registry_proto, project, tags - ) - - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> OnDemandFeatureView: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_on_demand_feature_view( - registry_proto, name, project - ) - - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_data_source(registry_proto, name, project) - - def apply_materialization( - self, - feature_view: Union[FeatureView, OnDemandFeatureView], - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_feature_view_proto in enumerate( - self.cached_registry_proto.feature_views - ): - if ( - existing_feature_view_proto.spec.name == feature_view.name - and existing_feature_view_proto.spec.project == project - ): - existing_feature_view = FeatureView.from_proto( - existing_feature_view_proto - ) - existing_feature_view.materialization_intervals.append( - (start_date, end_date) - ) - existing_feature_view.last_updated_timestamp = _utc_now() - feature_view_proto = existing_feature_view.to_proto() - feature_view_proto.spec.project = project - del self.cached_registry_proto.feature_views[idx] - self.cached_registry_proto.feature_views.append(feature_view_proto) - if commit: - self.commit() - return - - for idx, existing_stream_feature_view_proto in enumerate( - self.cached_registry_proto.stream_feature_views - ): - if ( - existing_stream_feature_view_proto.spec.name == feature_view.name - and existing_stream_feature_view_proto.spec.project == project - ): - existing_stream_feature_view = StreamFeatureView.from_proto( - existing_stream_feature_view_proto - ) - existing_stream_feature_view.materialization_intervals.append( - (start_date, end_date) - ) - existing_stream_feature_view.last_updated_timestamp = _utc_now() - stream_feature_view_proto = existing_stream_feature_view.to_proto() - stream_feature_view_proto.spec.project = project - del self.cached_registry_proto.stream_feature_views[idx] - self.cached_registry_proto.stream_feature_views.append( - stream_feature_view_proto - ) - if commit: - self.commit() - return - - def list_all_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[BaseFeatureView]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_all_feature_views( - registry_proto, project, tags - ) - - def get_any_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> BaseFeatureView: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_any_feature_view(registry_proto, name, project) - - def list_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureView]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_feature_views(registry_proto, project, tags) - - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_feature_view(registry_proto, name, project) - - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> StreamFeatureView: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_stream_feature_view( - registry_proto, name, project - ) - - def delete_feature_service(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, feature_service_proto in enumerate( - self.cached_registry_proto.feature_services - ): - if ( - feature_service_proto.spec.name == name - and feature_service_proto.spec.project == project - ): - del self.cached_registry_proto.feature_services[idx] - if commit: - self.commit() - return - raise FeatureServiceNotFoundException(name, project) - - def delete_feature_view(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_feature_view_proto in enumerate( - self.cached_registry_proto.feature_views - ): - if ( - existing_feature_view_proto.spec.name == name - and existing_feature_view_proto.spec.project == project - ): - del self.cached_registry_proto.feature_views[idx] - if commit: - self.commit() - return - - for idx, existing_on_demand_feature_view_proto in enumerate( - self.cached_registry_proto.on_demand_feature_views - ): - if ( - existing_on_demand_feature_view_proto.spec.name == name - and existing_on_demand_feature_view_proto.spec.project == project - ): - del self.cached_registry_proto.on_demand_feature_views[idx] - if commit: - self.commit() - return - - for idx, existing_stream_feature_view_proto in enumerate( - self.cached_registry_proto.stream_feature_views - ): - if ( - existing_stream_feature_view_proto.spec.name == name - and existing_stream_feature_view_proto.spec.project == project - ): - del self.cached_registry_proto.stream_feature_views[idx] - if commit: - self.commit() - return - - raise FeatureViewNotFoundException(name, project) - - def delete_entity(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_entity_proto in enumerate( - self.cached_registry_proto.entities - ): - if ( - existing_entity_proto.spec.name == name - and existing_entity_proto.spec.project == project - ): - del self.cached_registry_proto.entities[idx] - if commit: - self.commit() - return - - raise EntityNotFoundException(name, project) - - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - now = _utc_now() - if not saved_dataset.created_timestamp: - saved_dataset.created_timestamp = now - saved_dataset.last_updated_timestamp = now - - saved_dataset_proto = saved_dataset.to_proto() - saved_dataset_proto.spec.project = project - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, existing_saved_dataset_proto in enumerate( - self.cached_registry_proto.saved_datasets - ): - if ( - existing_saved_dataset_proto.spec.name == saved_dataset_proto.spec.name - and existing_saved_dataset_proto.spec.project == project - ): - saved_dataset.created_timestamp = ( - existing_saved_dataset_proto.meta.created_timestamp.ToDatetime() - ) - saved_dataset.min_event_timestamp = ( - existing_saved_dataset_proto.meta.min_event_timestamp.ToDatetime() - ) - saved_dataset.max_event_timestamp = ( - existing_saved_dataset_proto.meta.max_event_timestamp.ToDatetime() - ) - saved_dataset_proto = saved_dataset.to_proto() - saved_dataset_proto.spec.project = project - del self.cached_registry_proto.saved_datasets[idx] - break - - self.cached_registry_proto.saved_datasets.append(saved_dataset_proto) - if commit: - self.commit() - - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_saved_dataset(registry_proto, name, project) - - def list_saved_datasets( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[SavedDataset]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_saved_datasets(registry_proto, project, tags) - - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - validation_reference_proto = validation_reference.to_proto() - validation_reference_proto.project = project - - registry_proto = self._prepare_registry_for_changes(project) - for idx, existing_validation_reference in enumerate( - registry_proto.validation_references - ): - if ( - existing_validation_reference.name == validation_reference_proto.name - and existing_validation_reference.project == project - ): - del registry_proto.validation_references[idx] - break - - registry_proto.validation_references.append(validation_reference_proto) - if commit: - self.commit() - - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_validation_reference( - registry_proto, name, project - ) - - def list_validation_references( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[ValidationReference]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_validation_references( - registry_proto, project, tags - ) - - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - for idx, existing_validation_reference in enumerate( - self.cached_registry_proto.validation_references - ): - if ( - existing_validation_reference.name == name - and existing_validation_reference.project == project - ): - del self.cached_registry_proto.validation_references[idx] - if commit: - self.commit() - return - raise ValidationReferenceNotFound(name, project=project) - - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_project_metadata(registry_proto, project) - - def commit(self): - """Commits the state of the registry cache to the remote registry store.""" - if self.cached_registry_proto: - self._registry_store.update_registry_proto(self.cached_registry_proto) - - def refresh(self, project: Optional[str] = None): - """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" - self._get_registry_proto(project=project, allow_cache=False) - - def teardown(self): - """Tears down (removes) the registry.""" - self._registry_store.teardown() - - def proto(self) -> RegistryProto: - return self.cached_registry_proto or RegistryProto() - - def _prepare_registry_for_changes(self, project_name: str): - """Prepares the Registry for changes by refreshing the cache if necessary.""" - - assert self.cached_registry_proto is not None - - try: - # Check if the project exists in the registry cache - self.get_project(name=project_name, allow_cache=True) - return self.cached_registry_proto - except ProjectObjectNotFoundException: - # If the project does not exist in cache, refresh cache from store - registry_proto = self._registry_store.get_registry_proto() - self.cached_registry_proto = registry_proto - self.cached_registry_proto_created = _utc_now() - - try: - # Check if the project exists in the registry cache after refresh from store - self.get_project(name=project_name) - except ProjectObjectNotFoundException: - # If the project still does not exist, create it - project_proto = Project(name=project_name).to_proto() - self.cached_registry_proto.projects.append(project_proto) - if not self.purge_feast_metadata: - project_metadata_proto = ProjectMetadata( - project_name=project_name - ).to_proto() - self.cached_registry_proto.project_metadata.append( - project_metadata_proto - ) - self.commit() - return self.cached_registry_proto - - def _get_registry_proto( - self, project: Optional[str], allow_cache: bool = False - ) -> RegistryProto: - """Returns the cached or remote registry state - - Args: - project: Name of the Feast project (optional) - allow_cache: Whether to allow the use of the registry cache when fetching the RegistryProto - - Returns: Returns a RegistryProto object which represents the state of the registry - """ - with self._refresh_lock: - expired = (self.cached_registry_proto_created is None) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - _utc_now() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl - ) - ) - ) - - if allow_cache and not expired: - return self.cached_registry_proto - logger.info("Registry cache expired, so refreshing") - registry_proto = self._registry_store.get_registry_proto() - self.cached_registry_proto = registry_proto - self.cached_registry_proto_created = _utc_now() - return registry_proto - - def _check_conflicting_feature_view_names(self, feature_view: BaseFeatureView): - name_to_fv_protos = self._existing_feature_view_names_to_fvs() - if feature_view.name in name_to_fv_protos: - if not isinstance( - name_to_fv_protos.get(feature_view.name), feature_view.proto_class - ): - raise ConflictingFeatureViewNames(feature_view.name) - - def _existing_feature_view_names_to_fvs(self) -> Dict[str, Message]: - assert self.cached_registry_proto - odfvs = { - fv.spec.name: fv - for fv in self.cached_registry_proto.on_demand_feature_views - } - fvs = {fv.spec.name: fv for fv in self.cached_registry_proto.feature_views} - sfv = { - fv.spec.name: fv for fv in self.cached_registry_proto.stream_feature_views - } - return {**odfvs, **fvs, **sfv} - - def get_permission( - self, name: str, project: str, allow_cache: bool = False - ) -> Permission: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.get_permission(registry_proto, name, project) - - def list_permissions( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Permission]: - registry_proto = self._get_registry_proto( - project=project, allow_cache=allow_cache - ) - return proto_registry_utils.list_permissions(registry_proto, project, tags) - - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - now = _utc_now() - if not permission.created_timestamp: - permission.created_timestamp = now - permission.last_updated_timestamp = now - - registry = self._prepare_registry_for_changes(project) - for idx, existing_permission_proto in enumerate(registry.permissions): - if ( - existing_permission_proto.spec.name == permission.name - and existing_permission_proto.spec.project == project - ): - permission.created_timestamp = ( - existing_permission_proto.meta.created_timestamp.ToDatetime() - ) - del registry.permissions[idx] - - permission_proto = permission.to_proto() - permission_proto.spec.project = project - self.cached_registry_proto.permissions.append(permission_proto) - if commit: - self.commit() - - def delete_permission(self, name: str, project: str, commit: bool = True): - self._prepare_registry_for_changes(project) - assert self.cached_registry_proto - - for idx, permission_proto in enumerate(self.cached_registry_proto.permissions): - if ( - permission_proto.spec.name == name - and permission_proto.spec.project == project - ): - del self.cached_registry_proto.permissions[idx] - if commit: - self.commit() - return - raise PermissionNotFoundException(name, project) - - def apply_project( - self, - project: Project, - commit: bool = True, - ): - registry = self.cached_registry_proto - - for idx, existing_project_proto in enumerate(registry.projects): - if existing_project_proto.spec.name == project.name: - project.created_timestamp = ( - existing_project_proto.meta.created_timestamp.ToDatetime().replace( - tzinfo=timezone.utc - ) - ) - del registry.projects[idx] - - project_proto = project.to_proto() - self.cached_registry_proto.projects.append(project_proto) - if commit: - self.commit() - - def get_project( - self, - name: str, - allow_cache: bool = False, - ) -> Project: - registry_proto = self._get_registry_proto(project=name, allow_cache=allow_cache) - return proto_registry_utils.get_project(registry_proto, name) - - def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Project]: - registry_proto = self._get_registry_proto(project=None, allow_cache=allow_cache) - return proto_registry_utils.list_projects( - registry_proto=registry_proto, tags=tags - ) - - def delete_project( - self, - name: str, - commit: bool = True, - ): - assert self.cached_registry_proto - - for idx, project_proto in enumerate(self.cached_registry_proto.projects): - if project_proto.spec.name == name: - list_validation_references = self.list_validation_references(name) - for validation_reference in list_validation_references: - self.delete_validation_reference(validation_reference.name, name) - - list_saved_datasets = self.list_saved_datasets(name) - for saved_dataset in list_saved_datasets: - self.delete_saved_dataset(saved_dataset.name, name) - - list_feature_services = self.list_feature_services(name) - for feature_service in list_feature_services: - self.delete_feature_service(feature_service.name, name) - - list_on_demand_feature_views = self.list_on_demand_feature_views(name) - for on_demand_feature_view in list_on_demand_feature_views: - self.delete_feature_view(on_demand_feature_view.name, name) - - list_stream_feature_views = self.list_stream_feature_views(name) - for stream_feature_view in list_stream_feature_views: - self.delete_feature_view(stream_feature_view.name, name) - - list_feature_views = self.list_feature_views(name) - for feature_view in list_feature_views: - self.delete_feature_view(feature_view.name, name) - - list_data_sources = self.list_data_sources(name) - for data_source in list_data_sources: - self.delete_data_source(data_source.name, name) - - list_entities = self.list_entities(name) - for entity in list_entities: - self.delete_entity(entity.name, name) - list_permissions = self.list_permissions(name) - for permission in list_permissions: - self.delete_permission(permission.name, name) - del self.cached_registry_proto.projects[idx] - if commit: - self.commit() - return - raise ProjectNotFoundException(name) +# Copyright 2019 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from datetime import datetime, timedelta, timezone +from enum import Enum +from pathlib import Path +from threading import Lock +from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse + +from google.protobuf.internal.containers import RepeatedCompositeFieldContainer +from google.protobuf.message import Message + +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.errors import ( + ConflictingFeatureViewNames, + DataSourceNotFoundException, + EntityNotFoundException, + FeatureServiceNotFoundException, + FeatureViewNotFoundException, + PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, + ValidationReferenceNotFound, +) +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.importer import import_class +from feast.infra.infra_object import Infra +from feast.infra.registry import proto_registry_utils +from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.registry.registry_store import NoopRegistryStore +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.auth_model import AuthConfig, NoAuthConfig +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.repo_config import RegistryConfig +from feast.repo_contents import RepoContents +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView +from feast.utils import _utc_now + +REGISTRY_SCHEMA_VERSION = "1" + +REGISTRY_STORE_CLASS_FOR_TYPE = { + "GCSRegistryStore": "feast.infra.registry.gcs.GCSRegistryStore", + "S3RegistryStore": "feast.infra.registry.s3.S3RegistryStore", + "FileRegistryStore": "feast.infra.registry.file.FileRegistryStore", + "AzureRegistryStore": "feast.infra.registry.contrib.azure.azure_registry_store.AzBlobRegistryStore", +} + +REGISTRY_STORE_CLASS_FOR_SCHEME = { + "gs": "GCSRegistryStore", + "s3": "S3RegistryStore", + "file": "FileRegistryStore", + "": "FileRegistryStore", +} + + +class FeastObjectType(Enum): + PROJECT = "project" + DATA_SOURCE = "data source" + ENTITY = "entity" + FEATURE_VIEW = "feature view" + ON_DEMAND_FEATURE_VIEW = "on demand feature view" + STREAM_FEATURE_VIEW = "stream feature view" + FEATURE_SERVICE = "feature service" + PERMISSION = "permission" + + @staticmethod + def get_objects_from_registry( + registry: "BaseRegistry", project: str + ) -> Dict["FeastObjectType", List[Any]]: + return { + FeastObjectType.PROJECT: [ + project_obj + for project_obj in registry.list_projects() + if project_obj.name == project + ], + FeastObjectType.DATA_SOURCE: registry.list_data_sources(project=project), + FeastObjectType.ENTITY: registry.list_entities(project=project), + FeastObjectType.FEATURE_VIEW: registry.list_feature_views(project=project), + FeastObjectType.ON_DEMAND_FEATURE_VIEW: registry.list_on_demand_feature_views( + project=project + ), + FeastObjectType.STREAM_FEATURE_VIEW: registry.list_stream_feature_views( + project=project, + ), + FeastObjectType.FEATURE_SERVICE: registry.list_feature_services( + project=project + ), + FeastObjectType.PERMISSION: registry.list_permissions(project=project), + } + + @staticmethod + def get_objects_from_repo_contents( + repo_contents: RepoContents, + ) -> Dict["FeastObjectType", List[Any]]: + return { + FeastObjectType.PROJECT: repo_contents.projects, + FeastObjectType.DATA_SOURCE: repo_contents.data_sources, + FeastObjectType.ENTITY: repo_contents.entities, + FeastObjectType.FEATURE_VIEW: repo_contents.feature_views, + FeastObjectType.ON_DEMAND_FEATURE_VIEW: repo_contents.on_demand_feature_views, + FeastObjectType.STREAM_FEATURE_VIEW: repo_contents.stream_feature_views, + FeastObjectType.FEATURE_SERVICE: repo_contents.feature_services, + FeastObjectType.PERMISSION: repo_contents.permissions, + } + + +FEAST_OBJECT_TYPES = [feast_object_type for feast_object_type in FeastObjectType] + +logger = logging.getLogger(__name__) + + +def get_registry_store_class_from_type(registry_store_type: str): + if not registry_store_type.endswith("RegistryStore"): + raise Exception('Registry store class name should end with "RegistryStore"') + if registry_store_type in REGISTRY_STORE_CLASS_FOR_TYPE: + registry_store_type = REGISTRY_STORE_CLASS_FOR_TYPE[registry_store_type] + module_name, registry_store_class_name = registry_store_type.rsplit(".", 1) + + return import_class(module_name, registry_store_class_name, "RegistryStore") + + +def get_registry_store_class_from_scheme(registry_path: str): + uri = urlparse(registry_path) + if uri.scheme not in REGISTRY_STORE_CLASS_FOR_SCHEME: + raise Exception( + f"Registry path {registry_path} has unsupported scheme {uri.scheme}. " + f"Supported schemes are file, s3 and gs." + ) + else: + registry_store_type = REGISTRY_STORE_CLASS_FOR_SCHEME[uri.scheme] + return get_registry_store_class_from_type(registry_store_type) + + +class Registry(BaseRegistry): + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + pass + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + pass + + # The cached_registry_proto object is used for both reads and writes. In particular, + # all write operations refresh the cache and modify it in memory; the write must + # then be persisted to the underlying RegistryStore with a call to commit(). + cached_registry_proto: RegistryProto + cached_registry_proto_created: datetime + cached_registry_proto_ttl: timedelta + + def __init__( + self, + project: str, + registry_config: Optional[RegistryConfig], + repo_path: Optional[Path], + auth_config: AuthConfig = NoAuthConfig(), + ): + """ + Create the Registry object. + + Args: + registry_config: RegistryConfig object containing the destination path and cache ttl, + repo_path: Path to the base of the Feast repository + or where it will be created if it does not exist yet. + """ + + self._refresh_lock = Lock() + self._auth_config = auth_config + + registry_proto = RegistryProto() + registry_proto.registry_schema_version = REGISTRY_SCHEMA_VERSION + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + + self.purge_feast_metadata = ( + registry_config.purge_feast_metadata + if registry_config is not None + else False + ) + + if registry_config: + registry_store_type = registry_config.registry_store_type + registry_path = registry_config.path + if registry_store_type is None: + cls = get_registry_store_class_from_scheme(registry_path) + else: + cls = get_registry_store_class_from_type(str(registry_store_type)) + + self._registry_store = cls(registry_config, repo_path) + self.cached_registry_proto_ttl = timedelta( + seconds=( + registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) + ) + + try: + registry_proto = self._registry_store.get_registry_proto() + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + # Sync feast_metadata to projects table + # when purge_feast_metadata is set to True, Delete data from + # feast_metadata table and list_project_metadata will not return any data + self._sync_feast_metadata_to_projects_table() + except FileNotFoundError: + logger.info("Registry file not found. Creating new registry.") + self.commit() + + def _sync_feast_metadata_to_projects_table(self): + """ + Sync feast_metadata to projects table + """ + feast_metadata_projects = [] + projects_set = [] + # List of project in project_metadata + for project_metadata in self.cached_registry_proto.project_metadata: + project = ProjectMetadata.from_proto(project_metadata) + feast_metadata_projects.append(project.project_name) + if len(feast_metadata_projects) > 0: + # List of project in projects + for project_metadata in self.cached_registry_proto.projects: + project = Project.from_proto(project_metadata) + projects_set.append(project.name) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects) - set(projects_set) + # Sync feast_metadata to projects table + for project_name in projects_to_sync: + project = Project(name=project_name) + self.cached_registry_proto.projects.append(project.to_proto()) + + if self.purge_feast_metadata: + self.cached_registry_proto.project_metadata = [] + + def clone(self) -> "Registry": + new_registry = Registry("project", None, None, self._auth_config) + new_registry.cached_registry_proto_ttl = timedelta(seconds=0) + new_registry.cached_registry_proto = ( + self.cached_registry_proto.__deepcopy__() + if self.cached_registry_proto + else RegistryProto() + ) + new_registry.cached_registry_proto_created = _utc_now() + new_registry._registry_store = NoopRegistryStore() + return new_registry + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + self.cached_registry_proto.infra.CopyFrom(infra.to_proto()) + if commit: + self.commit() + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return Infra.from_proto(registry_proto.infra) + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + entity.is_valid() + + now = _utc_now() + if not entity.created_timestamp: + entity.created_timestamp = now + entity.last_updated_timestamp = now + + entity_proto = entity.to_proto() + entity_proto.spec.project = project + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_entity_proto in enumerate( + self.cached_registry_proto.entities + ): + if ( + existing_entity_proto.spec.name == entity_proto.spec.name + and existing_entity_proto.spec.project == project + ): + entity.created_timestamp = ( + existing_entity_proto.meta.created_timestamp.ToDatetime() + ) + entity_proto = entity.to_proto() + entity_proto.spec.project = project + del self.cached_registry_proto.entities[idx] + break + self.cached_registry_proto.entities.append(entity_proto) + if commit: + self.commit() + + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_entities(registry_proto, project, tags) + + def list_data_sources( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[DataSource]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_data_sources(registry_proto, project, tags) + + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + registry = self._prepare_registry_for_changes(project) + for idx, existing_data_source_proto in enumerate(registry.data_sources): + if existing_data_source_proto.name == data_source.name: + del registry.data_sources[idx] + data_source_proto = data_source.to_proto() + data_source_proto.project = project + data_source_proto.data_source_class_type = ( + f"{data_source.__class__.__module__}.{data_source.__class__.__name__}" + ) + self.cached_registry_proto.data_sources.append(data_source_proto) + if commit: + self.commit() + + def delete_data_source(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, data_source_proto in enumerate( + self.cached_registry_proto.data_sources + ): + if data_source_proto.name == name: + del self.cached_registry_proto.data_sources[idx] + if commit: + self.commit() + return + raise DataSourceNotFoundException(name) + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + now = _utc_now() + if not feature_service.created_timestamp: + feature_service.created_timestamp = now + feature_service.last_updated_timestamp = now + + feature_service_proto = feature_service.to_proto() + feature_service_proto.spec.project = project + + registry = self._prepare_registry_for_changes(project) + + for idx, existing_feature_service_proto in enumerate(registry.feature_services): + if ( + existing_feature_service_proto.spec.name + == feature_service_proto.spec.name + and existing_feature_service_proto.spec.project == project + ): + feature_service.created_timestamp = ( + existing_feature_service_proto.meta.created_timestamp.ToDatetime() + ) + feature_service_proto = feature_service.to_proto() + feature_service_proto.spec.project = project + del registry.feature_services[idx] + self.cached_registry_proto.feature_services.append(feature_service_proto) + if commit: + self.commit() + + def list_feature_services( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureService]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_feature_services(registry_proto, project, tags) + + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_feature_service(registry_proto, name, project) + + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_entity(registry_proto, name, project) + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + feature_view.ensure_valid() + + now = _utc_now() + if not feature_view.created_timestamp: + feature_view.created_timestamp = now + feature_view.last_updated_timestamp = now + + feature_view_proto = feature_view.to_proto() + feature_view_proto.spec.project = project + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + self._check_conflicting_feature_view_names(feature_view) + existing_feature_views_of_same_type: RepeatedCompositeFieldContainer + if isinstance(feature_view, StreamFeatureView): + existing_feature_views_of_same_type = ( + self.cached_registry_proto.stream_feature_views + ) + elif isinstance(feature_view, FeatureView): + existing_feature_views_of_same_type = ( + self.cached_registry_proto.feature_views + ) + elif isinstance(feature_view, OnDemandFeatureView): + existing_feature_views_of_same_type = ( + self.cached_registry_proto.on_demand_feature_views + ) + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + + for idx, existing_feature_view_proto in enumerate( + existing_feature_views_of_same_type + ): + if ( + existing_feature_view_proto.spec.name == feature_view_proto.spec.name + and existing_feature_view_proto.spec.project == project + ): + if ( + feature_view.__class__.from_proto(existing_feature_view_proto) + == feature_view + ): + return + else: + existing_feature_view = type(feature_view).from_proto( + existing_feature_view_proto + ) + feature_view.created_timestamp = ( + existing_feature_view.created_timestamp + ) + if isinstance(feature_view, (FeatureView, StreamFeatureView)): + feature_view.update_materialization_intervals( + existing_feature_view.materialization_intervals + ) + feature_view_proto = feature_view.to_proto() + feature_view_proto.spec.project = project + del existing_feature_views_of_same_type[idx] + break + + existing_feature_views_of_same_type.append(feature_view_proto) + if commit: + self.commit() + + def list_stream_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[StreamFeatureView]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_stream_feature_views( + registry_proto, project, tags + ) + + def list_on_demand_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[OnDemandFeatureView]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_on_demand_feature_views( + registry_proto, project, tags + ) + + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_on_demand_feature_view( + registry_proto, name, project + ) + + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_data_source(registry_proto, name, project) + + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_feature_view_proto in enumerate( + self.cached_registry_proto.feature_views + ): + if ( + existing_feature_view_proto.spec.name == feature_view.name + and existing_feature_view_proto.spec.project == project + ): + existing_feature_view = FeatureView.from_proto( + existing_feature_view_proto + ) + existing_feature_view.materialization_intervals.append( + (start_date, end_date) + ) + existing_feature_view.last_updated_timestamp = _utc_now() + feature_view_proto = existing_feature_view.to_proto() + feature_view_proto.spec.project = project + del self.cached_registry_proto.feature_views[idx] + self.cached_registry_proto.feature_views.append(feature_view_proto) + if commit: + self.commit() + return + + for idx, existing_stream_feature_view_proto in enumerate( + self.cached_registry_proto.stream_feature_views + ): + if ( + existing_stream_feature_view_proto.spec.name == feature_view.name + and existing_stream_feature_view_proto.spec.project == project + ): + existing_stream_feature_view = StreamFeatureView.from_proto( + existing_stream_feature_view_proto + ) + existing_stream_feature_view.materialization_intervals.append( + (start_date, end_date) + ) + existing_stream_feature_view.last_updated_timestamp = _utc_now() + stream_feature_view_proto = existing_stream_feature_view.to_proto() + stream_feature_view_proto.spec.project = project + del self.cached_registry_proto.stream_feature_views[idx] + self.cached_registry_proto.stream_feature_views.append( + stream_feature_view_proto + ) + if commit: + self.commit() + return + + def list_all_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[BaseFeatureView]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_all_feature_views( + registry_proto, project, tags + ) + + def get_any_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> BaseFeatureView: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_any_feature_view(registry_proto, name, project) + + def list_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureView]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_feature_views(registry_proto, project, tags) + + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_feature_view(registry_proto, name, project) + + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> StreamFeatureView: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_stream_feature_view( + registry_proto, name, project + ) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, feature_service_proto in enumerate( + self.cached_registry_proto.feature_services + ): + if ( + feature_service_proto.spec.name == name + and feature_service_proto.spec.project == project + ): + del self.cached_registry_proto.feature_services[idx] + if commit: + self.commit() + return + raise FeatureServiceNotFoundException(name, project) + + def delete_feature_view(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_feature_view_proto in enumerate( + self.cached_registry_proto.feature_views + ): + if ( + existing_feature_view_proto.spec.name == name + and existing_feature_view_proto.spec.project == project + ): + del self.cached_registry_proto.feature_views[idx] + if commit: + self.commit() + return + + for idx, existing_on_demand_feature_view_proto in enumerate( + self.cached_registry_proto.on_demand_feature_views + ): + if ( + existing_on_demand_feature_view_proto.spec.name == name + and existing_on_demand_feature_view_proto.spec.project == project + ): + del self.cached_registry_proto.on_demand_feature_views[idx] + if commit: + self.commit() + return + + for idx, existing_stream_feature_view_proto in enumerate( + self.cached_registry_proto.stream_feature_views + ): + if ( + existing_stream_feature_view_proto.spec.name == name + and existing_stream_feature_view_proto.spec.project == project + ): + del self.cached_registry_proto.stream_feature_views[idx] + if commit: + self.commit() + return + + raise FeatureViewNotFoundException(name, project) + + def delete_entity(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_entity_proto in enumerate( + self.cached_registry_proto.entities + ): + if ( + existing_entity_proto.spec.name == name + and existing_entity_proto.spec.project == project + ): + del self.cached_registry_proto.entities[idx] + if commit: + self.commit() + return + + raise EntityNotFoundException(name, project) + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + now = _utc_now() + if not saved_dataset.created_timestamp: + saved_dataset.created_timestamp = now + saved_dataset.last_updated_timestamp = now + + saved_dataset_proto = saved_dataset.to_proto() + saved_dataset_proto.spec.project = project + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, existing_saved_dataset_proto in enumerate( + self.cached_registry_proto.saved_datasets + ): + if ( + existing_saved_dataset_proto.spec.name == saved_dataset_proto.spec.name + and existing_saved_dataset_proto.spec.project == project + ): + saved_dataset.created_timestamp = ( + existing_saved_dataset_proto.meta.created_timestamp.ToDatetime() + ) + saved_dataset.min_event_timestamp = ( + existing_saved_dataset_proto.meta.min_event_timestamp.ToDatetime() + ) + saved_dataset.max_event_timestamp = ( + existing_saved_dataset_proto.meta.max_event_timestamp.ToDatetime() + ) + saved_dataset_proto = saved_dataset.to_proto() + saved_dataset_proto.spec.project = project + del self.cached_registry_proto.saved_datasets[idx] + break + + self.cached_registry_proto.saved_datasets.append(saved_dataset_proto) + if commit: + self.commit() + + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_saved_dataset(registry_proto, name, project) + + def list_saved_datasets( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[SavedDataset]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_saved_datasets(registry_proto, project, tags) + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + validation_reference_proto = validation_reference.to_proto() + validation_reference_proto.project = project + + registry_proto = self._prepare_registry_for_changes(project) + for idx, existing_validation_reference in enumerate( + registry_proto.validation_references + ): + if ( + existing_validation_reference.name == validation_reference_proto.name + and existing_validation_reference.project == project + ): + del registry_proto.validation_references[idx] + break + + registry_proto.validation_references.append(validation_reference_proto) + if commit: + self.commit() + + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_validation_reference( + registry_proto, name, project + ) + + def list_validation_references( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[ValidationReference]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_validation_references( + registry_proto, project, tags + ) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + for idx, existing_validation_reference in enumerate( + self.cached_registry_proto.validation_references + ): + if ( + existing_validation_reference.name == name + and existing_validation_reference.project == project + ): + del self.cached_registry_proto.validation_references[idx] + if commit: + self.commit() + return + raise ValidationReferenceNotFound(name, project=project) + + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_project_metadata(registry_proto, project) + + def commit(self): + """Commits the state of the registry cache to the remote registry store.""" + if self.cached_registry_proto: + self._registry_store.update_registry_proto(self.cached_registry_proto) + + def refresh(self, project: Optional[str] = None): + """Refreshes the state of the registry cache by fetching the registry state from the remote registry store.""" + self._get_registry_proto(project=project, allow_cache=False) + + def teardown(self): + """Tears down (removes) the registry.""" + self._registry_store.teardown() + + def proto(self) -> RegistryProto: + return self.cached_registry_proto or RegistryProto() + + def _prepare_registry_for_changes(self, project_name: str): + """Prepares the Registry for changes by refreshing the cache if necessary.""" + + assert self.cached_registry_proto is not None + + try: + # Check if the project exists in the registry cache + self.get_project(name=project_name, allow_cache=True) + return self.cached_registry_proto + except ProjectObjectNotFoundException: + # If the project does not exist in cache, refresh cache from store + registry_proto = self._registry_store.get_registry_proto() + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + + try: + # Check if the project exists in the registry cache after refresh from store + self.get_project(name=project_name) + except ProjectObjectNotFoundException: + # If the project still does not exist, create it + project_proto = Project(name=project_name).to_proto() + self.cached_registry_proto.projects.append(project_proto) + if not self.purge_feast_metadata: + project_metadata_proto = ProjectMetadata( + project_name=project_name + ).to_proto() + self.cached_registry_proto.project_metadata.append( + project_metadata_proto + ) + self.commit() + return self.cached_registry_proto + + def _get_registry_proto( + self, project: Optional[str], allow_cache: bool = False + ) -> RegistryProto: + """Returns the cached or remote registry state + + Args: + project: Name of the Feast project (optional) + allow_cache: Whether to allow the use of the registry cache when fetching the RegistryProto + + Returns: Returns a RegistryProto object which represents the state of the registry + """ + with self._refresh_lock: + expired = (self.cached_registry_proto_created is None) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + _utc_now() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) + ) + ) + + if allow_cache and not expired: + return self.cached_registry_proto + logger.info("Registry cache expired, so refreshing") + registry_proto = self._registry_store.get_registry_proto() + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + return registry_proto + + def _check_conflicting_feature_view_names(self, feature_view: BaseFeatureView): + name_to_fv_protos = self._existing_feature_view_names_to_fvs() + if feature_view.name in name_to_fv_protos: + if not isinstance( + name_to_fv_protos.get(feature_view.name), feature_view.proto_class + ): + raise ConflictingFeatureViewNames(feature_view.name) + + def _existing_feature_view_names_to_fvs(self) -> Dict[str, Message]: + assert self.cached_registry_proto + odfvs = { + fv.spec.name: fv + for fv in self.cached_registry_proto.on_demand_feature_views + } + fvs = {fv.spec.name: fv for fv in self.cached_registry_proto.feature_views} + sfv = { + fv.spec.name: fv for fv in self.cached_registry_proto.stream_feature_views + } + return {**odfvs, **fvs, **sfv} + + def get_permission( + self, name: str, project: str, allow_cache: bool = False + ) -> Permission: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.get_permission(registry_proto, name, project) + + def list_permissions( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Permission]: + registry_proto = self._get_registry_proto( + project=project, allow_cache=allow_cache + ) + return proto_registry_utils.list_permissions(registry_proto, project, tags) + + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + now = _utc_now() + if not permission.created_timestamp: + permission.created_timestamp = now + permission.last_updated_timestamp = now + + registry = self._prepare_registry_for_changes(project) + for idx, existing_permission_proto in enumerate(registry.permissions): + if ( + existing_permission_proto.spec.name == permission.name + and existing_permission_proto.spec.project == project + ): + permission.created_timestamp = ( + existing_permission_proto.meta.created_timestamp.ToDatetime() + ) + del registry.permissions[idx] + + permission_proto = permission.to_proto() + permission_proto.spec.project = project + self.cached_registry_proto.permissions.append(permission_proto) + if commit: + self.commit() + + def delete_permission(self, name: str, project: str, commit: bool = True): + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto + + for idx, permission_proto in enumerate(self.cached_registry_proto.permissions): + if ( + permission_proto.spec.name == name + and permission_proto.spec.project == project + ): + del self.cached_registry_proto.permissions[idx] + if commit: + self.commit() + return + raise PermissionNotFoundException(name, project) + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + registry = self.cached_registry_proto + + for idx, existing_project_proto in enumerate(registry.projects): + if existing_project_proto.spec.name == project.name: + project.created_timestamp = ( + existing_project_proto.meta.created_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ) + ) + del registry.projects[idx] + + project_proto = project.to_proto() + self.cached_registry_proto.projects.append(project_proto) + if commit: + self.commit() + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + registry_proto = self._get_registry_proto(project=name, allow_cache=allow_cache) + return proto_registry_utils.get_project(registry_proto, name) + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + registry_proto = self._get_registry_proto(project=None, allow_cache=allow_cache) + return proto_registry_utils.list_projects( + registry_proto=registry_proto, tags=tags + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + assert self.cached_registry_proto + + for idx, project_proto in enumerate(self.cached_registry_proto.projects): + if project_proto.spec.name == name: + list_validation_references = self.list_validation_references(name) + for validation_reference in list_validation_references: + self.delete_validation_reference(validation_reference.name, name) + + list_saved_datasets = self.list_saved_datasets(name) + for saved_dataset in list_saved_datasets: + self.delete_saved_dataset(saved_dataset.name, name) + + list_feature_services = self.list_feature_services(name) + for feature_service in list_feature_services: + self.delete_feature_service(feature_service.name, name) + + list_on_demand_feature_views = self.list_on_demand_feature_views(name) + for on_demand_feature_view in list_on_demand_feature_views: + self.delete_feature_view(on_demand_feature_view.name, name) + + list_stream_feature_views = self.list_stream_feature_views(name) + for stream_feature_view in list_stream_feature_views: + self.delete_feature_view(stream_feature_view.name, name) + + list_feature_views = self.list_feature_views(name) + for feature_view in list_feature_views: + self.delete_feature_view(feature_view.name, name) + + list_data_sources = self.list_data_sources(name) + for data_source in list_data_sources: + self.delete_data_source(data_source.name, name) + + list_entities = self.list_entities(name) + for entity in list_entities: + self.delete_entity(entity.name, name) + list_permissions = self.list_permissions(name) + for permission in list_permissions: + self.delete_permission(permission.name, name) + del self.cached_registry_proto.projects[idx] + if commit: + self.commit() + return + raise ProjectNotFoundException(name) From 9bb8ffdb38244413571946fa9a9878c0fab18fcc Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 19 Jun 2025 21:44:46 -0400 Subject: [PATCH 7/7] updated files Signed-off-by: Francisco Javier Arceo --- sdk/python/feast/infra/provider.py | 1076 +++---- sdk/python/feast/infra/registry/remote.py | 1188 ++++---- sdk/python/feast/infra/registry/snowflake.py | 2750 +++++++++--------- sdk/python/feast/infra/registry/sql.py | 2516 ++++++++-------- 4 files changed, 3765 insertions(+), 3765 deletions(-) diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 39686895a87..4f7b0d4b5c1 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -1,538 +1,538 @@ -from abc import ABC, abstractmethod -from datetime import datetime -from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) - -import pandas as pd -import pyarrow -from tqdm import tqdm - -from feast import FeatureService, errors -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.feature_view import FeatureView -from feast.importer import import_class -from feast.infra.infra_object import Infra -from feast.infra.offline_stores.offline_store import RetrievalJob -from feast.infra.registry.base_registry import BaseRegistry -from feast.infra.supported_async_methods import ProviderAsyncMethods -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.online_response import OnlineResponse -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import RepeatedValue -from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.repo_config import RepoConfig -from feast.saved_dataset import SavedDataset - -PROVIDERS_CLASS_FOR_TYPE = { - "gcp": "feast.infra.passthrough_provider.PassthroughProvider", - "aws": "feast.infra.passthrough_provider.PassthroughProvider", - "local": "feast.infra.passthrough_provider.PassthroughProvider", - "azure": "feast.infra.passthrough_provider.PassthroughProvider", -} - - -class Provider(ABC): - """ - A provider defines an implementation of a feature store object. It orchestrates the various - components of a feature store, such as the offline store, online store, and materialization - engine. It is configured through a RepoConfig object. - """ - - @abstractmethod - def __init__(self, config: RepoConfig): - pass - - @property - def async_supported(self) -> ProviderAsyncMethods: - return ProviderAsyncMethods() - - @abstractmethod - def update_infra( - self, - project: str, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[Union[FeatureView, OnDemandFeatureView]], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, - ): - """ - Reconciles cloud resources with the specified set of Feast objects. - - Args: - project: Feast project to which the objects belong. - tables_to_delete: Feature views whose corresponding infrastructure should be deleted. - tables_to_keep: Feature views whose corresponding infrastructure should not be deleted, and - may need to be updated. - entities_to_delete: Entities whose corresponding infrastructure should be deleted. - entities_to_keep: Entities whose corresponding infrastructure should not be deleted, and - may need to be updated. - partial: If true, tables_to_delete and tables_to_keep are not exhaustive lists, so - infrastructure corresponding to other feature views should be not be touched. - """ - pass - - def plan_infra( - self, config: RepoConfig, desired_registry_proto: RegistryProto - ) -> Infra: - """ - Returns the Infra required to support the desired registry. - - Args: - config: The RepoConfig for the current FeatureStore. - desired_registry_proto: The desired registry, in proto form. - """ - return Infra() - - @abstractmethod - def teardown_infra( - self, - project: str, - tables: Sequence[FeatureView], - entities: Sequence[Entity], - ): - """ - Tears down all cloud resources for the specified set of Feast objects. - - Args: - project: Feast project to which the objects belong. - tables: Feature views whose corresponding infrastructure should be deleted. - entities: Entities whose corresponding infrastructure should be deleted. - """ - pass - - @abstractmethod - def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], - ) -> None: - """ - Writes a batch of feature rows to the online store. - - If a tz-naive timestamp is passed to this method, it is assumed to be UTC. - - Args: - config: The config for the current feature store. - table: Feature view to which these feature rows correspond. - data: A list of quadruplets containing feature data. Each quadruplet contains an entity - key, a dict containing feature values, an event timestamp for the row, and the created - timestamp for the row if it exists. - progress: Function to be called once a batch of rows is written to the online store, used - to show progress. - """ - pass - - @abstractmethod - async def online_write_batch_async( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], - ) -> None: - """ - Writes a batch of feature rows to the online store asynchronously. - - If a tz-naive timestamp is passed to this method, it is assumed to be UTC. - - Args: - config: The config for the current feature store. - table: Feature view to which these feature rows correspond. - data: A list of quadruplets containing feature data. Each quadruplet contains an entity - key, a dict containing feature values, an event timestamp for the row, and the created - timestamp for the row if it exists. - progress: Function to be called once a batch of rows is written to the online store, used - to show progress. - """ - pass - - def ingest_df( - self, - feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], - df: pd.DataFrame, - field_mapping: Optional[Dict] = None, - ): - """ - Persists a dataframe to the online store. - - Args: - feature_view: The feature view to which the dataframe corresponds. - df: The dataframe to be persisted. - field_mapping: A dictionary mapping dataframe column names to feature names. - """ - pass - - async def ingest_df_async( - self, - feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], - df: pd.DataFrame, - field_mapping: Optional[Dict] = None, - ): - """ - Persists a dataframe to the online store asynchronously. - - Args: - feature_view: The feature view to which the dataframe corresponds. - df: The dataframe to be persisted. - field_mapping: A dictionary mapping dataframe column names to feature names. - """ - pass - - def ingest_df_to_offline_store( - self, - feature_view: FeatureView, - df: pyarrow.Table, - ): - """ - Persists a dataframe to the offline store. - - Args: - feature_view: The feature view to which the dataframe corresponds. - df: The dataframe to be persisted. - """ - pass - - @abstractmethod - def materialize_single_feature_view( - self, - config: RepoConfig, - feature_view: Union[FeatureView, OnDemandFeatureView], - start_date: datetime, - end_date: datetime, - registry: BaseRegistry, - project: str, - tqdm_builder: Callable[[int], tqdm], - ) -> None: - """ - Writes latest feature values in the specified time range to the online store. - - Args: - config: The config for the current feature store. - feature_view: The feature view to materialize. - start_date: The start of the time range. - end_date: The end of the time range. - registry: The registry for the current feature store. - project: Feast project to which the objects belong. - tqdm_builder: A function to monitor the progress of materialization. - """ - pass - - @abstractmethod - def get_historical_features( - self, - config: RepoConfig, - feature_views: List[Union[FeatureView, OnDemandFeatureView]], - feature_refs: List[str], - entity_df: Union[pd.DataFrame, str], - registry: BaseRegistry, - project: str, - full_feature_names: bool, - ) -> RetrievalJob: - """ - Retrieves the point-in-time correct historical feature values for the specified entity rows. - - Args: - config: The config for the current feature store. - feature_views: A list containing all feature views that are referenced in the entity rows. - feature_refs: The features to be retrieved. - entity_df: A collection of rows containing all entity columns on which features need to be joined, - as well as the timestamp column used for point-in-time joins. Either a pandas dataframe can be - provided or a SQL query. - registry: The registry for the current feature store. - project: Feast project to which the feature views belong. - full_feature_names: If True, feature names will be prefixed with the corresponding feature view name, - changing them from the format "feature" to "feature_view__feature" (e.g. "daily_transactions" - changes to "customer_fv__daily_transactions"). - - Returns: - A RetrievalJob that can be executed to get the features. - """ - pass - - @abstractmethod - def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, - ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - """ - Reads features values for the given entity keys. - - Args: - config: The config for the current feature store. - table: The feature view whose feature values should be read. - entity_keys: The list of entity keys for which feature values should be read. - requested_features: The list of features that should be read. - - Returns: - A list of the same length as entity_keys. Each item in the list is a tuple where the first - item is the event timestamp for the row, and the second item is a dict mapping feature names - to values, which are returned in proto format. - """ - pass - - @abstractmethod - def get_online_features( - self, - config: RepoConfig, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], - ], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, - ) -> OnlineResponse: - pass - - @abstractmethod - async def get_online_features_async( - self, - config: RepoConfig, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], - ], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, - ) -> OnlineResponse: - pass - - @abstractmethod - async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, - ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - """ - Reads features values for the given entity keys asynchronously. - - Args: - config: The config for the current feature store. - table: The feature view whose feature values should be read. - entity_keys: The list of entity keys for which feature values should be read. - requested_features: The list of features that should be read. - - Returns: - A list of the same length as entity_keys. Each item in the list is a tuple where the first - item is the event timestamp for the row, and the second item is a dict mapping feature names - to values, which are returned in proto format. - """ - pass - - @abstractmethod - def retrieve_saved_dataset( - self, config: RepoConfig, dataset: SavedDataset - ) -> RetrievalJob: - """ - Reads a saved dataset. - - Args: - config: The config for the current feature store. - dataset: A SavedDataset object containing all parameters necessary for retrieving the dataset. - - Returns: - A RetrievalJob that can be executed to get the saved dataset. - """ - pass - - @abstractmethod - def write_feature_service_logs( - self, - feature_service: FeatureService, - logs: Union[pyarrow.Table, Path], - config: RepoConfig, - registry: BaseRegistry, - ): - """ - Writes features and entities logged by a feature server to the offline store. - - The schema of the logs table is inferred from the specified feature service. Only feature - services with configured logging are accepted. - - Args: - feature_service: The feature service to be logged. - logs: The logs, either as an arrow table or as a path to a parquet directory. - config: The config for the current feature store. - registry: The registry for the current feature store. - """ - pass - - @abstractmethod - def retrieve_feature_service_logs( - self, - feature_service: FeatureService, - start_date: datetime, - end_date: datetime, - config: RepoConfig, - registry: BaseRegistry, - ) -> RetrievalJob: - """ - Reads logged features for the specified time window. - - Args: - feature_service: The feature service whose logs should be retrieved. - start_date: The start of the window. - end_date: The end of the window. - config: The config for the current feature store. - registry: The registry for the current feature store. - - Returns: - A RetrievalJob that can be executed to get the feature service logs. - """ - pass - - def get_feature_server_endpoint(self) -> Optional[str]: - """Returns endpoint for the feature server, if it exists.""" - return None - - @abstractmethod - def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_features: Optional[List[str]], - query: List[float], - top_k: int, - distance_metric: Optional[str] = None, - ) -> List[ - Tuple[ - Optional[datetime], - Optional[EntityKeyProto], - Optional[ValueProto], - Optional[ValueProto], - Optional[ValueProto], - ], - ]: - """ - Searches for the top-k most similar documents in the online document store. - - Args: - distance_metric: distance metric to use for the search. - config: The config for the current feature store. - table: The feature view whose embeddings should be searched. - requested_features: the requested document feature names. - query: The query embedding to search for. - top_k: The number of documents to return. - - Returns: - A list of dictionaries, where each dictionary contains the document feature. - """ - pass - - @abstractmethod - def retrieve_online_documents_v2( - self, - config: RepoConfig, - table: FeatureView, - requested_features: List[str], - query: Optional[List[float]], - top_k: int, - distance_metric: Optional[str] = None, - query_string: Optional[str] = None, - ) -> List[ - Tuple[ - Optional[datetime], - Optional[EntityKeyProto], - Optional[Dict[str, ValueProto]], - ] - ]: - """ - Searches for the top-k most similar documents in the online document store. - - Args: - distance_metric: distance metric to use for the search. - config: The config for the current feature store. - table: The feature view whose embeddings should be searched. - requested_features: the requested document feature names. - query: The query embedding to search for (optional). - top_k: The number of documents to return. - query_string: The query string to search for using keyword search (bm25) (optional) - - Returns: - A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary - of feature key value pairs - """ - pass - - @abstractmethod - def validate_data_source( - self, - config: RepoConfig, - data_source: DataSource, - ): - """ - Validates the underlying data source. - - Args: - config: Configuration object used to configure a feature store. - data_source: DataSource object that needs to be validated - """ - pass - - @abstractmethod - def get_table_column_names_and_types_from_data_source( - self, config: RepoConfig, data_source: DataSource - ) -> Iterable[Tuple[str, str]]: - """ - Returns the list of column names and raw column types for a DataSource. - - Args: - config: Configuration object used to configure a feature store. - data_source: DataSource object - """ - pass - - @abstractmethod - async def initialize(self, config: RepoConfig) -> None: - pass - - @abstractmethod - async def close(self) -> None: - pass - - -def get_provider(config: RepoConfig) -> Provider: - if "." not in config.provider: - if config.provider not in PROVIDERS_CLASS_FOR_TYPE: - raise errors.FeastProviderNotImplementedError(config.provider) - - provider = PROVIDERS_CLASS_FOR_TYPE[config.provider] - else: - provider = config.provider - - # Split provider into module and class names by finding the right-most dot. - # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' - module_name, class_name = provider.rsplit(".", 1) - - cls = import_class(module_name, class_name, "Provider") - - return cls(config) +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +import pandas as pd +import pyarrow +from tqdm import tqdm + +from feast import FeatureService, errors +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.importer import import_class +from feast.infra.infra_object import Infra +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.supported_async_methods import ProviderAsyncMethods +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.online_response import OnlineResponse +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import RepeatedValue +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDataset + +PROVIDERS_CLASS_FOR_TYPE = { + "gcp": "feast.infra.passthrough_provider.PassthroughProvider", + "aws": "feast.infra.passthrough_provider.PassthroughProvider", + "local": "feast.infra.passthrough_provider.PassthroughProvider", + "azure": "feast.infra.passthrough_provider.PassthroughProvider", +} + + +class Provider(ABC): + """ + A provider defines an implementation of a feature store object. It orchestrates the various + components of a feature store, such as the offline store, online store, and materialization + engine. It is configured through a RepoConfig object. + """ + + @abstractmethod + def __init__(self, config: RepoConfig): + pass + + @property + def async_supported(self) -> ProviderAsyncMethods: + return ProviderAsyncMethods() + + @abstractmethod + def update_infra( + self, + project: str, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[Union[FeatureView, OnDemandFeatureView]], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, + ): + """ + Reconciles cloud resources with the specified set of Feast objects. + + Args: + project: Feast project to which the objects belong. + tables_to_delete: Feature views whose corresponding infrastructure should be deleted. + tables_to_keep: Feature views whose corresponding infrastructure should not be deleted, and + may need to be updated. + entities_to_delete: Entities whose corresponding infrastructure should be deleted. + entities_to_keep: Entities whose corresponding infrastructure should not be deleted, and + may need to be updated. + partial: If true, tables_to_delete and tables_to_keep are not exhaustive lists, so + infrastructure corresponding to other feature views should be not be touched. + """ + pass + + def plan_infra( + self, config: RepoConfig, desired_registry_proto: RegistryProto + ) -> Infra: + """ + Returns the Infra required to support the desired registry. + + Args: + config: The RepoConfig for the current FeatureStore. + desired_registry_proto: The desired registry, in proto form. + """ + return Infra() + + @abstractmethod + def teardown_infra( + self, + project: str, + tables: Sequence[FeatureView], + entities: Sequence[Entity], + ): + """ + Tears down all cloud resources for the specified set of Feast objects. + + Args: + project: Feast project to which the objects belong. + tables: Feature views whose corresponding infrastructure should be deleted. + entities: Entities whose corresponding infrastructure should be deleted. + """ + pass + + @abstractmethod + def online_write_batch( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + """ + Writes a batch of feature rows to the online store. + + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. + + Args: + config: The config for the current feature store. + table: Feature view to which these feature rows correspond. + data: A list of quadruplets containing feature data. Each quadruplet contains an entity + key, a dict containing feature values, an event timestamp for the row, and the created + timestamp for the row if it exists. + progress: Function to be called once a batch of rows is written to the online store, used + to show progress. + """ + pass + + @abstractmethod + async def online_write_batch_async( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + """ + Writes a batch of feature rows to the online store asynchronously. + + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. + + Args: + config: The config for the current feature store. + table: Feature view to which these feature rows correspond. + data: A list of quadruplets containing feature data. Each quadruplet contains an entity + key, a dict containing feature values, an event timestamp for the row, and the created + timestamp for the row if it exists. + progress: Function to be called once a batch of rows is written to the online store, used + to show progress. + """ + pass + + def ingest_df( + self, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], + df: pd.DataFrame, + field_mapping: Optional[Dict] = None, + ): + """ + Persists a dataframe to the online store. + + Args: + feature_view: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + field_mapping: A dictionary mapping dataframe column names to feature names. + """ + pass + + async def ingest_df_async( + self, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], + df: pd.DataFrame, + field_mapping: Optional[Dict] = None, + ): + """ + Persists a dataframe to the online store asynchronously. + + Args: + feature_view: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + field_mapping: A dictionary mapping dataframe column names to feature names. + """ + pass + + def ingest_df_to_offline_store( + self, + feature_view: FeatureView, + df: pyarrow.Table, + ): + """ + Persists a dataframe to the offline store. + + Args: + feature_view: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + """ + pass + + @abstractmethod + def materialize_single_feature_view( + self, + config: RepoConfig, + feature_view: Union[FeatureView, OnDemandFeatureView], + start_date: datetime, + end_date: datetime, + registry: BaseRegistry, + project: str, + tqdm_builder: Callable[[int], tqdm], + ) -> None: + """ + Writes latest feature values in the specified time range to the online store. + + Args: + config: The config for the current feature store. + feature_view: The feature view to materialize. + start_date: The start of the time range. + end_date: The end of the time range. + registry: The registry for the current feature store. + project: Feast project to which the objects belong. + tqdm_builder: A function to monitor the progress of materialization. + """ + pass + + @abstractmethod + def get_historical_features( + self, + config: RepoConfig, + feature_views: List[Union[FeatureView, OnDemandFeatureView]], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: BaseRegistry, + project: str, + full_feature_names: bool, + ) -> RetrievalJob: + """ + Retrieves the point-in-time correct historical feature values for the specified entity rows. + + Args: + config: The config for the current feature store. + feature_views: A list containing all feature views that are referenced in the entity rows. + feature_refs: The features to be retrieved. + entity_df: A collection of rows containing all entity columns on which features need to be joined, + as well as the timestamp column used for point-in-time joins. Either a pandas dataframe can be + provided or a SQL query. + registry: The registry for the current feature store. + project: Feast project to which the feature views belong. + full_feature_names: If True, feature names will be prefixed with the corresponding feature view name, + changing them from the format "feature" to "feature_view__feature" (e.g. "daily_transactions" + changes to "customer_fv__daily_transactions"). + + Returns: + A RetrievalJob that can be executed to get the features. + """ + pass + + @abstractmethod + def online_read( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + """ + Reads features values for the given entity keys. + + Args: + config: The config for the current feature store. + table: The feature view whose feature values should be read. + entity_keys: The list of entity keys for which feature values should be read. + requested_features: The list of features that should be read. + + Returns: + A list of the same length as entity_keys. Each item in the list is a tuple where the first + item is the event timestamp for the row, and the second item is a dict mapping feature names + to values, which are returned in proto format. + """ + pass + + @abstractmethod + def get_online_features( + self, + config: RepoConfig, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], + ], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + ) -> OnlineResponse: + pass + + @abstractmethod + async def get_online_features_async( + self, + config: RepoConfig, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], + ], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + ) -> OnlineResponse: + pass + + @abstractmethod + async def online_read_async( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + """ + Reads features values for the given entity keys asynchronously. + + Args: + config: The config for the current feature store. + table: The feature view whose feature values should be read. + entity_keys: The list of entity keys for which feature values should be read. + requested_features: The list of features that should be read. + + Returns: + A list of the same length as entity_keys. Each item in the list is a tuple where the first + item is the event timestamp for the row, and the second item is a dict mapping feature names + to values, which are returned in proto format. + """ + pass + + @abstractmethod + def retrieve_saved_dataset( + self, config: RepoConfig, dataset: SavedDataset + ) -> RetrievalJob: + """ + Reads a saved dataset. + + Args: + config: The config for the current feature store. + dataset: A SavedDataset object containing all parameters necessary for retrieving the dataset. + + Returns: + A RetrievalJob that can be executed to get the saved dataset. + """ + pass + + @abstractmethod + def write_feature_service_logs( + self, + feature_service: FeatureService, + logs: Union[pyarrow.Table, Path], + config: RepoConfig, + registry: BaseRegistry, + ): + """ + Writes features and entities logged by a feature server to the offline store. + + The schema of the logs table is inferred from the specified feature service. Only feature + services with configured logging are accepted. + + Args: + feature_service: The feature service to be logged. + logs: The logs, either as an arrow table or as a path to a parquet directory. + config: The config for the current feature store. + registry: The registry for the current feature store. + """ + pass + + @abstractmethod + def retrieve_feature_service_logs( + self, + feature_service: FeatureService, + start_date: datetime, + end_date: datetime, + config: RepoConfig, + registry: BaseRegistry, + ) -> RetrievalJob: + """ + Reads logged features for the specified time window. + + Args: + feature_service: The feature service whose logs should be retrieved. + start_date: The start of the window. + end_date: The end of the window. + config: The config for the current feature store. + registry: The registry for the current feature store. + + Returns: + A RetrievalJob that can be executed to get the feature service logs. + """ + pass + + def get_feature_server_endpoint(self) -> Optional[str]: + """Returns endpoint for the feature server, if it exists.""" + return None + + @abstractmethod + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_features: Optional[List[str]], + query: List[float], + top_k: int, + distance_metric: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ], + ]: + """ + Searches for the top-k most similar documents in the online document store. + + Args: + distance_metric: distance metric to use for the search. + config: The config for the current feature store. + table: The feature view whose embeddings should be searched. + requested_features: the requested document feature names. + query: The query embedding to search for. + top_k: The number of documents to return. + + Returns: + A list of dictionaries, where each dictionary contains the document feature. + """ + pass + + @abstractmethod + def retrieve_online_documents_v2( + self, + config: RepoConfig, + table: FeatureView, + requested_features: List[str], + query: Optional[List[float]], + top_k: int, + distance_metric: Optional[str] = None, + query_string: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[Dict[str, ValueProto]], + ] + ]: + """ + Searches for the top-k most similar documents in the online document store. + + Args: + distance_metric: distance metric to use for the search. + config: The config for the current feature store. + table: The feature view whose embeddings should be searched. + requested_features: the requested document feature names. + query: The query embedding to search for (optional). + top_k: The number of documents to return. + query_string: The query string to search for using keyword search (bm25) (optional) + + Returns: + A list of dictionaries, where each dictionary contains the datetime, entitykey, and a dictionary + of feature key value pairs + """ + pass + + @abstractmethod + def validate_data_source( + self, + config: RepoConfig, + data_source: DataSource, + ): + """ + Validates the underlying data source. + + Args: + config: Configuration object used to configure a feature store. + data_source: DataSource object that needs to be validated + """ + pass + + @abstractmethod + def get_table_column_names_and_types_from_data_source( + self, config: RepoConfig, data_source: DataSource + ) -> Iterable[Tuple[str, str]]: + """ + Returns the list of column names and raw column types for a DataSource. + + Args: + config: Configuration object used to configure a feature store. + data_source: DataSource object + """ + pass + + @abstractmethod + async def initialize(self, config: RepoConfig) -> None: + pass + + @abstractmethod + async def close(self) -> None: + pass + + +def get_provider(config: RepoConfig) -> Provider: + if "." not in config.provider: + if config.provider not in PROVIDERS_CLASS_FOR_TYPE: + raise errors.FeastProviderNotImplementedError(config.provider) + + provider = PROVIDERS_CLASS_FOR_TYPE[config.provider] + else: + provider = config.provider + + # Split provider into module and class names by finding the right-most dot. + # For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider' + module_name, class_name = provider.rsplit(".", 1) + + cls = import_class(module_name, class_name, "Provider") + + return cls(config) diff --git a/sdk/python/feast/infra/registry/remote.py b/sdk/python/feast/infra/registry/remote.py index 78f901ca202..4122586046f 100644 --- a/sdk/python/feast/infra/registry/remote.py +++ b/sdk/python/feast/infra/registry/remote.py @@ -1,594 +1,594 @@ -import os -from datetime import datetime -from pathlib import Path -from typing import List, Optional, Union - -import grpc -from google.protobuf.empty_pb2 import Empty -from google.protobuf.timestamp_pb2 import Timestamp -from pydantic import StrictStr - -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.infra.infra_object import Infra -from feast.infra.registry.base_registry import BaseRegistry -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.auth_model import AuthConfig, NoAuthConfig -from feast.permissions.client.grpc_client_auth_interceptor import ( - GrpcClientAuthHeaderInterceptor, -) -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc -from feast.repo_config import RegistryConfig -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView - - -def extract_base_feature_view( - any_feature_view: RegistryServer_pb2.AnyFeatureView, -) -> BaseFeatureView: - feature_view_type = any_feature_view.WhichOneof("any_feature_view") - if feature_view_type == "feature_view": - feature_view = FeatureView.from_proto(any_feature_view.feature_view) - elif feature_view_type == "on_demand_feature_view": - feature_view = OnDemandFeatureView.from_proto( - any_feature_view.on_demand_feature_view - ) - elif feature_view_type == "stream_feature_view": - feature_view = StreamFeatureView.from_proto( - any_feature_view.stream_feature_view - ) - - return feature_view - - -class RemoteRegistryConfig(RegistryConfig): - registry_type: StrictStr = "remote" - """ str: Provider name or a class name that implements Registry.""" - - path: StrictStr = "" - """ str: Path to metadata store. - If registry_type is 'remote', then this is a URL for registry server """ - - cert: StrictStr = "" - """ str: Path to the public certificate when the registry server starts in TLS(SSL) mode. This may be needed if the registry server started with a self-signed certificate, typically this file ends with `*.crt`, `*.cer`, or `*.pem`. - If registry_type is 'remote', then this configuration is needed to connect to remote registry server in TLS mode. If the remote registry started in non-tls mode then this configuration is not needed.""" - - is_tls: bool = False - """ bool: Set to `True` if you plan to connect to a registry server running in TLS (SSL) mode. - If you intend to add the public certificate to the trust store instead of passing it via the `cert` parameter, this field must be set to `True`. - If you are planning to add the public certificate as part of the trust store instead of passing it as a `cert` parameters then setting this field to `true` is mandatory. - """ - - -class RemoteRegistry(BaseRegistry): - def __init__( - self, - registry_config: Union[RegistryConfig, RemoteRegistryConfig], - project: str, - repo_path: Optional[Path], - auth_config: AuthConfig = NoAuthConfig(), - ): - self.auth_config = auth_config - assert isinstance(registry_config, RemoteRegistryConfig) - self.channel = self._create_grpc_channel(registry_config) - - auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) - self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor) - self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel) - - def _create_grpc_channel(self, registry_config): - assert isinstance(registry_config, RemoteRegistryConfig) - if registry_config.cert or registry_config.is_tls: - cafile = os.getenv("SSL_CERT_FILE") or os.getenv("REQUESTS_CA_BUNDLE") - if not cafile and not registry_config.cert: - raise EnvironmentError( - "SSL_CERT_FILE or REQUESTS_CA_BUNDLE environment variable must be set to use secure TLS or set the cert parameter in feature_Store.yaml file under remote registry configuration." - ) - with open( - registry_config.cert if registry_config.cert else cafile, "rb" - ) as cert_file: - trusted_certs = cert_file.read() - tls_credentials = grpc.ssl_channel_credentials( - root_certificates=trusted_certs - ) - return grpc.secure_channel(registry_config.path, tls_credentials) - else: - # Create an insecure gRPC channel - return grpc.insecure_channel(registry_config.path) - - def close(self): - if self.channel: - self.channel.close() - - def __del__(self): - self.close() - - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - request = RegistryServer_pb2.ApplyEntityRequest( - entity=entity.to_proto(), project=project, commit=commit - ) - self.stub.ApplyEntity(request) - - def delete_entity(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteEntityRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteEntity(request) - - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - request = RegistryServer_pb2.GetEntityRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetEntity(request) - return Entity.from_proto(response) - - def list_entities( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Entity]: - request = RegistryServer_pb2.ListEntitiesRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListEntities(request) - return [Entity.from_proto(entity) for entity in response.entities] - - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - request = RegistryServer_pb2.ApplyDataSourceRequest( - data_source=data_source.to_proto(), project=project, commit=commit - ) - self.stub.ApplyDataSource(request) - - def delete_data_source(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteDataSourceRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteDataSource(request) - - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - request = RegistryServer_pb2.GetDataSourceRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetDataSource(request) - return DataSource.from_proto(response) - - def list_data_sources( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[DataSource]: - request = RegistryServer_pb2.ListDataSourcesRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListDataSources(request) - return [ - DataSource.from_proto(data_source) for data_source in response.data_sources - ] - - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - request = RegistryServer_pb2.ApplyFeatureServiceRequest( - feature_service=feature_service.to_proto(), project=project, commit=commit - ) - self.stub.ApplyFeatureService(request) - - def delete_feature_service(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteFeatureServiceRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteFeatureService(request) - - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - request = RegistryServer_pb2.GetFeatureServiceRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetFeatureService(request) - return FeatureService.from_proto(response) - - def list_feature_services( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureService]: - request = RegistryServer_pb2.ListFeatureServicesRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListFeatureServices(request) - return [ - FeatureService.from_proto(feature_service) - for feature_service in response.feature_services - ] - - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - if isinstance(feature_view, StreamFeatureView): - arg_name = "stream_feature_view" - elif isinstance(feature_view, FeatureView): - arg_name = "feature_view" - elif isinstance(feature_view, OnDemandFeatureView): - arg_name = "on_demand_feature_view" - - request = RegistryServer_pb2.ApplyFeatureViewRequest( - feature_view=( - feature_view.to_proto() if arg_name == "feature_view" else None - ), - stream_feature_view=( - feature_view.to_proto() if arg_name == "stream_feature_view" else None - ), - on_demand_feature_view=( - feature_view.to_proto() - if arg_name == "on_demand_feature_view" - else None - ), - project=project, - commit=commit, - ) - - self.stub.ApplyFeatureView(request) - - def delete_feature_view(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteFeatureViewRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteFeatureView(request) - - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> StreamFeatureView: - request = RegistryServer_pb2.GetStreamFeatureViewRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetStreamFeatureView(request) - return StreamFeatureView.from_proto(response) - - def list_stream_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[StreamFeatureView]: - request = RegistryServer_pb2.ListStreamFeatureViewsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListStreamFeatureViews(request) - return [ - StreamFeatureView.from_proto(stream_feature_view) - for stream_feature_view in response.stream_feature_views - ] - - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> OnDemandFeatureView: - request = RegistryServer_pb2.GetOnDemandFeatureViewRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetOnDemandFeatureView(request) - return OnDemandFeatureView.from_proto(response) - - def list_on_demand_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[OnDemandFeatureView]: - request = RegistryServer_pb2.ListOnDemandFeatureViewsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListOnDemandFeatureViews(request) - return [ - OnDemandFeatureView.from_proto(on_demand_feature_view) - for on_demand_feature_view in response.on_demand_feature_views - ] - - def get_any_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> BaseFeatureView: - request = RegistryServer_pb2.GetAnyFeatureViewRequest( - name=name, project=project, allow_cache=allow_cache - ) - - response: RegistryServer_pb2.GetAnyFeatureViewResponse = ( - self.stub.GetAnyFeatureView(request) - ) - any_feature_view = response.any_feature_view - return extract_base_feature_view(any_feature_view) - - def list_all_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[BaseFeatureView]: - request = RegistryServer_pb2.ListAllFeatureViewsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - - response: RegistryServer_pb2.ListAllFeatureViewsResponse = ( - self.stub.ListAllFeatureViews(request) - ) - return [ - extract_base_feature_view(any_feature_view) - for any_feature_view in response.feature_views - ] - - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - request = RegistryServer_pb2.GetFeatureViewRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetFeatureView(request) - return FeatureView.from_proto(response) - - def list_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureView]: - request = RegistryServer_pb2.ListFeatureViewsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListFeatureViews(request) - - return [ - FeatureView.from_proto(feature_view) - for feature_view in response.feature_views - ] - - def apply_materialization( - self, - feature_view: Union[FeatureView, OnDemandFeatureView], - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - start_date_timestamp = Timestamp() - end_date_timestamp = Timestamp() - start_date_timestamp.FromDatetime(start_date) - end_date_timestamp.FromDatetime(end_date) - - # TODO: for this to work for stream feature views, ApplyMaterializationRequest needs to be updated - request = RegistryServer_pb2.ApplyMaterializationRequest( - feature_view=feature_view.to_proto(), - project=project, - start_date=start_date_timestamp, - end_date=end_date_timestamp, - commit=commit, - ) - self.stub.ApplyMaterialization(request) - - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - request = RegistryServer_pb2.ApplySavedDatasetRequest( - saved_dataset=saved_dataset.to_proto(), project=project, commit=commit - ) - self.stub.ApplyFeatureService(request) - - def delete_saved_dataset(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteSavedDatasetRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteSavedDataset(request) - - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - request = RegistryServer_pb2.GetSavedDatasetRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetSavedDataset(request) - return SavedDataset.from_proto(response) - - def list_saved_datasets( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[SavedDataset]: - request = RegistryServer_pb2.ListSavedDatasetsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListSavedDatasets(request) - return [ - SavedDataset.from_proto(saved_dataset) - for saved_dataset in response.saved_datasets - ] - - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - request = RegistryServer_pb2.ApplyValidationReferenceRequest( - validation_reference=validation_reference.to_proto(), - project=project, - commit=commit, - ) - self.stub.ApplyValidationReference(request) - - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeleteValidationReferenceRequest( - name=name, project=project, commit=commit - ) - self.stub.DeleteValidationReference(request) - - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - request = RegistryServer_pb2.GetValidationReferenceRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetValidationReference(request) - return ValidationReference.from_proto(response) - - def list_validation_references( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[ValidationReference]: - request = RegistryServer_pb2.ListValidationReferencesRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListValidationReferences(request) - return [ - ValidationReference.from_proto(validation_reference) - for validation_reference in response.validation_references - ] - - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - request = RegistryServer_pb2.ListProjectMetadataRequest( - project=project, allow_cache=allow_cache - ) - response = self.stub.ListProjectMetadata(request) - return [ProjectMetadata.from_proto(pm) for pm in response.project_metadata] - - def update_infra(self, infra: Infra, project: str, commit: bool = True): - request = RegistryServer_pb2.UpdateInfraRequest( - infra=infra.to_proto(), project=project, commit=commit - ) - self.stub.UpdateInfra(request) - - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - request = RegistryServer_pb2.GetInfraRequest( - project=project, allow_cache=allow_cache - ) - response = self.stub.GetInfra(request) - return Infra.from_proto(response) - - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): - pass - - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: - pass - - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - permission_proto = permission.to_proto() - permission_proto.spec.project = project - - request = RegistryServer_pb2.ApplyPermissionRequest( - permission=permission_proto, project=project, commit=commit - ) - self.stub.ApplyPermission(request) - - def delete_permission(self, name: str, project: str, commit: bool = True): - request = RegistryServer_pb2.DeletePermissionRequest( - name=name, project=project, commit=commit - ) - self.stub.DeletePermission(request) - - def get_permission( - self, name: str, project: str, allow_cache: bool = False - ) -> Permission: - request = RegistryServer_pb2.GetPermissionRequest( - name=name, project=project, allow_cache=allow_cache - ) - response = self.stub.GetPermission(request) - - return Permission.from_proto(response) - - def list_permissions( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Permission]: - request = RegistryServer_pb2.ListPermissionsRequest( - project=project, allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListPermissions(request) - return [ - Permission.from_proto(permission) for permission in response.permissions - ] - - def apply_project( - self, - project: Project, - commit: bool = True, - ): - project_proto = project.to_proto() - - request = RegistryServer_pb2.ApplyProjectRequest( - project=project_proto, commit=commit - ) - self.stub.ApplyProject(request) - - def delete_project( - self, - name: str, - commit: bool = True, - ): - request = RegistryServer_pb2.DeleteProjectRequest(name=name, commit=commit) - self.stub.DeleteProject(request) - - def get_project( - self, - name: str, - allow_cache: bool = False, - ) -> Project: - request = RegistryServer_pb2.GetProjectRequest( - name=name, allow_cache=allow_cache - ) - response = self.stub.GetProject(request) - - return Project.from_proto(response) - - def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Project]: - request = RegistryServer_pb2.ListProjectsRequest( - allow_cache=allow_cache, tags=tags - ) - response = self.stub.ListProjects(request) - return [Project.from_proto(project) for project in response.projects] - - def proto(self) -> RegistryProto: - return self.stub.Proto(Empty()) - - def commit(self): - self.stub.Commit(Empty()) - - def refresh(self, project: Optional[str] = None): - request = RegistryServer_pb2.RefreshRequest(project=str(project)) - self.stub.Refresh(request) - - def teardown(self): - pass +import os +from datetime import datetime +from pathlib import Path +from typing import List, Optional, Union + +import grpc +from google.protobuf.empty_pb2 import Empty +from google.protobuf.timestamp_pb2 import Timestamp +from pydantic import StrictStr + +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.auth_model import AuthConfig, NoAuthConfig +from feast.permissions.client.grpc_client_auth_interceptor import ( + GrpcClientAuthHeaderInterceptor, +) +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc +from feast.repo_config import RegistryConfig +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView + + +def extract_base_feature_view( + any_feature_view: RegistryServer_pb2.AnyFeatureView, +) -> BaseFeatureView: + feature_view_type = any_feature_view.WhichOneof("any_feature_view") + if feature_view_type == "feature_view": + feature_view = FeatureView.from_proto(any_feature_view.feature_view) + elif feature_view_type == "on_demand_feature_view": + feature_view = OnDemandFeatureView.from_proto( + any_feature_view.on_demand_feature_view + ) + elif feature_view_type == "stream_feature_view": + feature_view = StreamFeatureView.from_proto( + any_feature_view.stream_feature_view + ) + + return feature_view + + +class RemoteRegistryConfig(RegistryConfig): + registry_type: StrictStr = "remote" + """ str: Provider name or a class name that implements Registry.""" + + path: StrictStr = "" + """ str: Path to metadata store. + If registry_type is 'remote', then this is a URL for registry server """ + + cert: StrictStr = "" + """ str: Path to the public certificate when the registry server starts in TLS(SSL) mode. This may be needed if the registry server started with a self-signed certificate, typically this file ends with `*.crt`, `*.cer`, or `*.pem`. + If registry_type is 'remote', then this configuration is needed to connect to remote registry server in TLS mode. If the remote registry started in non-tls mode then this configuration is not needed.""" + + is_tls: bool = False + """ bool: Set to `True` if you plan to connect to a registry server running in TLS (SSL) mode. + If you intend to add the public certificate to the trust store instead of passing it via the `cert` parameter, this field must be set to `True`. + If you are planning to add the public certificate as part of the trust store instead of passing it as a `cert` parameters then setting this field to `true` is mandatory. + """ + + +class RemoteRegistry(BaseRegistry): + def __init__( + self, + registry_config: Union[RegistryConfig, RemoteRegistryConfig], + project: str, + repo_path: Optional[Path], + auth_config: AuthConfig = NoAuthConfig(), + ): + self.auth_config = auth_config + assert isinstance(registry_config, RemoteRegistryConfig) + self.channel = self._create_grpc_channel(registry_config) + + auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) + self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor) + self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel) + + def _create_grpc_channel(self, registry_config): + assert isinstance(registry_config, RemoteRegistryConfig) + if registry_config.cert or registry_config.is_tls: + cafile = os.getenv("SSL_CERT_FILE") or os.getenv("REQUESTS_CA_BUNDLE") + if not cafile and not registry_config.cert: + raise EnvironmentError( + "SSL_CERT_FILE or REQUESTS_CA_BUNDLE environment variable must be set to use secure TLS or set the cert parameter in feature_Store.yaml file under remote registry configuration." + ) + with open( + registry_config.cert if registry_config.cert else cafile, "rb" + ) as cert_file: + trusted_certs = cert_file.read() + tls_credentials = grpc.ssl_channel_credentials( + root_certificates=trusted_certs + ) + return grpc.secure_channel(registry_config.path, tls_credentials) + else: + # Create an insecure gRPC channel + return grpc.insecure_channel(registry_config.path) + + def close(self): + if self.channel: + self.channel.close() + + def __del__(self): + self.close() + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + request = RegistryServer_pb2.ApplyEntityRequest( + entity=entity.to_proto(), project=project, commit=commit + ) + self.stub.ApplyEntity(request) + + def delete_entity(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteEntityRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteEntity(request) + + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + request = RegistryServer_pb2.GetEntityRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetEntity(request) + return Entity.from_proto(response) + + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: + request = RegistryServer_pb2.ListEntitiesRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListEntities(request) + return [Entity.from_proto(entity) for entity in response.entities] + + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + request = RegistryServer_pb2.ApplyDataSourceRequest( + data_source=data_source.to_proto(), project=project, commit=commit + ) + self.stub.ApplyDataSource(request) + + def delete_data_source(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteDataSourceRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteDataSource(request) + + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + request = RegistryServer_pb2.GetDataSourceRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetDataSource(request) + return DataSource.from_proto(response) + + def list_data_sources( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[DataSource]: + request = RegistryServer_pb2.ListDataSourcesRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListDataSources(request) + return [ + DataSource.from_proto(data_source) for data_source in response.data_sources + ] + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + request = RegistryServer_pb2.ApplyFeatureServiceRequest( + feature_service=feature_service.to_proto(), project=project, commit=commit + ) + self.stub.ApplyFeatureService(request) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteFeatureServiceRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteFeatureService(request) + + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + request = RegistryServer_pb2.GetFeatureServiceRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetFeatureService(request) + return FeatureService.from_proto(response) + + def list_feature_services( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureService]: + request = RegistryServer_pb2.ListFeatureServicesRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListFeatureServices(request) + return [ + FeatureService.from_proto(feature_service) + for feature_service in response.feature_services + ] + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + if isinstance(feature_view, StreamFeatureView): + arg_name = "stream_feature_view" + elif isinstance(feature_view, FeatureView): + arg_name = "feature_view" + elif isinstance(feature_view, OnDemandFeatureView): + arg_name = "on_demand_feature_view" + + request = RegistryServer_pb2.ApplyFeatureViewRequest( + feature_view=( + feature_view.to_proto() if arg_name == "feature_view" else None + ), + stream_feature_view=( + feature_view.to_proto() if arg_name == "stream_feature_view" else None + ), + on_demand_feature_view=( + feature_view.to_proto() + if arg_name == "on_demand_feature_view" + else None + ), + project=project, + commit=commit, + ) + + self.stub.ApplyFeatureView(request) + + def delete_feature_view(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteFeatureViewRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteFeatureView(request) + + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> StreamFeatureView: + request = RegistryServer_pb2.GetStreamFeatureViewRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetStreamFeatureView(request) + return StreamFeatureView.from_proto(response) + + def list_stream_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[StreamFeatureView]: + request = RegistryServer_pb2.ListStreamFeatureViewsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListStreamFeatureViews(request) + return [ + StreamFeatureView.from_proto(stream_feature_view) + for stream_feature_view in response.stream_feature_views + ] + + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + request = RegistryServer_pb2.GetOnDemandFeatureViewRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetOnDemandFeatureView(request) + return OnDemandFeatureView.from_proto(response) + + def list_on_demand_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[OnDemandFeatureView]: + request = RegistryServer_pb2.ListOnDemandFeatureViewsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListOnDemandFeatureViews(request) + return [ + OnDemandFeatureView.from_proto(on_demand_feature_view) + for on_demand_feature_view in response.on_demand_feature_views + ] + + def get_any_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> BaseFeatureView: + request = RegistryServer_pb2.GetAnyFeatureViewRequest( + name=name, project=project, allow_cache=allow_cache + ) + + response: RegistryServer_pb2.GetAnyFeatureViewResponse = ( + self.stub.GetAnyFeatureView(request) + ) + any_feature_view = response.any_feature_view + return extract_base_feature_view(any_feature_view) + + def list_all_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[BaseFeatureView]: + request = RegistryServer_pb2.ListAllFeatureViewsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + + response: RegistryServer_pb2.ListAllFeatureViewsResponse = ( + self.stub.ListAllFeatureViews(request) + ) + return [ + extract_base_feature_view(any_feature_view) + for any_feature_view in response.feature_views + ] + + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + request = RegistryServer_pb2.GetFeatureViewRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetFeatureView(request) + return FeatureView.from_proto(response) + + def list_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureView]: + request = RegistryServer_pb2.ListFeatureViewsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListFeatureViews(request) + + return [ + FeatureView.from_proto(feature_view) + for feature_view in response.feature_views + ] + + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + start_date_timestamp = Timestamp() + end_date_timestamp = Timestamp() + start_date_timestamp.FromDatetime(start_date) + end_date_timestamp.FromDatetime(end_date) + + # TODO: for this to work for stream feature views, ApplyMaterializationRequest needs to be updated + request = RegistryServer_pb2.ApplyMaterializationRequest( + feature_view=feature_view.to_proto(), + project=project, + start_date=start_date_timestamp, + end_date=end_date_timestamp, + commit=commit, + ) + self.stub.ApplyMaterialization(request) + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + request = RegistryServer_pb2.ApplySavedDatasetRequest( + saved_dataset=saved_dataset.to_proto(), project=project, commit=commit + ) + self.stub.ApplyFeatureService(request) + + def delete_saved_dataset(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteSavedDatasetRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteSavedDataset(request) + + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + request = RegistryServer_pb2.GetSavedDatasetRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetSavedDataset(request) + return SavedDataset.from_proto(response) + + def list_saved_datasets( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[SavedDataset]: + request = RegistryServer_pb2.ListSavedDatasetsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListSavedDatasets(request) + return [ + SavedDataset.from_proto(saved_dataset) + for saved_dataset in response.saved_datasets + ] + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + request = RegistryServer_pb2.ApplyValidationReferenceRequest( + validation_reference=validation_reference.to_proto(), + project=project, + commit=commit, + ) + self.stub.ApplyValidationReference(request) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeleteValidationReferenceRequest( + name=name, project=project, commit=commit + ) + self.stub.DeleteValidationReference(request) + + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + request = RegistryServer_pb2.GetValidationReferenceRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetValidationReference(request) + return ValidationReference.from_proto(response) + + def list_validation_references( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[ValidationReference]: + request = RegistryServer_pb2.ListValidationReferencesRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListValidationReferences(request) + return [ + ValidationReference.from_proto(validation_reference) + for validation_reference in response.validation_references + ] + + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + request = RegistryServer_pb2.ListProjectMetadataRequest( + project=project, allow_cache=allow_cache + ) + response = self.stub.ListProjectMetadata(request) + return [ProjectMetadata.from_proto(pm) for pm in response.project_metadata] + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + request = RegistryServer_pb2.UpdateInfraRequest( + infra=infra.to_proto(), project=project, commit=commit + ) + self.stub.UpdateInfra(request) + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + request = RegistryServer_pb2.GetInfraRequest( + project=project, allow_cache=allow_cache + ) + response = self.stub.GetInfra(request) + return Infra.from_proto(response) + + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + pass + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + pass + + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + permission_proto = permission.to_proto() + permission_proto.spec.project = project + + request = RegistryServer_pb2.ApplyPermissionRequest( + permission=permission_proto, project=project, commit=commit + ) + self.stub.ApplyPermission(request) + + def delete_permission(self, name: str, project: str, commit: bool = True): + request = RegistryServer_pb2.DeletePermissionRequest( + name=name, project=project, commit=commit + ) + self.stub.DeletePermission(request) + + def get_permission( + self, name: str, project: str, allow_cache: bool = False + ) -> Permission: + request = RegistryServer_pb2.GetPermissionRequest( + name=name, project=project, allow_cache=allow_cache + ) + response = self.stub.GetPermission(request) + + return Permission.from_proto(response) + + def list_permissions( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Permission]: + request = RegistryServer_pb2.ListPermissionsRequest( + project=project, allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListPermissions(request) + return [ + Permission.from_proto(permission) for permission in response.permissions + ] + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + project_proto = project.to_proto() + + request = RegistryServer_pb2.ApplyProjectRequest( + project=project_proto, commit=commit + ) + self.stub.ApplyProject(request) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + request = RegistryServer_pb2.DeleteProjectRequest(name=name, commit=commit) + self.stub.DeleteProject(request) + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + request = RegistryServer_pb2.GetProjectRequest( + name=name, allow_cache=allow_cache + ) + response = self.stub.GetProject(request) + + return Project.from_proto(response) + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + request = RegistryServer_pb2.ListProjectsRequest( + allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListProjects(request) + return [Project.from_proto(project) for project in response.projects] + + def proto(self) -> RegistryProto: + return self.stub.Proto(Empty()) + + def commit(self): + self.stub.Commit(Empty()) + + def refresh(self, project: Optional[str] = None): + request = RegistryServer_pb2.RefreshRequest(project=str(project)) + self.stub.Refresh(request) + + def teardown(self): + pass diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 71a8aa067ec..e46231ca7a0 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -1,1375 +1,1375 @@ -import logging -import os -import uuid -from binascii import hexlify -from datetime import datetime, timedelta, timezone -from enum import Enum -from threading import Lock -from typing import Any, Callable, List, Literal, Optional, Union, cast - -from pydantic import ConfigDict, Field, StrictStr - -import feast -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.errors import ( - DataSourceObjectNotFoundException, - EntityNotFoundException, - FeatureServiceNotFoundException, - FeatureViewNotFoundException, - PermissionNotFoundException, - ProjectNotFoundException, - ProjectObjectNotFoundException, - SavedDatasetNotFound, - ValidationReferenceNotFound, -) -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.infra.infra_object import Infra -from feast.infra.registry import proto_registry_utils -from feast.infra.registry.base_registry import BaseRegistry -from feast.infra.utils.snowflake.snowflake_utils import ( - GetSnowflakeConnection, - execute_snowflake_statement, -) -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto -from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto -from feast.protos.feast.core.FeatureService_pb2 import ( - FeatureService as FeatureServiceProto, -) -from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto -from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto -from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( - OnDemandFeatureView as OnDemandFeatureViewProto, -) -from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto -from feast.protos.feast.core.Project_pb2 import Project as ProjectProto -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto -from feast.protos.feast.core.StreamFeatureView_pb2 import ( - StreamFeatureView as StreamFeatureViewProto, -) -from feast.protos.feast.core.ValidationProfile_pb2 import ( - ValidationReference as ValidationReferenceProto, -) -from feast.repo_config import RegistryConfig -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView -from feast.utils import _utc_now, has_all_tags - -logger = logging.getLogger(__name__) - - -class FeastMetadataKeys(Enum): - LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" - PROJECT_UUID = "project_uuid" - - -class SnowflakeRegistryConfig(RegistryConfig): - """Registry config for Snowflake""" - - registry_type: Literal["snowflake.registry"] = "snowflake.registry" - """ Registry type selector """ - - type: Literal["snowflake.registry"] = "snowflake.registry" - """ Registry type selector """ - - config_path: Optional[str] = os.path.expanduser("~/.snowsql/config") - """ Snowflake snowsql config path -- absolute path required (Cant use ~)""" - - connection_name: Optional[str] = None - """ Snowflake connector connection name -- typically defined in ~/.snowflake/connections.toml """ - - account: Optional[str] = None - """ Snowflake deployment identifier -- drop .snowflakecomputing.com """ - - user: Optional[str] = None - """ Snowflake user name """ - - password: Optional[str] = None - """ Snowflake password """ - - role: Optional[str] = None - """ Snowflake role name """ - - warehouse: Optional[str] = None - """ Snowflake warehouse name """ - - authenticator: Optional[str] = None - """ Snowflake authenticator name """ - - private_key: Optional[str] = None - """ Snowflake private key file path""" - - private_key_content: Optional[bytes] = None - """ Snowflake private key stored as bytes""" - - private_key_passphrase: Optional[str] = None - """ Snowflake private key file passphrase""" - - database: StrictStr - """ Snowflake database name """ - - schema_: Optional[str] = Field("PUBLIC", alias="schema") - """ Snowflake schema name """ - model_config = ConfigDict(populate_by_name=True) - - -class SnowflakeRegistry(BaseRegistry): - def __init__( - self, - registry_config, - project: str, - repo_path, - ): - assert registry_config is not None and isinstance( - registry_config, SnowflakeRegistryConfig - ), "SnowflakeRegistry needs a valid registry_config, a path does not work" - - self.registry_config = registry_config - self.registry_path = ( - f'"{self.registry_config.database}"."{self.registry_config.schema_}"' - ) - - with GetSnowflakeConnection(self.registry_config) as conn: - sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql" - with open(sql_function_file, "r") as file: - sql_file = file.read() - sql_cmds = sql_file.split(";") - for command in sql_cmds: - query = command.replace("REGISTRY_PATH", f"{self.registry_path}") - execute_snowflake_statement(conn, query) - - self.purge_feast_metadata = registry_config.purge_feast_metadata - self._sync_feast_metadata_to_projects_table() - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) - - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = _utc_now() - self._refresh_lock = Lock() - self.cached_registry_proto_ttl = timedelta( - seconds=( - registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 - ) - ) - self.project = project - - def _sync_feast_metadata_to_projects_table(self): - feast_metadata_projects: set = [] - projects_set: set = [] - - with GetSnowflakeConnection(self.registry_config) as conn: - query = ( - f'SELECT DISTINCT project_id FROM {self.registry_path}."FEAST_METADATA"' - ) - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - for row in df.iterrows(): - feast_metadata_projects.add(row[1]["PROJECT_ID"]) - - if len(feast_metadata_projects) > 0: - with GetSnowflakeConnection(self.registry_config) as conn: - query = f'SELECT project_id FROM {self.registry_path}."PROJECTS"' - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - for row in df.iterrows(): - projects_set.add(row[1]["PROJECT_ID"]) - - # Find object in feast_metadata_projects but not in projects - projects_to_sync = set(feast_metadata_projects) - set(projects_set) - for project_name in projects_to_sync: - self.apply_project(Project(name=project_name), commit=True) - - if self.purge_feast_metadata: - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - DELETE FROM {self.registry_path}."FEAST_METADATA" - """ - execute_snowflake_statement(conn, query) - - def refresh(self, project: Optional[str] = None): - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = _utc_now() - - def _refresh_cached_registry_if_necessary(self): - with self._refresh_lock: - expired = ( - self.cached_registry_proto is None - or self.cached_registry_proto_created is None - ) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - _utc_now() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl - ) - ) - ) - - if expired: - logger.info("Registry cache expired, so refreshing") - self.refresh() - - def teardown(self): - with GetSnowflakeConnection(self.registry_config) as conn: - sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql" - with open(sql_function_file, "r") as file: - sqlFile = file.read() - sqlCommands = sqlFile.split(";") - for command in sqlCommands: - query = command.replace("REGISTRY_PATH", f"{self.registry_path}") - execute_snowflake_statement(conn, query) - - # apply operations - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - return self._apply_object( - "DATA_SOURCES", - project, - "DATA_SOURCE_NAME", - data_source, - "DATA_SOURCE_PROTO", - ) - - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - return self._apply_object( - "ENTITIES", project, "ENTITY_NAME", entity, "ENTITY_PROTO" - ) - - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - return self._apply_object( - "FEATURE_SERVICES", - project, - "FEATURE_SERVICE_NAME", - feature_service, - "FEATURE_SERVICE_PROTO", - ) - - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - fv_table_str = self._infer_fv_table(feature_view) - fv_column_name = fv_table_str[:-1] - return self._apply_object( - fv_table_str, - project, - f"{fv_column_name}_NAME", - feature_view, - f"{fv_column_name}_PROTO", - ) - - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - return self._apply_object( - "SAVED_DATASETS", - project, - "SAVED_DATASET_NAME", - saved_dataset, - "SAVED_DATASET_PROTO", - ) - - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - return self._apply_object( - "VALIDATION_REFERENCES", - project, - "VALIDATION_REFERENCE_NAME", - validation_reference, - "VALIDATION_REFERENCE_PROTO", - ) - - def update_infra(self, infra: Infra, project: str, commit: bool = True): - self._apply_object( - "MANAGED_INFRA", - project, - "INFRA_NAME", - infra, - "INFRA_PROTO", - name="infra_obj", - ) - - def _initialize_project_if_not_exists(self, project_name: str): - try: - self.get_project(project_name, allow_cache=True) - return - except ProjectObjectNotFoundException: - try: - self.get_project(project_name, allow_cache=False) - return - except ProjectObjectNotFoundException: - self.apply_project(Project(name=project_name), commit=True) - - def _apply_object( - self, - table: str, - project: str, - id_field_name: str, - obj: Any, - proto_field_name: str, - name: Optional[str] = None, - ): - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) - # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option - if not isinstance(obj, Project): - self._initialize_project_if_not_exists(project_name=project) - - name = name or (obj.name if hasattr(obj, "name") else None) - assert name, f"name needs to be provided for {obj}" - - update_datetime = _utc_now() - if hasattr(obj, "last_updated_timestamp"): - obj.last_updated_timestamp = update_datetime - - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - project_id - FROM - {self.registry_path}."{table}" - WHERE - project_id = '{project}' - AND {id_field_name.lower()} = '{name}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - proto = hexlify(obj.to_proto().SerializeToString()).__str__()[1:] - query = f""" - UPDATE {self.registry_path}."{table}" - SET - {proto_field_name} = TO_BINARY({proto}), - last_updated_timestamp = CURRENT_TIMESTAMP() - WHERE - {id_field_name.lower()} = '{name}' - """ - execute_snowflake_statement(conn, query) - - else: - obj_proto = obj.to_proto() - - if hasattr(obj_proto, "meta") and hasattr( - obj_proto.meta, "created_timestamp" - ): - obj_proto.meta.created_timestamp.FromDatetime(update_datetime) - - proto = hexlify(obj_proto.SerializeToString()).__str__()[1:] - if table == "FEATURE_VIEWS": - query = f""" - INSERT INTO {self.registry_path}."{table}" - VALUES - ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '', '') - """ - elif "_FEATURE_VIEWS" in table: - query = f""" - INSERT INTO {self.registry_path}."{table}" - VALUES - ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '') - """ - else: - query = f""" - INSERT INTO {self.registry_path}."{table}" - VALUES - ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto})) - """ - execute_snowflake_statement(conn, query) - - if not isinstance(obj, Project): - self.apply_project( - self.get_project(name=project, allow_cache=False), commit=True - ) - - if not self.purge_feast_metadata: - self._set_last_updated_metadata(update_datetime, project) - - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - return self._apply_object( - "PERMISSIONS", - project, - "PERMISSION_NAME", - permission, - "PERMISSION_PROTO", - ) - - # delete operations - def delete_data_source(self, name: str, project: str, commit: bool = True): - return self._delete_object( - "DATA_SOURCES", - name, - project, - "DATA_SOURCE_NAME", - DataSourceObjectNotFoundException, - ) - - def delete_entity(self, name: str, project: str, commit: bool = True): - return self._delete_object( - "ENTITIES", name, project, "ENTITY_NAME", EntityNotFoundException - ) - - def delete_feature_service(self, name: str, project: str, commit: bool = True): - return self._delete_object( - "FEATURE_SERVICES", - name, - project, - "FEATURE_SERVICE_NAME", - FeatureServiceNotFoundException, - ) - - # can you have featureviews with the same name - def delete_feature_view(self, name: str, project: str, commit: bool = True): - deleted_count = 0 - for table in { - "FEATURE_VIEWS", - "ON_DEMAND_FEATURE_VIEWS", - "STREAM_FEATURE_VIEWS", - }: - deleted_count += self._delete_object( - table, name, project, "FEATURE_VIEW_NAME", None - ) - if deleted_count == 0: - raise FeatureViewNotFoundException(name, project) - - def delete_saved_dataset(self, name: str, project: str, allow_cache: bool = False): - self._delete_object( - "SAVED_DATASETS", - name, - project, - "SAVED_DATASET_NAME", - SavedDatasetNotFound, - ) - - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - self._delete_object( - "VALIDATION_REFERENCES", - name, - project, - "VALIDATION_REFERENCE_NAME", - ValidationReferenceNotFound, - ) - - def _delete_object( - self, - table: str, - name: str, - project: str, - id_field_name: str, - not_found_exception: Optional[Callable], - ): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - DELETE FROM {self.registry_path}."{table}" - WHERE - project_id = '{project}' - AND {id_field_name.lower()} = '{name}' - """ - cursor = execute_snowflake_statement(conn, query) - - if cursor.rowcount < 1 and not_found_exception: # type: ignore - raise not_found_exception(name, project) - self._set_last_updated_metadata(_utc_now(), project) - - return cursor.rowcount - - def delete_permission(self, name: str, project: str, commit: bool = True): - return self._delete_object( - "PERMISSIONS", - name, - project, - "PERMISSION_NAME", - PermissionNotFoundException, - ) - - # get operations - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_data_source( - self.cached_registry_proto, name, project - ) - return self._get_object( - "DATA_SOURCES", - name, - project, - DataSourceProto, - DataSource, - "DATA_SOURCE_NAME", - "DATA_SOURCE_PROTO", - DataSourceObjectNotFoundException, - ) - - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_entity( - self.cached_registry_proto, name, project - ) - return self._get_object( - "ENTITIES", - name, - project, - EntityProto, - Entity, - "ENTITY_NAME", - "ENTITY_PROTO", - EntityNotFoundException, - ) - - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_feature_service( - self.cached_registry_proto, name, project - ) - return self._get_object( - "FEATURE_SERVICES", - name, - project, - FeatureServiceProto, - FeatureService, - "FEATURE_SERVICE_NAME", - "FEATURE_SERVICE_PROTO", - FeatureServiceNotFoundException, - ) - - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_feature_view( - self.cached_registry_proto, name, project - ) - return self._get_object( - "FEATURE_VIEWS", - name, - project, - FeatureViewProto, - FeatureView, - "FEATURE_VIEW_NAME", - "FEATURE_VIEW_PROTO", - FeatureViewNotFoundException, - ) - - def get_any_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> BaseFeatureView: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_any_feature_view( - self.cached_registry_proto, name, project - ) - fv = self._get_object( - "FEATURE_VIEWS", - name, - project, - FeatureViewProto, - FeatureView, - "FEATURE_VIEW_NAME", - "FEATURE_VIEW_PROTO", - None, - ) - - if not fv: - fv = self._get_object( - "STREAM_FEATURE_VIEWS", - name, - project, - StreamFeatureViewProto, - StreamFeatureView, - "STREAM_FEATURE_VIEW_NAME", - "STREAM_FEATURE_VIEW_PROTO", - None, - ) - if not fv: - fv = self._get_object( - "ON_DEMAND_FEATURE_VIEWS", - name, - project, - OnDemandFeatureViewProto, - OnDemandFeatureView, - "ON_DEMAND_FEATURE_VIEW_NAME", - "ON_DEMAND_FEATURE_VIEW_PROTO", - FeatureViewNotFoundException, - ) - return fv - - def list_all_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[BaseFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_all_feature_views( - self.cached_registry_proto, project, tags - ) - - return ( - cast( - list[BaseFeatureView], - self.list_feature_views(project, allow_cache, tags), - ) - + cast( - list[BaseFeatureView], - self.list_stream_feature_views(project, allow_cache, tags), - ) - + cast( - list[BaseFeatureView], - self.list_on_demand_feature_views(project, allow_cache, tags), - ) - ) - - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - infra_object = self._get_object( - "MANAGED_INFRA", - "infra_obj", - project, - InfraProto, - Infra, - "INFRA_NAME", - "INFRA_PROTO", - None, - ) - infra_object = infra_object or InfraProto() - return Infra.from_proto(infra_object) - - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> OnDemandFeatureView: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_on_demand_feature_view( - self.cached_registry_proto, name, project - ) - return self._get_object( - "ON_DEMAND_FEATURE_VIEWS", - name, - project, - OnDemandFeatureViewProto, - OnDemandFeatureView, - "ON_DEMAND_FEATURE_VIEW_NAME", - "ON_DEMAND_FEATURE_VIEW_PROTO", - FeatureViewNotFoundException, - ) - - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_saved_dataset( - self.cached_registry_proto, name, project - ) - return self._get_object( - "SAVED_DATASETS", - name, - project, - SavedDatasetProto, - SavedDataset, - "SAVED_DATASET_NAME", - "SAVED_DATASET_PROTO", - SavedDatasetNotFound, - ) - - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ): - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_stream_feature_view( - self.cached_registry_proto, name, project - ) - return self._get_object( - "STREAM_FEATURE_VIEWS", - name, - project, - StreamFeatureViewProto, - StreamFeatureView, - "STREAM_FEATURE_VIEW_NAME", - "STREAM_FEATURE_VIEW_PROTO", - FeatureViewNotFoundException, - ) - - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_validation_reference( - self.cached_registry_proto, name, project - ) - return self._get_object( - "VALIDATION_REFERENCES", - name, - project, - ValidationReferenceProto, - ValidationReference, - "VALIDATION_REFERENCE_NAME", - "VALIDATION_REFERENCE_PROTO", - ValidationReferenceNotFound, - ) - - def _get_object( - self, - table: str, - name: str, - project: str, - proto_class: Any, - python_class: Any, - id_field_name: str, - proto_field_name: str, - not_found_exception: Optional[Callable], - ): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - {proto_field_name} - FROM - {self.registry_path}."{table}" - WHERE - project_id = '{project}' - AND {id_field_name.lower()} = '{name}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - _proto = proto_class.FromString(df.squeeze()) - return python_class.from_proto(_proto) - elif not_found_exception: - raise not_found_exception(name, project) - else: - return None - - def get_permission( - self, name: str, project: str, allow_cache: bool = False - ) -> Permission: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_permission( - self.cached_registry_proto, name, project - ) - return self._get_object( - "PERMISSIONS", - name, - project, - PermissionProto, - Permission, - "PERMISSION_NAME", - "PERMISSION_PROTO", - PermissionNotFoundException, - ) - - # list operations - def list_data_sources( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[DataSource]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_data_sources( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "DATA_SOURCES", - project, - DataSourceProto, - DataSource, - "DATA_SOURCE_PROTO", - tags=tags, - ) - - def list_entities( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Entity]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_entities( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "ENTITIES", project, EntityProto, Entity, "ENTITY_PROTO", tags=tags - ) - - def list_feature_services( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureService]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_feature_services( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "FEATURE_SERVICES", - project, - FeatureServiceProto, - FeatureService, - "FEATURE_SERVICE_PROTO", - tags=tags, - ) - - def list_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[FeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_feature_views( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "FEATURE_VIEWS", - project, - FeatureViewProto, - FeatureView, - "FEATURE_VIEW_PROTO", - tags=tags, - ) - - def list_on_demand_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[OnDemandFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_on_demand_feature_views( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "ON_DEMAND_FEATURE_VIEWS", - project, - OnDemandFeatureViewProto, - OnDemandFeatureView, - "ON_DEMAND_FEATURE_VIEW_PROTO", - tags=tags, - ) - - def list_saved_datasets( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[SavedDataset]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_saved_datasets( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "SAVED_DATASETS", - project, - SavedDatasetProto, - SavedDataset, - "SAVED_DATASET_PROTO", - tags=tags, - ) - - def list_stream_feature_views( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[StreamFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_stream_feature_views( - self.cached_registry_proto, project, tags - ) - return self._list_objects( - "STREAM_FEATURE_VIEWS", - project, - StreamFeatureViewProto, - StreamFeatureView, - "STREAM_FEATURE_VIEW_PROTO", - tags=tags, - ) - - def list_validation_references( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[ValidationReference]: - return self._list_objects( - "VALIDATION_REFERENCES", - project, - ValidationReferenceProto, - ValidationReference, - "VALIDATION_REFERENCE_PROTO", - tags=tags, - ) - - def _list_objects( - self, - table: str, - project: str, - proto_class: Any, - python_class: Any, - proto_field_name: str, - tags: Optional[dict[str, str]] = None, - ): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - {proto_field_name} - FROM - {self.registry_path}."{table}" - WHERE - project_id = '{project}' - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - if not df.empty: - objects = [] - for row in df.iterrows(): - obj = python_class.from_proto( - proto_class.FromString(row[1][proto_field_name]) - ) - if has_all_tags(obj.tags, tags): - objects.append(obj) - return objects - return [] - - def list_permissions( - self, - project: str, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Permission]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_permissions( - self.cached_registry_proto, project - ) - return self._list_objects( - "PERMISSIONS", - project, - PermissionProto, - Permission, - "PERMISSION_PROTO", - tags, - ) - - def apply_materialization( - self, - feature_view: Union[FeatureView, OnDemandFeatureView], - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - fv_table_str = self._infer_fv_table(feature_view) - fv_column_name = fv_table_str[:-1] - python_class, proto_class = self._infer_fv_classes(feature_view) - - if python_class in {OnDemandFeatureView}: - raise ValueError( - f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" - ) - fv: Union[FeatureView, StreamFeatureView] = self._get_object( - fv_table_str, - feature_view.name, - project, - proto_class, - python_class, - f"{fv_column_name}_NAME", - f"{fv_column_name}_PROTO", - FeatureViewNotFoundException, - ) - fv.materialization_intervals.append((start_date, end_date)) - self._apply_object( - fv_table_str, - project, - f"{fv_column_name}_NAME", - fv, - f"{fv_column_name}_PROTO", - ) - - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_project_metadata( - self.cached_registry_proto, project - ) - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - metadata_key, - metadata_value - FROM - {self.registry_path}."FEAST_METADATA" - WHERE - project_id = '{project}' - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - project_metadata = ProjectMetadata(project_name=project) - for row in df.iterrows(): - if row[1]["METADATA_KEY"] == FeastMetadataKeys.PROJECT_UUID.value: - project_metadata.project_uuid = row[1]["METADATA_VALUE"] - break - # TODO(adchia): Add other project metadata in a structured way - return [project_metadata] - return [] - - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): - fv_table_str = self._infer_fv_table(feature_view) - fv_column_name = fv_table_str[:-1].lower() - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - project_id - FROM - {self.registry_path}."{fv_table_str}" - WHERE - project_id = '{project}' - AND {fv_column_name}_name = '{feature_view.name}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - if metadata_bytes: - metadata_hex = hexlify(metadata_bytes).__str__()[1:] - else: - metadata_hex = "''" - query = f""" - UPDATE {self.registry_path}."{fv_table_str}" - SET - user_metadata = TO_BINARY({metadata_hex}), - last_updated_timestamp = CURRENT_TIMESTAMP() - WHERE - project_id = '{project}' - AND {fv_column_name}_name = '{feature_view.name}' - """ - execute_snowflake_statement(conn, query) - else: - raise FeatureViewNotFoundException(feature_view.name, project=project) - - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: - fv_table_str = self._infer_fv_table(feature_view) - fv_column_name = fv_table_str[:-1].lower() - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - user_metadata - FROM - {self.registry_path}."{fv_table_str}" - WHERE - {fv_column_name}_name = '{feature_view.name}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if not df.empty: - return df.squeeze() - else: - raise FeatureViewNotFoundException(feature_view.name, project=project) - - def proto(self) -> RegistryProto: - r = RegistryProto() - last_updated_timestamps = [] - - def process_project(project: Project): - nonlocal r, last_updated_timestamps - project_name = project.name - last_updated_timestamp = project.last_updated_timestamp - - try: - cached_project = self.get_project(project_name, True) - except ProjectObjectNotFoundException: - cached_project = None - - allow_cache = False - - if cached_project is not None: - allow_cache = ( - last_updated_timestamp <= cached_project.last_updated_timestamp - ) - - r.projects.extend([project.to_proto()]) - last_updated_timestamps.append(last_updated_timestamp) - - for lister, registry_proto_field in [ - (self.list_entities, r.entities), - (self.list_feature_views, r.feature_views), - (self.list_data_sources, r.data_sources), - (self.list_on_demand_feature_views, r.on_demand_feature_views), - (self.list_stream_feature_views, r.stream_feature_views), - (self.list_feature_services, r.feature_services), - (self.list_saved_datasets, r.saved_datasets), - (self.list_validation_references, r.validation_references), - (self.list_permissions, r.permissions), - ]: - objs: List[Any] = lister(project_name, allow_cache) # type: ignore - if objs: - obj_protos = [obj.to_proto() for obj in objs] - for obj_proto in obj_protos: - if "spec" in obj_proto.DESCRIPTOR.fields_by_name: - obj_proto.spec.project = project_name - else: - obj_proto.project = project_name - registry_proto_field.extend(obj_protos) - - # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, - # the registry proto only has a single infra field, which we're currently setting as the "last" project. - r.infra.CopyFrom(self.get_infra(project_name).to_proto()) - - projects_list = self.list_projects(allow_cache=False) - for project in projects_list: - process_project(project) - - if last_updated_timestamps: - r.last_updated.FromDatetime(max(last_updated_timestamps)) - - return r - - def _get_last_updated_metadata(self, project: str): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - metadata_value - FROM - {self.registry_path}."FEAST_METADATA" - WHERE - project_id = '{project}' - AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if df.empty: - return None - - return datetime.fromtimestamp(int(df.squeeze()), tz=timezone.utc) - - def _infer_fv_classes(self, feature_view): - if isinstance(feature_view, StreamFeatureView): - python_class, proto_class = StreamFeatureView, StreamFeatureViewProto - elif isinstance(feature_view, FeatureView): - python_class, proto_class = FeatureView, FeatureViewProto - elif isinstance(feature_view, OnDemandFeatureView): - python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - return python_class, proto_class - - def _infer_fv_table(self, feature_view) -> str: - if isinstance(feature_view, StreamFeatureView): - table = "STREAM_FEATURE_VIEWS" - elif isinstance(feature_view, FeatureView): - table = "FEATURE_VIEWS" - elif isinstance(feature_view, OnDemandFeatureView): - table = "ON_DEMAND_FEATURE_VIEWS" - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - return table - - def _maybe_init_project_metadata(self, project): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - metadata_value - FROM - {self.registry_path}."FEAST_METADATA" - WHERE - project_id = '{project}' - AND metadata_key = '{FeastMetadataKeys.PROJECT_UUID.value}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - if df.empty: - new_project_uuid = f"{uuid.uuid4()}" - query = f""" - INSERT INTO {self.registry_path}."FEAST_METADATA" - VALUES - ('{project}', '{FeastMetadataKeys.PROJECT_UUID.value}', '{new_project_uuid}', CURRENT_TIMESTAMP()) - """ - execute_snowflake_statement(conn, query) - - def _set_last_updated_metadata(self, last_updated: datetime, project: str): - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT - project_id - FROM - {self.registry_path}."FEAST_METADATA" - WHERE - project_id = '{project}' - AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' - LIMIT 1 - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - update_time = int(last_updated.timestamp()) - if not df.empty: - query = f""" - UPDATE {self.registry_path}."FEAST_METADATA" - SET - project_id = '{project}', - metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', - metadata_value = '{update_time}', - last_updated_timestamp = CURRENT_TIMESTAMP() - WHERE - project_id = '{project}' - AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' - """ - execute_snowflake_statement(conn, query) - - else: - query = f""" - INSERT INTO {self.registry_path}."FEAST_METADATA" - VALUES - ('{project}', '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', '{update_time}', CURRENT_TIMESTAMP()) - """ - execute_snowflake_statement(conn, query) - - def commit(self): - pass - - def apply_project( - self, - project: Project, - commit: bool = True, - ): - return self._apply_object( - "PROJECTS", project.name, "project_name", project, "project_proto" - ) - - def delete_project( - self, - name: str, - commit: bool = True, - ): - project = self.get_project(name, allow_cache=False) - if project: - with GetSnowflakeConnection(self.registry_config) as conn: - for table in { - "MANAGED_INFRA", - "SAVED_DATASETS", - "VALIDATION_REFERENCES", - "FEATURE_SERVICES", - "FEATURE_VIEWS", - "ON_DEMAND_FEATURE_VIEWS", - "STREAM_FEATURE_VIEWS", - "DATA_SOURCES", - "ENTITIES", - "PERMISSIONS", - "FEAST_METADATA", - "PROJECTS", - }: - query = f""" - DELETE FROM {self.registry_path}."{table}" - WHERE - project_id = '{project}' - """ - execute_snowflake_statement(conn, query) - return - - raise ProjectNotFoundException(name) - - def _get_project( - self, - name: str, - ) -> Project: - return self._get_object( - table="PROJECTS", - name=name, - project=name, - proto_class=ProjectProto, - python_class=Project, - id_field_name="project_name", - proto_field_name="project_proto", - not_found_exception=ProjectObjectNotFoundException, - ) - - def get_project( - self, - name: str, - allow_cache: bool = False, - ) -> Project: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_project(self.cached_registry_proto, name) - return self._get_project(name) - - def _list_projects( - self, - tags: Optional[dict[str, str]], - ) -> List[Project]: - with GetSnowflakeConnection(self.registry_config) as conn: - query = f""" - SELECT project_proto FROM {self.registry_path}."PROJECTS" - """ - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - if not df.empty: - objects = [] - for row in df.iterrows(): - obj = Project.from_proto( - ProjectProto.FromString(row[1]["project_proto"]) - ) - if has_all_tags(obj.tags, tags): - objects.append(obj) - return objects - return [] - - def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None, - ) -> List[Project]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_projects(self.cached_registry_proto, tags) - return self._list_projects(tags) +import logging +import os +import uuid +from binascii import hexlify +from datetime import datetime, timedelta, timezone +from enum import Enum +from threading import Lock +from typing import Any, Callable, List, Literal, Optional, Union, cast + +from pydantic import ConfigDict, Field, StrictStr + +import feast +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.errors import ( + DataSourceObjectNotFoundException, + EntityNotFoundException, + FeatureServiceNotFoundException, + FeatureViewNotFoundException, + PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, + SavedDatasetNotFound, + ValidationReferenceNotFound, +) +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry import proto_registry_utils +from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, + execute_snowflake_statement, +) +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto +from feast.protos.feast.core.FeatureService_pb2 import ( + FeatureService as FeatureServiceProto, +) +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto +from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( + OnDemandFeatureView as OnDemandFeatureViewProto, +) +from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto +from feast.protos.feast.core.StreamFeatureView_pb2 import ( + StreamFeatureView as StreamFeatureViewProto, +) +from feast.protos.feast.core.ValidationProfile_pb2 import ( + ValidationReference as ValidationReferenceProto, +) +from feast.repo_config import RegistryConfig +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView +from feast.utils import _utc_now, has_all_tags + +logger = logging.getLogger(__name__) + + +class FeastMetadataKeys(Enum): + LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" + PROJECT_UUID = "project_uuid" + + +class SnowflakeRegistryConfig(RegistryConfig): + """Registry config for Snowflake""" + + registry_type: Literal["snowflake.registry"] = "snowflake.registry" + """ Registry type selector """ + + type: Literal["snowflake.registry"] = "snowflake.registry" + """ Registry type selector """ + + config_path: Optional[str] = os.path.expanduser("~/.snowsql/config") + """ Snowflake snowsql config path -- absolute path required (Cant use ~)""" + + connection_name: Optional[str] = None + """ Snowflake connector connection name -- typically defined in ~/.snowflake/connections.toml """ + + account: Optional[str] = None + """ Snowflake deployment identifier -- drop .snowflakecomputing.com """ + + user: Optional[str] = None + """ Snowflake user name """ + + password: Optional[str] = None + """ Snowflake password """ + + role: Optional[str] = None + """ Snowflake role name """ + + warehouse: Optional[str] = None + """ Snowflake warehouse name """ + + authenticator: Optional[str] = None + """ Snowflake authenticator name """ + + private_key: Optional[str] = None + """ Snowflake private key file path""" + + private_key_content: Optional[bytes] = None + """ Snowflake private key stored as bytes""" + + private_key_passphrase: Optional[str] = None + """ Snowflake private key file passphrase""" + + database: StrictStr + """ Snowflake database name """ + + schema_: Optional[str] = Field("PUBLIC", alias="schema") + """ Snowflake schema name """ + model_config = ConfigDict(populate_by_name=True) + + +class SnowflakeRegistry(BaseRegistry): + def __init__( + self, + registry_config, + project: str, + repo_path, + ): + assert registry_config is not None and isinstance( + registry_config, SnowflakeRegistryConfig + ), "SnowflakeRegistry needs a valid registry_config, a path does not work" + + self.registry_config = registry_config + self.registry_path = ( + f'"{self.registry_config.database}"."{self.registry_config.schema_}"' + ) + + with GetSnowflakeConnection(self.registry_config) as conn: + sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql" + with open(sql_function_file, "r") as file: + sql_file = file.read() + sql_cmds = sql_file.split(";") + for command in sql_cmds: + query = command.replace("REGISTRY_PATH", f"{self.registry_path}") + execute_snowflake_statement(conn, query) + + self.purge_feast_metadata = registry_config.purge_feast_metadata + self._sync_feast_metadata_to_projects_table() + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = _utc_now() + self._refresh_lock = Lock() + self.cached_registry_proto_ttl = timedelta( + seconds=( + registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) + ) + self.project = project + + def _sync_feast_metadata_to_projects_table(self): + feast_metadata_projects: set = [] + projects_set: set = [] + + with GetSnowflakeConnection(self.registry_config) as conn: + query = ( + f'SELECT DISTINCT project_id FROM {self.registry_path}."FEAST_METADATA"' + ) + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + for row in df.iterrows(): + feast_metadata_projects.add(row[1]["PROJECT_ID"]) + + if len(feast_metadata_projects) > 0: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f'SELECT project_id FROM {self.registry_path}."PROJECTS"' + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + for row in df.iterrows(): + projects_set.add(row[1]["PROJECT_ID"]) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects) - set(projects_set) + for project_name in projects_to_sync: + self.apply_project(Project(name=project_name), commit=True) + + if self.purge_feast_metadata: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + DELETE FROM {self.registry_path}."FEAST_METADATA" + """ + execute_snowflake_statement(conn, query) + + def refresh(self, project: Optional[str] = None): + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = _utc_now() + + def _refresh_cached_registry_if_necessary(self): + with self._refresh_lock: + expired = ( + self.cached_registry_proto is None + or self.cached_registry_proto_created is None + ) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + _utc_now() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) + ) + ) + + if expired: + logger.info("Registry cache expired, so refreshing") + self.refresh() + + def teardown(self): + with GetSnowflakeConnection(self.registry_config) as conn: + sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql" + with open(sql_function_file, "r") as file: + sqlFile = file.read() + sqlCommands = sqlFile.split(";") + for command in sqlCommands: + query = command.replace("REGISTRY_PATH", f"{self.registry_path}") + execute_snowflake_statement(conn, query) + + # apply operations + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + return self._apply_object( + "DATA_SOURCES", + project, + "DATA_SOURCE_NAME", + data_source, + "DATA_SOURCE_PROTO", + ) + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + return self._apply_object( + "ENTITIES", project, "ENTITY_NAME", entity, "ENTITY_PROTO" + ) + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + return self._apply_object( + "FEATURE_SERVICES", + project, + "FEATURE_SERVICE_NAME", + feature_service, + "FEATURE_SERVICE_PROTO", + ) + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1] + return self._apply_object( + fv_table_str, + project, + f"{fv_column_name}_NAME", + feature_view, + f"{fv_column_name}_PROTO", + ) + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + return self._apply_object( + "SAVED_DATASETS", + project, + "SAVED_DATASET_NAME", + saved_dataset, + "SAVED_DATASET_PROTO", + ) + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + return self._apply_object( + "VALIDATION_REFERENCES", + project, + "VALIDATION_REFERENCE_NAME", + validation_reference, + "VALIDATION_REFERENCE_PROTO", + ) + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + self._apply_object( + "MANAGED_INFRA", + project, + "INFRA_NAME", + infra, + "INFRA_PROTO", + name="infra_obj", + ) + + def _initialize_project_if_not_exists(self, project_name: str): + try: + self.get_project(project_name, allow_cache=True) + return + except ProjectObjectNotFoundException: + try: + self.get_project(project_name, allow_cache=False) + return + except ProjectObjectNotFoundException: + self.apply_project(Project(name=project_name), commit=True) + + def _apply_object( + self, + table: str, + project: str, + id_field_name: str, + obj: Any, + proto_field_name: str, + name: Optional[str] = None, + ): + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option + if not isinstance(obj, Project): + self._initialize_project_if_not_exists(project_name=project) + + name = name or (obj.name if hasattr(obj, "name") else None) + assert name, f"name needs to be provided for {obj}" + + update_datetime = _utc_now() + if hasattr(obj, "last_updated_timestamp"): + obj.last_updated_timestamp = update_datetime + + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + proto = hexlify(obj.to_proto().SerializeToString()).__str__()[1:] + query = f""" + UPDATE {self.registry_path}."{table}" + SET + {proto_field_name} = TO_BINARY({proto}), + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + {id_field_name.lower()} = '{name}' + """ + execute_snowflake_statement(conn, query) + + else: + obj_proto = obj.to_proto() + + if hasattr(obj_proto, "meta") and hasattr( + obj_proto.meta, "created_timestamp" + ): + obj_proto.meta.created_timestamp.FromDatetime(update_datetime) + + proto = hexlify(obj_proto.SerializeToString()).__str__()[1:] + if table == "FEATURE_VIEWS": + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '', '') + """ + elif "_FEATURE_VIEWS" in table: + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '') + """ + else: + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto})) + """ + execute_snowflake_statement(conn, query) + + if not isinstance(obj, Project): + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + + if not self.purge_feast_metadata: + self._set_last_updated_metadata(update_datetime, project) + + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + return self._apply_object( + "PERMISSIONS", + project, + "PERMISSION_NAME", + permission, + "PERMISSION_PROTO", + ) + + # delete operations + def delete_data_source(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "DATA_SOURCES", + name, + project, + "DATA_SOURCE_NAME", + DataSourceObjectNotFoundException, + ) + + def delete_entity(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "ENTITIES", name, project, "ENTITY_NAME", EntityNotFoundException + ) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "FEATURE_SERVICES", + name, + project, + "FEATURE_SERVICE_NAME", + FeatureServiceNotFoundException, + ) + + # can you have featureviews with the same name + def delete_feature_view(self, name: str, project: str, commit: bool = True): + deleted_count = 0 + for table in { + "FEATURE_VIEWS", + "ON_DEMAND_FEATURE_VIEWS", + "STREAM_FEATURE_VIEWS", + }: + deleted_count += self._delete_object( + table, name, project, "FEATURE_VIEW_NAME", None + ) + if deleted_count == 0: + raise FeatureViewNotFoundException(name, project) + + def delete_saved_dataset(self, name: str, project: str, allow_cache: bool = False): + self._delete_object( + "SAVED_DATASETS", + name, + project, + "SAVED_DATASET_NAME", + SavedDatasetNotFound, + ) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + self._delete_object( + "VALIDATION_REFERENCES", + name, + project, + "VALIDATION_REFERENCE_NAME", + ValidationReferenceNotFound, + ) + + def _delete_object( + self, + table: str, + name: str, + project: str, + id_field_name: str, + not_found_exception: Optional[Callable], + ): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + DELETE FROM {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + """ + cursor = execute_snowflake_statement(conn, query) + + if cursor.rowcount < 1 and not_found_exception: # type: ignore + raise not_found_exception(name, project) + self._set_last_updated_metadata(_utc_now(), project) + + return cursor.rowcount + + def delete_permission(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "PERMISSIONS", + name, + project, + "PERMISSION_NAME", + PermissionNotFoundException, + ) + + # get operations + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_data_source( + self.cached_registry_proto, name, project + ) + return self._get_object( + "DATA_SOURCES", + name, + project, + DataSourceProto, + DataSource, + "DATA_SOURCE_NAME", + "DATA_SOURCE_PROTO", + DataSourceObjectNotFoundException, + ) + + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_entity( + self.cached_registry_proto, name, project + ) + return self._get_object( + "ENTITIES", + name, + project, + EntityProto, + Entity, + "ENTITY_NAME", + "ENTITY_PROTO", + EntityNotFoundException, + ) + + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_service( + self.cached_registry_proto, name, project + ) + return self._get_object( + "FEATURE_SERVICES", + name, + project, + FeatureServiceProto, + FeatureService, + "FEATURE_SERVICE_NAME", + "FEATURE_SERVICE_PROTO", + FeatureServiceNotFoundException, + ) + + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "FEATURE_VIEWS", + name, + project, + FeatureViewProto, + FeatureView, + "FEATURE_VIEW_NAME", + "FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_any_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> BaseFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_any_feature_view( + self.cached_registry_proto, name, project + ) + fv = self._get_object( + "FEATURE_VIEWS", + name, + project, + FeatureViewProto, + FeatureView, + "FEATURE_VIEW_NAME", + "FEATURE_VIEW_PROTO", + None, + ) + + if not fv: + fv = self._get_object( + "STREAM_FEATURE_VIEWS", + name, + project, + StreamFeatureViewProto, + StreamFeatureView, + "STREAM_FEATURE_VIEW_NAME", + "STREAM_FEATURE_VIEW_PROTO", + None, + ) + if not fv: + fv = self._get_object( + "ON_DEMAND_FEATURE_VIEWS", + name, + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "ON_DEMAND_FEATURE_VIEW_NAME", + "ON_DEMAND_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + return fv + + def list_all_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[BaseFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_all_feature_views( + self.cached_registry_proto, project, tags + ) + + return ( + cast( + list[BaseFeatureView], + self.list_feature_views(project, allow_cache, tags), + ) + + cast( + list[BaseFeatureView], + self.list_stream_feature_views(project, allow_cache, tags), + ) + + cast( + list[BaseFeatureView], + self.list_on_demand_feature_views(project, allow_cache, tags), + ) + ) + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + infra_object = self._get_object( + "MANAGED_INFRA", + "infra_obj", + project, + InfraProto, + Infra, + "INFRA_NAME", + "INFRA_PROTO", + None, + ) + infra_object = infra_object or InfraProto() + return Infra.from_proto(infra_object) + + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_on_demand_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "ON_DEMAND_FEATURE_VIEWS", + name, + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "ON_DEMAND_FEATURE_VIEW_NAME", + "ON_DEMAND_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_saved_dataset( + self.cached_registry_proto, name, project + ) + return self._get_object( + "SAVED_DATASETS", + name, + project, + SavedDatasetProto, + SavedDataset, + "SAVED_DATASET_NAME", + "SAVED_DATASET_PROTO", + SavedDatasetNotFound, + ) + + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ): + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_stream_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "STREAM_FEATURE_VIEWS", + name, + project, + StreamFeatureViewProto, + StreamFeatureView, + "STREAM_FEATURE_VIEW_NAME", + "STREAM_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_validation_reference( + self.cached_registry_proto, name, project + ) + return self._get_object( + "VALIDATION_REFERENCES", + name, + project, + ValidationReferenceProto, + ValidationReference, + "VALIDATION_REFERENCE_NAME", + "VALIDATION_REFERENCE_PROTO", + ValidationReferenceNotFound, + ) + + def _get_object( + self, + table: str, + name: str, + project: str, + proto_class: Any, + python_class: Any, + id_field_name: str, + proto_field_name: str, + not_found_exception: Optional[Callable], + ): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + {proto_field_name} + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + _proto = proto_class.FromString(df.squeeze()) + return python_class.from_proto(_proto) + elif not_found_exception: + raise not_found_exception(name, project) + else: + return None + + def get_permission( + self, name: str, project: str, allow_cache: bool = False + ) -> Permission: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_permission( + self.cached_registry_proto, name, project + ) + return self._get_object( + "PERMISSIONS", + name, + project, + PermissionProto, + Permission, + "PERMISSION_NAME", + "PERMISSION_PROTO", + PermissionNotFoundException, + ) + + # list operations + def list_data_sources( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[DataSource]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_data_sources( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "DATA_SOURCES", + project, + DataSourceProto, + DataSource, + "DATA_SOURCE_PROTO", + tags=tags, + ) + + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_entities( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "ENTITIES", project, EntityProto, Entity, "ENTITY_PROTO", tags=tags + ) + + def list_feature_services( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureService]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_feature_services( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "FEATURE_SERVICES", + project, + FeatureServiceProto, + FeatureService, + "FEATURE_SERVICE_PROTO", + tags=tags, + ) + + def list_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[FeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_feature_views( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "FEATURE_VIEWS", + project, + FeatureViewProto, + FeatureView, + "FEATURE_VIEW_PROTO", + tags=tags, + ) + + def list_on_demand_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[OnDemandFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_on_demand_feature_views( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "ON_DEMAND_FEATURE_VIEWS", + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "ON_DEMAND_FEATURE_VIEW_PROTO", + tags=tags, + ) + + def list_saved_datasets( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[SavedDataset]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_saved_datasets( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "SAVED_DATASETS", + project, + SavedDatasetProto, + SavedDataset, + "SAVED_DATASET_PROTO", + tags=tags, + ) + + def list_stream_feature_views( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[StreamFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_stream_feature_views( + self.cached_registry_proto, project, tags + ) + return self._list_objects( + "STREAM_FEATURE_VIEWS", + project, + StreamFeatureViewProto, + StreamFeatureView, + "STREAM_FEATURE_VIEW_PROTO", + tags=tags, + ) + + def list_validation_references( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[ValidationReference]: + return self._list_objects( + "VALIDATION_REFERENCES", + project, + ValidationReferenceProto, + ValidationReference, + "VALIDATION_REFERENCE_PROTO", + tags=tags, + ) + + def _list_objects( + self, + table: str, + project: str, + proto_class: Any, + python_class: Any, + proto_field_name: str, + tags: Optional[dict[str, str]] = None, + ): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + {proto_field_name} + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + if not df.empty: + objects = [] + for row in df.iterrows(): + obj = python_class.from_proto( + proto_class.FromString(row[1][proto_field_name]) + ) + if has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def list_permissions( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Permission]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_permissions( + self.cached_registry_proto, project + ) + return self._list_objects( + "PERMISSIONS", + project, + PermissionProto, + Permission, + "PERMISSION_PROTO", + tags, + ) + + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1] + python_class, proto_class = self._infer_fv_classes(feature_view) + + if python_class in {OnDemandFeatureView}: + raise ValueError( + f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" + ) + fv: Union[FeatureView, StreamFeatureView] = self._get_object( + fv_table_str, + feature_view.name, + project, + proto_class, + python_class, + f"{fv_column_name}_NAME", + f"{fv_column_name}_PROTO", + FeatureViewNotFoundException, + ) + fv.materialization_intervals.append((start_date, end_date)) + self._apply_object( + fv_table_str, + project, + f"{fv_column_name}_NAME", + fv, + f"{fv_column_name}_PROTO", + ) + + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_project_metadata( + self.cached_registry_proto, project + ) + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + metadata_key, + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + project_metadata = ProjectMetadata(project_name=project) + for row in df.iterrows(): + if row[1]["METADATA_KEY"] == FeastMetadataKeys.PROJECT_UUID.value: + project_metadata.project_uuid = row[1]["METADATA_VALUE"] + break + # TODO(adchia): Add other project metadata in a structured way + return [project_metadata] + return [] + + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1].lower() + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."{fv_table_str}" + WHERE + project_id = '{project}' + AND {fv_column_name}_name = '{feature_view.name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + if metadata_bytes: + metadata_hex = hexlify(metadata_bytes).__str__()[1:] + else: + metadata_hex = "''" + query = f""" + UPDATE {self.registry_path}."{fv_table_str}" + SET + user_metadata = TO_BINARY({metadata_hex}), + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + project_id = '{project}' + AND {fv_column_name}_name = '{feature_view.name}' + """ + execute_snowflake_statement(conn, query) + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1].lower() + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + user_metadata + FROM + {self.registry_path}."{fv_table_str}" + WHERE + {fv_column_name}_name = '{feature_view.name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + return df.squeeze() + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def proto(self) -> RegistryProto: + r = RegistryProto() + last_updated_timestamps = [] + + def process_project(project: Project): + nonlocal r, last_updated_timestamps + project_name = project.name + last_updated_timestamp = project.last_updated_timestamp + + try: + cached_project = self.get_project(project_name, True) + except ProjectObjectNotFoundException: + cached_project = None + + allow_cache = False + + if cached_project is not None: + allow_cache = ( + last_updated_timestamp <= cached_project.last_updated_timestamp + ) + + r.projects.extend([project.to_proto()]) + last_updated_timestamps.append(last_updated_timestamp) + + for lister, registry_proto_field in [ + (self.list_entities, r.entities), + (self.list_feature_views, r.feature_views), + (self.list_data_sources, r.data_sources), + (self.list_on_demand_feature_views, r.on_demand_feature_views), + (self.list_stream_feature_views, r.stream_feature_views), + (self.list_feature_services, r.feature_services), + (self.list_saved_datasets, r.saved_datasets), + (self.list_validation_references, r.validation_references), + (self.list_permissions, r.permissions), + ]: + objs: List[Any] = lister(project_name, allow_cache) # type: ignore + if objs: + obj_protos = [obj.to_proto() for obj in objs] + for obj_proto in obj_protos: + if "spec" in obj_proto.DESCRIPTOR.fields_by_name: + obj_proto.spec.project = project_name + else: + obj_proto.project = project_name + registry_proto_field.extend(obj_protos) + + # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, + # the registry proto only has a single infra field, which we're currently setting as the "last" project. + r.infra.CopyFrom(self.get_infra(project_name).to_proto()) + + projects_list = self.list_projects(allow_cache=False) + for project in projects_list: + process_project(project) + + if last_updated_timestamps: + r.last_updated.FromDatetime(max(last_updated_timestamps)) + + return r + + def _get_last_updated_metadata(self, project: str): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if df.empty: + return None + + return datetime.fromtimestamp(int(df.squeeze()), tz=timezone.utc) + + def _infer_fv_classes(self, feature_view): + if isinstance(feature_view, StreamFeatureView): + python_class, proto_class = StreamFeatureView, StreamFeatureViewProto + elif isinstance(feature_view, FeatureView): + python_class, proto_class = FeatureView, FeatureViewProto + elif isinstance(feature_view, OnDemandFeatureView): + python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return python_class, proto_class + + def _infer_fv_table(self, feature_view) -> str: + if isinstance(feature_view, StreamFeatureView): + table = "STREAM_FEATURE_VIEWS" + elif isinstance(feature_view, FeatureView): + table = "FEATURE_VIEWS" + elif isinstance(feature_view, OnDemandFeatureView): + table = "ON_DEMAND_FEATURE_VIEWS" + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return table + + def _maybe_init_project_metadata(self, project): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.PROJECT_UUID.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if df.empty: + new_project_uuid = f"{uuid.uuid4()}" + query = f""" + INSERT INTO {self.registry_path}."FEAST_METADATA" + VALUES + ('{project}', '{FeastMetadataKeys.PROJECT_UUID.value}', '{new_project_uuid}', CURRENT_TIMESTAMP()) + """ + execute_snowflake_statement(conn, query) + + def _set_last_updated_metadata(self, last_updated: datetime, project: str): + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + update_time = int(last_updated.timestamp()) + if not df.empty: + query = f""" + UPDATE {self.registry_path}."FEAST_METADATA" + SET + project_id = '{project}', + metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', + metadata_value = '{update_time}', + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + """ + execute_snowflake_statement(conn, query) + + else: + query = f""" + INSERT INTO {self.registry_path}."FEAST_METADATA" + VALUES + ('{project}', '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', '{update_time}', CURRENT_TIMESTAMP()) + """ + execute_snowflake_statement(conn, query) + + def commit(self): + pass + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + return self._apply_object( + "PROJECTS", project.name, "project_name", project, "project_proto" + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + project = self.get_project(name, allow_cache=False) + if project: + with GetSnowflakeConnection(self.registry_config) as conn: + for table in { + "MANAGED_INFRA", + "SAVED_DATASETS", + "VALIDATION_REFERENCES", + "FEATURE_SERVICES", + "FEATURE_VIEWS", + "ON_DEMAND_FEATURE_VIEWS", + "STREAM_FEATURE_VIEWS", + "DATA_SOURCES", + "ENTITIES", + "PERMISSIONS", + "FEAST_METADATA", + "PROJECTS", + }: + query = f""" + DELETE FROM {self.registry_path}."{table}" + WHERE + project_id = '{project}' + """ + execute_snowflake_statement(conn, query) + return + + raise ProjectNotFoundException(name) + + def _get_project( + self, + name: str, + ) -> Project: + return self._get_object( + table="PROJECTS", + name=name, + project=name, + proto_class=ProjectProto, + python_class=Project, + id_field_name="project_name", + proto_field_name="project_proto", + not_found_exception=ProjectObjectNotFoundException, + ) + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_project(self.cached_registry_proto, name) + return self._get_project(name) + + def _list_projects( + self, + tags: Optional[dict[str, str]], + ) -> List[Project]: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT project_proto FROM {self.registry_path}."PROJECTS" + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + if not df.empty: + objects = [] + for row in df.iterrows(): + obj = Project.from_proto( + ProjectProto.FromString(row[1]["project_proto"]) + ) + if has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_projects(self.cached_registry_proto, tags) + return self._list_projects(tags) diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index 36b8174d2b6..68dcd893f9d 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -1,1258 +1,1258 @@ -import logging -import uuid -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timezone -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, cast - -from pydantic import StrictInt, StrictStr -from sqlalchemy import ( # type: ignore - BigInteger, - Column, - Index, - LargeBinary, - MetaData, - String, - Table, - create_engine, - delete, - insert, - select, - update, -) -from sqlalchemy.engine import Engine - -from feast import utils -from feast.base_feature_view import BaseFeatureView -from feast.data_source import DataSource -from feast.entity import Entity -from feast.errors import ( - DataSourceObjectNotFoundException, - EntityNotFoundException, - FeatureServiceNotFoundException, - FeatureViewNotFoundException, - PermissionNotFoundException, - ProjectNotFoundException, - ProjectObjectNotFoundException, - SavedDatasetNotFound, - ValidationReferenceNotFound, -) -from feast.feature_service import FeatureService -from feast.feature_view import FeatureView -from feast.infra.infra_object import Infra -from feast.infra.registry.caching_registry import CachingRegistry -from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.permission import Permission -from feast.project import Project -from feast.project_metadata import ProjectMetadata -from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto -from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto -from feast.protos.feast.core.FeatureService_pb2 import ( - FeatureService as FeatureServiceProto, -) -from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto -from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto -from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( - OnDemandFeatureView as OnDemandFeatureViewProto, -) -from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto -from feast.protos.feast.core.Project_pb2 import Project as ProjectProto -from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto -from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto -from feast.protos.feast.core.StreamFeatureView_pb2 import ( - StreamFeatureView as StreamFeatureViewProto, -) -from feast.protos.feast.core.ValidationProfile_pb2 import ( - ValidationReference as ValidationReferenceProto, -) -from feast.repo_config import RegistryConfig -from feast.saved_dataset import SavedDataset, ValidationReference -from feast.stream_feature_view import StreamFeatureView -from feast.utils import _utc_now - -metadata = MetaData() - - -projects = Table( - "projects", - metadata, - Column("project_id", String(255), primary_key=True), - Column("project_name", String(255), nullable=False), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("project_proto", LargeBinary, nullable=False), -) - -Index("idx_projects_project_id", projects.c.project_id) - -entities = Table( - "entities", - metadata, - Column("entity_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("entity_proto", LargeBinary, nullable=False), -) - -Index("idx_entities_project_id", entities.c.project_id) - -data_sources = Table( - "data_sources", - metadata, - Column("data_source_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("data_source_proto", LargeBinary, nullable=False), -) - -Index("idx_data_sources_project_id", data_sources.c.project_id) - -feature_views = Table( - "feature_views", - metadata, - Column("feature_view_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("materialized_intervals", LargeBinary, nullable=True), - Column("feature_view_proto", LargeBinary, nullable=False), - Column("user_metadata", LargeBinary, nullable=True), -) - -Index("idx_feature_views_project_id", feature_views.c.project_id) - -stream_feature_views = Table( - "stream_feature_views", - metadata, - Column("feature_view_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("feature_view_proto", LargeBinary, nullable=False), - Column("user_metadata", LargeBinary, nullable=True), -) - -Index("idx_stream_feature_views_project_id", stream_feature_views.c.project_id) - -on_demand_feature_views = Table( - "on_demand_feature_views", - metadata, - Column("feature_view_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("feature_view_proto", LargeBinary, nullable=False), - Column("user_metadata", LargeBinary, nullable=True), -) - -Index("idx_on_demand_feature_views_project_id", on_demand_feature_views.c.project_id) - -feature_services = Table( - "feature_services", - metadata, - Column("feature_service_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("feature_service_proto", LargeBinary, nullable=False), -) - -Index("idx_feature_services_project_id", feature_services.c.project_id) - -saved_datasets = Table( - "saved_datasets", - metadata, - Column("saved_dataset_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("saved_dataset_proto", LargeBinary, nullable=False), -) - -Index("idx_saved_datasets_project_id", saved_datasets.c.project_id) - -validation_references = Table( - "validation_references", - metadata, - Column("validation_reference_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("validation_reference_proto", LargeBinary, nullable=False), -) -Index("idx_validation_references_project_id", validation_references.c.project_id) - -managed_infra = Table( - "managed_infra", - metadata, - Column("infra_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("infra_proto", LargeBinary, nullable=False), -) - -Index("idx_managed_infra_project_id", managed_infra.c.project_id) - -permissions = Table( - "permissions", - metadata, - Column("permission_name", String(255), primary_key=True), - Column("project_id", String(255), primary_key=True), - Column("last_updated_timestamp", BigInteger, nullable=False), - Column("permission_proto", LargeBinary, nullable=False), -) - -Index("idx_permissions_project_id", permissions.c.project_id) - - -class FeastMetadataKeys(Enum): - LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" - PROJECT_UUID = "project_uuid" - - -feast_metadata = Table( - "feast_metadata", - metadata, - Column("project_id", String(255), primary_key=True), - Column("metadata_key", String(50), primary_key=True), - Column("metadata_value", String(50), nullable=False), - Column("last_updated_timestamp", BigInteger, nullable=False), -) - -Index("idx_feast_metadata_project_id", feast_metadata.c.project_id) - -logger = logging.getLogger(__name__) - - -class SqlRegistryConfig(RegistryConfig): - registry_type: StrictStr = "sql" - """ str: Provider name or a class name that implements Registry.""" - - path: StrictStr = "" - """ str: Path to metadata store. - If registry_type is 'sql', then this is a database URL as expected by SQLAlchemy """ - - read_path: Optional[StrictStr] = None - """ str: Read Path to metadata store if different from path. - If registry_type is 'sql', then this is a Read Endpoint for database URL. If not set, path will be used for read and write. """ - - sqlalchemy_config_kwargs: Dict[str, Any] = {"echo": False} - """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ - - cache_mode: StrictStr = "sync" - """ str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)""" - - thread_pool_executor_worker_count: StrictInt = 0 - """ int: Number of worker threads to use for asynchronous caching in SQL Registry. If set to 0, it doesn't use ThreadPoolExecutor. """ - - -class SqlRegistry(CachingRegistry): - def __init__( - self, - registry_config, - project: str, - repo_path: Optional[Path], - ): - assert registry_config is not None and isinstance( - registry_config, SqlRegistryConfig - ), "SqlRegistry needs a valid registry_config" - - self.registry_config = registry_config - - self.write_engine: Engine = create_engine( - registry_config.path, **registry_config.sqlalchemy_config_kwargs - ) - if registry_config.read_path: - self.read_engine: Engine = create_engine( - registry_config.read_path, - **registry_config.sqlalchemy_config_kwargs, - ) - else: - self.read_engine = self.write_engine - metadata.create_all(self.write_engine) - self.thread_pool_executor_worker_count = ( - registry_config.thread_pool_executor_worker_count - ) - self.purge_feast_metadata = registry_config.purge_feast_metadata - # Sync feast_metadata to projects table - # when purge_feast_metadata is set to True, Delete data from - # feast_metadata table and list_project_metadata will not return any data - self._sync_feast_metadata_to_projects_table() - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) - super().__init__( - project=project, - cache_ttl_seconds=registry_config.cache_ttl_seconds, - cache_mode=registry_config.cache_mode, - ) - - def _sync_feast_metadata_to_projects_table(self): - feast_metadata_projects: dict = {} - projects_set: set = [] - with self.read_engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value - ) - rows = conn.execute(stmt).all() - for row in rows: - feast_metadata_projects[row._mapping["project_id"]] = int( - row._mapping["last_updated_timestamp"] - ) - - if len(feast_metadata_projects) > 0: - with self.read_engine.begin() as conn: - stmt = select(projects) - rows = conn.execute(stmt).all() - for row in rows: - projects_set.append(row._mapping["project_id"]) - - # Find object in feast_metadata_projects but not in projects - projects_to_sync = set(feast_metadata_projects.keys()) - set(projects_set) - for project_name in projects_to_sync: - self.apply_project( - Project( - name=project_name, - created_timestamp=datetime.fromtimestamp( - feast_metadata_projects[project_name], tz=timezone.utc - ), - ), - commit=True, - ) - - if self.purge_feast_metadata: - with self.write_engine.begin() as conn: - for project_name in feast_metadata_projects: - stmt = delete(feast_metadata).where( - feast_metadata.c.project_id == project_name - ) - conn.execute(stmt) - - def teardown(self): - for t in { - entities, - data_sources, - feature_views, - feature_services, - on_demand_feature_views, - saved_datasets, - validation_references, - permissions, - }: - with self.write_engine.begin() as conn: - stmt = delete(t) - conn.execute(stmt) - - def _get_stream_feature_view(self, name: str, project: str): - return self._get_object( - table=stream_feature_views, - name=name, - project=project, - proto_class=StreamFeatureViewProto, - python_class=StreamFeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=FeatureViewNotFoundException, - ) - - def _list_stream_feature_views( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[StreamFeatureView]: - return self._list_objects( - stream_feature_views, - project, - StreamFeatureViewProto, - StreamFeatureView, - "feature_view_proto", - tags=tags, - ) - - def apply_entity(self, entity: Entity, project: str, commit: bool = True): - return self._apply_object( - table=entities, - project=project, - id_field_name="entity_name", - obj=entity, - proto_field_name="entity_proto", - ) - - def _get_entity(self, name: str, project: str) -> Entity: - return self._get_object( - table=entities, - name=name, - project=project, - proto_class=EntityProto, - python_class=Entity, - id_field_name="entity_name", - proto_field_name="entity_proto", - not_found_exception=EntityNotFoundException, - ) - - def _get_any_feature_view(self, name: str, project: str) -> BaseFeatureView: - fv = self._get_object( - table=feature_views, - name=name, - project=project, - proto_class=FeatureViewProto, - python_class=FeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=None, - ) - - if not fv: - fv = self._get_object( - table=on_demand_feature_views, - name=name, - project=project, - proto_class=OnDemandFeatureViewProto, - python_class=OnDemandFeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=None, - ) - - if not fv: - fv = self._get_object( - table=stream_feature_views, - name=name, - project=project, - proto_class=StreamFeatureViewProto, - python_class=StreamFeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=FeatureViewNotFoundException, - ) - return fv - - def _list_all_feature_views( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[BaseFeatureView]: - return ( - cast( - list[BaseFeatureView], - self._list_feature_views(project=project, tags=tags), - ) - + cast( - list[BaseFeatureView], - self._list_stream_feature_views(project=project, tags=tags), - ) - + cast( - list[BaseFeatureView], - self._list_on_demand_feature_views(project=project, tags=tags), - ) - ) - - def _get_feature_view(self, name: str, project: str) -> FeatureView: - return self._get_object( - table=feature_views, - name=name, - project=project, - proto_class=FeatureViewProto, - python_class=FeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=FeatureViewNotFoundException, - ) - - def _get_on_demand_feature_view( - self, name: str, project: str - ) -> OnDemandFeatureView: - return self._get_object( - table=on_demand_feature_views, - name=name, - project=project, - proto_class=OnDemandFeatureViewProto, - python_class=OnDemandFeatureView, - id_field_name="feature_view_name", - proto_field_name="feature_view_proto", - not_found_exception=FeatureViewNotFoundException, - ) - - def _get_feature_service(self, name: str, project: str) -> FeatureService: - return self._get_object( - table=feature_services, - name=name, - project=project, - proto_class=FeatureServiceProto, - python_class=FeatureService, - id_field_name="feature_service_name", - proto_field_name="feature_service_proto", - not_found_exception=FeatureServiceNotFoundException, - ) - - def _get_saved_dataset(self, name: str, project: str) -> SavedDataset: - return self._get_object( - table=saved_datasets, - name=name, - project=project, - proto_class=SavedDatasetProto, - python_class=SavedDataset, - id_field_name="saved_dataset_name", - proto_field_name="saved_dataset_proto", - not_found_exception=SavedDatasetNotFound, - ) - - def _get_validation_reference(self, name: str, project: str) -> ValidationReference: - return self._get_object( - table=validation_references, - name=name, - project=project, - proto_class=ValidationReferenceProto, - python_class=ValidationReference, - id_field_name="validation_reference_name", - proto_field_name="validation_reference_proto", - not_found_exception=ValidationReferenceNotFound, - ) - - def _list_validation_references( - self, project: str, tags: Optional[dict[str, str]] = None - ) -> List[ValidationReference]: - return self._list_objects( - table=validation_references, - project=project, - proto_class=ValidationReferenceProto, - python_class=ValidationReference, - proto_field_name="validation_reference_proto", - tags=tags, - ) - - def _list_entities( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[Entity]: - return self._list_objects( - entities, project, EntityProto, Entity, "entity_proto", tags=tags - ) - - def delete_entity(self, name: str, project: str, commit: bool = True): - return self._delete_object( - entities, name, project, "entity_name", EntityNotFoundException - ) - - def delete_feature_view(self, name: str, project: str, commit: bool = True): - deleted_count = 0 - for table in { - feature_views, - on_demand_feature_views, - stream_feature_views, - }: - deleted_count += self._delete_object( - table, name, project, "feature_view_name", None - ) - if deleted_count == 0: - raise FeatureViewNotFoundException(name, project) - - def delete_feature_service(self, name: str, project: str, commit: bool = True): - return self._delete_object( - feature_services, - name, - project, - "feature_service_name", - FeatureServiceNotFoundException, - ) - - def _get_data_source(self, name: str, project: str) -> DataSource: - return self._get_object( - table=data_sources, - name=name, - project=project, - proto_class=DataSourceProto, - python_class=DataSource, - id_field_name="data_source_name", - proto_field_name="data_source_proto", - not_found_exception=DataSourceObjectNotFoundException, - ) - - def _list_data_sources( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[DataSource]: - return self._list_objects( - data_sources, - project, - DataSourceProto, - DataSource, - "data_source_proto", - tags=tags, - ) - - def apply_data_source( - self, data_source: DataSource, project: str, commit: bool = True - ): - return self._apply_object( - data_sources, project, "data_source_name", data_source, "data_source_proto" - ) - - def apply_feature_view( - self, feature_view: BaseFeatureView, project: str, commit: bool = True - ): - fv_table = self._infer_fv_table(feature_view) - - return self._apply_object( - fv_table, project, "feature_view_name", feature_view, "feature_view_proto" - ) - - def apply_feature_service( - self, feature_service: FeatureService, project: str, commit: bool = True - ): - return self._apply_object( - feature_services, - project, - "feature_service_name", - feature_service, - "feature_service_proto", - ) - - def delete_data_source(self, name: str, project: str, commit: bool = True): - with self.write_engine.begin() as conn: - stmt = delete(data_sources).where( - data_sources.c.data_source_name == name, - data_sources.c.project_id == project, - ) - rows = conn.execute(stmt) - if rows.rowcount < 1: - raise DataSourceObjectNotFoundException(name, project) - - def _list_feature_services( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[FeatureService]: - return self._list_objects( - feature_services, - project, - FeatureServiceProto, - FeatureService, - "feature_service_proto", - tags=tags, - ) - - def _list_feature_views( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[FeatureView]: - return self._list_objects( - feature_views, - project, - FeatureViewProto, - FeatureView, - "feature_view_proto", - tags=tags, - ) - - def _list_saved_datasets( - self, project: str, tags: Optional[dict[str, str]] = None - ) -> List[SavedDataset]: - return self._list_objects( - saved_datasets, - project, - SavedDatasetProto, - SavedDataset, - "saved_dataset_proto", - tags=tags, - ) - - def _list_on_demand_feature_views( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[OnDemandFeatureView]: - return self._list_objects( - on_demand_feature_views, - project, - OnDemandFeatureViewProto, - OnDemandFeatureView, - "feature_view_proto", - tags=tags, - ) - - def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: - with self.read_engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.project_id == project, - ) - rows = conn.execute(stmt).all() - if rows: - project_metadata = ProjectMetadata(project_name=project) - for row in rows: - if ( - row._mapping["metadata_key"] - == FeastMetadataKeys.PROJECT_UUID.value - ): - project_metadata.project_uuid = row._mapping["metadata_value"] - break - # TODO(adchia): Add other project metadata in a structured way - return [project_metadata] - return [] - - def apply_saved_dataset( - self, - saved_dataset: SavedDataset, - project: str, - commit: bool = True, - ): - return self._apply_object( - saved_datasets, - project, - "saved_dataset_name", - saved_dataset, - "saved_dataset_proto", - ) - - def apply_validation_reference( - self, - validation_reference: ValidationReference, - project: str, - commit: bool = True, - ): - return self._apply_object( - validation_references, - project, - "validation_reference_name", - validation_reference, - "validation_reference_proto", - ) - - def apply_materialization( - self, - feature_view: Union[FeatureView, OnDemandFeatureView], - project: str, - start_date: datetime, - end_date: datetime, - commit: bool = True, - ): - table = self._infer_fv_table(feature_view) - python_class, proto_class = self._infer_fv_classes(feature_view) - - if python_class in {OnDemandFeatureView}: - raise ValueError( - f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" - ) - fv: Union[FeatureView, StreamFeatureView] = self._get_object( - table, - feature_view.name, - project, - proto_class, - python_class, - "feature_view_name", - "feature_view_proto", - FeatureViewNotFoundException, - ) - fv.materialization_intervals.append((start_date, end_date)) - self._apply_object( - table, project, "feature_view_name", fv, "feature_view_proto" - ) - - def delete_validation_reference(self, name: str, project: str, commit: bool = True): - self._delete_object( - validation_references, - name, - project, - "validation_reference_name", - ValidationReferenceNotFound, - ) - - def update_infra(self, infra: Infra, project: str, commit: bool = True): - self._apply_object( - table=managed_infra, - project=project, - id_field_name="infra_name", - obj=infra, - proto_field_name="infra_proto", - name="infra_obj", - ) - - def _get_infra(self, project: str) -> Infra: - infra_object = self._get_object( - table=managed_infra, - name="infra_obj", - project=project, - proto_class=InfraProto, - python_class=Infra, - id_field_name="infra_name", - proto_field_name="infra_proto", - not_found_exception=None, - ) - if infra_object: - return infra_object - return Infra() - - def apply_user_metadata( - self, - project: str, - feature_view: BaseFeatureView, - metadata_bytes: Optional[bytes], - ): - table = self._infer_fv_table(feature_view) - - name = feature_view.name - with self.write_engine.begin() as conn: - stmt = select(table).where( - getattr(table.c, "feature_view_name") == name, - table.c.project_id == project, - ) - row = conn.execute(stmt).first() - update_datetime = _utc_now() - update_time = int(update_datetime.timestamp()) - if row: - values = { - "user_metadata": metadata_bytes, - "last_updated_timestamp": update_time, - } - update_stmt = ( - update(table) - .where( - getattr(table.c, "feature_view_name") == name, - table.c.project_id == project, - ) - .values( - values, - ) - ) - conn.execute(update_stmt) - else: - raise FeatureViewNotFoundException(feature_view.name, project=project) - - def _infer_fv_table(self, feature_view): - if isinstance(feature_view, StreamFeatureView): - table = stream_feature_views - elif isinstance(feature_view, FeatureView): - table = feature_views - elif isinstance(feature_view, OnDemandFeatureView): - table = on_demand_feature_views - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - return table - - def _infer_fv_classes(self, feature_view): - if isinstance(feature_view, StreamFeatureView): - python_class, proto_class = StreamFeatureView, StreamFeatureViewProto - elif isinstance(feature_view, FeatureView): - python_class, proto_class = FeatureView, FeatureViewProto - elif isinstance(feature_view, OnDemandFeatureView): - python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") - return python_class, proto_class - - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: - table = self._infer_fv_table(feature_view) - - name = feature_view.name - with self.read_engine.begin() as conn: - stmt = select(table).where(getattr(table.c, "feature_view_name") == name) - row = conn.execute(stmt).first() - if row: - return row._mapping["user_metadata"] - else: - raise FeatureViewNotFoundException(feature_view.name, project=project) - - def proto(self) -> RegistryProto: - r = RegistryProto() - last_updated_timestamps = [] - - def process_project(project: Project): - nonlocal r, last_updated_timestamps - project_name = project.name - last_updated_timestamp = project.last_updated_timestamp - - try: - cached_project = self.get_project(project_name, True) - except ProjectObjectNotFoundException: - cached_project = None - - allow_cache = False - - if cached_project is not None: - allow_cache = ( - last_updated_timestamp <= cached_project.last_updated_timestamp - ) - - r.projects.extend([project.to_proto()]) - last_updated_timestamps.append(last_updated_timestamp) - - for lister, registry_proto_field in [ - (self.list_entities, r.entities), - (self.list_feature_views, r.feature_views), - (self.list_data_sources, r.data_sources), - (self.list_on_demand_feature_views, r.on_demand_feature_views), - (self.list_stream_feature_views, r.stream_feature_views), - (self.list_feature_services, r.feature_services), - (self.list_saved_datasets, r.saved_datasets), - (self.list_validation_references, r.validation_references), - (self.list_permissions, r.permissions), - ]: - objs: List[Any] = lister(project_name, allow_cache) # type: ignore - if objs: - obj_protos = [obj.to_proto() for obj in objs] - for obj_proto in obj_protos: - if "spec" in obj_proto.DESCRIPTOR.fields_by_name: - obj_proto.spec.project = project_name - else: - obj_proto.project = project_name - registry_proto_field.extend(obj_protos) - - # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, - # the registry proto only has a single infra field, which we're currently setting as the "last" project. - r.infra.CopyFrom(self.get_infra(project_name).to_proto()) - - projects_list = self.list_projects(allow_cache=False) - if self.thread_pool_executor_worker_count == 0: - for project in projects_list: - process_project(project) - else: - with ThreadPoolExecutor( - max_workers=self.thread_pool_executor_worker_count - ) as executor: - executor.map(process_project, projects_list) - - if last_updated_timestamps: - r.last_updated.FromDatetime(max(last_updated_timestamps)) - - return r - - def commit(self): - # This method is a no-op since we're always writing values eagerly to the db. - pass - - def _initialize_project_if_not_exists(self, project_name: str): - try: - self.get_project(project_name, allow_cache=True) - return - except ProjectObjectNotFoundException: - try: - self.get_project(project_name, allow_cache=False) - return - except ProjectObjectNotFoundException: - self.apply_project(Project(name=project_name), commit=True) - - def _apply_object( - self, - table: Table, - project: str, - id_field_name: str, - obj: Any, - proto_field_name: str, - name: Optional[str] = None, - ): - if not self.purge_feast_metadata: - self._maybe_init_project_metadata(project) - # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option - if not isinstance(obj, Project): - self._initialize_project_if_not_exists(project_name=project) - name = name or (obj.name if hasattr(obj, "name") else None) - assert name, f"name needs to be provided for {obj}" - - with self.write_engine.begin() as conn: - update_datetime = _utc_now() - update_time = int(update_datetime.timestamp()) - stmt = select(table).where( - getattr(table.c, id_field_name) == name, table.c.project_id == project - ) - row = conn.execute(stmt).first() - if hasattr(obj, "last_updated_timestamp"): - obj.last_updated_timestamp = update_datetime - - if row: - if proto_field_name in [ - "entity_proto", - "saved_dataset_proto", - "feature_view_proto", - "feature_service_proto", - "permission_proto", - "project_proto", - ]: - deserialized_proto = self.deserialize_registry_values( - row._mapping[proto_field_name], type(obj) - ) - obj.created_timestamp = ( - deserialized_proto.meta.created_timestamp.ToDatetime().replace( - tzinfo=timezone.utc - ) - ) - if isinstance(obj, (FeatureView, StreamFeatureView)): - obj.update_materialization_intervals( - type(obj) - .from_proto(deserialized_proto) - .materialization_intervals - ) - values = { - proto_field_name: obj.to_proto().SerializeToString(), - "last_updated_timestamp": update_time, - } - update_stmt = ( - update(table) - .where( - getattr(table.c, id_field_name) == name, - table.c.project_id == project, - ) - .values( - values, - ) - ) - conn.execute(update_stmt) - else: - obj_proto = obj.to_proto() - - if hasattr(obj_proto, "meta") and hasattr( - obj_proto.meta, "created_timestamp" - ): - if not obj_proto.meta.HasField("created_timestamp"): - obj_proto.meta.created_timestamp.FromDatetime(update_datetime) - - values = { - id_field_name: name, - proto_field_name: obj_proto.SerializeToString(), - "last_updated_timestamp": update_time, - "project_id": project, - } - insert_stmt = insert(table).values( - values, - ) - conn.execute(insert_stmt) - - if not isinstance(obj, Project): - self.apply_project( - self.get_project(name=project, allow_cache=False), commit=True - ) - if not self.purge_feast_metadata: - self._set_last_updated_metadata(update_datetime, project) - - def _maybe_init_project_metadata(self, project): - # Initialize project metadata if needed - with self.write_engine.begin() as conn: - update_datetime = _utc_now() - update_time = int(update_datetime.timestamp()) - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value, - feast_metadata.c.project_id == project, - ) - row = conn.execute(stmt).first() - if not row: - new_project_uuid = f"{uuid.uuid4()}" - values = { - "metadata_key": FeastMetadataKeys.PROJECT_UUID.value, - "metadata_value": new_project_uuid, - "last_updated_timestamp": update_time, - "project_id": project, - } - insert_stmt = insert(feast_metadata).values(values) - conn.execute(insert_stmt) - - def _delete_object( - self, - table: Table, - name: str, - project: str, - id_field_name: str, - not_found_exception: Optional[Callable], - ): - with self.write_engine.begin() as conn: - stmt = delete(table).where( - getattr(table.c, id_field_name) == name, table.c.project_id == project - ) - rows = conn.execute(stmt) - if rows.rowcount < 1 and not_found_exception: - raise not_found_exception(name, project) - self.apply_project( - self.get_project(name=project, allow_cache=False), commit=True - ) - if not self.purge_feast_metadata: - self._set_last_updated_metadata(_utc_now(), project) - - return rows.rowcount - - def _get_object( - self, - table: Table, - name: str, - project: str, - proto_class: Any, - python_class: Any, - id_field_name: str, - proto_field_name: str, - not_found_exception: Optional[Callable], - ): - with self.read_engine.begin() as conn: - stmt = select(table).where( - getattr(table.c, id_field_name) == name, table.c.project_id == project - ) - row = conn.execute(stmt).first() - if row: - _proto = proto_class.FromString(row._mapping[proto_field_name]) - return python_class.from_proto(_proto) - if not_found_exception: - raise not_found_exception(name, project) - else: - return None - - def _list_objects( - self, - table: Table, - project: str, - proto_class: Any, - python_class: Any, - proto_field_name: str, - tags: Optional[dict[str, str]] = None, - ): - with self.read_engine.begin() as conn: - stmt = select(table).where(table.c.project_id == project) - rows = conn.execute(stmt).all() - if rows: - objects = [] - for row in rows: - obj = python_class.from_proto( - proto_class.FromString(row._mapping[proto_field_name]) - ) - if utils.has_all_tags(obj.tags, tags): - objects.append(obj) - return objects - return [] - - def _set_last_updated_metadata(self, last_updated: datetime, project: str): - with self.write_engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key - == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - feast_metadata.c.project_id == project, - ) - row = conn.execute(stmt).first() - - update_time = int(last_updated.timestamp()) - - values = { - "metadata_key": FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - "metadata_value": f"{update_time}", - "last_updated_timestamp": update_time, - "project_id": project, - } - if row: - update_stmt = ( - update(feast_metadata) - .where( - feast_metadata.c.metadata_key - == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - feast_metadata.c.project_id == project, - ) - .values(values) - ) - conn.execute(update_stmt) - else: - insert_stmt = insert(feast_metadata).values( - values, - ) - conn.execute(insert_stmt) - - def _get_last_updated_metadata(self, project: str): - with self.read_engine.begin() as conn: - stmt = select(feast_metadata).where( - feast_metadata.c.metadata_key - == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, - feast_metadata.c.project_id == project, - ) - row = conn.execute(stmt).first() - if not row: - return None - update_time = int(row._mapping["last_updated_timestamp"]) - - return datetime.fromtimestamp(update_time, tz=timezone.utc) - - def _get_permission(self, name: str, project: str) -> Permission: - return self._get_object( - table=permissions, - name=name, - project=project, - proto_class=PermissionProto, - python_class=Permission, - id_field_name="permission_name", - proto_field_name="permission_proto", - not_found_exception=PermissionNotFoundException, - ) - - def _list_permissions( - self, project: str, tags: Optional[dict[str, str]] - ) -> List[Permission]: - return self._list_objects( - permissions, - project, - PermissionProto, - Permission, - "permission_proto", - tags=tags, - ) - - def apply_permission( - self, permission: Permission, project: str, commit: bool = True - ): - return self._apply_object( - permissions, project, "permission_name", permission, "permission_proto" - ) - - def delete_permission(self, name: str, project: str, commit: bool = True): - with self.write_engine.begin() as conn: - stmt = delete(permissions).where( - permissions.c.permission_name == name, - permissions.c.project_id == project, - ) - rows = conn.execute(stmt) - if rows.rowcount < 1: - raise PermissionNotFoundException(name, project) - - def _list_projects( - self, - tags: Optional[dict[str, str]], - ) -> List[Project]: - with self.read_engine.begin() as conn: - stmt = select(projects) - rows = conn.execute(stmt).all() - if rows: - objects = [] - for row in rows: - obj = Project.from_proto( - ProjectProto.FromString(row._mapping["project_proto"]) - ) - if utils.has_all_tags(obj.tags, tags): - objects.append(obj) - return objects - return [] - - def _get_project( - self, - name: str, - ) -> Project: - return self._get_object( - table=projects, - name=name, - project=name, - proto_class=ProjectProto, - python_class=Project, - id_field_name="project_name", - proto_field_name="project_proto", - not_found_exception=ProjectObjectNotFoundException, - ) - - def apply_project( - self, - project: Project, - commit: bool = True, - ): - return self._apply_object( - projects, project.name, "project_name", project, "project_proto" - ) - - def delete_project( - self, - name: str, - commit: bool = True, - ): - project = self.get_project(name, allow_cache=False) - if project: - with self.write_engine.begin() as conn: - for t in { - managed_infra, - saved_datasets, - validation_references, - feature_services, - feature_views, - on_demand_feature_views, - stream_feature_views, - data_sources, - entities, - permissions, - feast_metadata, - projects, - }: - stmt = delete(t).where(t.c.project_id == name) - conn.execute(stmt) - return - - raise ProjectNotFoundException(name) +import logging +import uuid +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union, cast + +from pydantic import StrictInt, StrictStr +from sqlalchemy import ( # type: ignore + BigInteger, + Column, + Index, + LargeBinary, + MetaData, + String, + Table, + create_engine, + delete, + insert, + select, + update, +) +from sqlalchemy.engine import Engine + +from feast import utils +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.errors import ( + DataSourceObjectNotFoundException, + EntityNotFoundException, + FeatureServiceNotFoundException, + FeatureViewNotFoundException, + PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, + SavedDatasetNotFound, + ValidationReferenceNotFound, +) +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry.caching_registry import CachingRegistry +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.permission import Permission +from feast.project import Project +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto +from feast.protos.feast.core.FeatureService_pb2 import ( + FeatureService as FeatureServiceProto, +) +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto +from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( + OnDemandFeatureView as OnDemandFeatureViewProto, +) +from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto +from feast.protos.feast.core.StreamFeatureView_pb2 import ( + StreamFeatureView as StreamFeatureViewProto, +) +from feast.protos.feast.core.ValidationProfile_pb2 import ( + ValidationReference as ValidationReferenceProto, +) +from feast.repo_config import RegistryConfig +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView +from feast.utils import _utc_now + +metadata = MetaData() + + +projects = Table( + "projects", + metadata, + Column("project_id", String(255), primary_key=True), + Column("project_name", String(255), nullable=False), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("project_proto", LargeBinary, nullable=False), +) + +Index("idx_projects_project_id", projects.c.project_id) + +entities = Table( + "entities", + metadata, + Column("entity_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("entity_proto", LargeBinary, nullable=False), +) + +Index("idx_entities_project_id", entities.c.project_id) + +data_sources = Table( + "data_sources", + metadata, + Column("data_source_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("data_source_proto", LargeBinary, nullable=False), +) + +Index("idx_data_sources_project_id", data_sources.c.project_id) + +feature_views = Table( + "feature_views", + metadata, + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("materialized_intervals", LargeBinary, nullable=True), + Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), +) + +Index("idx_feature_views_project_id", feature_views.c.project_id) + +stream_feature_views = Table( + "stream_feature_views", + metadata, + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), +) + +Index("idx_stream_feature_views_project_id", stream_feature_views.c.project_id) + +on_demand_feature_views = Table( + "on_demand_feature_views", + metadata, + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), +) + +Index("idx_on_demand_feature_views_project_id", on_demand_feature_views.c.project_id) + +feature_services = Table( + "feature_services", + metadata, + Column("feature_service_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("feature_service_proto", LargeBinary, nullable=False), +) + +Index("idx_feature_services_project_id", feature_services.c.project_id) + +saved_datasets = Table( + "saved_datasets", + metadata, + Column("saved_dataset_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("saved_dataset_proto", LargeBinary, nullable=False), +) + +Index("idx_saved_datasets_project_id", saved_datasets.c.project_id) + +validation_references = Table( + "validation_references", + metadata, + Column("validation_reference_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("validation_reference_proto", LargeBinary, nullable=False), +) +Index("idx_validation_references_project_id", validation_references.c.project_id) + +managed_infra = Table( + "managed_infra", + metadata, + Column("infra_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("infra_proto", LargeBinary, nullable=False), +) + +Index("idx_managed_infra_project_id", managed_infra.c.project_id) + +permissions = Table( + "permissions", + metadata, + Column("permission_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("permission_proto", LargeBinary, nullable=False), +) + +Index("idx_permissions_project_id", permissions.c.project_id) + + +class FeastMetadataKeys(Enum): + LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" + PROJECT_UUID = "project_uuid" + + +feast_metadata = Table( + "feast_metadata", + metadata, + Column("project_id", String(255), primary_key=True), + Column("metadata_key", String(50), primary_key=True), + Column("metadata_value", String(50), nullable=False), + Column("last_updated_timestamp", BigInteger, nullable=False), +) + +Index("idx_feast_metadata_project_id", feast_metadata.c.project_id) + +logger = logging.getLogger(__name__) + + +class SqlRegistryConfig(RegistryConfig): + registry_type: StrictStr = "sql" + """ str: Provider name or a class name that implements Registry.""" + + path: StrictStr = "" + """ str: Path to metadata store. + If registry_type is 'sql', then this is a database URL as expected by SQLAlchemy """ + + read_path: Optional[StrictStr] = None + """ str: Read Path to metadata store if different from path. + If registry_type is 'sql', then this is a Read Endpoint for database URL. If not set, path will be used for read and write. """ + + sqlalchemy_config_kwargs: Dict[str, Any] = {"echo": False} + """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ + + cache_mode: StrictStr = "sync" + """ str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)""" + + thread_pool_executor_worker_count: StrictInt = 0 + """ int: Number of worker threads to use for asynchronous caching in SQL Registry. If set to 0, it doesn't use ThreadPoolExecutor. """ + + +class SqlRegistry(CachingRegistry): + def __init__( + self, + registry_config, + project: str, + repo_path: Optional[Path], + ): + assert registry_config is not None and isinstance( + registry_config, SqlRegistryConfig + ), "SqlRegistry needs a valid registry_config" + + self.registry_config = registry_config + + self.write_engine: Engine = create_engine( + registry_config.path, **registry_config.sqlalchemy_config_kwargs + ) + if registry_config.read_path: + self.read_engine: Engine = create_engine( + registry_config.read_path, + **registry_config.sqlalchemy_config_kwargs, + ) + else: + self.read_engine = self.write_engine + metadata.create_all(self.write_engine) + self.thread_pool_executor_worker_count = ( + registry_config.thread_pool_executor_worker_count + ) + self.purge_feast_metadata = registry_config.purge_feast_metadata + # Sync feast_metadata to projects table + # when purge_feast_metadata is set to True, Delete data from + # feast_metadata table and list_project_metadata will not return any data + self._sync_feast_metadata_to_projects_table() + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + super().__init__( + project=project, + cache_ttl_seconds=registry_config.cache_ttl_seconds, + cache_mode=registry_config.cache_mode, + ) + + def _sync_feast_metadata_to_projects_table(self): + feast_metadata_projects: dict = {} + projects_set: set = [] + with self.read_engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value + ) + rows = conn.execute(stmt).all() + for row in rows: + feast_metadata_projects[row._mapping["project_id"]] = int( + row._mapping["last_updated_timestamp"] + ) + + if len(feast_metadata_projects) > 0: + with self.read_engine.begin() as conn: + stmt = select(projects) + rows = conn.execute(stmt).all() + for row in rows: + projects_set.append(row._mapping["project_id"]) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects.keys()) - set(projects_set) + for project_name in projects_to_sync: + self.apply_project( + Project( + name=project_name, + created_timestamp=datetime.fromtimestamp( + feast_metadata_projects[project_name], tz=timezone.utc + ), + ), + commit=True, + ) + + if self.purge_feast_metadata: + with self.write_engine.begin() as conn: + for project_name in feast_metadata_projects: + stmt = delete(feast_metadata).where( + feast_metadata.c.project_id == project_name + ) + conn.execute(stmt) + + def teardown(self): + for t in { + entities, + data_sources, + feature_views, + feature_services, + on_demand_feature_views, + saved_datasets, + validation_references, + permissions, + }: + with self.write_engine.begin() as conn: + stmt = delete(t) + conn.execute(stmt) + + def _get_stream_feature_view(self, name: str, project: str): + return self._get_object( + table=stream_feature_views, + name=name, + project=project, + proto_class=StreamFeatureViewProto, + python_class=StreamFeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=FeatureViewNotFoundException, + ) + + def _list_stream_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[StreamFeatureView]: + return self._list_objects( + stream_feature_views, + project, + StreamFeatureViewProto, + StreamFeatureView, + "feature_view_proto", + tags=tags, + ) + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + return self._apply_object( + table=entities, + project=project, + id_field_name="entity_name", + obj=entity, + proto_field_name="entity_proto", + ) + + def _get_entity(self, name: str, project: str) -> Entity: + return self._get_object( + table=entities, + name=name, + project=project, + proto_class=EntityProto, + python_class=Entity, + id_field_name="entity_name", + proto_field_name="entity_proto", + not_found_exception=EntityNotFoundException, + ) + + def _get_any_feature_view(self, name: str, project: str) -> BaseFeatureView: + fv = self._get_object( + table=feature_views, + name=name, + project=project, + proto_class=FeatureViewProto, + python_class=FeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=None, + ) + + if not fv: + fv = self._get_object( + table=on_demand_feature_views, + name=name, + project=project, + proto_class=OnDemandFeatureViewProto, + python_class=OnDemandFeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=None, + ) + + if not fv: + fv = self._get_object( + table=stream_feature_views, + name=name, + project=project, + proto_class=StreamFeatureViewProto, + python_class=StreamFeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=FeatureViewNotFoundException, + ) + return fv + + def _list_all_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[BaseFeatureView]: + return ( + cast( + list[BaseFeatureView], + self._list_feature_views(project=project, tags=tags), + ) + + cast( + list[BaseFeatureView], + self._list_stream_feature_views(project=project, tags=tags), + ) + + cast( + list[BaseFeatureView], + self._list_on_demand_feature_views(project=project, tags=tags), + ) + ) + + def _get_feature_view(self, name: str, project: str) -> FeatureView: + return self._get_object( + table=feature_views, + name=name, + project=project, + proto_class=FeatureViewProto, + python_class=FeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=FeatureViewNotFoundException, + ) + + def _get_on_demand_feature_view( + self, name: str, project: str + ) -> OnDemandFeatureView: + return self._get_object( + table=on_demand_feature_views, + name=name, + project=project, + proto_class=OnDemandFeatureViewProto, + python_class=OnDemandFeatureView, + id_field_name="feature_view_name", + proto_field_name="feature_view_proto", + not_found_exception=FeatureViewNotFoundException, + ) + + def _get_feature_service(self, name: str, project: str) -> FeatureService: + return self._get_object( + table=feature_services, + name=name, + project=project, + proto_class=FeatureServiceProto, + python_class=FeatureService, + id_field_name="feature_service_name", + proto_field_name="feature_service_proto", + not_found_exception=FeatureServiceNotFoundException, + ) + + def _get_saved_dataset(self, name: str, project: str) -> SavedDataset: + return self._get_object( + table=saved_datasets, + name=name, + project=project, + proto_class=SavedDatasetProto, + python_class=SavedDataset, + id_field_name="saved_dataset_name", + proto_field_name="saved_dataset_proto", + not_found_exception=SavedDatasetNotFound, + ) + + def _get_validation_reference(self, name: str, project: str) -> ValidationReference: + return self._get_object( + table=validation_references, + name=name, + project=project, + proto_class=ValidationReferenceProto, + python_class=ValidationReference, + id_field_name="validation_reference_name", + proto_field_name="validation_reference_proto", + not_found_exception=ValidationReferenceNotFound, + ) + + def _list_validation_references( + self, project: str, tags: Optional[dict[str, str]] = None + ) -> List[ValidationReference]: + return self._list_objects( + table=validation_references, + project=project, + proto_class=ValidationReferenceProto, + python_class=ValidationReference, + proto_field_name="validation_reference_proto", + tags=tags, + ) + + def _list_entities( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[Entity]: + return self._list_objects( + entities, project, EntityProto, Entity, "entity_proto", tags=tags + ) + + def delete_entity(self, name: str, project: str, commit: bool = True): + return self._delete_object( + entities, name, project, "entity_name", EntityNotFoundException + ) + + def delete_feature_view(self, name: str, project: str, commit: bool = True): + deleted_count = 0 + for table in { + feature_views, + on_demand_feature_views, + stream_feature_views, + }: + deleted_count += self._delete_object( + table, name, project, "feature_view_name", None + ) + if deleted_count == 0: + raise FeatureViewNotFoundException(name, project) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + return self._delete_object( + feature_services, + name, + project, + "feature_service_name", + FeatureServiceNotFoundException, + ) + + def _get_data_source(self, name: str, project: str) -> DataSource: + return self._get_object( + table=data_sources, + name=name, + project=project, + proto_class=DataSourceProto, + python_class=DataSource, + id_field_name="data_source_name", + proto_field_name="data_source_proto", + not_found_exception=DataSourceObjectNotFoundException, + ) + + def _list_data_sources( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[DataSource]: + return self._list_objects( + data_sources, + project, + DataSourceProto, + DataSource, + "data_source_proto", + tags=tags, + ) + + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + return self._apply_object( + data_sources, project, "data_source_name", data_source, "data_source_proto" + ) + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + fv_table = self._infer_fv_table(feature_view) + + return self._apply_object( + fv_table, project, "feature_view_name", feature_view, "feature_view_proto" + ) + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + return self._apply_object( + feature_services, + project, + "feature_service_name", + feature_service, + "feature_service_proto", + ) + + def delete_data_source(self, name: str, project: str, commit: bool = True): + with self.write_engine.begin() as conn: + stmt = delete(data_sources).where( + data_sources.c.data_source_name == name, + data_sources.c.project_id == project, + ) + rows = conn.execute(stmt) + if rows.rowcount < 1: + raise DataSourceObjectNotFoundException(name, project) + + def _list_feature_services( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[FeatureService]: + return self._list_objects( + feature_services, + project, + FeatureServiceProto, + FeatureService, + "feature_service_proto", + tags=tags, + ) + + def _list_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[FeatureView]: + return self._list_objects( + feature_views, + project, + FeatureViewProto, + FeatureView, + "feature_view_proto", + tags=tags, + ) + + def _list_saved_datasets( + self, project: str, tags: Optional[dict[str, str]] = None + ) -> List[SavedDataset]: + return self._list_objects( + saved_datasets, + project, + SavedDatasetProto, + SavedDataset, + "saved_dataset_proto", + tags=tags, + ) + + def _list_on_demand_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[OnDemandFeatureView]: + return self._list_objects( + on_demand_feature_views, + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "feature_view_proto", + tags=tags, + ) + + def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: + with self.read_engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.project_id == project, + ) + rows = conn.execute(stmt).all() + if rows: + project_metadata = ProjectMetadata(project_name=project) + for row in rows: + if ( + row._mapping["metadata_key"] + == FeastMetadataKeys.PROJECT_UUID.value + ): + project_metadata.project_uuid = row._mapping["metadata_value"] + break + # TODO(adchia): Add other project metadata in a structured way + return [project_metadata] + return [] + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + return self._apply_object( + saved_datasets, + project, + "saved_dataset_name", + saved_dataset, + "saved_dataset_proto", + ) + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + return self._apply_object( + validation_references, + project, + "validation_reference_name", + validation_reference, + "validation_reference_proto", + ) + + def apply_materialization( + self, + feature_view: Union[FeatureView, OnDemandFeatureView], + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + table = self._infer_fv_table(feature_view) + python_class, proto_class = self._infer_fv_classes(feature_view) + + if python_class in {OnDemandFeatureView}: + raise ValueError( + f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" + ) + fv: Union[FeatureView, StreamFeatureView] = self._get_object( + table, + feature_view.name, + project, + proto_class, + python_class, + "feature_view_name", + "feature_view_proto", + FeatureViewNotFoundException, + ) + fv.materialization_intervals.append((start_date, end_date)) + self._apply_object( + table, project, "feature_view_name", fv, "feature_view_proto" + ) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + self._delete_object( + validation_references, + name, + project, + "validation_reference_name", + ValidationReferenceNotFound, + ) + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + self._apply_object( + table=managed_infra, + project=project, + id_field_name="infra_name", + obj=infra, + proto_field_name="infra_proto", + name="infra_obj", + ) + + def _get_infra(self, project: str) -> Infra: + infra_object = self._get_object( + table=managed_infra, + name="infra_obj", + project=project, + proto_class=InfraProto, + python_class=Infra, + id_field_name="infra_name", + proto_field_name="infra_proto", + not_found_exception=None, + ) + if infra_object: + return infra_object + return Infra() + + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + table = self._infer_fv_table(feature_view) + + name = feature_view.name + with self.write_engine.begin() as conn: + stmt = select(table).where( + getattr(table.c, "feature_view_name") == name, + table.c.project_id == project, + ) + row = conn.execute(stmt).first() + update_datetime = _utc_now() + update_time = int(update_datetime.timestamp()) + if row: + values = { + "user_metadata": metadata_bytes, + "last_updated_timestamp": update_time, + } + update_stmt = ( + update(table) + .where( + getattr(table.c, "feature_view_name") == name, + table.c.project_id == project, + ) + .values( + values, + ) + ) + conn.execute(update_stmt) + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def _infer_fv_table(self, feature_view): + if isinstance(feature_view, StreamFeatureView): + table = stream_feature_views + elif isinstance(feature_view, FeatureView): + table = feature_views + elif isinstance(feature_view, OnDemandFeatureView): + table = on_demand_feature_views + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return table + + def _infer_fv_classes(self, feature_view): + if isinstance(feature_view, StreamFeatureView): + python_class, proto_class = StreamFeatureView, StreamFeatureViewProto + elif isinstance(feature_view, FeatureView): + python_class, proto_class = FeatureView, FeatureViewProto + elif isinstance(feature_view, OnDemandFeatureView): + python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return python_class, proto_class + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + table = self._infer_fv_table(feature_view) + + name = feature_view.name + with self.read_engine.begin() as conn: + stmt = select(table).where(getattr(table.c, "feature_view_name") == name) + row = conn.execute(stmt).first() + if row: + return row._mapping["user_metadata"] + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def proto(self) -> RegistryProto: + r = RegistryProto() + last_updated_timestamps = [] + + def process_project(project: Project): + nonlocal r, last_updated_timestamps + project_name = project.name + last_updated_timestamp = project.last_updated_timestamp + + try: + cached_project = self.get_project(project_name, True) + except ProjectObjectNotFoundException: + cached_project = None + + allow_cache = False + + if cached_project is not None: + allow_cache = ( + last_updated_timestamp <= cached_project.last_updated_timestamp + ) + + r.projects.extend([project.to_proto()]) + last_updated_timestamps.append(last_updated_timestamp) + + for lister, registry_proto_field in [ + (self.list_entities, r.entities), + (self.list_feature_views, r.feature_views), + (self.list_data_sources, r.data_sources), + (self.list_on_demand_feature_views, r.on_demand_feature_views), + (self.list_stream_feature_views, r.stream_feature_views), + (self.list_feature_services, r.feature_services), + (self.list_saved_datasets, r.saved_datasets), + (self.list_validation_references, r.validation_references), + (self.list_permissions, r.permissions), + ]: + objs: List[Any] = lister(project_name, allow_cache) # type: ignore + if objs: + obj_protos = [obj.to_proto() for obj in objs] + for obj_proto in obj_protos: + if "spec" in obj_proto.DESCRIPTOR.fields_by_name: + obj_proto.spec.project = project_name + else: + obj_proto.project = project_name + registry_proto_field.extend(obj_protos) + + # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, + # the registry proto only has a single infra field, which we're currently setting as the "last" project. + r.infra.CopyFrom(self.get_infra(project_name).to_proto()) + + projects_list = self.list_projects(allow_cache=False) + if self.thread_pool_executor_worker_count == 0: + for project in projects_list: + process_project(project) + else: + with ThreadPoolExecutor( + max_workers=self.thread_pool_executor_worker_count + ) as executor: + executor.map(process_project, projects_list) + + if last_updated_timestamps: + r.last_updated.FromDatetime(max(last_updated_timestamps)) + + return r + + def commit(self): + # This method is a no-op since we're always writing values eagerly to the db. + pass + + def _initialize_project_if_not_exists(self, project_name: str): + try: + self.get_project(project_name, allow_cache=True) + return + except ProjectObjectNotFoundException: + try: + self.get_project(project_name, allow_cache=False) + return + except ProjectObjectNotFoundException: + self.apply_project(Project(name=project_name), commit=True) + + def _apply_object( + self, + table: Table, + project: str, + id_field_name: str, + obj: Any, + proto_field_name: str, + name: Optional[str] = None, + ): + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option + if not isinstance(obj, Project): + self._initialize_project_if_not_exists(project_name=project) + name = name or (obj.name if hasattr(obj, "name") else None) + assert name, f"name needs to be provided for {obj}" + + with self.write_engine.begin() as conn: + update_datetime = _utc_now() + update_time = int(update_datetime.timestamp()) + stmt = select(table).where( + getattr(table.c, id_field_name) == name, table.c.project_id == project + ) + row = conn.execute(stmt).first() + if hasattr(obj, "last_updated_timestamp"): + obj.last_updated_timestamp = update_datetime + + if row: + if proto_field_name in [ + "entity_proto", + "saved_dataset_proto", + "feature_view_proto", + "feature_service_proto", + "permission_proto", + "project_proto", + ]: + deserialized_proto = self.deserialize_registry_values( + row._mapping[proto_field_name], type(obj) + ) + obj.created_timestamp = ( + deserialized_proto.meta.created_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ) + ) + if isinstance(obj, (FeatureView, StreamFeatureView)): + obj.update_materialization_intervals( + type(obj) + .from_proto(deserialized_proto) + .materialization_intervals + ) + values = { + proto_field_name: obj.to_proto().SerializeToString(), + "last_updated_timestamp": update_time, + } + update_stmt = ( + update(table) + .where( + getattr(table.c, id_field_name) == name, + table.c.project_id == project, + ) + .values( + values, + ) + ) + conn.execute(update_stmt) + else: + obj_proto = obj.to_proto() + + if hasattr(obj_proto, "meta") and hasattr( + obj_proto.meta, "created_timestamp" + ): + if not obj_proto.meta.HasField("created_timestamp"): + obj_proto.meta.created_timestamp.FromDatetime(update_datetime) + + values = { + id_field_name: name, + proto_field_name: obj_proto.SerializeToString(), + "last_updated_timestamp": update_time, + "project_id": project, + } + insert_stmt = insert(table).values( + values, + ) + conn.execute(insert_stmt) + + if not isinstance(obj, Project): + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + if not self.purge_feast_metadata: + self._set_last_updated_metadata(update_datetime, project) + + def _maybe_init_project_metadata(self, project): + # Initialize project metadata if needed + with self.write_engine.begin() as conn: + update_datetime = _utc_now() + update_time = int(update_datetime.timestamp()) + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value, + feast_metadata.c.project_id == project, + ) + row = conn.execute(stmt).first() + if not row: + new_project_uuid = f"{uuid.uuid4()}" + values = { + "metadata_key": FeastMetadataKeys.PROJECT_UUID.value, + "metadata_value": new_project_uuid, + "last_updated_timestamp": update_time, + "project_id": project, + } + insert_stmt = insert(feast_metadata).values(values) + conn.execute(insert_stmt) + + def _delete_object( + self, + table: Table, + name: str, + project: str, + id_field_name: str, + not_found_exception: Optional[Callable], + ): + with self.write_engine.begin() as conn: + stmt = delete(table).where( + getattr(table.c, id_field_name) == name, table.c.project_id == project + ) + rows = conn.execute(stmt) + if rows.rowcount < 1 and not_found_exception: + raise not_found_exception(name, project) + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + if not self.purge_feast_metadata: + self._set_last_updated_metadata(_utc_now(), project) + + return rows.rowcount + + def _get_object( + self, + table: Table, + name: str, + project: str, + proto_class: Any, + python_class: Any, + id_field_name: str, + proto_field_name: str, + not_found_exception: Optional[Callable], + ): + with self.read_engine.begin() as conn: + stmt = select(table).where( + getattr(table.c, id_field_name) == name, table.c.project_id == project + ) + row = conn.execute(stmt).first() + if row: + _proto = proto_class.FromString(row._mapping[proto_field_name]) + return python_class.from_proto(_proto) + if not_found_exception: + raise not_found_exception(name, project) + else: + return None + + def _list_objects( + self, + table: Table, + project: str, + proto_class: Any, + python_class: Any, + proto_field_name: str, + tags: Optional[dict[str, str]] = None, + ): + with self.read_engine.begin() as conn: + stmt = select(table).where(table.c.project_id == project) + rows = conn.execute(stmt).all() + if rows: + objects = [] + for row in rows: + obj = python_class.from_proto( + proto_class.FromString(row._mapping[proto_field_name]) + ) + if utils.has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def _set_last_updated_metadata(self, last_updated: datetime, project: str): + with self.write_engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key + == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, + feast_metadata.c.project_id == project, + ) + row = conn.execute(stmt).first() + + update_time = int(last_updated.timestamp()) + + values = { + "metadata_key": FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, + "metadata_value": f"{update_time}", + "last_updated_timestamp": update_time, + "project_id": project, + } + if row: + update_stmt = ( + update(feast_metadata) + .where( + feast_metadata.c.metadata_key + == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, + feast_metadata.c.project_id == project, + ) + .values(values) + ) + conn.execute(update_stmt) + else: + insert_stmt = insert(feast_metadata).values( + values, + ) + conn.execute(insert_stmt) + + def _get_last_updated_metadata(self, project: str): + with self.read_engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key + == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, + feast_metadata.c.project_id == project, + ) + row = conn.execute(stmt).first() + if not row: + return None + update_time = int(row._mapping["last_updated_timestamp"]) + + return datetime.fromtimestamp(update_time, tz=timezone.utc) + + def _get_permission(self, name: str, project: str) -> Permission: + return self._get_object( + table=permissions, + name=name, + project=project, + proto_class=PermissionProto, + python_class=Permission, + id_field_name="permission_name", + proto_field_name="permission_proto", + not_found_exception=PermissionNotFoundException, + ) + + def _list_permissions( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[Permission]: + return self._list_objects( + permissions, + project, + PermissionProto, + Permission, + "permission_proto", + tags=tags, + ) + + def apply_permission( + self, permission: Permission, project: str, commit: bool = True + ): + return self._apply_object( + permissions, project, "permission_name", permission, "permission_proto" + ) + + def delete_permission(self, name: str, project: str, commit: bool = True): + with self.write_engine.begin() as conn: + stmt = delete(permissions).where( + permissions.c.permission_name == name, + permissions.c.project_id == project, + ) + rows = conn.execute(stmt) + if rows.rowcount < 1: + raise PermissionNotFoundException(name, project) + + def _list_projects( + self, + tags: Optional[dict[str, str]], + ) -> List[Project]: + with self.read_engine.begin() as conn: + stmt = select(projects) + rows = conn.execute(stmt).all() + if rows: + objects = [] + for row in rows: + obj = Project.from_proto( + ProjectProto.FromString(row._mapping["project_proto"]) + ) + if utils.has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def _get_project( + self, + name: str, + ) -> Project: + return self._get_object( + table=projects, + name=name, + project=name, + proto_class=ProjectProto, + python_class=Project, + id_field_name="project_name", + proto_field_name="project_proto", + not_found_exception=ProjectObjectNotFoundException, + ) + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + return self._apply_object( + projects, project.name, "project_name", project, "project_proto" + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + project = self.get_project(name, allow_cache=False) + if project: + with self.write_engine.begin() as conn: + for t in { + managed_infra, + saved_datasets, + validation_references, + feature_services, + feature_views, + on_demand_feature_views, + stream_feature_views, + data_sources, + entities, + permissions, + feast_metadata, + projects, + }: + stmt = delete(t).where(t.c.project_id == name) + conn.execute(stmt) + return + + raise ProjectNotFoundException(name)