From cd7f971bd79b4ef44e51da2dc2b75797f48ef5be Mon Sep 17 00:00:00 2001 From: ntkathole Date: Tue, 18 Feb 2025 19:06:21 +0530 Subject: [PATCH] fix: Skip refresh if already in progress or if lock is already held Signed-off-by: ntkathole --- .../feast/infra/registry/caching_registry.py | 25 ++- .../tests/unit/infra/registry/__init__.py | 0 .../unit/infra/registry/test_registry.py | 197 ++++++++++++++++++ 3 files changed, 219 insertions(+), 3 deletions(-) create mode 100644 sdk/python/tests/unit/infra/registry/__init__.py create mode 100644 sdk/python/tests/unit/infra/registry/test_registry.py diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index 042eee06ab7..23ab80ee1d8 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -425,12 +425,24 @@ def list_projects( return self._list_projects(tags) def refresh(self, project: Optional[str] = None): - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = _utc_now() + if self._refresh_lock.locked(): + logger.info("Skipping refresh if already in progress") + return + try: + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = _utc_now() + except Exception as e: + logger.error(f"Error while refreshing registry: {e}", exc_info=True) def _refresh_cached_registry_if_necessary(self): if self.cache_mode == "sync": - with self._refresh_lock: + # Try acquiring the lock without blocking + if not self._refresh_lock.acquire(blocking=False): + logger.info( + "Skipping refresh if lock is already held by another thread" + ) + return + try: if self.cached_registry_proto == RegistryProto(): # Avoids the need to refresh the registry when cache is not populated yet # Specially during the __init__ phase @@ -454,6 +466,13 @@ def _refresh_cached_registry_if_necessary(self): if expired: logger.info("Registry cache expired, so refreshing") self.refresh() + except Exception as e: + logger.error( + f"Error in _refresh_cached_registry_if_necessary: {e}", + exc_info=True, + ) + finally: + self._refresh_lock.release() # Always release the lock safely def _start_thread_async_refresh(self, cache_ttl_seconds): self.refresh() diff --git a/sdk/python/tests/unit/infra/registry/__init__.py b/sdk/python/tests/unit/infra/registry/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/tests/unit/infra/registry/test_registry.py b/sdk/python/tests/unit/infra/registry/test_registry.py new file mode 100644 index 00000000000..65dea2ff680 --- /dev/null +++ b/sdk/python/tests/unit/infra/registry/test_registry.py @@ -0,0 +1,197 @@ +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +import pytest + +from feast.infra.registry.caching_registry import CachingRegistry + + +class TestCachingRegistry(CachingRegistry): + """Test subclass that implements abstract methods as no-ops""" + + def _get_any_feature_view(self, *args, **kwargs): + pass + + def _get_data_source(self, *args, **kwargs): + pass + + def _get_entity(self, *args, **kwargs): + pass + + def _get_feature_service(self, *args, **kwargs): + pass + + def _get_feature_view(self, *args, **kwargs): + pass + + def _get_infra(self, *args, **kwargs): + pass + + def _get_on_demand_feature_view(self, *args, **kwargs): + pass + + def _get_permission(self, *args, **kwargs): + pass + + def _get_project(self, *args, **kwargs): + pass + + def _get_saved_dataset(self, *args, **kwargs): + pass + + def _get_stream_feature_view(self, *args, **kwargs): + pass + + def _get_validation_reference(self, *args, **kwargs): + pass + + def _list_all_feature_views(self, *args, **kwargs): + pass + + def _list_data_sources(self, *args, **kwargs): + pass + + def _list_entities(self, *args, **kwargs): + pass + + def _list_feature_services(self, *args, **kwargs): + pass + + def _list_feature_views(self, *args, **kwargs): + pass + + def _list_on_demand_feature_views(self, *args, **kwargs): + pass + + def _list_permissions(self, *args, **kwargs): + pass + + def _list_project_metadata(self, *args, **kwargs): + pass + + def _list_projects(self, *args, **kwargs): + pass + + def _list_saved_datasets(self, *args, **kwargs): + pass + + def _list_stream_feature_views(self, *args, **kwargs): + pass + + def _list_validation_references(self, *args, **kwargs): + pass + + def apply_data_source(self, *args, **kwargs): + pass + + def apply_entity(self, *args, **kwargs): + pass + + def apply_feature_service(self, *args, **kwargs): + pass + + def apply_feature_view(self, *args, **kwargs): + pass + + def apply_materialization(self, *args, **kwargs): + pass + + def apply_permission(self, *args, **kwargs): + pass + + def apply_project(self, *args, **kwargs): + pass + + def apply_saved_dataset(self, *args, **kwargs): + pass + + def apply_user_metadata(self, *args, **kwargs): + pass + + def apply_validation_reference(self, *args, **kwargs): + pass + + def commit(self, *args, **kwargs): + pass + + def delete_data_source(self, *args, **kwargs): + pass + + def delete_entity(self, *args, **kwargs): + pass + + def delete_feature_service(self, *args, **kwargs): + pass + + def delete_feature_view(self, *args, **kwargs): + pass + + def delete_permission(self, *args, **kwargs): + pass + + def delete_project(self, *args, **kwargs): + pass + + def delete_validation_reference(self, *args, **kwargs): + pass + + def get_user_metadata(self, *args, **kwargs): + pass + + def proto(self, *args, **kwargs): + pass + + def update_infra(self, *args, **kwargs): + pass + + +@pytest.fixture +def registry(): + """Fixture to create a real instance of CachingRegistry""" + return TestCachingRegistry( + project="test_example", cache_ttl_seconds=2, cache_mode="sync" + ) + + +def test_cache_expiry_triggers_refresh(registry): + """Test that an expired cache triggers a refresh""" + # Set cache creation time to a value that is expired + registry.cached_registry_proto = "some_cached_data" + registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta( + seconds=5 + ) + + # Mock _refresh_cached_registry_if_necessary to check if it is called + with patch.object( + CachingRegistry, + "_refresh_cached_registry_if_necessary", + wraps=registry._refresh_cached_registry_if_necessary, + ) as mock_refresh_check: + registry._refresh_cached_registry_if_necessary() + mock_refresh_check.assert_called_once() + + # Now check if the refresh was actually triggered + with patch.object( + CachingRegistry, "refresh", wraps=registry.refresh + ) as mock_refresh: + registry._refresh_cached_registry_if_necessary() + mock_refresh.assert_called_once() + + +def test_skip_refresh_if_lock_held(registry): + """Test that refresh is skipped if the lock is already held by another thread""" + registry.cached_registry_proto = "some_cached_data" + registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta( + seconds=5 + ) + + # Acquire the lock manually to simulate another thread holding it + registry._refresh_lock.acquire() + with patch.object( + CachingRegistry, "refresh", wraps=registry.refresh + ) as mock_refresh: + registry._refresh_cached_registry_if_necessary() + + # Since the lock was already held, refresh should NOT be called + mock_refresh.assert_not_called() + registry._refresh_lock.release()