diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b7ca7815..c6074ba2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,4 +3,4 @@ # For syntax help see: # https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners#codeowners-syntax -* @googleapis/senseai-eco @googleapis/langchain-csql-pg \ No newline at end of file +* @googleapis/senseai-eco-team @googleapis/langchain-csql-pg-team \ No newline at end of file diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml index cedea1bf..7d31931e 100644 --- a/.github/blunderbuss.yml +++ b/.github/blunderbuss.yml @@ -1,4 +1,4 @@ assign_issues: - - googleapis/langchain-csql-pg + - googleapis/langchain-csql-pg-team assign_prs: - - googleapis/langchain-csql-pg + - googleapis/langchain-csql-pg-team diff --git a/.github/renovate.json5 b/.github/renovate.json5 index 93a38cf3..05a3fdf6 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -48,12 +48,6 @@ "matchCurrentVersion": "<=2.2.6", "enabled": false }, - { - "description": "Disable numpy updates for python <=3.9 in requirements.txt", - "matchPackageNames": ["numpy"], - "matchCurrentVersion": "<=2.0.2", - "enabled": false - }, { "description": "Disable numpy updates for python 3.10 in pyproject.toml", "matchFileNames": ["pyproject.toml"], @@ -61,25 +55,11 @@ "matchCurrentValue": ">=1.24.4, <=2.2.6", "enabled": false }, - { - "description": "Disable numpy updates for python <=3.9 in pyproject.toml", - "matchFileNames": ["pyproject.toml"], - "matchPackageNames": ["numpy"], - "matchCurrentValue": ">=1.24.4, <=2.0.2", - "enabled": false - }, { "description": "Use feat commit type for LangChain Postgres dependency updates", "matchPackageNames": ["langchain-postgres"], "semanticCommitType": "feat", "groupName": "langchain-postgres" }, - { - "description": "Disable isort updates for python <=3.9 in pyproject.toml", - "matchFileNames": ["pyproject.toml"], - "matchPackageNames": ["isort"], - "matchCurrentValue": "==6.1.0", - "enabled": false - } ], } diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 569858cd..be1330fa 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -10,9 +10,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Setup Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: python-version: "3.10" - name: Install nox @@ -26,9 +26,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Setup Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: python-version: "3.10" - name: Install nox diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8a33dc15..27ce3719 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -31,10 +31,10 @@ jobs: steps: - name: Checkout Repository - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Setup Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.11" diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml index 28b65edb..f27201de 100644 --- a/.github/workflows/schedule_reporter.yml +++ b/.github/workflows/schedule_reporter.yml @@ -24,6 +24,6 @@ jobs: issues: 'write' checks: 'read' contents: 'read' - uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@d5dd58511d12347d8a233d492b3ad15e6d2b3721 + uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@d2a71adee9dd1ff0d225edef933e40b052aa3386 with: trigger_names: "pg-integration-test-nightly,pg-continuous-test-on-merge" diff --git a/.kokoro/trampoline_v2.sh b/.kokoro/trampoline_v2.sh index 35fa5292..d03f92df 100755 --- a/.kokoro/trampoline_v2.sh +++ b/.kokoro/trampoline_v2.sh @@ -26,8 +26,8 @@ # To run this script, first download few files from gcs to /dev/shm. # (/dev/shm is passed into the container as KOKORO_GFILE_DIR). # -# gsutil cp gs://cloud-devrel-kokoro-resources/python-docs-samples/secrets_viewer_service_account.json /dev/shm -# gsutil cp gs://cloud-devrel-kokoro-resources/python-docs-samples/automl_secrets.txt /dev/shm +# gcloud storage cp gs://cloud-devrel-kokoro-resources/python-docs-samples/secrets_viewer_service_account.json /dev/shm +# gcloud storage cp gs://cloud-devrel-kokoro-resources/python-docs-samples/automl_secrets.txt /dev/shm # # Then run the script. # .kokoro/trampoline_v2.sh diff --git a/.repo-metadata.json b/.repo-metadata.json index 9a566027..1afa8aee 100644 --- a/.repo-metadata.json +++ b/.repo-metadata.json @@ -9,5 +9,5 @@ "repo": "googleapis/langchain-google-cloud-sql-pg-python", "distribution_name": "langchain-google-cloud-sql-pg", "requires_billing": true, - "codeowner_team": "@googleapis/senseai-eco" + "codeowner_team": "@googleapis/senseai-eco-team" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 445b05ea..8765db3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # Changelog +## [0.15.0](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.14.1...v0.15.0) (2026-01-08) + + +### ⚠ BREAKING CHANGES + +* Refactor PostgresVectorStore and PostgresEngine to depend on PGVectorstore and PGEngine respectively ([#316](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/316)) + +### Features + +* **deps:** Update langchain-postgres to v0.0.16 ([#366](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/366)) ([e773505](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/e773505453683dad5681e6155831b710cbc7fcc1)) +* Disable support for python 3.9 and enable support for python3.13 ([#378](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/378)) ([b97060e](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/b97060e1fd69f1902c370c90218b1e61b72050b8)) +* Update Langgraph dependency to v1 ([#379](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/379)) ([7a841b3](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/7a841b357c998bce7c6aede0e2e5fed8fa48f198)) + + +### Documentation + +* Add Hybrid Search documentation ([#329](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/329)) ([14098ca](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/14098ca7a6cf7116e6edbcb7a5c6c3ccbce76b4a)) + + +### Code Refactoring + +* Refactor PostgresVectorStore and PostgresEngine to depend on PGVectorstore and PGEngine respectively ([#316](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/316)) ([7917d62](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/7917d62c3f9ea2c6ca8ab8d6284cfa2c7e535401)) + ## [0.14.1](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.14.0...v0.14.1) (2025-07-11) diff --git a/DEVELOPER.md b/DEVELOPER.md index 899f62b6..751df2e7 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -42,11 +42,11 @@ These tests are registered as required tests in `.github/sync-repo-settings.yaml #### Trigger Setup -Cloud Build triggers (for Python versions 3.9 to 3.11) were created with the following specs: +Cloud Build triggers (for Python versions 3.10 to 3.13) were created with the following specs: ```YAML name: pg-integration-test-pr-py39 -description: Run integration tests on PR for Python 3.9 +description: Run integration tests on PR for Python 3.10 filename: integration.cloudbuild.yaml github: name: langchain-google-cloud-sql-pg-python @@ -64,7 +64,7 @@ substitutions: _DATABASE_ID: _INSTANCE_ID: _REGION: us-central1 - _VERSION: "3.9" + _VERSION: "3.10" ``` Use `gcloud builds triggers import --source=trigger.yaml` to create triggers via the command line diff --git a/README.rst b/README.rst index 9839b661..d1e258c5 100644 --- a/README.rst +++ b/README.rst @@ -56,7 +56,7 @@ dependencies. Supported Python Versions ^^^^^^^^^^^^^^^^^^^^^^^^^ -Python >= 3.9 +Python >= 3.10 Mac/Linux ^^^^^^^^^ diff --git a/docs/vector_store.ipynb b/docs/vector_store.ipynb index ddc5ce30..6420d40f 100644 --- a/docs/vector_store.ipynb +++ b/docs/vector_store.ipynb @@ -422,7 +422,40 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Delete texts" + "### Get document\n", + "\n", + "Get documents from the vectorstore using filters and parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "documents_with_apple = await store.aget(\n", + " where_document={\"$ilike\": \"%apple%\"}, include=\"documents\"\n", + ")\n", + "paginated_ids = await store.aget(limit=3, offset=3)\n", + "\n", + "print(documents_with_apple[\"documents\"])\n", + "print(paginated_ids[\"ids\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Delete documents\n", + "\n", + "Documents can be deleted using IDs or metadata filters." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Delete by IDs" ] }, { @@ -434,6 +467,46 @@ "await store.adelete([ids[1]])" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Delete by metadata filter\n", + "You can delete documents based on metadata filters. This is useful for bulk deletion operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Delete all documents with a specific metadata value\n", + "await store.adelete(filter={\"source\": \"documentation\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Delete documents matching complex filter criteria\n", + "await store.adelete(\n", + " filter={\"$and\": [{\"category\": \"obsolete\"}, {\"year\": {\"$lt\": 2020}}]}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Delete by both IDs and filter (must match both criteria)\n", + "await store.adelete(ids=[\"id1\", \"id2\"], filter={\"status\": \"archived\"})" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -573,6 +646,15 @@ "### Search for documents with metadata filter" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### For v0.16.0+\n", + "\n", + "Metadata filtering on the `metadata_json_column` is now supported in the `AlloyDBVectorStore`." + ] + }, { "cell_type": "code", "execution_count": null, @@ -592,9 +674,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### For v0.15.0+\n", - "\n", - "**Important Update:** Support for string filters has been deprecated. Please use dictionaries to add filters." + "**Important Update:** From v0.15.0, support for string filters has been deprecated. Please use dictionaries to add filters." ] }, { @@ -722,7 +802,9 @@ "\n", "- **`metadata_columns=[\"name\", \"category\", \"price_usd\", \"quantity\", \"sku\", \"image_url\"]`**: These columns are treated as metadata for each product. Metadata provides additional information about a product, such as its name, category, price, quantity available, SKU (Stock Keeping Unit), and an image URL. This information is useful for displaying product details in search results or for filtering and categorization.\n", "\n", - "- **`metadata_json_column=\"metadata\"`**: The `metadata` column can store any additional information about the products in a flexible JSON format. This allows for storing varied and complex data that doesn't fit into the standard columns.\n" + "- **`metadata_json_column=\"metadata\"`**: The `metadata` column can store any additional information about the products in a flexible JSON format. This allows for storing varied and complex data that doesn't fit into the standard columns.\n", + "Note that filtering on fields within the JSON but not in `metadata_columns` will be less efficient.\n", + "\n" ] }, { diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 18414b8e..bc0b0d8f 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -62,7 +62,7 @@ substitutions: _DATABASE_PORT: "5432" _DATABASE_ID: test-database _REGION: us-central1 - _VERSION: "3.9" + _VERSION: "3.10" _IP_ADDRESS: "127.0.0.1" options: diff --git a/pyproject.toml b/pyproject.toml index cfb4d532..6fa3d320 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ dynamic = ["version"] description = "LangChain integrations for Google Cloud SQL for PostgreSQL" readme = "README.rst" license = {file = "LICENSE"} -requires-python = ">=3.9" +requires-python = ">=3.10" authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] @@ -13,8 +13,7 @@ dependencies = [ "cloud-sql-python-connector[asyncpg] >= 1.10.0, <2.0.0", "numpy>=1.24.4, <3.0.0; python_version >= '3.11'", "numpy>=1.24.4, <=2.2.6; python_version == '3.10'", - "numpy>=1.24.4, <=2.0.2; python_version <= '3.9'", - "langchain-postgres>=0.0.15", + "langchain-postgres>=0.0.16", ] classifiers = [ @@ -22,10 +21,10 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] [tool.setuptools.dynamic] @@ -39,18 +38,17 @@ Changelog = "https://github.com/googleapis/langchain-google-cloud-sql-pg-python/ [project.optional-dependencies] langgraph = [ - "langgraph-checkpoint>=2.0.9, <3.0.0" + "langgraph-checkpoint>=4.0.0, <4.1.0" ] test = [ - "black[jupyter]==25.9.0", - "isort==6.1.0; python_version == '3.9'", - "isort==7.0.0; python_version >= '3.10'", - "mypy==1.18.2", + "black[jupyter]==26.1.0", + "isort==8.0.0", + "mypy==1.19.1", "pytest-asyncio==0.26.0", "pytest==8.4.2", "pytest-cov==7.0.0", - "langchain-tests==0.3.22", - "langgraph==0.6.10" + "langchain-tests==1.1.2", + "langgraph==1.0.7" ] [build-system] @@ -64,7 +62,7 @@ target-version = ['py39'] profile = "black" [tool.mypy] -python_version = 3.9 +python_version = "3.10" warn_unused_configs = true disallow_incomplete_defs = true diff --git a/requirements.txt b/requirements.txt index 02eaa6f8..b5b40541 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ -cloud-sql-python-connector[asyncpg]==1.18.4 -numpy==2.3.3; python_version >= "3.11" +cloud-sql-python-connector[asyncpg]==1.20.0 +numpy==2.4.1; python_version >= "3.11" numpy==2.2.6; python_version == "3.10" -numpy==2.0.2; python_version <= "3.9" -langgraph==0.6.10 -langchain-postgres==0.0.15 +langgraph==1.0.10 +langchain-postgres==0.0.17 diff --git a/samples/index_tuning_sample/requirements.txt b/samples/index_tuning_sample/requirements.txt index d2c09cdf..95c01a93 100644 --- a/samples/index_tuning_sample/requirements.txt +++ b/samples/index_tuning_sample/requirements.txt @@ -1,3 +1,3 @@ langchain-community==0.4.1 -langchain-google-cloud-sql-pg==0.14.1 -langchain-google-vertexai==3.0.1 +langchain-google-cloud-sql-pg==0.15.0 +langchain-google-vertexai==3.2.2 diff --git a/samples/langchain_on_vertexai/clean_up.py b/samples/langchain_on_vertexai/clean_up.py index 45e57ae5..42c3866a 100644 --- a/samples/langchain_on_vertexai/clean_up.py +++ b/samples/langchain_on_vertexai/clean_up.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio import os +from typing import Any, Coroutine from config import ( CHAT_TABLE_NAME, @@ -32,6 +33,15 @@ TEST_NAME = os.getenv("DISPLAY_NAME") +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._default_loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._default_loop) + ) + return await coro + + async def delete_tables(): engine = await PostgresEngine.afrom_instance( PROJECT_ID, @@ -42,12 +52,14 @@ async def delete_tables(): password=PASSWORD, ) - async with engine._pool.connect() as conn: - await conn.execute(text("COMMIT")) - await conn.execute(text(f"DROP TABLE IF EXISTS {TABLE_NAME}")) - await conn.execute(text(f"DROP TABLE IF EXISTS {CHAT_TABLE_NAME}")) + async def _logic(): + async with engine._pool.connect() as conn: + await conn.execute(text("COMMIT")) + await conn.execute(text(f"DROP TABLE IF EXISTS {TABLE_NAME}")) + await conn.execute(text(f"DROP TABLE IF EXISTS {CHAT_TABLE_NAME}")) + + await run_on_background(engine, _logic()) await engine.close() - await engine._connector.close_async() def delete_engines(): diff --git a/samples/langchain_on_vertexai/create_embeddings.py b/samples/langchain_on_vertexai/create_embeddings.py index 105a86df..370d8262 100644 --- a/samples/langchain_on_vertexai/create_embeddings.py +++ b/samples/langchain_on_vertexai/create_embeddings.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio import uuid +from typing import Any, Coroutine from config import ( CHAT_TABLE_NAME, @@ -32,6 +33,15 @@ from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._default_loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._default_loop) + ) + return await coro + + async def create_databases(): engine = await PostgresEngine.afrom_instance( PROJECT_ID, @@ -41,10 +51,14 @@ async def create_databases(): user=USER, password=PASSWORD, ) - async with engine._pool.connect() as conn: - await conn.execute(text("COMMIT")) - await conn.execute(text(f'DROP DATABASE IF EXISTS "{DATABASE}"')) - await conn.execute(text(f'CREATE DATABASE "{DATABASE}"')) + + async def _logic(): + async with engine._pool.connect() as conn: + await conn.execute(text("COMMIT")) + await conn.execute(text(f'DROP DATABASE IF EXISTS "{DATABASE}"')) + await conn.execute(text(f'CREATE DATABASE "{DATABASE}"')) + + await run_on_background(engine, _logic()) await engine.close() @@ -95,7 +109,7 @@ async def grant_select(engine): engine, table_name=TABLE_NAME, embedding_service=VertexAIEmbeddings( - model_name="textembedding-gecko@latest", project=PROJECT_ID + model_name="text-embedding-005", project=PROJECT_ID ), ) diff --git a/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py b/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py index 472b9da9..efd7fb58 100644 --- a/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py +++ b/samples/langchain_on_vertexai/prebuilt_langchain_agent_template.py @@ -65,7 +65,7 @@ def similarity_search(query: str) -> list[Document]: engine, table_name=TABLE_NAME, embedding_service=VertexAIEmbeddings( - model_name="textembedding-gecko@latest", project=PROJECT_ID + model_name="text-embedding-005", project=PROJECT_ID ), ) retriever = vector_store.as_retriever() @@ -91,9 +91,9 @@ def similarity_search(query: str) -> list[Document]: DISPLAY_NAME = os.getenv("DISPLAY_NAME") or "PrebuiltAgent" remote_app = reasoning_engines.ReasoningEngine.create( - reasoning_engines.LangchainAgent( + reasoning_engines.LangchainAgent( # type: ignore[arg-type] model="gemini-2.0-flash-001", - tools=[similarity_search], + tools=[similarity_search], # type: ignore[list-item] model_kwargs={ "temperature": 0.1, }, @@ -104,4 +104,4 @@ def similarity_search(query: str) -> list[Document]: extra_packages=["config.py"], ) -print(remote_app.query(input="movies about engineers")) +print(remote_app.query(input="movies about engineers")) # type: ignore[attr-defined] diff --git a/samples/langchain_on_vertexai/requirements.txt b/samples/langchain_on_vertexai/requirements.txt index f841a4c3..aa80c386 100644 --- a/samples/langchain_on_vertexai/requirements.txt +++ b/samples/langchain_on_vertexai/requirements.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform[reasoningengine,langchain]==1.120.0 -google-cloud-resource-manager==1.14.2 +google-cloud-aiplatform[reasoningengine,langchain]==1.134.0 +google-cloud-resource-manager==1.16.0 langchain-community==0.3.31 -langchain-google-cloud-sql-pg==0.14.1 +langchain-google-cloud-sql-pg==0.15.0 langchain-google-vertexai==2.1.2 diff --git a/samples/langchain_on_vertexai/retriever_agent_with_history_template.py b/samples/langchain_on_vertexai/retriever_agent_with_history_template.py index 7d8a520e..bba06a16 100644 --- a/samples/langchain_on_vertexai/retriever_agent_with_history_template.py +++ b/samples/langchain_on_vertexai/retriever_agent_with_history_template.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Optional +from typing import Any, Optional import vertexai # type: ignore from config import ( @@ -91,7 +91,7 @@ def set_up(self): engine, table_name=self.table, embedding_service=VertexAIEmbeddings( - model_name="textembedding-gecko@latest", project=self.project + model_name="text-embedding-005", project=self.project ), ) retriever = vector_store.as_retriever() @@ -132,7 +132,7 @@ def set_up(self): history_messages_key="chat_history", ) - def query(self, input: str, session_id: str) -> str: + def query(self, input: str, session_id: str, **kwargs: Any) -> str: # type: ignore[override] """Query the application. Args: @@ -192,4 +192,4 @@ def query(self, input: str, session_id: str) -> str: extra_packages=["config.py"], ) -print(remote_app.query(input="movies about engineers", session_id="abc123")) +print(remote_app.query(input="movies about engineers", session_id="abc123")) # type: ignore diff --git a/samples/langchain_on_vertexai/retriever_chain_template.py b/samples/langchain_on_vertexai/retriever_chain_template.py index d05780c3..0d322ba8 100644 --- a/samples/langchain_on_vertexai/retriever_chain_template.py +++ b/samples/langchain_on_vertexai/retriever_chain_template.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Optional +from typing import Any, Optional import vertexai # type: ignore from config import ( @@ -97,7 +97,7 @@ def set_up(self): engine, table_name=self.table, embedding_service=VertexAIEmbeddings( - model_name="textembedding-gecko@latest", project=self.project + model_name="text-embedding-005", project=self.project ), ) retriever = vector_store.as_retriever() @@ -106,7 +106,7 @@ def set_up(self): # an LLM to generate a response self.chain = create_retrieval_chain(retriever, combine_docs_chain) - def query(self, input: str) -> str: + def query(self, input: str, **kwargs: Any) -> str: # type: ignore[override] """Query the application. Args: @@ -161,4 +161,4 @@ def query(self, input: str) -> str: extra_packages=["config.py"], ) -print(remote_app.query(input="movies about engineers")) +print(remote_app.query(input="movies about engineers")) # type: ignore diff --git a/samples/requirements.txt b/samples/requirements.txt index b6b27ad1..0cded642 100644 --- a/samples/requirements.txt +++ b/samples/requirements.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform[reasoningengine,langchain]==1.97.0 -google-cloud-resource-manager==1.14.2 +google-cloud-aiplatform[reasoningengine,langchain]==1.134.0 +google-cloud-resource-manager==1.16.0 langchain-community==0.3.29 -langchain-google-cloud-sql-pg==0.14.1 -langchain-google-vertexai==2.0.27 +langchain-google-cloud-sql-pg==0.15.0 +langchain-google-vertexai==2.1.2 diff --git a/src/langchain_google_cloud_sql_pg/async_checkpoint.py b/src/langchain_google_cloud_sql_pg/async_checkpoint.py index fc875991..32eef521 100644 --- a/src/langchain_google_cloud_sql_pg/async_checkpoint.py +++ b/src/langchain_google_cloud_sql_pg/async_checkpoint.py @@ -276,7 +276,9 @@ async def aput( async with self.pool.connect() as conn: type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint) - serialized_metadata = self.jsonplus_serde.dumps(metadata) + serialized_metadata = json.dumps(metadata, ensure_ascii=False).encode( + "utf-8", "ignore" + ) await conn.execute( text(query), { @@ -409,7 +411,7 @@ async def alist( (value["type"], value["checkpoint"]) ), metadata=( - self.jsonplus_serde.loads(value["metadata"]) # type: ignore + json.loads(value["metadata"]) # type: ignore if value["metadata"] is not None else {} ), @@ -494,7 +496,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: }, checkpoint=self.serde.loads_typed((value["type"], value["checkpoint"])), metadata=( - self.jsonplus_serde.loads(value["metadata"]) # type: ignore + json.loads(value["metadata"]) # type: ignore if value["metadata"] is not None else {} ), diff --git a/src/langchain_google_cloud_sql_pg/version.py b/src/langchain_google_cloud_sql_pg/version.py index f735a04c..a9b14a39 100644 --- a/src/langchain_google_cloud_sql_pg/version.py +++ b/src/langchain_google_cloud_sql_pg/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.14.1" +__version__ = "0.15.0" diff --git a/tests/test_async_chatmessagehistory.py b/tests/test_async_chatmessagehistory.py index e5443b11..585661a1 100644 --- a/tests/test_async_chatmessagehistory.py +++ b/tests/test_async_chatmessagehistory.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid +from typing import Any, Coroutine import pytest import pytest_asyncio @@ -33,10 +35,23 @@ table_name_async = "message_store" + str(uuid.uuid4()) +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) @pytest_asyncio.fixture @@ -47,7 +62,10 @@ async def async_engine(): instance=instance_id, database=db_name, ) - await async_engine._ainit_chat_history_table(table_name=table_name_async) + await run_on_background( + async_engine, + async_engine._ainit_chat_history_table(table_name=table_name_async), + ) yield async_engine # use default table for AsyncPostgresChatMessageHistory query = f'DROP TABLE IF EXISTS "{table_name_async}"' @@ -59,14 +77,19 @@ async def async_engine(): async def test_chat_message_history_async( async_engine: PostgresEngine, ) -> None: - history = await AsyncPostgresChatMessageHistory.create( - engine=async_engine, session_id="test", table_name=table_name_async + history = await run_on_background( + async_engine, + AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name_async + ), ) msg1 = HumanMessage(content="hi!") msg2 = AIMessage(content="whats up?") - await history.aadd_message(msg1) - await history.aadd_message(msg2) - messages = await history._aget_messages() + + await run_on_background(async_engine, history.aadd_message(msg1)) + await run_on_background(async_engine, history.aadd_message(msg2)) + + messages = await run_on_background(async_engine, history._aget_messages()) # verify messages are correct assert messages[0].content == "hi!" @@ -75,48 +98,71 @@ async def test_chat_message_history_async( assert type(messages[1]) is AIMessage # verify clear() clears message history - await history.aclear() - assert len(await history._aget_messages()) == 0 + await run_on_background(async_engine, history.aclear()) + messages_after_clear = await run_on_background( + async_engine, history._aget_messages() + ) + assert len(messages_after_clear) == 0 @pytest.mark.asyncio async def test_chat_message_history_sync_messages( async_engine: PostgresEngine, ) -> None: - history1 = await AsyncPostgresChatMessageHistory.create( - engine=async_engine, session_id="test", table_name=table_name_async + history1 = await run_on_background( + async_engine, + AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name_async + ), ) - history2 = await AsyncPostgresChatMessageHistory.create( - engine=async_engine, session_id="test", table_name=table_name_async + history2 = await run_on_background( + async_engine, + AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name_async + ), ) msg1 = HumanMessage(content="hi!") msg2 = AIMessage(content="whats up?") - await history1.aadd_message(msg1) - await history2.aadd_message(msg2) + await run_on_background(async_engine, history1.aadd_message(msg1)) + await run_on_background(async_engine, history2.aadd_message(msg2)) + + len_history1 = len(await run_on_background(async_engine, history1._aget_messages())) + len_history2 = len(await run_on_background(async_engine, history2._aget_messages())) - assert len(await history1._aget_messages()) == 2 - assert len(await history2._aget_messages()) == 2 + assert len_history1 == 2 + assert len_history2 == 2 # verify clear() clears message history - await history2.aclear() - assert len(await history2._aget_messages()) == 0 + await run_on_background(async_engine, history2.aclear()) + len_history2_after_clear = len( + await run_on_background(async_engine, history2._aget_messages()) + ) + assert len_history2_after_clear == 0 @pytest.mark.asyncio async def test_chat_table_async(async_engine): with pytest.raises(ValueError): - await AsyncPostgresChatMessageHistory.create( - engine=async_engine, session_id="test", table_name="doesnotexist" + await run_on_background( + async_engine, + AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name="doesnotexist" + ), ) @pytest.mark.asyncio async def test_chat_schema_async(async_engine): table_name = "test_table" + str(uuid.uuid4()) - await async_engine._ainit_document_table(table_name=table_name) + await run_on_background( + async_engine, async_engine._ainit_document_table(table_name=table_name) + ) with pytest.raises(IndexError): - await AsyncPostgresChatMessageHistory.create( - engine=async_engine, session_id="test", table_name=table_name + await run_on_background( + async_engine, + AsyncPostgresChatMessageHistory.create( + engine=async_engine, session_id="test", table_name=table_name + ), ) query = f'DROP TABLE IF EXISTS "{table_name}"' diff --git a/tests/test_async_checkpoint.py b/tests/test_async_checkpoint.py index 821b27c0..00d26b29 100644 --- a/tests/test_async_checkpoint.py +++ b/tests/test_async_checkpoint.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import re import uuid -from typing import Any, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Coroutine, List, Literal, Optional, Sequence, Tuple, Union import pytest import pytest_asyncio @@ -107,18 +108,33 @@ def _AnyIdToolMessage(**kwargs: Any) -> ToolMessage: return message +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async with engine._pool.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + async def _impl(): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + return result_map.fetchall() + + return await run_on_background(engine, _impl()) @pytest_asyncio.fixture @@ -139,10 +155,15 @@ async def async_engine(): @pytest_asyncio.fixture async def checkpointer(async_engine): - await async_engine._ainit_checkpoint_table(table_name=table_name) - checkpointer = await AsyncPostgresSaver.create( + await run_on_background( + async_engine, async_engine._ainit_checkpoint_table(table_name=table_name) + ) + checkpointer = await run_on_background( async_engine, - table_name, # serde=JsonPlusSerializer + AsyncPostgresSaver.create( + async_engine, + table_name, # serde=JsonPlusSerializer + ), ) yield checkpointer @@ -160,7 +181,9 @@ async def test_checkpoint_async( } } # Verify if updated configuration after storing the checkpoint is correct - next_config = await checkpointer.aput(write_config, checkpoint, {}, {}) + next_config = await run_on_background( + async_engine, checkpointer.aput(write_config, checkpoint, {}, {}) + ) assert dict(next_config) == test_config # Verify if the checkpoint is stored correctly in the database @@ -258,7 +281,9 @@ async def test_checkpoint_aput_writes( ("test_channel1", {}), ("test_channel2", {}), ] - await checkpointer.aput_writes(config, writes, task_id="1") + await run_on_background( + async_engine, checkpointer.aput_writes(config, writes, task_id="1") + ) results = await afetch(async_engine, f'SELECT * FROM "{table_name_writes}"') assert len(results) == 2 @@ -277,9 +302,19 @@ async def test_checkpoint_alist( checkpoints = test_data["checkpoints"] metadata = test_data["metadata"] - await checkpointer.aput(configs[1], checkpoints[1], metadata[0], {}) - await checkpointer.aput(configs[2], checkpoints[2], metadata[1], {}) - await checkpointer.aput(configs[3], checkpoints[3], metadata[2], {}) + await run_on_background( + async_engine, checkpointer.aput(configs[1], checkpoints[1], metadata[0], {}) + ) + await run_on_background( + async_engine, checkpointer.aput(configs[2], checkpoints[2], metadata[1], {}) + ) + await run_on_background( + async_engine, checkpointer.aput(configs[3], checkpoints[3], metadata[2], {}) + ) + + # Helper to consume async iterator on background thread + async def consume_alist(config, filter): + return [c async for c in checkpointer.alist(config, filter=filter)] # call method / assertions query_1 = {"source": "input"} # search by 1 key @@ -290,26 +325,35 @@ async def test_checkpoint_alist( query_3: dict[str, Any] = {} # search by no keys, return all checkpoints query_4 = {"source": "update", "step": 1} # no match - search_results_1 = [c async for c in checkpointer.alist(None, filter=query_1)] + search_results_1 = await run_on_background( + async_engine, consume_alist(None, filter=query_1) + ) assert len(search_results_1) == 1 print(metadata[0]) print(search_results_1[0].metadata) assert search_results_1[0].metadata == metadata[0] - search_results_2 = [c async for c in checkpointer.alist(None, filter=query_2)] + search_results_2 = await run_on_background( + async_engine, consume_alist(None, filter=query_2) + ) assert len(search_results_2) == 1 assert search_results_2[0].metadata == metadata[1] - search_results_3 = [c async for c in checkpointer.alist(None, filter=query_3)] + search_results_3 = await run_on_background( + async_engine, consume_alist(None, filter=query_3) + ) assert len(search_results_3) == 3 - search_results_4 = [c async for c in checkpointer.alist(None, filter=query_4)] + search_results_4 = await run_on_background( + async_engine, consume_alist(None, filter=query_4) + ) assert len(search_results_4) == 0 # search by config (defaults to checkpoints across all namespaces) - search_results_5 = [ - c async for c in checkpointer.alist({"configurable": {"thread_id": "thread-2"}}) - ] + search_results_5 = await run_on_background( + async_engine, + consume_alist({"configurable": {"thread_id": "thread-2"}}, filter=None), + ) assert len(search_results_5) == 2 assert { search_results_5[0].config["configurable"]["checkpoint_ns"], @@ -353,6 +397,7 @@ def _llm_type(self) -> str: @pytest.mark.asyncio async def test_checkpoint_with_agent( + async_engine: PostgresEngine, checkpointer: AsyncPostgresSaver, ) -> None: # from the tests in https://github.com/langchain-ai/langgraph/blob/909190cede6a80bb94a2d4cfe7dedc49ef0d4127/libs/langgraph/tests/test_prebuilt.py @@ -360,8 +405,9 @@ async def test_checkpoint_with_agent( agent = create_react_agent(model, [], checkpointer=checkpointer) inputs = [HumanMessage("hi?")] - response = await agent.ainvoke( - {"messages": inputs}, config=thread_agent_config, debug=True + response = await run_on_background( + async_engine, + agent.ainvoke({"messages": inputs}, config=thread_agent_config, debug=True), ) expected_response = {"messages": inputs + [AIMessage(content="hi?", id="0")]} assert response == expected_response @@ -372,7 +418,9 @@ def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: message.id = AnyStr() return message - saved = await checkpointer.aget_tuple(thread_agent_config) + saved = await run_on_background( + async_engine, checkpointer.aget_tuple(thread_agent_config) + ) assert saved is not None assert ( _AnyIdHumanMessage(content="hi?") @@ -392,6 +440,7 @@ def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: @pytest.mark.asyncio async def test_checkpoint_aget_tuple( + async_engine: PostgresEngine, checkpointer: AsyncPostgresSaver, test_data: dict[str, Any], ) -> None: @@ -399,30 +448,48 @@ async def test_checkpoint_aget_tuple( checkpoints = test_data["checkpoints"] metadata = test_data["metadata"] - new_config = await checkpointer.aput(configs[1], checkpoints[1], metadata[0], {}) + new_config = await run_on_background( + async_engine, checkpointer.aput(configs[1], checkpoints[1], metadata[0], {}) + ) # Matching checkpoint - search_results_1 = await checkpointer.aget_tuple(new_config) + search_results_1 = await run_on_background( + async_engine, checkpointer.aget_tuple(new_config) + ) assert search_results_1.metadata == metadata[0] # type: ignore # No matching checkpoint - assert await checkpointer.aget_tuple(configs[0]) is None + assert ( + await run_on_background(async_engine, checkpointer.aget_tuple(configs[0])) + is None + ) @pytest.mark.asyncio async def test_metadata( + async_engine: PostgresEngine, checkpointer: AsyncPostgresSaver, test_data: dict[str, Any], ) -> None: - config = await checkpointer.aput( - test_data["configs"][0], - test_data["checkpoints"][0], - {"my_key": "abc"}, # type: ignore - {}, + # Wrap aput + config = await run_on_background( + async_engine, + checkpointer.aput( + test_data["configs"][0], + test_data["checkpoints"][0], + {"my_key": "abc"}, # type: ignore + {}, + ), + ) + tuple_result = await run_on_background( + async_engine, checkpointer.aget_tuple(config) + ) + assert tuple_result.metadata["my_key"] == "abc" # type: ignore + + async def consume_alist(config, filter): + return [c async for c in checkpointer.alist(config, filter=filter)] + + alist_results = await run_on_background( + async_engine, consume_alist(None, filter={"my_key": "abc"}) ) - assert (await checkpointer.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore - assert [c async for c in checkpointer.alist(None, filter={"my_key": "abc"})][ - 0 - ].metadata[ - "my_key" # type: ignore - ] == "abc" # type: ignore + assert alist_results[0].metadata["my_key"] == "abc" # type: ignore diff --git a/tests/test_async_loader.py b/tests/test_async_loader.py index c29a82f7..61316519 100644 --- a/tests/test_async_loader.py +++ b/tests/test_async_loader.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import json import os import uuid +from typing import Any, Coroutine import pytest import pytest_asyncio @@ -34,10 +36,23 @@ table_name = "test-table" + str(uuid.uuid4()) +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _action(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _action()) @pytest.mark.asyncio(scope="class") @@ -45,7 +60,6 @@ class TestLoaderAsync: @pytest_asyncio.fixture(scope="class") async def engine(self): - PostgresEngine._connector = None engine = await PostgresEngine.afrom_instance( project_id=project_id, instance=instance_id, @@ -56,37 +70,50 @@ async def engine(self): await engine.close() - async def _collect_async_items(self, docs_generator): - """Collects items from an async generator.""" - docs = [] - async for doc in docs_generator: - docs.append(doc) - return docs + async def _collect_async_items(self, engine, docs_generator): + """Collects items from an async generator, running on background loop.""" + + async def _consume(): + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + return await run_on_background(engine, _consume()) async def _cleanup_table(self, engine): await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') async def test_create_loader_with_invalid_parameters(self, engine): with pytest.raises(ValueError): - await AsyncPostgresLoader.create( - engine=engine, + await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + ), ) with pytest.raises(ValueError): def fake_formatter(): return None - await AsyncPostgresLoader.create( - engine=engine, - table_name=table_name, - format="text", - formatter=fake_formatter, + await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + format="text", + formatter=fake_formatter, + ), ) with pytest.raises(ValueError): - await AsyncPostgresLoader.create( - engine=engine, - table_name=table_name, - format="fake_format", + await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + format="fake_format", + ), ) async def test_load_from_query_default(self, engine): @@ -110,12 +137,15 @@ async def test_load_from_query_default(self, engine): """ await aexecute(engine, insert_query) - loader = await AsyncPostgresLoader.create( - engine=engine, - table_name=table_name, + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + ), ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -153,20 +183,23 @@ async def test_load_from_query_customized_content_customized_metadata(self, engi """ await aexecute(engine, insert_query) - loader = await AsyncPostgresLoader.create( - engine=engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "fruit_name", - "variety", - "quantity_in_stock", - "price_per_unit", - "organic", - ], - metadata_columns=["fruit_id"], + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ), ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -205,19 +238,20 @@ async def test_load_from_query_customized_content_default_metadata(self, engine) """ await aexecute(engine, insert_query) - loader = await AsyncPostgresLoader.create( - engine=engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "variety", - "quantity_in_stock", - "price_per_unit", - ], + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ), ) - documents = [] - async for docs in loader.alazy_load(): - documents.append(docs) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -230,18 +264,21 @@ async def test_load_from_query_customized_content_default_metadata(self, engine) ) ] - loader = await AsyncPostgresLoader.create( - engine=engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "variety", - "quantity_in_stock", - "price_per_unit", - ], - format="JSON", + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ), ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -280,13 +317,16 @@ async def test_load_from_query_default_content_customized_metadata(self, engine) """ await aexecute(engine, insert_query) - loader = await AsyncPostgresLoader.create( - engine=engine, - query=f'SELECT * FROM "{table_name}";', - metadata_columns=["fruit_name", "organic"], + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=["fruit_name", "organic"], + ), ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -317,16 +357,19 @@ async def test_load_from_query_with_langchain_metadata(self, engine): VALUES ('Apple', 'Granny Smith', 150, 1, '{metadata}');""" await aexecute(engine, insert_query) - loader = await AsyncPostgresLoader.create( - engine=engine, - query=f'SELECT * FROM "{table_name}";', - metadata_columns=[ - "fruit_name", - "langchain_metadata", - ], + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "fruit_name", + "langchain_metadata", + ], + ), ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -362,15 +405,18 @@ async def test_load_from_query_with_json(self, engine): VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" await aexecute(engine, insert_query) - loader = await AsyncPostgresLoader.create( - engine=engine, - query=f'SELECT * FROM "{table_name}";', - metadata_columns=[ - "variety", - ], + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ), ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -411,18 +457,21 @@ def my_formatter(row, content_columns): str(row[column]) for column in content_columns if column in row ) - loader = await AsyncPostgresLoader.create( - engine=engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "variety", - "quantity_in_stock", - "price_per_unit", - ], - formatter=my_formatter, + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ), ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -458,18 +507,21 @@ async def test_load_from_query_customized_content_default_metadata_custom_page_c """ await aexecute(engine, insert_query) - loader = await AsyncPostgresLoader.create( - engine=engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "variety", - "quantity_in_stock", - "price_per_unit", - ], - format="YAML", + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ), ) - documents = await self._collect_async_items(loader.alazy_load()) + documents = await self._collect_async_items(engine, loader.alazy_load()) assert documents == [ Document( @@ -487,7 +539,7 @@ async def test_load_from_query_customized_content_default_metadata_custom_page_c async def test_save_doc_with_default_metadata(self, engine): await self._cleanup_table(engine) - await engine._ainit_document_table(table_name) + await run_on_background(engine, engine._ainit_document_table(table_name)) test_docs = [ Document( page_content="Apple Granny Smith 150 0.99 1", @@ -502,16 +554,21 @@ async def test_save_doc_with_default_metadata(self, engine): metadata={"fruit_id": 3}, ), ] - saver = await AsyncPostgresDocumentSaver.create( - engine=engine, table_name=table_name + saver = await run_on_background( + engine, + AsyncPostgresDocumentSaver.create(engine=engine, table_name=table_name), + ) + loader = await run_on_background( + engine, AsyncPostgresLoader.create(engine=engine, table_name=table_name) ) - loader = await AsyncPostgresLoader.create(engine=engine, table_name=table_name) - await saver.aadd_documents(test_docs) - docs = await self._collect_async_items(loader.alazy_load()) + await run_on_background(engine, saver.aadd_documents(test_docs)) + docs = await self._collect_async_items(engine, loader.alazy_load()) assert docs == test_docs - assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + + schema = await run_on_background(engine, engine._aload_table_schema(table_name)) + assert schema.columns.keys() == [ "page_content", "langchain_metadata", ] @@ -520,13 +577,16 @@ async def test_save_doc_with_default_metadata(self, engine): @pytest.mark.parametrize("store_metadata", [True, False]) async def test_save_doc_with_customized_metadata(self, engine, store_metadata): table_name = "test-table" + str(uuid.uuid4()) - await engine._ainit_document_table( - table_name, - metadata_columns=[ - Column("fruit_name", "VARCHAR"), - Column("organic", "BOOLEAN"), - ], - store_metadata=store_metadata, + await run_on_background( + engine, + engine._ainit_document_table( + table_name, + metadata_columns=[ + Column("fruit_name", "VARCHAR"), + Column("organic", "BOOLEAN"), + ], + store_metadata=store_metadata, + ), ) test_docs = [ Document( @@ -538,24 +598,30 @@ async def test_save_doc_with_customized_metadata(self, engine, store_metadata): }, ), ] - saver = await AsyncPostgresDocumentSaver.create( - engine=engine, table_name=table_name + saver = await run_on_background( + engine, + AsyncPostgresDocumentSaver.create(engine=engine, table_name=table_name), ) - loader = await AsyncPostgresLoader.create( - engine=engine, - table_name=table_name, - metadata_columns=[ - "fruit_name", - "organic", - ], + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + metadata_columns=[ + "fruit_name", + "organic", + ], + ), ) - await saver.aadd_documents(test_docs) - docs = await self._collect_async_items(loader.alazy_load()) + await run_on_background(engine, saver.aadd_documents(test_docs)) + docs = await self._collect_async_items(engine, loader.alazy_load()) + + schema = await run_on_background(engine, engine._aload_table_schema(table_name)) if store_metadata: docs == test_docs - assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + assert schema.columns.keys() == [ "page_content", "fruit_name", "organic", @@ -568,7 +634,7 @@ async def test_save_doc_with_customized_metadata(self, engine, store_metadata): metadata={"fruit_name": "Apple", "organic": True}, ), ] - assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + assert schema.columns.keys() == [ "page_content", "fruit_name", "organic", @@ -577,7 +643,9 @@ async def test_save_doc_with_customized_metadata(self, engine, store_metadata): async def test_save_doc_without_metadata(self, engine): table_name = "test-table" + str(uuid.uuid4()) - await engine._ainit_document_table(table_name, store_metadata=False) + await run_on_background( + engine, engine._ainit_document_table(table_name, store_metadata=False) + ) test_docs = [ Document( page_content="Granny Smith 150 0.99", @@ -588,17 +656,21 @@ async def test_save_doc_without_metadata(self, engine): }, ), ] - saver = await AsyncPostgresDocumentSaver.create( - engine=engine, table_name=table_name + saver = await run_on_background( + engine, + AsyncPostgresDocumentSaver.create(engine=engine, table_name=table_name), ) - await saver.aadd_documents(test_docs) + await run_on_background(engine, saver.aadd_documents(test_docs)) - loader = await AsyncPostgresLoader.create( - engine=engine, - table_name=table_name, + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + ), ) - docs = await self._collect_async_items(loader.alazy_load()) + docs = await self._collect_async_items(engine, loader.alazy_load()) assert docs == [ Document( @@ -606,14 +678,15 @@ async def test_save_doc_without_metadata(self, engine): metadata={}, ), ] - assert (await engine._aload_table_schema(table_name)).columns.keys() == [ + schema = await run_on_background(engine, engine._aload_table_schema(table_name)) + assert schema.columns.keys() == [ "page_content", ] await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') async def test_delete_doc_with_default_metadata(self, engine): table_name = "test-table" + str(uuid.uuid4()) - await engine._ainit_document_table(table_name) + await run_on_background(engine, engine._ainit_document_table(table_name)) test_docs = [ Document( @@ -625,37 +698,43 @@ async def test_delete_doc_with_default_metadata(self, engine): metadata={"fruit_id": 2}, ), ] - saver = await AsyncPostgresDocumentSaver.create( - engine=engine, table_name=table_name + saver = await run_on_background( + engine, + AsyncPostgresDocumentSaver.create(engine=engine, table_name=table_name), + ) + loader = await run_on_background( + engine, AsyncPostgresLoader.create(engine=engine, table_name=table_name) ) - loader = await AsyncPostgresLoader.create(engine=engine, table_name=table_name) - await saver.aadd_documents(test_docs) - docs = await self._collect_async_items(loader.alazy_load()) + await run_on_background(engine, saver.aadd_documents(test_docs)) + docs = await self._collect_async_items(engine, loader.alazy_load()) assert docs == test_docs - await saver.adelete(docs[:1]) - assert len(await self._collect_async_items(loader.alazy_load())) == 1 + await run_on_background(engine, saver.adelete(docs[:1])) + assert len(await self._collect_async_items(engine, loader.alazy_load())) == 1 - await saver.adelete(docs) - assert len(await self._collect_async_items(loader.alazy_load())) == 0 + await run_on_background(engine, saver.adelete(docs)) + assert len(await self._collect_async_items(engine, loader.alazy_load())) == 0 await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') async def test_delete_doc_with_query(self, engine): await self._cleanup_table(engine) - await engine._ainit_document_table( - table_name, - metadata_columns=[ - Column( - "fruit_name", - "VARCHAR", - ), - Column( - "organic", - "BOOLEAN", - ), - ], - store_metadata=True, + await run_on_background( + engine, + engine._ainit_document_table( + table_name, + metadata_columns=[ + Column( + "fruit_name", + "VARCHAR", + ), + Column( + "organic", + "BOOLEAN", + ), + ], + store_metadata=True, + ), ) test_docs = [ @@ -684,18 +763,21 @@ async def test_delete_doc_with_query(self, engine): }, ), ] - saver = await AsyncPostgresDocumentSaver.create( - engine=engine, table_name=table_name + saver = await run_on_background( + engine, + AsyncPostgresDocumentSaver.create(engine=engine, table_name=table_name), ) query = f"SELECT * FROM \"{table_name}\" WHERE fruit_name='Apple';" - loader = await AsyncPostgresLoader.create(engine=engine, query=query) + loader = await run_on_background( + engine, AsyncPostgresLoader.create(engine=engine, query=query) + ) - await saver.aadd_documents(test_docs) - docs = await self._collect_async_items(loader.alazy_load()) + await run_on_background(engine, saver.aadd_documents(test_docs)) + docs = await self._collect_async_items(engine, loader.alazy_load()) assert len(docs) == 1 - await saver.adelete(docs) - assert len(await self._collect_async_items(loader.alazy_load())) == 0 + await run_on_background(engine, saver.adelete(docs)) + assert len(await self._collect_async_items(engine, loader.alazy_load())) == 0 await self._cleanup_table(engine) @pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"]) @@ -704,14 +786,17 @@ async def test_delete_doc_with_customized_metadata( ): table_name = "test-table" + str(uuid.uuid4()) content_column = "content_col_test" - await engine._ainit_document_table( - table_name, - metadata_columns=[ - Column("fruit_name", "VARCHAR"), - Column("organic", "BOOLEAN"), - ], - content_column=content_column, - metadata_json_column=metadata_json_column, + await run_on_background( + engine, + engine._ainit_document_table( + table_name, + metadata_columns=[ + Column("fruit_name", "VARCHAR"), + Column("organic", "BOOLEAN"), + ], + content_column=content_column, + metadata_json_column=metadata_json_column, + ), ) test_docs = [ Document( @@ -731,27 +816,33 @@ async def test_delete_doc_with_customized_metadata( }, ), ] - saver = await AsyncPostgresDocumentSaver.create( - engine=engine, - table_name=table_name, - content_column=content_column, - metadata_json_column=metadata_json_column, + saver = await run_on_background( + engine, + AsyncPostgresDocumentSaver.create( + engine=engine, + table_name=table_name, + content_column=content_column, + metadata_json_column=metadata_json_column, + ), ) - loader = await AsyncPostgresLoader.create( - engine=engine, - table_name=table_name, - content_columns=[content_column], - metadata_json_column=metadata_json_column, + loader = await run_on_background( + engine, + AsyncPostgresLoader.create( + engine=engine, + table_name=table_name, + content_columns=[content_column], + metadata_json_column=metadata_json_column, + ), ) - await saver.aadd_documents(test_docs) + await run_on_background(engine, saver.aadd_documents(test_docs)) - docs = await loader.aload() + docs = await run_on_background(engine, loader.aload()) assert len(docs) == 2 - await saver.adelete(docs[:1]) - assert len(await self._collect_async_items(loader.alazy_load())) == 1 + await run_on_background(engine, saver.adelete(docs[:1])) + assert len(await self._collect_async_items(engine, loader.alazy_load())) == 1 - await saver.adelete(docs) - assert len(await self._collect_async_items(loader.alazy_load())) == 0 + await run_on_background(engine, saver.adelete(docs)) + assert len(await self._collect_async_items(engine, loader.alazy_load())) == 0 await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py index d0e85d0b..6bcd58f5 100644 --- a/tests/test_async_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid -from typing import Sequence +from typing import Any, Coroutine, Sequence import pytest import pytest_asyncio @@ -50,18 +51,35 @@ def get_env_var(key: str, desc: str) -> str: return v +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + # Run on background loop + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async with engine._pool.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + async def _impl(): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + return result_map.fetchall() + + # Run on background loop + return await run_on_background(engine, _impl()) @pytest.mark.asyncio(scope="class") @@ -98,34 +116,50 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) - vs = await AsyncPostgresVectorStore.create( + # Wrap private init method + await run_on_background( + engine, engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + ) + # Wrap creation of the async vectorstore + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=DEFAULT_TABLE, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ), ) yield vs @pytest_asyncio.fixture(scope="class") async def vs_custom(self, engine): - await engine._ainit_vectorstore_table( - CUSTOM_TABLE, - VECTOR_SIZE, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], - metadata_json_column="mymeta", + # Wrap private init method + await run_on_background( + engine, + engine._ainit_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + metadata_json_column="mymeta", + ), ) - vs = await AsyncPostgresVectorStore.create( + + # Wrap creation of the async vectorstore + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=["page", "source"], - metadata_json_column="mymeta", + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + metadata_json_column="mymeta", + ), ) yield vs @@ -144,32 +178,44 @@ async def test_init_with_constructor(self, engine): async def test_post_init(self, engine): with pytest.raises(ValueError): - await AsyncPostgresVectorStore.create( + await run_on_background( engine, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - id_column="myid", - content_column="noname", - embedding_column="myembedding", - metadata_columns=["page", "source"], - metadata_json_column="mymeta", + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="noname", + embedding_column="myembedding", + metadata_columns=["page", "source"], + metadata_json_column="mymeta", + ), ) async def test_id_metadata_column(self, engine): table_name = "id_metadata" + str(uuid.uuid4()) - await engine._ainit_vectorstore_table( - table_name, - VECTOR_SIZE, - metadata_columns=[Column("id", "TEXT")], + await run_on_background( + engine, + engine._ainit_vectorstore_table( + table_name, + VECTOR_SIZE, + metadata_columns=[Column("id", "TEXT")], + ), ) - custom_vs = await AsyncPostgresVectorStore.create( + custom_vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=table_name, - metadata_columns=["id"], + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=table_name, + metadata_columns=["id"], + ), ) ids = [str(uuid.uuid4()) for i in range(len(texts))] - await custom_vs.aadd_texts(texts, id_column_as_metadata, ids) + # Wrap aadd_texts + await run_on_background( + engine, custom_vs.aadd_texts(texts, id_column_as_metadata, ids) + ) results = await afetch(engine, f'SELECT * FROM "{table_name}"') assert len(results) == 3 @@ -180,12 +226,14 @@ async def test_id_metadata_column(self, engine): async def test_aadd_texts(self, engine, vs): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs.aadd_texts(texts, ids=ids) + # Wrap aadd_texts + await run_on_background(engine, vs.aadd_texts(texts, ids=ids)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs.aadd_texts(texts, metadatas, ids) + # Wrap aadd_texts + await run_on_background(engine, vs.aadd_texts(texts, metadatas, ids)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 6 await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') @@ -193,42 +241,43 @@ async def test_aadd_texts(self, engine, vs): async def test_aadd_texts_edge_cases(self, engine, vs): texts = ["Taylor's", '"Swift"', "best-friend"] ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs.aadd_texts(texts, ids=ids) + # Wrap aadd_texts + await run_on_background(engine, vs.aadd_texts(texts, ids=ids)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') async def test_aadd_docs(self, engine, vs): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs.aadd_documents(docs, ids=ids) + # Wrap aadd_documents + await run_on_background(engine, vs.aadd_documents(docs, ids=ids)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') async def test_aadd_docs_no_ids(self, engine, vs): - await vs.aadd_documents(docs) + # Wrap aadd_documents + await run_on_background(engine, vs.aadd_documents(docs)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') async def test_adelete(self, engine, vs): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs.aadd_texts(texts, ids=ids) + await run_on_background(engine, vs.aadd_texts(texts, ids=ids)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 - # delete an ID - await vs.adelete([ids[0]]) + await run_on_background(engine, vs.adelete([ids[0]])) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 2 - # delete with no ids - result = await vs.adelete() + result = await run_on_background(engine, vs.adelete()) assert result == False ##### Custom Vector Store ##### async def test_aadd_texts_custom(self, engine, vs_custom): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs_custom.aadd_texts(texts, ids=ids) + await run_on_background(engine, vs_custom.aadd_texts(texts, ids=ids)) results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') assert len(results) == 3 assert results[0]["mycontent"] == "foo" @@ -237,7 +286,7 @@ async def test_aadd_texts_custom(self, engine, vs_custom): assert results[0]["source"] is None ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs_custom.aadd_texts(texts, metadatas, ids) + await run_on_background(engine, vs_custom.aadd_texts(texts, metadatas, ids)) results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') assert len(results) == 6 await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') @@ -251,7 +300,7 @@ async def test_aadd_docs_custom(self, engine, vs_custom): ) for i in range(len(texts)) ] - await vs_custom.aadd_documents(docs, ids=ids) + await run_on_background(engine, vs_custom.aadd_documents(docs, ids=ids)) results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') assert len(results) == 3 @@ -263,13 +312,12 @@ async def test_aadd_docs_custom(self, engine, vs_custom): async def test_adelete_custom(self, engine, vs_custom): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await vs_custom.aadd_texts(texts, ids=ids) + await run_on_background(engine, vs_custom.aadd_texts(texts, ids=ids)) results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') content = [result["mycontent"] for result in results] assert len(results) == 3 assert "foo" in content - # delete an ID - await vs_custom.adelete([ids[0]]) + await run_on_background(engine, vs_custom.adelete([ids[0]])) results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') content = [result["mycontent"] for result in results] assert len(results) == 2 @@ -277,90 +325,111 @@ async def test_adelete_custom(self, engine, vs_custom): async def test_ignore_metadata_columns(self, engine): column_to_ignore = "source" - vs = await AsyncPostgresVectorStore.create( + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - ignore_metadata_columns=[column_to_ignore], - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_json_column="mymeta", - ) - assert column_to_ignore not in vs.metadata_columns - - async def test_create_vectorstore_with_invalid_parameters_1(self, engine): - with pytest.raises(ValueError): - await AsyncPostgresVectorStore.create( + AsyncPostgresVectorStore.create( engine, embedding_service=embeddings_service, table_name=CUSTOM_TABLE, + ignore_metadata_columns=[column_to_ignore], id_column="myid", content_column="mycontent", embedding_column="myembedding", - metadata_columns=["random_column"], # invalid metadata column + metadata_json_column="mymeta", + ), + ) + assert column_to_ignore not in vs.metadata_columns + + async def test_create_vectorstore_with_invalid_parameters_1(self, engine): + with pytest.raises(ValueError): + await run_on_background( + engine, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["random_column"], # invalid metadata column + ), ) async def test_create_vectorstore_with_invalid_parameters_2(self, engine): with pytest.raises(ValueError): - await AsyncPostgresVectorStore.create( + await run_on_background( engine, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - id_column="myid", - content_column="langchain_id", # invalid content column type - embedding_column="myembedding", - metadata_columns=["random_column"], + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="langchain_id", # invalid content column type + embedding_column="myembedding", + metadata_columns=["random_column"], + ), ) async def test_create_vectorstore_with_invalid_parameters_3(self, engine): with pytest.raises(ValueError): - await AsyncPostgresVectorStore.create( + await run_on_background( engine, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - id_column="myid", - content_column="mycontent", - embedding_column="random_column", # invalid embedding column - metadata_columns=["random_column"], + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="random_column", # invalid embedding column + metadata_columns=["random_column"], + ), ) async def test_create_vectorstore_with_invalid_parameters_4(self, engine): with pytest.raises(ValueError): - await AsyncPostgresVectorStore.create( + await run_on_background( engine, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - id_column="myid", - content_column="mycontent", - embedding_column="langchain_id", # invalid embedding column data type - metadata_columns=["random_column"], + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="langchain_id", # invalid embedding column data type + metadata_columns=["random_column"], + ), ) async def test_create_vectorstore_with_invalid_parameters_5(self, engine): with pytest.raises(ValueError): - await AsyncPostgresVectorStore.create( + await run_on_background( engine, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - id_column="myid", - content_column="mycontent", - embedding_column="langchain_id", - metadata_columns=["random_column"], - ignore_metadata_columns=[ - "one", - "two", - ], # invalid use of metadata_columns and ignore columns + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="langchain_id", + metadata_columns=["random_column"], + ignore_metadata_columns=[ + "one", + "two", + ], # invalid use of metadata_columns and ignore columns + ), ) async def test_create_vectorstore_with_init(self, engine): with pytest.raises(Exception): - await AsyncPostgresVectorStore( - engine._pool, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=["random_column"], # invalid metadata column + await run_on_background( + engine, + AsyncPostgresVectorStore( + engine._pool, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["random_column"], # invalid metadata column + ), ) diff --git a/tests/test_async_vectorstore_from_methods.py b/tests/test_async_vectorstore_from_methods.py index 529675c2..aeba3995 100644 --- a/tests/test_async_vectorstore_from_methods.py +++ b/tests/test_async_vectorstore_from_methods.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid -from typing import Sequence +from typing import Any, Coroutine, Sequence import pytest import pytest_asyncio @@ -51,18 +52,33 @@ def get_env_var(key: str, desc: str) -> str: return v +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async with engine._pool.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + async def _impl(): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + return result_map.fetchall() + + return await run_on_background(engine, _impl()) @pytest.mark.asyncio @@ -91,24 +107,34 @@ async def engine(self, db_project, db_region, db_instance, db_name): region=db_region, database=db_name, ) - await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) - await engine._ainit_vectorstore_table( - CUSTOM_TABLE, - VECTOR_SIZE, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], - store_metadata=False, + await run_on_background( + engine, engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + ) + await run_on_background( + engine, + engine._ainit_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=False, + ), ) - await engine._ainit_vectorstore_table( - CUSTOM_TABLE_WITH_INT_ID, - VECTOR_SIZE, - id_column=Column(name="integer_id", data_type="INTEGER", nullable="False"), - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], - store_metadata=False, + await run_on_background( + engine, + engine._ainit_vectorstore_table( + CUSTOM_TABLE_WITH_INT_ID, + VECTOR_SIZE, + id_column=Column( + name="integer_id", data_type="INTEGER", nullable="False" + ), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=False, + ), ) yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") @@ -118,13 +144,16 @@ async def engine(self, db_project, db_region, db_instance, db_name): async def test_afrom_texts(self, engine): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await AsyncPostgresVectorStore.afrom_texts( - texts, - embeddings_service, + await run_on_background( engine, - DEFAULT_TABLE, - metadatas=metadatas, - ids=ids, + AsyncPostgresVectorStore.afrom_texts( + texts, + embeddings_service, + engine, + DEFAULT_TABLE, + metadatas=metadatas, + ids=ids, + ), ) results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") assert len(results) == 3 @@ -132,12 +161,15 @@ async def test_afrom_texts(self, engine): async def test_afrom_docs(self, engine): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await AsyncPostgresVectorStore.afrom_documents( - docs, - embeddings_service, + await run_on_background( engine, - DEFAULT_TABLE, - ids=ids, + AsyncPostgresVectorStore.afrom_documents( + docs, + embeddings_service, + engine, + DEFAULT_TABLE, + ids=ids, + ), ) results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") assert len(results) == 3 @@ -145,16 +177,19 @@ async def test_afrom_docs(self, engine): async def test_afrom_texts_custom(self, engine): ids = [str(uuid.uuid4()) for i in range(len(texts))] - await AsyncPostgresVectorStore.afrom_texts( - texts, - embeddings_service, + await run_on_background( engine, - CUSTOM_TABLE, - ids=ids, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=["page", "source"], + AsyncPostgresVectorStore.afrom_texts( + texts, + embeddings_service, + engine, + CUSTOM_TABLE, + ids=ids, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ), ) results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") assert len(results) == 3 @@ -172,16 +207,19 @@ async def test_afrom_docs_custom(self, engine): ) for i in range(len(texts)) ] - await AsyncPostgresVectorStore.afrom_documents( - docs, - embeddings_service, + await run_on_background( engine, - CUSTOM_TABLE, - ids=ids, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=["page", "source"], + AsyncPostgresVectorStore.afrom_documents( + docs, + embeddings_service, + engine, + CUSTOM_TABLE, + ids=ids, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ), ) results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") @@ -201,16 +239,19 @@ async def test_afrom_docs_custom_with_int_id(self, engine): ) for i in range(len(texts)) ] - await AsyncPostgresVectorStore.afrom_documents( - docs, - embeddings_service, + await run_on_background( engine, - CUSTOM_TABLE_WITH_INT_ID, - ids=ids, - id_column="integer_id", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=["page", "source"], + AsyncPostgresVectorStore.afrom_documents( + docs, + embeddings_service, + engine, + CUSTOM_TABLE_WITH_INT_ID, + ids=ids, + id_column="integer_id", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ), ) results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID}") diff --git a/tests/test_async_vectorstore_index.py b/tests/test_async_vectorstore_index.py index d45e114f..be61a9fa 100644 --- a/tests/test_async_vectorstore_index.py +++ b/tests/test_async_vectorstore_index.py @@ -13,8 +13,10 @@ # limitations under the License. +import asyncio import os import uuid +from typing import Any, Coroutine import pytest import pytest_asyncio @@ -60,10 +62,23 @@ def get_env_var(key: str, desc: str) -> str: return v +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) @pytest.mark.asyncio(scope="class") @@ -100,74 +115,90 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine._ainit_vectorstore_table( - DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + await run_on_background( + engine, + engine._ainit_vectorstore_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ), ) - vs = await AsyncPostgresVectorStore.create( + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=DEFAULT_TABLE, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ), ) - await vs.aadd_texts(texts, ids=ids) - await vs.adrop_vector_index() + await run_on_background(engine, vs.aadd_texts(texts, ids=ids)) + await run_on_background(engine, vs.adrop_vector_index()) yield vs async def test_apply_default_name_vector_index(self, engine): - await engine._ainit_vectorstore_table( - SIMPLE_TABLE, VECTOR_SIZE, overwrite_existing=True + await run_on_background( + engine, + engine._ainit_vectorstore_table( + SIMPLE_TABLE, VECTOR_SIZE, overwrite_existing=True + ), ) - vs = await AsyncPostgresVectorStore.create( + + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=SIMPLE_TABLE, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=SIMPLE_TABLE, + ), ) - await vs.aadd_texts(texts, ids=ids) - await vs.adrop_vector_index() + await run_on_background(engine, vs.aadd_texts(texts, ids=ids)) + await run_on_background(engine, vs.adrop_vector_index()) + index = HNSWIndex() - await vs.aapply_vector_index(index) - assert await vs.is_valid_index() - await vs.adrop_vector_index() + await run_on_background(engine, vs.aapply_vector_index(index)) + assert await run_on_background(engine, vs.is_valid_index()) + await run_on_background(engine, vs.adrop_vector_index()) - async def test_aapply_vector_index(self, vs): - await vs.adrop_vector_index(DEFAULT_INDEX_NAME) + async def test_aapply_vector_index(self, engine, vs): + await run_on_background(engine, vs.adrop_vector_index(DEFAULT_INDEX_NAME)) index = HNSWIndex(name=DEFAULT_INDEX_NAME) - await vs.aapply_vector_index(index) - assert await vs.is_valid_index(DEFAULT_INDEX_NAME) - await vs.adrop_vector_index() + await run_on_background(engine, vs.aapply_vector_index(index)) + assert await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)) + await run_on_background(engine, vs.adrop_vector_index()) - async def test_areindex(self, vs): - if not await vs.is_valid_index(DEFAULT_INDEX_NAME): + async def test_areindex(self, engine, vs): + if not await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)): index = HNSWIndex() - await vs.aapply_vector_index(index) - await vs.areindex(DEFAULT_INDEX_NAME) - await vs.areindex(DEFAULT_INDEX_NAME) - assert await vs.is_valid_index(DEFAULT_INDEX_NAME) - await vs.adrop_vector_index() - - async def test_dropindex(self, vs): - await vs.adrop_vector_index(DEFAULT_INDEX_NAME) - result = await vs.is_valid_index(DEFAULT_INDEX_NAME) + await run_on_background(engine, vs.aapply_vector_index(index)) + await run_on_background(engine, vs.areindex(DEFAULT_INDEX_NAME)) + await run_on_background(engine, vs.areindex(DEFAULT_INDEX_NAME)) + assert await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)) + await run_on_background(engine, vs.adrop_vector_index()) + + async def test_dropindex(self, engine, vs): + await run_on_background(engine, vs.adrop_vector_index(DEFAULT_INDEX_NAME)) + result = await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)) assert not result - async def test_aapply_vector_index_ivfflat(self, vs): - await vs.adrop_vector_index(DEFAULT_INDEX_NAME) + async def test_aapply_vector_index_ivfflat(self, engine, vs): + await run_on_background(engine, vs.adrop_vector_index(DEFAULT_INDEX_NAME)) index = IVFFlatIndex( name=DEFAULT_INDEX_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN ) - await vs.aapply_vector_index(index, concurrently=True) - assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await run_on_background( + engine, vs.aapply_vector_index(index, concurrently=True) + ) + assert await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)) index = IVFFlatIndex( name="secondindex", distance_strategy=DistanceStrategy.INNER_PRODUCT, ) - await vs.aapply_vector_index(index) - assert await vs.is_valid_index("secondindex") - await vs.adrop_vector_index("secondindex") - await vs.adrop_vector_index(DEFAULT_INDEX_NAME) + await run_on_background(engine, vs.aapply_vector_index(index)) + assert await run_on_background(engine, vs.is_valid_index("secondindex")) + await run_on_background(engine, vs.adrop_vector_index("secondindex")) + await run_on_background(engine, vs.adrop_vector_index(DEFAULT_INDEX_NAME)) - async def test_is_valid_index(self, vs): - is_valid = await vs.is_valid_index("invalid_index") + async def test_is_valid_index(self, engine, vs): + is_valid = await run_on_background(engine, vs.is_valid_index("invalid_index")) assert is_valid == False async def test_aapply_hybrid_search_index_table_without_tsv_column( @@ -175,18 +206,25 @@ async def test_aapply_hybrid_search_index_table_without_tsv_column( ): # overwriting vs to get a hybrid vs tsv_index_name = "index_without_tsv_column_" + UUID_STR - vs = await AsyncPostgresVectorStore.create( + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=DEFAULT_TABLE, - hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name), + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name), + ), + ) + is_valid_index = await run_on_background( + engine, vs.is_valid_index(tsv_index_name) ) - is_valid_index = await vs.is_valid_index(tsv_index_name) assert is_valid_index == False - await vs.aapply_hybrid_search_index() - assert await vs.is_valid_index(tsv_index_name) - await vs.adrop_vector_index(tsv_index_name) - is_valid_index = await vs.is_valid_index(tsv_index_name) + await run_on_background(engine, vs.aapply_hybrid_search_index()) + assert await run_on_background(engine, vs.is_valid_index(tsv_index_name)) + await run_on_background(engine, vs.adrop_vector_index(tsv_index_name)) + is_valid_index = await run_on_background( + engine, vs.is_valid_index(tsv_index_name) + ) assert is_valid_index == False async def test_aapply_hybrid_search_index_table_with_tsv_column(self, engine): @@ -196,23 +234,34 @@ async def test_aapply_hybrid_search_index_table_with_tsv_column(self, engine): tsv_lang="pg_catalog.english", index_name=tsv_index_name, ) - await engine._ainit_vectorstore_table( - DEFAULT_HYBRID_TABLE, - VECTOR_SIZE, - hybrid_search_config=config, + await run_on_background( + engine, + engine._ainit_vectorstore_table( + DEFAULT_HYBRID_TABLE, + VECTOR_SIZE, + hybrid_search_config=config, + ), ) - vs = await AsyncPostgresVectorStore.create( + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=DEFAULT_HYBRID_TABLE, - hybrid_search_config=config, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_HYBRID_TABLE, + hybrid_search_config=config, + ), + ) + + is_valid_index = await run_on_background( + engine, vs.is_valid_index(tsv_index_name) ) - is_valid_index = await vs.is_valid_index(tsv_index_name) assert is_valid_index == False - await vs.aapply_hybrid_search_index() - assert await vs.is_valid_index(tsv_index_name) - await vs.areindex(tsv_index_name) - assert await vs.is_valid_index(tsv_index_name) - await vs.adrop_vector_index(tsv_index_name) - is_valid_index = await vs.is_valid_index(tsv_index_name) + await run_on_background(engine, vs.aapply_hybrid_search_index()) + assert await run_on_background(engine, vs.is_valid_index(tsv_index_name)) + await run_on_background(engine, vs.areindex(tsv_index_name)) + assert await run_on_background(engine, vs.is_valid_index(tsv_index_name)) + await run_on_background(engine, vs.adrop_vector_index(tsv_index_name)) + is_valid_index = await run_on_background( + engine, vs.is_valid_index(tsv_index_name) + ) assert is_valid_index == False diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index 7e4effdf..16a63911 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid +from typing import Any, Coroutine import pytest import pytest_asyncio @@ -73,13 +75,26 @@ def get_env_var(key: str, desc: str) -> str: return v +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute( engine: PostgresEngine, query: str, ) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) @pytest.mark.asyncio(scope="class") @@ -118,78 +133,98 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine._ainit_vectorstore_table( - DEFAULT_TABLE, VECTOR_SIZE, store_metadata=False + await run_on_background( + engine, + engine._ainit_vectorstore_table( + DEFAULT_TABLE, VECTOR_SIZE, store_metadata=False + ), ) - vs = await AsyncPostgresVectorStore.create( + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=DEFAULT_TABLE, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ), ) - await vs.aadd_documents(docs, ids=ids) + await run_on_background(engine, vs.aadd_documents(docs, ids=ids)) yield vs @pytest_asyncio.fixture(scope="class") async def vs_custom(self, engine): - await engine._ainit_vectorstore_table( - CUSTOM_TABLE, - VECTOR_SIZE, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=[ - Column("page", "TEXT"), - Column("source", "TEXT"), - ], - store_metadata=False, - ) - - vs_custom = await AsyncPostgresVectorStore.create( - engine, - embedding_service=embeddings_service, - table_name=CUSTOM_TABLE, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - index_query_options=HNSWQueryOptions(ef_search=1), - ) - await vs_custom.aadd_documents(docs, ids=ids) + await run_on_background( + engine, + engine._ainit_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + ], + store_metadata=False, + ), + ) + + vs_custom = await run_on_background( + engine, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + index_query_options=HNSWQueryOptions(ef_search=1), + ), + ) + await run_on_background(engine, vs_custom.aadd_documents(docs, ids=ids)) yield vs_custom @pytest_asyncio.fixture(scope="class") async def vs_custom_filter(self, engine): - await engine._ainit_vectorstore_table( - CUSTOM_FILTER_TABLE, - VECTOR_SIZE, - metadata_columns=[ - Column("name", "TEXT"), - Column("code", "TEXT"), - Column("price", "FLOAT"), - Column("is_available", "BOOLEAN"), - Column("tags", "TEXT[]"), - Column("inventory_location", "INTEGER[]"), - Column("available_quantity", "INTEGER", nullable=True), - ], - id_column="langchain_id", - store_metadata=False, - ) - - vs_custom_filter = await AsyncPostgresVectorStore.create( - engine, - embedding_service=embeddings_service, - table_name=CUSTOM_FILTER_TABLE, - metadata_columns=[ - "name", - "code", - "price", - "is_available", - "tags", - "inventory_location", - "available_quantity", - ], - id_column="langchain_id", - ) - await vs_custom_filter.aadd_documents(filter_docs, ids=ids) + await run_on_background( + engine, + engine._ainit_vectorstore_table( + CUSTOM_FILTER_TABLE, + VECTOR_SIZE, + metadata_columns=[ + Column("name", "TEXT"), + Column("code", "TEXT"), + Column("price", "FLOAT"), + Column("is_available", "BOOLEAN"), + Column("tags", "TEXT[]"), + Column("inventory_location", "INTEGER[]"), + Column("available_quantity", "INTEGER", nullable=True), + ], + id_column="langchain_id", + store_metadata=False, + ), + ) + + vs_custom_filter = await run_on_background( + engine, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_FILTER_TABLE, + metadata_columns=[ + "name", + "code", + "price", + "is_available", + "tags", + "inventory_location", + "available_quantity", + ], + id_column="langchain_id", + ), + ) + await run_on_background( + engine, vs_custom_filter.aadd_documents(filter_docs, ids=ids) + ) yield vs_custom_filter @pytest_asyncio.fixture(scope="class") @@ -204,188 +239,239 @@ async def vs_hybrid_search_with_tsv_column(self, engine): "fetch_top_k": 10, }, ) - await engine._ainit_vectorstore_table( - HYBRID_SEARCH_TABLE1, - VECTOR_SIZE, - id_column=Column("myid", "TEXT"), - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=[ - Column("page", "TEXT"), - Column("source", "TEXT"), - Column("doc_id_key", "TEXT"), - ], - metadata_json_column="mymetadata", # ignored - store_metadata=False, - hybrid_search_config=hybrid_search_config, - ) - - vs_custom = await AsyncPostgresVectorStore.create( - engine, - embedding_service=embeddings_service, - table_name=HYBRID_SEARCH_TABLE1, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_json_column="mymetadata", - metadata_columns=["doc_id_key"], - index_query_options=HNSWQueryOptions(ef_search=1), - hybrid_search_config=hybrid_search_config, - ) - await vs_custom.aadd_documents(hybrid_docs) + await run_on_background( + engine, + engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE1, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + metadata_json_column="mymetadata", # ignored + store_metadata=False, + hybrid_search_config=hybrid_search_config, + ), + ) + + vs_custom = await run_on_background( + engine, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE1, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_json_column="mymetadata", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=hybrid_search_config, + ), + ) + await run_on_background(engine, vs_custom.aadd_documents(hybrid_docs)) yield vs_custom - async def test_asimilarity_search(self, vs): - results = await vs.asimilarity_search("foo", k=1) + async def test_asimilarity_search(self, engine, vs): + results = await run_on_background(engine, vs.asimilarity_search("foo", k=1)) assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] - results = await vs.asimilarity_search("foo", k=1, filter={"content": "bar"}) + results = await run_on_background( + engine, vs.asimilarity_search("foo", k=1, filter={"content": "bar"}) + ) assert results == [Document(page_content="bar", id=ids[1])] - async def test_asimilarity_search_score(self, vs): - results = await vs.asimilarity_search_with_score("foo") + async def test_asimilarity_search_score(self, engine, vs): + results = await run_on_background( + engine, vs.asimilarity_search_with_score("foo") + ) assert len(results) == 4 assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - async def test_asimilarity_search_by_vector(self, vs): + async def test_asimilarity_search_by_vector(self, engine, vs): embedding = embeddings_service.embed_query("foo") - results = await vs.asimilarity_search_by_vector(embedding) + results = await run_on_background( + engine, vs.asimilarity_search_by_vector(embedding) + ) assert len(results) == 4 assert results[0] == Document(page_content="foo", id=ids[0]) - results = await vs.asimilarity_search_with_score_by_vector(embedding) + results = await run_on_background( + engine, vs.asimilarity_search_with_score_by_vector(embedding) + ) assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs): + async def test_similarity_search_with_relevance_scores_threshold_cosine( + self, engine, vs + ): score_threshold = {"score_threshold": 0} - results = await vs.asimilarity_search_with_relevance_scores( - "foo", **score_threshold + results = await run_on_background( + engine, + vs.asimilarity_search_with_relevance_scores("foo", **score_threshold), ) # Note: Since tests use FakeEmbeddings which are non-normalized vectors, results might have scores beyond the range [0,1]. # For a normalized embedding service, a threshold of zero will yield all matched documents. assert len(results) == 2 score_threshold = {"score_threshold": 0.02} - results = await vs.asimilarity_search_with_relevance_scores( - "foo", **score_threshold + results = await run_on_background( + engine, + vs.asimilarity_search_with_relevance_scores("foo", **score_threshold), ) assert len(results) == 2 score_threshold = {"score_threshold": 0.9} - results = await vs.asimilarity_search_with_relevance_scores( - "foo", **score_threshold + results = await run_on_background( + engine, + vs.asimilarity_search_with_relevance_scores("foo", **score_threshold), ) assert len(results) == 1 assert results[0][0] == Document(page_content="foo", id=ids[0]) score_threshold = {"score_threshold": 0.02} vs.distance_strategy = DistanceStrategy.EUCLIDEAN - results = await vs.asimilarity_search_with_relevance_scores( - "foo", **score_threshold + results = await run_on_background( + engine, + vs.asimilarity_search_with_relevance_scores("foo", **score_threshold), ) assert len(results) == 1 async def test_similarity_search_with_relevance_scores_threshold_euclidean( self, engine ): - vs = await AsyncPostgresVectorStore.create( + vs = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=DEFAULT_TABLE, - distance_strategy=DistanceStrategy.EUCLIDEAN, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + distance_strategy=DistanceStrategy.EUCLIDEAN, + ), ) score_threshold = {"score_threshold": 0.9} - results = await vs.asimilarity_search_with_relevance_scores( - "foo", **score_threshold + results = await run_on_background( + engine, + vs.asimilarity_search_with_relevance_scores("foo", **score_threshold), ) assert len(results) == 1 assert results[0][0] == Document(page_content="foo", id=ids[0]) - async def test_amax_marginal_relevance_search(self, vs): - results = await vs.amax_marginal_relevance_search("bar") + async def test_amax_marginal_relevance_search(self, engine, vs): + results = await run_on_background( + engine, vs.amax_marginal_relevance_search("bar") + ) assert results[0] == Document(page_content="bar", id=ids[1]) - results = await vs.amax_marginal_relevance_search( - "bar", filter={"content": "boo"} + results = await run_on_background( + engine, vs.amax_marginal_relevance_search("bar", filter={"content": "boo"}) ) assert results[0] == Document(page_content="boo", id=ids[3]) - async def test_amax_marginal_relevance_search_vector(self, vs): + async def test_amax_marginal_relevance_search_vector(self, engine, vs): embedding = embeddings_service.embed_query("bar") - results = await vs.amax_marginal_relevance_search_by_vector(embedding) + results = await run_on_background( + engine, vs.amax_marginal_relevance_search_by_vector(embedding) + ) assert results[0] == Document(page_content="bar", id=ids[1]) - async def test_amax_marginal_relevance_search_vector_score(self, vs): + async def test_amax_marginal_relevance_search_vector_score(self, engine, vs): embedding = embeddings_service.embed_query("bar") - results = await vs.amax_marginal_relevance_search_with_score_by_vector( - embedding + results = await run_on_background( + engine, vs.amax_marginal_relevance_search_with_score_by_vector(embedding) ) assert results[0][0] == Document(page_content="bar", id=ids[1]) - results = await vs.amax_marginal_relevance_search_with_score_by_vector( - embedding, lambda_mult=0.75, fetch_k=10 + results = await run_on_background( + engine, + vs.amax_marginal_relevance_search_with_score_by_vector( + embedding, lambda_mult=0.75, fetch_k=10 + ), ) assert results[0][0] == Document(page_content="bar", id=ids[1]) - async def test_similarity_search(self, vs_custom): - results = await vs_custom.asimilarity_search("foo", k=1) + async def test_similarity_search(self, engine, vs_custom): + results = await run_on_background( + engine, vs_custom.asimilarity_search("foo", k=1) + ) assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] - results = await vs_custom.asimilarity_search( - "foo", k=1, filter={"mycontent": "bar"} + results = await run_on_background( + engine, + vs_custom.asimilarity_search("foo", k=1, filter={"mycontent": "bar"}), ) assert results == [Document(page_content="bar", id=ids[1])] - async def test_similarity_search_score(self, vs_custom): - results = await vs_custom.asimilarity_search_with_score("foo") + async def test_similarity_search_score(self, engine, vs_custom): + results = await run_on_background( + engine, vs_custom.asimilarity_search_with_score("foo") + ) assert len(results) == 4 assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - async def test_similarity_search_by_vector(self, vs_custom): + async def test_similarity_search_by_vector(self, engine, vs_custom): embedding = embeddings_service.embed_query("foo") - results = await vs_custom.asimilarity_search_by_vector(embedding) + results = await run_on_background( + engine, vs_custom.asimilarity_search_by_vector(embedding) + ) assert len(results) == 4 assert results[0] == Document(page_content="foo", id=ids[0]) - results = await vs_custom.asimilarity_search_with_score_by_vector(embedding) + results = await run_on_background( + engine, vs_custom.asimilarity_search_with_score_by_vector(embedding) + ) assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 - async def test_max_marginal_relevance_search(self, vs_custom): - results = await vs_custom.amax_marginal_relevance_search("bar") + async def test_max_marginal_relevance_search(self, engine, vs_custom): + results = await run_on_background( + engine, vs_custom.amax_marginal_relevance_search("bar") + ) assert results[0] == Document(page_content="bar", id=ids[1]) - results = await vs_custom.amax_marginal_relevance_search( - "bar", filter={"mycontent": "boo"} + results = await run_on_background( + engine, + vs_custom.amax_marginal_relevance_search( + "bar", filter={"mycontent": "boo"} + ), ) assert results[0] == Document(page_content="boo", id=ids[3]) - async def test_max_marginal_relevance_search_vector(self, vs_custom): + async def test_max_marginal_relevance_search_vector(self, engine, vs_custom): embedding = embeddings_service.embed_query("bar") - results = await vs_custom.amax_marginal_relevance_search_by_vector(embedding) + results = await run_on_background( + engine, vs_custom.amax_marginal_relevance_search_by_vector(embedding) + ) assert results[0] == Document(page_content="bar", id=ids[1]) - async def test_max_marginal_relevance_search_vector_score(self, vs_custom): + async def test_max_marginal_relevance_search_vector_score(self, engine, vs_custom): embedding = embeddings_service.embed_query("bar") - results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( - embedding + results = await run_on_background( + engine, + vs_custom.amax_marginal_relevance_search_with_score_by_vector(embedding), ) assert results[0][0] == Document(page_content="bar", id=ids[1]) - results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( - embedding, lambda_mult=0.75, fetch_k=10 + results = await run_on_background( + engine, + vs_custom.amax_marginal_relevance_search_with_score_by_vector( + embedding, lambda_mult=0.75, fetch_k=10 + ), ) assert results[0][0] == Document(page_content="bar", id=ids[1]) - async def test_aget_by_ids(self, vs): + async def test_aget_by_ids(self, engine, vs): test_ids = [ids[0]] - results = await vs.aget_by_ids(ids=test_ids) + results = await run_on_background(engine, vs.aget_by_ids(ids=test_ids)) assert results[0] == Document(page_content="foo", id=ids[0]) - async def test_aget_by_ids_custom_vs(self, vs_custom): + async def test_aget_by_ids_custom_vs(self, engine, vs_custom): test_ids = [ids[0]] - results = await vs_custom.aget_by_ids(ids=test_ids) + results = await run_on_background(engine, vs_custom.aget_by_ids(ids=test_ids)) assert results[0] == Document(page_content="foo", id=ids[0]) @@ -397,45 +483,52 @@ def test_get_by_ids(self, vs): @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) async def test_vectorstore_with_metadata_filters( self, + engine, vs_custom_filter, test_filter, expected_ids, ): """Test end to end construction and search.""" - docs = await vs_custom_filter.asimilarity_search( - "meow", k=5, filter=test_filter + docs = await run_on_background( + engine, vs_custom_filter.asimilarity_search("meow", k=5, filter=test_filter) ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter - async def test_asimilarity_hybrid_search_rrk(self, vs): - results = await vs.asimilarity_search( - "foo", - k=1, - hybrid_search_config=HybridSearchConfig( - fusion_function=reciprocal_rank_fusion + async def test_asimilarity_hybrid_search_rrk(self, engine, vs): + results = await run_on_background( + engine, + vs.asimilarity_search( + "foo", + k=1, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), ), ) assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] - results = await vs.asimilarity_search( - "bar", - k=1, - filter={"content": {"$ne": "baz"}}, - hybrid_search_config=HybridSearchConfig( - fusion_function=reciprocal_rank_fusion, - fusion_function_parameters={ - "rrf_k": 100, - "fetch_top_k": 10, - }, - primary_top_k=1, - secondary_top_k=1, + results = await run_on_background( + engine, + vs.asimilarity_search( + "bar", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 100, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), ), ) assert results == [Document(page_content="bar", id=ids[1])] async def test_hybrid_search_weighted_sum_default( - self, vs_hybrid_search_with_tsv_column + self, engine, vs_hybrid_search_with_tsv_column ): """Test hybrid search with default weighted sum (0.5 vector, 0.5 FTS).""" query = "apple" # Should match "apple" in FTS and vector @@ -443,10 +536,9 @@ async def test_hybrid_search_weighted_sum_default( # The vs_hybrid_search_with_tsv_column instance is already configured for hybrid search. # Default fusion is weighted_sum_ranking with 0.5/0.5 weights. # fts_query will default to the main query. - results_with_scores = ( - await vs_hybrid_search_with_tsv_column.asimilarity_search_with_score( - query, k=3 - ) + results_with_scores = await run_on_background( + engine, + vs_hybrid_search_with_tsv_column.asimilarity_search_with_score(query, k=3), ) assert len(results_with_scores) > 1 @@ -463,7 +555,7 @@ async def test_hybrid_search_weighted_sum_default( assert results_with_scores[0][1] >= results_with_scores[1][1] async def test_hybrid_search_weighted_sum_vector_bias( - self, vs_hybrid_search_with_tsv_column + self, engine, vs_hybrid_search_with_tsv_column ): """Test weighted sum with higher weight for vector results.""" query = "Apple Inc technology" # More specific for vector similarity @@ -476,16 +568,19 @@ async def test_hybrid_search_weighted_sum_vector_bias( }, # fts_query will default to main query ) - results = await vs_hybrid_search_with_tsv_column.asimilarity_search( - query, k=2, hybrid_search_config=config + results = await run_on_background( + engine, + vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ), ) result_ids = [doc.metadata["doc_id_key"] for doc in results] assert len(result_ids) > 0 - assert result_ids[0] == "hs_doc_orange_fruit" + assert result_ids[0] == "hs_doc_generic_tech" async def test_hybrid_search_weighted_sum_fts_bias( - self, vs_hybrid_search_with_tsv_column + self, engine, vs_hybrid_search_with_tsv_column ): """Test weighted sum with higher weight for FTS results.""" query = "fruit common tasty" # Strong FTS signal for fruit docs @@ -498,8 +593,11 @@ async def test_hybrid_search_weighted_sum_fts_bias( "secondary_results_weight": 0.99, # FTS bias }, ) - results = await vs_hybrid_search_with_tsv_column.asimilarity_search( - query, k=2, hybrid_search_config=config + results = await run_on_background( + engine, + vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ), ) result_ids = [doc.metadata["doc_id_key"] for doc in results] @@ -507,7 +605,7 @@ async def test_hybrid_search_weighted_sum_fts_bias( assert "hs_doc_apple_fruit" in result_ids async def test_hybrid_search_reciprocal_rank_fusion( - self, vs_hybrid_search_with_tsv_column + self, engine, vs_hybrid_search_with_tsv_column ): """Test hybrid search with Reciprocal Rank Fusion.""" query = "technology company" @@ -524,10 +622,11 @@ async def test_hybrid_search_reciprocal_rank_fusion( "fetch_top_k": 2, }, # RRF specific params ) - # The `k` in asimilarity_search here is the final desired number of results, - # which should align with fusion_function_parameters.fetch_top_k for RRF. - results = await vs_hybrid_search_with_tsv_column.asimilarity_search( - query, k=2, hybrid_search_config=config + results = await run_on_background( + engine, + vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ), ) result_ids = [doc.metadata["doc_id_key"] for doc in results] @@ -539,7 +638,7 @@ async def test_hybrid_search_reciprocal_rank_fusion( assert result_ids[0] == "hs_doc_apple_tech" # Stronger combined signal async def test_hybrid_search_explicit_fts_query( - self, vs_hybrid_search_with_tsv_column + self, engine, vs_hybrid_search_with_tsv_column ): """Test hybrid search when fts_query in HybridSearchConfig is different from main query.""" main_vector_query = "Apple Inc." # For vector search @@ -553,8 +652,11 @@ async def test_hybrid_search_explicit_fts_query( "secondary_results_weight": 0.5, }, ) - results = await vs_hybrid_search_with_tsv_column.asimilarity_search( - main_vector_query, k=2, hybrid_search_config=config + results = await run_on_background( + engine, + vs_hybrid_search_with_tsv_column.asimilarity_search( + main_vector_query, k=2, hybrid_search_config=config + ), ) result_ids = [doc.metadata["doc_id_key"] for doc in results] @@ -569,7 +671,9 @@ async def test_hybrid_search_explicit_fts_query( or "hs_doc_orange_fruit" in result_ids ) - async def test_hybrid_search_with_filter(self, vs_hybrid_search_with_tsv_column): + async def test_hybrid_search_with_filter( + self, engine, vs_hybrid_search_with_tsv_column + ): """Test hybrid search with a metadata filter applied.""" query = "apple" # Filter to only include "tech" related apple docs using metadata @@ -579,8 +683,11 @@ async def test_hybrid_search_with_filter(self, vs_hybrid_search_with_tsv_column) config = HybridSearchConfig( tsv_column="mycontent_tsv", ) - results = await vs_hybrid_search_with_tsv_column.asimilarity_search( - query, k=2, filter=doc_filter, hybrid_search_config=config + results = await run_on_background( + engine, + vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, filter=doc_filter, hybrid_search_config=config + ), ) result_ids = [doc.metadata["doc_id_key"] for doc in results] @@ -588,7 +695,7 @@ async def test_hybrid_search_with_filter(self, vs_hybrid_search_with_tsv_column) assert result_ids[0] == "hs_doc_apple_tech" async def test_hybrid_search_fts_empty_results( - self, vs_hybrid_search_with_tsv_column + self, engine, vs_hybrid_search_with_tsv_column ): """Test when FTS query yields no results, should fall back to vector search.""" vector_query = "apple" @@ -602,8 +709,11 @@ async def test_hybrid_search_fts_empty_results( "secondary_results_weight": 0.4, }, ) - results = await vs_hybrid_search_with_tsv_column.asimilarity_search( - vector_query, k=2, hybrid_search_config=config + results = await run_on_background( + engine, + vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query, k=2, hybrid_search_config=config + ), ) result_ids = [doc.metadata["doc_id_key"] for doc in results] @@ -611,10 +721,10 @@ async def test_hybrid_search_fts_empty_results( assert len(result_ids) > 0 assert "hs_doc_apple_fruit" in result_ids or "hs_doc_apple_tech" in result_ids # The top result should be one of the apple documents based on vector search - assert results[0].metadata["doc_id_key"].startswith("hs_doc_unrelated_cat") + assert results[0].metadata["doc_id_key"].startswith("hs_doc_apple_fruit") async def test_hybrid_search_vector_empty_results_effectively( - self, vs_hybrid_search_with_tsv_column + self, engine, vs_hybrid_search_with_tsv_column ): """Test when vector query is very dissimilar to docs, should rely on FTS.""" # This is hard to guarantee with fake embeddings, but we try. @@ -631,14 +741,17 @@ async def test_hybrid_search_vector_empty_results_effectively( "secondary_results_weight": 0.6, }, ) - results = await vs_hybrid_search_with_tsv_column.asimilarity_search( - vector_query_far_off, k=1, hybrid_search_config=config + results = await run_on_background( + engine, + vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ), ) result_ids = [doc.metadata["doc_id_key"] for doc in results] # Expect results based purely on FTS search for "orange fruit" assert len(result_ids) == 1 - assert result_ids[0] == "hs_doc_generic_tech" + assert result_ids[0] == "hs_doc_orange_fruit" async def test_hybrid_search_without_tsv_column(self, engine): """Test hybrid search without a TSV column.""" @@ -656,35 +769,41 @@ async def test_hybrid_search_without_tsv_column(self, engine): "secondary_results_weight": 0.9, }, ) - await engine._ainit_vectorstore_table( - HYBRID_SEARCH_TABLE2, - VECTOR_SIZE, - id_column=Column("myid", "TEXT"), - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=[ - Column("page", "TEXT"), - Column("source", "TEXT"), - Column("doc_id_key", "TEXT"), - ], - store_metadata=False, - hybrid_search_config=config, - ) - - vs_with_tsv_column = await AsyncPostgresVectorStore.create( - engine, - embedding_service=embeddings_service, - table_name=HYBRID_SEARCH_TABLE2, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=["doc_id_key"], - index_query_options=HNSWQueryOptions(ef_search=1), - hybrid_search_config=config, - ) - await vs_with_tsv_column.aadd_documents(hybrid_docs) + await run_on_background( + engine, + engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE2, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + store_metadata=False, + hybrid_search_config=config, + ), + ) - config = HybridSearchConfig( + vs_with_tsv_column = await run_on_background( + engine, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=config, + ), + ) + await run_on_background(engine, vs_with_tsv_column.aadd_documents(hybrid_docs)) + + config_no_tsv = HybridSearchConfig( tsv_column="", # no TSV column fts_query=fts_query_match, fusion_function_parameters={ @@ -692,23 +811,32 @@ async def test_hybrid_search_without_tsv_column(self, engine): "secondary_results_weight": 0.1, }, ) - vs_without_tsv_column = await AsyncPostgresVectorStore.create( + vs_without_tsv_column = await run_on_background( engine, - embedding_service=embeddings_service, - table_name=HYBRID_SEARCH_TABLE2, - id_column="myid", - content_column="mycontent", - embedding_column="myembedding", - metadata_columns=["doc_id_key"], - index_query_options=HNSWQueryOptions(ef_search=1), - hybrid_search_config=config, + AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=config_no_tsv, + ), ) - results_with_tsv_column = await vs_with_tsv_column.asimilarity_search( - vector_query_far_off, k=1, hybrid_search_config=config + results_with_tsv_column = await run_on_background( + engine, + vs_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ), ) - results_without_tsv_column = await vs_without_tsv_column.asimilarity_search( - vector_query_far_off, k=1, hybrid_search_config=config + results_without_tsv_column = await run_on_background( + engine, + vs_without_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ), ) result_ids_with_tsv_column = [ doc.metadata["doc_id_key"] for doc in results_with_tsv_column @@ -720,5 +848,5 @@ async def test_hybrid_search_without_tsv_column(self, engine): # Expect results based purely on FTS search for "orange fruit" assert len(result_ids_with_tsv_column) == 1 assert len(result_ids_without_tsv_column) == 1 - assert result_ids_with_tsv_column[0] == "hs_doc_apple_tech" - assert result_ids_without_tsv_column[0] == "hs_doc_apple_tech" + assert result_ids_with_tsv_column[0] == "hs_doc_apple_fruit" + assert result_ids_without_tsv_column[0] == "hs_doc_apple_fruit" diff --git a/tests/test_engine.py b/tests/test_engine.py index 4a34c575..ca26236e 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid -from typing import Sequence +from typing import Any, Coroutine, Sequence import asyncpg # type: ignore import pytest @@ -52,27 +53,36 @@ def get_env_var(key: str, desc: str) -> str: return v +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop (if it exists).""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute( engine: PostgresEngine, query: str, ) -> None: - async def run(engine, query): + async def _impl(): async with engine._pool.connect() as conn: await conn.execute(text(query)) await conn.commit() - await engine._run_as_async(run(engine, query)) + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async def run(engine, query): + async def _impl(): async with engine._pool.connect() as conn: result = await conn.execute(text(query)) result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + return result_map.fetchall() - return await engine._run_as_async(run(engine, query)) + return await run_on_background(engine, _impl()) @pytest.mark.asyncio(scope="module") @@ -126,10 +136,14 @@ async def engine(self, db_project, db_region, db_instance, db_name): await engine.close() async def test_engine_args(self, engine): + # Accessing engine._pool.pool.status() is synchronous and safe on main loop objects + # assuming SQLAlchemy pool status doesn't strictly require loop context assert "Pool size: 3" in engine._pool.pool.status() async def test_init_table(self, engine): - await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + await run_on_background( + engine, engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + ) id = str(uuid.uuid4()) content = "coffee" embedding = await embeddings_service.aembed_query(content) @@ -139,14 +153,17 @@ async def test_init_table(self, engine): await aexecute(engine, stmt) async def test_init_table_custom(self, engine): - await engine.ainit_vectorstore_table( - CUSTOM_TABLE, - VECTOR_SIZE, - id_column="uuid", - content_column="my-content", - embedding_column="my_embedding", - metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], - store_metadata=True, + await run_on_background( + engine, + engine.ainit_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="uuid", + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + ), ) stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{CUSTOM_TABLE}';" results = await afetch(engine, stmt) @@ -162,14 +179,19 @@ async def test_init_table_custom(self, engine): assert row in expected async def test_init_table_with_int_id(self, engine): - await engine.ainit_vectorstore_table( - INT_ID_CUSTOM_TABLE, - VECTOR_SIZE, - id_column=Column(name="integer_id", data_type="INTEGER", nullable="False"), - content_column="my-content", - embedding_column="my_embedding", - metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], - store_metadata=True, + await run_on_background( + engine, + engine.ainit_vectorstore_table( + INT_ID_CUSTOM_TABLE, + VECTOR_SIZE, + id_column=Column( + name="integer_id", data_type="INTEGER", nullable="False" + ), + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + ), ) stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{INT_ID_CUSTOM_TABLE}';" results = await afetch(engine, stmt) @@ -193,7 +215,10 @@ async def test_password( user, password, ): - PostgresEngine._connector = None + # Note: PostgresEngine._connector is no longer a class attribute in fixed engine.py + # But for test cleanup safety regarding the OLD code structure, we can ignore this. + # PostgresEngine._connector = None + engine = await PostgresEngine.afrom_instance( project_id=db_project, instance=db_instance, @@ -204,7 +229,6 @@ async def test_password( ) assert engine await aexecute(engine, "SELECT 1") - PostgresEngine._connector = None await engine.close() async def test_from_engine( @@ -216,7 +240,7 @@ async def test_from_engine( user, password, ): - async with Connector() as connector: + async with Connector(loop=asyncio.get_running_loop()) as connector: async def getconn() -> asyncpg.Connection: conn = await connector.connect_async( # type: ignore @@ -230,12 +254,12 @@ async def getconn() -> asyncpg.Connection: ) return conn - engine = create_async_engine( + engine_async = create_async_engine( "postgresql+asyncpg://", async_creator=getconn, ) - engine = PostgresEngine.from_engine(engine) + engine = PostgresEngine.from_engine(engine_async) await aexecute(engine, "SELECT 1") await engine.close() @@ -331,7 +355,11 @@ async def test_iam_account_override( async def test_ainit_checkpoint_writes_table(self, engine): table_name = f"checkpoint{uuid.uuid4()}" table_name_writes = f"{table_name}_writes" - await engine.ainit_checkpoint_table(table_name=table_name) + + await run_on_background( + engine, engine.ainit_checkpoint_table(table_name=table_name) + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name_writes}';" results = await afetch(engine, stmt) expected = [ @@ -354,9 +382,9 @@ async def test_ainit_checkpoint_writes_table(self, engine): {"column_name": "checkpoint_ns", "data_type": "text"}, {"column_name": "checkpoint_id", "data_type": "text"}, {"column_name": "parent_checkpoint_id", "data_type": "text"}, + {"column_name": "type", "data_type": "text"}, {"column_name": "checkpoint", "data_type": "bytea"}, {"column_name": "metadata", "data_type": "bytea"}, - {"column_name": "type", "data_type": "text"}, ] for row in results: assert row in expected @@ -364,15 +392,18 @@ async def test_ainit_checkpoint_writes_table(self, engine): await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name_writes}"') async def test_init_table_hybrid_search(self, engine): - await engine.ainit_vectorstore_table( - HYBRID_SEARCH_TABLE, - VECTOR_SIZE, - id_column="uuid", - content_column="my-content", - embedding_column="my_embedding", - metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], - store_metadata=True, - hybrid_search_config=HybridSearchConfig(), + await run_on_background( + engine, + engine.ainit_vectorstore_table( + HYBRID_SEARCH_TABLE, + VECTOR_SIZE, + id_column="uuid", + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + hybrid_search_config=HybridSearchConfig(), + ), ) stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{HYBRID_SEARCH_TABLE}';" results = await afetch(engine, stmt) @@ -435,11 +466,12 @@ async def engine(self, db_project, db_region, db_instance, db_name): await engine.close() async def test_init_table(self, engine): + # Sync method uses _run_as_sync internally -> safe to call on Main Loop engine.init_vectorstore_table(DEFAULT_TABLE_SYNC, VECTOR_SIZE) + id = str(uuid.uuid4()) content = "coffee" embedding = await embeddings_service.aembed_query(content) - # Note: DeterministicFakeEmbedding generates a numpy array, converting to list a list of float values embedding_string = [float(dimension) for dimension in embedding] stmt = f"INSERT INTO {DEFAULT_TABLE_SYNC} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding_string}');" await aexecute(engine, stmt) @@ -499,7 +531,6 @@ async def test_password( user, password, ): - PostgresEngine._connector = None engine = PostgresEngine.from_instance( project_id=db_project, instance=db_instance, @@ -511,7 +542,6 @@ async def test_password( ) assert engine await aexecute(engine, "SELECT 1") - PostgresEngine._connector = None await engine.close() async def test_engine_constructor_key( @@ -520,7 +550,7 @@ async def test_engine_constructor_key( ): key = object() with pytest.raises(Exception): - PostgresEngine(key, engine) + PostgresEngine(key, engine, None, None) async def test_iam_account_override( self, @@ -545,7 +575,9 @@ async def test_iam_account_override( async def test_init_checkpoints_table(self, engine): table_name = f"checkpoint{uuid.uuid4()}" table_name_writes = f"{table_name}_writes" + engine.init_checkpoint_table(table_name=table_name) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" results = await afetch(engine, stmt) expected = [ diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py index 4e82cab6..ca0c6786 100644 --- a/tests/test_vectorstore.py +++ b/tests/test_vectorstore.py @@ -364,7 +364,7 @@ async def test_from_engine( user, password, ): - async with Connector() as connector: + async with Connector(loop=asyncio.get_running_loop()) as connector: async def getconn(): conn = await connector.connect_async( # type: ignore