From 689ef35e9a9d6a4febcbbaea8ce5208ce8ed69a4 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 26 Apr 2023 18:48:39 -0700 Subject: [PATCH 1/3] Added Pgvector struct --- lib/pgvector.ex | 66 +++++++++++++++++++++++++++++++ lib/pgvector/extensions/vector.ex | 31 ++------------- test/ecto_test.exs | 2 +- test/pgvector_test.exs | 18 +++++++++ test/postgrex_test.exs | 4 +- 5 files changed, 90 insertions(+), 31 deletions(-) create mode 100644 lib/pgvector.ex create mode 100644 test/pgvector_test.exs diff --git a/lib/pgvector.ex b/lib/pgvector.ex new file mode 100644 index 0000000..a1cb853 --- /dev/null +++ b/lib/pgvector.ex @@ -0,0 +1,66 @@ +defmodule Pgvector do + defstruct [:data] + + def new(binary) when is_binary(binary) do + %Pgvector{data: binary} + end + + def new(list) when is_list(list) do + dim = list |> length() + bin = for v <- list, into: "", do: <> + data = <> + new(data) + end + + if Code.ensure_loaded?(Nx) do + def new(t) when is_struct(t, Nx.Tensor) do + if Nx.rank(t) != 1 do + raise ArgumentError, "expected rank to be 1" + end + dim = t |> Nx.size() + bin = t |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_native_to_big() + data = <> + new(data) + end + + defp f32_native_to_big(binary) do + if System.endianness() == :big do + binary + else + for <>, into: "", do: <> + end + end + end + + def to_binary(vector) when is_struct(vector, Pgvector) do + vector.data + end + + def to_list(vector) when is_struct(vector, Pgvector) do + <> = vector.data + for <>, do: v + end + + if Code.ensure_loaded?(Nx) do + def to_tensor(vector) when is_struct(vector, Pgvector) do + <> = vector.data + bin |> f32_big_to_native() |> Nx.from_binary(:f32) + end + + defp f32_big_to_native(binary) do + if System.endianness() == :big do + binary + else + for <>, into: "", do: <> + end + end + end +end + +defimpl Inspect, for: Pgvector do + import Inspect.Algebra + + def inspect(vec, opts) do + concat(["Pgvector.new(", Inspect.List.inspect(Pgvector.to_list(vec), opts), ")"]) + end +end diff --git a/lib/pgvector/extensions/vector.ex b/lib/pgvector/extensions/vector.ex index 26c88d0..3360f37 100644 --- a/lib/pgvector/extensions/vector.ex +++ b/lib/pgvector/extensions/vector.ex @@ -10,40 +10,15 @@ defmodule Pgvector.Extensions.Vector do def encode(_) do quote do vec -> - data = unquote(__MODULE__).encode_vector(vec) + data = vec |> Pgvector.new() |> Pgvector.to_binary() [<> | data] end end def decode(_) do quote do - <<_len::int32(), dim::uint16, 0::uint16, bin::binary-size(dim)-unit(32)>> -> - for <>, do: v - end - end - - def encode_vector(list) when is_list(list) do - dim = list |> length() - bin = for v <- list, into: "", do: <> - [<> | bin] - end - - if Code.ensure_loaded?(Nx) do - def encode_vector(tensor) when is_struct(tensor, Nx.Tensor) do - if Nx.rank(tensor) != 1 do - raise ArgumentError, "expected rank to be 1" - end - dim = tensor |> Nx.size() - bin = tensor |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_native_to_big() - [<> | bin] - end - - defp f32_native_to_big(bin) do - if System.endianness() == :big do - bin - else - for <>, into: "", do: <> - end + <> -> + bin |> Pgvector.new() end end end diff --git a/test/ecto_test.exs b/test/ecto_test.exs index 52b70ea..963d862 100644 --- a/test/ecto_test.exs +++ b/test/ecto_test.exs @@ -33,7 +33,7 @@ defmodule EctoTest do items = Repo.all(from i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5) assert Enum.map(items, fn v -> v.id end) == [1, 3, 2] - assert Enum.map(items, fn v -> v.embedding end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]] + assert Enum.map(items, fn v -> v.embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]] items = Repo.all(from i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5) assert Enum.map(items, fn v -> v.id end) == [2, 3, 1] diff --git a/test/pgvector_test.exs b/test/pgvector_test.exs new file mode 100644 index 0000000..c38ee75 --- /dev/null +++ b/test/pgvector_test.exs @@ -0,0 +1,18 @@ +defmodule PgvectorTest do + use ExUnit.Case + + test "list" do + list = [1.0, 2.0, 3.0] + assert list == (list |> Pgvector.new() |> Pgvector.to_list()) + end + + test "tensor" do + tensor = Nx.tensor([1.0, 2.0, 3.0]) + assert tensor == (tensor |> Pgvector.new() |> Pgvector.to_tensor()) + end + + test "inspect" do + vector = Pgvector.new([1, 2, 3]) + assert "Pgvector.new([1.0, 2.0, 3.0])" == inspect(vector) + end +end diff --git a/test/postgrex_test.exs b/test/postgrex_test.exs index e63e7ac..a6629b2 100644 --- a/test/postgrex_test.exs +++ b/test/postgrex_test.exs @@ -26,7 +26,7 @@ defmodule PostgrexTest do assert ["id", "embedding"] == result.columns assert Enum.map(result.rows, fn v -> Enum.at(v, 0) end) == [1, 3, 2] - assert Enum.map(result.rows, fn v -> Enum.at(v, 1) end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]] + assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]] Postgrex.query!(pid, "CREATE INDEX my_index ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)", []) end @@ -34,7 +34,7 @@ defmodule PostgrexTest do test "tensor", %{pid: pid} = _context do embedding = Nx.tensor([1.0, 2.0, 3.0]) result = Postgrex.query!(pid, "SELECT $1::vector", [embedding]) - assert result.rows == [[Nx.to_list(embedding)]] + assert Enum.map(result.rows, fn v -> Enum.at(v, 0) |> Pgvector.to_tensor() end) == [embedding] end test "tensor rank", %{pid: pid} = _context do From 1416c8ddea7914d90b27caa469cccf6a0ee7563f Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Wed, 26 Apr 2023 19:01:07 -0700 Subject: [PATCH 2/3] Improved code --- lib/pgvector.ex | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/lib/pgvector.ex b/lib/pgvector.ex index a1cb853..ead630e 100644 --- a/lib/pgvector.ex +++ b/lib/pgvector.ex @@ -8,19 +8,17 @@ defmodule Pgvector do def new(list) when is_list(list) do dim = list |> length() bin = for v <- list, into: "", do: <> - data = <> - new(data) + new(<>) end if Code.ensure_loaded?(Nx) do - def new(t) when is_struct(t, Nx.Tensor) do - if Nx.rank(t) != 1 do + def new(tensor) when is_struct(tensor, Nx.Tensor) do + if Nx.rank(tensor) != 1 do raise ArgumentError, "expected rank to be 1" end - dim = t |> Nx.size() - bin = t |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_native_to_big() - data = <> - new(data) + dim = tensor |> Nx.size() + bin = tensor |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_native_to_big() + new(<>) end defp f32_native_to_big(binary) do From 7ae2e90066df1abe7037ab2067681080c74d738e Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 29 May 2023 01:09:46 -0700 Subject: [PATCH 3/3] Improved code --- README.md | 8 +++---- lib/pgvector.ex | 26 +++++++++++++++++++++ lib/pgvector/ecto/vector.ex | 26 --------------------- lib/pgvector/extensions/vector.ex | 9 ++++++- test/ecto_test.exs | 39 ++++++++++++++++++++++--------- test/pgvector_test.exs | 4 ++-- test/postgrex_test.exs | 32 +++++++++++++++++++------ 7 files changed, 93 insertions(+), 51 deletions(-) delete mode 100644 lib/pgvector/ecto/vector.ex diff --git a/README.md b/README.md index 482c3e5..76f9ce2 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ Update the model ```elixir schema "items" do - field :embedding, Pgvector.Ecto.Vector + field :embedding, Pgvector end ``` @@ -82,7 +82,7 @@ Insert a vector ```elixir alias MyApp.{Repo, Item} -Repo.insert(%Item{embedding: [1, 2, 3]}) +Repo.insert(%Item{embedding: Pgvector.new([1, 2, 3])}) ``` Get the nearest neighbors @@ -127,13 +127,13 @@ Postgrex.query!(pid, "CREATE TABLE items (embedding vector(3))", []) Insert a vector ```elixir -Postgrex.query!(pid, "INSERT INTO items (embedding) VALUES ($1)", [[1, 2, 3]]) +Postgrex.query!(pid, "INSERT INTO items (embedding) VALUES ($1)", [Pgvector.new([1, 2, 3])]) ``` Get the nearest neighbors ```elixir -Postgrex.query!(pid, "SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5", [[1, 2, 3]]) +Postgrex.query!(pid, "SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5", [Pgvector.new([1, 2, 3])]) ``` Add an approximate index diff --git a/lib/pgvector.ex b/lib/pgvector.ex index ead630e..4b5c4bc 100644 --- a/lib/pgvector.ex +++ b/lib/pgvector.ex @@ -1,6 +1,13 @@ defmodule Pgvector do + @moduledoc """ + todo + """ + defstruct [:data] + @doc """ + todo + """ def new(binary) when is_binary(binary) do %Pgvector{data: binary} end @@ -16,6 +23,7 @@ defmodule Pgvector do if Nx.rank(tensor) != 1 do raise ArgumentError, "expected rank to be 1" end + dim = tensor |> Nx.size() bin = tensor |> Nx.as_type(:f32) |> Nx.to_binary() |> f32_native_to_big() new(<>) @@ -53,6 +61,24 @@ defmodule Pgvector do end end end + + if Code.ensure_loaded?(Ecto) do + use Ecto.Type + + def type, do: :vector + + def cast(value) do + {:ok, value |> Pgvector.new()} + end + + def load(data) do + {:ok, data} + end + + def dump(value) do + {:ok, value} + end + end end defimpl Inspect, for: Pgvector do diff --git a/lib/pgvector/ecto/vector.ex b/lib/pgvector/ecto/vector.ex deleted file mode 100644 index 38496d3..0000000 --- a/lib/pgvector/ecto/vector.ex +++ /dev/null @@ -1,26 +0,0 @@ -if Code.ensure_loaded?(Ecto) do - defmodule Pgvector.Ecto.Vector do - use Ecto.Type - - def type, do: :vector - - def cast(value) when is_list(value) do - {:ok, value} - end - def cast(_), do: :error - - def load(data) do - {:ok, data} - end - - def dump(value) when is_list(value) do - {:ok, value} - end - if Code.ensure_loaded?(Nx) do - def dump(value) when is_struct(value, Nx.Tensor) do - {:ok, value} - end - end - def dump(_), do: :error - end -end diff --git a/lib/pgvector/extensions/vector.ex b/lib/pgvector/extensions/vector.ex index 3360f37..cd6a1e7 100644 --- a/lib/pgvector/extensions/vector.ex +++ b/lib/pgvector/extensions/vector.ex @@ -10,11 +10,18 @@ defmodule Pgvector.Extensions.Vector do def encode(_) do quote do vec -> - data = vec |> Pgvector.new() |> Pgvector.to_binary() + data = vec |> Pgvector.to_binary() [<> | data] end end + def decode(:copy) do + quote do + <> -> + bin |> :binary.copy() |> Pgvector.new() + end + end + def decode(_) do quote do <> -> diff --git a/test/ecto_test.exs b/test/ecto_test.exs index 963d862..59b9dda 100644 --- a/test/ecto_test.exs +++ b/test/ecto_test.exs @@ -1,4 +1,8 @@ -Postgrex.Types.define(EctoApp.PostgrexTypes, [Pgvector.Extensions.Vector] ++ Ecto.Adapters.Postgres.extensions(), []) +Postgrex.Types.define( + EctoApp.PostgrexTypes, + [Pgvector.Extensions.Vector] ++ Ecto.Adapters.Postgres.extensions(), + [] +) defmodule Repo do use Ecto.Repo, @@ -10,7 +14,7 @@ defmodule Item do use Ecto.Schema schema "items" do - field :embedding, Pgvector.Ecto.Vector + field(:embedding, Pgvector) end end @@ -25,23 +29,36 @@ defmodule EctoTest do Ecto.Adapters.SQL.query!(Repo, "CREATE EXTENSION IF NOT EXISTS vector") Ecto.Adapters.SQL.query!(Repo, "DROP TABLE IF EXISTS items") - Ecto.Adapters.SQL.query!(Repo, "CREATE TABLE items (id bigserial primary key, embedding vector(3))") - Repo.insert(%Item{embedding: [1, 1, 1]}) - Repo.insert(%Item{embedding: [2, 2, 3]}) - Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32)}) + Ecto.Adapters.SQL.query!( + Repo, + "CREATE TABLE items (id bigserial primary key, embedding vector(3))" + ) - items = Repo.all(from i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5) + Repo.insert(%Item{embedding: Pgvector.new([1, 1, 1])}) + Repo.insert(%Item{embedding: Pgvector.new([2, 2, 3])}) + Repo.insert(%Item{embedding: Nx.tensor([1, 1, 2], type: :f32) |> Pgvector.new()}) + + items = Repo.all(from(i in Item, order_by: l2_distance(i.embedding, [1, 1, 1]), limit: 5)) assert Enum.map(items, fn v -> v.id end) == [1, 3, 2] - assert Enum.map(items, fn v -> v.embedding |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 3.0]] - items = Repo.all(from i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5) + assert Enum.map(items, fn v -> v.embedding |> Pgvector.to_list() end) == [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 2.0], + [2.0, 2.0, 3.0] + ] + + items = + Repo.all(from(i in Item, order_by: max_inner_product(i.embedding, [1, 1, 1]), limit: 5)) + assert Enum.map(items, fn v -> v.id end) == [2, 3, 1] - items = Repo.all(from i in Item, order_by: cosine_distance(i.embedding, [1, 1, 1]), limit: 5) + items = Repo.all(from(i in Item, order_by: cosine_distance(i.embedding, [1, 1, 1]), limit: 5)) assert Enum.map(items, fn v -> v.id end) == [1, 2, 3] - items = Repo.all(from i in Item, order_by: (1 - cosine_distance(i.embedding, [1, 1, 1])), limit: 5) + items = + Repo.all(from(i in Item, order_by: 1 - cosine_distance(i.embedding, [1, 1, 1]), limit: 5)) + assert Enum.map(items, fn v -> v.id end) == [3, 2, 1] end end diff --git a/test/pgvector_test.exs b/test/pgvector_test.exs index c38ee75..39af857 100644 --- a/test/pgvector_test.exs +++ b/test/pgvector_test.exs @@ -3,12 +3,12 @@ defmodule PgvectorTest do test "list" do list = [1.0, 2.0, 3.0] - assert list == (list |> Pgvector.new() |> Pgvector.to_list()) + assert list == list |> Pgvector.new() |> Pgvector.to_list() end test "tensor" do tensor = Nx.tensor([1.0, 2.0, 3.0]) - assert tensor == (tensor |> Pgvector.new() |> Pgvector.to_tensor()) + assert tensor == tensor |> Pgvector.new() |> Pgvector.to_tensor() end test "inspect" do diff --git a/test/postgrex_test.exs b/test/postgrex_test.exs index a6629b2..339c4cc 100644 --- a/test/postgrex_test.exs +++ b/test/postgrex_test.exs @@ -7,7 +7,9 @@ defmodule PostgrexTest do use ExUnit.Case setup_all do - {:ok, pid} = Postgrex.start_link(database: "pgvector_elixir_test", types: PostgrexApp.PostgrexTypes) + {:ok, pid} = + Postgrex.start_link(database: "pgvector_elixir_test", types: PostgrexApp.PostgrexTypes) + Postgrex.query!(pid, "CREATE EXTENSION IF NOT EXISTS vector", []) Postgrex.query!(pid, "DROP TABLE IF EXISTS items", []) Postgrex.query!(pid, "CREATE TABLE items (id bigserial primary key, embedding vector(3))", []) @@ -20,26 +22,42 @@ defmodule PostgrexTest do end test "works", %{pid: pid} = _context do - Postgrex.query!(pid, "INSERT INTO items (embedding) VALUES ($1), ($2), ($3)", [[1, 1, 1], [2, 2, 2], Nx.tensor([1, 1, 2], type: :f32)]) + Postgrex.query!(pid, "INSERT INTO items (embedding) VALUES ($1), ($2), ($3)", [ + Pgvector.new([1, 1, 1]), + Pgvector.new([2, 2, 2]), + Nx.tensor([1, 1, 2], type: :f32) |> Pgvector.new() + ]) - result = Postgrex.query!(pid, "SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5", [[1, 1, 1]]) + result = + Postgrex.query!(pid, "SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5", [ + Pgvector.new([1, 1, 1]) + ]) assert ["id", "embedding"] == result.columns assert Enum.map(result.rows, fn v -> Enum.at(v, 0) end) == [1, 3, 2] - assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [[1.0, 1.0, 1.0], [1.0, 1.0, 2.0], [2.0, 2.0, 2.0]] - Postgrex.query!(pid, "CREATE INDEX my_index ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)", []) + assert Enum.map(result.rows, fn v -> Enum.at(v, 1) |> Pgvector.to_list() end) == [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 2.0], + [2.0, 2.0, 2.0] + ] + + Postgrex.query!( + pid, + "CREATE INDEX my_index ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)", + [] + ) end test "tensor", %{pid: pid} = _context do embedding = Nx.tensor([1.0, 2.0, 3.0]) - result = Postgrex.query!(pid, "SELECT $1::vector", [embedding]) + result = Postgrex.query!(pid, "SELECT $1::vector", [embedding |> Pgvector.new()]) assert Enum.map(result.rows, fn v -> Enum.at(v, 0) |> Pgvector.to_tensor() end) == [embedding] end test "tensor rank", %{pid: pid} = _context do assert_raise ArgumentError, "expected rank to be 1", fn -> - Postgrex.query!(pid, "SELECT $1::vector", [Nx.tensor([[1, 2, 3]])]) + Postgrex.query!(pid, "SELECT $1::vector", [Nx.tensor([[1, 2, 3]]) |> Pgvector.new()]) end end end