From 4939705458b0d36b031cab97b34239fd5f0471b0 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 09:36:36 +0200 Subject: [PATCH 01/10] fix: equality for docvec if tensors are involved Signed-off-by: Johannes Messner --- docarray/array/doc_vec/column_storage.py | 9 +++- docarray/helper.py | 50 +++++++++++++++++++ tests/units/array/stack/test_array_stacked.py | 19 +++++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/docarray/array/doc_vec/column_storage.py b/docarray/array/doc_vec/column_storage.py index e525c8aee0d..b2317556eb4 100644 --- a/docarray/array/doc_vec/column_storage.py +++ b/docarray/array/doc_vec/column_storage.py @@ -14,6 +14,7 @@ ) from docarray.array.list_advance_indexing import ListAdvancedIndexing +from docarray.helper import _is_tensor, _tensor_equals from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -102,7 +103,13 @@ def __eq__(self, other: Any) -> bool: for key_self in col_map_self.keys(): if key_self == 'id': continue - if col_map_self[key_self] != col_map_other[key_self]: + + val1, val2 = col_map_self[key_self], col_map_other[key_self] + if _is_tensor(val1) or _is_tensor(val2): + values_are_equal = _tensor_equals(val1, val2) + else: + values_are_equal = val1 == val2 + if not values_are_equal: return False return True diff --git a/docarray/helper.py b/docarray/helper.py index 5db06eb6d6f..07b783cc3c8 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -15,7 +15,10 @@ Union, ) +import numpy as np + from docarray.utils._internal._typing import safe_issubclass +from docarray.utils._internal.misc import is_tf_available, is_torch_available if TYPE_CHECKING: from docarray import BaseDoc @@ -256,3 +259,50 @@ def _shallow_copy_doc(doc): setattr(shallow_copy, field_name, val) return shallow_copy + + +def _is_tensor(x: Any) -> bool: + """ + Determines whether `x` is either np.ndarray, torch.tensor, or tf.tensor + """ + + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + if is_tf_available(): + import tensorflow as tf + + if tf.is_tensor(x): + return True + + if isinstance(x, np.ndarray): + return True + + return False + + +def _tensor_equals(tens1: Any, tens2: Any) -> bool: + """ + Determines if two {torch, tf, np} tensors are equal. + If at least one of them is not a tensor, of if they are tensors of different frameworks, False is returned. + """ + if is_torch_available(): + import torch + + if isinstance(tens1, torch.Tensor) and isinstance(tens2, torch.Tensor): + return torch.equal(tens1, tens2) + + if is_tf_available(): + import tensorflow as tf + + if tf.is_tensor(tens1) and tf.is_tensor(tens2): + return tf.math.reduce_all(tf.equal(tens1, tens2)) + + are_np_arrays = isinstance(tens1, np.ndarray) and isinstance(tens2, np.ndarray) + if are_np_arrays: + return np.array_equal(tens1, tens2) + + return False diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 4976aaddd31..14514e1bec0 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -598,6 +598,25 @@ class Text(BaseDoc): assert da == da2.to_doc_vec() +@pytest.mark.parametrize('tensor_type', [TorchTensor, NdArray]) +def test_doc_vec_equality_tensor(tensor_type): + class Text(BaseDoc): + tens: tensor_type + + da = DocVec[Text]( + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=tensor_type + ) + da2 = DocVec[Text]( + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=tensor_type + ) + assert da == da2 + + da2 = DocVec[Text]( + [Text(tens=[1, 2, 3, 4, 5]) for _ in range(10)], tensor_type=tensor_type + ) + assert da != da2 + + def test_doc_vec_nested(batch_nested_doc): batch, Doc, Inner = batch_nested_doc batch2 = DocVec[Doc]([Doc(inner=Inner(hello='hello')) for _ in range(10)]) From e0be74156e8b3218af83548d395078269a90a665 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 09:38:27 +0200 Subject: [PATCH 02/10] test: add tf tests Signed-off-by: Johannes Messner --- tests/units/array/stack/test_array_stacked.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 14514e1bec0..c8bcc47aeb7 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -617,6 +617,27 @@ class Text(BaseDoc): assert da != da2 +@pytest.mark.tensorflow +def test_doc_vec_equality_tf(tensor_type): + from docarray.typing import TensorflowTensor + + class Text(BaseDoc): + tens: TensorflowTensor + + da = DocVec[Text]( + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorflowTensor + ) + da2 = DocVec[Text]( + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorflowTensor + ) + assert da == da2 + + da2 = DocVec[Text]( + [Text(tens=[1, 2, 3, 4, 5]) for _ in range(10)], tensor_type=TensorflowTensor + ) + assert da != da2 + + def test_doc_vec_nested(batch_nested_doc): batch, Doc, Inner = batch_nested_doc batch2 = DocVec[Doc]([Doc(inner=Inner(hello='hello')) for _ in range(10)]) From 98920a68c8437d8ba061af6bc4a46c148823528d Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 09:58:23 +0200 Subject: [PATCH 03/10] fix: add type ignore for mypy imports Signed-off-by: Johannes Messner --- docarray/helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/helper.py b/docarray/helper.py index 07b783cc3c8..fec1afc07c8 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -273,7 +273,7 @@ def _is_tensor(x: Any) -> bool: return True if is_tf_available(): - import tensorflow as tf + import tensorflow as tf # type: ignore if tf.is_tensor(x): return True @@ -296,7 +296,7 @@ def _tensor_equals(tens1: Any, tens2: Any) -> bool: return torch.equal(tens1, tens2) if is_tf_available(): - import tensorflow as tf + import tensorflow as tf # type: ignore if tf.is_tensor(tens1) and tf.is_tensor(tens2): return tf.math.reduce_all(tf.equal(tens1, tens2)) From 8bf163223e445809ed2729bb1502e140cae0dd17 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 10:03:24 +0200 Subject: [PATCH 04/10] test: remove non existing fixture Signed-off-by: Johannes Messner --- tests/units/array/stack/test_array_stacked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index c8bcc47aeb7..3d4a3069840 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -618,7 +618,7 @@ class Text(BaseDoc): @pytest.mark.tensorflow -def test_doc_vec_equality_tf(tensor_type): +def test_doc_vec_equality_tf(): from docarray.typing import TensorflowTensor class Text(BaseDoc): From 8ac18465ddc0bc0d431c82760cb80078cf7b2efc Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 10:26:59 +0200 Subject: [PATCH 05/10] test: fix faulty import Signed-off-by: Johannes Messner --- tests/units/array/stack/test_array_stacked.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 3d4a3069840..35509338068 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -619,21 +619,21 @@ class Text(BaseDoc): @pytest.mark.tensorflow def test_doc_vec_equality_tf(): - from docarray.typing import TensorflowTensor + from docarray.typing import TensorFlowTensor class Text(BaseDoc): - tens: TensorflowTensor + tens: TensorFlowTensor da = DocVec[Text]( - [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorflowTensor + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorFlowTensor ) da2 = DocVec[Text]( - [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorflowTensor + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorFlowTensor ) assert da == da2 da2 = DocVec[Text]( - [Text(tens=[1, 2, 3, 4, 5]) for _ in range(10)], tensor_type=TensorflowTensor + [Text(tens=[1, 2, 3, 4, 5]) for _ in range(10)], tensor_type=TensorFlowTensor ) assert da != da2 From 6f4abed409db9cb0fcf77fa50b5bd96131f1db85 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 10:44:51 +0200 Subject: [PATCH 06/10] fix: compare tf tensors of different shapes Signed-off-by: Johannes Messner --- docarray/helper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docarray/helper.py b/docarray/helper.py index fec1afc07c8..12bb73faf11 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -299,7 +299,9 @@ def _tensor_equals(tens1: Any, tens2: Any) -> bool: import tensorflow as tf # type: ignore if tf.is_tensor(tens1) and tf.is_tensor(tens2): - return tf.math.reduce_all(tf.equal(tens1, tens2)) + return tens1.shape == tens2.shape and tf.math.reduce_all( + tf.equal(tens1, tens2) + ) are_np_arrays = isinstance(tens1, np.ndarray) and isinstance(tens2, np.ndarray) if are_np_arrays: From a76a37b4bf7fce6188e109ec604d91f0822e342b Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 11:08:56 +0200 Subject: [PATCH 07/10] fix: unwrap tf tensor when needed Signed-off-by: Johannes Messner --- docarray/helper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docarray/helper.py b/docarray/helper.py index 12bb73faf11..6f7c0797d1b 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -275,7 +275,8 @@ def _is_tensor(x: Any) -> bool: if is_tf_available(): import tensorflow as tf # type: ignore - if tf.is_tensor(x): + t = getattr(x, 'tensor', None) + if tf.is_tensor(t): return True if isinstance(x, np.ndarray): @@ -298,10 +299,9 @@ def _tensor_equals(tens1: Any, tens2: Any) -> bool: if is_tf_available(): import tensorflow as tf # type: ignore - if tf.is_tensor(tens1) and tf.is_tensor(tens2): - return tens1.shape == tens2.shape and tf.math.reduce_all( - tf.equal(tens1, tens2) - ) + t1, t2 = getattr(tens1, 'tensor', None), getattr(tens2, 'tensor', None) + if tf.is_tensor(t1) and tf.is_tensor(t2): + return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) are_np_arrays = isinstance(tens1, np.ndarray) and isinstance(tens2, np.ndarray) if are_np_arrays: From e4622b6f1aaf70f59588e16e05dbc4ac7c123493 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 11:22:07 +0200 Subject: [PATCH 08/10] fix: mypy Signed-off-by: Johannes Messner --- docarray/helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/helper.py b/docarray/helper.py index 6f7c0797d1b..1d2536d91ef 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -301,7 +301,7 @@ def _tensor_equals(tens1: Any, tens2: Any) -> bool: t1, t2 = getattr(tens1, 'tensor', None), getattr(tens2, 'tensor', None) if tf.is_tensor(t1) and tf.is_tensor(t2): - return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) + return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) # type: ignore are_np_arrays = isinstance(tens1, np.ndarray) and isinstance(tens2, np.ndarray) if are_np_arrays: From 62a3b1dd94d1d750d4d8335ba9603d6668208056 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 12:17:53 +0200 Subject: [PATCH 09/10] refactor: move stuff to comp backends Signed-off-by: Johannes Messner --- docarray/array/doc_vec/column_storage.py | 7 +-- docarray/computation/abstract_comp_backend.py | 13 +++++ docarray/computation/numpy_backend.py | 15 ++++++ docarray/computation/tensorflow_backend.py | 15 ++++++ docarray/computation/torch_backend.py | 15 ++++++ docarray/helper.py | 52 ------------------- 6 files changed, 62 insertions(+), 55 deletions(-) diff --git a/docarray/array/doc_vec/column_storage.py b/docarray/array/doc_vec/column_storage.py index b2317556eb4..539e9fd42af 100644 --- a/docarray/array/doc_vec/column_storage.py +++ b/docarray/array/doc_vec/column_storage.py @@ -14,7 +14,6 @@ ) from docarray.array.list_advance_indexing import ListAdvancedIndexing -from docarray.helper import _is_tensor, _tensor_equals from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -105,8 +104,10 @@ def __eq__(self, other: Any) -> bool: continue val1, val2 = col_map_self[key_self], col_map_other[key_self] - if _is_tensor(val1) or _is_tensor(val2): - values_are_equal = _tensor_equals(val1, val2) + if isinstance(val1, AbstractTensor): + values_are_equal = val1.get_comp_backend().equal(val1, val2) + elif isinstance(val2, AbstractTensor): + values_are_equal = val2.get_comp_backend().equal(val1, val2) else: values_are_equal = val1 == val2 if not values_are_equal: diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 8e2be24cbfb..afaf4564e61 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -157,6 +157,19 @@ def minmax_normalize( """ ... + @classmethod + @abstractmethod + def equal(cls, tensor1: 'TTensor', tensor2: 'TTensor') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first tensor + :param tensor2: the second tensor + :return: True if two tensors are equal, False otherwise. + If one or more of the inputs is not a tensor of this framework, return False. + """ + ... + class Retrieval(ABC, typing.Generic[TTensorRetrieval]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index 30d50cc0174..913f42d429e 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -111,6 +111,21 @@ def minmax_normalize( return np.clip(r, *((a, b) if a < b else (b, a))) + @classmethod + def equal(cls, tensor1: 'np.ndarray', tensor2: 'np.ndarray') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first array + :param tensor2: the second array + :return: True if two arrays are equal, False otherwise. + If one or more of the inputs is not an ndarray, return False. + """ + are_np_arrays = isinstance(tensor1, np.ndarray) and isinstance( + tensor2, np.ndarray + ) + return are_np_arrays and np.array_equal(tensor1, tensor2) + class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/tensorflow_backend.py b/docarray/computation/tensorflow_backend.py index fc963cdb48b..91bfba81c89 100644 --- a/docarray/computation/tensorflow_backend.py +++ b/docarray/computation/tensorflow_backend.py @@ -121,6 +121,21 @@ def minmax_normalize( normalized = tnp.clip(i, *((a, b) if a < b else (b, a))) return cls._cast_output(tf.cast(normalized, tensor.tensor.dtype)) + @classmethod + def equal(cls, tensor1: 'TensorFlowTensor', tensor2: 'TensorFlowTensor') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first tensor + :param tensor2: the second tensor + :return: True if two tensors are equal, False otherwise. + If one or more of the inputs is not a TensorFlowTensor, return False. + """ + t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None) + if tf.is_tensor(t1) and tf.is_tensor(t2): + return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) + return False + class Retrieval(AbstractComputationalBackend.Retrieval[TensorFlowTensor]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index be6d4ea03fd..97f0abbb3b5 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -113,6 +113,21 @@ def reshape(cls, tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tenso """ return tensor.reshape(shape) + @classmethod + def equal(cls, tensor1: 'torch.Tensor', tensor2: 'torch.Tensor') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first tensor + :param tensor2: the second tensor + :return: True if two tensors are equal, False otherwise. + If one or more of the inputs is not a torch.Tensor, return False. + """ + are_torch = isinstance(tensor1, torch.Tensor) and isinstance( + tensor2, torch.Tensor + ) + return are_torch and torch.equal(tensor1, tensor2) + @classmethod def detach(cls, tensor: 'torch.Tensor') -> 'torch.Tensor': """ diff --git a/docarray/helper.py b/docarray/helper.py index 1d2536d91ef..5db06eb6d6f 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -15,10 +15,7 @@ Union, ) -import numpy as np - from docarray.utils._internal._typing import safe_issubclass -from docarray.utils._internal.misc import is_tf_available, is_torch_available if TYPE_CHECKING: from docarray import BaseDoc @@ -259,52 +256,3 @@ def _shallow_copy_doc(doc): setattr(shallow_copy, field_name, val) return shallow_copy - - -def _is_tensor(x: Any) -> bool: - """ - Determines whether `x` is either np.ndarray, torch.tensor, or tf.tensor - """ - - if is_torch_available(): - import torch - - if isinstance(x, torch.Tensor): - return True - - if is_tf_available(): - import tensorflow as tf # type: ignore - - t = getattr(x, 'tensor', None) - if tf.is_tensor(t): - return True - - if isinstance(x, np.ndarray): - return True - - return False - - -def _tensor_equals(tens1: Any, tens2: Any) -> bool: - """ - Determines if two {torch, tf, np} tensors are equal. - If at least one of them is not a tensor, of if they are tensors of different frameworks, False is returned. - """ - if is_torch_available(): - import torch - - if isinstance(tens1, torch.Tensor) and isinstance(tens2, torch.Tensor): - return torch.equal(tens1, tens2) - - if is_tf_available(): - import tensorflow as tf # type: ignore - - t1, t2 = getattr(tens1, 'tensor', None), getattr(tens2, 'tensor', None) - if tf.is_tensor(t1) and tf.is_tensor(t2): - return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) # type: ignore - - are_np_arrays = isinstance(tens1, np.ndarray) and isinstance(tens2, np.ndarray) - if are_np_arrays: - return np.array_equal(tens1, tens2) - - return False From a89cae6c9437d3a80a72fd807c88652304eb6f6b Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 20 Jun 2023 12:42:21 +0200 Subject: [PATCH 10/10] fix: mypy Signed-off-by: Johannes Messner --- docarray/computation/tensorflow_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docarray/computation/tensorflow_backend.py b/docarray/computation/tensorflow_backend.py index 91bfba81c89..27609b737e1 100644 --- a/docarray/computation/tensorflow_backend.py +++ b/docarray/computation/tensorflow_backend.py @@ -133,7 +133,8 @@ def equal(cls, tensor1: 'TensorFlowTensor', tensor2: 'TensorFlowTensor') -> bool """ t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None) if tf.is_tensor(t1) and tf.is_tensor(t2): - return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) + # mypy doesn't know that tf.is_tensor implies that t1, t2 are not None + return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) # type: ignore return False class Retrieval(AbstractComputationalBackend.Retrieval[TensorFlowTensor]):