diff --git a/pgvector/psycopg/__init__.py b/pgvector/psycopg/__init__.py index ecb8fe6..c85bd02 100644 --- a/pgvector/psycopg/__init__.py +++ b/pgvector/psycopg/__init__.py @@ -4,7 +4,7 @@ from psycopg.types import TypeInfo from ..utils import from_db, from_db_binary, to_db, to_db_binary -__all__ = ['register_vector'] +__all__ = ["register_vector"] class VectorDumper(Dumper): @@ -43,14 +43,27 @@ def load(self, data): return from_db_binary(data) +def _register_vector_adapters(context): + adapters = context.adapters + adapters.register_dumper("numpy.ndarray", VectorDumper) + adapters.register_dumper("numpy.ndarray", VectorBinaryDumper) + adapters.register_loader("vector", VectorLoader) + adapters.register_loader("vector", VectorBinaryLoader) + + def register_vector(context): - info = TypeInfo.fetch(context, 'vector') + info = TypeInfo.fetch(context, "vector") if info is None: - raise psycopg.ProgrammingError('vector type not found in the database') + raise psycopg.ProgrammingError("vector type not found in the database") info.register(context) - adapters = context.adapters - adapters.register_dumper('numpy.ndarray', VectorDumper) - adapters.register_dumper('numpy.ndarray', VectorBinaryDumper) - adapters.register_loader(info.oid, VectorLoader) - adapters.register_loader(info.oid, VectorBinaryLoader) + _register_vector_adapters(context) + + +async def async_register_vector(async_context): + info = await TypeInfo.fetch(async_context, "vector") + if info is None: + raise psycopg.ProgrammingError("vector type not found in the database") + info.register(async_context) + + _register_vector_adapters(async_context)