From 6e8814beb8bc7347d2c47e76b2709c62c66310b5 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Tue, 5 Sep 2023 15:39:46 -0500 Subject: [PATCH 01/10] Add `semiring(A @ B @ C)` that applies semiring to both matmuls --- graphblas/core/infix.py | 15 ++++++++++++++- graphblas/core/matrix.py | 22 ++++++++++++++++++++-- graphblas/core/vector.py | 21 +++++++++++++++++++-- 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 88fc52dbe..097e5fe62 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -258,6 +258,11 @@ def __init__(self, left, right, *, method_name, size): self.method_name = method_name self._size = size + __matmul__ = Vector.__matmul__ + __rmatmul__ = Vector.__rmatmul__ + inner = Vector.inner + vxm = Vector.vxm + utils._output_types[VectorEwiseAddExpr] = Vector utils._output_types[VectorEwiseMultExpr] = Vector @@ -396,6 +401,11 @@ def __init__(self, left, right, *, nrows, ncols): self._nrows = nrows self._ncols = ncols + __matmul__ = Matrix.__matmul__ + __rmatmul__ = Matrix.__rmatmul__ + mxm = Matrix.mxm + mxv = Matrix.mxv + utils._output_types[MatrixEwiseAddExpr] = Matrix utils._output_types[MatrixEwiseMultExpr] = Matrix @@ -489,7 +499,10 @@ def _matmul_infix_expr(left, right, *, within): ) # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, any_pair[bool]) + # Be careful not to accidentally reify an expression. + expr = getattr(left if type(left) is left_type else left.dup(clear=True), method)( + right if type(right) is right_type else right.dup(clear=True), any_pair[bool] + ) if expr.output_type is Vector: return VectorMatMulExpr(left, right, method_name=method, size=expr._size) if expr.output_type is Matrix: diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index d820ca424..065fe67f8 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -2187,10 +2187,18 @@ def mxv(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(A @ v) """ + from .infix import MatrixMatMulExpr, VectorMatMulExpr + method_name = "mxv" - other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) + other = self._expect_type( + other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op + ) op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") + if type(self) is MatrixMatMulExpr: + self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) + if type(other) is VectorMatMulExpr: + other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, "GrB_mxv", @@ -2230,12 +2238,22 @@ def mxm(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(A @ B) """ + from .infix import MatrixMatMulExpr + method_name = "mxm" other = self._expect_type( - other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op + other, + (Matrix, TransposedMatrix, MatrixMatMulExpr), + within=method_name, + argname="other", + op=op, ) op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") + if type(self) is MatrixMatMulExpr: + self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) + if type(other) is MatrixMatMulExpr: + other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, "GrB_mxm", diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index cd5b992ba..15758a924 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -1285,14 +1285,23 @@ def vxm(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(v @ A) """ + from .infix import MatrixMatMulExpr, VectorMatMulExpr from .matrix import Matrix, TransposedMatrix method_name = "vxm" other = self._expect_type( - other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op + other, + (Matrix, TransposedMatrix, MatrixMatMulExpr), + within=method_name, + argname="other", + op=op, ) op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") + if type(self) is VectorMatMulExpr: + self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) + if type(other) is MatrixMatMulExpr: + other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, "GrB_vxm", @@ -1634,10 +1643,18 @@ def inner(self, other, op=semiring.plus_times): `Matrix Multiplication <../user_guide/operations.html#matrix-multiply>`__ family of functions. """ + from .infix import VectorMatMulExpr + method_name = "inner" - other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) + other = self._expect_type( + other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op + ) op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") + if type(self) is VectorMatMulExpr: + self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) + if type(other) is VectorMatMulExpr: + other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = ScalarExpression( method_name, "GrB_vxm", From 7d73cb66ca4bb8b1fe59ffd826b61667978a5e9c Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Fri, 8 Sep 2023 00:22:59 -0500 Subject: [PATCH 02/10] WIP: begin implementing ewise add too --- graphblas/core/infix.py | 86 +++++++++++++- graphblas/core/matrix.py | 168 ++++++++++++++++++++++++++-- graphblas/core/operator/__init__.py | 1 + graphblas/core/operator/utils.py | 25 +++++ graphblas/core/vector.py | 166 +++++++++++++++++++++++++-- graphblas/tests/test_matrix.py | 28 ++++- graphblas/tests/test_vector.py | 28 +++-- 7 files changed, 465 insertions(+), 37 deletions(-) diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 097e5fe62..39bf2fbc1 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -7,6 +7,7 @@ from .expr import InfixExprBase from .mask import Mask from .matrix import Matrix, MatrixExpression, TransposedMatrix +from .recorder import skip_record from .scalar import Scalar, ScalarExpression from .utils import output_type, wrapdoc from .vector import Vector, VectorExpression @@ -238,6 +239,32 @@ class VectorEwiseAddExpr(VectorInfixExpr): _to_expr = _ewise_add_to_expr + ewise_add = Vector.ewise_add + ewise_mult = Vector.ewise_mult + ewise_union = Vector.ewise_union + + def __and__(self, other): + 1 / 0 + raise TypeError("XXX") + + def __rand__(self, other): + 1 / 0 + raise TypeError("XXX") + + def __or__(self, other): + if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): + 1 / 0 + raise TypeError("XXX") + 1 / 0 + return _ewise_infix_expr(self, other, method="ewise_add", within="__or__") + + def __ror__(self, other): + if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): + 1 / 0 + raise TypeError("XXX") + 1 / 0 + return _ewise_infix_expr(other, self, method="ewise_add", within="__ror__") + class VectorEwiseMultExpr(VectorInfixExpr): __slots__ = () @@ -247,6 +274,27 @@ class VectorEwiseMultExpr(VectorInfixExpr): _to_expr = _ewise_mult_to_expr + ewise_add = Vector.ewise_add + ewise_mult = Vector.ewise_mult + ewise_union = Vector.ewise_union + + def __and__(self, other): + if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): + raise TypeError("XXX") + return _ewise_infix_expr(self, other, method="ewise_mult", within="__and__") + + def __rand__(self, other): + if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): + 1 / 0 + raise TypeError("XXX") + return _ewise_infix_expr(other, self, method="ewise_mult", within="__rand__") + + def __or__(self, other): + raise TypeError("XXX") + + def __ror__(self, other): + raise TypeError("XXX") + class VectorMatMulExpr(VectorInfixExpr): __slots__ = "method_name" @@ -380,6 +428,14 @@ class MatrixEwiseAddExpr(MatrixInfixExpr): _to_expr = _ewise_add_to_expr + __and__ = VectorEwiseAddExpr.__and__ + __or__ = VectorEwiseAddExpr.__or__ + __rand__ = VectorEwiseAddExpr.__rand__ + __ror__ = VectorEwiseAddExpr.__ror__ + ewise_add = Matrix.ewise_add + ewise_mult = Matrix.ewise_mult + ewise_union = Matrix.ewise_union + class MatrixEwiseMultExpr(MatrixInfixExpr): __slots__ = () @@ -389,6 +445,14 @@ class MatrixEwiseMultExpr(MatrixInfixExpr): _to_expr = _ewise_mult_to_expr + __and__ = VectorEwiseMultExpr.__and__ + __or__ = VectorEwiseMultExpr.__or__ + __rand__ = VectorEwiseMultExpr.__rand__ + __ror__ = VectorEwiseMultExpr.__ror__ + ewise_add = Matrix.ewise_add + ewise_mult = Matrix.ewise_mult + ewise_union = Matrix.ewise_union + class MatrixMatMulExpr(MatrixInfixExpr): __slots__ = () @@ -412,6 +476,11 @@ def __init__(self, left, right, *, nrows, ncols): utils._output_types[MatrixMatMulExpr] = Matrix +def _dummy(obj, obj_type): + with skip_record: + return obj_type(BOOL, *obj.shape, name="") + + def _ewise_infix_expr(left, right, *, method, within): left_type = output_type(left) right_type = output_type(right) @@ -419,7 +488,9 @@ def _ewise_infix_expr(left, right, *, method, within): types = {Vector, Matrix, TransposedMatrix} if left_type in types and right_type in types: # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, binary.any) + expr = getattr( + _dummy(left, left_type) if isinstance(left, InfixExprBase) else left, method + )(_dummy(right, right_type) if isinstance(right, InfixExprBase) else right, binary.first) if expr.output_type is Vector: if method == "ewise_mult": return VectorEwiseMultExpr(left, right) @@ -437,13 +508,17 @@ def _ewise_infix_expr(left, right, *, method, within): right._expect_type(left, tuple(types), within=within, argname="left") elif left_type is Scalar: # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, binary.any) + expr = getattr( + _dummy(left, left_type) if isinstance(left, InfixExprBase) else left, method + )(_dummy(right, right_type) if isinstance(right, InfixExprBase) else right, binary.first) if method == "ewise_mult": return ScalarEwiseMultExpr(left, right) return ScalarEwiseAddExpr(left, right) elif right_type is Scalar: # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(right, method)(left, binary.any) + expr = getattr( + _dummy(right, right_type) if isinstance(right, InfixExprBase) else right, method + )(_dummy(left, left_type) if isinstance(left, InfixExprBase) else left, binary.first) if method == "ewise_mult": return ScalarEwiseMultExpr(right, left) return ScalarEwiseAddExpr(right, left) @@ -499,9 +574,8 @@ def _matmul_infix_expr(left, right, *, within): ) # Create dummy expression to check compatibility of dimensions, etc. - # Be careful not to accidentally reify an expression. - expr = getattr(left if type(left) is left_type else left.dup(clear=True), method)( - right if type(right) is right_type else right.dup(clear=True), any_pair[bool] + expr = getattr(_dummy(left, left_type) if isinstance(left, InfixExprBase) else left, method)( + _dummy(right, right_type) if isinstance(right, InfixExprBase) else right, any_pair[BOOL] ) if expr.output_type is Vector: return VectorMatMulExpr(left, right, method_name=method, size=expr._size) diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 065fe67f8..a2ae97c78 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -12,7 +12,14 @@ from .descriptor import lookup as descriptor_lookup from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _get_typed_op_from_exprs, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -1938,17 +1945,36 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax C << monoid.max(A | B) """ + from .infix import ( + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ) + method_name = "ewise_add" other = self._expect_type( other, - (Matrix, TransposedMatrix, Vector), + ( + Matrix, + TransposedMatrix, + Vector, + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ), within=method_name, argname="other", op=op, ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + op = _get_typed_op_from_exprs(op, self, other, kind="binary") # Per the spec, op may be a semiring, but this is weird, so don't. self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if type(self) is MatrixEwiseMultExpr: + raise TypeError("XXX") + if type(other) in {MatrixEwiseMultExpr, VectorEwiseMultExpr}: + raise TypeError("XXX") if other.ndim == 1: # Broadcast rowwise from the right if self._ncols != other._size: @@ -1957,6 +1983,16 @@ def ewise_add(self, other, op=monoid.plus): f"to rows of Matrix in {method_name}. Matrix.ncols (={self._ncols}) " f"must equal Vector.size (={other._size})." ) + if type(self) is MatrixEwiseAddExpr: + 1 / 0 + self = self._expect_type( + op(self), Matrix, within=method_name, argname="self", op=op + ) + if type(other) is VectorEwiseAddExpr: + 1 / 0 + other = self._expect_type( + op(other), Vector, within=method_name, argname="other", op=op + ) return MatrixExpression( method_name, None, @@ -1965,6 +2001,12 @@ def ewise_add(self, other, op=monoid.plus): ncols=self._ncols, op=op, ) + if type(self) is MatrixEwiseAddExpr: + 1 / 0 + self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) + if type(other) is MatrixEwiseAddExpr: + 1 / 0 + other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, f"GrB_Matrix_eWiseAdd_{op.opclass}", @@ -2006,13 +2048,38 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax C << binary.gt(A & B) """ + from .infix import ( + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ) + method_name = "ewise_mult" other = self._expect_type( - other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op + other, + ( + Matrix, + TransposedMatrix, + Vector, + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ), + within=method_name, + argname="other", + op=op, ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + op = _get_typed_op_from_exprs(op, self, other, kind="binary") # Per the spec, op may be a semiring, but this is weird, so don't. self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if type(self) is MatrixEwiseAddExpr: + 1 / 0 + raise TypeError("XXX") + if type(other) in {MatrixEwiseAddExpr, VectorEwiseAddExpr}: + 1 / 0 + raise TypeError("XXX") if other.ndim == 1: # Broadcast rowwise from the right if self._ncols != other._size: @@ -2021,6 +2088,16 @@ def ewise_mult(self, other, op=binary.times): f"to rows of Matrix in {method_name}. Matrix.ncols (={self._ncols}) " f"must equal Vector.size (={other._size})." ) + if type(self) is MatrixEwiseMultExpr: + 1 / 0 + self = self._expect_type( + op(self), Matrix, within=method_name, argname="self", op=op + ) + if type(other) is VectorEwiseMultExpr: + 1 / 0 + other = self._expect_type( + op(other), Vector, within=method_name, argname="other", op=op + ) return MatrixExpression( method_name, None, @@ -2029,6 +2106,10 @@ def ewise_mult(self, other, op=binary.times): ncols=self._ncols, op=op, ) + if type(self) is MatrixEwiseMultExpr: + self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) + if type(other) is MatrixEwiseMultExpr: + other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, f"GrB_Matrix_eWiseMult_{op.opclass}", @@ -2074,9 +2155,28 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax C << binary.div(A | B, left_default=1, right_default=1) """ + from .infix import ( + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ) + method_name = "ewise_union" other = self._expect_type( - other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op + other, + ( + Matrix, + TransposedMatrix, + Vector, + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ), + within=method_name, + argname="other", + op=op, ) dtype = self.dtype if self.dtype._is_udt else None if type(left_default) is not Scalar: @@ -2111,12 +2211,23 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - scalar_dtype = unify(left.dtype, right.dtype) - nonscalar_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + + op = _get_typed_op_from_exprs(op, self, right, kind="binary") + op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + if op is not op2: + scalar_dtype = unify(op.type2, op2.type) + nonscalar_dtype = unify(op.type, op2.type2) + op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + 1 / 0 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop + if type(self) is MatrixEwiseMultExpr: + 1 / 0 + raise TypeError("XXX") + if type(other) in {MatrixEwiseMultExpr, VectorEwiseMultExpr}: + 1 / 0 + raise TypeError("XXX") expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 1: # Broadcast rowwise from the right @@ -2126,6 +2237,24 @@ def ewise_union(self, other, op, left_default, right_default): f"to rows of Matrix in {method_name}. Matrix.ncols (={self._ncols}) " f"must equal Vector.size (={other._size})." ) + if type(self) is MatrixEwiseAddExpr: + 1 / 0 + self = self._expect_type( + op(self, left_default=left, right_default=right), + Matrix, + within=method_name, + argname="self", + op=op, + ) + if type(other) is VectorEwiseAddExpr: + 1 / 0 + other = self._expect_type( + op(other, left_default=left, right_default=right), + Vector, + within=method_name, + argname="other", + op=op, + ) return MatrixExpression( method_name, None, @@ -2135,6 +2264,24 @@ def ewise_union(self, other, op, left_default, right_default): ncols=self._ncols, op=op, ) + if type(self) is MatrixEwiseAddExpr: + 1 / 0 + self = self._expect_type( + op(self, left_default=left, right_default=right), + Matrix, + within=method_name, + argname="self", + op=op, + ) + if type(other) is MatrixEwiseAddExpr: + 1 / 0 + other = self._expect_type( + op(other, left_default=left, right_default=right), + Matrix, + within=method_name, + argname="other", + op=op, + ) if backend == "suitesparse": expr = MatrixExpression( method_name, @@ -2196,6 +2343,7 @@ def mxv(self, other, op=semiring.plus_times): op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") if type(self) is MatrixMatMulExpr: + 1 / 0 self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) if type(other) is VectorMatMulExpr: other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) @@ -2251,8 +2399,10 @@ def mxm(self, other, op=semiring.plus_times): op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") if type(self) is MatrixMatMulExpr: + 1 / 0 self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) if type(other) is MatrixMatMulExpr: + 1 / 0 other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, diff --git a/graphblas/core/operator/__init__.py b/graphblas/core/operator/__init__.py index 509e84a04..d59c835b3 100644 --- a/graphblas/core/operator/__init__.py +++ b/graphblas/core/operator/__init__.py @@ -6,6 +6,7 @@ from .semiring import ParameterizedSemiring, Semiring from .unary import ParameterizedUnaryOp, UnaryOp from .utils import ( + _get_typed_op_from_exprs, aggregator_from_string, binary_from_string, get_semiring, diff --git a/graphblas/core/operator/utils.py b/graphblas/core/operator/utils.py index 00df31db8..cd0b82d3c 100644 --- a/graphblas/core/operator/utils.py +++ b/graphblas/core/operator/utils.py @@ -2,6 +2,7 @@ from ... import backend, binary, config, indexunary, monoid, op, select, semiring, unary from ...dtypes import UINT64, lookup_dtype, unify +from ..expr import InfixExprBase from .base import ( _SS_OPERATORS, OpBase, @@ -132,6 +133,30 @@ def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scala raise TypeError(f"Unable to get typed operator from object with type {type(op)}") +def _get_typed_op_from_exprs(op, left, right, *, kind=None): + if isinstance(left, InfixExprBase): + left_op = _get_typed_op_from_exprs(op, left.left, left.right, kind=kind) + left_dtype = left_op.type + else: + left_op = None + left_dtype = left.dtype + if isinstance(right, InfixExprBase): + right_op = _get_typed_op_from_exprs(op, right.left, right.right, kind=kind) + if right_op is left_op: + return right_op + right_dtype = right_op.type2 + else: + right_dtype = right.dtype + return get_typed_op( + op, + left_dtype, + right_dtype, + is_left_scalar=left._is_scalar, + is_right_scalar=right._is_scalar, + kind=kind, + ) + + def get_semiring(monoid, binaryop, name=None): """Get or create a Semiring object from a monoid and binaryop. diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index 15758a924..3510aede5 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -11,7 +11,14 @@ from .descriptor import lookup as descriptor_lookup from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _get_typed_op_from_exprs, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -1038,15 +1045,37 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax w << monoid.max(u | v) """ + from .infix import ( + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ) from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_add" other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op + other, + ( + Vector, + Matrix, + TransposedMatrix, + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ), + within=method_name, + argname="other", + op=op, ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + op = _get_typed_op_from_exprs(op, self, other, kind="binary") # Per the spec, op may be a semiring, but this is weird, so don't. self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if type(self) is VectorEwiseMultExpr: + raise TypeError("XXX") + if type(other) in {VectorEwiseMultExpr, MatrixEwiseMultExpr}: + raise TypeError("XXX") if other.ndim == 2: # Broadcast columnwise from the left if other._nrows != self._size: @@ -1056,6 +1085,16 @@ def ewise_add(self, other, op=monoid.plus): f"to columns of Matrix in {method_name}. Matrix.nrows (={other._nrows}) " f"must equal Vector.size (={self._size})." ) + if type(self) is VectorEwiseAddExpr: + 1 / 0 + self = self._expect_type( + op(self), Vector, within=method_name, argname="self", op=op + ) + if type(other) is MatrixEwiseAddExpr: + 1 / 0 + other = self._expect_type( + op(other), Matrix, within=method_name, argname="other", op=op + ) return MatrixExpression( method_name, None, @@ -1064,6 +1103,12 @@ def ewise_add(self, other, op=monoid.plus): ncols=other._ncols, op=op, ) + if type(self) is VectorEwiseAddExpr: + 1 / 0 + self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) + if type(other) is VectorEwiseAddExpr: + 1 / 0 + other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, f"GrB_Vector_eWiseAdd_{op.opclass}", @@ -1103,15 +1148,39 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax w << binary.gt(u & v) """ + from .infix import ( + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ) from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_mult" other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op + other, + ( + Vector, + Matrix, + TransposedMatrix, + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ), + within=method_name, + argname="other", + op=op, ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + op = _get_typed_op_from_exprs(op, self, other, kind="binary") # Per the spec, op may be a semiring, but this is weird, so don't. self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if type(self) is VectorEwiseAddExpr: + 1 / 0 + raise TypeError("XXX") + if type(other) in {VectorEwiseAddExpr, MatrixEwiseAddExpr}: + 1 / 0 + raise TypeError("XXX") if other.ndim == 2: # Broadcast columnwise from the left if other._nrows != self._size: @@ -1120,6 +1189,16 @@ def ewise_mult(self, other, op=binary.times): f"to columns of Matrix in {method_name}. Matrix.nrows (={other._nrows}) " f"must equal Vector.size (={self._size})." ) + if type(self) is VectorEwiseMultExpr: + 1 / 0 + self = self._expect_type( + op(self), Vector, within=method_name, argname="self", op=op + ) + if type(other) is MatrixEwiseMultExpr: + 1 / 0 + other = self._expect_type( + op(other), Matrix, within=method_name, argname="other", op=op + ) return MatrixExpression( method_name, None, @@ -1128,6 +1207,10 @@ def ewise_mult(self, other, op=binary.times): ncols=other._ncols, op=op, ) + if type(self) is VectorEwiseMultExpr: + self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) + if type(other) is VectorEwiseMultExpr: + other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, f"GrB_Vector_eWiseMult_{op.opclass}", @@ -1171,11 +1254,29 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax w << binary.div(u | v, left_default=1, right_default=1) """ + from .infix import ( + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ) from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_union" other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op + other, + ( + Vector, + Matrix, + TransposedMatrix, + MatrixEwiseAddExpr, + MatrixEwiseMultExpr, + VectorEwiseAddExpr, + VectorEwiseMultExpr, + ), + within=method_name, + argname="other", + op=op, ) dtype = self.dtype if self.dtype._is_udt else None if type(left_default) is not Scalar: @@ -1210,12 +1311,23 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - scalar_dtype = unify(left.dtype, right.dtype) - nonscalar_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + + op = _get_typed_op_from_exprs(op, self, right, kind="binary") + op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + if op is not op2: + scalar_dtype = unify(op.type2, op2.type) + nonscalar_dtype = unify(op.type, op2.type2) + op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + 1 / 0 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop + if type(self) is VectorEwiseMultExpr: + 1 / 0 + raise TypeError("XXX") + if type(other) in {VectorEwiseMultExpr, MatrixEwiseMultExpr}: + 1 / 0 + raise TypeError("XXX") expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 2: # Broadcast columnwise from the left @@ -1225,6 +1337,24 @@ def ewise_union(self, other, op, left_default, right_default): f"to columns of Matrix in {method_name}. Matrix.nrows (={other._nrows}) " f"must equal Vector.size (={self._size})." ) + if type(self) is VectorEwiseAddExpr: + 1 / 0 + self = self._expect_type( + op(self, left_default=left, right_default=right), + Vector, + within=method_name, + argname="self", + op=op, + ) + if type(other) is MatrixEwiseAddExpr: + 1 / 0 + other = self._expect_type( + op(other, left_default=left, right_default=right), + Matrix, + within=method_name, + argname="other", + op=op, + ) return MatrixExpression( method_name, None, @@ -1234,6 +1364,24 @@ def ewise_union(self, other, op, left_default, right_default): ncols=other._ncols, op=op, ) + if type(self) is VectorEwiseAddExpr: + 1 / 0 + self = self._expect_type( + op(self, left_default=left, right_default=right), + Vector, + within=method_name, + argname="self", + op=op, + ) + if type(other) is VectorEwiseAddExpr: + 1 / 0 + other = self._expect_type( + op(other, left_default=left, right_default=right), + Vector, + within=method_name, + argname="other", + op=op, + ) if backend == "suitesparse": expr = VectorExpression( method_name, diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index fe85bb9bf..69ba38774 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -2805,6 +2805,8 @@ def test_ss_nbytes(A): @autocompute def test_auto(A, v): + from graphblas.core.infix import MatrixEwiseMultExpr + expected = binary.land[bool](A & A).new() B = A.dup(dtype=bool) for expr in [(B & B), binary.land[bool](A & A)]: @@ -2827,14 +2829,26 @@ def test_auto(A, v): "__and__", "__or__", # "kronecker", + "__rand__", + "__ror__", ]: + # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() val2 = getattr(expected, method)(expr) - val3 = getattr(expr, method)(expected) - val4 = getattr(expr, method)(expr) - assert val1.isequal(val2) - assert val1.isequal(val3) - assert val1.isequal(val4) + if method in {"__or__", "__ror__"} and type(expr) is MatrixEwiseMultExpr: + # Doing e.g. `plus(A & B | C)` isn't allowed--make user be explicit + with pytest.raises(TypeError): + assert val1.isequal(val2) + with pytest.raises(TypeError): + val3 = getattr(expr, method)(expected) + with pytest.raises(TypeError): + val4 = getattr(expr, method)(expr) + else: + assert val1.isequal(val2) + val3 = getattr(expr, method)(expected) + assert val1.isequal(val3) + val4 = getattr(expr, method)(expr) + assert val1.isequal(val4) for method in ["reduce_rowwise", "reduce_columnwise", "reduce_scalar"]: s1 = getattr(expected, method)(monoid.lor).new() s2 = getattr(expr, method)(monoid.lor) @@ -3136,6 +3150,10 @@ def test_ss_reshape(A): def test_autocompute_argument_messages(A, v): with pytest.raises(TypeError, match="autocompute"): A.ewise_mult(A & A) + with pytest.raises(TypeError, match="autocompute"): + A.ewise_mult(binary.plus(A & A)) + with pytest.raises(TypeError, match="autocompute"): + A.ewise_mult(A + A) with pytest.raises(TypeError, match="autocompute"): A.mxv(A @ v) diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index 2571f288b..d08c9cb3b 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -1532,6 +1532,8 @@ def test_outer(v): @autocompute def test_auto(v): + from graphblas.core.infix import VectorEwiseMultExpr + v = v.dup(dtype=bool) expected = binary.land(v & v).new() assert 0 not in expected @@ -1579,16 +1581,26 @@ def test_auto(v): "__rand__", "__ror__", ]: + # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() val2 = getattr(expected, method)(expr) - val3 = getattr(expr, method)(expected) - val4 = getattr(expr, method)(expr) - assert val1.isequal(val2) - assert val1.isequal(val3) - assert val1.isequal(val4) - assert val1.isequal(val2.new()) - assert val1.isequal(val3.new()) - assert val1.isequal(val4.new()) + if method in {"__or__", "__ror__"} and type(expr) is VectorEwiseMultExpr: + # Doing e.g. `plus(x & y | z)` isn't allowed--make user be explicit + with pytest.raises(TypeError): + assert val1.isequal(val2) + with pytest.raises(TypeError): + val3 = getattr(expr, method)(expected) + with pytest.raises(TypeError): + val4 = getattr(expr, method)(expr) + else: + assert val1.isequal(val2) + assert val1.isequal(val2.new()) + val3 = getattr(expr, method)(expected) + assert val1.isequal(val3) + assert val1.isequal(val3.new()) + val4 = getattr(expr, method)(expr) + assert val1.isequal(val4) + assert val1.isequal(val4.new()) s1 = expected.reduce(monoid.lor).new() s2 = expr.reduce(monoid.lor) assert s1.isequal(s2.new()) From 6285d83a6e86151ed9175781cf50195efdce27fb Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Fri, 8 Sep 2023 00:42:19 -0500 Subject: [PATCH 03/10] Fix imports --- graphblas/core/infix.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 39bf2fbc1..6042321f3 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -2,12 +2,11 @@ from ..dtypes import BOOL from ..monoid import land, lor from ..semiring import any_pair -from . import automethods, utils +from . import automethods, recorder, utils from .base import _expect_op, _expect_type from .expr import InfixExprBase from .mask import Mask from .matrix import Matrix, MatrixExpression, TransposedMatrix -from .recorder import skip_record from .scalar import Scalar, ScalarExpression from .utils import output_type, wrapdoc from .vector import Vector, VectorExpression @@ -477,7 +476,7 @@ def __init__(self, left, right, *, nrows, ncols): def _dummy(obj, obj_type): - with skip_record: + with recorder.skip_record: return obj_type(BOOL, *obj.shape, name="") From fa61888577180c603608e736bbecbcc3617ab1a0 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Thu, 14 Sep 2023 10:54:12 +0200 Subject: [PATCH 04/10] checkpoint --- graphblas/core/base.py | 18 +++++++--- graphblas/core/infix.py | 47 +++++++------------------ graphblas/core/matrix.py | 17 ++++++--- graphblas/core/vector.py | 39 +++++++++++---------- graphblas/tests/test_infix.py | 66 ++++++++++++++++++++++++++++++++++- 5 files changed, 123 insertions(+), 64 deletions(-) diff --git a/graphblas/core/base.py b/graphblas/core/base.py index 42a4de9a1..011e7bbf5 100644 --- a/graphblas/core/base.py +++ b/graphblas/core/base.py @@ -263,23 +263,33 @@ def __call__( ) def __or__(self, other): - from .infix import _ewise_infix_expr + from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr, _ewise_infix_expr + if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): + raise TypeError("XXX") return _ewise_infix_expr(self, other, method="ewise_add", within="__or__") def __ror__(self, other): - from .infix import _ewise_infix_expr + from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr, _ewise_infix_expr + if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): + 1 / 0 + raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_add", within="__ror__") def __and__(self, other): - from .infix import _ewise_infix_expr + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr, _ewise_infix_expr + if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): + raise TypeError("XXX") return _ewise_infix_expr(self, other, method="ewise_mult", within="__and__") def __rand__(self, other): - from .infix import _ewise_infix_expr + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr, _ewise_infix_expr + if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): + 1 / 0 + raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_mult", within="__rand__") def __matmul__(self, other): diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 6042321f3..f4115a487 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -238,31 +238,17 @@ class VectorEwiseAddExpr(VectorInfixExpr): _to_expr = _ewise_add_to_expr + __or__ = Vector.__or__ + __ror__ = Vector.__ror__ ewise_add = Vector.ewise_add ewise_mult = Vector.ewise_mult ewise_union = Vector.ewise_union - def __and__(self, other): - 1 / 0 + def __and__(self, other, *, within="__and__"): raise TypeError("XXX") def __rand__(self, other): - 1 / 0 - raise TypeError("XXX") - - def __or__(self, other): - if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): - 1 / 0 - raise TypeError("XXX") - 1 / 0 - return _ewise_infix_expr(self, other, method="ewise_add", within="__or__") - - def __ror__(self, other): - if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): - 1 / 0 - raise TypeError("XXX") - 1 / 0 - return _ewise_infix_expr(other, self, method="ewise_add", within="__ror__") + self.__and__(other, within="__rand__") class VectorEwiseMultExpr(VectorInfixExpr): @@ -273,21 +259,12 @@ class VectorEwiseMultExpr(VectorInfixExpr): _to_expr = _ewise_mult_to_expr + __and__ = Vector.__and__ + __rand__ = Vector.__rand__ ewise_add = Vector.ewise_add ewise_mult = Vector.ewise_mult ewise_union = Vector.ewise_union - def __and__(self, other): - if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): - raise TypeError("XXX") - return _ewise_infix_expr(self, other, method="ewise_mult", within="__and__") - - def __rand__(self, other): - if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): - 1 / 0 - raise TypeError("XXX") - return _ewise_infix_expr(other, self, method="ewise_mult", within="__rand__") - def __or__(self, other): raise TypeError("XXX") @@ -427,10 +404,10 @@ class MatrixEwiseAddExpr(MatrixInfixExpr): _to_expr = _ewise_add_to_expr - __and__ = VectorEwiseAddExpr.__and__ - __or__ = VectorEwiseAddExpr.__or__ - __rand__ = VectorEwiseAddExpr.__rand__ - __ror__ = VectorEwiseAddExpr.__ror__ + __and__ = VectorEwiseMultExpr.__and__ + __rand__ = VectorEwiseMultExpr.__rand__ + __or__ = Matrix.__or__ + __ror__ = Matrix.__ror__ ewise_add = Matrix.ewise_add ewise_mult = Matrix.ewise_mult ewise_union = Matrix.ewise_union @@ -444,9 +421,9 @@ class MatrixEwiseMultExpr(MatrixInfixExpr): _to_expr = _ewise_mult_to_expr - __and__ = VectorEwiseMultExpr.__and__ + __and__ = Matrix.__and__ + __rand__ = Matrix.__rand__ __or__ = VectorEwiseMultExpr.__or__ - __rand__ = VectorEwiseMultExpr.__rand__ __ror__ = VectorEwiseMultExpr.__ror__ ewise_add = Matrix.ewise_add ewise_mult = Matrix.ewise_mult diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index a2ae97c78..a4d846768 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -1972,8 +1972,10 @@ def ewise_add(self, other, op=monoid.plus): # Per the spec, op may be a semiring, but this is weird, so don't. self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if type(self) is MatrixEwiseMultExpr: + 1 / 0 raise TypeError("XXX") if type(other) in {MatrixEwiseMultExpr, VectorEwiseMultExpr}: + 1 / 0 raise TypeError("XXX") if other.ndim == 1: # Broadcast rowwise from the right @@ -2107,8 +2109,10 @@ def ewise_mult(self, other, op=binary.times): op=op, ) if type(self) is MatrixEwiseMultExpr: + 1 / 0 self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) if type(other) is MatrixEwiseMultExpr: + 1 / 0 other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, @@ -2212,13 +2216,15 @@ def ewise_union(self, other, op, left_default, right_default): else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - op = _get_typed_op_from_exprs(op, self, right, kind="binary") + op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") - if op is not op2: - scalar_dtype = unify(op.type2, op2.type) - nonscalar_dtype = unify(op.type, op2.type2) - op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") 1 / 0 + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop @@ -2346,6 +2352,7 @@ def mxv(self, other, op=semiring.plus_times): 1 / 0 self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) if type(other) is VectorMatMulExpr: + 1 / 0 other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index 3510aede5..b14d00555 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -68,13 +68,13 @@ def _v_union_m(updater, left, right, left_default, right_default, op): updater << temp.ewise_union(right, op, left_default=left_default, right_default=right_default) -def _v_union_v(updater, left, right, left_default, right_default, op, dtype): +def _v_union_v(updater, left, right, left_default, right_default, op): mask = updater.kwargs.get("mask") opts = updater.opts - new_left = left.dup(dtype, clear=True) + new_left = left.dup(op.type, clear=True) new_left(mask=mask, **opts) << binary.second(right, left_default) new_left(mask=mask, **opts) << binary.first(left | new_left) - new_right = right.dup(dtype, clear=True) + new_right = right.dup(op.type2, clear=True) new_right(mask=mask, **opts) << binary.second(left, right_default) new_right(mask=mask, **opts) << binary.first(right | new_right) updater << op(new_left & new_right) @@ -1104,10 +1104,8 @@ def ewise_add(self, other, op=monoid.plus): op=op, ) if type(self) is VectorEwiseAddExpr: - 1 / 0 self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) if type(other) is VectorEwiseAddExpr: - 1 / 0 other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, @@ -1176,10 +1174,8 @@ def ewise_mult(self, other, op=binary.times): # Per the spec, op may be a semiring, but this is weird, so don't. self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if type(self) is VectorEwiseAddExpr: - 1 / 0 raise TypeError("XXX") if type(other) in {VectorEwiseAddExpr, MatrixEwiseAddExpr}: - 1 / 0 raise TypeError("XXX") if other.ndim == 2: # Broadcast columnwise from the left @@ -1278,7 +1274,9 @@ def ewise_union(self, other, op, left_default, right_default): argname="other", op=op, ) - dtype = self.dtype if self.dtype._is_udt else None + temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -1295,6 +1293,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -1312,21 +1312,21 @@ def ewise_union(self, other, op, left_default, right_default): else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - op = _get_typed_op_from_exprs(op, self, right, kind="binary") + op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") - if op is not op2: - scalar_dtype = unify(op.type2, op2.type) - nonscalar_dtype = unify(op.type, op2.type2) - op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") 1 / 0 + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop if type(self) is VectorEwiseMultExpr: - 1 / 0 raise TypeError("XXX") if type(other) in {VectorEwiseMultExpr, MatrixEwiseMultExpr}: - 1 / 0 raise TypeError("XXX") expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 2: @@ -1365,7 +1365,6 @@ def ewise_union(self, other, op, left_default, right_default): op=op, ) if type(self) is VectorEwiseAddExpr: - 1 / 0 self = self._expect_type( op(self, left_default=left, right_default=right), Vector, @@ -1374,7 +1373,6 @@ def ewise_union(self, other, op, left_default, right_default): op=op, ) if type(other) is VectorEwiseAddExpr: - 1 / 0 other = self._expect_type( op(other, left_default=left, right_default=right), Vector, @@ -1391,11 +1389,10 @@ def ewise_union(self, other, op, left_default, right_default): expr_repr=expr_repr, ) else: - dtype = unify(scalar_dtype, nonscalar_dtype, is_left_scalar=True) expr = VectorExpression( method_name, None, - [self, left, other, right, _v_union_v, (self, other, left, right, op, dtype)], + [self, left, other, right, _v_union_v, (self, other, left, right, op)], expr_repr=expr_repr, size=self._size, op=op, @@ -1448,8 +1445,10 @@ def vxm(self, other, op=semiring.plus_times): self._expect_op(op, "Semiring", within=method_name, argname="op") if type(self) is VectorMatMulExpr: self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) + 1 / 0 if type(other) is MatrixMatMulExpr: other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) + 1 / 0 expr = VectorExpression( method_name, "GrB_vxm", @@ -1801,8 +1800,10 @@ def inner(self, other, op=semiring.plus_times): self._expect_op(op, "Semiring", within=method_name, argname="op") if type(self) is VectorMatMulExpr: self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) + 1 / 0 if type(other) is VectorMatMulExpr: other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) + 1 / 0 expr = ScalarExpression( method_name, "GrB_vxm", diff --git a/graphblas/tests/test_infix.py b/graphblas/tests/test_infix.py index 72e1c8a42..614f33423 100644 --- a/graphblas/tests/test_infix.py +++ b/graphblas/tests/test_infix.py @@ -1,6 +1,6 @@ import pytest -from graphblas import monoid, op +from graphblas import binary, monoid, op from graphblas.exceptions import DimensionMismatch from .conftest import autocompute @@ -367,3 +367,67 @@ def test_infix_expr_value_types(): expr._value = None assert expr._value is None assert expr._expr._value is None + + +@autocompute +def test_multi_infix_ewise(): + v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . + v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 + v3 = Vector.from_coo([2, 0], [1, 2], size=3) # 2 . 1 + result = binary.plus((v1 | v2) | v3).new() + expected = Vector.from_scalar(3, size=3) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = monoid.min(v1 | v2 | v3).new() + expected = Vector.from_scalar(1, size=3) + assert result.isequal(expected) + result = monoid.max((v1 & v2) & v3).new() + expected = Vector(int, size=3) + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = Vector.from_coo([1], [1], size=3) + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = Vector.from_scalar(13, size=3) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We don't (yet) differentiate between infix and methods + with pytest.raises(TypeError, match="XXX"): # TODO + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).ewise_mult(v3) From 9b1fdb1cfa9686156a76846c832d5678e8fed262 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Thu, 14 Sep 2023 12:47:55 +0200 Subject: [PATCH 05/10] a little more --- graphblas/core/matrix.py | 17 +++++++++-------- graphblas/core/vector.py | 5 ----- graphblas/tests/test_infix.py | 21 +++++++++++++++++++++ pyproject.toml | 3 +++ 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index a4d846768..80587b220 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -74,13 +74,13 @@ def _m_mult_v(updater, left, right, op): updater << left.mxm(right.diag(name="M_temp"), get_semiring(monoid.any, op)) -def _m_union_m(updater, left, right, left_default, right_default, op, dtype): +def _m_union_m(updater, left, right, left_default, right_default, op): mask = updater.kwargs.get("mask") opts = updater.opts - new_left = left.dup(dtype, clear=True) + new_left = left.dup(op.type, clear=True) new_left(mask=mask, **opts) << binary.second(right, left_default) new_left(mask=mask, **opts) << binary.first(left | new_left) - new_right = right.dup(dtype, clear=True) + new_right = right.dup(op.type2, clear=True) new_right(mask=mask, **opts) << binary.second(left, right_default) new_right(mask=mask, **opts) << binary.first(right | new_right) updater << op(new_left & new_right) @@ -1986,12 +1986,10 @@ def ewise_add(self, other, op=monoid.plus): f"must equal Vector.size (={other._size})." ) if type(self) is MatrixEwiseAddExpr: - 1 / 0 self = self._expect_type( op(self), Matrix, within=method_name, argname="self", op=op ) if type(other) is VectorEwiseAddExpr: - 1 / 0 other = self._expect_type( op(other), Vector, within=method_name, argname="other", op=op ) @@ -2182,7 +2180,9 @@ def ewise_union(self, other, op, left_default, right_default): argname="other", op=op, ) - dtype = self.dtype if self.dtype._is_udt else None + temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -2199,6 +2199,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -2299,11 +2301,10 @@ def ewise_union(self, other, op, left_default, right_default): expr_repr=expr_repr, ) else: - dtype = unify(scalar_dtype, nonscalar_dtype, is_left_scalar=True) expr = MatrixExpression( method_name, None, - [self, left, other, right, _m_union_m, (self, other, left, right, op, dtype)], + [self, left, other, right, _m_union_m, (self, other, left, right, op)], expr_repr=expr_repr, nrows=self._nrows, ncols=self._ncols, diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index b14d00555..c600a42ca 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -1086,12 +1086,10 @@ def ewise_add(self, other, op=monoid.plus): f"must equal Vector.size (={self._size})." ) if type(self) is VectorEwiseAddExpr: - 1 / 0 self = self._expect_type( op(self), Vector, within=method_name, argname="self", op=op ) if type(other) is MatrixEwiseAddExpr: - 1 / 0 other = self._expect_type( op(other), Matrix, within=method_name, argname="other", op=op ) @@ -1318,7 +1316,6 @@ def ewise_union(self, other, op, left_default, right_default): left_dtype = unify(op1.type, op2.type, is_right_scalar=True) right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) op = get_typed_op(op, left_dtype, right_dtype, kind="binary") - 1 / 0 else: op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") @@ -1800,10 +1797,8 @@ def inner(self, other, op=semiring.plus_times): self._expect_op(op, "Semiring", within=method_name, argname="op") if type(self) is VectorMatMulExpr: self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) - 1 / 0 if type(other) is VectorMatMulExpr: other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) - 1 / 0 expr = ScalarExpression( method_name, "GrB_vxm", diff --git a/graphblas/tests/test_infix.py b/graphblas/tests/test_infix.py index 614f33423..5d8f2df3f 100644 --- a/graphblas/tests/test_infix.py +++ b/graphblas/tests/test_infix.py @@ -371,9 +371,11 @@ def test_infix_expr_value_types(): @autocompute def test_multi_infix_ewise(): + D0 = Vector.from_scalar(0, 3).diag() v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 v3 = Vector.from_coo([2, 0], [1, 2], size=3) # 2 . 1 + # ewise_add result = binary.plus((v1 | v2) | v3).new() expected = Vector.from_scalar(3, size=3) assert result.isequal(expected) @@ -382,6 +384,7 @@ def test_multi_infix_ewise(): result = monoid.min(v1 | v2 | v3).new() expected = Vector.from_scalar(1, size=3) assert result.isequal(expected) + # ewise_mult result = monoid.max((v1 & v2) & v3).new() expected = Vector(int, size=3) assert result.isequal(expected) @@ -390,11 +393,29 @@ def test_multi_infix_ewise(): result = monoid.min((v1 & v2) & v1).new() expected = Vector.from_coo([1], [1], size=3) assert result.isequal(expected) + # ewise_union result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() expected = Vector.from_scalar(13, size=3) assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + expected = Vector.from_scalar(13.0, size=3) + assert result.isequal(expected) result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() assert result.isequal(expected) + # inner + assert op.plus_plus(v1 @ v1).value == 6 + assert op.plus_plus(v1 @ (v1 @ D0)).value == 6 + assert op.plus_plus((D0 @ v1) @ v1).value == 6 + # matrix-vector ewise_add + result = binary.plus((D0 | v1) | v2).new() + expected = binary.plus(binary.plus(D0 | v1) | v2).new() + assert result.isequal(expected) + result = binary.plus(D0 | (v1 | v2)).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | D0).new() + assert result.isequal(expected.T) + result = binary.plus(v1 | (v2 | D0)).new() + assert result.isequal(expected.T) with pytest.raises(TypeError, match="XXX"): # TODO (v1 & v2) | v3 diff --git a/pyproject.toml b/pyproject.toml index 619ce18f2..ff970cc0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,6 +235,9 @@ ignore-words-list = "coo,ba" # https://github.com/charliermarsh/ruff/ line-length = 100 target-version = "py39" +unfixable = [ + "F841" # unused-variable (Note: can leave useless expression) +] select = [ # Have we enabled too many checks that they'll become a nuisance? We'll see... "F", # pyflakes From 0a26c3e1e1a335ac67d4b25e0f809fa43eb5656c Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Fri, 29 Sep 2023 17:32:21 -0500 Subject: [PATCH 06/10] bump pre-commit --- .pre-commit-config.yaml | 10 +++++----- scripts/check_versions.sh | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a945fe49a..16f677471 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: - id: isort # Let's keep `pyupgrade` even though `ruff --fix` probably does most of it - repo: https://github.com/asottile/pyupgrade - rev: v3.12.0 + rev: v3.13.0 hooks: - id: pyupgrade args: [--py39-plus] @@ -66,7 +66,7 @@ repos: - id: black - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.290 + rev: v0.0.291 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -80,7 +80,7 @@ repos: # These versions need updated manually - flake8==6.1.0 - flake8-bugbear==23.9.16 - - flake8-simplify==0.20.0 + - flake8-simplify==0.21.0 - repo: https://github.com/asottile/yesqa rev: v1.5.0 hooks: @@ -94,7 +94,7 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.290 + rev: v0.0.291 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint @@ -110,7 +110,7 @@ repos: - id: pyroma args: [-n, "10", .] - repo: https://github.com/shellcheck-py/shellcheck-py - rev: "v0.9.0.5" + rev: "v0.9.0.6" hooks: - id: shellcheck - repo: local diff --git a/scripts/check_versions.sh b/scripts/check_versions.sh index a76fee1d2..9051ebe6e 100755 --- a/scripts/check_versions.sh +++ b/scripts/check_versions.sh @@ -4,12 +4,12 @@ # This may be helpful when updating dependency versions in CI. # Tip: add `--json` for more information. conda search 'flake8-bugbear[channel=conda-forge]>=23.9.16' -conda search 'flake8-simplify[channel=conda-forge]>=0.20.0' +conda search 'flake8-simplify[channel=conda-forge]>=0.21.0' conda search 'numpy[channel=conda-forge]>=1.26.0' conda search 'pandas[channel=conda-forge]>=2.1.1' -conda search 'scipy[channel=conda-forge]>=1.11.2' +conda search 'scipy[channel=conda-forge]>=1.11.3' conda search 'networkx[channel=conda-forge]>=3.1' -conda search 'awkward[channel=conda-forge]>=2.4.3' +conda search 'awkward[channel=conda-forge]>=2.4.4' conda search 'sparse[channel=conda-forge]>=0.14.0' conda search 'fast_matrix_market[channel=conda-forge]>=1.7.3' conda search 'numba[channel=conda-forge]>=0.57.1' From 539633033e35a7ecdc8e52a1da430a4ea070775a Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Sun, 15 Oct 2023 17:16:55 -0500 Subject: [PATCH 07/10] More tests --- graphblas/core/matrix.py | 6 ------ graphblas/core/vector.py | 6 ------ graphblas/tests/test_infix.py | 40 ++++++++++++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 367f8495c..37a3ebd92 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -2089,12 +2089,10 @@ def ewise_mult(self, other, op=binary.times): f"must equal Vector.size (={other._size})." ) if type(self) is MatrixEwiseMultExpr: - 1 / 0 self = self._expect_type( op(self), Matrix, within=method_name, argname="self", op=op ) if type(other) is VectorEwiseMultExpr: - 1 / 0 other = self._expect_type( op(other), Vector, within=method_name, argname="other", op=op ) @@ -2246,7 +2244,6 @@ def ewise_union(self, other, op, left_default, right_default): f"must equal Vector.size (={other._size})." ) if type(self) is MatrixEwiseAddExpr: - 1 / 0 self = self._expect_type( op(self, left_default=left, right_default=right), Matrix, @@ -2255,7 +2252,6 @@ def ewise_union(self, other, op, left_default, right_default): op=op, ) if type(other) is VectorEwiseAddExpr: - 1 / 0 other = self._expect_type( op(other, left_default=left, right_default=right), Vector, @@ -2350,10 +2346,8 @@ def mxv(self, other, op=semiring.plus_times): op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") if type(self) is MatrixMatMulExpr: - 1 / 0 self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) if type(other) is VectorMatMulExpr: - 1 / 0 other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index c600a42ca..0a80f9118 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -1184,12 +1184,10 @@ def ewise_mult(self, other, op=binary.times): f"must equal Vector.size (={self._size})." ) if type(self) is VectorEwiseMultExpr: - 1 / 0 self = self._expect_type( op(self), Vector, within=method_name, argname="self", op=op ) if type(other) is MatrixEwiseMultExpr: - 1 / 0 other = self._expect_type( op(other), Matrix, within=method_name, argname="other", op=op ) @@ -1335,7 +1333,6 @@ def ewise_union(self, other, op, left_default, right_default): f"must equal Vector.size (={self._size})." ) if type(self) is VectorEwiseAddExpr: - 1 / 0 self = self._expect_type( op(self, left_default=left, right_default=right), Vector, @@ -1344,7 +1341,6 @@ def ewise_union(self, other, op, left_default, right_default): op=op, ) if type(other) is MatrixEwiseAddExpr: - 1 / 0 other = self._expect_type( op(other, left_default=left, right_default=right), Matrix, @@ -1442,10 +1438,8 @@ def vxm(self, other, op=semiring.plus_times): self._expect_op(op, "Semiring", within=method_name, argname="op") if type(self) is VectorMatMulExpr: self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) - 1 / 0 if type(other) is MatrixMatMulExpr: other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) - 1 / 0 expr = VectorExpression( method_name, "GrB_vxm", diff --git a/graphblas/tests/test_infix.py b/graphblas/tests/test_infix.py index 5d8f2df3f..e2efef109 100644 --- a/graphblas/tests/test_infix.py +++ b/graphblas/tests/test_infix.py @@ -370,7 +370,7 @@ def test_infix_expr_value_types(): @autocompute -def test_multi_infix_ewise(): +def test_multi_infix(): D0 = Vector.from_scalar(0, 3).diag() v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 @@ -416,6 +416,44 @@ def test_multi_infix_ewise(): assert result.isequal(expected.T) result = binary.plus(v1 | (v2 | D0)).new() assert result.isequal(expected.T) + # matrix-vector ewise_mult + result = binary.plus((D0 & v1) & v2).new() + expected = binary.plus(binary.plus(D0 & v1) & v2).new() + assert result.isequal(expected) + assert result.nvals > 0 + result = binary.plus(D0 & (v1 & v2)).new() + assert result.isequal(expected) + result = binary.plus((v1 & v2) & D0).new() + assert result.isequal(expected.T) + result = binary.plus(v1 & (v2 & D0)).new() + assert result.isequal(expected.T) + # matrix-vector ewise_union + kwargs = {"left_default": 10, "right_default": 20} + result = binary.plus((D0 | v1) | v2, **kwargs).new() + expected = binary.plus(binary.plus(D0 | v1, **kwargs) | v2, **kwargs).new() + assert result.isequal(expected) + result = binary.plus(D0 | (v1 | v2), **kwargs).new() + expected = binary.plus(D0 | binary.plus(v1 | v2, **kwargs), **kwargs).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | D0, **kwargs).new() + expected = binary.plus(binary.plus(v1 | v2, **kwargs) | D0, **kwargs).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | D0), **kwargs).new() + expected = binary.plus(v1 | binary.plus(v2 | D0, **kwargs), **kwargs).new() + assert result.isequal(expected) + # vxm, mxv + result = op.plus_plus((D0 @ v1) @ D0).new() + assert result.isequal(v1) + result = op.plus_plus(D0 @ (v1 @ D0)).new() + assert result.isequal(v1) + result = op.plus_plus(v1 @ (D0 @ D0)).new() + assert result.isequal(v1) + result = op.plus_plus((D0 @ D0) @ v1).new() + assert result.isequal(v1) + result = op.plus_plus((v1 @ D0) @ D0).new() + assert result.isequal(v1) + result = op.plus_plus(D0 @ (D0 @ v1)).new() + assert result.isequal(v1) with pytest.raises(TypeError, match="XXX"): # TODO (v1 & v2) | v3 From 343081e0dab0e2ac5cd54ff22556eb7f5228d4db Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Sat, 28 Oct 2023 09:55:46 -0500 Subject: [PATCH 08/10] more tests --- .pre-commit-config.yaml | 6 +-- graphblas/core/base.py | 2 - graphblas/core/infix.py | 4 +- graphblas/core/matrix.py | 15 ------ graphblas/tests/test_infix.py | 94 ++++++++++++++++++++++++++++++++++- scripts/check_versions.sh | 2 +- 6 files changed, 99 insertions(+), 24 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b2e08e638..3766e2e7c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,12 +61,12 @@ repos: - id: auto-walrus args: [--line-length, "100"] - repo: https://github.com/psf/black - rev: 23.10.0 + rev: 23.10.1 hooks: - id: black - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.1 + rev: v0.1.3 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -94,7 +94,7 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.1 + rev: v0.1.3 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint diff --git a/graphblas/core/base.py b/graphblas/core/base.py index 011e7bbf5..e384799ba 100644 --- a/graphblas/core/base.py +++ b/graphblas/core/base.py @@ -273,7 +273,6 @@ def __ror__(self, other): from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr, _ewise_infix_expr if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): - 1 / 0 raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_add", within="__ror__") @@ -288,7 +287,6 @@ def __rand__(self, other): from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr, _ewise_infix_expr if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): - 1 / 0 raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_mult", within="__rand__") diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index f4115a487..56de3c6cc 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -404,8 +404,8 @@ class MatrixEwiseAddExpr(MatrixInfixExpr): _to_expr = _ewise_add_to_expr - __and__ = VectorEwiseMultExpr.__and__ - __rand__ = VectorEwiseMultExpr.__rand__ + __and__ = VectorEwiseAddExpr.__and__ + __rand__ = VectorEwiseAddExpr.__rand__ __or__ = Matrix.__or__ __ror__ = Matrix.__ror__ ewise_add = Matrix.ewise_add diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 37a3ebd92..505395be1 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -1972,10 +1972,8 @@ def ewise_add(self, other, op=monoid.plus): # Per the spec, op may be a semiring, but this is weird, so don't. self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if type(self) is MatrixEwiseMultExpr: - 1 / 0 raise TypeError("XXX") if type(other) in {MatrixEwiseMultExpr, VectorEwiseMultExpr}: - 1 / 0 raise TypeError("XXX") if other.ndim == 1: # Broadcast rowwise from the right @@ -2002,10 +2000,8 @@ def ewise_add(self, other, op=monoid.plus): op=op, ) if type(self) is MatrixEwiseAddExpr: - 1 / 0 self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) if type(other) is MatrixEwiseAddExpr: - 1 / 0 other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, @@ -2075,10 +2071,8 @@ def ewise_mult(self, other, op=binary.times): # Per the spec, op may be a semiring, but this is weird, so don't. self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if type(self) is MatrixEwiseAddExpr: - 1 / 0 raise TypeError("XXX") if type(other) in {MatrixEwiseAddExpr, VectorEwiseAddExpr}: - 1 / 0 raise TypeError("XXX") if other.ndim == 1: # Broadcast rowwise from the right @@ -2105,10 +2099,8 @@ def ewise_mult(self, other, op=binary.times): op=op, ) if type(self) is MatrixEwiseMultExpr: - 1 / 0 self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) if type(other) is MatrixEwiseMultExpr: - 1 / 0 other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, @@ -2222,17 +2214,14 @@ def ewise_union(self, other, op, left_default, right_default): left_dtype = unify(op1.type, op2.type, is_right_scalar=True) right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) op = get_typed_op(op, left_dtype, right_dtype, kind="binary") - 1 / 0 else: op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop if type(self) is MatrixEwiseMultExpr: - 1 / 0 raise TypeError("XXX") if type(other) in {MatrixEwiseMultExpr, VectorEwiseMultExpr}: - 1 / 0 raise TypeError("XXX") expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 1: @@ -2269,7 +2258,6 @@ def ewise_union(self, other, op, left_default, right_default): op=op, ) if type(self) is MatrixEwiseAddExpr: - 1 / 0 self = self._expect_type( op(self, left_default=left, right_default=right), Matrix, @@ -2278,7 +2266,6 @@ def ewise_union(self, other, op, left_default, right_default): op=op, ) if type(other) is MatrixEwiseAddExpr: - 1 / 0 other = self._expect_type( op(other, left_default=left, right_default=right), Matrix, @@ -2401,10 +2388,8 @@ def mxm(self, other, op=semiring.plus_times): op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") if type(self) is MatrixMatMulExpr: - 1 / 0 self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) if type(other) is MatrixMatMulExpr: - 1 / 0 other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, diff --git a/graphblas/tests/test_infix.py b/graphblas/tests/test_infix.py index e2efef109..cc28f1134 100644 --- a/graphblas/tests/test_infix.py +++ b/graphblas/tests/test_infix.py @@ -370,7 +370,7 @@ def test_infix_expr_value_types(): @autocompute -def test_multi_infix(): +def test_multi_infix_vector(): D0 = Vector.from_scalar(0, 3).diag() v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 @@ -457,23 +457,115 @@ def test_multi_infix(): with pytest.raises(TypeError, match="XXX"): # TODO (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) with pytest.raises(TypeError, match="XXX"): # TODO (v1 & v2) | (v2 & v3) with pytest.raises(TypeError, match="XXX"): # TODO (v1 & v2) | (v2 | v3) with pytest.raises(TypeError, match="XXX"): # TODO v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) with pytest.raises(TypeError, match="XXX"): # TODO (v1 | v2) | (v2 & v3) with pytest.raises(TypeError, match="XXX"): # TODO v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) with pytest.raises(TypeError, match="XXX"): # TODO (v1 | v2) & (v2 | v3) with pytest.raises(TypeError, match="XXX"): # TODO (v1 & v2) & (v2 | v3) with pytest.raises(TypeError, match="XXX"): # TODO (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We don't (yet) differentiate between infix and methods + with pytest.raises(TypeError, match="XXX"): # TODO + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_matrix(): + # Adapted from test_multi_infix_vector + D0 = Vector.from_scalar(0, 3).diag() + v1 = Matrix.from_coo([0, 1], [0, 0], [1, 2], nrows=3) # 1 2 . + v2 = Matrix.from_coo([1, 2], [0, 0], [1, 2], nrows=3) # . 1 2 + v3 = Matrix.from_coo([2, 0], [0, 0], [1, 2], nrows=3) # 2 . 1 + # ewise_add + result = binary.plus((v1 | v2) | v3).new() + expected = Matrix.from_scalar(3, 3, 1) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = monoid.min(v1 | v2 | v3).new() + expected = Matrix.from_scalar(1, 3, 1) + assert result.isequal(expected) + # ewise_mult + result = monoid.max((v1 & v2) & v3).new() + expected = Matrix(int, 3, 1) + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = Matrix.from_coo([1], [0], [1], nrows=3) + assert result.isequal(expected) + # ewise_union + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = Matrix.from_scalar(13, 3, 1) + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + expected = Matrix.from_scalar(13.0, 3, 1) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + # mxm + assert op.plus_plus(v1.T @ v1)[0, 0].value == 6 + assert op.plus_plus(v1 @ (v1.T @ D0))[0, 0].value == 2 + assert op.plus_plus((v1.T @ D0) @ v1)[0, 0].value == 6 + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) with pytest.raises(TypeError, match="XXX"): # TODO (v1 | v2) & (v2 & v3) diff --git a/scripts/check_versions.sh b/scripts/check_versions.sh index 7c09bc168..b9a065829 100755 --- a/scripts/check_versions.sh +++ b/scripts/check_versions.sh @@ -6,7 +6,7 @@ conda search 'flake8-bugbear[channel=conda-forge]>=23.9.16' conda search 'flake8-simplify[channel=conda-forge]>=0.21.0' conda search 'numpy[channel=conda-forge]>=1.26.0' -conda search 'pandas[channel=conda-forge]>=2.1.1' +conda search 'pandas[channel=conda-forge]>=2.1.2' conda search 'scipy[channel=conda-forge]>=1.11.3' conda search 'networkx[channel=conda-forge]>=3.2' conda search 'awkward[channel=conda-forge]>=2.4.6' From ecaf6ac6e6211a4a088efdf26228703fe555dd19 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Sun, 29 Oct 2023 14:55:58 -0500 Subject: [PATCH 09/10] Getting very close I think --- graphblas/core/base.py | 16 +- graphblas/core/infix.py | 83 +++++--- graphblas/core/matrix.py | 314 ++++++++++++++---------------- graphblas/core/operator/base.py | 4 +- graphblas/core/operator/binary.py | 4 +- graphblas/core/operator/monoid.py | 26 +-- graphblas/core/scalar.py | 82 +++++++- graphblas/core/vector.py | 311 ++++++++++++++--------------- graphblas/tests/test_infix.py | 242 ++++++++++++++++++++--- graphblas/tests/test_matrix.py | 9 +- graphblas/tests/test_scalar.py | 4 +- graphblas/tests/test_vector.py | 8 +- 12 files changed, 664 insertions(+), 439 deletions(-) diff --git a/graphblas/core/base.py b/graphblas/core/base.py index e384799ba..5658e99c1 100644 --- a/graphblas/core/base.py +++ b/graphblas/core/base.py @@ -263,30 +263,30 @@ def __call__( ) def __or__(self, other): - from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr, _ewise_infix_expr + from .infix import _ewise_infix_expr, _ewise_mult_expr_types - if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): + if isinstance(other, _ewise_mult_expr_types): raise TypeError("XXX") return _ewise_infix_expr(self, other, method="ewise_add", within="__or__") def __ror__(self, other): - from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr, _ewise_infix_expr + from .infix import _ewise_infix_expr, _ewise_mult_expr_types - if isinstance(other, (VectorEwiseMultExpr, MatrixEwiseMultExpr)): + if isinstance(other, _ewise_mult_expr_types): raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_add", within="__ror__") def __and__(self, other): - from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr, _ewise_infix_expr + from .infix import _ewise_add_expr_types, _ewise_infix_expr - if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): + if isinstance(other, _ewise_add_expr_types): raise TypeError("XXX") return _ewise_infix_expr(self, other, method="ewise_mult", within="__and__") def __rand__(self, other): - from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr, _ewise_infix_expr + from .infix import _ewise_add_expr_types, _ewise_infix_expr - if isinstance(other, (VectorEwiseAddExpr, MatrixEwiseAddExpr)): + if isinstance(other, _ewise_add_expr_types): raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_mult", within="__rand__") diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 56de3c6cc..e1dc15bbe 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -125,6 +125,18 @@ class ScalarEwiseAddExpr(ScalarInfixExpr): _to_expr = _ewise_add_to_expr + __or__ = Scalar.__or__ + __ror__ = Scalar.__ror__ + _ewise_add = Scalar._ewise_add + _ewise_mult = Scalar._ewise_mult + _ewise_union = Scalar._ewise_union + + def __and__(self, other, *, within="__and__"): + raise TypeError("XXX") + + def __rand__(self, other): + self.__and__(other, within="__rand__") + class ScalarEwiseMultExpr(ScalarInfixExpr): __slots__ = () @@ -134,6 +146,18 @@ class ScalarEwiseMultExpr(ScalarInfixExpr): _to_expr = _ewise_mult_to_expr + __and__ = Scalar.__and__ + __rand__ = Scalar.__rand__ + _ewise_add = Scalar._ewise_add + _ewise_mult = Scalar._ewise_mult + _ewise_union = Scalar._ewise_union + + def __or__(self, other): + raise TypeError("XXX") + + def __ror__(self, other): + raise TypeError("XXX") + class ScalarMatMulExpr(ScalarInfixExpr): __slots__ = () @@ -238,17 +262,13 @@ class VectorEwiseAddExpr(VectorInfixExpr): _to_expr = _ewise_add_to_expr + __and__ = ScalarEwiseAddExpr.__and__ # raises + __rand__ = ScalarEwiseAddExpr.__rand__ # raises __or__ = Vector.__or__ __ror__ = Vector.__ror__ - ewise_add = Vector.ewise_add - ewise_mult = Vector.ewise_mult - ewise_union = Vector.ewise_union - - def __and__(self, other, *, within="__and__"): - raise TypeError("XXX") - - def __rand__(self, other): - self.__and__(other, within="__rand__") + _ewise_add = Vector._ewise_add + _ewise_mult = Vector._ewise_mult + _ewise_union = Vector._ewise_union class VectorEwiseMultExpr(VectorInfixExpr): @@ -261,15 +281,11 @@ class VectorEwiseMultExpr(VectorInfixExpr): __and__ = Vector.__and__ __rand__ = Vector.__rand__ - ewise_add = Vector.ewise_add - ewise_mult = Vector.ewise_mult - ewise_union = Vector.ewise_union - - def __or__(self, other): - raise TypeError("XXX") - - def __ror__(self, other): - raise TypeError("XXX") + __or__ = ScalarEwiseMultExpr.__or__ # raises + __ror__ = ScalarEwiseMultExpr.__ror__ # raises + _ewise_add = Vector._ewise_add + _ewise_mult = Vector._ewise_mult + _ewise_union = Vector._ewise_union class VectorMatMulExpr(VectorInfixExpr): @@ -284,8 +300,8 @@ def __init__(self, left, right, *, method_name, size): __matmul__ = Vector.__matmul__ __rmatmul__ = Vector.__rmatmul__ - inner = Vector.inner - vxm = Vector.vxm + _inner = Vector._inner + _vxm = Vector._vxm utils._output_types[VectorEwiseAddExpr] = Vector @@ -404,13 +420,13 @@ class MatrixEwiseAddExpr(MatrixInfixExpr): _to_expr = _ewise_add_to_expr - __and__ = VectorEwiseAddExpr.__and__ - __rand__ = VectorEwiseAddExpr.__rand__ + __and__ = VectorEwiseAddExpr.__and__ # raises + __rand__ = VectorEwiseAddExpr.__rand__ # raises __or__ = Matrix.__or__ __ror__ = Matrix.__ror__ - ewise_add = Matrix.ewise_add - ewise_mult = Matrix.ewise_mult - ewise_union = Matrix.ewise_union + _ewise_add = Matrix._ewise_add + _ewise_mult = Matrix._ewise_mult + _ewise_union = Matrix._ewise_union class MatrixEwiseMultExpr(MatrixInfixExpr): @@ -423,11 +439,11 @@ class MatrixEwiseMultExpr(MatrixInfixExpr): __and__ = Matrix.__and__ __rand__ = Matrix.__rand__ - __or__ = VectorEwiseMultExpr.__or__ - __ror__ = VectorEwiseMultExpr.__ror__ - ewise_add = Matrix.ewise_add - ewise_mult = Matrix.ewise_mult - ewise_union = Matrix.ewise_union + __or__ = VectorEwiseMultExpr.__or__ # raises + __ror__ = VectorEwiseMultExpr.__ror__ # raises + _ewise_add = Matrix._ewise_add + _ewise_mult = Matrix._ewise_mult + _ewise_union = Matrix._ewise_union class MatrixMatMulExpr(MatrixInfixExpr): @@ -443,8 +459,8 @@ def __init__(self, left, right, *, nrows, ncols): __matmul__ = Matrix.__matmul__ __rmatmul__ = Matrix.__rmatmul__ - mxm = Matrix.mxm - mxv = Matrix.mxv + _mxm = Matrix._mxm + _mxv = Matrix._mxv utils._output_types[MatrixEwiseAddExpr] = Matrix @@ -560,5 +576,8 @@ def _matmul_infix_expr(left, right, *, within): return ScalarMatMulExpr(left, right) +_ewise_add_expr_types = (MatrixEwiseAddExpr, VectorEwiseAddExpr, ScalarEwiseAddExpr) +_ewise_mult_expr_types = (MatrixEwiseMultExpr, VectorEwiseMultExpr, ScalarEwiseMultExpr) + # Import infixmethods, which has side effects from . import infixmethods # noqa: E402, F401 isort:skip diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 505395be1..34789d68d 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -10,7 +10,7 @@ from . import _supports_udfs, automethods, ffi, lib, utils from .base import BaseExpression, BaseType, _check_mask, call from .descriptor import lookup as descriptor_lookup -from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater +from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, InfixExprBase, Updater from .mask import Mask, StructuralMask, ValueMask from .operator import ( UNKNOWN_OPCLASS, @@ -1945,36 +1945,39 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax C << monoid.max(A | B) """ - from .infix import ( - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ) + return self._ewise_add(other, op) + def _ewise_add(self, other, op=monoid.plus, is_infix=False): method_name = "ewise_add" - other = self._expect_type( - other, - ( - Matrix, - TransposedMatrix, - Vector, - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ), - within=method_name, - argname="other", - op=op, - ) - op = _get_typed_op_from_exprs(op, self, other, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") - if type(self) is MatrixEwiseMultExpr: - raise TypeError("XXX") - if type(other) in {MatrixEwiseMultExpr, VectorEwiseMultExpr}: - raise TypeError("XXX") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, MatrixEwiseAddExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 1: # Broadcast rowwise from the right if self._ncols != other._size: @@ -1983,14 +1986,6 @@ def ewise_add(self, other, op=monoid.plus): f"to rows of Matrix in {method_name}. Matrix.ncols (={self._ncols}) " f"must equal Vector.size (={other._size})." ) - if type(self) is MatrixEwiseAddExpr: - self = self._expect_type( - op(self), Matrix, within=method_name, argname="self", op=op - ) - if type(other) is VectorEwiseAddExpr: - other = self._expect_type( - op(other), Vector, within=method_name, argname="other", op=op - ) return MatrixExpression( method_name, None, @@ -1999,10 +1994,6 @@ def ewise_add(self, other, op=monoid.plus): ncols=self._ncols, op=op, ) - if type(self) is MatrixEwiseAddExpr: - self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) - if type(other) is MatrixEwiseAddExpr: - other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, f"GrB_Matrix_eWiseAdd_{op.opclass}", @@ -2044,36 +2035,39 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax C << binary.gt(A & B) """ - from .infix import ( - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ) + return self._ewise_mult(other, op) + def _ewise_mult(self, other, op=binary.times, is_infix=False): method_name = "ewise_mult" - other = self._expect_type( - other, - ( - Matrix, - TransposedMatrix, - Vector, - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ), - within=method_name, - argname="other", - op=op, - ) - op = _get_typed_op_from_exprs(op, self, other, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") - if type(self) is MatrixEwiseAddExpr: - raise TypeError("XXX") - if type(other) in {MatrixEwiseAddExpr, VectorEwiseAddExpr}: - raise TypeError("XXX") + if is_infix: + from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseMultExpr, VectorEwiseMultExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, MatrixEwiseMultExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 1: # Broadcast rowwise from the right if self._ncols != other._size: @@ -2082,14 +2076,6 @@ def ewise_mult(self, other, op=binary.times): f"to rows of Matrix in {method_name}. Matrix.ncols (={self._ncols}) " f"must equal Vector.size (={other._size})." ) - if type(self) is MatrixEwiseMultExpr: - self = self._expect_type( - op(self), Matrix, within=method_name, argname="self", op=op - ) - if type(other) is VectorEwiseMultExpr: - other = self._expect_type( - op(other), Vector, within=method_name, argname="other", op=op - ) return MatrixExpression( method_name, None, @@ -2098,10 +2084,6 @@ def ewise_mult(self, other, op=binary.times): ncols=self._ncols, op=op, ) - if type(self) is MatrixEwiseMultExpr: - self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) - if type(other) is MatrixEwiseMultExpr: - other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) expr = MatrixExpression( method_name, f"GrB_Matrix_eWiseMult_{op.opclass}", @@ -2147,30 +2129,31 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax C << binary.div(A | B, left_default=1, right_default=1) """ - from .infix import ( - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ) + return self._ewise_union(other, op, left_default, right_default) + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): method_name = "ewise_union" - other = self._expect_type( - other, - ( - Matrix, - TransposedMatrix, - Vector, - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ), - within=method_name, - argname="other", - op=op, - ) - temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + left_dtype = temp_op.type dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: @@ -2208,8 +2191,12 @@ def ewise_union(self, other, op, left_default, right_default): else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") - op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + if is_infix: + op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") + op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + else: + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") if op1 is not op2: left_dtype = unify(op1.type, op2.type, is_right_scalar=True) right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) @@ -2219,10 +2206,13 @@ def ewise_union(self, other, op, left_default, right_default): self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop - if type(self) is MatrixEwiseMultExpr: - raise TypeError("XXX") - if type(other) in {MatrixEwiseMultExpr, VectorEwiseMultExpr}: - raise TypeError("XXX") + + if is_infix: + if isinstance(self, MatrixEwiseAddExpr): + self = op(self, left_default=left, right_default=right).new() + if isinstance(other, InfixExprBase): + other = op(other, left_default=left, right_default=right).new() + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 1: # Broadcast rowwise from the right @@ -2232,22 +2222,6 @@ def ewise_union(self, other, op, left_default, right_default): f"to rows of Matrix in {method_name}. Matrix.ncols (={self._ncols}) " f"must equal Vector.size (={other._size})." ) - if type(self) is MatrixEwiseAddExpr: - self = self._expect_type( - op(self, left_default=left, right_default=right), - Matrix, - within=method_name, - argname="self", - op=op, - ) - if type(other) is VectorEwiseAddExpr: - other = self._expect_type( - op(other, left_default=left, right_default=right), - Vector, - within=method_name, - argname="other", - op=op, - ) return MatrixExpression( method_name, None, @@ -2257,22 +2231,6 @@ def ewise_union(self, other, op, left_default, right_default): ncols=self._ncols, op=op, ) - if type(self) is MatrixEwiseAddExpr: - self = self._expect_type( - op(self, left_default=left, right_default=right), - Matrix, - within=method_name, - argname="self", - op=op, - ) - if type(other) is MatrixEwiseAddExpr: - other = self._expect_type( - op(other, left_default=left, right_default=right), - Matrix, - within=method_name, - argname="other", - op=op, - ) if backend == "suitesparse": expr = MatrixExpression( method_name, @@ -2324,18 +2282,27 @@ def mxv(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(A @ v) """ - from .infix import MatrixMatMulExpr, VectorMatMulExpr + return self._mxv(other, op) + def _mxv(self, other, op=semiring.plus_times, is_infix=False): method_name = "mxv" - other = self._expect_type( - other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") - if type(self) is MatrixMatMulExpr: - self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) - if type(other) is VectorMatMulExpr: - other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) + if is_infix: + from .infix import MatrixMatMulExpr, VectorMatMulExpr + + other = self._expect_type( + other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, MatrixMatMulExpr): + self = op(self).new() + if isinstance(other, VectorMatMulExpr): + other = op(other).new() + else: + other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = VectorExpression( method_name, "GrB_mxv", @@ -2375,22 +2342,33 @@ def mxm(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(A @ B) """ - from .infix import MatrixMatMulExpr + return self._mxm(other, op) + def _mxm(self, other, op=semiring.plus_times, is_infix=False): method_name = "mxm" - other = self._expect_type( - other, - (Matrix, TransposedMatrix, MatrixMatMulExpr), - within=method_name, - argname="other", - op=op, - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") - if type(self) is MatrixMatMulExpr: - self = self._expect_type(op(self), Matrix, within=method_name, argname="self", op=op) - if type(other) is MatrixMatMulExpr: - other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) + if is_infix: + from .infix import MatrixMatMulExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, MatrixMatMulExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, MatrixMatMulExpr): + self = op(self).new() + if isinstance(other, MatrixMatMulExpr): + other = op(other).new() + else: + other = self._expect_type( + other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = MatrixExpression( method_name, "GrB_mxm", @@ -4006,6 +3984,12 @@ def to_dicts(self, order="rowwise"): reposition = Matrix.reposition power = Matrix.power + _ewise_add = Matrix._ewise_add + _ewise_mult = Matrix._ewise_mult + _ewise_union = Matrix._ewise_union + _mxv = Matrix._mxv + _mxm = Matrix._mxm + # Operator sugar __or__ = Matrix.__or__ __ror__ = Matrix.__ror__ diff --git a/graphblas/core/operator/base.py b/graphblas/core/operator/base.py index d66aa2f4a..59482b47d 100644 --- a/graphblas/core/operator/base.py +++ b/graphblas/core/operator/base.py @@ -111,7 +111,9 @@ def _call_op(op, left, right=None, thunk=None, **kwargs): if right is None and thunk is None: if isinstance(left, InfixExprBase): # op(A & B), op(A | B), op(A @ B) - return getattr(left.left, left.method_name)(left.right, op, **kwargs) + return getattr(left.left, f"_{left.method_name}")( + left.right, op, is_infix=True, **kwargs + ) if find_opclass(op)[1] == "Semiring": raise TypeError( f"Bad type when calling {op!r}. Got type: {type(left)}.\n" diff --git a/graphblas/core/operator/binary.py b/graphblas/core/operator/binary.py index 676ed0970..278ee3183 100644 --- a/graphblas/core/operator/binary.py +++ b/graphblas/core/operator/binary.py @@ -94,7 +94,9 @@ def __call__(self, left, right=None, *, left_default=None, right_default=None): f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " "are Vectors or Matrices, and left_default and right_default are scalars." ) - return left.left.ewise_union(left.right, self, left_default, right_default) + return left.left._ewise_union( + left.right, self, left_default, right_default, is_infix=True + ) return _call_op(self, left, right) @property diff --git a/graphblas/core/operator/monoid.py b/graphblas/core/operator/monoid.py index fc327b4a7..21d2b7cac 100644 --- a/graphblas/core/operator/monoid.py +++ b/graphblas/core/operator/monoid.py @@ -19,10 +19,9 @@ ) from ...exceptions import check_status_carg from .. import ffi, lib -from ..expr import InfixExprBase from ..utils import libget -from .base import OpBase, ParameterizedUdf, TypedOpBase, _call_op, _hasop -from .binary import BinaryOp, ParameterizedBinaryOp +from .base import OpBase, ParameterizedUdf, TypedOpBase, _hasop +from .binary import BinaryOp, ParameterizedBinaryOp, TypedBuiltinBinaryOp ffi_new = ffi.new @@ -36,25 +35,6 @@ def __init__(self, parent, name, type_, return_type, gb_obj, gb_name): super().__init__(parent, name, type_, return_type, gb_obj, gb_name) self._identity = None - def __call__(self, left, right=None, *, left_default=None, right_default=None): - if left_default is not None or right_default is not None: - if ( - left_default is None - or right_default is None - or right is not None - or not isinstance(left, InfixExprBase) - or left.method_name != "ewise_add" - ): - raise TypeError( - "Specifying `left_default` or `right_default` keyword arguments implies " - "performing `ewise_union` operation with infix notation.\n" - "There is only one valid way to do this:\n\n" - f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " - "are Vectors or Matrices, and left_default and right_default are scalars." - ) - return left.left.ewise_union(left.right, self, left_default, right_default) - return _call_op(self, left, right) - @property def identity(self): if self._identity is None: @@ -84,6 +64,8 @@ def is_idempotent(self): """True if ``monoid(x, x) == x`` for any x.""" return self.parent.is_idempotent + __call__ = TypedBuiltinBinaryOp.__call__ + class TypedUserMonoid(TypedOpBase): __slots__ = "binaryop", "identity" diff --git a/graphblas/core/scalar.py b/graphblas/core/scalar.py index 8a95e1d71..bcbbdadd4 100644 --- a/graphblas/core/scalar.py +++ b/graphblas/core/scalar.py @@ -30,12 +30,12 @@ def _scalar_index(name): return self -def _s_union_s(updater, left, right, left_default, right_default, op, dtype): +def _s_union_s(updater, left, right, left_default, right_default, op): opts = updater.opts - new_left = left.dup(dtype, clear=True) + new_left = left.dup(op.type, clear=True) new_left(**opts) << binary.second(right, left_default) new_left(**opts) << binary.first(left | new_left) - new_right = right.dup(dtype, clear=True) + new_right = right.dup(op.type2, clear=True) new_right(**opts) << binary.second(left, right_default) new_right(**opts) << binary.first(right | new_right) updater << op(new_left & new_right) @@ -629,7 +629,23 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax c << monoid.max(a | b) """ + return self._ewise_add(other, op) + + def _ewise_add(self, other, op=monoid.plus, is_infix=False): method_name = "ewise_add" + if is_infix: + from .infix import ScalarEwiseAddExpr + + # This is a little different than how we handle ewise_add for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseAddExpr): + self = op(self).new() + if isinstance(other, ScalarEwiseAddExpr): + other = op(other).new() + if type(other) is not Scalar: dtype = self.dtype if self.dtype._is_udt else None try: @@ -683,7 +699,23 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax c << binary.gt(a & b) """ + return self._ewise_mult(other, op) + + def _ewise_mult(self, other, op=binary.times, is_infix=False): method_name = "ewise_mult" + if is_infix: + from .infix import ScalarEwiseMultExpr + + # This is a little different than how we handle ewise_mult for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseMultExpr): + self = op(self).new() + if isinstance(other, ScalarEwiseMultExpr): + other = op(other).new() + if type(other) is not Scalar: dtype = self.dtype if self.dtype._is_udt else None try: @@ -741,8 +773,25 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax c << binary.div(a | b, left_default=1, right_default=1) """ + return self._ewise_union(other, op, left_default, right_default) + + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): method_name = "ewise_union" - dtype = self.dtype if self.dtype._is_udt else None + if is_infix: + from .infix import ScalarEwiseAddExpr + + # This is a little different than how we handle ewise_union for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseAddExpr): + self = op(self, left_default=left_default, right_default=right_default).new() + if isinstance(other, ScalarEwiseAddExpr): + other = op(other, left_default=left_default, right_default=right_default).new() + + right_dtype = self.dtype + dtype = right_dtype if right_dtype._is_udt else None if type(other) is not Scalar: try: other = Scalar.from_value(other, dtype, is_cscalar=False, name="") @@ -755,6 +804,13 @@ def ewise_union(self, other, op, left_default, right_default): extra_message="Literal scalars also accepted.", op=op, ) + else: + other = _as_scalar(other, dtype, is_cscalar=False) # pragma: is_grbscalar + + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -771,6 +827,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -787,12 +845,19 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - defaults_dtype = unify(left.dtype, right.dtype) - args_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, defaults_dtype, args_dtype, kind="binary") + + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if backend == "suitesparse": expr = ScalarExpression( @@ -805,11 +870,10 @@ def ewise_union(self, other, op, left_default, right_default): scalar_as_vector=True, ) else: - dtype = unify(defaults_dtype, args_dtype) expr = ScalarExpression( method_name, None, - [self, left, other, right, _s_union_s, (self, other, left, right, op, dtype)], + [self, left, other, right, _s_union_s, (self, other, left, right, op)], op=op, expr_repr=expr_repr, is_cscalar=False, diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index 0a80f9118..feb95ed02 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -9,7 +9,7 @@ from . import _supports_udfs, automethods, ffi, lib, utils from .base import BaseExpression, BaseType, _check_mask, call from .descriptor import lookup as descriptor_lookup -from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater +from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, InfixExprBase, Updater from .mask import Mask, StructuralMask, ValueMask from .operator import ( UNKNOWN_OPCLASS, @@ -1045,37 +1045,41 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax w << monoid.max(u | v) """ - from .infix import ( - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ) + return self._ewise_add(other, op) + + def _ewise_add(self, other, op=monoid.plus, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_add" - other = self._expect_type( - other, - ( - Vector, - Matrix, - TransposedMatrix, - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ), - within=method_name, - argname="other", - op=op, - ) - op = _get_typed_op_from_exprs(op, self, other, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") - if type(self) is VectorEwiseMultExpr: - raise TypeError("XXX") - if type(other) in {VectorEwiseMultExpr, MatrixEwiseMultExpr}: - raise TypeError("XXX") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, VectorEwiseAddExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 2: # Broadcast columnwise from the left if other._nrows != self._size: @@ -1085,14 +1089,6 @@ def ewise_add(self, other, op=monoid.plus): f"to columns of Matrix in {method_name}. Matrix.nrows (={other._nrows}) " f"must equal Vector.size (={self._size})." ) - if type(self) is VectorEwiseAddExpr: - self = self._expect_type( - op(self), Vector, within=method_name, argname="self", op=op - ) - if type(other) is MatrixEwiseAddExpr: - other = self._expect_type( - op(other), Matrix, within=method_name, argname="other", op=op - ) return MatrixExpression( method_name, None, @@ -1101,10 +1097,6 @@ def ewise_add(self, other, op=monoid.plus): ncols=other._ncols, op=op, ) - if type(self) is VectorEwiseAddExpr: - self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) - if type(other) is VectorEwiseAddExpr: - other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, f"GrB_Vector_eWiseAdd_{op.opclass}", @@ -1144,37 +1136,40 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax w << binary.gt(u & v) """ - from .infix import ( - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ) + return self._ewise_mult(other, op) + + def _ewise_mult(self, other, op=binary.times, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_mult" - other = self._expect_type( - other, - ( - Vector, - Matrix, - TransposedMatrix, - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ), - within=method_name, - argname="other", - op=op, - ) - op = _get_typed_op_from_exprs(op, self, other, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") - if type(self) is VectorEwiseAddExpr: - raise TypeError("XXX") - if type(other) in {VectorEwiseAddExpr, MatrixEwiseAddExpr}: - raise TypeError("XXX") + if is_infix: + from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseMultExpr, VectorEwiseMultExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, VectorEwiseMultExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if other.ndim == 2: # Broadcast columnwise from the left if other._nrows != self._size: @@ -1183,14 +1178,6 @@ def ewise_mult(self, other, op=binary.times): f"to columns of Matrix in {method_name}. Matrix.nrows (={other._nrows}) " f"must equal Vector.size (={self._size})." ) - if type(self) is VectorEwiseMultExpr: - self = self._expect_type( - op(self), Vector, within=method_name, argname="self", op=op - ) - if type(other) is MatrixEwiseMultExpr: - other = self._expect_type( - op(other), Matrix, within=method_name, argname="other", op=op - ) return MatrixExpression( method_name, None, @@ -1199,10 +1186,6 @@ def ewise_mult(self, other, op=binary.times): ncols=other._ncols, op=op, ) - if type(self) is VectorEwiseMultExpr: - self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) - if type(other) is VectorEwiseMultExpr: - other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) expr = VectorExpression( method_name, f"GrB_Vector_eWiseMult_{op.opclass}", @@ -1246,31 +1229,33 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax w << binary.div(u | v, left_default=1, right_default=1) """ - from .infix import ( - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ) + return self._ewise_union(other, op, left_default, right_default) + + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_union" - other = self._expect_type( - other, - ( - Vector, - Matrix, - TransposedMatrix, - MatrixEwiseAddExpr, - MatrixEwiseMultExpr, - VectorEwiseAddExpr, - VectorEwiseMultExpr, - ), - within=method_name, - argname="other", - op=op, - ) - temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + left_dtype = temp_op.type dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: @@ -1308,8 +1293,12 @@ def ewise_union(self, other, op, left_default, right_default): else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") - op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + if is_infix: + op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") + op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + else: + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") if op1 is not op2: left_dtype = unify(op1.type, op2.type, is_right_scalar=True) right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) @@ -1319,10 +1308,13 @@ def ewise_union(self, other, op, left_default, right_default): self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop - if type(self) is VectorEwiseMultExpr: - raise TypeError("XXX") - if type(other) in {VectorEwiseMultExpr, MatrixEwiseMultExpr}: - raise TypeError("XXX") + + if is_infix: + if isinstance(self, VectorEwiseAddExpr): + self = op(self, left_default=left, right_default=right).new() + if isinstance(other, InfixExprBase): + other = op(other, left_default=left, right_default=right).new() + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 2: # Broadcast columnwise from the left @@ -1332,22 +1324,6 @@ def ewise_union(self, other, op, left_default, right_default): f"to columns of Matrix in {method_name}. Matrix.nrows (={other._nrows}) " f"must equal Vector.size (={self._size})." ) - if type(self) is VectorEwiseAddExpr: - self = self._expect_type( - op(self, left_default=left, right_default=right), - Vector, - within=method_name, - argname="self", - op=op, - ) - if type(other) is MatrixEwiseAddExpr: - other = self._expect_type( - op(other, left_default=left, right_default=right), - Matrix, - within=method_name, - argname="other", - op=op, - ) return MatrixExpression( method_name, None, @@ -1357,22 +1333,6 @@ def ewise_union(self, other, op, left_default, right_default): ncols=other._ncols, op=op, ) - if type(self) is VectorEwiseAddExpr: - self = self._expect_type( - op(self, left_default=left, right_default=right), - Vector, - within=method_name, - argname="self", - op=op, - ) - if type(other) is VectorEwiseAddExpr: - other = self._expect_type( - op(other, left_default=left, right_default=right), - Vector, - within=method_name, - argname="other", - op=op, - ) if backend == "suitesparse": expr = VectorExpression( method_name, @@ -1423,23 +1383,35 @@ def vxm(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(v @ A) """ - from .infix import MatrixMatMulExpr, VectorMatMulExpr + return self._vxm(other, op) + + def _vxm(self, other, op=semiring.plus_times, is_infix=False): from .matrix import Matrix, TransposedMatrix method_name = "vxm" - other = self._expect_type( - other, - (Matrix, TransposedMatrix, MatrixMatMulExpr), - within=method_name, - argname="other", - op=op, - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") - if type(self) is VectorMatMulExpr: - self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) - if type(other) is MatrixMatMulExpr: - other = self._expect_type(op(other), Matrix, within=method_name, argname="other", op=op) + if is_infix: + from .infix import MatrixMatMulExpr, VectorMatMulExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, MatrixMatMulExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, VectorMatMulExpr): + self = op(self).new() + if isinstance(other, MatrixMatMulExpr): + other = op(other).new() + else: + other = self._expect_type( + other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = VectorExpression( method_name, "GrB_vxm", @@ -1781,18 +1753,27 @@ def inner(self, other, op=semiring.plus_times): `Matrix Multiplication <../user_guide/operations.html#matrix-multiply>`__ family of functions. """ - from .infix import VectorMatMulExpr + return self._inner(other, op) + def _inner(self, other, op=semiring.plus_times, is_infix=False): method_name = "inner" - other = self._expect_type( - other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") - if type(self) is VectorMatMulExpr: - self = self._expect_type(op(self), Vector, within=method_name, argname="self", op=op) - if type(other) is VectorMatMulExpr: - other = self._expect_type(op(other), Vector, within=method_name, argname="other", op=op) + if is_infix: + from .infix import VectorMatMulExpr + + other = self._expect_type( + other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, VectorMatMulExpr): + self = op(self).new() + if isinstance(other, VectorMatMulExpr): + other = op(other).new() + else: + other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = ScalarExpression( method_name, "GrB_vxm", diff --git a/graphblas/tests/test_infix.py b/graphblas/tests/test_infix.py index cc28f1134..d141a08fd 100644 --- a/graphblas/tests/test_infix.py +++ b/graphblas/tests/test_infix.py @@ -369,7 +369,6 @@ def test_infix_expr_value_types(): assert expr._expr._value is None -@autocompute def test_multi_infix_vector(): D0 = Vector.from_scalar(0, 3).diag() v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . @@ -403,12 +402,12 @@ def test_multi_infix_vector(): result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() assert result.isequal(expected) # inner - assert op.plus_plus(v1 @ v1).value == 6 - assert op.plus_plus(v1 @ (v1 @ D0)).value == 6 - assert op.plus_plus((D0 @ v1) @ v1).value == 6 + assert op.plus_plus(v1 @ v1).new().value == 6 + assert op.plus_plus(v1 @ (v1 @ D0)).new().value == 6 + assert op.plus_plus((D0 @ v1) @ v1).new().value == 6 # matrix-vector ewise_add result = binary.plus((D0 | v1) | v2).new() - expected = binary.plus(binary.plus(D0 | v1) | v2).new() + expected = binary.plus(binary.plus(D0 | v1).new() | v2).new() assert result.isequal(expected) result = binary.plus(D0 | (v1 | v2)).new() assert result.isequal(expected) @@ -418,7 +417,7 @@ def test_multi_infix_vector(): assert result.isequal(expected.T) # matrix-vector ewise_mult result = binary.plus((D0 & v1) & v2).new() - expected = binary.plus(binary.plus(D0 & v1) & v2).new() + expected = binary.plus(binary.plus(D0 & v1).new() & v2).new() assert result.isequal(expected) assert result.nvals > 0 result = binary.plus(D0 & (v1 & v2)).new() @@ -430,16 +429,16 @@ def test_multi_infix_vector(): # matrix-vector ewise_union kwargs = {"left_default": 10, "right_default": 20} result = binary.plus((D0 | v1) | v2, **kwargs).new() - expected = binary.plus(binary.plus(D0 | v1, **kwargs) | v2, **kwargs).new() + expected = binary.plus(binary.plus(D0 | v1, **kwargs).new() | v2, **kwargs).new() assert result.isequal(expected) result = binary.plus(D0 | (v1 | v2), **kwargs).new() - expected = binary.plus(D0 | binary.plus(v1 | v2, **kwargs), **kwargs).new() + expected = binary.plus(D0 | binary.plus(v1 | v2, **kwargs).new(), **kwargs).new() assert result.isequal(expected) result = binary.plus((v1 | v2) | D0, **kwargs).new() - expected = binary.plus(binary.plus(v1 | v2, **kwargs) | D0, **kwargs).new() + expected = binary.plus(binary.plus(v1 | v2, **kwargs).new() | D0, **kwargs).new() assert result.isequal(expected) result = binary.plus(v1 | (v2 | D0), **kwargs).new() - expected = binary.plus(v1 | binary.plus(v2 | D0, **kwargs), **kwargs).new() + expected = binary.plus(v1 | binary.plus(v2 | D0, **kwargs).new(), **kwargs).new() assert result.isequal(expected) # vxm, mxv result = op.plus_plus((D0 @ v1) @ D0).new() @@ -485,22 +484,41 @@ def test_multi_infix_vector(): with pytest.raises(TypeError, match="XXX"): # TODO (v1 | v2) & (v2 & v3) - # We don't (yet) differentiate between infix and methods - with pytest.raises(TypeError, match="XXX"): # TODO + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): v1.ewise_add(v2 & v3) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="Automatic computation"): (v1 & v2).ewise_add(v3) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="to automatically compute"): v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="Automatic computation"): (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="to automatically compute"): v1.ewise_mult(v2 | v3) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="Automatic computation"): (v1 | v2).ewise_mult(v3) @autocompute +def test_multi_infix_vector_auto(): + v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . + v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 + v3 = Vector.from_coo([2, 0], [1, 2], size=3) # 2 . 1 + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 | v2).ewise_mult(v3) + + def test_multi_infix_matrix(): # Adapted from test_multi_infix_vector D0 = Vector.from_scalar(0, 3).diag() @@ -535,9 +553,9 @@ def test_multi_infix_matrix(): result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() assert result.isequal(expected) # mxm - assert op.plus_plus(v1.T @ v1)[0, 0].value == 6 - assert op.plus_plus(v1 @ (v1.T @ D0))[0, 0].value == 2 - assert op.plus_plus((v1.T @ D0) @ v1)[0, 0].value == 6 + assert op.plus_plus(v1.T @ v1).new()[0, 0].new().value == 6 + assert op.plus_plus(v1 @ (v1.T @ D0)).new()[0, 0].new().value == 2 + assert op.plus_plus((v1.T @ D0) @ v1).new()[0, 0].new().value == 6 with pytest.raises(TypeError, match="XXX"): # TODO (v1 & v2) | v3 @@ -569,16 +587,188 @@ def test_multi_infix_matrix(): with pytest.raises(TypeError, match="XXX"): # TODO (v1 | v2) & (v2 & v3) - # We don't (yet) differentiate between infix and methods - with pytest.raises(TypeError, match="XXX"): # TODO + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): v1.ewise_add(v2 & v3) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="Automatic computation"): (v1 & v2).ewise_add(v3) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="to automatically compute"): v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="Automatic computation"): (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) - with pytest.raises(TypeError, match="XXX"): # TODO + with pytest.raises(TypeError, match="to automatically compute"): v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_matrix_auto(): + v1 = Matrix.from_coo([0, 1], [0, 0], [1, 2], nrows=3) # 1 2 . + v2 = Matrix.from_coo([1, 2], [0, 0], [1, 2], nrows=3) # . 1 2 + v3 = Matrix.from_coo([2, 0], [0, 0], [1, 2], nrows=3) # 2 . 1 + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 | v2).ewise_mult(v3) + + +def test_multi_infix_scalar(): + # Adapted from test_multi_infix_vector + v1 = Scalar.from_value(1) + v2 = Scalar.from_value(2) + v3 = Scalar(int) + # ewise_add + result = binary.plus((v1 | v2) | v3).new() + expected = 3 + assert result.isequal(expected) + result = binary.plus((1 | v2) | v3).new() + assert result.isequal(expected) + result = binary.plus((1 | v2) | 0).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | v3).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | 0).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | 0).new() + assert result.isequal(expected) + + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | 0)).new() + assert result.isequal(expected) + result = binary.plus(v1 | (2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | 0)).new() + assert result.isequal(expected) + + result = monoid.min(v1 | v2 | v3).new() + expected = 1 + assert result.isequal(expected) + # ewise_mult + result = monoid.max((v1 & v2) & v3).new() + expected = None + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = 1 + assert result.isequal(expected) + + result = monoid.min((1 & v2) & v1).new() + assert result.isequal(expected) + result = monoid.min((1 & v2) & 1).new() + assert result.isequal(expected) + result = monoid.min((v1 & 2) & v1).new() + assert result.isequal(expected) + result = monoid.min((v1 & 2) & 1).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & 1).new() + assert result.isequal(expected) + + result = monoid.min(1 & (v2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(1 & (2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(1 & (v2 & 1)).new() + assert result.isequal(expected) + result = monoid.min(v1 & (2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(v1 & (v2 & 1)).new() + assert result.isequal(expected) + + # ewise_union + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = 13 + assert result.isequal(expected) + result = binary.plus((1 | v2) | v3, left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | v3, left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(1 | (2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(v1 | (2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_scalar_auto(): + v1 = Scalar.from_value(1) + v2 = Scalar.from_value(2) + v3 = Scalar(int) + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): (v1 | v2).ewise_mult(v3) diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 7d7eee775..964b8f612 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -2834,16 +2834,17 @@ def test_auto(A, v): ]: # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() - val2 = getattr(expected, method)(expr) if method in {"__or__", "__ror__"} and type(expr) is MatrixEwiseMultExpr: # Doing e.g. `plus(A & B | C)` isn't allowed--make user be explicit with pytest.raises(TypeError): - assert val1.isequal(val2) + # assert val1.isequal(val2) + val2 = getattr(expected, method)(expr) with pytest.raises(TypeError): val3 = getattr(expr, method)(expected) with pytest.raises(TypeError): val4 = getattr(expr, method)(expr) else: + val2 = getattr(expected, method)(expr) assert val1.isequal(val2) val3 = getattr(expr, method)(expected) assert val1.isequal(val3) @@ -2957,7 +2958,7 @@ def test_expr_is_like_matrix(A): "setdiag", "update", } - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_union", "_ewise_mult", "_mxm", "_mxv"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Matrix. You may need to " "add an entry to `matrix` or `matrix_vector` set in `graphblas.core.automethods` " @@ -3022,7 +3023,7 @@ def test_index_expr_is_like_matrix(A): "resize", "setdiag", } - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_union", "_ewise_mult", "_mxm", "_mxv"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Matrix. You may need to " "add an entry to `matrix` or `matrix_vector` set in `graphblas.core.automethods` " diff --git a/graphblas/tests/test_scalar.py b/graphblas/tests/test_scalar.py index ba9903169..659d16f50 100644 --- a/graphblas/tests/test_scalar.py +++ b/graphblas/tests/test_scalar.py @@ -360,7 +360,7 @@ def test_expr_is_like_scalar(s): } if s.is_cscalar: expected.add("_empty") - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_union", "_ewise_mult", "_ewise_add"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Scalar. You may need to " "add an entry to `scalar` set in `graphblas.core.automethods` " @@ -402,7 +402,7 @@ def test_index_expr_is_like_scalar(s): } if s.is_cscalar: expected.add("_empty") - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_union", "_ewise_mult", "_ewise_add"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Scalar. You may need to " "add an entry to `scalar` set in `graphblas.core.automethods` " diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index d08c9cb3b..d677ff545 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -1583,16 +1583,16 @@ def test_auto(v): ]: # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() - val2 = getattr(expected, method)(expr) if method in {"__or__", "__ror__"} and type(expr) is VectorEwiseMultExpr: # Doing e.g. `plus(x & y | z)` isn't allowed--make user be explicit with pytest.raises(TypeError): - assert val1.isequal(val2) + val2 = getattr(expected, method)(expr) with pytest.raises(TypeError): val3 = getattr(expr, method)(expected) with pytest.raises(TypeError): val4 = getattr(expr, method)(expr) else: + val2 = getattr(expected, method)(expr) assert val1.isequal(val2) assert val1.isequal(val2.new()) val3 = getattr(expr, method)(expected) @@ -1664,7 +1664,7 @@ def test_expr_is_like_vector(v): "resize", "update", } - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_inner", "_vxm", "_ewise_add", "_ewise_union", "_ewise_mult"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Vector. You may need to " "add an entry to `vector` or `matrix_vector` set in `graphblas.core.automethods` " @@ -1713,7 +1713,7 @@ def test_index_expr_is_like_vector(v): "from_values", "resize", } - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_inner", "_vxm", "_ewise_add", "_ewise_union", "_ewise_mult"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Vector. You may need to " "add an entry to `vector` or `matrix_vector` set in `graphblas.core.automethods` " From 18a1e1b258c5798c9a374a61e7534916d5b4d67b Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Sun, 29 Oct 2023 22:29:55 -0500 Subject: [PATCH 10/10] Clean up, and handle dimension mismatch more explicitly in infix exprs --- graphblas/core/infix.py | 161 +++++++++++++++++---------------- graphblas/core/scalar.py | 1 - graphblas/tests/test_infix.py | 7 ++ graphblas/tests/test_matrix.py | 5 +- graphblas/tests/test_scalar.py | 4 +- graphblas/tests/test_vector.py | 4 +- 6 files changed, 97 insertions(+), 85 deletions(-) diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index e1dc15bbe..51714633c 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -1,5 +1,6 @@ from .. import backend, binary from ..dtypes import BOOL +from ..exceptions import DimensionMismatch from ..monoid import land, lor from ..semiring import any_pair from . import automethods, recorder, utils @@ -125,17 +126,18 @@ class ScalarEwiseAddExpr(ScalarInfixExpr): _to_expr = _ewise_add_to_expr + # Allow e.g. `plus(x | y | z)` __or__ = Scalar.__or__ __ror__ = Scalar.__ror__ _ewise_add = Scalar._ewise_add - _ewise_mult = Scalar._ewise_mult _ewise_union = Scalar._ewise_union - def __and__(self, other, *, within="__and__"): + # Don't allow e.g. `plus(x | y & z)` + def __and__(self, other): raise TypeError("XXX") def __rand__(self, other): - self.__and__(other, within="__rand__") + raise TypeError("XXX") class ScalarEwiseMultExpr(ScalarInfixExpr): @@ -146,12 +148,12 @@ class ScalarEwiseMultExpr(ScalarInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` __and__ = Scalar.__and__ __rand__ = Scalar.__rand__ - _ewise_add = Scalar._ewise_add _ewise_mult = Scalar._ewise_mult - _ewise_union = Scalar._ewise_union + # Don't allow e.g. `plus(x | y & z)` def __or__(self, other): raise TypeError("XXX") @@ -262,13 +264,14 @@ class VectorEwiseAddExpr(VectorInfixExpr): _to_expr = _ewise_add_to_expr - __and__ = ScalarEwiseAddExpr.__and__ # raises - __rand__ = ScalarEwiseAddExpr.__rand__ # raises + # Allow e.g. `plus(x | y | z)` __or__ = Vector.__or__ __ror__ = Vector.__ror__ _ewise_add = Vector._ewise_add - _ewise_mult = Vector._ewise_mult _ewise_union = Vector._ewise_union + # Don't allow e.g. `plus(x | y & z)` + __and__ = ScalarEwiseAddExpr.__and__ # raises + __rand__ = ScalarEwiseAddExpr.__rand__ # raises class VectorEwiseMultExpr(VectorInfixExpr): @@ -279,13 +282,13 @@ class VectorEwiseMultExpr(VectorInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` __and__ = Vector.__and__ __rand__ = Vector.__rand__ + _ewise_mult = Vector._ewise_mult + # Don't allow e.g. `plus(x | y & z)` __or__ = ScalarEwiseMultExpr.__or__ # raises __ror__ = ScalarEwiseMultExpr.__ror__ # raises - _ewise_add = Vector._ewise_add - _ewise_mult = Vector._ewise_mult - _ewise_union = Vector._ewise_union class VectorMatMulExpr(VectorInfixExpr): @@ -420,13 +423,14 @@ class MatrixEwiseAddExpr(MatrixInfixExpr): _to_expr = _ewise_add_to_expr - __and__ = VectorEwiseAddExpr.__and__ # raises - __rand__ = VectorEwiseAddExpr.__rand__ # raises + # Allow e.g. `plus(x | y | z)` __or__ = Matrix.__or__ __ror__ = Matrix.__ror__ _ewise_add = Matrix._ewise_add - _ewise_mult = Matrix._ewise_mult _ewise_union = Matrix._ewise_union + # Don't allow e.g. `plus(x | y & z)` + __and__ = VectorEwiseAddExpr.__and__ # raises + __rand__ = VectorEwiseAddExpr.__rand__ # raises class MatrixEwiseMultExpr(MatrixInfixExpr): @@ -437,13 +441,13 @@ class MatrixEwiseMultExpr(MatrixInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` __and__ = Matrix.__and__ __rand__ = Matrix.__rand__ + _ewise_mult = Matrix._ewise_mult + # Don't allow e.g. `plus(x | y & z)` __or__ = VectorEwiseMultExpr.__or__ # raises __ror__ = VectorEwiseMultExpr.__ror__ # raises - _ewise_add = Matrix._ewise_add - _ewise_mult = Matrix._ewise_mult - _ewise_union = Matrix._ewise_union class MatrixMatMulExpr(MatrixInfixExpr): @@ -470,7 +474,15 @@ def __init__(self, left, right, *, nrows, ncols): def _dummy(obj, obj_type): with recorder.skip_record: - return obj_type(BOOL, *obj.shape, name="") + return output_type(obj)(BOOL, *obj.shape, name="") + + +def _mismatched(left, right, method, op): + # Create dummy expression to raise on incompatible dimensions + getattr(_dummy(left) if isinstance(left, InfixExprBase) else left, method)( + _dummy(right) if isinstance(right, InfixExprBase) else right, op + ) + raise DimensionMismatch # pragma: no cover def _ewise_infix_expr(left, right, *, method, within): @@ -479,43 +491,43 @@ def _ewise_infix_expr(left, right, *, method, within): types = {Vector, Matrix, TransposedMatrix} if left_type in types and right_type in types: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr( - _dummy(left, left_type) if isinstance(left, InfixExprBase) else left, method - )(_dummy(right, right_type) if isinstance(right, InfixExprBase) else right, binary.first) - if expr.output_type is Vector: - if method == "ewise_mult": - return VectorEwiseMultExpr(left, right) - return VectorEwiseAddExpr(left, right) + if left_type is Vector: + if right_type is Vector: + if left._size != right._size: + _mismatched(left, right, method, binary.first) + if method == "ewise_mult": + return VectorEwiseMultExpr(left, right) + return VectorEwiseAddExpr(left, right) + if left._size != right._nrows: + _mismatched(left, right, method, binary.first) + elif right_type is Vector: + if left._ncols != right._size: + _mismatched(left, right, method, binary.first) + elif left.shape != right.shape: + _mismatched(left, right, method, binary.first) if method == "ewise_mult": return MatrixEwiseMultExpr(left, right) return MatrixEwiseAddExpr(left, right) + if within == "__or__" and isinstance(right, Mask): return right.__ror__(left) if within == "__and__" and isinstance(right, Mask): return right.__rand__(left) if left_type in types: left._expect_type(right, tuple(types), within=within, argname="right") - elif right_type in types: + if right_type in types: right._expect_type(left, tuple(types), within=within, argname="left") - elif left_type is Scalar: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr( - _dummy(left, left_type) if isinstance(left, InfixExprBase) else left, method - )(_dummy(right, right_type) if isinstance(right, InfixExprBase) else right, binary.first) + if left_type is Scalar: if method == "ewise_mult": return ScalarEwiseMultExpr(left, right) return ScalarEwiseAddExpr(left, right) - elif right_type is Scalar: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr( - _dummy(right, right_type) if isinstance(right, InfixExprBase) else right, method - )(_dummy(left, left_type) if isinstance(left, InfixExprBase) else left, binary.first) + if right_type is Scalar: if method == "ewise_mult": return ScalarEwiseMultExpr(right, left) return ScalarEwiseAddExpr(right, left) - else: # pragma: no cover (sanity) - raise TypeError(f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}") + raise TypeError( # pragma: no cover (sanity) + f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}" + ) def _matmul_infix_expr(left, right, *, within): @@ -524,56 +536,51 @@ def _matmul_infix_expr(left, right, *, within): if left_type is Vector: if right_type is Matrix or right_type is TransposedMatrix: - method = "vxm" - elif right_type is Vector: - method = "inner" - else: - right = left._expect_type( - right, - (Matrix, TransposedMatrix), - within=within, - argname="right", - ) - elif left_type is Matrix or left_type is TransposedMatrix: + if left._size != right._nrows: + _mismatched(left, right, "vxm", any_pair[BOOL]) + return VectorMatMulExpr(left, right, method_name="vxm", size=right._ncols) if right_type is Vector: - method = "mxv" - elif right_type is Matrix or right_type is TransposedMatrix: - method = "mxm" - else: - right = left._expect_type( - right, - (Vector, Matrix, TransposedMatrix), - within=within, - argname="right", - ) - elif right_type is Vector: - left = right._expect_type( + if left._size != right._size: + _mismatched(left, right, "inner", any_pair[BOOL]) + return ScalarMatMulExpr(left, right) + left._expect_type( + right, + (Matrix, TransposedMatrix, Vector), + within=within, + argname="right", + ) + if left_type is Matrix or left_type is TransposedMatrix: + if right_type is Vector: + if left._ncols != right._size: + _mismatched(left, right, "mxv", any_pair[BOOL]) + return VectorMatMulExpr(left, right, method_name="mxv", size=left._nrows) + if right_type is Matrix or right_type is TransposedMatrix: + if left._ncols != right._nrows: + _mismatched(left, right, "mxm", any_pair[BOOL]) + return MatrixMatMulExpr(left, right, nrows=left._nrows, ncols=right._ncols) + left._expect_type( + right, + (Vector, Matrix, TransposedMatrix), + within=within, + argname="right", + ) + if right_type is Vector: + right._expect_type( left, (Matrix, TransposedMatrix), within=within, argname="left", ) - elif right_type is Matrix or right_type is TransposedMatrix: - left = right._expect_type( + if right_type is Matrix or right_type is TransposedMatrix: + right._expect_type( left, (Vector, Matrix, TransposedMatrix), within=within, argname="left", ) - else: # pragma: no cover (sanity) - raise TypeError( - f"Bad types for matmul infix: {type(left).__name__}, {type(right).__name__}" - ) - - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(_dummy(left, left_type) if isinstance(left, InfixExprBase) else left, method)( - _dummy(right, right_type) if isinstance(right, InfixExprBase) else right, any_pair[BOOL] + raise TypeError( # pragma: no cover (sanity) + f"Bad types for matmul infix: {type(left).__name__}, {type(right).__name__}" ) - if expr.output_type is Vector: - return VectorMatMulExpr(left, right, method_name=method, size=expr._size) - if expr.output_type is Matrix: - return MatrixMatMulExpr(left, right, nrows=expr._nrows, ncols=expr._ncols) - return ScalarMatMulExpr(left, right) _ewise_add_expr_types = (MatrixEwiseAddExpr, VectorEwiseAddExpr, ScalarEwiseAddExpr) diff --git a/graphblas/core/scalar.py b/graphblas/core/scalar.py index bcbbdadd4..9cdf3043e 100644 --- a/graphblas/core/scalar.py +++ b/graphblas/core/scalar.py @@ -857,7 +857,6 @@ def _ewise_union(self, other, op, left_default, right_default, is_infix=False): self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop - expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if backend == "suitesparse": expr = ScalarExpression( diff --git a/graphblas/tests/test_infix.py b/graphblas/tests/test_infix.py index d141a08fd..e688086b9 100644 --- a/graphblas/tests/test_infix.py +++ b/graphblas/tests/test_infix.py @@ -534,6 +534,9 @@ def test_multi_infix_matrix(): result = monoid.min(v1 | v2 | v3).new() expected = Matrix.from_scalar(1, 3, 1) assert result.isequal(expected) + result = binary.plus(v1 | v1 | v1 | v1 | v1).new() + expected = (5 * v1).new() + assert result.isequal(expected) # ewise_mult result = monoid.max((v1 & v2) & v3).new() expected = Matrix(int, 3, 1) @@ -543,6 +546,9 @@ def test_multi_infix_matrix(): result = monoid.min((v1 & v2) & v1).new() expected = Matrix.from_coo([1], [0], [1], nrows=3) assert result.isequal(expected) + result = binary.plus(v1 & v1 & v1 & v1 & v1).new() + expected = (5 * v1).new() + assert result.isequal(expected) # ewise_union result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() expected = Matrix.from_scalar(13, 3, 1) @@ -556,6 +562,7 @@ def test_multi_infix_matrix(): assert op.plus_plus(v1.T @ v1).new()[0, 0].new().value == 6 assert op.plus_plus(v1 @ (v1.T @ D0)).new()[0, 0].new().value == 2 assert op.plus_plus((v1.T @ D0) @ v1).new()[0, 0].new().value == 6 + assert op.plus_plus(D0 @ D0 @ D0 @ D0 @ D0).new().isequal(D0) with pytest.raises(TypeError, match="XXX"): # TODO (v1 & v2) | v3 diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 964b8f612..c716c97a9 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -2837,7 +2837,6 @@ def test_auto(A, v): if method in {"__or__", "__ror__"} and type(expr) is MatrixEwiseMultExpr: # Doing e.g. `plus(A & B | C)` isn't allowed--make user be explicit with pytest.raises(TypeError): - # assert val1.isequal(val2) val2 = getattr(expected, method)(expr) with pytest.raises(TypeError): val3 = getattr(expr, method)(expected) @@ -2958,7 +2957,7 @@ def test_expr_is_like_matrix(A): "setdiag", "update", } - ignore = {"__sizeof__", "_ewise_add", "_ewise_union", "_ewise_mult", "_mxm", "_mxv"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_mxm", "_mxv"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Matrix. You may need to " "add an entry to `matrix` or `matrix_vector` set in `graphblas.core.automethods` " @@ -3023,7 +3022,7 @@ def test_index_expr_is_like_matrix(A): "resize", "setdiag", } - ignore = {"__sizeof__", "_ewise_add", "_ewise_union", "_ewise_mult", "_mxm", "_mxv"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_mxm", "_mxv"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Matrix. You may need to " "add an entry to `matrix` or `matrix_vector` set in `graphblas.core.automethods` " diff --git a/graphblas/tests/test_scalar.py b/graphblas/tests/test_scalar.py index 659d16f50..aeb19e170 100644 --- a/graphblas/tests/test_scalar.py +++ b/graphblas/tests/test_scalar.py @@ -360,7 +360,7 @@ def test_expr_is_like_scalar(s): } if s.is_cscalar: expected.add("_empty") - ignore = {"__sizeof__", "_ewise_union", "_ewise_mult", "_ewise_add"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Scalar. You may need to " "add an entry to `scalar` set in `graphblas.core.automethods` " @@ -402,7 +402,7 @@ def test_index_expr_is_like_scalar(s): } if s.is_cscalar: expected.add("_empty") - ignore = {"__sizeof__", "_ewise_union", "_ewise_mult", "_ewise_add"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Scalar. You may need to " "add an entry to `scalar` set in `graphblas.core.automethods` " diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index d677ff545..1c9a8d38c 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -1664,7 +1664,7 @@ def test_expr_is_like_vector(v): "resize", "update", } - ignore = {"__sizeof__", "_inner", "_vxm", "_ewise_add", "_ewise_union", "_ewise_mult"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_inner", "_vxm"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Vector. You may need to " "add an entry to `vector` or `matrix_vector` set in `graphblas.core.automethods` " @@ -1713,7 +1713,7 @@ def test_index_expr_is_like_vector(v): "from_values", "resize", } - ignore = {"__sizeof__", "_inner", "_vxm", "_ewise_add", "_ewise_union", "_ewise_mult"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_inner", "_vxm"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Vector. You may need to " "add an entry to `vector` or `matrix_vector` set in `graphblas.core.automethods` "