diff --git a/README.md b/README.md index 482c3e5..76f9ce2 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ Update the model ```elixir schema "items" do - field :embedding, Pgvector.Ecto.Vector + field :embedding, Pgvector end ``` @@ -82,7 +82,7 @@ Insert a vector ```elixir alias MyApp.{Repo, Item} -Repo.insert(%Item{embedding: [1, 2, 3]}) +Repo.insert(%Item{embedding: Pgvector.new([1, 2, 3])}) ``` Get the nearest neighbors @@ -127,13 +127,13 @@ Postgrex.query!(pid, "CREATE TABLE items (embedding vector(3))", []) Insert a vector ```elixir -Postgrex.query!(pid, "INSERT INTO items (embedding) VALUES ($1)", [[1, 2, 3]]) +Postgrex.query!(pid, "INSERT INTO items (embedding) VALUES ($1)", [Pgvector.new([1, 2, 3])]) ``` Get the nearest neighbors ```elixir -Postgrex.query!(pid, "SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5", [[1, 2, 3]]) +Postgrex.query!(pid, "SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5", [Pgvector.new([1, 2, 3])]) ``` Add an approximate index diff --git a/lib/pgvector.ex b/lib/pgvector.ex new file mode 100644 index 0000000..4b5c4bc --- /dev/null +++ b/lib/pgvector.ex @@ -0,0 +1,90 @@ +defmodule Pgvector do + @moduledoc """ + todo + """ + + defstruct [:data] + + @doc """ + todo + """ + def new(binary) when is_binary(binary) do + %Pgvector{data: binary} + end + + def new(list) when is_list(list) do + dim = list |> length() + bin = for v <- list, into: "", do: <> + new(<>) + end + + if Code.ensure_loaded?(Nx) do + def new(tensor) when is_struct(tensor, Nx.Tensor) do + if Nx.rank(tensor) != 1 do + raise ArgumentError, "expected rank to be 1" + end + + dim = tensor |> Nx.size() + bin = tensor |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_native_to_big() + new(<>) + end + + defp f32_native_to_big(binary) do + if System.endianness() == :big do + binary + else + for <>, into: "", do: <> + end + end + end + + def to_binary(vector) when is_struct(vector, Pgvector) do + vector.data + end + + def to_list(vector) when is_struct(vector, Pgvector) do + <> = vector.data + for <>, do: v + end + + if Code.ensure_loaded?(Nx) do + def to_tensor(vector) when is_struct(vector, Pgvector) do + <> = vector.data + bin |> f32_big_to_native() |> Nx.from_binary(:f32) + end + + defp f32_big_to_native(binary) do + if System.endianness() == :big do + binary + else + for <>, into: "", do: <> + end + end + end + + if Code.ensure_loaded?(Ecto) do + use Ecto.Type + + def type, do: :vector + + def cast(value) do + {:ok, value |> Pgvector.new()} + end + + def load(data) do + {:ok, data} + end + + def dump(value) do + {:ok, value} + end + end +end + +defimpl Inspect, for: Pgvector do + import Inspect.Algebra + + def inspect(vec, opts) do + concat(["Pgvector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"]) + end +end diff --git a/lib/pgvector/ecto/vector.ex b/lib/pgvector/ecto/vector.ex deleted file mode 100644 index 38496d3..0000000 --- a/lib/pgvector/ecto/vector.ex +++ /dev/null @@ -1,26 +0,0 @@ -if Code.ensure_loaded?(Ecto) do - defmodule Pgvector.Ecto.Vector do - use Ecto.Type - - def type, do: :vector - - def cast(value) when is_list(value) do - {:ok, value} - end - def cast(_), do: :error - - def load(data) do - {:ok, data} - end - - def dump(value) when is_list(value) do - {:ok, value} - end - if Code.ensure_loaded?(Nx) do - def dump(value) when is_struct(value, Nx.Tensor) do - {:ok, value} - end - end - def dump(_), do: :error - end -end diff --git a/lib/pgvector/extensions/vector.ex b/lib/pgvector/extensions/vector.ex index 26c88d0..cd6a1e7 100644 --- a/lib/pgvector/extensions/vector.ex +++ b/lib/pgvector/extensions/vector.ex @@ -10,40 +10,22 @@ defmodule Pgvector.Extensions.Vector do def encode(_) do quote do vec -> - data = unquote(__MODULE__).encode_vector(vec) + data = vec |> Pgvector.to_binary() [<> | data] end end - def decode(_) do + def decode(:copy) do quote do - <<_len::int32(), dim::uint16, 0::uint16, bin::binary-size(dim)-unit(32)>> -> - for <>, do: v + <> -> + bin |> :binary.copy() |> Pgvector.new() end end - def encode_vector(list) when is_list(list) do - dim = list |> length() - bin = for v <- list, into: "", do: <> - [<> | bin] - end - - if Code.ensure_loaded?(Nx) do - def encode_vector(tensor) when is_struct(tensor, Nx.Tensor) do - if Nx.rank(tensor) != 1 do - raise ArgumentError, "expected rank to be 1" - end - dim = tensor |> Nx.size() - bin = tensor |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_native_to_big() - [<> | bin] - end - - defp f32_native_to_big(bin) do - if System.endianness() == :big do - bin - else - for <>, into: "", do: <> - end + def decode(_) do + quote do + <> -> + bin |> Pgvector.new() end end end diff --git a/test/ecto_test.exs b/test/ecto_test.exs index 52b70ea..59b9dda 100644 --- a/test/ecto_test.exs +++ b/test/ecto_test.exs @@ -1,4 +1,8 @@ -Postgrex.Types.define(EctoApp.PostgrexTypes, [Pgvector.Extensions.Vector] ++ Ecto.Adapters.Postgres.extensions(), []) +Postgrex.Types.define( + EctoApp.PostgrexTypes, + [Pgvector.Extensions.Vector] ++ Ecto.Adapters.Postgres.extensions(), + [] +) defmodule Repo do use Ecto.Repo, @@ -10,7 +14,7 @@ defmodule Item do use Ecto.Schema schema "items" do - field :embedding, Pgvector.Ecto.Vector + field(:embedding, Pgvector) end end @@ -25,23 +29,36 @@ defmodule EctoTest do Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector") Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS items") - Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE items (id bigserial primary key, embedding vector(3))") - Repo.insert(%Item{embedding: [1, 1, 1]}) - Repo.insert(%Item{embedding: [2, 2, 3]}) - Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32)}) + Ecto.Adapters.SQL.query!( + Repo, + "CREATE TABLE items (id bigserial primary key, embedding vector(3))" + ) - items = Repo.all(from i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5) + Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1])}) + Repo.insert(%Item{embedding: Pgvector.new([2, 2, 3])}) + Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32) |> Pgvector.new()}) + + items = Repo.all(from(i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5)) assert Enum.map(items, fn v -> v.id end) == [1, 3, 2] - assert Enum.map(items, fn v -> v.embedding end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]] - items = Repo.all(from i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5) + assert Enum.map(items, fn v -> v.embedding |> Pgvector.to_list() end) == [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 2.0], + [2.0, 2.0, 3.0] + ] + + items = + Repo.all(from(i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5)) + assert Enum.map(items, fn v -> v.id end) == [2, 3, 1] - items = Repo.all(from i in Item, order_by: cosine_distance(i.embedding, [1, 1, 1]), limit: 5) + items = Repo.all(from(i in Item, order_by: cosine_distance(i.embedding, [1, 1, 1]), limit: 5)) assert Enum.map(items, fn v -> v.id end) == [1, 2, 3] - items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.embedding, [1, 1, 1])), limit: 5) + items = + Repo.all(from(i in Item, order_by: 1 - cosine_distance(i.embedding, [1, 1, 1]), limit: 5)) + assert Enum.map(items, fn v -> v.id end) == [3, 2, 1] end end diff --git a/test/pgvector_test.exs b/test/pgvector_test.exs new file mode 100644 index 0000000..39af857 --- /dev/null +++ b/test/pgvector_test.exs @@ -0,0 +1,18 @@ +defmodule PgvectorTest do + use ExUnit.Case + + test "list" do + list = [1.0, 2.0, 3.0] + assert list == list |> Pgvector.new() |> Pgvector.to_list() + end + + test "tensor" do + tensor = Nx.tensor([1.0, 2.0, 3.0]) + assert tensor == tensor |> Pgvector.new() |> Pgvector.to_tensor() + end + + test "inspect" do + vector = Pgvector.new([1, 2, 3]) + assert "Pgvector.new([1.0, 2.0, 3.0])" == inspect(vector) + end +end diff --git a/test/postgrex_test.exs b/test/postgrex_test.exs index e63e7ac..339c4cc 100644 --- a/test/postgrex_test.exs +++ b/test/postgrex_test.exs @@ -7,7 +7,9 @@ defmodule PostgrexTest do use ExUnit.Case setup_all do - {:ok, pid} = Postgrex.start_link(database: "pgvector_elixir_test", types: PostgrexApp.PostgrexTypes) + {:ok, pid} = + Postgrex.start_link(database: "pgvector_elixir_test", types: PostgrexApp.PostgrexTypes) + Postgrex.query!(pid, "CREATE EXTENSION IF NOT EXISTS vector", []) Postgrex.query!(pid, "DROP TABLE IF EXISTS items", []) Postgrex.query!(pid, "CREATE TABLE items (id bigserial primary key, embedding vector(3))", []) @@ -20,26 +22,42 @@ defmodule PostgrexTest do end test "works", %{pid: pid} = _context do - Postgrex.query!(pid, "INSERT INTO items (embedding) VALUES ($1), ($2), ($3)", [[1, 1, 1], [2, 2, 2], Nx.tensor([1, 1, 2], type: :f32)]) + Postgrex.query!(pid, "INSERT INTO items (embedding) VALUES ($1), ($2), ($3)", [ + Pgvector.new([1, 1, 1]), + Pgvector.new([2, 2, 2]), + Nx.tensor([1, 1, 2], type: :f32) |> Pgvector.new() + ]) - result = Postgrex.query!(pid, "SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5", [[1, 1, 1]]) + result = + Postgrex.query!(pid, "SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5", [ + Pgvector.new([1, 1, 1]) + ]) assert ["id", "embedding"] == result.columns assert Enum.map(result.rows, fn v -> Enum.at(v, 0) end) == [1, 3, 2] - assert Enum.map(result.rows, fn v -> Enum.at(v, 1) end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]] - Postgrex.query!(pid, "CREATE INDEX my_index ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)", []) + assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 2.0], + [2.0, 2.0, 2.0] + ] + + Postgrex.query!( + pid, + "CREATE INDEX my_index ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)", + [] + ) end test "tensor", %{pid: pid} = _context do embedding = Nx.tensor([1.0, 2.0, 3.0]) - result = Postgrex.query!(pid, "SELECT $1::vector", [embedding]) - assert result.rows == [[Nx.to_list(embedding)]] + result = Postgrex.query!(pid, "SELECT $1::vector", [embedding |> Pgvector.new()]) + assert Enum.map(result.rows, fn v -> Enum.at(v, 0) |> Pgvector.to_tensor() end) == [embedding] end test "tensor rank", %{pid: pid} = _context do assert_raise ArgumentError, "expected rank to be 1", fn -> - Postgrex.query!(pid, "SELECT $1::vector", [Nx.tensor([[1, 2, 3]])]) + Postgrex.query!(pid, "SELECT $1::vector", [Nx.tensor([[1, 2, 3]]) |> Pgvector.new()]) end end end