diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 022ab8d818..9d59af0120 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -137,7 +137,7 @@ def __init__(self, db_config=None, **kwargs): self._column_names: List[str] = [] self._create_docs_table() self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed self._logger.info(f'{self.__class__.__name__} has been initialized') @property @@ -281,7 +281,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs): self._send_docs_to_sqlite(docs_validated) self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: """ @@ -318,9 +318,6 @@ def _find_batched( def _find( self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: - if self.num_docs() == 0: - return _FindResult(documents=[], scores=[]) # type: ignore - query_batched = np.expand_dims(query, axis=0) docs, scores = self._find_batched( queries=query_batched, limit=limit, search_field=search_field @@ -385,7 +382,7 @@ def _del_items(self, doc_ids: Sequence[str]): self._delete_docs_from_sqlite(doc_ids) self._sqlite_conn.commit() - self._num_docs = self._get_num_docs_sqlite() + self._num_docs = 0 # recompute again when needed def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]: """Get Documents from the hnswlib index, by `id`. @@ -410,6 +407,8 @@ def num_docs(self) -> int: """ Get the number of documents. """ + if self._num_docs == 0: + self._num_docs = self._get_num_docs_sqlite() return self._num_docs ############################################### @@ -605,7 +604,7 @@ def _search_and_filter( documents and their corresponding scores. """ # If there are no documents or hashed_ids is an empty set, return an empty _FindResultBatched - if self.num_docs() == 0 or (hashed_ids is not None and len(hashed_ids) == 0): + if hashed_ids is not None and len(hashed_ids) == 0: return _FindResultBatched(documents=[], scores=[]) # type: ignore # Set limit as the minimum of the provided limit and the total number of documents @@ -628,8 +627,11 @@ def accept_hashed_ids(id): # If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit k = min(limit, len(hashed_ids)) if hashed_ids else limit - - labels, distances = index.knn_query(queries, k=k, **extra_kwargs) + try: + labels, distances = index.knn_query(queries, k=k, **extra_kwargs) + except RuntimeError: # logic to avoid calling num_docs in most of the cases which comes at performance cost when many docs are indexed + k = min(k, self.num_docs()) + labels, distances = index.knn_query(queries, k=k, **extra_kwargs) result_das = [ self._get_docs_sqlite_hashed_id(