From e96f7d6681ffb3fdda0ed8e1dcc9eeb8eecc1a8d Mon Sep 17 00:00:00 2001
From: Shizoqua
Date: Sat, 3 Jan 2026 00:13:03 +0100
Subject: [PATCH 1/6] Fix image utils optional deps errors
Signed-off-by: Shizoqua
Signed-off-by: Lanre Shittu <136805224+Shizoqua@users.noreply.github.com>
Signed-off-by: Shizoqua
---
sdk/python/feast/image_utils.py | 2 +
sdk/python/pytest.ini | 4 -
sdk/python/tests/conftest.py | 247 ++++++++++--------
.../unit/test_image_utils_optional_deps.py | 33 +++
4 files changed, 179 insertions(+), 107 deletions(-)
create mode 100644 sdk/python/tests/unit/test_image_utils_optional_deps.py
diff --git a/sdk/python/feast/image_utils.py b/sdk/python/feast/image_utils.py
index 88622284e93..6c540473cd4 100644
--- a/sdk/python/feast/image_utils.py
+++ b/sdk/python/feast/image_utils.py
@@ -241,6 +241,7 @@ def validate_image_format(image_bytes: bytes) -> bool:
Returns:
True if valid image, False otherwise
"""
+ _check_image_dependencies()
try:
with Image.open(io.BytesIO(image_bytes)) as img:
img.verify()
@@ -259,6 +260,7 @@ def get_image_metadata(image_bytes: bytes) -> dict:
Raises:
ValueError: If image cannot be processed
"""
+ _check_image_dependencies()
try:
with Image.open(io.BytesIO(image_bytes)) as img:
return {
diff --git a/sdk/python/pytest.ini b/sdk/python/pytest.ini
index d79459c0d0e..f5d5647d9ff 100644
--- a/sdk/python/pytest.ini
+++ b/sdk/python/pytest.ini
@@ -6,12 +6,8 @@ markers =
universal_online_stores: mark a test as using all online stores.
rbac_remote_integration_test: mark a integration test related to rbac and remote functionality.
-env =
- IS_TEST=True
-
filterwarnings =
error::_pytest.warning_types.PytestConfigWarning
- error::_pytest.warning_types.PytestUnhandledCoroutineWarning
ignore::DeprecationWarning:pyspark.sql.pandas.*:
ignore::DeprecationWarning:pyspark.sql.connect.*:
ignore::DeprecationWarning:httpx.*:
diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py
index a57846c7e2e..46beb19ceaa 100644
--- a/sdk/python/tests/conftest.py
+++ b/sdk/python/tests/conftest.py
@@ -36,28 +36,67 @@
create_document_dataset,
create_image_dataset,
)
-from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
- IntegrationTestRepoConfig,
-)
-from tests.integration.feature_repos.repo_configuration import ( # noqa: E402
- AVAILABLE_OFFLINE_STORES,
- AVAILABLE_ONLINE_STORES,
- OFFLINE_STORE_TO_PROVIDER_CONFIG,
- Environment,
- TestData,
- construct_test_environment,
- construct_universal_feature_views,
- construct_universal_test_data,
-)
-from tests.integration.feature_repos.universal.data_sources.file import ( # noqa: E402
- FileDataSourceCreator,
-)
-from tests.integration.feature_repos.universal.entities import ( # noqa: E402
- customer,
- driver,
- location,
-)
-from tests.utils.auth_permissions_util import default_store
+try:
+ from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
+ IntegrationTestRepoConfig,
+ )
+ from tests.integration.feature_repos.repo_configuration import ( # noqa: E402
+ AVAILABLE_OFFLINE_STORES,
+ AVAILABLE_ONLINE_STORES,
+ OFFLINE_STORE_TO_PROVIDER_CONFIG,
+ Environment,
+ TestData,
+ construct_test_environment,
+ construct_universal_feature_views,
+ construct_universal_test_data,
+ )
+ from tests.integration.feature_repos.universal.data_sources.file import ( # noqa: E402
+ FileDataSourceCreator,
+ )
+ from tests.integration.feature_repos.universal.entities import ( # noqa: E402
+ customer,
+ driver,
+ location,
+ )
+
+ _integration_test_deps_available = True
+except ModuleNotFoundError:
+ _integration_test_deps_available = False
+
+ IntegrationTestRepoConfig = None # type: ignore[assignment]
+ AVAILABLE_OFFLINE_STORES = [] # type: ignore[assignment]
+ AVAILABLE_ONLINE_STORES = {} # type: ignore[assignment]
+ OFFLINE_STORE_TO_PROVIDER_CONFIG = {} # type: ignore[assignment]
+ Environment = Any # type: ignore[assignment]
+ TestData = Any # type: ignore[assignment]
+
+ def construct_test_environment(*args, **kwargs): # type: ignore[no-redef]
+ raise RuntimeError("Integration test dependencies are not available")
+
+ def construct_universal_feature_views(*args, **kwargs): # type: ignore[no-redef]
+ raise RuntimeError("Integration test dependencies are not available")
+
+ def construct_universal_test_data(*args, **kwargs): # type: ignore[no-redef]
+ raise RuntimeError("Integration test dependencies are not available")
+
+ class FileDataSourceCreator: # type: ignore[no-redef]
+ pass
+
+ def customer(*args, **kwargs): # type: ignore[no-redef]
+ raise RuntimeError("Integration test dependencies are not available")
+
+ def driver(*args, **kwargs): # type: ignore[no-redef]
+ raise RuntimeError("Integration test dependencies are not available")
+
+ def location(*args, **kwargs): # type: ignore[no-redef]
+ raise RuntimeError("Integration test dependencies are not available")
+
+try:
+ from tests.utils.auth_permissions_util import default_store
+except ModuleNotFoundError:
+
+ def default_store(*args, **kwargs): # type: ignore[no-redef]
+ raise RuntimeError("Auth test dependencies are not available")
from tests.utils.http_server import check_port_open, free_port # noqa: E402
from tests.utils.ssl_certifcates_util import (
combine_trust_stores,
@@ -67,6 +106,8 @@
logger = logging.getLogger(__name__)
+os.environ.setdefault("IS_TEST", "True")
+
level = logging.INFO
logging.basicConfig(
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
@@ -85,7 +126,7 @@
def pytest_configure(config):
- if platform in ["darwin", "windows"]:
+ if platform == "darwin" or platform.startswith("win"):
multiprocessing.set_start_method("spawn", force=True)
else:
multiprocessing.set_start_method("fork")
@@ -239,92 +280,92 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
See more examples at https://docs.pytest.org/en/6.2.x/example/parametrize.html#paramexamples
- We also utilize indirect parametrization here. Since `environment` is a fixture,
- when we call metafunc.parametrize("environment", ..., indirect=True) we actually
- parametrizing this "environment" fixture and not the test itself.
- Moreover, by utilizing `_config_cache` we are able to share `environment` fixture between different tests.
- In order for pytest to group tests together (and share environment fixture)
- parameter should point to the same Python object (hence, we use _config_cache dict to store those objects).
"""
- if "environment" in metafunc.fixturenames:
- markers = {m.name: m for m in metafunc.definition.own_markers}
- offline_stores = None
- if "universal_offline_stores" in markers:
- # Offline stores can be explicitly requested
- if "only" in markers["universal_offline_stores"].kwargs:
- offline_stores = [
- OFFLINE_STORE_TO_PROVIDER_CONFIG.get(store_name)
- for store_name in markers["universal_offline_stores"].kwargs["only"]
- if store_name in OFFLINE_STORE_TO_PROVIDER_CONFIG
- ]
- else:
- offline_stores = AVAILABLE_OFFLINE_STORES
+ if "environment" not in metafunc.fixturenames:
+ return
+
+ if not _integration_test_deps_available:
+ pytest.skip("Integration test dependencies are not available")
+
+ markers = {m.name: m for m in metafunc.definition.iter_markers()}
+
+ offline_stores = None
+ if "universal_offline_stores" in markers:
+ # Offline stores can be explicitly requested
+ if "only" in markers["universal_offline_stores"].kwargs:
+ offline_stores = [
+ OFFLINE_STORE_TO_PROVIDER_CONFIG.get(store_name)
+ for store_name in markers["universal_offline_stores"].kwargs["only"]
+ if store_name in OFFLINE_STORE_TO_PROVIDER_CONFIG
+ ]
else:
- # default offline store for testing online store dimension
- offline_stores = [("local", FileDataSourceCreator)]
-
- online_stores = None
- if "universal_online_stores" in markers:
- # Online stores can be explicitly requested
- if "only" in markers["universal_online_stores"].kwargs:
- online_stores = [
- AVAILABLE_ONLINE_STORES.get(store_name)
- for store_name in markers["universal_online_stores"].kwargs["only"]
- if store_name in AVAILABLE_ONLINE_STORES
- ]
- else:
- online_stores = AVAILABLE_ONLINE_STORES.values()
-
- if online_stores is None:
- # No online stores requested -> setting the default or first available
+ offline_stores = AVAILABLE_OFFLINE_STORES
+ else:
+ # default offline store for testing online store dimension
+ offline_stores = [("local", FileDataSourceCreator)]
+
+ online_stores = None
+ if "universal_online_stores" in markers:
+ # Online stores can be explicitly requested
+ if "only" in markers["universal_online_stores"].kwargs:
online_stores = [
- AVAILABLE_ONLINE_STORES.get(
- "redis",
- AVAILABLE_ONLINE_STORES.get(
- "sqlite", next(iter(AVAILABLE_ONLINE_STORES.values()))
- ),
- )
+ AVAILABLE_ONLINE_STORES.get(store_name)
+ for store_name in markers["universal_online_stores"].kwargs["only"]
+ if store_name in AVAILABLE_ONLINE_STORES
]
-
- extra_dimensions: List[Dict[str, Any]] = [{}]
-
- if "python_server" in metafunc.fixturenames:
- extra_dimensions.extend([{"python_feature_server": True}])
-
- configs = []
- if offline_stores:
- for provider, offline_store_creator in offline_stores:
- for online_store, online_store_creator in online_stores:
- for dim in extra_dimensions:
- config = {
- "provider": provider,
- "offline_store_creator": offline_store_creator,
- "online_store": online_store,
- "online_store_creator": online_store_creator,
- **dim,
- }
-
- c = IntegrationTestRepoConfig(**config)
-
- if c not in _config_cache:
- marks = [
- pytest.mark.xdist_group(name=m)
- for m in c.offline_store_creator.xdist_groups()
- ]
- # Check if there are any test markers associated with the creator and add them.
- if c.offline_store_creator.test_markers():
- marks.extend(c.offline_store_creator.test_markers())
-
- _config_cache[c] = pytest.param(c, marks=marks)
-
- configs.append(_config_cache[c])
else:
- # No offline stores requested -> setting the default or first available
- offline_stores = [("local", FileDataSourceCreator)]
+ online_stores = AVAILABLE_ONLINE_STORES.values()
- metafunc.parametrize(
- "environment", configs, indirect=True, ids=[str(c) for c in configs]
- )
+ if online_stores is None:
+ # No online stores requested -> setting the default or first available
+ online_stores = [
+ AVAILABLE_ONLINE_STORES.get(
+ "redis",
+ AVAILABLE_ONLINE_STORES.get(
+ "sqlite", next(iter(AVAILABLE_ONLINE_STORES.values()))
+ ),
+ )
+ ]
+
+ extra_dimensions: List[Dict[str, Any]] = [{}]
+
+ if "python_server" in metafunc.fixturenames:
+ extra_dimensions.extend([{"python_feature_server": True}])
+
+ configs = []
+ if offline_stores:
+ for provider, offline_store_creator in offline_stores:
+ for online_store, online_store_creator in online_stores:
+ for dim in extra_dimensions:
+ config = {
+ "provider": provider,
+ "offline_store_creator": offline_store_creator,
+ "online_store": online_store,
+ "online_store_creator": online_store_creator,
+ **dim,
+ }
+
+ c = IntegrationTestRepoConfig(**config)
+
+ if c not in _config_cache:
+ marks = [
+ pytest.mark.xdist_group(name=m)
+ for m in c.offline_store_creator.xdist_groups()
+ ]
+ # Check if there are any test markers associated with the creator and add them.
+ if c.offline_store_creator.test_markers():
+ marks.extend(c.offline_store_creator.test_markers())
+
+ _config_cache[c] = pytest.param(c, marks=marks)
+
+ configs.append(_config_cache[c])
+ else:
+ # No offline stores requested -> setting the default or first available
+ offline_stores = [("local", FileDataSourceCreator)]
+
+ metafunc.parametrize(
+ "environment", configs, indirect=True, ids=[str(c) for c in configs]
+ )
@pytest.fixture
diff --git a/sdk/python/tests/unit/test_image_utils_optional_deps.py b/sdk/python/tests/unit/test_image_utils_optional_deps.py
new file mode 100644
index 00000000000..3821af52845
--- /dev/null
+++ b/sdk/python/tests/unit/test_image_utils_optional_deps.py
@@ -0,0 +1,33 @@
+# Copyright 2024 The Feast Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+
+def test_validate_image_format_raises_when_deps_missing(monkeypatch):
+ from feast import image_utils
+
+ monkeypatch.setattr(image_utils, "_image_dependencies_available", False)
+
+ with pytest.raises(ImportError, match="Image processing dependencies are not installed"):
+ image_utils.validate_image_format(b"anything")
+
+
+def test_get_image_metadata_raises_when_deps_missing(monkeypatch):
+ from feast import image_utils
+
+ monkeypatch.setattr(image_utils, "_image_dependencies_available", False)
+
+ with pytest.raises(ImportError, match="Image processing dependencies are not installed"):
+ image_utils.get_image_metadata(b"anything")
From 13356c0d098909b17f5796714c3d9cea158f1f1e Mon Sep 17 00:00:00 2001
From: Shizoqua
Date: Sat, 3 Jan 2026 17:14:26 +0100
Subject: [PATCH 2/6] tests: clarify platform detection and marker handling in
conftest
Signed-off-by: Shizoqua
Signed-off-by: Lanre Shittu <136805224+Shizoqua@users.noreply.github.com>
Signed-off-by: Shizoqua
---
sdk/python/tests/conftest.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py
index 46beb19ceaa..280dc6ea2cc 100644
--- a/sdk/python/tests/conftest.py
+++ b/sdk/python/tests/conftest.py
@@ -18,7 +18,7 @@
import tempfile
from datetime import timedelta
from multiprocessing import Process
-from sys import platform
+import sys
from textwrap import dedent
from typing import Any, Dict, List, Tuple, no_type_check
from unittest import mock
@@ -126,7 +126,7 @@ def default_store(*args, **kwargs): # type: ignore[no-redef]
def pytest_configure(config):
- if platform == "darwin" or platform.startswith("win"):
+ if sys.platform == "darwin" or sys.platform.startswith("win"):
multiprocessing.set_start_method("spawn", force=True)
else:
multiprocessing.set_start_method("fork")
@@ -287,7 +287,9 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
if not _integration_test_deps_available:
pytest.skip("Integration test dependencies are not available")
- markers = {m.name: m for m in metafunc.definition.iter_markers()}
+ own_markers = getattr(metafunc.definition, "own_markers", None)
+ marker_iter = own_markers if own_markers is not None else metafunc.definition.iter_markers()
+ markers = {m.name: m for m in marker_iter}
offline_stores = None
if "universal_offline_stores" in markers:
From 1f16439df37ec676433b7befa3335ee75f1e140d Mon Sep 17 00:00:00 2001
From: Shizoqua
Date: Fri, 9 Jan 2026 02:48:27 +0100
Subject: [PATCH 3/6] fix: avoid creating empty Milvus indexes
Signed-off-by: Lanre Shittu <136805224+Shizoqua@users.noreply.github.com>
Signed-off-by: Shizoqua
---
.../infra/online_stores/milvus_online_store/milvus.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
index 37c88850ee7..b5fb1a2b88c 100644
--- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
+++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
@@ -202,6 +202,7 @@ def _get_or_create_collection(
schema=schema,
)
index_params = self.client.prepare_index_params()
+ added_index = False
for vector_field in schema.fields:
if (
vector_field.dtype
@@ -222,10 +223,12 @@ def _get_or_create_collection(
index_name=f"vector_index_{vector_field.name}",
params={"nlist": config.online_store.nlist},
)
- self.client.create_index(
- collection_name=collection_name,
- index_params=index_params,
- )
+ added_index = True
+ if added_index:
+ self.client.create_index(
+ collection_name=collection_name,
+ index_params=index_params,
+ )
else:
self.client.load_collection(collection_name)
self._collections[collection_name] = self.client.describe_collection(
From 62fde30de07f26e8a533db6e6c92a5231081a9a1 Mon Sep 17 00:00:00 2001
From: Lanre Shittu <136805224+Shizoqua@users.noreply.github.com>
Date: Fri, 9 Jan 2026 02:58:40 +0100
Subject: [PATCH 4/6] Delete
sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
Signed-off-by: Lanre Shittu <136805224+Shizoqua@users.noreply.github.com>
Signed-off-by: Shizoqua
---
.../milvus_online_store/milvus.py | 772 ------------------
1 file changed, 772 deletions(-)
delete mode 100644 sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
deleted file mode 100644
index b5fb1a2b88c..00000000000
--- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
+++ /dev/null
@@ -1,772 +0,0 @@
-import base64
-from datetime import datetime
-from pathlib import Path
-from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
-
-from pydantic import StrictStr
-from pymilvus import (
- CollectionSchema,
- DataType,
- FieldSchema,
- MilvusClient,
-)
-
-from feast import Entity
-from feast.feature_view import FeatureView
-from feast.infra.infra_object import InfraObject
-from feast.infra.key_encoding_utils import (
- deserialize_entity_key,
- serialize_entity_key,
-)
-from feast.infra.online_stores.online_store import OnlineStore
-from feast.infra.online_stores.vector_store import VectorStoreConfig
-from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
-from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
-from feast.protos.feast.types.Value_pb2 import Value as ValueProto
-from feast.repo_config import FeastConfigBaseModel, RepoConfig
-from feast.type_map import (
- PROTO_VALUE_TO_VALUE_TYPE_MAP,
- VALUE_TYPE_TO_PROTO_VALUE_MAP,
- feast_value_type_to_python_type,
-)
-from feast.types import (
- VALUE_TYPES_TO_FEAST_TYPES,
- Array,
- ComplexFeastType,
- PrimitiveFeastType,
- ValueType,
- from_feast_type,
-)
-from feast.utils import (
- _serialize_vector_to_float_list,
- to_naive_utc,
-)
-
-PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = {
- PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR,
- ValueType.IMAGE_BYTES: DataType.VARCHAR,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_val"]: DataType.BOOL,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.VARCHAR,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["float_val"]: DataType.FLOAT,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["double_val"]: DataType.DOUBLE,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_val"]: DataType.INT32,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["int64_val"]: DataType.INT64,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["float_list_val"]: DataType.FLOAT_VECTOR,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_list_val"]: DataType.FLOAT_VECTOR,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["int64_list_val"]: DataType.FLOAT_VECTOR,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["double_list_val"]: DataType.FLOAT_VECTOR,
- PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_list_val"]: DataType.BINARY_VECTOR,
-}
-
-FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING: Dict[
- Union[PrimitiveFeastType, Array, ComplexFeastType], DataType
-] = {}
-
-for value_type, feast_type in VALUE_TYPES_TO_FEAST_TYPES.items():
- if isinstance(feast_type, PrimitiveFeastType):
- milvus_type = PROTO_TO_MILVUS_TYPE_MAPPING.get(value_type)
- if milvus_type:
- FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = milvus_type
- elif isinstance(feast_type, Array):
- base_type = feast_type.base_type
- base_value_type = base_type.to_value_type()
- if base_value_type in [
- ValueType.INT32,
- ValueType.INT64,
- ValueType.FLOAT,
- ValueType.DOUBLE,
- ]:
- FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.FLOAT_VECTOR
- elif base_value_type == ValueType.STRING:
- FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.VARCHAR
- elif base_value_type == ValueType.BOOL:
- FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.BINARY_VECTOR
-
-
-class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
- """
- Configuration for the Milvus online store.
- NOTE: The class *must* end with the `OnlineStoreConfig` suffix.
- """
-
- type: Literal["milvus"] = "milvus"
- path: Optional[StrictStr] = ""
- host: Optional[StrictStr] = "http://localhost"
- port: Optional[int] = 19530
- index_type: Optional[str] = "FLAT"
- metric_type: Optional[str] = "COSINE"
- embedding_dim: Optional[int] = 128
- vector_enabled: Optional[bool] = True
- text_search_enabled: Optional[bool] = False
- nlist: Optional[int] = 128
- username: Optional[StrictStr] = ""
- password: Optional[StrictStr] = ""
-
-
-class MilvusOnlineStore(OnlineStore):
- """
- Milvus implementation of the online store interface.
-
- Attributes:
- _collections: Dictionary to cache Milvus collections.
- """
-
- client: Optional[MilvusClient] = None
- _collections: Dict[str, Any] = {}
-
- def _get_db_path(self, config: RepoConfig) -> str:
- assert (
- config.online_store.type == "milvus"
- or config.online_store.type.endswith("MilvusOnlineStore")
- )
-
- if config.repo_path and not Path(config.online_store.path).is_absolute():
- db_path = str(config.repo_path / config.online_store.path)
- else:
- db_path = config.online_store.path
- return db_path
-
- def _connect(self, config: RepoConfig) -> MilvusClient:
- if not self.client:
- if config.provider == "local" and config.online_store.path:
- db_path = self._get_db_path(config)
- print(f"Connecting to Milvus in local mode using {db_path}")
- self.client = MilvusClient(db_path)
- else:
- print(
- f"Connecting to Milvus remotely at {config.online_store.host}:{config.online_store.port}"
- )
- self.client = MilvusClient(
- uri=f"{config.online_store.host}:{config.online_store.port}",
- token=f"{config.online_store.username}:{config.online_store.password}"
- if config.online_store.username and config.online_store.password
- else "",
- )
- return self.client
-
- def _get_or_create_collection(
- self, config: RepoConfig, table: FeatureView
- ) -> Dict[str, Any]:
- self.client = self._connect(config)
- vector_field_dict = {k.name: k for k in table.schema if k.vector_index}
- collection_name = _table_id(config.project, table)
- if collection_name not in self._collections:
- # Create a composite key by combining entity fields
- composite_key_name = _get_composite_key_name(table)
-
- fields = [
- FieldSchema(
- name=composite_key_name,
- dtype=DataType.VARCHAR,
- max_length=512,
- is_primary=True,
- ),
- FieldSchema(name="event_ts", dtype=DataType.INT64),
- FieldSchema(name="created_ts", dtype=DataType.INT64),
- ]
- fields_to_exclude = [
- "event_ts",
- "created_ts",
- ]
- fields_to_add = [f for f in table.schema if f.name not in fields_to_exclude]
- for field in fields_to_add:
- dtype = FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING.get(field.dtype)
- if dtype:
- if dtype == DataType.FLOAT_VECTOR:
- fields.append(
- FieldSchema(
- name=field.name,
- dtype=dtype,
- dim=config.online_store.embedding_dim,
- )
- )
- else:
- fields.append(
- FieldSchema(
- name=field.name,
- dtype=DataType.VARCHAR,
- max_length=512,
- )
- )
-
- schema = CollectionSchema(
- fields=fields, description="Feast feature view data"
- )
- collection_exists = self.client.has_collection(
- collection_name=collection_name
- )
- if not collection_exists:
- self.client.create_collection(
- collection_name=collection_name,
- dimension=config.online_store.embedding_dim,
- schema=schema,
- )
- index_params = self.client.prepare_index_params()
- added_index = False
- for vector_field in schema.fields:
- if (
- vector_field.dtype
- in [
- DataType.FLOAT_VECTOR,
- DataType.BINARY_VECTOR,
- ]
- and vector_field.name in vector_field_dict
- ):
- metric = vector_field_dict[
- vector_field.name
- ].vector_search_metric
- index_params.add_index(
- collection_name=collection_name,
- field_name=vector_field.name,
- metric_type=metric or config.online_store.metric_type,
- index_type=config.online_store.index_type,
- index_name=f"vector_index_{vector_field.name}",
- params={"nlist": config.online_store.nlist},
- )
- added_index = True
- if added_index:
- self.client.create_index(
- collection_name=collection_name,
- index_params=index_params,
- )
- else:
- self.client.load_collection(collection_name)
- self._collections[collection_name] = self.client.describe_collection(
- collection_name
- )
- return self._collections[collection_name]
-
- def online_write_batch(
- self,
- config: RepoConfig,
- table: FeatureView,
- data: List[
- Tuple[
- EntityKeyProto,
- Dict[str, ValueProto],
- datetime,
- Optional[datetime],
- ]
- ],
- progress: Optional[Callable[[int], Any]],
- ) -> None:
- self.client = self._connect(config)
- collection = self._get_or_create_collection(config, table)
- vector_cols = [f.name for f in table.features if f.vector_index]
- entity_batch_to_insert = []
- unique_entities: dict[str, dict[str, Any]] = {}
- required_fields = {field["name"] for field in collection["fields"]}
- for entity_key, values_dict, timestamp, created_ts in data:
- # need to construct the composite primary key also need to handle the fact that entities are a list
- entity_key_str = serialize_entity_key(
- entity_key,
- entity_key_serialization_version=config.entity_key_serialization_version,
- ).hex()
- # to recover the entity key just run:
- # deserialize_entity_key(bytes.fromhex(entity_key_str), entity_key_serialization_version=3)
- composite_key_name = _get_composite_key_name(table)
-
- timestamp_int = int(to_naive_utc(timestamp).timestamp() * 1e6)
- created_ts_int = (
- int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0
- )
- entity_dict = {
- join_key: feast_value_type_to_python_type(value)
- for join_key, value in zip(
- entity_key.join_keys, entity_key.entity_values
- )
- }
- values_dict.update(entity_dict)
- values_dict = _extract_proto_values_to_dict(
- values_dict,
- vector_cols=vector_cols,
- serialize_to_string=True,
- )
-
- single_entity_record = {
- composite_key_name: entity_key_str,
- "event_ts": timestamp_int,
- "created_ts": created_ts_int,
- }
- single_entity_record.update(values_dict)
- # Ensure all required fields exist, setting missing ones to empty strings
- for field in required_fields:
- if field not in single_entity_record:
- single_entity_record[field] = ""
- # Store only the latest event timestamp per entity
- if (
- entity_key_str not in unique_entities
- or unique_entities[entity_key_str]["event_ts"] < timestamp_int
- ):
- unique_entities[entity_key_str] = single_entity_record
-
- if progress:
- progress(1)
-
- entity_batch_to_insert = list(unique_entities.values())
- self.client.upsert(
- collection_name=collection["collection_name"],
- data=entity_batch_to_insert,
- )
-
- def online_read(
- self,
- config: RepoConfig,
- table: FeatureView,
- entity_keys: List[EntityKeyProto],
- requested_features: Optional[List[str]] = None,
- full_feature_names: bool = False,
- ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
- self.client = self._connect(config)
- collection_name = _table_id(config.project, table)
- collection = self._get_or_create_collection(config, table)
-
- composite_key_name = _get_composite_key_name(table)
-
- output_fields = (
- [composite_key_name]
- + (requested_features if requested_features else [])
- + ["created_ts", "event_ts"]
- )
- assert all(
- field in [f["name"] for f in collection["fields"]]
- for field in output_fields
- ), (
- f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema"
- )
- composite_entities = []
- for entity_key in entity_keys:
- entity_key_str = serialize_entity_key(
- entity_key,
- entity_key_serialization_version=config.entity_key_serialization_version,
- ).hex()
- composite_entities.append(entity_key_str)
-
- query_filter_for_entities = (
- f"{composite_key_name} in ["
- + ", ".join([f"'{e}'" for e in composite_entities])
- + "]"
- )
- self.client.load_collection(collection_name)
- results = self.client.query(
- collection_name=collection_name,
- filter=query_filter_for_entities,
- output_fields=output_fields,
- )
- # Group hits by composite key.
- grouped_hits: Dict[str, Any] = {}
- for hit in results:
- key = hit.get(composite_key_name)
- grouped_hits.setdefault(key, []).append(hit)
-
- # Map the features to their Feast types.
- feature_name_feast_primitive_type_map = {
- f.name: f.dtype for f in table.features
- }
- if getattr(table, "write_to_online_store", False):
- feature_name_feast_primitive_type_map.update(
- {f.name: f.dtype for f in table.schema}
- )
- # Build a dictionary mapping composite key -> (res_ts, res)
- results_dict: Dict[
- str, Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
- ] = {}
-
- # here we need to map the data stored as characters back into the protobuf value
- for hit in results:
- key = hit.get(composite_key_name)
- # Only take one hit per composite key (adjust if you need aggregation)
- if key not in results_dict:
- res = {}
- res_ts = None
- for field in output_fields:
- val = ValueProto()
- field_value = hit.get(field, None)
- if field_value is None and ":" in field:
- _, field_short = field.split(":", 1)
- field_value = hit.get(field_short)
-
- if field in ["created_ts", "event_ts"]:
- res_ts = datetime.fromtimestamp(field_value / 1e6)
- elif field == composite_key_name:
- # We do not return the composite key value
- pass
- else:
- feature_feast_primitive_type = (
- feature_name_feast_primitive_type_map.get(
- field, PrimitiveFeastType.INVALID
- )
- )
- feature_fv_dtype = from_feast_type(feature_feast_primitive_type)
- proto_attr = VALUE_TYPE_TO_PROTO_VALUE_MAP.get(feature_fv_dtype)
- if proto_attr:
- if proto_attr == "bytes_val":
- setattr(val, proto_attr, field_value.encode())
- elif proto_attr in [
- "int32_val",
- "int64_val",
- "float_val",
- "double_val",
- "string_val",
- ]:
- setattr(
- val,
- proto_attr,
- type(getattr(val, proto_attr))(field_value),
- )
- elif proto_attr in [
- "int32_list_val",
- "int64_list_val",
- "float_list_val",
- "double_list_val",
- ]:
- getattr(val, proto_attr).val.extend(field_value)
- else:
- setattr(val, proto_attr, field_value)
- else:
- raise ValueError(
- f"Unsupported ValueType: {feature_feast_primitive_type} with feature view value {field_value} for feature {field} with value type {proto_attr}"
- )
- # res[field] = val
- key_to_use = field.split(":", 1)[-1] if ":" in field else field
- res[key_to_use] = val
- results_dict[key] = (res_ts, res if res else None)
-
- # Map the results back into a list matching the original order of composite_keys.
- result_list = [
- results_dict.get(key, (None, None)) for key in composite_entities
- ]
-
- return result_list
-
- def update(
- self,
- config: RepoConfig,
- tables_to_delete: Sequence[FeatureView],
- tables_to_keep: Sequence[FeatureView],
- entities_to_delete: Sequence[Entity],
- entities_to_keep: Sequence[Entity],
- partial: bool,
- ):
- self.client = self._connect(config)
- for table in tables_to_keep:
- self._collections = self._get_or_create_collection(config, table)
-
- for table in tables_to_delete:
- collection_name = _table_id(config.project, table)
- if self._collections.get(collection_name, None):
- self.client.drop_collection(collection_name)
- self._collections.pop(collection_name, None)
-
- def plan(
- self, config: RepoConfig, desired_registry_proto: RegistryProto
- ) -> List[InfraObject]:
- raise NotImplementedError
-
- def teardown(
- self,
- config: RepoConfig,
- tables: Sequence[FeatureView],
- entities: Sequence[Entity],
- ):
- self.client = self._connect(config)
- for table in tables:
- collection_name = _table_id(config.project, table)
- if self._collections.get(collection_name, None):
- self.client.drop_collection(collection_name)
- self._collections.pop(collection_name, None)
-
- def retrieve_online_documents_v2(
- self,
- config: RepoConfig,
- table: FeatureView,
- requested_features: List[str],
- embedding: Optional[List[float]],
- top_k: int,
- distance_metric: Optional[str] = None,
- query_string: Optional[str] = None,
- ) -> List[
- Tuple[
- Optional[datetime],
- Optional[EntityKeyProto],
- Optional[Dict[str, ValueProto]],
- ]
- ]:
- """
- Retrieve documents using vector similarity search or keyword search in Milvus.
- Args:
- config: Feast configuration object
- table: FeatureView object as the table to search
- requested_features: List of requested features to retrieve
- embedding: Query embedding to search for (optional)
- top_k: Number of items to return
- distance_metric: Distance metric to use (optional)
- query_string: The query string to search for using keyword search (optional)
- Returns:
- List of tuples containing the event timestamp, entity key, and feature values
- """
- entity_name_feast_primitive_type_map = {
- k.name: k.dtype for k in table.entity_columns
- }
- self.client = self._connect(config)
- collection_name = _table_id(config.project, table)
- collection = self._get_or_create_collection(config, table)
- if not config.online_store.vector_enabled:
- raise ValueError("Vector search is not enabled in the online store config")
-
- if embedding is None and query_string is None:
- raise ValueError("Either embedding or query_string must be provided")
-
- composite_key_name = _get_composite_key_name(table)
-
- output_fields = (
- [composite_key_name]
- + (requested_features if requested_features else [])
- + ["created_ts", "event_ts"]
- )
- assert all(
- field in [f["name"] for f in collection["fields"]]
- for field in output_fields
- ), (
- f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema"
- )
-
- # Find the vector search field if we need it
- ann_search_field = None
- if embedding is not None:
- for field in collection["fields"]:
- if (
- field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
- and field["name"] in output_fields
- ):
- ann_search_field = field["name"]
- break
-
- self.client.load_collection(collection_name)
-
- if (
- embedding is not None
- and query_string is not None
- and config.online_store.vector_enabled
- ):
- string_field_list = [
- f.name
- for f in table.features
- if isinstance(f.dtype, PrimitiveFeastType)
- and f.dtype.to_value_type() == ValueType.STRING
- ]
-
- if not string_field_list:
- raise ValueError(
- "No string fields found in the feature view for text search in hybrid mode"
- )
-
- # Create a filter expression for text search
- filter_expressions = []
- for field in string_field_list:
- if field in output_fields:
- filter_expressions.append(f"{field} LIKE '%{query_string}%'")
-
- # Combine filter expressions with OR
- filter_expr = " OR ".join(filter_expressions) if filter_expressions else ""
-
- # Vector search with text filter
- search_params = {
- "metric_type": distance_metric or config.online_store.metric_type,
- "params": {"nprobe": 10},
- }
-
- # For hybrid search, use filter parameter instead of expr
- results = self.client.search(
- collection_name=collection_name,
- data=[embedding],
- anns_field=ann_search_field,
- search_params=search_params,
- limit=top_k,
- output_fields=output_fields,
- filter=filter_expr if filter_expr else None,
- )
-
- elif embedding is not None and config.online_store.vector_enabled:
- # Vector search only
- search_params = {
- "metric_type": distance_metric or config.online_store.metric_type,
- "params": {"nprobe": 10},
- }
-
- results = self.client.search(
- collection_name=collection_name,
- data=[embedding],
- anns_field=ann_search_field,
- search_params=search_params,
- limit=top_k,
- output_fields=output_fields,
- )
-
- elif query_string is not None:
- string_field_list = [
- f.name
- for f in table.features
- if isinstance(f.dtype, PrimitiveFeastType)
- and f.dtype.to_value_type() == ValueType.STRING
- ]
-
- if not string_field_list:
- raise ValueError(
- "No string fields found in the feature view for text search"
- )
-
- filter_expressions = []
- for field in string_field_list:
- if field in output_fields:
- filter_expressions.append(f"{field} LIKE '%{query_string}%'")
-
- filter_expr = " OR ".join(filter_expressions)
-
- if not filter_expr:
- raise ValueError(
- "No text fields found in requested features for search"
- )
-
- query_results = self.client.query(
- collection_name=collection_name,
- filter=filter_expr,
- output_fields=output_fields,
- limit=top_k,
- )
-
- results = [
- [{"entity": entity, "distance": -1.0}] for entity in query_results
- ]
- else:
- raise ValueError(
- "Either vector_enabled must be True for embedding search or query_string must be provided for keyword search"
- )
-
- result_list = []
- for hits in results:
- for hit in hits:
- res = {}
- res_ts = None
- entity_key_bytes = bytes.fromhex(
- hit.get("entity", {}).get(composite_key_name, None)
- )
- entity_key_proto = (
- deserialize_entity_key(entity_key_bytes)
- if entity_key_bytes
- else None
- )
- for field in output_fields:
- val = ValueProto()
- field_value = hit.get("entity", {}).get(field, None)
- # entity_key_proto = None
- if field in ["created_ts", "event_ts"]:
- res_ts = datetime.fromtimestamp(field_value / 1e6)
- elif field == ann_search_field and embedding is not None:
- serialized_embedding = _serialize_vector_to_float_list(
- embedding
- )
- res[ann_search_field] = serialized_embedding
- elif (
- entity_name_feast_primitive_type_map.get(
- field, PrimitiveFeastType.INVALID
- )
- == PrimitiveFeastType.STRING
- ):
- res[field] = ValueProto(string_val=str(field_value))
- elif (
- entity_name_feast_primitive_type_map.get(
- field, PrimitiveFeastType.INVALID
- )
- == PrimitiveFeastType.BYTES
- ):
- try:
- decoded_bytes = base64.b64decode(field_value)
- res[field] = ValueProto(bytes_val=decoded_bytes)
- except Exception:
- res[field] = ValueProto(string_val=str(field_value))
- elif entity_name_feast_primitive_type_map.get(
- field, PrimitiveFeastType.INVALID
- ) in [
- PrimitiveFeastType.INT64,
- PrimitiveFeastType.INT32,
- ]:
- res[field] = ValueProto(int64_val=int(field_value))
- elif field == composite_key_name:
- pass
- elif isinstance(field_value, bytes):
- val.ParseFromString(field_value)
- res[field] = val
- else:
- val.string_val = field_value
- res[field] = val
- distance = hit.get("distance", None)
- res["distance"] = (
- ValueProto(float_val=distance) if distance else ValueProto()
- )
- result_list.append((res_ts, entity_key_proto, res if res else None))
- return result_list
-
-
-def _table_id(project: str, table: FeatureView) -> str:
- return f"{project}_{table.name}"
-
-
-def _get_composite_key_name(table: FeatureView) -> str:
- return "_".join([field.name for field in table.entity_columns]) + "_pk"
-
-
-def _extract_proto_values_to_dict(
- input_dict: Dict[str, Any],
- vector_cols: List[str],
- serialize_to_string=False,
-) -> Dict[str, Any]:
- numeric_vector_list_types = [
- k
- for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys()
- if k is not None and "list" in k and "string" not in k
- ]
- numeric_types = [
- "double_val",
- "float_val",
- "int32_val",
- "int64_val",
- "bool_val",
- ]
- output_dict = {}
- for feature_name, feature_values in input_dict.items():
- for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
- if not isinstance(feature_values, (int, float, str)):
- if feature_values.HasField(proto_val_type):
- if proto_val_type in numeric_vector_list_types:
- if serialize_to_string and feature_name not in vector_cols:
- vector_values = getattr(
- feature_values, proto_val_type
- ).SerializeToString()
- else:
- vector_values = getattr(feature_values, proto_val_type).val
- else:
- if (
- serialize_to_string
- and proto_val_type
- not in ["string_val", "bytes_val"] + numeric_types
- ):
- vector_values = feature_values.SerializeToString().decode()
- elif proto_val_type == "bytes_val":
- byte_data = getattr(feature_values, proto_val_type)
- vector_values = base64.b64encode(byte_data).decode("utf-8")
- else:
- if not isinstance(feature_values, str):
- vector_values = str(
- getattr(feature_values, proto_val_type)
- )
- else:
- vector_values = getattr(feature_values, proto_val_type)
- output_dict[feature_name] = vector_values
- else:
- if serialize_to_string:
- if not isinstance(feature_values, str):
- feature_values = str(feature_values)
- output_dict[feature_name] = feature_values
-
- return output_dict
From 2b2a4f23b10f8129343d32f21c922c79acf3d561 Mon Sep 17 00:00:00 2001
From: Shizoqua
Date: Sat, 10 Jan 2026 04:58:23 +0100
Subject: [PATCH 5/6] Restore Milvus online store file
Signed-off-by: Shizoqua
---
.../milvus_online_store/milvus.py | 769 ++++++++++++++++++
1 file changed, 769 insertions(+)
create mode 100644 sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
new file mode 100644
index 00000000000..37c88850ee7
--- /dev/null
+++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
@@ -0,0 +1,769 @@
+import base64
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
+
+from pydantic import StrictStr
+from pymilvus import (
+ CollectionSchema,
+ DataType,
+ FieldSchema,
+ MilvusClient,
+)
+
+from feast import Entity
+from feast.feature_view import FeatureView
+from feast.infra.infra_object import InfraObject
+from feast.infra.key_encoding_utils import (
+ deserialize_entity_key,
+ serialize_entity_key,
+)
+from feast.infra.online_stores.online_store import OnlineStore
+from feast.infra.online_stores.vector_store import VectorStoreConfig
+from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
+from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
+from feast.protos.feast.types.Value_pb2 import Value as ValueProto
+from feast.repo_config import FeastConfigBaseModel, RepoConfig
+from feast.type_map import (
+ PROTO_VALUE_TO_VALUE_TYPE_MAP,
+ VALUE_TYPE_TO_PROTO_VALUE_MAP,
+ feast_value_type_to_python_type,
+)
+from feast.types import (
+ VALUE_TYPES_TO_FEAST_TYPES,
+ Array,
+ ComplexFeastType,
+ PrimitiveFeastType,
+ ValueType,
+ from_feast_type,
+)
+from feast.utils import (
+ _serialize_vector_to_float_list,
+ to_naive_utc,
+)
+
+PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = {
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR,
+ ValueType.IMAGE_BYTES: DataType.VARCHAR,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_val"]: DataType.BOOL,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.VARCHAR,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["float_val"]: DataType.FLOAT,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["double_val"]: DataType.DOUBLE,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_val"]: DataType.INT32,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["int64_val"]: DataType.INT64,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["float_list_val"]: DataType.FLOAT_VECTOR,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_list_val"]: DataType.FLOAT_VECTOR,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["int64_list_val"]: DataType.FLOAT_VECTOR,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["double_list_val"]: DataType.FLOAT_VECTOR,
+ PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_list_val"]: DataType.BINARY_VECTOR,
+}
+
+FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING: Dict[
+ Union[PrimitiveFeastType, Array, ComplexFeastType], DataType
+] = {}
+
+for value_type, feast_type in VALUE_TYPES_TO_FEAST_TYPES.items():
+ if isinstance(feast_type, PrimitiveFeastType):
+ milvus_type = PROTO_TO_MILVUS_TYPE_MAPPING.get(value_type)
+ if milvus_type:
+ FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = milvus_type
+ elif isinstance(feast_type, Array):
+ base_type = feast_type.base_type
+ base_value_type = base_type.to_value_type()
+ if base_value_type in [
+ ValueType.INT32,
+ ValueType.INT64,
+ ValueType.FLOAT,
+ ValueType.DOUBLE,
+ ]:
+ FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.FLOAT_VECTOR
+ elif base_value_type == ValueType.STRING:
+ FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.VARCHAR
+ elif base_value_type == ValueType.BOOL:
+ FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.BINARY_VECTOR
+
+
+class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
+ """
+ Configuration for the Milvus online store.
+ NOTE: The class *must* end with the `OnlineStoreConfig` suffix.
+ """
+
+ type: Literal["milvus"] = "milvus"
+ path: Optional[StrictStr] = ""
+ host: Optional[StrictStr] = "http://localhost"
+ port: Optional[int] = 19530
+ index_type: Optional[str] = "FLAT"
+ metric_type: Optional[str] = "COSINE"
+ embedding_dim: Optional[int] = 128
+ vector_enabled: Optional[bool] = True
+ text_search_enabled: Optional[bool] = False
+ nlist: Optional[int] = 128
+ username: Optional[StrictStr] = ""
+ password: Optional[StrictStr] = ""
+
+
+class MilvusOnlineStore(OnlineStore):
+ """
+ Milvus implementation of the online store interface.
+
+ Attributes:
+ _collections: Dictionary to cache Milvus collections.
+ """
+
+ client: Optional[MilvusClient] = None
+ _collections: Dict[str, Any] = {}
+
+ def _get_db_path(self, config: RepoConfig) -> str:
+ assert (
+ config.online_store.type == "milvus"
+ or config.online_store.type.endswith("MilvusOnlineStore")
+ )
+
+ if config.repo_path and not Path(config.online_store.path).is_absolute():
+ db_path = str(config.repo_path / config.online_store.path)
+ else:
+ db_path = config.online_store.path
+ return db_path
+
+ def _connect(self, config: RepoConfig) -> MilvusClient:
+ if not self.client:
+ if config.provider == "local" and config.online_store.path:
+ db_path = self._get_db_path(config)
+ print(f"Connecting to Milvus in local mode using {db_path}")
+ self.client = MilvusClient(db_path)
+ else:
+ print(
+ f"Connecting to Milvus remotely at {config.online_store.host}:{config.online_store.port}"
+ )
+ self.client = MilvusClient(
+ uri=f"{config.online_store.host}:{config.online_store.port}",
+ token=f"{config.online_store.username}:{config.online_store.password}"
+ if config.online_store.username and config.online_store.password
+ else "",
+ )
+ return self.client
+
+ def _get_or_create_collection(
+ self, config: RepoConfig, table: FeatureView
+ ) -> Dict[str, Any]:
+ self.client = self._connect(config)
+ vector_field_dict = {k.name: k for k in table.schema if k.vector_index}
+ collection_name = _table_id(config.project, table)
+ if collection_name not in self._collections:
+ # Create a composite key by combining entity fields
+ composite_key_name = _get_composite_key_name(table)
+
+ fields = [
+ FieldSchema(
+ name=composite_key_name,
+ dtype=DataType.VARCHAR,
+ max_length=512,
+ is_primary=True,
+ ),
+ FieldSchema(name="event_ts", dtype=DataType.INT64),
+ FieldSchema(name="created_ts", dtype=DataType.INT64),
+ ]
+ fields_to_exclude = [
+ "event_ts",
+ "created_ts",
+ ]
+ fields_to_add = [f for f in table.schema if f.name not in fields_to_exclude]
+ for field in fields_to_add:
+ dtype = FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING.get(field.dtype)
+ if dtype:
+ if dtype == DataType.FLOAT_VECTOR:
+ fields.append(
+ FieldSchema(
+ name=field.name,
+ dtype=dtype,
+ dim=config.online_store.embedding_dim,
+ )
+ )
+ else:
+ fields.append(
+ FieldSchema(
+ name=field.name,
+ dtype=DataType.VARCHAR,
+ max_length=512,
+ )
+ )
+
+ schema = CollectionSchema(
+ fields=fields, description="Feast feature view data"
+ )
+ collection_exists = self.client.has_collection(
+ collection_name=collection_name
+ )
+ if not collection_exists:
+ self.client.create_collection(
+ collection_name=collection_name,
+ dimension=config.online_store.embedding_dim,
+ schema=schema,
+ )
+ index_params = self.client.prepare_index_params()
+ for vector_field in schema.fields:
+ if (
+ vector_field.dtype
+ in [
+ DataType.FLOAT_VECTOR,
+ DataType.BINARY_VECTOR,
+ ]
+ and vector_field.name in vector_field_dict
+ ):
+ metric = vector_field_dict[
+ vector_field.name
+ ].vector_search_metric
+ index_params.add_index(
+ collection_name=collection_name,
+ field_name=vector_field.name,
+ metric_type=metric or config.online_store.metric_type,
+ index_type=config.online_store.index_type,
+ index_name=f"vector_index_{vector_field.name}",
+ params={"nlist": config.online_store.nlist},
+ )
+ self.client.create_index(
+ collection_name=collection_name,
+ index_params=index_params,
+ )
+ else:
+ self.client.load_collection(collection_name)
+ self._collections[collection_name] = self.client.describe_collection(
+ collection_name
+ )
+ return self._collections[collection_name]
+
+ def online_write_batch(
+ self,
+ config: RepoConfig,
+ table: FeatureView,
+ data: List[
+ Tuple[
+ EntityKeyProto,
+ Dict[str, ValueProto],
+ datetime,
+ Optional[datetime],
+ ]
+ ],
+ progress: Optional[Callable[[int], Any]],
+ ) -> None:
+ self.client = self._connect(config)
+ collection = self._get_or_create_collection(config, table)
+ vector_cols = [f.name for f in table.features if f.vector_index]
+ entity_batch_to_insert = []
+ unique_entities: dict[str, dict[str, Any]] = {}
+ required_fields = {field["name"] for field in collection["fields"]}
+ for entity_key, values_dict, timestamp, created_ts in data:
+ # need to construct the composite primary key also need to handle the fact that entities are a list
+ entity_key_str = serialize_entity_key(
+ entity_key,
+ entity_key_serialization_version=config.entity_key_serialization_version,
+ ).hex()
+ # to recover the entity key just run:
+ # deserialize_entity_key(bytes.fromhex(entity_key_str), entity_key_serialization_version=3)
+ composite_key_name = _get_composite_key_name(table)
+
+ timestamp_int = int(to_naive_utc(timestamp).timestamp() * 1e6)
+ created_ts_int = (
+ int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0
+ )
+ entity_dict = {
+ join_key: feast_value_type_to_python_type(value)
+ for join_key, value in zip(
+ entity_key.join_keys, entity_key.entity_values
+ )
+ }
+ values_dict.update(entity_dict)
+ values_dict = _extract_proto_values_to_dict(
+ values_dict,
+ vector_cols=vector_cols,
+ serialize_to_string=True,
+ )
+
+ single_entity_record = {
+ composite_key_name: entity_key_str,
+ "event_ts": timestamp_int,
+ "created_ts": created_ts_int,
+ }
+ single_entity_record.update(values_dict)
+ # Ensure all required fields exist, setting missing ones to empty strings
+ for field in required_fields:
+ if field not in single_entity_record:
+ single_entity_record[field] = ""
+ # Store only the latest event timestamp per entity
+ if (
+ entity_key_str not in unique_entities
+ or unique_entities[entity_key_str]["event_ts"] < timestamp_int
+ ):
+ unique_entities[entity_key_str] = single_entity_record
+
+ if progress:
+ progress(1)
+
+ entity_batch_to_insert = list(unique_entities.values())
+ self.client.upsert(
+ collection_name=collection["collection_name"],
+ data=entity_batch_to_insert,
+ )
+
+ def online_read(
+ self,
+ config: RepoConfig,
+ table: FeatureView,
+ entity_keys: List[EntityKeyProto],
+ requested_features: Optional[List[str]] = None,
+ full_feature_names: bool = False,
+ ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
+ self.client = self._connect(config)
+ collection_name = _table_id(config.project, table)
+ collection = self._get_or_create_collection(config, table)
+
+ composite_key_name = _get_composite_key_name(table)
+
+ output_fields = (
+ [composite_key_name]
+ + (requested_features if requested_features else [])
+ + ["created_ts", "event_ts"]
+ )
+ assert all(
+ field in [f["name"] for f in collection["fields"]]
+ for field in output_fields
+ ), (
+ f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema"
+ )
+ composite_entities = []
+ for entity_key in entity_keys:
+ entity_key_str = serialize_entity_key(
+ entity_key,
+ entity_key_serialization_version=config.entity_key_serialization_version,
+ ).hex()
+ composite_entities.append(entity_key_str)
+
+ query_filter_for_entities = (
+ f"{composite_key_name} in ["
+ + ", ".join([f"'{e}'" for e in composite_entities])
+ + "]"
+ )
+ self.client.load_collection(collection_name)
+ results = self.client.query(
+ collection_name=collection_name,
+ filter=query_filter_for_entities,
+ output_fields=output_fields,
+ )
+ # Group hits by composite key.
+ grouped_hits: Dict[str, Any] = {}
+ for hit in results:
+ key = hit.get(composite_key_name)
+ grouped_hits.setdefault(key, []).append(hit)
+
+ # Map the features to their Feast types.
+ feature_name_feast_primitive_type_map = {
+ f.name: f.dtype for f in table.features
+ }
+ if getattr(table, "write_to_online_store", False):
+ feature_name_feast_primitive_type_map.update(
+ {f.name: f.dtype for f in table.schema}
+ )
+ # Build a dictionary mapping composite key -> (res_ts, res)
+ results_dict: Dict[
+ str, Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
+ ] = {}
+
+ # here we need to map the data stored as characters back into the protobuf value
+ for hit in results:
+ key = hit.get(composite_key_name)
+ # Only take one hit per composite key (adjust if you need aggregation)
+ if key not in results_dict:
+ res = {}
+ res_ts = None
+ for field in output_fields:
+ val = ValueProto()
+ field_value = hit.get(field, None)
+ if field_value is None and ":" in field:
+ _, field_short = field.split(":", 1)
+ field_value = hit.get(field_short)
+
+ if field in ["created_ts", "event_ts"]:
+ res_ts = datetime.fromtimestamp(field_value / 1e6)
+ elif field == composite_key_name:
+ # We do not return the composite key value
+ pass
+ else:
+ feature_feast_primitive_type = (
+ feature_name_feast_primitive_type_map.get(
+ field, PrimitiveFeastType.INVALID
+ )
+ )
+ feature_fv_dtype = from_feast_type(feature_feast_primitive_type)
+ proto_attr = VALUE_TYPE_TO_PROTO_VALUE_MAP.get(feature_fv_dtype)
+ if proto_attr:
+ if proto_attr == "bytes_val":
+ setattr(val, proto_attr, field_value.encode())
+ elif proto_attr in [
+ "int32_val",
+ "int64_val",
+ "float_val",
+ "double_val",
+ "string_val",
+ ]:
+ setattr(
+ val,
+ proto_attr,
+ type(getattr(val, proto_attr))(field_value),
+ )
+ elif proto_attr in [
+ "int32_list_val",
+ "int64_list_val",
+ "float_list_val",
+ "double_list_val",
+ ]:
+ getattr(val, proto_attr).val.extend(field_value)
+ else:
+ setattr(val, proto_attr, field_value)
+ else:
+ raise ValueError(
+ f"Unsupported ValueType: {feature_feast_primitive_type} with feature view value {field_value} for feature {field} with value type {proto_attr}"
+ )
+ # res[field] = val
+ key_to_use = field.split(":", 1)[-1] if ":" in field else field
+ res[key_to_use] = val
+ results_dict[key] = (res_ts, res if res else None)
+
+ # Map the results back into a list matching the original order of composite_keys.
+ result_list = [
+ results_dict.get(key, (None, None)) for key in composite_entities
+ ]
+
+ return result_list
+
+ def update(
+ self,
+ config: RepoConfig,
+ tables_to_delete: Sequence[FeatureView],
+ tables_to_keep: Sequence[FeatureView],
+ entities_to_delete: Sequence[Entity],
+ entities_to_keep: Sequence[Entity],
+ partial: bool,
+ ):
+ self.client = self._connect(config)
+ for table in tables_to_keep:
+ self._collections = self._get_or_create_collection(config, table)
+
+ for table in tables_to_delete:
+ collection_name = _table_id(config.project, table)
+ if self._collections.get(collection_name, None):
+ self.client.drop_collection(collection_name)
+ self._collections.pop(collection_name, None)
+
+ def plan(
+ self, config: RepoConfig, desired_registry_proto: RegistryProto
+ ) -> List[InfraObject]:
+ raise NotImplementedError
+
+ def teardown(
+ self,
+ config: RepoConfig,
+ tables: Sequence[FeatureView],
+ entities: Sequence[Entity],
+ ):
+ self.client = self._connect(config)
+ for table in tables:
+ collection_name = _table_id(config.project, table)
+ if self._collections.get(collection_name, None):
+ self.client.drop_collection(collection_name)
+ self._collections.pop(collection_name, None)
+
+ def retrieve_online_documents_v2(
+ self,
+ config: RepoConfig,
+ table: FeatureView,
+ requested_features: List[str],
+ embedding: Optional[List[float]],
+ top_k: int,
+ distance_metric: Optional[str] = None,
+ query_string: Optional[str] = None,
+ ) -> List[
+ Tuple[
+ Optional[datetime],
+ Optional[EntityKeyProto],
+ Optional[Dict[str, ValueProto]],
+ ]
+ ]:
+ """
+ Retrieve documents using vector similarity search or keyword search in Milvus.
+ Args:
+ config: Feast configuration object
+ table: FeatureView object as the table to search
+ requested_features: List of requested features to retrieve
+ embedding: Query embedding to search for (optional)
+ top_k: Number of items to return
+ distance_metric: Distance metric to use (optional)
+ query_string: The query string to search for using keyword search (optional)
+ Returns:
+ List of tuples containing the event timestamp, entity key, and feature values
+ """
+ entity_name_feast_primitive_type_map = {
+ k.name: k.dtype for k in table.entity_columns
+ }
+ self.client = self._connect(config)
+ collection_name = _table_id(config.project, table)
+ collection = self._get_or_create_collection(config, table)
+ if not config.online_store.vector_enabled:
+ raise ValueError("Vector search is not enabled in the online store config")
+
+ if embedding is None and query_string is None:
+ raise ValueError("Either embedding or query_string must be provided")
+
+ composite_key_name = _get_composite_key_name(table)
+
+ output_fields = (
+ [composite_key_name]
+ + (requested_features if requested_features else [])
+ + ["created_ts", "event_ts"]
+ )
+ assert all(
+ field in [f["name"] for f in collection["fields"]]
+ for field in output_fields
+ ), (
+ f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema"
+ )
+
+ # Find the vector search field if we need it
+ ann_search_field = None
+ if embedding is not None:
+ for field in collection["fields"]:
+ if (
+ field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
+ and field["name"] in output_fields
+ ):
+ ann_search_field = field["name"]
+ break
+
+ self.client.load_collection(collection_name)
+
+ if (
+ embedding is not None
+ and query_string is not None
+ and config.online_store.vector_enabled
+ ):
+ string_field_list = [
+ f.name
+ for f in table.features
+ if isinstance(f.dtype, PrimitiveFeastType)
+ and f.dtype.to_value_type() == ValueType.STRING
+ ]
+
+ if not string_field_list:
+ raise ValueError(
+ "No string fields found in the feature view for text search in hybrid mode"
+ )
+
+ # Create a filter expression for text search
+ filter_expressions = []
+ for field in string_field_list:
+ if field in output_fields:
+ filter_expressions.append(f"{field} LIKE '%{query_string}%'")
+
+ # Combine filter expressions with OR
+ filter_expr = " OR ".join(filter_expressions) if filter_expressions else ""
+
+ # Vector search with text filter
+ search_params = {
+ "metric_type": distance_metric or config.online_store.metric_type,
+ "params": {"nprobe": 10},
+ }
+
+ # For hybrid search, use filter parameter instead of expr
+ results = self.client.search(
+ collection_name=collection_name,
+ data=[embedding],
+ anns_field=ann_search_field,
+ search_params=search_params,
+ limit=top_k,
+ output_fields=output_fields,
+ filter=filter_expr if filter_expr else None,
+ )
+
+ elif embedding is not None and config.online_store.vector_enabled:
+ # Vector search only
+ search_params = {
+ "metric_type": distance_metric or config.online_store.metric_type,
+ "params": {"nprobe": 10},
+ }
+
+ results = self.client.search(
+ collection_name=collection_name,
+ data=[embedding],
+ anns_field=ann_search_field,
+ search_params=search_params,
+ limit=top_k,
+ output_fields=output_fields,
+ )
+
+ elif query_string is not None:
+ string_field_list = [
+ f.name
+ for f in table.features
+ if isinstance(f.dtype, PrimitiveFeastType)
+ and f.dtype.to_value_type() == ValueType.STRING
+ ]
+
+ if not string_field_list:
+ raise ValueError(
+ "No string fields found in the feature view for text search"
+ )
+
+ filter_expressions = []
+ for field in string_field_list:
+ if field in output_fields:
+ filter_expressions.append(f"{field} LIKE '%{query_string}%'")
+
+ filter_expr = " OR ".join(filter_expressions)
+
+ if not filter_expr:
+ raise ValueError(
+ "No text fields found in requested features for search"
+ )
+
+ query_results = self.client.query(
+ collection_name=collection_name,
+ filter=filter_expr,
+ output_fields=output_fields,
+ limit=top_k,
+ )
+
+ results = [
+ [{"entity": entity, "distance": -1.0}] for entity in query_results
+ ]
+ else:
+ raise ValueError(
+ "Either vector_enabled must be True for embedding search or query_string must be provided for keyword search"
+ )
+
+ result_list = []
+ for hits in results:
+ for hit in hits:
+ res = {}
+ res_ts = None
+ entity_key_bytes = bytes.fromhex(
+ hit.get("entity", {}).get(composite_key_name, None)
+ )
+ entity_key_proto = (
+ deserialize_entity_key(entity_key_bytes)
+ if entity_key_bytes
+ else None
+ )
+ for field in output_fields:
+ val = ValueProto()
+ field_value = hit.get("entity", {}).get(field, None)
+ # entity_key_proto = None
+ if field in ["created_ts", "event_ts"]:
+ res_ts = datetime.fromtimestamp(field_value / 1e6)
+ elif field == ann_search_field and embedding is not None:
+ serialized_embedding = _serialize_vector_to_float_list(
+ embedding
+ )
+ res[ann_search_field] = serialized_embedding
+ elif (
+ entity_name_feast_primitive_type_map.get(
+ field, PrimitiveFeastType.INVALID
+ )
+ == PrimitiveFeastType.STRING
+ ):
+ res[field] = ValueProto(string_val=str(field_value))
+ elif (
+ entity_name_feast_primitive_type_map.get(
+ field, PrimitiveFeastType.INVALID
+ )
+ == PrimitiveFeastType.BYTES
+ ):
+ try:
+ decoded_bytes = base64.b64decode(field_value)
+ res[field] = ValueProto(bytes_val=decoded_bytes)
+ except Exception:
+ res[field] = ValueProto(string_val=str(field_value))
+ elif entity_name_feast_primitive_type_map.get(
+ field, PrimitiveFeastType.INVALID
+ ) in [
+ PrimitiveFeastType.INT64,
+ PrimitiveFeastType.INT32,
+ ]:
+ res[field] = ValueProto(int64_val=int(field_value))
+ elif field == composite_key_name:
+ pass
+ elif isinstance(field_value, bytes):
+ val.ParseFromString(field_value)
+ res[field] = val
+ else:
+ val.string_val = field_value
+ res[field] = val
+ distance = hit.get("distance", None)
+ res["distance"] = (
+ ValueProto(float_val=distance) if distance else ValueProto()
+ )
+ result_list.append((res_ts, entity_key_proto, res if res else None))
+ return result_list
+
+
+def _table_id(project: str, table: FeatureView) -> str:
+ return f"{project}_{table.name}"
+
+
+def _get_composite_key_name(table: FeatureView) -> str:
+ return "_".join([field.name for field in table.entity_columns]) + "_pk"
+
+
+def _extract_proto_values_to_dict(
+ input_dict: Dict[str, Any],
+ vector_cols: List[str],
+ serialize_to_string=False,
+) -> Dict[str, Any]:
+ numeric_vector_list_types = [
+ k
+ for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys()
+ if k is not None and "list" in k and "string" not in k
+ ]
+ numeric_types = [
+ "double_val",
+ "float_val",
+ "int32_val",
+ "int64_val",
+ "bool_val",
+ ]
+ output_dict = {}
+ for feature_name, feature_values in input_dict.items():
+ for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
+ if not isinstance(feature_values, (int, float, str)):
+ if feature_values.HasField(proto_val_type):
+ if proto_val_type in numeric_vector_list_types:
+ if serialize_to_string and feature_name not in vector_cols:
+ vector_values = getattr(
+ feature_values, proto_val_type
+ ).SerializeToString()
+ else:
+ vector_values = getattr(feature_values, proto_val_type).val
+ else:
+ if (
+ serialize_to_string
+ and proto_val_type
+ not in ["string_val", "bytes_val"] + numeric_types
+ ):
+ vector_values = feature_values.SerializeToString().decode()
+ elif proto_val_type == "bytes_val":
+ byte_data = getattr(feature_values, proto_val_type)
+ vector_values = base64.b64encode(byte_data).decode("utf-8")
+ else:
+ if not isinstance(feature_values, str):
+ vector_values = str(
+ getattr(feature_values, proto_val_type)
+ )
+ else:
+ vector_values = getattr(feature_values, proto_val_type)
+ output_dict[feature_name] = vector_values
+ else:
+ if serialize_to_string:
+ if not isinstance(feature_values, str):
+ feature_values = str(feature_values)
+ output_dict[feature_name] = feature_values
+
+ return output_dict
From e3efbf314e2479a702e074bcd45e5386d41afa19 Mon Sep 17 00:00:00 2001
From: Shizoqua
Date: Mon, 12 Jan 2026 21:11:17 +0100
Subject: [PATCH 6/6] test: fix ruff imports and stabilize CLI runner env
Signed-off-by: Shizoqua
---
sdk/python/tests/conftest.py | 10 ++++++++--
sdk/python/tests/utils/auth_permissions_util.py | 15 +++++++++++----
sdk/python/tests/utils/cli_repo_creator.py | 11 ++++++++++-
3 files changed, 29 insertions(+), 7 deletions(-)
diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py
index 280dc6ea2cc..df2f3fd731d 100644
--- a/sdk/python/tests/conftest.py
+++ b/sdk/python/tests/conftest.py
@@ -15,10 +15,10 @@
import multiprocessing
import os
import random
+import sys
import tempfile
from datetime import timedelta
from multiprocessing import Process
-import sys
from textwrap import dedent
from typing import Any, Dict, List, Tuple, no_type_check
from unittest import mock
@@ -36,6 +36,7 @@
create_document_dataset,
create_image_dataset,
)
+
try:
from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
IntegrationTestRepoConfig,
@@ -91,12 +92,15 @@ def driver(*args, **kwargs): # type: ignore[no-redef]
def location(*args, **kwargs): # type: ignore[no-redef]
raise RuntimeError("Integration test dependencies are not available")
+
try:
from tests.utils.auth_permissions_util import default_store
except ModuleNotFoundError:
def default_store(*args, **kwargs): # type: ignore[no-redef]
raise RuntimeError("Auth test dependencies are not available")
+
+
from tests.utils.http_server import check_port_open, free_port # noqa: E402
from tests.utils.ssl_certifcates_util import (
combine_trust_stores,
@@ -288,7 +292,9 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
pytest.skip("Integration test dependencies are not available")
own_markers = getattr(metafunc.definition, "own_markers", None)
- marker_iter = own_markers if own_markers is not None else metafunc.definition.iter_markers()
+ marker_iter = (
+ own_markers if own_markers is not None else metafunc.definition.iter_markers()
+ )
markers = {m.name: m for m in marker_iter}
offline_stores = None
diff --git a/sdk/python/tests/utils/auth_permissions_util.py b/sdk/python/tests/utils/auth_permissions_util.py
index c332a5ab8d3..e04686286a2 100644
--- a/sdk/python/tests/utils/auth_permissions_util.py
+++ b/sdk/python/tests/utils/auth_permissions_util.py
@@ -1,5 +1,6 @@
import os
import subprocess
+from pathlib import Path
import yaml
from keycloak import KeycloakAdmin
@@ -36,16 +37,22 @@ def default_store(
permissions: list[Permission],
):
runner = CliRunner()
- result = runner.run(["init", PROJECT_NAME], cwd=temp_dir)
+ result = runner.run(["init", PROJECT_NAME], cwd=Path(temp_dir))
repo_path = os.path.join(temp_dir, PROJECT_NAME, "feature_repo")
- assert result.returncode == 0
+ assert result.returncode == 0, (
+ f"feast init failed. stdout:\n{result.stdout.decode(errors='ignore')}\n"
+ f"stderr:\n{result.stderr.decode(errors='ignore')}\n"
+ )
include_auth_config(
file_path=f"{repo_path}/feature_store.yaml", auth_config=auth_config
)
- result = runner.run(["--chdir", repo_path, "apply"], cwd=temp_dir)
- assert result.returncode == 0
+ result = runner.run(["--chdir", repo_path, "apply"], cwd=Path(temp_dir))
+ assert result.returncode == 0, (
+ f"feast apply failed. stdout:\n{result.stdout.decode(errors='ignore')}\n"
+ f"stderr:\n{result.stderr.decode(errors='ignore')}\n"
+ )
fs = FeatureStore(repo_path=repo_path)
diff --git a/sdk/python/tests/utils/cli_repo_creator.py b/sdk/python/tests/utils/cli_repo_creator.py
index ea1d7fcf10b..7922e7454df 100644
--- a/sdk/python/tests/utils/cli_repo_creator.py
+++ b/sdk/python/tests/utils/cli_repo_creator.py
@@ -1,3 +1,4 @@
+import os
import random
import string
import subprocess
@@ -33,11 +34,18 @@ class CliRunner:
"""
def run(self, args: List[str], cwd: Path) -> subprocess.CompletedProcess:
+ env = os.environ.copy()
+ env.setdefault("IS_TEST", "True")
return subprocess.run(
- [sys.executable, cli.__file__] + args, cwd=cwd, capture_output=True
+ [sys.executable, cli.__file__] + args,
+ cwd=cwd,
+ capture_output=True,
+ env=env,
)
def run_with_output(self, args: List[str], cwd: Path) -> Tuple[int, bytes]:
+ env = os.environ.copy()
+ env.setdefault("IS_TEST", "True")
try:
return (
0,
@@ -45,6 +53,7 @@ def run_with_output(self, args: List[str], cwd: Path) -> Tuple[int, bytes]:
[sys.executable, cli.__file__] + args,
cwd=cwd,
stderr=subprocess.STDOUT,
+ env=env,
),
)
except subprocess.CalledProcessError as e: