Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class VersionedOnlineReadNotSupported(FeastError):
def __init__(self, store_name: str, version: int):
super().__init__(
f"Versioned feature reads (@v{version}) are not yet supported by {store_name}. "
f"Currently only SQLite, PostgreSQL, and MySQL support version-qualified feature references. "
f"Currently only SQLite, PostgreSQL, MySQL, and FAISS support version-qualified feature references. "
)


Expand Down
119 changes: 75 additions & 44 deletions sdk/python/feast/infra/online_stores/faiss_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from feast import Entity, FeatureView, RepoConfig
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.helpers import compute_table_id
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
Expand Down Expand Up @@ -43,16 +44,21 @@ def teardown(self):
self.entity_keys = {}


def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str:
return compute_table_id(project, table, enable_versioning)


class FaissOnlineStore(OnlineStore):
_index: Optional[faiss.IndexIVFFlat] = None
_in_memory_store: InMemoryStore = InMemoryStore()
_config: Optional[FaissOnlineStoreConfig] = None
_logger: logging.Logger = logging.getLogger(__name__)

def _get_index(self, config: RepoConfig) -> faiss.IndexIVFFlat:
if self._index is None or self._config is None:
raise ValueError("Index is not initialized")
return self._index
def __init__(self):
super().__init__()
self._indices: Dict[str, faiss.IndexIVFFlat] = {}
self._in_memory_stores: Dict[str, InMemoryStore] = {}
self._config: Optional[FaissOnlineStoreConfig] = None

def _get_index(self, table_key: str) -> Optional[faiss.IndexIVFFlat]:
return self._indices.get(table_key)

def update(
self,
Expand All @@ -63,32 +69,45 @@ def update(
entities_to_keep: Sequence[Entity],
partial: bool,
):
feature_views = tables_to_keep
if not feature_views:
return

feature_names = [f.name for f in feature_views[0].features]
dimension = len(feature_names)

self._config = FaissOnlineStoreConfig(**config.online_store.dict())
if self._index is None or not partial:
quantizer = faiss.IndexFlatL2(dimension)
self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
self._index.train(
np.random.rand(self._config.nlist * 100, dimension).astype(np.float32)
)
self._in_memory_store = InMemoryStore()
versioning = config.registry.enable_online_feature_view_versioning

for table in tables_to_delete:
table_key = _table_id(config.project, table, versioning)
self._indices.pop(table_key, None)
self._in_memory_stores.pop(table_key, None)

for table in tables_to_keep:
table_key = _table_id(config.project, table, versioning)
feature_names = [f.name for f in table.features]
dimension = len(feature_names)

if table_key not in self._indices or not partial:
quantizer = faiss.IndexFlatL2(dimension)
index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
index.train(
np.random.rand(self._config.nlist * 100, dimension).astype(
np.float32
)
)
self._indices[table_key] = index
self._in_memory_stores[table_key] = InMemoryStore()

self._in_memory_store.update(feature_names, {})
self._in_memory_stores[table_key].update(feature_names, {})

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
self._index = None
self._in_memory_store.teardown()
versioning = config.registry.enable_online_feature_view_versioning
for table in tables:
table_key = _table_id(config.project, table, versioning)
self._indices.pop(table_key, None)
store = self._in_memory_stores.pop(table_key, None)
if store is not None:
store.teardown()

def online_read(
self,
Expand All @@ -97,23 +116,28 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
if self._index is None:
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)
in_memory_store = self._in_memory_stores.get(table_key)

if index is None or in_memory_store is None:
return [(None, None)] * len(entity_keys)

results: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = []
for entity_key in entity_keys:
serialized_key = serialize_entity_key(
entity_key, config.entity_key_serialization_version
).hex()
idx = self._in_memory_store.entity_keys.get(serialized_key, -1)
idx = in_memory_store.entity_keys.get(serialized_key, -1)
if idx == -1:
results.append((None, None))
else:
feature_vector = self._index.reconstruct(int(idx))
feature_vector = index.reconstruct(int(idx))
feature_dict = {
name: ValueProto(double_val=value)
for name, value in zip(
self._in_memory_store.feature_names, feature_vector
in_memory_store.feature_names, feature_vector
)
}
results.append((None, feature_dict))
Expand All @@ -128,8 +152,16 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
if self._index is None:
self._logger.warning("Index is not initialized. Skipping write operation.")
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)
in_memory_store = self._in_memory_stores.get(table_key)

if index is None or in_memory_store is None:
self._logger.warning(
"Index for table '%s' is not initialized. Skipping write operation.",
table_key,
)
return

feature_vectors = []
Expand All @@ -142,7 +174,7 @@ def online_write_batch(
feature_vector = np.array(
[
feature_dict[name].double_val
for name in self._in_memory_store.feature_names
for name in in_memory_store.feature_names
],
dtype=np.float32,
)
Expand All @@ -153,21 +185,17 @@ def online_write_batch(
feature_vectors_array = np.array(feature_vectors)

existing_indices = [
self._in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
]
mask = np.array(existing_indices) != -1
if np.any(mask):
self._index.remove_ids(
np.array([idx for idx in existing_indices if idx != -1])
)
index.remove_ids(np.array([idx for idx in existing_indices if idx != -1]))

new_indices = np.arange(
self._index.ntotal, self._index.ntotal + len(feature_vectors_array)
)
self._index.add(feature_vectors_array)
new_indices = np.arange(index.ntotal, index.ntotal + len(feature_vectors_array))
index.add(feature_vectors_array)

for sk, idx in zip(serialized_keys, new_indices):
self._in_memory_store.entity_keys[sk] = idx
in_memory_store.entity_keys[sk] = idx

if progress:
progress(len(data))
Expand All @@ -189,12 +217,16 @@ def retrieve_online_documents(
Optional[ValueProto],
]
]:
if self._index is None:
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)

if index is None:
self._logger.warning("Index is not initialized. Returning empty result.")
return []

query_vector = np.array(embedding, dtype=np.float32).reshape(1, -1)
distances, indices = self._index.search(query_vector, top_k)
distances, indices = index.search(query_vector, top_k)

results: List[
Tuple[
Expand All @@ -209,7 +241,7 @@ def retrieve_online_documents(
if idx == -1:
continue

feature_vector = self._index.reconstruct(int(idx))
feature_vector = index.reconstruct(int(idx))

timestamp = Timestamp()
timestamp.GetCurrentTime()
Expand Down Expand Up @@ -237,5 +269,4 @@ async def online_read_async(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
# Implement async read if needed
raise NotImplementedError("Async read is not implemented for FaissOnlineStore")
6 changes: 6 additions & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def _check_versioned_read_support(self, grouped_refs):
supported_types.append(PostgreSQLOnlineStore)
except ImportError:
pass
try:
from feast.infra.online_stores.faiss_online_store import FaissOnlineStore

supported_types.append(FaissOnlineStore)
except ImportError:
pass

if isinstance(self, tuple(supported_types)):
return
Expand Down
Loading
Loading