From f95b623180483b490221f0f48c2861dfe6fc2735 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Wed, 26 Jul 2023 13:00:55 +0200 Subject: [PATCH 1/2] refactor: do not recompute every time num_docs Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/hnswlib.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 6048a5fec8..2841b26e3b 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -137,7 +137,7 @@ def __init__(self, db_config=None, **kwargs): self._column_names: List[str] = [] self._create_docs_table() self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed self._logger.info(f'{self.__class__.__name__} has been initialized') @property @@ -281,7 +281,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs): self._send_docs_to_sqlite(docs_validated) self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: """ @@ -318,9 +318,6 @@ def _find_batched( def _find( self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: - if self.num_docs() == 0: - return _FindResult(documents=[], scores=[]) # type: ignore - query_batched = np.expand_dims(query, axis=0) docs, scores = self._find_batched( queries=query_batched, limit=limit, search_field=search_field @@ -385,7 +382,7 @@ def _del_items(self, doc_ids: Sequence[str]): self._delete_docs_from_sqlite(doc_ids) self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]: """Get Documents from the hnswlib index, by `id`. @@ -410,6 +407,8 @@ def num_docs(self) -> int: """ Get the number of documents. """ + if self._num_docs == 0: + self._num_docs = self._get_num_docs_sqlite() return self._num_docs ############################################### @@ -605,11 +604,11 @@ def _search_and_filter( documents and their corresponding scores. """ # If there are no documents or hashed_ids is an empty set, return an empty _FindResultBatched - if self.num_docs() == 0 or (hashed_ids is not None and len(hashed_ids) == 0): + if hashed_ids is not None and len(hashed_ids) == 0: return _FindResultBatched(documents=[], scores=[]) # type: ignore # Set limit as the minimum of the provided limit and the total number of documents - limit = min(limit, self.num_docs()) + limit = limit # Ensure the search field is in the HNSW indices if search_field not in self._hnsw_indices: From 3ca5780b9f5c89769e8ead133475151d205105d3 Mon Sep 17 00:00:00 2001 From: jupyterjazz Date: Wed, 26 Jul 2023 13:28:53 +0200 Subject: [PATCH 2/2] fix: k more than num docs Signed-off-by: jupyterjazz --- docarray/index/backends/hnswlib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 2841b26e3b..8b40c043c9 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -608,7 +608,7 @@ def _search_and_filter( return _FindResultBatched(documents=[], scores=[]) # type: ignore # Set limit as the minimum of the provided limit and the total number of documents - limit = limit + limit = min(limit, self.num_docs()) # Ensure the search field is in the HNSW indices if search_field not in self._hnsw_indices: