diff --git a/.flake8 b/.flake8 index 0dede3f1d..80124c9e8 100644 --- a/.flake8 +++ b/.flake8 @@ -12,5 +12,6 @@ extend-ignore = per-file-ignores = scripts/create_pickle.py:F403,F405, graphblas/tests/*.py:T201, + graphblas/core/agg.py:F401,F403, graphblas/core/ss/matrix.py:SIM113, graphblas/**/__init__.py:F401, diff --git a/.github/workflows/test_and_build.yml b/.github/workflows/test_and_build.yml index 0dfa08859..807123889 100644 --- a/.github/workflows/test_and_build.yml +++ b/.github/workflows/test_and_build.yml @@ -188,7 +188,7 @@ jobs: echo "versions: np${npver} sp${spver} pd${pdver} ak${akver} nx${nxver} numba${numbaver} yaml${yamlver} sparse${sparsever} psgver${psgver}" # Once we have wheels for all OSes, we can delete the last two lines. - mamba install packaging pytest coverage coveralls=3.3.1 pytest-randomly cffi donfig pyyaml${yamlver} sparse${sparsever} \ + mamba install packaging pytest coverage coveralls=3.3.1 pytest-randomly cffi donfig tomli pyyaml${yamlver} sparse${sparsever} \ pandas${pdver} scipy${spver} numpy${npver} awkward${akver} networkx${nxver} numba${numbaver} fast_matrix_market${fmmver} \ ${{ matrix.slowtask == 'pytest_bizarro' && 'black' || '' }} \ ${{ matrix.slowtask == 'notebooks' && 'matplotlib nbconvert jupyter "ipython>=7"' || '' }} \ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 05469a926..8eb2bf10b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,12 @@ repos: hooks: - id: validate-pyproject name: Validate pyproject.toml + # I don't yet trust ruff to do what autoflake does + - repo: https://github.com/myint/autoflake + rev: v2.0.2 + hooks: + - id: autoflake + args: [--in-place] # We can probably remove `isort` if we come to trust `ruff --fix`, # but we'll need to figure out the configuration to do this in `ruff` - repo: https://github.com/pycqa/isort @@ -42,15 +48,15 @@ repos: - id: auto-walrus args: [--line-length, "100"] - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 23.3.0 hooks: - id: black - id: black-jupyter - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.259 + rev: v0.0.260 hooks: - id: ruff - args: [--fix-only] + args: [--fix-only, --show-fixes] # Let's keep `flake8` even though `ruff` does much of the same. # `flake8-bugbear` and `flake8-simplify` have caught things missed by `ruff`. - repo: https://github.com/PyCQA/flake8 @@ -60,7 +66,7 @@ repos: additional_dependencies: &flake8_dependencies # These versions need updated manually - flake8==6.0.0 - - flake8-bugbear==23.3.12 + - flake8-bugbear==23.3.23 - flake8-simplify==0.19.3 - repo: https://github.com/asottile/yesqa rev: v1.4.0 @@ -75,7 +81,7 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.259 + rev: v0.0.260 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint diff --git a/dev-requirements.txt b/dev-requirements.txt index 273980db9..a281672ec 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -19,6 +19,7 @@ pre-commit # For testing packaging pytest-cov +tomli # For debugging icecream ipykernel diff --git a/environment.yml b/environment.yml index 41eb3c43d..875ec5cbd 100644 --- a/environment.yml +++ b/environment.yml @@ -36,6 +36,7 @@ dependencies: # For testing - packaging - pytest-cov + - tomli # For debugging - icecream - ipykernel diff --git a/graphblas/agg/__init__.py b/graphblas/agg/__init__.py index f2dddb851..c1319facb 100644 --- a/graphblas/agg/__init__.py +++ b/graphblas/agg/__init__.py @@ -111,6 +111,6 @@ def __getattr__(key): raise AttributeError(f"module {__name__!r} has no attribute {key!r}") -from ..core import agg # noqa: E402 isort:skip +from ..core import operator # noqa: E402 isort:skip -del agg +del operator diff --git a/graphblas/agg/ss.py b/graphblas/agg/ss.py index c3f06c0a7..e45cbcda0 100644 --- a/graphblas/agg/ss.py +++ b/graphblas/agg/ss.py @@ -1,3 +1,3 @@ -from ..core import agg +from ..core import operator -del agg +del operator diff --git a/graphblas/binary/numpy.py b/graphblas/binary/numpy.py index 21ed568ea..68764db05 100644 --- a/graphblas/binary/numpy.py +++ b/graphblas/binary/numpy.py @@ -143,17 +143,18 @@ def __getattr__(name): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") if _config.get("mapnumpy") and name in _numpy_to_graphblas: if name == "float_power": - from ..core import operator + from ..core.operator import binary + from ..dtypes import FP64 - new_op = operator.BinaryOp(f"numpy.{name}") + new_op = binary.BinaryOp(f"numpy.{name}") builtin_op = _binary.pow for dtype in builtin_op.types: if dtype.name in {"FP32", "FC32", "FC64"}: orig_dtype = dtype else: - orig_dtype = operator.FP64 + orig_dtype = FP64 orig_op = builtin_op[orig_dtype] - cur_op = operator.TypedBuiltinBinaryOp( + cur_op = binary.TypedBuiltinBinaryOp( new_op, new_op.name, dtype, @@ -166,14 +167,12 @@ def __getattr__(name): else: globals()[name] = getattr(_binary, _numpy_to_graphblas[name]) else: - from ..core import operator - numpy_func = getattr(_np, name) def func(x, y): # pragma: no cover (numba) return numpy_func(x, y) - operator.BinaryOp.register_new(f"numpy.{name}", func) + _binary.register_new(f"numpy.{name}", func) rv = globals()[name] if name in _commutative: rv._commutes_to = rv diff --git a/graphblas/core/agg.py b/graphblas/core/agg.py index 3afcbc408..3418daffc 100644 --- a/graphblas/core/agg.py +++ b/graphblas/core/agg.py @@ -1,680 +1,17 @@ -from functools import partial -from operator import getitem +"""graphblas.core.agg namespace is deprecated; please use graphblas.core.operator.agg instead. -import numpy as np +.. deprecated:: 2023.3.0 +`graphblas.core.agg` will be removed in a future release. +Use `graphblas.core.operator.agg` instead. +Will be removed in version 2023.11.0 or later. -from .. import agg, backend, binary, monoid, semiring, unary -from ..dtypes import INT64, lookup_dtype -from .utils import output_type +""" +import warnings +from .operator.agg import * -def _get_types(ops, initdtype): - """Determine the input and output types of an aggregator based on a list of ops.""" - if initdtype is None: - prev = dict(ops[0].types) - else: - op = ops[0] - prev = {key: get_typed_op(op, key, initdtype).return_type for key in op.types} - for op in ops[1:]: - cur = {} - types = op.types - for in_type, out_type in prev.items(): - if out_type not in types: # pragma: no cover (safety) - continue - cur[in_type] = types[out_type] - prev = cur - return prev - - -class Aggregator: - opclass = "Aggregator" - - def __init__( - self, - name, - *, - initval=None, - monoid=None, - semiring=None, - switch=False, - semiring2=None, - finalize=None, - composite=None, - custom=None, - types=None, - any_dtype=None, - ): - self.name = name - self._initval_orig = initval - self._initval = False if initval is None else initval - self._initdtype = lookup_dtype(type(self._initval), self._initval) - self._monoid = monoid - self._semiring = semiring - self._semiring2 = semiring2 - self._switch = switch - self._finalize = finalize - self._composite = composite - self._custom = custom - if types is None: - if monoid is not None: - types = [monoid] - elif semiring is not None: - types = [semiring, semiring2] - if finalize is not None: - types.append(finalize) - initval = self._initval - else: # pragma: no cover (sanity) - raise TypeError("types must be provided for composite and custom aggregators") - self._types_orig = types - self._types = None - self._typed_ops = {} - self._any_dtype = any_dtype - - @property - def types(self): - if self._types is None: - if type(self._semiring) is str: - self._semiring = semiring.from_string(self._semiring) - if type(self._types_orig[0]) is str: # pragma: no branch - self._types_orig[0] = semiring.from_string(self._types_orig[0]) - self._types = _get_types( - self._types_orig, None if self._initval_orig is None else self._initdtype - ) - return self._types - - def __getitem__(self, dtype): - dtype = lookup_dtype(dtype) - if not self._any_dtype and dtype not in self.types: - raise KeyError(f"{self.name} does not work with {dtype}") - if dtype not in self._typed_ops: - self._typed_ops[dtype] = TypedAggregator(self, dtype) - return self._typed_ops[dtype] - - def __contains__(self, dtype): - dtype = lookup_dtype(dtype) - return self._any_dtype or dtype in self.types - - def __repr__(self): - if self.name in agg._deprecated: - return f"agg.ss.{self.name}" - return f"agg.{self.name}" - - def __reduce__(self): - if self.name in agg._deprecated: - return f"agg.ss.{self.name}" - return f"agg.{self.name}" - - def __call__(self, val, *, rowwise=False, columnwise=False): - # Should we expose `allow_empty=` keyword when reducing to Scalar? - from .matrix import Matrix, TransposedMatrix - from .vector import Vector - - typ = output_type(val) - if typ is Vector: - if rowwise or columnwise: - raise ValueError( - "rowwise and columnwise arguments should not be used with Vector input" - ) - return val.reduce(self) - if typ in {Matrix, TransposedMatrix}: - if rowwise: - if columnwise: - raise ValueError("rowwise and columnwise arguments cannot both be True") - return val.reduce_rowwise(self) - if columnwise: - return val.reduce_columnwise(self) - return val.reduce_scalar(self) - raise TypeError( - f"Bad type when calling {self!r}.\n" - " - Expected type: Vector, Matrix, TransposedMatrix.\n" - f" - Got: {type(val)}.\n" - "Calling an Aggregator is syntactic sugar for calling reduce methods. " - f"For example, `A.reduce_scalar({self!r})` is the same as `{self!r}(A)`." - ) - - -class TypedAggregator: - opclass = "Aggregator" - - def __init__(self, agg, dtype): - self.name = agg.name - self.parent = agg - self.type = dtype - if dtype in agg.types: - self.return_type = agg.types[dtype] - elif agg._any_dtype is True: - self.return_type = dtype - else: - self.return_type = agg._any_dtype - - def __repr__(self): - return f"agg.{self.name}[{self.type}]" - - def _new(self, updater, expr, *, in_composite=False): - agg = self.parent - if agg._monoid is not None: - x = expr.args[0] - method = getattr(x, expr.method_name) - if expr.output_type.__name__ == "Scalar": - expr = method(agg._monoid[self.type], allow_empty=not expr._is_cscalar) - else: - expr = method(agg._monoid[self.type]) - updater << expr - if in_composite: - parent = updater.parent - if not parent._is_scalar: - return parent - return parent._as_vector() - return - - opts = updater.opts - if agg._composite is not None: - # Masks are applied throughout the aggregation, including composite aggregations. - # Aggregations done while `in_composite is True` should return the updater parent - # if the result is not a Scalar. If the result is a Scalar, then there can be no - # output mask, and a Vector of size 1 should be returned instead. - results = [] - mask = updater.kwargs.get("mask") - for cur_agg in agg._composite: - cur_agg = cur_agg[self.type] # Hopefully works well enough - arg = expr.construct_output(cur_agg.return_type) - results.append(cur_agg._new(arg(mask=mask, **opts), expr, in_composite=True)) - final_expr = agg._finalize(*results, opts) - if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": - updater << final_expr - elif expr.cfunc_name.startswith("GrB_Vector_reduce") or expr.cfunc_name.startswith( - "GrB_Matrix_reduce" - ): - final = final_expr.new(**opts) - updater << final[0] - else: - raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") - if in_composite: - parent = updater.parent - if not parent._is_scalar: - return parent - return parent._as_vector() - return - - if agg._custom is not None: - return agg._custom(self, updater, expr, opts, in_composite=in_composite) - - semiring = get_typed_op(agg._semiring, self.type, agg._initdtype) - if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": - # Matrix -> Vector - A = expr.args[0] - orig_updater = updater - if agg._finalize is not None: - step1 = expr.construct_output(semiring.return_type) - updater = step1(mask=updater.kwargs.get("mask"), **opts) - if expr.method_name == "reduce_columnwise": - A = A.T - size = A._ncols - init = expr._new_vector(agg._initdtype, size=size) - init(**opts)[...] = agg._initval # O(1) dense vector in SuiteSparse 5 - if agg._switch: - updater << semiring(init @ A.T) - else: - updater << semiring(A @ init) - if agg._finalize is not None: - orig_updater << agg._finalize[semiring.return_type](step1) - if in_composite: - return orig_updater.parent - elif expr.cfunc_name.startswith("GrB_Vector_reduce"): - # Vector -> Scalar - v = expr.args[0] - step1 = expr._new_vector(semiring.return_type, size=1) - init = expr._new_matrix(agg._initdtype, nrows=v._size, ncols=1) - init(**opts)[...] = agg._initval # O(1) dense column vector in SuiteSparse 5 - if agg._switch: - step1(**opts) << semiring(init.T @ v) - else: - step1(**opts) << semiring(v @ init) - if agg._finalize is not None: - finalize = agg._finalize[semiring.return_type] - if step1.dtype == finalize.return_type: - step1(**opts) << finalize(step1) - else: - step1 = finalize(step1).new(finalize.return_type, **opts) - if in_composite: - return step1 - updater << step1[0] - elif expr.cfunc_name.startswith("GrB_Matrix_reduce"): - # Matrix -> Scalar - A = expr.args[0] - # We need to compute in two steps: Matrix -> Vector -> Scalar. - # This has not been benchmarked or optimized. - # We may be able to intelligently choose the faster path. - init1 = expr._new_vector(agg._initdtype, size=A._ncols) - init1(**opts)[...] = agg._initval # O(1) dense vector in SuiteSparse 5 - step1 = expr._new_vector(semiring.return_type, size=A._nrows) - if agg._switch: - step1(**opts) << semiring(init1 @ A.T) - else: - step1(**opts) << semiring(A @ init1) - init2 = expr._new_matrix(agg._initdtype, nrows=A._nrows, ncols=1) - init2(**opts)[...] = agg._initval # O(1) dense vector in SuiteSparse 5 - semiring2 = agg._semiring2[semiring.return_type] - step2 = expr._new_vector(semiring2.return_type, size=1) - step2(**opts) << semiring2(step1 @ init2) - if agg._finalize is not None: - finalize = agg._finalize[semiring2.return_type] - if step2.dtype == finalize.return_type: - step2 << finalize(step2) - else: - step2 = finalize(step2).new(finalize.return_type, **opts) - if in_composite: - return step2 - updater << step2[0] - else: - raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") - - def __reduce__(self): - return (getitem, (self.parent, self.type)) - - __call__ = Aggregator.__call__ - - -# Monoid-only -agg.sum = Aggregator("sum", monoid=monoid.plus) -agg.prod = Aggregator("prod", monoid=monoid.times) -agg.all = Aggregator("all", monoid=monoid.land) -agg.any = Aggregator("any", monoid=monoid.lor) -agg.min = Aggregator("min", monoid=monoid.min) -agg.max = Aggregator("max", monoid=monoid.max) -agg.any_value = Aggregator("any_value", monoid=monoid.any, any_dtype=True) -agg.bitwise_all = Aggregator("bitwise_all", monoid=monoid.band) -agg.bitwise_any = Aggregator("bitwise_any", monoid=monoid.bor) -# Other monoids: bxnor bxor eq lxnor lxor - -# Semiring-only -agg.count = Aggregator( - "count", semiring=semiring.plus_pair, semiring2=semiring.plus_first, any_dtype=INT64 -) -agg.count_nonzero = Aggregator( - "count_nonzero", semiring=semiring.plus_isne, semiring2=semiring.plus_first -) -agg.count_zero = Aggregator( - "count_zero", semiring=semiring.plus_iseq, semiring2=semiring.plus_first -) -agg.sum_of_squares = Aggregator( - "sum_of_squares", initval=2, semiring=semiring.plus_pow, semiring2=semiring.plus_first -) -agg.sum_of_inverses = Aggregator( - "sum_of_inverses", - initval=-1.0, - semiring=semiring.plus_pow, - semiring2=semiring.plus_first, -) -agg.exists = Aggregator( - "exists", semiring=semiring.any_pair, semiring2=semiring.any_pair, any_dtype=INT64 -) - -# Semiring and finalize -agg.hypot = Aggregator( - "hypot", - initval=2, - semiring=semiring.plus_pow, - semiring2=semiring.plus_first, - finalize=unary.sqrt, -) -agg.logaddexp = Aggregator( - "logaddexp", - initval=np.e, - semiring=semiring.plus_pow, - switch=True, - semiring2=semiring.plus_first, - finalize=unary.log, -) -agg.logaddexp2 = Aggregator( - "logaddexp2", - initval=2, - semiring=semiring.plus_pow, - switch=True, - semiring2=semiring.plus_first, - finalize=unary.log2, -) -# Alternatives -# logaddexp = Aggregator('logaddexp', monoid=semiring.numpy.logaddexp) -# logaddexp2 = Aggregator('logaddexp2', monoid=semiring.numpy.logaddexp2) -# hypot as monoid doesn't work if single negative element! -# hypot = Aggregator('hypot', monoid=semiring.numpy.hypot) - -agg.L0norm = agg.count_nonzero -agg.L1norm = Aggregator("L1norm", semiring="plus_absfirst", semiring2=semiring.plus_first) -agg.L2norm = agg.hypot -agg.Linfnorm = Aggregator("Linfnorm", semiring="max_absfirst", semiring2=semiring.max_first) - - -# Composite -def _mean_finalize(c, x, opts): - return binary.truediv(x & c) - - -def _ptp_finalize(max, min, opts): - return binary.minus(max & min) - - -def _varp_finalize(c, x, x2, opts): - # / n - ( / n)**2 - left = binary.truediv(x2 & c).new(**opts) - right = binary.truediv(x & c).new(**opts) - right(**opts) << binary.pow(right, 2) - return binary.minus(left & right) - - -def _vars_finalize(c, x, x2, opts): - # / (n-1) - **2 / (n * (n-1)) - x(**opts) << binary.pow(x, 2) - right = binary.truediv(x & c).new(**opts) - c(**opts) << binary.minus(c, 1) - right(**opts) << binary.truediv(right & c) - left = binary.truediv(x2 & c).new(**opts) - return binary.minus(left & right) - - -def _stdp_finalize(c, x, x2, opts): - val = _varp_finalize(c, x, x2, opts).new(**opts) - return unary.sqrt(val) - - -def _stds_finalize(c, x, x2, opts): - val = _vars_finalize(c, x, x2, opts).new(**opts) - return unary.sqrt(val) - - -def _geometric_mean_finalize(c, x, opts): - right = unary.minv["FP64"](c).new(**opts) - return binary.pow(x & right) - - -def _harmonic_mean_finalize(c, x, opts): - return binary.truediv(c & x) - - -def _root_mean_square_finalize(c, x2, opts): - val = binary.truediv(x2 & c).new(**opts) - return unary.sqrt(val) - - -agg.mean = Aggregator( - "mean", - composite=[agg.count, agg.sum], - finalize=_mean_finalize, - types=[binary.truediv], -) -agg.peak_to_peak = Aggregator( - "peak_to_peak", - composite=[agg.max, agg.min], - finalize=_ptp_finalize, - types=[monoid.min], -) -agg.varp = Aggregator( - "varp", - composite=[agg.count, agg.sum, agg.sum_of_squares], - finalize=_varp_finalize, - types=[binary.truediv], -) -agg.vars = Aggregator( - "vars", - composite=[agg.count, agg.sum, agg.sum_of_squares], - finalize=_vars_finalize, - types=[binary.truediv], +warnings.warn( + "graphblas.core.agg namespace is deprecated; please use graphblas.core.operator.agg instead.", + DeprecationWarning, + stacklevel=1, ) -agg.stdp = Aggregator( - "stdp", - composite=[agg.count, agg.sum, agg.sum_of_squares], - finalize=_stdp_finalize, - types=[binary.truediv, unary.sqrt], -) -agg.stds = Aggregator( - "stds", - composite=[agg.count, agg.sum, agg.sum_of_squares], - finalize=_stds_finalize, - types=[binary.truediv, unary.sqrt], -) -agg.geometric_mean = Aggregator( - "geometric_mean", - composite=[agg.count, agg.prod], - finalize=_geometric_mean_finalize, - types=[binary.truediv], -) -agg.harmonic_mean = Aggregator( - "harmonic_mean", - composite=[agg.count, agg.sum_of_inverses], - finalize=_harmonic_mean_finalize, - types=[agg.sum_of_inverses, binary.truediv], -) -agg.root_mean_square = Aggregator( - "root_mean_square", - composite=[agg.count, agg.sum_of_squares], - finalize=_root_mean_square_finalize, - types=[binary.truediv, unary.sqrt], -) - - -# Special recipes -def _argminmaxij( - agg, - updater, - expr, - opts, - *, - in_composite, - monoid, - col_semiring, - row_semiring, -): - if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": - A = expr.args[0] - if expr.method_name == "reduce_rowwise": - step1 = A.reduce_rowwise(monoid).new(**opts) - - D = step1.diag() - - masked = semiring.any_eq(D @ A).new(**opts) - masked(mask=masked.V, replace=True, **opts) << masked # Could use select - init = expr._new_vector(bool, size=A._ncols) - init(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 - updater << row_semiring(masked @ init) - if in_composite: - return updater.parent - else: - step1 = A.reduce_columnwise(monoid).new(**opts) - - D = step1.diag() - - masked = semiring.any_eq(A @ D).new(**opts) - masked(mask=masked.V, replace=True, **opts) << masked # Could use select - init = expr._new_vector(bool, size=A._nrows) - init(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 - updater << col_semiring(init @ masked) - if in_composite: - return updater.parent - elif expr.cfunc_name.startswith("GrB_Vector_reduce"): - v = expr.args[0] - step1 = v.reduce(monoid, allow_empty=False).new(**opts) - masked = binary.eq(v, step1).new(**opts) - masked(mask=masked.V, replace=True, **opts) << masked # Could use select - init = expr._new_matrix(bool, nrows=v._size, ncols=1) - init(**opts)[...] = False # O(1) dense column vector in SuiteSparse 5 - step2 = col_semiring(masked @ init).new(**opts) - if in_composite: - return step2 - updater << step2[0] - else: - raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") - - -def _argminmax(agg, updater, expr, opts, *, in_composite, monoid): - if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": - if expr.method_name == "reduce_rowwise": - return _argminmaxij( - agg, - updater, - expr, - opts, - in_composite=in_composite, - monoid=monoid, - row_semiring=semiring._deprecated["min_firstj"], - col_semiring=semiring._deprecated["min_secondj"], - ) - return _argminmaxij( - agg, - updater, - expr, - opts, - in_composite=in_composite, - monoid=monoid, - row_semiring=semiring._deprecated["min_firsti"], - col_semiring=semiring._deprecated["min_secondi"], - ) - if expr.cfunc_name.startswith("GrB_Vector_reduce"): - return _argminmaxij( - agg, - updater, - expr, - opts, - in_composite=in_composite, - monoid=monoid, - row_semiring=semiring._deprecated["min_firsti"], - col_semiring=semiring._deprecated["min_secondi"], - ) - if expr.cfunc_name.startswith("GrB_Matrix_reduce"): - raise ValueError(f"Aggregator {agg.name} may not be used with Matrix.reduce_scalar.") - raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") - - -# These "do the right thing", but don't work with `reduce_scalar` -_argmin = Aggregator( - "argmin", - custom=partial(_argminmax, monoid=monoid.min), - types=[semiring._deprecated["min_firsti"]], -) -_argmax = Aggregator( - "argmax", - custom=partial(_argminmax, monoid=monoid.max), - types=[semiring._deprecated["min_firsti"]], -) - - -def _first_last(agg, updater, expr, opts, *, in_composite, semiring_): - if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": - A = expr.args[0] - if expr.method_name == "reduce_columnwise": - A = A.T - init = expr._new_vector(bool, size=A._ncols) - init(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 - step1 = semiring_(A @ init).new(**opts) - Is, Js = step1.to_coo() - - Matrix_ = type(expr._new_matrix(bool)) - P = Matrix_.from_coo(Js, Is, 1, nrows=A._ncols, ncols=A._nrows) - mask = step1.diag() - result = semiring.any_first(A @ P).new(mask=mask.S, **opts).diag(**opts) - - updater << result - if in_composite: - return updater.parent - elif expr.cfunc_name.startswith("GrB_Vector_reduce"): - v = expr.args[0] - init = expr._new_matrix(bool, nrows=v._size, ncols=1) - init(**opts)[...] = False # O(1) dense matrix in SuiteSparse 5 - step1 = semiring_(v @ init).new(**opts) - index = step1[0].new().value - # `==` instead of `is` automatically triggers index.compute() in dask-graphblas: - if index == None: # noqa: E711 - index = 0 - if in_composite: - return v[[index]].new(**opts) - updater << v[index] - else: # GrB_Matrix_reduce - A = expr.args[0] - init1 = expr._new_matrix(bool, nrows=A._ncols, ncols=1) - init1(**opts)[...] = False # O(1) dense matrix in SuiteSparse 5 - step1 = semiring_(A @ init1).new(**opts) - init2 = expr._new_vector(bool, size=A._nrows) - init2(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 - step2 = semiring_(step1.T @ init2).new(**opts) - i = step2[0].new().value - # `==` instead of `is` automatically triggers i.compute() in dask-graphblas: - if i == None: # noqa: E711 - i = j = 0 - else: - j = step1[i, 0].new().value - if in_composite: - return A[i, [j]].new(**opts) - updater << A[i, j] - - -_first = Aggregator( - "first", - custom=partial(_first_last, semiring_=semiring._deprecated["min_secondi"]), - types=[binary.first], - any_dtype=True, -) -_last = Aggregator( - "last", - custom=partial(_first_last, semiring_=semiring._deprecated["max_secondi"]), - types=[binary.second], - any_dtype=True, -) - - -def _first_last_index(agg, updater, expr, opts, *, in_composite, semiring): - if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": - A = expr.args[0] - if expr.method_name == "reduce_columnwise": - A = A.T - init = expr._new_vector(bool, size=A._ncols) - init(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 - expr = semiring(A @ init) - updater << expr - if in_composite: - return updater.parent - elif expr.cfunc_name.startswith("GrB_Vector_reduce"): - v = expr.args[0] - init = expr._new_matrix(bool, nrows=v._size, ncols=1) - init(**opts)[...] = False # O(1) dense matrix in SuiteSparse 5 - step1 = semiring(v @ init).new(**opts) - if in_composite: - return step1 - updater << step1[0] - elif expr.cfunc_name.startswith("GrB_Matrix_reduce"): - raise ValueError(f"Aggregator {agg.name} may not be used with Matrix.reduce_scalar.") - else: - raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") - - -_first_index = Aggregator( - "first_index", - custom=partial(_first_last_index, semiring=semiring._deprecated["min_secondi"]), - types=[semiring._deprecated["min_secondi"]], - any_dtype=INT64, -) -_last_index = Aggregator( - "last_index", - custom=partial(_first_last_index, semiring=semiring._deprecated["max_secondi"]), - types=[semiring._deprecated["min_secondi"]], - any_dtype=INT64, -) -agg._deprecated = { - "argmin": _argmin, - "argmax": _argmax, - "first": _first, - "last": _last, - "first_index": _first_index, - "last_index": _last_index, -} -if backend == "suitesparse": - agg.ss.argmin = _argmin - agg.ss.argmax = _argmax - agg.ss.first = _first - agg.ss.last = _last - agg.ss.first_index = _first_index - agg.ss.last_index = _last_index - -agg.Aggregator = Aggregator -agg.TypedAggregator = TypedAggregator - -from .operator import get_typed_op # noqa: E402 isort:skip diff --git a/graphblas/core/operator.py b/graphblas/core/operator.py deleted file mode 100644 index b38add7f1..000000000 --- a/graphblas/core/operator.py +++ /dev/null @@ -1,3585 +0,0 @@ -import inspect -import itertools -import re -from collections.abc import Mapping -from functools import lru_cache, reduce -from operator import getitem, mul -from types import BuiltinFunctionType, FunctionType, ModuleType - -import numba -import numpy as np - -from .. import ( - _STANDARD_OPERATOR_NAMES, - backend, - binary, - config, - indexunary, - monoid, - op, - select, - semiring, - unary, -) -from ..dtypes import ( - BOOL, - FP32, - FP64, - INT8, - INT16, - INT32, - INT64, - UINT8, - UINT16, - UINT32, - UINT64, - _sample_values, - _supports_complex, - lookup_dtype, - unify, -) -from ..exceptions import UdfParseError, check_status_carg -from . import ffi, lib -from .expr import InfixExprBase -from .utils import libget, output_type - -if _supports_complex: - from ..dtypes import FC32, FC64 - -ffi_new = ffi.new -UNKNOWN_OPCLASS = "UnknownOpClass" - -# These now live as e.g. `gb.unary.ss.positioni` -# Deprecations such as `gb.unary.positioni` will be removed in 2023.9.0 or later. -_SS_OPERATORS = { - # unary - "erf", # scipy.special.erf - "erfc", # scipy.special.erfc - "frexpe", # np.frexp[1] - "frexpx", # np.frexp[0] - "lgamma", # scipy.special.loggamma - "tgamma", # scipy.special.gamma - # Positional - # unary - "positioni", - "positioni1", - "positionj", - "positionj1", - # binary - "firsti", - "firsti1", - "firstj", - "firstj1", - "secondi", - "secondi1", - "secondj", - "secondj1", - # semiring - "any_firsti", - "any_firsti1", - "any_firstj", - "any_firstj1", - "any_secondi", - "any_secondi1", - "any_secondj", - "any_secondj1", - "max_firsti", - "max_firsti1", - "max_firstj", - "max_firstj1", - "max_secondi", - "max_secondi1", - "max_secondj", - "max_secondj1", - "min_firsti", - "min_firsti1", - "min_firstj", - "min_firstj1", - "min_secondi", - "min_secondi1", - "min_secondj", - "min_secondj1", - "plus_firsti", - "plus_firsti1", - "plus_firstj", - "plus_firstj1", - "plus_secondi", - "plus_secondi1", - "plus_secondj", - "plus_secondj1", - "times_firsti", - "times_firsti1", - "times_firstj", - "times_firstj1", - "times_secondi", - "times_secondi1", - "times_secondj", - "times_secondj1", -} - - -def _hasop(module, name): - return ( - name in module.__dict__ - or name in module._delayed - or name in getattr(module, "_deprecated", ()) - ) - - -class OpPath: - def __init__(self, parent, name): - self._parent = parent - self._name = name - self._delayed = {} - self._delayed_commutes_to = {} - - def __getattr__(self, key): - if key in self._delayed: - func, kwargs = self._delayed.pop(key) - return func(**kwargs) - self.__getattribute__(key) # raises - - -def _call_op(op, left, right=None, thunk=None, **kwargs): - if right is None and thunk is None: - if isinstance(left, InfixExprBase): - # op(A & B), op(A | B), op(A @ B) - return getattr(left.left, left.method_name)(left.right, op, **kwargs) - if find_opclass(op)[1] == "Semiring": - raise TypeError( - f"Bad type when calling {op!r}. Got type: {type(left)}.\n" - f"Expected an infix expression, such as: {op!r}(A @ B)" - ) - raise TypeError( - f"Bad type when calling {op!r}. Got type: {type(left)}.\n" - "Expected an infix expression or an apply with a Vector or Matrix and a scalar:\n" - f" - {op!r}(A & B)\n" - f" - {op!r}(A, 1)\n" - f" - {op!r}(1, A)" - ) - - # op(A, 1) -> apply (or select if thunk provided) - from .matrix import Matrix, TransposedMatrix - from .vector import Vector - - if (left_type := output_type(left)) in {Vector, Matrix, TransposedMatrix}: - if thunk is not None: - return left.select(op, thunk=thunk, **kwargs) - return left.apply(op, right=right, **kwargs) - if (right_type := output_type(right)) in {Vector, Matrix, TransposedMatrix}: - return right.apply(op, left=left, **kwargs) - - from .scalar import Scalar, _as_scalar - - if left_type is Scalar: - if thunk is not None: - return left.select(op, thunk=thunk, **kwargs) - return left.apply(op, right=right, **kwargs) - if right_type is Scalar: - return right.apply(op, left=left, **kwargs) - try: - left_scalar = _as_scalar(left, is_cscalar=False) - except Exception: - pass - else: - if thunk is not None: - return left_scalar.select(op, thunk=thunk, **kwargs) - return left_scalar.apply(op, right=right, **kwargs) - raise TypeError( - f"Bad types when calling {op!r}. Got types: {type(left)}, {type(right)}.\n" - "Expected an infix expression or an apply with a Vector or Matrix and a scalar:\n" - f" - {op!r}(A & B)\n" - f" - {op!r}(A, 1)\n" - f" - {op!r}(1, A)" - ) - - -_udt_mask_cache = {} - - -def _udt_mask(dtype): - """Create mask to determine which bytes of UDTs to use for equality check.""" - if dtype in _udt_mask_cache: - return _udt_mask_cache[dtype] - if dtype.subdtype is not None: - mask = _udt_mask(dtype.subdtype[0]) - N = reduce(mul, dtype.subdtype[1]) - rv = np.concatenate([mask] * N) - elif dtype.names is not None: - prev_offset = mask = None - masks = [] - for name in dtype.names: - dtype2, offset = dtype.fields[name] - if mask is not None: - masks.append(np.pad(mask, (0, offset - prev_offset - mask.size))) - mask = _udt_mask(dtype2) - prev_offset = offset - masks.append(np.pad(mask, (0, dtype.itemsize - prev_offset - mask.size))) - rv = np.concatenate(masks) - else: - rv = np.ones(dtype.itemsize, dtype=bool) - # assert rv.size == dtype.itemsize - _udt_mask_cache[dtype] = rv - return rv - - -class TypedOpBase: - __slots__ = ( - "parent", - "name", - "type", - "return_type", - "gb_obj", - "gb_name", - "_type2", - "__weakref__", - ) - - def __init__(self, parent, name, type_, return_type, gb_obj, gb_name, dtype2=None): - self.parent = parent - self.name = name - self.type = type_ - self.return_type = return_type - self.gb_obj = gb_obj - self.gb_name = gb_name - self._type2 = dtype2 - - def __repr__(self): - classname = self.opclass.lower() - if classname.endswith("op"): - classname = classname[:-2] - dtype2 = "" if self._type2 is None else f", {self._type2.name}" - return f"{classname}.{self.name}[{self.type.name}{dtype2}]" - - @property - def _carg(self): - return self.gb_obj - - @property - def is_positional(self): - return self.parent.is_positional - - def __reduce__(self): - if self._type2 is None or self.type == self._type2: - return (getitem, (self.parent, self.type)) - return (getitem, (self.parent, (self.type, self._type2))) - - -class TypedBuiltinUnaryOp(TypedOpBase): - __slots__ = () - opclass = "UnaryOp" - - def __call__(self, val): - from .matrix import Matrix, TransposedMatrix - from .vector import Vector - - if (typ := output_type(val)) in {Vector, Matrix, TransposedMatrix}: - return val.apply(self) - from .scalar import Scalar, _as_scalar - - if typ is Scalar: - return val.apply(self) - try: - scalar = _as_scalar(val, is_cscalar=False) - except Exception: - pass - else: - return scalar.apply(self) - raise TypeError( - f"Bad type when calling {self!r}.\n" - " - Expected type: Scalar, Vector, Matrix, TransposedMatrix.\n" - f" - Got: {type(val)}.\n" - "Calling a UnaryOp is syntactic sugar for calling apply. " - f"For example, `A.apply({self!r})` is the same as `{self!r}(A)`." - ) - - -class TypedBuiltinIndexUnaryOp(TypedOpBase): - __slots__ = () - opclass = "IndexUnaryOp" - - def __call__(self, val, thunk=None): - if thunk is None: - thunk = False # most basic form of 0 when unifying dtypes - return _call_op(self, val, right=thunk) - - -class TypedBuiltinSelectOp(TypedOpBase): - __slots__ = () - opclass = "SelectOp" - - def __call__(self, val, thunk=None): - if thunk is None: - thunk = False # most basic form of 0 when unifying dtypes - return _call_op(self, val, thunk=thunk) - - -class TypedBuiltinBinaryOp(TypedOpBase): - __slots__ = () - opclass = "BinaryOp" - - def __call__(self, left, right=None, *, left_default=None, right_default=None): - if left_default is not None or right_default is not None: - if ( - left_default is None - or right_default is None - or right is not None - or not isinstance(left, InfixExprBase) - or left.method_name != "ewise_add" - ): - raise TypeError( - "Specifying `left_default` or `right_default` keyword arguments implies " - "performing `ewise_union` operation with infix notation.\n" - "There is only one valid way to do this:\n\n" - f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " - "are Vectors or Matrices, and left_default and right_default are scalars." - ) - return left.left.ewise_union(left.right, self, left_default, right_default) - return _call_op(self, left, right) - - @property - def monoid(self): - rv = getattr(monoid, self.name, None) - if rv is not None and self.type in rv._typed_ops: - return rv[self.type] - - @property - def commutes_to(self): - commutes_to = self.parent.commutes_to - if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt): - return commutes_to[self.type] - - @property - def _semiring_commutes_to(self): - commutes_to = self.parent._semiring_commutes_to - if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt): - return commutes_to[self.type] - - @property - def is_commutative(self): - return self.commutes_to is self - - @property - def type2(self): - return self.type if self._type2 is None else self._type2 - - -class TypedBuiltinMonoid(TypedOpBase): - __slots__ = "_identity" - opclass = "Monoid" - is_commutative = True - - def __init__(self, parent, name, type_, return_type, gb_obj, gb_name): - super().__init__(parent, name, type_, return_type, gb_obj, gb_name) - self._identity = None - - def __call__(self, left, right=None, *, left_default=None, right_default=None): - if left_default is not None or right_default is not None: - if ( - left_default is None - or right_default is None - or right is not None - or not isinstance(left, InfixExprBase) - or left.method_name != "ewise_add" - ): - raise TypeError( - "Specifying `left_default` or `right_default` keyword arguments implies " - "performing `ewise_union` operation with infix notation.\n" - "There is only one valid way to do this:\n\n" - f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " - "are Vectors or Matrices, and left_default and right_default are scalars." - ) - return left.left.ewise_union(left.right, self, left_default, right_default) - return _call_op(self, left, right) - - @property - def identity(self): - if self._identity is None: - from .recorder import skip_record - from .vector import Vector - - with skip_record: - self._identity = ( - Vector(self.type, size=1, name="").reduce(self, allow_empty=False).new().value - ) - return self._identity - - @property - def binaryop(self): - return getattr(binary, self.name)[self.type] - - @property - def commutes_to(self): - return self - - @property - def type2(self): - return self.type - - @property - def is_idempotent(self): - """True if ``monoid(x, x) == x`` for any x.""" - return self.parent.is_idempotent - - -class TypedBuiltinSemiring(TypedOpBase): - __slots__ = () - opclass = "Semiring" - - def __call__(self, left, right=None): - if right is not None: - raise TypeError( - f"Bad types when calling {self!r}. Got types: {type(left)}, {type(right)}.\n" - f"Expected an infix expression, such as: {self!r}(A @ B)" - ) - return _call_op(self, left) - - @property - def binaryop(self): - name = self.name.split("_", 1)[1] - if name in _SS_OPERATORS: - binop = binary._deprecated[name] - else: - binop = getattr(binary, name) - return binop[self.type] - - @property - def monoid(self): - monoid_name, binary_name = self.name.split("_", 1) - if binary_name in _SS_OPERATORS: - binop = binary._deprecated[binary_name] - else: - binop = getattr(binary, binary_name) - binop = binop[self.type] - val = getattr(monoid, monoid_name) - return val[binop.return_type] - - @property - def commutes_to(self): - binop = self.binaryop - commutes_to = binop._semiring_commutes_to or binop.commutes_to - if commutes_to is None: - return - if commutes_to is binop: - return self - return get_semiring(self.monoid, commutes_to) - - @property - def is_commutative(self): - return self.binaryop.is_commutative - - type2 = TypedBuiltinBinaryOp.type2 - - -class TypedUserUnaryOp(TypedOpBase): - __slots__ = () - opclass = "UnaryOp" - - def __init__(self, parent, name, type_, return_type, gb_obj): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") - - @property - def orig_func(self): - return self.parent.orig_func - - @property - def _numba_func(self): - return self.parent._numba_func - - __call__ = TypedBuiltinUnaryOp.__call__ - - -class TypedUserIndexUnaryOp(TypedOpBase): - __slots__ = () - opclass = "IndexUnaryOp" - - def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) - - @property - def orig_func(self): - return self.parent.orig_func - - @property - def _numba_func(self): - return self.parent._numba_func - - __call__ = TypedBuiltinIndexUnaryOp.__call__ - - -class TypedUserSelectOp(TypedOpBase): - __slots__ = () - opclass = "SelectOp" - - def __init__(self, parent, name, type_, return_type, gb_obj): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") - - @property - def orig_func(self): - return self.parent.orig_func - - @property - def _numba_func(self): - return self.parent._numba_func - - __call__ = TypedBuiltinSelectOp.__call__ - - -class TypedUserBinaryOp(TypedOpBase): - __slots__ = "_monoid" - opclass = "BinaryOp" - - def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) - self._monoid = None - - @property - def monoid(self): - if self._monoid is None: - monoid = self.parent.monoid - if monoid is not None and self.type in monoid: - self._monoid = monoid[self.type] - return self._monoid - - commutes_to = TypedBuiltinBinaryOp.commutes_to - _semiring_commutes_to = TypedBuiltinBinaryOp._semiring_commutes_to - is_commutative = TypedBuiltinBinaryOp.is_commutative - orig_func = TypedUserUnaryOp.orig_func - _numba_func = TypedUserUnaryOp._numba_func - type2 = TypedBuiltinBinaryOp.type2 - __call__ = TypedBuiltinBinaryOp.__call__ - - -class TypedUserMonoid(TypedOpBase): - __slots__ = "binaryop", "identity" - opclass = "Monoid" - is_commutative = True - - def __init__(self, parent, name, type_, return_type, gb_obj, binaryop, identity): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") - self.binaryop = binaryop - self.identity = identity - binaryop._monoid = self - - commutes_to = TypedBuiltinMonoid.commutes_to - type2 = TypedBuiltinMonoid.type2 - is_idempotent = TypedBuiltinMonoid.is_idempotent - __call__ = TypedBuiltinMonoid.__call__ - - -class TypedUserSemiring(TypedOpBase): - __slots__ = "monoid", "binaryop" - opclass = "Semiring" - - def __init__(self, parent, name, type_, return_type, gb_obj, monoid, binaryop, dtype2=None): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) - self.monoid = monoid - self.binaryop = binaryop - - commutes_to = TypedBuiltinSemiring.commutes_to - is_commutative = TypedBuiltinSemiring.is_commutative - type2 = TypedBuiltinBinaryOp.type2 - __call__ = TypedBuiltinSemiring.__call__ - - -def _deserialize_parameterized(parameterized_op, args, kwargs): - return parameterized_op(*args, **kwargs) - - -class ParameterizedUdf: - __slots__ = "name", "__call__", "_anonymous", "__weakref__" - is_positional = False - _custom_dtype = None - - def __init__(self, name, anonymous): - self.name = name - self._anonymous = anonymous - # lru_cache per instance - method = self._call.__get__(self, type(self)) - self.__call__ = lru_cache(maxsize=1024)(method) - - def _call(self, *args, **kwargs): - raise NotImplementedError - - -class ParameterizedUnaryOp(ParameterizedUdf): - __slots__ = "func", "__signature__", "_is_udt" - - def __init__(self, name, func, *, anonymous=False, is_udt=False): - self.func = func - self.__signature__ = inspect.signature(func) - self._is_udt = is_udt - if name is None: - name = getattr(func, "__name__", name) - super().__init__(name, anonymous) - - def _call(self, *args, **kwargs): - unary = self.func(*args, **kwargs) - unary._parameterized_info = (self, args, kwargs) - return UnaryOp.register_anonymous(unary, self.name, is_udt=self._is_udt) - - def __reduce__(self): - name = f"unary.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover - return name - return (self._deserialize, (self.name, self.func, self._anonymous)) - - @staticmethod - def _deserialize(name, func, anonymous): - if anonymous: - return UnaryOp.register_anonymous(func, name, parameterized=True) - if (rv := UnaryOp._find(name)) is not None: - return rv - return UnaryOp.register_new(name, func, parameterized=True) - - -class ParameterizedIndexUnaryOp(ParameterizedUdf): - __slots__ = "func", "__signature__", "_is_udt" - - def __init__(self, name, func, *, anonymous=False, is_udt=False): - self.func = func - self.__signature__ = inspect.signature(func) - self._is_udt = is_udt - if name is None: - name = getattr(func, "__name__", name) - super().__init__(name, anonymous) - - def _call(self, *args, **kwargs): - indexunary = self.func(*args, **kwargs) - indexunary._parameterized_info = (self, args, kwargs) - return IndexUnaryOp.register_anonymous(indexunary, self.name, is_udt=self._is_udt) - - def __reduce__(self): - name = f"indexunary.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.func, self._anonymous)) - - @staticmethod - def _deserialize(name, func, anonymous): - if anonymous: - return IndexUnaryOp.register_anonymous(func, name, parameterized=True) - if (rv := IndexUnaryOp._find(name)) is not None: - return rv - return IndexUnaryOp.register_new(name, func, parameterized=True) - - -class ParameterizedSelectOp(ParameterizedUdf): - __slots__ = "func", "__signature__", "_is_udt" - - def __init__(self, name, func, *, anonymous=False, is_udt=False): - self.func = func - self.__signature__ = inspect.signature(func) - self._is_udt = is_udt - if name is None: - name = getattr(func, "__name__", name) - super().__init__(name, anonymous) - - def _call(self, *args, **kwargs): - sel = self.func(*args, **kwargs) - sel._parameterized_info = (self, args, kwargs) - return SelectOp.register_anonymous(sel, self.name, is_udt=self._is_udt) - - def __reduce__(self): - name = f"select.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.func, self._anonymous)) - - @staticmethod - def _deserialize(name, func, anonymous): - if anonymous: - return SelectOp.register_anonymous(func, name, parameterized=True) - if (rv := SelectOp._find(name)) is not None: - return rv - return SelectOp.register_new(name, func, parameterized=True) - - -class ParameterizedBinaryOp(ParameterizedUdf): - __slots__ = "func", "__signature__", "_monoid", "_cached_call", "_commutes_to", "_is_udt" - - def __init__(self, name, func, *, anonymous=False, is_udt=False): - self.func = func - self.__signature__ = inspect.signature(func) - self._monoid = None - self._is_udt = is_udt - if name is None: - name = getattr(func, "__name__", name) - super().__init__(name, anonymous) - method = self._call_to_cache.__get__(self, type(self)) - self._cached_call = lru_cache(maxsize=1024)(method) - self.__call__ = self._call - self._commutes_to = None - - def _call_to_cache(self, *args, **kwargs): - binary = self.func(*args, **kwargs) - binary._parameterized_info = (self, args, kwargs) - return BinaryOp.register_anonymous(binary, self.name, is_udt=self._is_udt) - - def _call(self, *args, **kwargs): - binop = self._cached_call(*args, **kwargs) - if self._monoid is not None and binop._monoid is None: - # This is all a bit funky. We try our best to associate a binaryop - # to a monoid. So, if we made a ParameterizedMonoid using this object, - # then try to create a monoid with the given arguments. - binop._monoid = binop # temporary! - try: - # If this call is successful, then it will set `binop._monoid` - self._monoid(*args, **kwargs) # pylint: disable=not-callable - except Exception: - binop._monoid = None - # assert binop._monoid is not binop - if self.is_commutative: - binop._commutes_to = binop - # Don't bother yet with creating `binop.commutes_to` (but we could!) - return binop - - @property - def monoid(self): - return self._monoid - - @property - def commutes_to(self): - if type(self._commutes_to) is str: - self._commutes_to = BinaryOp._find(self._commutes_to) - return self._commutes_to - - is_commutative = TypedBuiltinBinaryOp.is_commutative - - def __reduce__(self): - name = f"binary.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.func, self._anonymous)) - - @staticmethod - def _deserialize(name, func, anonymous): - if anonymous: - return BinaryOp.register_anonymous(func, name, parameterized=True) - if (rv := BinaryOp._find(name)) is not None: - return rv - return BinaryOp.register_new(name, func, parameterized=True) - - -class ParameterizedMonoid(ParameterizedUdf): - __slots__ = "binaryop", "identity", "_is_idempotent", "__signature__" - is_commutative = True - - def __init__(self, name, binaryop, identity, *, is_idempotent=False, anonymous=False): - if type(binaryop) is not ParameterizedBinaryOp: - raise TypeError("binaryop must be parameterized") - self.binaryop = binaryop - self.__signature__ = binaryop.__signature__ - if callable(identity): - # assume it must be parameterized as well, so signature must match - sig = inspect.signature(identity) - if sig != self.__signature__: - raise ValueError( - "Signatures of binaryop and identity passed to " - f"{type(self).__name__} must be the same. Got:\n" - f" binaryop{self.__signature__}\n" - " !=\n" - f" identity{sig}" - ) - self.identity = identity - self._is_idempotent = is_idempotent - if name is None: - name = binaryop.name - super().__init__(name, anonymous) - binaryop._monoid = self - # clear binaryop cache so it can be associated with this monoid - binaryop._cached_call.cache_clear() - - def _call(self, *args, **kwargs): - binary = self.binaryop(*args, **kwargs) - identity = self.identity - if callable(identity): - identity = identity(*args, **kwargs) - return Monoid.register_anonymous( - binary, identity, self.name, is_idempotent=self._is_idempotent - ) - - commutes_to = TypedBuiltinMonoid.commutes_to - - @property - def is_idempotent(self): - """True if ``monoid(x, x) == x`` for any x.""" - return self._is_idempotent - - def __reduce__(self): - name = f"monoid.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover - return name - return (self._deserialize, (self.name, self.binaryop, self.identity, self._anonymous)) - - @staticmethod - def _deserialize(name, binaryop, identity, anonymous): - if anonymous: - return Monoid.register_anonymous(binaryop, identity, name) - if (rv := Monoid._find(name)) is not None: - return rv - return Monoid.register_new(name, binaryop, identity) - - -class ParameterizedSemiring(ParameterizedUdf): - __slots__ = "monoid", "binaryop", "__signature__" - - def __init__(self, name, monoid, binaryop, *, anonymous=False): - if type(monoid) not in {ParameterizedMonoid, Monoid}: - raise TypeError("monoid must be of type Monoid or ParameterizedMonoid") - if type(binaryop) is ParameterizedBinaryOp: - self.__signature__ = binaryop.__signature__ - if type(monoid) is ParameterizedMonoid and monoid.__signature__ != self.__signature__: - raise ValueError( - "Signatures of monoid and binaryop passed to " - f"{type(self).__name__} must be the same. Got:\n" - f" monoid{monoid.__signature__}\n" - " !=\n" - f" binaryop{self.__signature__}\n\n" - "Perhaps call monoid or binaryop with parameters before creating the semiring." - ) - elif type(binaryop) is BinaryOp: - if type(monoid) is Monoid: - raise TypeError("At least one of monoid or binaryop must be parameterized") - self.__signature__ = monoid.__signature__ - else: - raise TypeError("binaryop must be of type BinaryOp or ParameterizedBinaryOp") - self.monoid = monoid - self.binaryop = binaryop - if name is None: - name = f"{monoid.name}_{binaryop.name}" - super().__init__(name, anonymous) - - def _call(self, *args, **kwargs): - monoid = self.monoid - if type(monoid) is ParameterizedMonoid: - monoid = monoid(*args, **kwargs) - binary = self.binaryop - if type(binary) is ParameterizedBinaryOp: - binary = binary(*args, **kwargs) - return Semiring.register_anonymous(monoid, binary, self.name) - - commutes_to = TypedBuiltinSemiring.commutes_to - is_commutative = TypedBuiltinSemiring.is_commutative - - def __reduce__(self): - name = f"semiring.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover - return name - return (self._deserialize, (self.name, self.monoid, self.binaryop, self._anonymous)) - - @staticmethod - def _deserialize(name, monoid, binaryop, anonymous): - if anonymous: - return Semiring.register_anonymous(monoid, binaryop, name) - if (rv := Semiring._find(name)) is not None: - return rv - return Semiring.register_new(name, monoid, binaryop) - - -_VARNAMES = tuple(x for x in dir(lib) if x[0] != "_") - - -class OpBase: - __slots__ = ( - "name", - "_typed_ops", - "types", - "coercions", - "_anonymous", - "_udt_types", - "_udt_ops", - "__weakref__", - ) - _parse_config = None - _initialized = False - _module = None - _positional = None - - def __init__(self, name, *, anonymous=False): - self.name = name - self._typed_ops = {} - self.types = {} - self.coercions = {} - self._anonymous = anonymous - self._udt_types = None - self._udt_ops = None - - def __repr__(self): - return f"{self._modname}.{self.name}" - - def __getitem__(self, type_): - if type(type_) is tuple: - dtype1, dtype2 = type_ - dtype1 = lookup_dtype(dtype1) - dtype2 = lookup_dtype(dtype2) - return get_typed_op(self, dtype1, dtype2) - if not self._is_udt: - type_ = lookup_dtype(type_) - if type_ not in self._typed_ops: - if self._udt_types is None: - if self.is_positional: - return self._typed_ops[UINT64] - raise KeyError(f"{self.name} does not work with {type_}") - else: - return self._typed_ops[type_] - # This is a UDT or is able to operate on UDTs such as `first` any `any` - dtype = lookup_dtype(type_) - return self._compile_udt(dtype, dtype) - - def _add(self, op): - self._typed_ops[op.type] = op - self.types[op.type] = op.return_type - - def __delitem__(self, type_): - type_ = lookup_dtype(type_) - del self._typed_ops[type_] - del self.types[type_] - - def __contains__(self, type_): - try: - self[type_] - except (TypeError, KeyError, numba.NumbaError): - return False - return True - - @classmethod - def _remove_nesting(cls, funcname, *, module=None, modname=None, strict=True): - if module is None: - module = cls._module - if modname is None: - modname = cls._modname - if "." not in funcname: - if strict and _hasop(module, funcname): - raise AttributeError(f"{modname}.{funcname} is already defined") - else: - path, funcname = funcname.rsplit(".", 1) - for folder in path.split("."): - if not _hasop(module, folder): - setattr(module, folder, OpPath(module, folder)) - module = getattr(module, folder) - modname = f"{modname}.{folder}" - if not isinstance(module, (OpPath, ModuleType)): - raise AttributeError( - f"{modname} is already defined. Cannot use as a nested path." - ) - if strict and _hasop(module, funcname): - raise AttributeError(f"{path}.{funcname} is already defined") - return module, funcname - - @classmethod - def _find(cls, funcname): - rv = cls._module - for attr in funcname.split("."): - if attr in getattr(rv, "_deprecated", ()): - rv = rv._deprecated[attr] - else: - rv = getattr(rv, attr, None) - if rv is None: - break - return rv - - @classmethod - def _initialize(cls, include_in_ops=True): - """ - include_in_ops determines whether the operators are included in the - `gb.ops` namespace in addition to the defined module. - """ - if cls._initialized: # pragma: no cover (safety) - return - # Read in the parse configs - trim_from_front = cls._parse_config.get("trim_from_front", 0) - delete_exact = cls._parse_config.get("delete_exact", None) - num_underscores = cls._parse_config["num_underscores"] - - for re_str, return_prefix in [ - ("re_exprs", None), - ("re_exprs_return_bool", "BOOL"), - ("re_exprs_return_float", "FP"), - ("re_exprs_return_complex", "FC"), - ]: - if re_str not in cls._parse_config: - continue - if "complex" in re_str and not _supports_complex: - continue - for r in reversed(cls._parse_config[re_str]): - for varname in _VARNAMES: - m = r.match(varname) - if m: - # Parse function into name and datatype - gb_name = m.string - splitname = gb_name[trim_from_front:].split("_") - if delete_exact and delete_exact in splitname: - splitname.remove(delete_exact) - if len(splitname) == num_underscores + 1: - *splitname, type_ = splitname - else: - type_ = None - name = "_".join(splitname).lower() - # Create object for name unless it already exists - if not _hasop(cls._module, name): - if backend == "suitesparse" and name in _SS_OPERATORS: - fullname = f"ss.{name}" - else: - fullname = name - if cls._positional is None: - obj = cls(fullname) - else: - obj = cls(fullname, is_positional=name in cls._positional) - if name in _SS_OPERATORS: - if backend == "suitesparse": - setattr(cls._module.ss, name, obj) - cls._module._deprecated[name] = obj - if include_in_ops and not _hasop(op, name): # pragma: no branch - op._deprecated[name] = obj - if backend == "suitesparse": - setattr(op.ss, name, obj) - else: - setattr(cls._module, name, obj) - if include_in_ops and not _hasop(op, name): - setattr(op, name, obj) - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{fullname}") - elif name in _SS_OPERATORS: - obj = cls._module._deprecated[name] - else: - obj = getattr(cls._module, name) - gb_obj = getattr(lib, varname) - # Determine return type - if return_prefix == "BOOL": - return_type = BOOL - if type_ is None: - type_ = BOOL - else: - if type_ is None: # pragma: no cover - raise TypeError(f"Unable to determine return type for {varname}") - if return_prefix is None: - return_type = type_ - else: - # Grab the number of bits from type_ - num_bits = type_[-2:] - if num_bits not in {"32", "64"}: # pragma: no cover (safety) - raise TypeError(f"Unexpected number of bits: {num_bits}") - return_type = f"{return_prefix}{num_bits}" - builtin_op = cls._typed_class( - obj, - name, - lookup_dtype(type_), - lookup_dtype(return_type), - gb_obj, - gb_name, - ) - obj._add(builtin_op) - - @classmethod - def _deserialize(cls, name, *args): - if (rv := cls._find(name)) is not None: - return rv # Should we verify this is what the user expects? - return cls.register_new(name, *args) - - -def _identity(x): - return x # pragma: no cover (numba) - - -def _one(x): - return 1 # pragma: no cover (numba) - - -class UnaryOp(OpBase): - """Takes one input and returns one output, possibly of a different data type. - - Built-in and registered UnaryOps are located in the ``graphblas.unary`` namespace - as well as in the ``graphblas.ops`` combined namespace. - """ - - __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" - _custom_dtype = None - _module = unary - _modname = "unary" - _typed_class = TypedBuiltinUnaryOp - _parse_config = { - "trim_from_front": 4, - "num_underscores": 1, - "re_exprs": [ - re.compile( - "^GrB_(IDENTITY|AINV|MINV|ABS|BNOT)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" - ), - re.compile( - "^GxB_(LNOT|ONE|POSITIONI1|POSITIONI|POSITIONJ1|POSITIONJ)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(SQRT|LOG|EXP|LOG2|SIN|COS|TAN|ACOS|ASIN|ATAN|SINH|COSH|TANH|ACOSH" - "|ASINH|ATANH|SIGNUM|CEIL|FLOOR|ROUND|TRUNC|EXP2|EXPM1|LOG10|LOG1P)" - "_(FP32|FP64|FC32|FC64)$" - ), - re.compile("^GxB_(LGAMMA|TGAMMA|ERF|ERFC|FREXPX|FREXPE|CBRT)_(FP32|FP64)$"), - re.compile("^GxB_(IDENTITY|AINV|MINV|ONE|CONJ)_(FC32|FC64)$"), - ], - "re_exprs_return_bool": [ - re.compile("^GrB_LNOT$"), - re.compile("^GxB_(ISINF|ISNAN|ISFINITE)_(FP32|FP64|FC32|FC64)$"), - ], - "re_exprs_return_float": [re.compile("^GxB_(CREAL|CIMAG|CARG|ABS)_(FC32|FC64)$")], - } - _positional = {"positioni", "positioni1", "positionj", "positionj1"} - - @classmethod - def _build(cls, name, func, *, anonymous=False, is_udt=False): - if type(func) is not FunctionType: - raise TypeError(f"UDF argument must be a function, not {type(func)}") - if name is None: - name = getattr(func, "__name__", "") - success = False - unary_udf = numba.njit(func) - new_type_obj = cls(name, func, anonymous=anonymous, is_udt=is_udt, numba_func=unary_udf) - return_types = {} - nt = numba.types - if not is_udt: - for type_ in _sample_values: - sig = (type_.numba_type,) - try: - unary_udf.compile(sig) - except numba.TypingError: - continue - ret_type = lookup_dtype(unary_udf.overloads[sig].signature.return_type) - if ret_type != type_ and ( - ("INT" in ret_type.name and "INT" in type_.name) - or ("FP" in ret_type.name and "FP" in type_.name) - or ("FC" in ret_type.name and "FC" in type_.name) - or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) - ): - # Downcast `ret_type` to `type_`. - # This is what users want most of the time, but we can't make a perfect rule. - # There should be a way for users to be explicit. - ret_type = type_ - elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: - ret_type = INT8 - - # Numba is unable to handle BOOL correctly right now, but we have a workaround - # See: https://github.com/numba/numba/issues/5395 - # We're relying on coercion behaving correctly here - input_type = INT8 if type_ == BOOL else type_ - return_type = INT8 if ret_type == BOOL else ret_type - - # Build wrapper because GraphBLAS wants pointers and void return - wrapper_sig = nt.void( - nt.CPointer(return_type.numba_type), - nt.CPointer(input_type.numba_type), - ) - - if type_ == BOOL: - if ret_type == BOOL: - - def unary_wrapper(z, x): - z[0] = bool(unary_udf(bool(x[0]))) # pragma: no cover (numba) - - else: - - def unary_wrapper(z, x): - z[0] = unary_udf(bool(x[0])) # pragma: no cover (numba) - - elif ret_type == BOOL: - - def unary_wrapper(z, x): - z[0] = bool(unary_udf(x[0])) # pragma: no cover (numba) - - else: - - def unary_wrapper(z, x): - z[0] = unary_udf(x[0]) # pragma: no cover (numba) - - unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper) - new_unary = ffi_new("GrB_UnaryOp*") - check_status_carg( - lib.GrB_UnaryOp_new( - new_unary, unary_wrapper.cffi, ret_type.gb_obj, type_.gb_obj - ), - "UnaryOp", - new_unary, - ) - op = TypedUserUnaryOp(new_type_obj, name, type_, ret_type, new_unary[0]) - new_type_obj._add(op) - success = True - return_types[type_] = ret_type - if success or is_udt: - return new_type_obj - raise UdfParseError("Unable to parse function using Numba") - - def _compile_udt(self, dtype, dtype2): - if dtype in self._udt_types: - return self._udt_ops[dtype] - - numba_func = self._numba_func - sig = (dtype.numba_type,) - numba_func.compile(sig) # Should we catch and give additional error message? - ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) - - unary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype) - unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper) - new_unary = ffi_new("GrB_UnaryOp*") - check_status_carg( - lib.GrB_UnaryOp_new(new_unary, unary_wrapper.cffi, ret_type._carg, dtype._carg), - "UnaryOp", - new_unary, - ) - op = TypedUserUnaryOp(self, self.name, dtype, ret_type, new_unary[0]) - self._udt_types[dtype] = ret_type - self._udt_ops[dtype] = op - return op - - @classmethod - def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): - """Register a UnaryOp without registering it in the ``graphblas.unary`` namespace. - - Because it is not registered in the namespace, the name is optional. - """ - if parameterized: - return ParameterizedUnaryOp(name, func, anonymous=True, is_udt=is_udt) - return cls._build(name, func, anonymous=True, is_udt=is_udt) - - @classmethod - def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): - """Register a UnaryOp. The name will be used to identify the UnaryOp in the - ``graphblas.unary`` namespace. - - >>> gb.core.operator.UnaryOp.register_new("plus_one", lambda x: x + 1) - >>> dir(gb.unary) - [..., 'plus_one', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "func": func, "parameterized": parameterized}, - ) - elif parameterized: - unary_op = ParameterizedUnaryOp(name, func, is_udt=is_udt) - setattr(module, funcname, unary_op) - else: - unary_op = cls._build(name, func, is_udt=is_udt) - setattr(module, funcname, unary_op) - # Also save it to `graphblas.op` if not yet defined - opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) - if not _hasop(opmodule, funcname): - if lazy: - opmodule._delayed[funcname] = module - else: - setattr(opmodule, funcname, unary_op) - if not cls._initialized: # pragma: no cover - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return unary_op - - @classmethod - def _initialize(cls): - if cls._initialized: - return - super()._initialize() - # Update type information with sane coercion - position_dtypes = [ - BOOL, - FP32, - FP64, - INT8, - INT16, - UINT8, - UINT16, - UINT32, - UINT64, - ] - if _supports_complex: - position_dtypes.extend([FC32, FC64]) - for names, *types in [ - # fmt: off - ( - ( - "erf", "erfc", "lgamma", "tgamma", "acos", "acosh", "asin", "asinh", - "atan", "atanh", "ceil", "cos", "cosh", "exp", "exp2", "expm1", "floor", - "log", "log10", "log1p", "log2", "round", "signum", "sin", "sinh", "sqrt", - "tan", "tanh", "trunc", "cbrt", - ), - ((BOOL, INT8, INT16, UINT8, UINT16), FP32), - ((INT32, INT64, UINT32, UINT64), FP64), - ), - ( - ("positioni", "positioni1", "positionj", "positionj1"), - ( - position_dtypes, - INT64, - ), - ), - # fmt: on - ]: - for name in names: - if name in _SS_OPERATORS: - op = unary._deprecated[name] - else: - op = getattr(unary, name) - for input_types, target_type in types: - typed_op = op._typed_ops[target_type] - output_type = op.types[target_type] - for dtype in input_types: - if dtype not in op.types: # pragma: no branch (safety) - op.types[dtype] = output_type - op._typed_ops[dtype] = typed_op - op.coercions[dtype] = target_type - # Allow some functions to work on UDTs - for unop, func in [ - (unary.identity, _identity), - (unary.one, _one), - ]: - unop.orig_func = func - unop._numba_func = numba.njit(func) - unop._udt_types = {} - unop._udt_ops = {} - cls._initialized = True - - def __init__( - self, - name, - func=None, - *, - anonymous=False, - is_positional=False, - is_udt=False, - numba_func=None, - ): - super().__init__(name, anonymous=anonymous) - self.orig_func = func - self._numba_func = numba_func - self.is_positional = is_positional - self._is_udt = is_udt - if is_udt: - self._udt_types = {} # {dtype: DataType} - self._udt_ops = {} # {dtype: TypedUserUnaryOp} - - def __reduce__(self): - if self._anonymous: - if hasattr(self.orig_func, "_parameterized_info"): - return (_deserialize_parameterized, self.orig_func._parameterized_info) - return (self.register_anonymous, (self.orig_func, self.name)) - if (name := f"unary.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.orig_func)) - - __call__ = TypedBuiltinUnaryOp.__call__ - - -class IndexUnaryOp(OpBase): - """Takes one input and a thunk and returns one output, possibly of a different data type. - Along with the input value, the index(es) of the element are given to the function. - - This is an advanced form of a unary operation that allows, for example, converting - elements of a Vector to their index position to build a ramp structure. Another use - case is returning a boolean value indicating whether the element is part of the upper - triangular structure of a Matrix. - - Built-in and registered IndexUnaryOps are located in the ``graphblas.indexunary`` namespace. - """ - - __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" - _module = indexunary - _modname = "indexunary" - _custom_dtype = None - _typed_class = TypedBuiltinIndexUnaryOp - _typed_user_class = TypedUserIndexUnaryOp - _parse_config = { - "trim_from_front": 4, - "num_underscores": 1, - "re_exprs": [ - re.compile("^GrB_(ROWINDEX|COLINDEX|DIAGINDEX)_(INT32|INT64)$"), - ], - "re_exprs_return_bool": [ - re.compile("^GrB_(TRIL|TRIU|DIAG|OFFDIAG|COLLE|COLGT|ROWLE|ROWGT)$"), - re.compile( - "^GrB_(VALUEEQ|VALUENE|VALUEGT|VALUEGE|VALUELT|VALUELE)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile("^GxB_(VALUEEQ|VALUENE)_(FC32|FC64)$"), - ], - } - _positional = {"tril", "triu", "diag", "offdiag", "colle", "colgt", "rowle", "rowgt", - "rowindex", "colindex"} # fmt: skip - - @classmethod - def _build(cls, name, func, *, is_udt=False, anonymous=False): - if not isinstance(func, FunctionType): - raise TypeError(f"UDF argument must be a function, not {type(func)}") - if name is None: - name = getattr(func, "__name__", "") - success = False - indexunary_udf = numba.njit(func) - new_type_obj = cls( - name, func, anonymous=anonymous, is_udt=is_udt, numba_func=indexunary_udf - ) - return_types = {} - nt = numba.types - if not is_udt: - for type_ in _sample_values: - sig = (type_.numba_type, UINT64.numba_type, UINT64.numba_type, type_.numba_type) - try: - indexunary_udf.compile(sig) - except numba.TypingError: - continue - ret_type = lookup_dtype(indexunary_udf.overloads[sig].signature.return_type) - if ret_type != type_ and ( - ("INT" in ret_type.name and "INT" in type_.name) - or ("FP" in ret_type.name and "FP" in type_.name) - or ("FC" in ret_type.name and "FC" in type_.name) - or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) - ): - # Downcast `ret_type` to `type_`. - # This is what users want most of the time, but we can't make a perfect rule. - # There should be a way for users to be explicit. - ret_type = type_ - elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: - ret_type = INT8 - - # Numba is unable to handle BOOL correctly right now, but we have a workaround - # See: https://github.com/numba/numba/issues/5395 - # We're relying on coercion behaving correctly here - input_type = INT8 if type_ == BOOL else type_ - return_type = INT8 if ret_type == BOOL else ret_type - - # Build wrapper because GraphBLAS wants pointers and void return - wrapper_sig = nt.void( - nt.CPointer(return_type.numba_type), - nt.CPointer(input_type.numba_type), - UINT64.numba_type, - UINT64.numba_type, - nt.CPointer(input_type.numba_type), - ) - - if type_ == BOOL: - if ret_type == BOOL: - - def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) - z[0] = bool(indexunary_udf(bool(x[0]), row, col, bool(y[0]))) - - else: - - def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) - z[0] = indexunary_udf(bool(x[0]), row, col, bool(y[0])) - - elif ret_type == BOOL: - - def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) - z[0] = bool(indexunary_udf(x[0], row, col, y[0])) - - else: - - def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) - z[0] = indexunary_udf(x[0], row, col, y[0]) - - indexunary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(indexunary_wrapper) - new_indexunary = ffi_new("GrB_IndexUnaryOp*") - check_status_carg( - lib.GrB_IndexUnaryOp_new( - new_indexunary, - indexunary_wrapper.cffi, - ret_type.gb_obj, - type_.gb_obj, - type_.gb_obj, - ), - "IndexUnaryOp", - new_indexunary, - ) - op = cls._typed_user_class(new_type_obj, name, type_, ret_type, new_indexunary[0]) - new_type_obj._add(op) - success = True - return_types[type_] = ret_type - if success or is_udt: - return new_type_obj - raise UdfParseError("Unable to parse function using Numba") - - def _compile_udt(self, dtype, dtype2): - if dtype2 is None: # pragma: no cover - dtype2 = dtype - dtypes = (dtype, dtype2) - if dtypes in self._udt_types: - return self._udt_ops[dtypes] - - numba_func = self._numba_func - sig = (dtype.numba_type, UINT64.numba_type, UINT64.numba_type, dtype2.numba_type) - numba_func.compile(sig) # Should we catch and give additional error message? - ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) - indexunary_wrapper, wrapper_sig = _get_udt_wrapper( - numba_func, ret_type, dtype, dtype2, include_indexes=True - ) - - indexunary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(indexunary_wrapper) - new_indexunary = ffi_new("GrB_IndexUnaryOp*") - check_status_carg( - lib.GrB_IndexUnaryOp_new( - new_indexunary, indexunary_wrapper.cffi, ret_type._carg, dtype._carg, dtype2._carg - ), - "IndexUnaryOp", - new_indexunary, - ) - op = TypedUserIndexUnaryOp( - self, - self.name, - dtype, - ret_type, - new_indexunary[0], - dtype2=dtype2, - ) - self._udt_types[dtypes] = ret_type - self._udt_ops[dtypes] = op - return op - - @classmethod - def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): - """Register an IndexUnaryOp without registering it in the - ``graphblas.indexunary`` namespace. - - Because it is not registered in the namespace, the name is optional. - """ - if parameterized: - return ParameterizedIndexUnaryOp(name, func, anonymous=True, is_udt=is_udt) - return cls._build(name, func, anonymous=True, is_udt=is_udt) - - @classmethod - def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): - """Register an IndexUnaryOp. The name will be used to identify the IndexUnaryOp in the - ``graphblas.indexunary`` namespace. - - If the return type is Boolean, the function will also be registered as a SelectOp - with the same name. - - >>> gb.indexunary.register_new("row_mod", lambda x, i, j, thunk: i % max(thunk, 2)) - >>> dir(gb.indexunary) - [..., 'row_mod', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "func": func, "parameterized": parameterized}, - ) - elif parameterized: - indexunary_op = ParameterizedIndexUnaryOp(name, func, is_udt=is_udt) - setattr(module, funcname, indexunary_op) - else: - indexunary_op = cls._build(name, func, is_udt=is_udt) - setattr(module, funcname, indexunary_op) - # If return type is BOOL, register additionally as a SelectOp - if all(x == BOOL for x in indexunary_op.types.values()): - setattr(select, funcname, SelectOp._from_indexunary(indexunary_op)) - - if not cls._initialized: - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return indexunary_op - - @classmethod - def _initialize(cls): - if cls._initialized: - return - super()._initialize(include_in_ops=False) - # Update type information to include UINT64 for positional ops - for name in ["tril", "triu", "diag", "offdiag", "colle", "colgt", "rowle", "rowgt"]: - op = getattr(indexunary, name) - typed_op = op._typed_ops[BOOL] - output_type = op.types[BOOL] - if UINT64 not in op.types: # pragma: no branch (safety) - op.types[UINT64] = output_type - op._typed_ops[UINT64] = typed_op - op.coercions[UINT64] = BOOL - for name in ["rowindex", "colindex"]: - op = getattr(indexunary, name) - typed_op = op._typed_ops[INT64] - output_type = op.types[INT64] - if UINT64 not in op.types: # pragma: no branch (safety) - op.types[UINT64] = output_type - op._typed_ops[UINT64] = typed_op - op.coercions[UINT64] = INT64 - # Add index->row alias to make it more intuitive which to use for vectors - indexunary.indexle = indexunary.rowle - indexunary.indexgt = indexunary.rowgt - indexunary.index = indexunary.rowindex - # fmt: off - # Add SelectOp when it makes sense - for name in ["tril", "triu", "diag", "offdiag", - "colle", "colgt", "rowle", "rowgt", "indexle", "indexgt", - "valueeq", "valuene", "valuegt", "valuege", "valuelt", "valuele"]: - iop = getattr(indexunary, name) - setattr(select, name, SelectOp._from_indexunary(iop)) - # fmt: on - cls._initialized = True - - def __init__( - self, - name, - func=None, - *, - anonymous=False, - is_positional=False, - is_udt=False, - numba_func=None, - ): - super().__init__(name, anonymous=anonymous) - self.orig_func = func - self._numba_func = numba_func - self.is_positional = is_positional - self._is_udt = is_udt - if is_udt: - self._udt_types = {} # {dtype: DataType} - self._udt_ops = {} # {dtype: TypedUserIndexUnaryOp} - - def __reduce__(self): - if self._anonymous: - if hasattr(self.orig_func, "_parameterized_info"): - return (_deserialize_parameterized, self.orig_func._parameterized_info) - return (self.register_anonymous, (self.orig_func, self.name)) - if (name := f"indexunary.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.orig_func)) - - __call__ = TypedBuiltinIndexUnaryOp.__call__ - - -class SelectOp(OpBase): - """Identical to an :class:`IndexUnaryOp `, - but must have a Boolean return type. - - A SelectOp is used exclusively to select a subset of values from a collection where - the function returns True. - - Built-in and registered SelectOps are located in the ``graphblas.select`` namespace. - """ - - __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" - _module = select - _modname = "select" - _custom_dtype = None - _typed_class = TypedBuiltinSelectOp - _typed_user_class = TypedUserSelectOp - - @classmethod - def _from_indexunary(cls, iop): - obj = cls( - iop.name, - iop.orig_func, - anonymous=iop._anonymous, - is_positional=iop.is_positional, - is_udt=iop._is_udt, - numba_func=iop._numba_func, - ) - if not all(x == BOOL for x in iop.types.values()): - raise ValueError("SelectOp must have BOOL return type") - for type_, t in iop._typed_ops.items(): - if iop.orig_func is not None: - op = cls._typed_user_class( - obj, - iop.name, - t.type, - t.return_type, - t.gb_obj, - ) - else: - op = cls._typed_class( - obj, - iop.name, - t.type, - t.return_type, - t.gb_obj, - t.gb_name, - ) - # type is not always equal to t.type, so can't use op._add - # but otherwise perform the same logic - obj._typed_ops[type_] = op - obj.types[type_] = op.return_type - return obj - - @classmethod - def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): - """Register a SelectOp without registering it in the ``graphblas.select`` namespace. - - Because it is not registered in the namespace, the name is optional. - """ - if parameterized: - return ParameterizedSelectOp(name, func, anonymous=True, is_udt=is_udt) - iop = IndexUnaryOp._build(name, func, anonymous=True, is_udt=is_udt) - return SelectOp._from_indexunary(iop) - - @classmethod - def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): - """Register a SelectOp. The name will be used to identify the SelectOp in the - ``graphblas.select`` namespace. - - The function will also be registered as a IndexUnaryOp with the same name. - - >>> gb.select.register_new("upper_left_triangle", lambda x, i, j, thunk: i + j <= thunk) - >>> dir(gb.select) - [..., 'upper_left_triangle', ...] - """ - iop = IndexUnaryOp.register_new( - name, func, parameterized=parameterized, is_udt=is_udt, lazy=lazy - ) - if not all(x == BOOL for x in iop.types.values()): - raise ValueError("SelectOp must have BOOL return type") - if lazy: - return getattr(select, iop.name) - - @classmethod - def _initialize(cls): - if cls._initialized: # pragma: no cover (safety) - return - # IndexUnaryOp adds it boolean-returning objects to SelectOp - IndexUnaryOp._initialize() - cls._initialized = True - - def __init__( - self, - name, - func=None, - *, - anonymous=False, - is_positional=False, - is_udt=False, - numba_func=None, - ): - super().__init__(name, anonymous=anonymous) - self.orig_func = func - self._numba_func = numba_func - self.is_positional = is_positional - self._is_udt = is_udt - if is_udt: - self._udt_types = {} # {dtype: DataType} - self._udt_ops = {} # {dtype: TypedUserIndexUnaryOp} - - def __reduce__(self): - if self._anonymous: - if hasattr(self.orig_func, "_parameterized_info"): - return (_deserialize_parameterized, self.orig_func._parameterized_info) - return (self.register_anonymous, (self.orig_func, self.name)) - if (name := f"select.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.orig_func)) - - __call__ = TypedBuiltinSelectOp.__call__ - - -def _floordiv(x, y): - return x // y # pragma: no cover (numba) - - -def _rfloordiv(x, y): - return y // x # pragma: no cover (numba) - - -def _absfirst(x, y): - return np.abs(x) # pragma: no cover (numba) - - -def _abssecond(x, y): - return np.abs(y) # pragma: no cover (numba) - - -def _rpow(x, y): - return y**x # pragma: no cover (numba) - - -def _isclose(rel_tol=1e-7, abs_tol=0.0): - def inner(x, y): # pragma: no cover (numba) - return x == y or abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol) - - return inner - - -_MAX_INT64 = np.iinfo(np.int64).max - - -def _binom(N, k): # pragma: no cover (numba) - # Returns 0 if overflow or out-of-bounds - if k > N or k < 0: - return 0 - val = np.int64(1) - for i in range(min(k, N - k)): - if val > _MAX_INT64 // (N - i): # Overflow - return 0 - val *= N - i - val //= i + 1 - return val - - -# Kinda complicated, but works for now -def _register_binom(): - # "Fake" UDT so we only compile once for INT64 - op = BinaryOp.register_new("binom", _binom, is_udt=True) - typed_op = op[INT64, INT64] - # Make this look like a normal operator - for dtype in [UINT8, UINT16, UINT32, UINT64, INT8, INT16, INT32, INT64]: - op.types[dtype] = INT64 - op._typed_ops[dtype] = typed_op - if dtype != INT64: - op.coercions[dtype] = typed_op - # And make it not look like it operates on UDTs - typed_op._type2 = None - op._is_udt = False - op._udt_types = None - op._udt_ops = None - return op - - -def _first(x, y): - return x # pragma: no cover (numba) - - -def _second(x, y): - return y # pragma: no cover (numba) - - -def _pair(x, y): - return 1 # pragma: no cover (numba) - - -def _first_dtype(op, dtype, dtype2): - if dtype._is_udt or dtype2._is_udt: - return op._compile_udt(dtype, dtype2) - - -def _second_dtype(op, dtype, dtype2): - if dtype._is_udt or dtype2._is_udt: - return op._compile_udt(dtype, dtype2) - - -def _pair_dtype(op, dtype, dtype2): - return op[INT64] - - -def _get_udt_wrapper(numba_func, return_type, dtype, dtype2=None, *, include_indexes=False): - ztype = INT8 if return_type == BOOL else return_type - xtype = INT8 if dtype == BOOL else dtype - nt = numba.types - wrapper_args = [nt.CPointer(ztype.numba_type), nt.CPointer(xtype.numba_type)] - if include_indexes: - wrapper_args.extend([UINT64.numba_type, UINT64.numba_type]) - if dtype2 is not None: - ytype = INT8 if dtype2 == BOOL else dtype2 - wrapper_args.append(nt.CPointer(ytype.numba_type)) - wrapper_sig = nt.void(*wrapper_args) - - zarray = xarray = yarray = BL = BR = yarg = yname = rcidx = "" - if return_type._is_udt: - if return_type.np_type.subdtype is None: - zarray = " z = numba.carray(z_ptr, 1)\n" - zname = "z[0]" - else: - zname = "z_ptr[0]" - BR = "[0]" - else: - zname = "z_ptr[0]" - if return_type == BOOL: - BL = "bool(" - BR = ")" - - if dtype._is_udt: - if dtype.np_type.subdtype is None: - xarray = " x = numba.carray(x_ptr, 1)\n" - xname = "x[0]" - else: - xname = "x_ptr" - elif dtype == BOOL: - xname = "bool(x_ptr[0])" - else: - xname = "x_ptr[0]" - - if dtype2 is not None: - yarg = ", y_ptr" - if dtype2._is_udt: - if dtype2.np_type.subdtype is None: - yarray = " y = numba.carray(y_ptr, 1)\n" - yname = ", y[0]" - else: - yname = ", y_ptr" - elif dtype2 == BOOL: - yname = ", bool(y_ptr[0])" - else: - yname = ", y_ptr[0]" - - if include_indexes: - rcidx = ", row, col" - - d = {"numba": numba, "numba_func": numba_func} - text = ( - f"def wrapper(z_ptr, x_ptr{rcidx}{yarg}):\n" - f"{zarray}{xarray}{yarray}" - f" {zname} = {BL}numba_func({xname}{rcidx}{yname}){BR}\n" - ) - exec(text, d) # pylint: disable=exec-used - return d["wrapper"], wrapper_sig - - -class BinaryOp(OpBase): - """Takes two inputs and returns one output, possibly of a different data type. - - Built-in and registered BinaryOps are located in the ``graphblas.binary`` namespace - as well as in the ``graphblas.ops`` combined namespace. - """ - - __slots__ = ( - "_monoid", - "_commutes_to", - "_semiring_commutes_to", - "orig_func", - "is_positional", - "_is_udt", - "_numba_func", - "_custom_dtype", - ) - _module = binary - _modname = "binary" - _typed_class = TypedBuiltinBinaryOp - _parse_config = { - "trim_from_front": 4, - "num_underscores": 1, - "re_exprs": [ - re.compile( - "^GrB_(FIRST|SECOND|PLUS|MINUS|TIMES|DIV|MIN|MAX)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" - ), - re.compile( - "GrB_(BOR|BAND|BXOR|BXNOR)_(INT8|INT16|INT32|INT64|UINT8|UINT16|UINT32|UINT64)$" - ), - re.compile( - "^GxB_(POW|RMINUS|RDIV|PAIR|ANY|ISEQ|ISNE|ISGT|ISLT|ISGE|ISLE|LOR|LAND|LXOR)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" - ), - re.compile("^GxB_(FIRST|SECOND|PLUS|MINUS|TIMES|DIV)_(FC32|FC64)$"), - re.compile("^GxB_(ATAN2|HYPOT|FMOD|REMAINDER|LDEXP|COPYSIGN)_(FP32|FP64)$"), - re.compile( - "GxB_(BGET|BSET|BCLR|BSHIFT|FIRSTI1|FIRSTI|FIRSTJ1|FIRSTJ" - "|SECONDI1|SECONDI|SECONDJ1|SECONDJ)" - "_(INT8|INT16|INT32|INT64|UINT8|UINT16|UINT32|UINT64)$" - ), - # These are coerced to 0 or 1, but don't return BOOL - re.compile( - "^GxB_(LOR|LAND|LXOR|LXNOR)_" - "(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - ], - "re_exprs_return_bool": [ - re.compile("^GrB_(LOR|LAND|LXOR|LXNOR)$"), - re.compile( - "^GrB_(EQ|NE|GT|LT|GE|LE)_" - "(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile("^GxB_(EQ|NE)_(FC32|FC64)$"), - ], - "re_exprs_return_complex": [re.compile("^GxB_(CMPLX)_(FP32|FP64)$")], - } - _commutes = { - # builtins - "cdiv": "rdiv", - "first": "second", - "ge": "le", - "gt": "lt", - "isge": "isle", - "isgt": "islt", - "minus": "rminus", - "pow": "rpow", - # special - "firsti": "secondi", - "firsti1": "secondi1", - "firstj": "secondj", - "firstj1": "secondj1", - # custom - # "absfirst": "abssecond", # handled in graphblas.binary - # "floordiv": "rfloordiv", - "truediv": "rtruediv", - } - _commutes_to_in_semiring = { - "firsti": "secondj", - "firsti1": "secondj1", - "firstj": "secondi", - "firstj1": "secondi1", - } - _commutative = { - # monoids - "any", - "band", - "bor", - "bxnor", - "bxor", - "eq", - "land", - "lor", - "lxnor", - "lxor", - "max", - "min", - "plus", - "times", - # other - "hypot", - "isclose", - "iseq", - "isne", - "ne", - "pair", - } - # Don't commute: atan2, bclr, bget, bset, bshift, cmplx, copysign, fmod, ldexp, remainder - _positional = { - "firsti", - "firsti1", - "firstj", - "firstj1", - "secondi", - "secondi1", - "secondj", - "secondj1", - } - - @classmethod - def _build(cls, name, func, *, is_udt=False, anonymous=False): - if not isinstance(func, FunctionType): - raise TypeError(f"UDF argument must be a function, not {type(func)}") - if name is None: - name = getattr(func, "__name__", "") - success = False - binary_udf = numba.njit(func) - new_type_obj = cls(name, func, anonymous=anonymous, is_udt=is_udt, numba_func=binary_udf) - return_types = {} - nt = numba.types - if not is_udt: - for type_ in _sample_values: - sig = (type_.numba_type, type_.numba_type) - try: - binary_udf.compile(sig) - except numba.TypingError: - continue - ret_type = lookup_dtype(binary_udf.overloads[sig].signature.return_type) - if ret_type != type_ and ( - ("INT" in ret_type.name and "INT" in type_.name) - or ("FP" in ret_type.name and "FP" in type_.name) - or ("FC" in ret_type.name and "FC" in type_.name) - or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) - ): - # Downcast `ret_type` to `type_`. - # This is what users want most of the time, but we can't make a perfect rule. - # There should be a way for users to be explicit. - ret_type = type_ - elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: - ret_type = INT8 - - # Numba is unable to handle BOOL correctly right now, but we have a workaround - # See: https://github.com/numba/numba/issues/5395 - # We're relying on coercion behaving correctly here - input_type = INT8 if type_ == BOOL else type_ - return_type = INT8 if ret_type == BOOL else ret_type - - # Build wrapper because GraphBLAS wants pointers and void return - wrapper_sig = nt.void( - nt.CPointer(return_type.numba_type), - nt.CPointer(input_type.numba_type), - nt.CPointer(input_type.numba_type), - ) - - if type_ == BOOL: - if ret_type == BOOL: - - def binary_wrapper(z, x, y): # pragma: no cover (numba) - z[0] = bool(binary_udf(bool(x[0]), bool(y[0]))) - - else: - - def binary_wrapper(z, x, y): # pragma: no cover (numba) - z[0] = binary_udf(bool(x[0]), bool(y[0])) - - elif ret_type == BOOL: - - def binary_wrapper(z, x, y): # pragma: no cover (numba) - z[0] = bool(binary_udf(x[0], y[0])) - - else: - - def binary_wrapper(z, x, y): # pragma: no cover (numba) - z[0] = binary_udf(x[0], y[0]) - - binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper) - new_binary = ffi_new("GrB_BinaryOp*") - check_status_carg( - lib.GrB_BinaryOp_new( - new_binary, - binary_wrapper.cffi, - ret_type.gb_obj, - type_.gb_obj, - type_.gb_obj, - ), - "BinaryOp", - new_binary, - ) - op = TypedUserBinaryOp(new_type_obj, name, type_, ret_type, new_binary[0]) - new_type_obj._add(op) - success = True - return_types[type_] = ret_type - if success or is_udt: - return new_type_obj - raise UdfParseError("Unable to parse function using Numba") - - def _compile_udt(self, dtype, dtype2): - if dtype2 is None: - dtype2 = dtype - dtypes = (dtype, dtype2) - if dtypes in self._udt_types: - return self._udt_ops[dtypes] - - nt = numba.types - if self.name == "eq" and not self._anonymous: - # assert dtype.np_type == dtype2.np_type - itemsize = dtype.np_type.itemsize - mask = _udt_mask(dtype.np_type) - ret_type = BOOL - wrapper_sig = nt.void( - nt.CPointer(INT8.numba_type), - nt.CPointer(UINT8.numba_type), - nt.CPointer(UINT8.numba_type), - ) - # PERF: we can probably make this faster - if mask.all(): - - def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) - x = numba.carray(x_ptr, itemsize) - y = numba.carray(y_ptr, itemsize) - # for i in range(itemsize): - # if x[i] != y[i]: - # z_ptr[0] = False - # break - # else: - # z_ptr[0] = True - z_ptr[0] = (x == y).all() - - else: - - def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) - x = numba.carray(x_ptr, itemsize) - y = numba.carray(y_ptr, itemsize) - # for i in range(itemsize): - # if mask[i] and x[i] != y[i]: - # z_ptr[0] = False - # break - # else: - # z_ptr[0] = True - z_ptr[0] = (x[mask] == y[mask]).all() - - elif self.name == "ne" and not self._anonymous: - # assert dtype.np_type == dtype2.np_type - itemsize = dtype.np_type.itemsize - mask = _udt_mask(dtype.np_type) - ret_type = BOOL - wrapper_sig = nt.void( - nt.CPointer(INT8.numba_type), - nt.CPointer(UINT8.numba_type), - nt.CPointer(UINT8.numba_type), - ) - if mask.all(): - - def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) - x = numba.carray(x_ptr, itemsize) - y = numba.carray(y_ptr, itemsize) - # for i in range(itemsize): - # if x[i] != y[i]: - # z_ptr[0] = True - # break - # else: - # z_ptr[0] = False - z_ptr[0] = (x != y).any() - - else: - - def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) - x = numba.carray(x_ptr, itemsize) - y = numba.carray(y_ptr, itemsize) - # for i in range(itemsize): - # if mask[i] and x[i] != y[i]: - # z_ptr[0] = True - # break - # else: - # z_ptr[0] = False - z_ptr[0] = (x[mask] != y[mask]).any() - - else: - numba_func = self._numba_func - sig = (dtype.numba_type, dtype2.numba_type) - numba_func.compile(sig) # Should we catch and give additional error message? - ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) - binary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype, dtype2) - - binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper) - new_binary = ffi_new("GrB_BinaryOp*") - check_status_carg( - lib.GrB_BinaryOp_new( - new_binary, binary_wrapper.cffi, ret_type._carg, dtype._carg, dtype2._carg - ), - "BinaryOp", - new_binary, - ) - op = TypedUserBinaryOp( - self, - self.name, - dtype, - ret_type, - new_binary[0], - dtype2=dtype2, - ) - self._udt_types[dtypes] = ret_type - self._udt_ops[dtypes] = op - return op - - @classmethod - def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): - """Register a BinaryOp without registering it in the ``graphblas.binary`` namespace. - - Because it is not registered in the namespace, the name is optional. - """ - if parameterized: - return ParameterizedBinaryOp(name, func, anonymous=True, is_udt=is_udt) - return cls._build(name, func, anonymous=True, is_udt=is_udt) - - @classmethod - def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): - """Register a BinaryOp. The name will be used to identify the BinaryOp in the - ``graphblas.binary`` namespace. - - >>> def max_zero(x, y): - r = 0 - if x > r: - r = x - if y > r: - r = y - return r - >>> gb.core.operator.BinaryOp.register_new("max_zero", max_zero) - >>> dir(gb.binary) - [..., 'max_zero', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "func": func, "parameterized": parameterized}, - ) - elif parameterized: - binary_op = ParameterizedBinaryOp(name, func, is_udt=is_udt) - setattr(module, funcname, binary_op) - else: - binary_op = cls._build(name, func, is_udt=is_udt) - setattr(module, funcname, binary_op) - # Also save it to `graphblas.op` if not yet defined - opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) - if not _hasop(opmodule, funcname): - if lazy: - opmodule._delayed[funcname] = module - else: - setattr(opmodule, funcname, binary_op) - if not cls._initialized: - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return binary_op - - @classmethod - def _initialize(cls): - if cls._initialized: # pragma: no cover (safety) - return - super()._initialize() - # Rename div to cdiv - cdiv = binary.cdiv = op.cdiv = BinaryOp("cdiv") - for dtype, ret_type in binary.div.types.items(): - orig_op = binary.div[dtype] - cur_op = TypedBuiltinBinaryOp( - cdiv, "cdiv", dtype, ret_type, orig_op.gb_obj, orig_op.gb_name - ) - cdiv._add(cur_op) - del binary.div - del op.div - # Add truediv which always points to floating point cdiv - # We are effectively hacking cdiv to always return floating point values - # If the inputs are FP32, we use DIV_FP32; use DIV_FP64 for all other input dtypes - truediv = binary.truediv = op.truediv = BinaryOp("truediv") - rtruediv = binary.rtruediv = op.rtruediv = BinaryOp("rtruediv") - for new_op, builtin_op in [(truediv, binary.cdiv), (rtruediv, binary.rdiv)]: - for dtype in builtin_op.types: - if dtype.name in {"FP32", "FC32", "FC64"}: - orig_dtype = dtype - else: - orig_dtype = FP64 - orig_op = builtin_op[orig_dtype] - cur_op = TypedBuiltinBinaryOp( - new_op, - new_op.name, - dtype, - builtin_op.types[orig_dtype], - orig_op.gb_obj, - orig_op.gb_name, - ) - new_op._add(cur_op) - # Add floordiv - # cdiv truncates towards 0, while floordiv truncates towards -inf - BinaryOp.register_new("floordiv", _floordiv, lazy=True) # cast to integer - BinaryOp.register_new("rfloordiv", _rfloordiv, lazy=True) # cast to integer - - # For aggregators - BinaryOp.register_new("absfirst", _absfirst, lazy=True) - BinaryOp.register_new("abssecond", _abssecond, lazy=True) - BinaryOp.register_new("rpow", _rpow, lazy=True) - - # For algorithms - binary._delayed["binom"] = (_register_binom, {}) # Lazy with custom creation - op._delayed["binom"] = binary - - BinaryOp.register_new("isclose", _isclose, parameterized=True) - - # Update type information with sane coercion - position_dtypes = [ - BOOL, - FP32, - FP64, - INT8, - INT16, - UINT8, - UINT16, - UINT32, - UINT64, - ] - if _supports_complex: - position_dtypes.extend([FC32, FC64]) - name_types = [ - # fmt: off - ( - ("atan2", "copysign", "fmod", "hypot", "ldexp", "remainder"), - ((BOOL, INT8, INT16, UINT8, UINT16), FP32), - ((INT32, INT64, UINT32, UINT64), FP64), - ), - ( - ( - "firsti", "firsti1", "firstj", "firstj1", "secondi", "secondi1", - "secondj", "secondj1"), - ( - position_dtypes, - INT64, - ), - ), - ( - ["lxnor"], - ( - ( - FP32, FP64, INT8, INT16, INT32, INT64, - UINT8, UINT16, UINT32, UINT64, - ), - BOOL, - ), - ), - # fmt: on - ] - if _supports_complex: - name_types.append( - ( - ["cmplx"], - ((BOOL, INT8, INT16, UINT8, UINT16), FP32), - ((INT32, INT64, UINT32, UINT64), FP64), - ) - ) - for names, *types in name_types: - for name in names: - if name in _SS_OPERATORS: - cur_op = binary._deprecated[name] - else: - cur_op = getattr(binary, name) - for input_types, target_type in types: - typed_op = cur_op._typed_ops[target_type] - output_type = cur_op.types[target_type] - for dtype in input_types: - if dtype not in cur_op.types: # pragma: no branch (safety) - cur_op.types[dtype] = output_type - cur_op._typed_ops[dtype] = typed_op - cur_op.coercions[dtype] = target_type - # Not valid input dtypes - del binary.ldexp[FP32] - del binary.ldexp[FP64] - # Fill in commutes info - for left_name, right_name in cls._commutes.items(): - if left_name in _SS_OPERATORS: - left = binary._deprecated[left_name] - else: - left = getattr(binary, left_name) - if backend == "suitesparse" and right_name in _SS_OPERATORS: - left._commutes_to = f"ss.{right_name}" - else: - left._commutes_to = right_name - if right_name not in binary._delayed: - if right_name in _SS_OPERATORS: - right = binary._deprecated[right_name] - else: - right = getattr(binary, right_name) - if backend == "suitesparse" and left_name in _SS_OPERATORS: - right._commutes_to = f"ss.{left_name}" - else: - right._commutes_to = left_name - for name in cls._commutative: - cur_op = getattr(binary, name) - cur_op._commutes_to = name - for left_name, right_name in cls._commutes_to_in_semiring.items(): - if left_name in _SS_OPERATORS: - left = binary._deprecated[left_name] - else: # pragma: no cover (safety) - left = getattr(binary, left_name) - if right_name in _SS_OPERATORS: - right = binary._deprecated[right_name] - else: # pragma: no cover (safety) - right = getattr(binary, right_name) - left._semiring_commutes_to = right - right._semiring_commutes_to = left - # Allow some functions to work on UDTs - for binop, func in [ - (binary.first, _first), - (binary.second, _second), - (binary.pair, _pair), - (binary.any, _first), - ]: - binop.orig_func = func - binop._numba_func = numba.njit(func) - binop._udt_types = {} - binop._udt_ops = {} - binary.any._numba_func = binary.first._numba_func - binary.eq._udt_types = {} - binary.eq._udt_ops = {} - binary.ne._udt_types = {} - binary.ne._udt_ops = {} - # Set custom dtype handling - binary.first._custom_dtype = _first_dtype - binary.second._custom_dtype = _second_dtype - binary.pair._custom_dtype = _pair_dtype - cls._initialized = True - - def __init__( - self, - name, - func=None, - *, - anonymous=False, - is_positional=False, - is_udt=False, - numba_func=None, - ): - super().__init__(name, anonymous=anonymous) - self._monoid = None - self._commutes_to = None - self._semiring_commutes_to = None - self.orig_func = func - self._numba_func = numba_func - self._is_udt = is_udt - self.is_positional = is_positional - self._custom_dtype = None - if is_udt: - self._udt_types = {} # {(dtype, dtype): DataType} - self._udt_ops = {} # {(dtype, dtype): TypedUserBinaryOp} - - def __reduce__(self): - if self._anonymous: - if hasattr(self.orig_func, "_parameterized_info"): - return (_deserialize_parameterized, self.orig_func._parameterized_info) - return (self.register_anonymous, (self.orig_func, self.name)) - if (name := f"binary.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.orig_func)) - - __call__ = TypedBuiltinBinaryOp.__call__ - is_commutative = TypedBuiltinBinaryOp.is_commutative - commutes_to = ParameterizedBinaryOp.commutes_to - - @property - def monoid(self): - if self._monoid is None and not self._anonymous: - self._monoid = Monoid._find(self.name) - return self._monoid - - -class Monoid(OpBase): - """Takes two inputs and returns one output, all of the same data type. - - Built-in and registered Monoids are located in the ``graphblas.monoid`` namespace - as well as in the ``graphblas.ops`` combined namespace. - """ - - __slots__ = "_binaryop", "_identity", "_is_idempotent" - is_commutative = True - is_positional = False - _custom_dtype = None - _module = monoid - _modname = "monoid" - _typed_class = TypedBuiltinMonoid - _parse_config = { - "trim_from_front": 4, - "delete_exact": "MONOID", - "num_underscores": 1, - "re_exprs": [ - re.compile( - "^GrB_(MIN|MAX|PLUS|TIMES|LOR|LAND|LXOR|LXNOR)_MONOID" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(ANY)_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)_MONOID$" - ), - re.compile("^GxB_(PLUS|TIMES|ANY)_(FC32|FC64)_MONOID$"), - re.compile("^GxB_(EQ|ANY)_BOOL_MONOID$"), - re.compile("^GxB_(BOR|BAND|BXOR|BXNOR)_(UINT8|UINT16|UINT32|UINT64)_MONOID$"), - ], - } - - @classmethod - def _build(cls, name, binaryop, identity, *, is_idempotent=False, anonymous=False): - if type(binaryop) is not BinaryOp: - raise TypeError(f"binaryop must be a BinaryOp, not {type(binaryop)}") - if name is None: - name = binaryop.name - new_type_obj = cls( - name, binaryop, identity, is_idempotent=is_idempotent, anonymous=anonymous - ) - if not binaryop._is_udt: - if not isinstance(identity, Mapping): - identities = dict.fromkeys(binaryop.types, identity) - explicit_identities = False - else: - identities = {lookup_dtype(key): val for key, val in identity.items()} - explicit_identities = True - for type_, ident in identities.items(): - ret_type = binaryop[type_].return_type - # If there is a domain mismatch, then DomainMismatch will be raised - # below if identities were explicitly given. - if type_ != ret_type and not explicit_identities: - continue - new_monoid = ffi_new("GrB_Monoid*") - func = libget(f"GrB_Monoid_new_{type_.name}") - zcast = ffi.cast(type_.c_type, ident) - check_status_carg( - func(new_monoid, binaryop[type_].gb_obj, zcast), "Monoid", new_monoid[0] - ) - op = TypedUserMonoid( - new_type_obj, - name, - type_, - ret_type, - new_monoid[0], - binaryop[type_], - ident, - ) - new_type_obj._add(op) - return new_type_obj - - def _compile_udt(self, dtype, dtype2): - if dtype2 is None: - dtype2 = dtype - elif dtype != dtype2: - raise TypeError( - "Monoid inputs must be the same dtype (got {dtype} and {dtype2}); " - "unable to coerce when using UDTs." - ) - if dtype in self._udt_types: - return self._udt_ops[dtype] - binaryop = self.binaryop._compile_udt(dtype, dtype2) - from .scalar import Scalar - - ret_type = binaryop.return_type - identity = Scalar.from_value(self._identity, dtype=ret_type, is_cscalar=True) - new_monoid = ffi_new("GrB_Monoid*") - status = lib.GrB_Monoid_new_UDT(new_monoid, binaryop.gb_obj, identity.gb_obj) - check_status_carg(status, "Monoid", new_monoid[0]) - op = TypedUserMonoid( - new_monoid, - self.name, - dtype, - ret_type, - new_monoid[0], - binaryop, - identity, - ) - self._udt_types[dtype] = dtype - self._udt_ops[dtype] = op - return op - - @classmethod - def register_anonymous(cls, binaryop, identity, name=None, *, is_idempotent=False): - """Register a Monoid without registering it in the ``graphblas.monoid`` namespace. - - Because it is not registered in the namespace, the name is optional. - - Parameters - ---------- - binaryop : BinaryOp - Builtin or registered binary operator - identity : - Identity value of the monoid - name : str, optional - Name associated with the monoid - is_idempotent : bool, default False - Does ``op(x, x) == x`` for any x? - - Returns - ------- - Function handle - """ - if type(binaryop) is ParameterizedBinaryOp: - return ParameterizedMonoid( - name, binaryop, identity, is_idempotent=is_idempotent, anonymous=True - ) - return cls._build(name, binaryop, identity, is_idempotent=is_idempotent, anonymous=True) - - @classmethod - def register_new(cls, name, binaryop, identity, *, is_idempotent=False, lazy=False): - """Register a Monoid. The name will be used to identify the Monoid in the - ``graphblas.monoid`` namespace. - - >>> gb.core.operator.Monoid.register_new("max_zero", gb.binary.max_zero, 0) - >>> dir(gb.monoid) - [..., 'max_zero', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "binaryop": binaryop, "identity": identity}, - ) - elif type(binaryop) is ParameterizedBinaryOp: - monoid = ParameterizedMonoid(name, binaryop, identity, is_idempotent=is_idempotent) - setattr(module, funcname, monoid) - else: - monoid = cls._build(name, binaryop, identity, is_idempotent=is_idempotent) - setattr(module, funcname, monoid) - # Also save it to `graphblas.op` if not yet defined - opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) - if not _hasop(opmodule, funcname): - if lazy: - opmodule._delayed[funcname] = module - else: - setattr(opmodule, funcname, monoid) - if not cls._initialized: # pragma: no cover - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return monoid - - def __init__(self, name, binaryop=None, identity=None, *, is_idempotent=False, anonymous=False): - super().__init__(name, anonymous=anonymous) - self._binaryop = binaryop - self._identity = identity - self._is_idempotent = is_idempotent - if binaryop is not None: - binaryop._monoid = self - if binaryop._is_udt: - self._udt_types = {} # {dtype: DataType} - self._udt_ops = {} # {dtype: TypedUserMonoid} - - def __reduce__(self): - if self._anonymous: - return (self.register_anonymous, (self._binaryop, self._identity, self.name)) - if (name := f"monoid.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self._binaryop, self._identity)) - - @property - def binaryop(self): - """The :class:`BinaryOp` associated with the Monoid.""" - if self._binaryop is not None: - return self._binaryop - # Must be builtin - return getattr(binary, self.name) - - @property - def identities(self): - """The per-dtype identity values for the Monoid.""" - return {dtype: val.identity for dtype, val in self._typed_ops.items()} - - @property - def is_idempotent(self): - """True if ``monoid(x, x) == x`` for any x.""" - return self._is_idempotent - - @property - def _is_udt(self): - return self._binaryop is not None and self._binaryop._is_udt - - @classmethod - def _initialize(cls): - if cls._initialized: # pragma: no cover (safety) - return - super()._initialize() - lor = monoid.lor._typed_ops[BOOL] - land = monoid.land._typed_ops[BOOL] - for cur_op, typed_op in [ - (monoid.max, lor), - (monoid.min, land), - # (monoid.plus, lor), # two choices: lor, or plus[int] - (monoid.times, land), - ]: - if BOOL not in cur_op.types: # pragma: no branch (safety) - cur_op.types[BOOL] = BOOL - cur_op.coercions[BOOL] = BOOL - cur_op._typed_ops[BOOL] = typed_op - - for cur_op in [monoid.lor, monoid.land, monoid.lxnor, monoid.lxor]: - bool_op = cur_op._typed_ops[BOOL] - for dtype in [ - FP32, - FP64, - INT8, - INT16, - INT32, - INT64, - UINT8, - UINT16, - UINT32, - UINT64, - ]: - if dtype in cur_op.types: # pragma: no cover (safety) - continue - cur_op.types[dtype] = BOOL - cur_op.coercions[dtype] = BOOL - cur_op._typed_ops[dtype] = bool_op - - # Builtin monoids that are idempotent; i.e., `op(x, x) == x` for any x - for name in ["any", "band", "bor", "land", "lor", "max", "min"]: - getattr(monoid, name)._is_idempotent = True - # Allow some functions to work on UDTs - any_ = monoid.any - any_._identity = 0 - any_._udt_types = {} - any_._udt_ops = {} - cls._initialized = True - - commutes_to = TypedBuiltinMonoid.commutes_to - __call__ = TypedBuiltinMonoid.__call__ - - -class Semiring(OpBase): - """Combination of a :class:`Monoid` and a :class:`BinaryOp`. - - Semirings are most commonly used for performing matrix multiplication, - with the BinaryOp taking the place of the standard multiplication operator - and the Monoid taking the place of the standard addition operator. - - Built-in and registered Semirings are located in the ``graphblas.semiring`` namespace - as well as in the ``graphblas.ops`` combined namespace. - """ - - __slots__ = "_monoid", "_binaryop" - _module = semiring - _modname = "semiring" - _typed_class = TypedBuiltinSemiring - _parse_config = { - "trim_from_front": 4, - "delete_exact": "SEMIRING", - "num_underscores": 2, - "re_exprs": [ - re.compile( - "^GrB_(PLUS|MIN|MAX)_(PLUS|TIMES|FIRST|SECOND|MIN|MAX)_SEMIRING" - "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(MIN|MAX|PLUS|TIMES|ANY)" - "_(FIRST|SECOND|PAIR|MIN|MAX|PLUS|MINUS|RMINUS|TIMES" - "|DIV|RDIV|ISEQ|ISNE|ISGT|ISLT|ISGE|ISLE|LOR|LAND|LXOR" - "|FIRSTI1|FIRSTI|FIRSTJ1|FIRSTJ|SECONDI1|SECONDI|SECONDJ1|SECONDJ)" - "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(PLUS|TIMES|ANY)_(FIRST|SECOND|PAIR|PLUS|MINUS|TIMES|DIV|RDIV|RMINUS)" - "_(FC32|FC64)$" - ), - re.compile( - "^GxB_(BOR|BAND|BXOR|BXNOR)_(BOR|BAND|BXOR|BXNOR)_(UINT8|UINT16|UINT32|UINT64)$" - ), - ], - "re_exprs_return_bool": [ - re.compile("^GrB_(LOR|LAND|LXOR|LXNOR)_(LOR|LAND)_SEMIRING_BOOL$"), - re.compile( - "^GxB_(LOR|LAND|LXOR|EQ|ANY)_(EQ|NE|GT|LT|GE|LE)" - "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(LOR|LAND|LXOR|EQ|ANY)_(FIRST|SECOND|PAIR|LOR|LAND|LXOR|EQ|GT|LT|GE|LE)_BOOL$" - ), - ], - } - - @classmethod - def _build(cls, name, monoid, binaryop, *, anonymous=False): - if type(monoid) is not Monoid: - raise TypeError(f"monoid must be a Monoid, not {type(monoid)}") - if type(binaryop) is not BinaryOp: - raise TypeError(f"binaryop must be a BinaryOp, not {type(binaryop)}") - if name is None: - name = f"{monoid.name}_{binaryop.name}".replace(".", "_") - new_type_obj = cls(name, monoid, binaryop, anonymous=anonymous) - if binaryop._is_udt: - return new_type_obj - for binary_in, binary_func in binaryop._typed_ops.items(): - binary_out = binary_func.return_type - # Unfortunately, we can't have user-defined monoids over bools yet - # because numba can't compile correctly. - if ( - binary_out not in monoid.types - # Are all coercions bad, or just to bool? - or monoid.coercions.get(binary_out, binary_out) != binary_out - ): - continue - new_semiring = ffi_new("GrB_Semiring*") - check_status_carg( - lib.GrB_Semiring_new(new_semiring, monoid[binary_out].gb_obj, binary_func.gb_obj), - "Semiring", - new_semiring, - ) - ret_type = monoid[binary_out].return_type - op = TypedUserSemiring( - new_type_obj, - name, - binary_in, - ret_type, - new_semiring[0], - monoid[binary_out], - binary_func, - ) - new_type_obj._add(op) - return new_type_obj - - def _compile_udt(self, dtype, dtype2): - if dtype2 is None: - dtype2 = dtype - dtypes = (dtype, dtype2) - if dtypes in self._udt_types: - return self._udt_ops[dtypes] - binaryop = self.binaryop._compile_udt(dtype, dtype2) - monoid = self.monoid[binaryop.return_type] - ret_type = monoid.return_type - new_semiring = ffi_new("GrB_Semiring*") - status = lib.GrB_Semiring_new(new_semiring, monoid.gb_obj, binaryop.gb_obj) - check_status_carg(status, "Semiring", new_semiring) - op = TypedUserSemiring( - new_semiring, - self.name, - dtype, - ret_type, - new_semiring[0], - monoid, - binaryop, - dtype2=dtype2, - ) - self._udt_types[dtypes] = dtype - self._udt_ops[dtypes] = op - return op - - @classmethod - def register_anonymous(cls, monoid, binaryop, name=None): - """Register a Semiring without registering it in the ``graphblas.semiring`` namespace. - - Because it is not registered in the namespace, the name is optional. - - Parameters - ---------- - monoid : Monoid - Builtin or registered monoid - binaryop : BinaryOp - Builtin or registered binary operator - name : str, optional - Name associated with the semiring - - Returns - ------- - Function handle - """ - if type(monoid) is ParameterizedMonoid or type(binaryop) is ParameterizedBinaryOp: - return ParameterizedSemiring(name, monoid, binaryop, anonymous=True) - return cls._build(name, monoid, binaryop, anonymous=True) - - @classmethod - def register_new(cls, name, monoid, binaryop, *, lazy=False): - """Register a Semiring. The name will be used to identify the Semiring in the - ``graphblas.semiring`` namespace. - - >>> gb.core.operator.Semiring.register_new("max_max", gb.monoid.max, gb.binary.max) - >>> dir(gb.semiring) - [..., 'max_max', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "monoid": monoid, "binaryop": binaryop}, - ) - elif type(monoid) is ParameterizedMonoid or type(binaryop) is ParameterizedBinaryOp: - semiring = ParameterizedSemiring(name, monoid, binaryop) - setattr(module, funcname, semiring) - else: - semiring = cls._build(name, monoid, binaryop) - setattr(module, funcname, semiring) - # Also save it to `graphblas.op` if not yet defined - opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) - if not _hasop(opmodule, funcname): - if lazy: - opmodule._delayed[funcname] = module - else: - setattr(opmodule, funcname, semiring) - if not cls._initialized: - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return semiring - - @classmethod - def _initialize(cls): - if cls._initialized: # pragma: no cover (safety) - return - super()._initialize() - # Rename div to cdiv (truncate towards 0) - div_semirings = { - attr: val - for attr, val in vars(semiring).items() - if type(val) is Semiring and attr.endswith("_div") - } - for orig_name, orig in div_semirings.items(): - name = f"{orig_name[:-3]}cdiv" - cdiv_semiring = Semiring(name) - setattr(semiring, name, cdiv_semiring) - setattr(op, name, cdiv_semiring) - delattr(semiring, orig_name) - delattr(op, orig_name) - for dtype, ret_type in orig.types.items(): - orig_semiring = orig[dtype] - new_semiring = TypedBuiltinSemiring( - cdiv_semiring, - name, - dtype, - ret_type, - orig_semiring.gb_obj, - orig_semiring.gb_name, - ) - cdiv_semiring._add(new_semiring) - # Also add truediv (always floating point) and floordiv (truncate towards -inf) - for orig_name, orig in div_semirings.items(): - cls.register_new(f"{orig_name[:-3]}truediv", orig.monoid, binary.truediv, lazy=True) - cls.register_new(f"{orig_name[:-3]}rtruediv", orig.monoid, "rtruediv", lazy=True) - cls.register_new(f"{orig_name[:-3]}floordiv", orig.monoid, "floordiv", lazy=True) - cls.register_new(f"{orig_name[:-3]}rfloordiv", orig.monoid, "rfloordiv", lazy=True) - # For aggregators - cls.register_new("plus_pow", monoid.plus, binary.pow) - cls.register_new("plus_rpow", monoid.plus, "rpow", lazy=True) - cls.register_new("plus_absfirst", monoid.plus, "absfirst", lazy=True) - cls.register_new("max_absfirst", monoid.max, "absfirst", lazy=True) - cls.register_new("plus_abssecond", monoid.plus, "abssecond", lazy=True) - cls.register_new("max_abssecond", monoid.max, "abssecond", lazy=True) - - # Update type information with sane coercion - for lname in ["any", "eq", "land", "lor", "lxnor", "lxor"]: - target_name = f"{lname}_ne" - source_name = f"{lname}_lxor" - if not _hasop(semiring, target_name): - continue - target_op = getattr(semiring, target_name) - if BOOL not in target_op.types: # pragma: no branch (safety) - source_op = getattr(semiring, source_name) - typed_op = source_op._typed_ops[BOOL] - target_op.types[BOOL] = BOOL - target_op._typed_ops[BOOL] = typed_op - target_op.coercions[dtype] = BOOL - - position_dtypes = [ - BOOL, - FP32, - FP64, - INT8, - INT16, - UINT8, - UINT16, - UINT32, - UINT64, - ] - notbool_dtypes = [ - FP32, - FP64, - INT8, - INT16, - INT32, - INT64, - UINT8, - UINT16, - UINT32, - UINT64, - ] - if _supports_complex: - position_dtypes.extend([FC32, FC64]) - notbool_dtypes.extend([FC32, FC64]) - for lnames, rnames, *types in [ - # fmt: off - ( - ("any", "max", "min", "plus", "times"), - ( - "firsti", "firsti1", "firstj", "firstj1", - "secondi", "secondi1", "secondj", "secondj1", - ), - ( - position_dtypes, - INT64, - ), - ), - ( - ("eq", "land", "lor", "lxnor", "lxor"), - ("first", "pair", "second"), - # TODO: check if FC coercion works here - ( - notbool_dtypes, - BOOL, - ), - ), - ( - ("band", "bor", "bxnor", "bxor"), - ("band", "bor", "bxnor", "bxor"), - ([INT8], UINT16), - ([INT16], UINT32), - ([INT32], UINT64), - ([INT64], UINT64), - ), - ( - ("any", "eq", "land", "lor", "lxnor", "lxor"), - ("eq", "land", "lor", "lxnor", "lxor", "ne"), - ( - ( - FP32, FP64, INT8, INT16, INT32, INT64, - UINT8, UINT16, UINT32, UINT64, - ), - BOOL, - ), - ), - # fmt: on - ]: - for left, right in itertools.product(lnames, rnames): - name = f"{left}_{right}" - if not _hasop(semiring, name): - continue - if name in _SS_OPERATORS: - cur_op = semiring._deprecated[name] - else: - cur_op = getattr(semiring, name) - for input_types, target_type in types: - typed_op = cur_op._typed_ops[target_type] - output_type = cur_op.types[target_type] - for dtype in input_types: - if dtype not in cur_op.types: - cur_op.types[dtype] = output_type - cur_op._typed_ops[dtype] = typed_op - cur_op.coercions[dtype] = target_type - - # Handle a few boolean cases - for opname, targetname in [ - ("max_first", "lor_first"), - ("max_second", "lor_second"), - ("max_land", "lor_land"), - ("max_lor", "lor_lor"), - ("max_lxor", "lor_lxor"), - ("min_first", "land_first"), - ("min_second", "land_second"), - ("min_land", "land_land"), - ("min_lor", "land_lor"), - ("min_lxor", "land_lxor"), - ]: - cur_op = getattr(semiring, opname) - target = getattr(semiring, targetname) - if BOOL in cur_op.types or BOOL not in target.types: # pragma: no cover (safety) - continue - cur_op.types[BOOL] = target.types[BOOL] - cur_op._typed_ops[BOOL] = target._typed_ops[BOOL] - cur_op.coercions[BOOL] = BOOL - cls._initialized = True - - def __init__(self, name, monoid=None, binaryop=None, *, anonymous=False): - super().__init__(name, anonymous=anonymous) - self._monoid = monoid - self._binaryop = binaryop - try: - if self.binaryop._udt_types is not None: - self._udt_types = {} # {(dtype, dtype): DataType} - self._udt_ops = {} # {(dtype, dtype): TypedUserSemiring} - except AttributeError: - # `*_div` semirings raise here, but don't need `_udt_types` - pass - - def __reduce__(self): - if self._anonymous: - return (self.register_anonymous, (self._monoid, self._binaryop, self.name)) - if (name := f"semiring.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self._monoid, self._binaryop)) - - @property - def binaryop(self): - """The :class:`BinaryOp` associated with the Semiring.""" - if self._binaryop is not None: - return self._binaryop - # Must be builtin - name = self.name.split("_")[1] - if name in _SS_OPERATORS: - return binary._deprecated[name] - return getattr(binary, name) - - @property - def monoid(self): - """The :class:`Monoid` associated with the Semiring.""" - if self._monoid is not None: - return self._monoid - # Must be builtin - return getattr(monoid, self.name.split("_")[0].split(".")[-1]) - - @property - def is_positional(self): - return self.binaryop.is_positional - - @property - def _is_udt(self): - return self._binaryop is not None and self._binaryop._is_udt - - @property - def _custom_dtype(self): - return self.binaryop._custom_dtype - - commutes_to = TypedBuiltinSemiring.commutes_to - is_commutative = TypedBuiltinSemiring.is_commutative - __call__ = TypedBuiltinSemiring.__call__ - - -def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scalar=False, kind=None): - if isinstance(op, OpBase): - # UDTs always get compiled - if op._is_udt: - return op._compile_udt(dtype, dtype2) - # Single dtype is simple lookup - if dtype2 is None: - return op[dtype] - # Handle special cases such as first and second (may have UDTs) - if op._custom_dtype is not None and (rv := op._custom_dtype(op, dtype, dtype2)) is not None: - return rv - # Generic case: try to unify the two dtypes - try: - return op[ - unify(dtype, dtype2, is_left_scalar=is_left_scalar, is_right_scalar=is_right_scalar) - ] - except (TypeError, AttributeError): - # Failure to unify implies a dtype is UDT; some builtin operators can handle UDTs - if op.is_positional: - return op[UINT64] - if op._udt_types is None: - raise - return op._compile_udt(dtype, dtype2) - if isinstance(op, ParameterizedUdf): - op = op() # Use default parameters of parameterized UDFs - return get_typed_op( - op, - dtype, - dtype2, - is_left_scalar=is_left_scalar, - is_right_scalar=is_right_scalar, - kind=kind, - ) - if isinstance(op, TypedOpBase): - return op - - from .agg import Aggregator, TypedAggregator - - if isinstance(op, Aggregator): - return op[dtype] - if isinstance(op, TypedAggregator): - return op - if isinstance(op, str): - if kind == "unary": - op = unary_from_string(op) - elif kind == "select": - op = select_from_string(op) - elif kind == "binary": - op = binary_from_string(op) - elif kind == "monoid": - op = monoid_from_string(op) - elif kind == "semiring": - op = semiring_from_string(op) - elif kind == "binary|aggregator": - try: - op = binary_from_string(op) - except ValueError: - try: - op = aggregator_from_string(op) - except ValueError: - raise ValueError( - f"Unknown binary or aggregator string: {op!r}. Example usage: '+[int]'" - ) from None - - else: - raise ValueError( - f"Unable to get op from string {op!r}. `kind=` argument must be provided as " - '"unary", "binary", "monoid", "semiring", "indexunary", "select", ' - 'or "binary|aggregator".' - ) - return get_typed_op( - op, - dtype, - dtype2, - is_left_scalar=is_left_scalar, - is_right_scalar=is_right_scalar, - kind=kind, - ) - if isinstance(op, FunctionType): - if kind == "unary": - op = UnaryOp.register_anonymous(op, is_udt=True) - return op._compile_udt(dtype, dtype2) - if kind.startswith("binary"): - op = BinaryOp.register_anonymous(op, is_udt=True) - return op._compile_udt(dtype, dtype2) - if isinstance(op, BuiltinFunctionType) and op in _builtin_to_op: - return get_typed_op( - _builtin_to_op[op], - dtype, - dtype2, - is_left_scalar=is_left_scalar, - is_right_scalar=is_right_scalar, - kind=kind, - ) - raise TypeError(f"Unable to get typed operator from object with type {type(op)}") - - -def find_opclass(gb_op): - if isinstance(gb_op, OpBase): - opclass = type(gb_op).__name__ - elif isinstance(gb_op, TypedOpBase): - opclass = gb_op.opclass - elif isinstance(gb_op, ParameterizedUdf): - gb_op = gb_op() # Use default parameters of parameterized UDFs - gb_op, opclass = find_opclass(gb_op) - elif isinstance(gb_op, BuiltinFunctionType) and gb_op in _builtin_to_op: - gb_op, opclass = find_opclass(_builtin_to_op[gb_op]) - else: - opclass = UNKNOWN_OPCLASS - return gb_op, opclass - - -def get_semiring(monoid, binaryop, name=None): - """Get or create a Semiring object from a monoid and binaryop. - - If either are typed, then the returned semiring will also be typed. - - See Also - -------- - semiring.register_anonymous - semiring.register_new - semiring.from_string - """ - monoid, opclass = find_opclass(monoid) - switched = False - if opclass == "BinaryOp" and monoid.monoid is not None: - switched = True - monoid = monoid.monoid - elif opclass != "Monoid": - raise TypeError(f"Expected a Monoid for the monoid argument. Got type: {type(monoid)}") - binaryop, opclass = find_opclass(binaryop) - if opclass == "Monoid": - if switched: - raise TypeError( - "Got a BinaryOp for the monoid argument and a Monoid for the binaryop argument. " - "Are the arguments switched? Hint: you can do `mymonoid.binaryop` to get the " - "binaryop from a monoid." - ) - binaryop = binaryop.binaryop - elif opclass != "BinaryOp": - raise TypeError( - f"Expected a BinaryOp for the binaryop argument. Got type: {type(binaryop)}" - ) - if isinstance(monoid, Monoid): - monoid_type = None - else: - monoid_type = monoid.type - monoid = monoid.parent - if isinstance(binaryop, BinaryOp): - binary_type = None - else: - binary_type = binaryop.type - binaryop = binaryop.parent - if monoid._anonymous or binaryop._anonymous: - rv = Semiring.register_anonymous(monoid, binaryop, name=name) - else: - *monoid_prefix, monoid_name = monoid.name.rsplit(".", 1) - *binary_prefix, binary_name = binaryop.name.rsplit(".", 1) - if ( - monoid_prefix - and binary_prefix - and monoid_prefix == binary_prefix - or config.get("mapnumpy") - and ( - monoid_prefix == ["numpy"] - and not binary_prefix - or binary_prefix == ["numpy"] - and not monoid_prefix - ) - or backend == "suitesparse" - and binary_name in _SS_OPERATORS - ): - canonical_name = ( - ".".join(monoid_prefix or binary_prefix) + f".{monoid_name}_{binary_name}" - ) - else: - canonical_name = f"{monoid.name}_{binaryop.name}".replace(".", "_") - if name is None: - name = canonical_name - - module, funcname = Semiring._remove_nesting(canonical_name, strict=False) - rv = ( - getattr(module, funcname) - if funcname in module.__dict__ or funcname in module._delayed - else getattr(module, "_deprecated", {}).get(funcname) - ) - if rv is None and name != canonical_name: - module, funcname = Semiring._remove_nesting(name, strict=False) - rv = ( - getattr(module, funcname) - if funcname in module.__dict__ or funcname in module._delayed - else getattr(module, "_deprecated", {}).get(funcname) - ) - if rv is None: - rv = Semiring.register_new(canonical_name, monoid, binaryop) - elif rv.monoid is not monoid or rv.binaryop is not binaryop: # pragma: no cover - # It's not the object we expect (can this happen?) - rv = Semiring.register_anonymous(monoid, binaryop, name=name) - if name != canonical_name: - module, funcname = Semiring._remove_nesting(name, strict=False) - if not _hasop(module, funcname): # pragma: no branch (safety) - setattr(module, funcname, rv) - - if binary_type is not None: - return rv[binary_type] - if monoid_type is not None: - return rv[monoid_type] - return rv - - -# Now initialize all the things! -try: - UnaryOp._initialize() - IndexUnaryOp._initialize() - SelectOp._initialize() - BinaryOp._initialize() - Monoid._initialize() - Semiring._initialize() -except Exception: # pragma: no cover (debug) - # Exceptions here can often get ignored by Python - import traceback - - traceback.print_exc() - raise - -unary.register_new = UnaryOp.register_new -unary.register_anonymous = UnaryOp.register_anonymous -indexunary.register_new = IndexUnaryOp.register_new -indexunary.register_anonymous = IndexUnaryOp.register_anonymous -select.register_new = SelectOp.register_new -select.register_anonymous = SelectOp.register_anonymous -binary.register_new = BinaryOp.register_new -binary.register_anonymous = BinaryOp.register_anonymous -monoid.register_new = Monoid.register_new -monoid.register_anonymous = Monoid.register_anonymous -semiring.register_new = Semiring.register_new -semiring.register_anonymous = Semiring.register_anonymous -semiring.get_semiring = get_semiring - -select._binary_to_select.update( - { - binary.eq: select.valueeq, - binary.ne: select.valuene, - binary.le: select.valuele, - binary.lt: select.valuelt, - binary.ge: select.valuege, - binary.gt: select.valuegt, - binary.iseq: select.valueeq, - binary.isne: select.valuene, - binary.isle: select.valuele, - binary.islt: select.valuelt, - binary.isge: select.valuege, - binary.isgt: select.valuegt, - } -) - -_builtin_to_op = { - abs: unary.abs, - max: binary.max, - min: binary.min, - # Maybe someday: all, any, pow, sum -} - -_str_to_unary = { - "-": unary.ainv, - "~": unary.lnot, -} -_str_to_select = { - "<": select.valuelt, - ">": select.valuegt, - "<=": select.valuele, - ">=": select.valuege, - "!=": select.valuene, - "==": select.valueeq, - "col<=": select.colle, - "col>": select.colgt, - "row<=": select.rowle, - "row>": select.rowgt, - "index<=": select.indexle, - "index>": select.indexgt, -} -_str_to_binary = { - "<": binary.lt, - ">": binary.gt, - "<=": binary.le, - ">=": binary.ge, - "!=": binary.ne, - "==": binary.eq, - "+": binary.plus, - "-": binary.minus, - "*": binary.times, - "/": binary.truediv, - "//": "floordiv", - "%": "numpy.mod", - "**": binary.pow, - "&": binary.land, - "|": binary.lor, - "^": binary.lxor, -} -_str_to_monoid = { - "==": monoid.eq, - "+": monoid.plus, - "*": monoid.times, - "&": monoid.land, - "|": monoid.lor, - "^": monoid.lxor, -} - - -def _from_string(string, module, mapping, example): - s = string.lower().strip() - base, *dtype = s.split("[") - if len(dtype) > 1: - name = module.__name__.split(".")[-1] - raise ValueError( - f'Bad {name} string: {string!r}. Contains too many "[". Example usage: {example!r}' - ) - if dtype: - dtype = dtype[0] - if not dtype.endswith("]"): - name = module.__name__.split(".")[-1] - raise ValueError( - f'Bad {name} string: {string!r}. Datatype specification does not end with "]". ' - f"Example usage: {example!r}" - ) - dtype = lookup_dtype(dtype[:-1]) - if "]" in base: - name = module.__name__.split(".")[-1] - raise ValueError( - f'Bad {name} string: {string!r}. "]" not matched by "[". Example usage: {example!r}' - ) - if base in mapping: - op = mapping[base] - if type(op) is str: - op = mapping[base] = module.from_string(op) - elif hasattr(module, base): - op = getattr(module, base) - elif hasattr(module, "numpy") and hasattr(module.numpy, base): - op = getattr(module.numpy, base) - else: - *paths, attr = base.split(".") - op = None - cur = module - for path in paths: - cur = getattr(cur, path, None) - if not isinstance(cur, (OpPath, ModuleType)): - cur = None - break - op = getattr(cur, attr, None) - if op is None: - name = module.__name__.split(".")[-1] - raise ValueError(f"Unknown {name} string: {string!r}. Example usage: {example!r}") - if dtype: - op = op[dtype] - return op - - -def unary_from_string(string): - return _from_string(string, unary, _str_to_unary, "abs[int]") - - -def indexunary_from_string(string): - # "select" is a variant of IndexUnary, so the string abbreviations in - # _str_to_select are appropriate to reuse here - return _from_string(string, indexunary, _str_to_select, "row_index") - - -def select_from_string(string): - return _from_string(string, select, _str_to_select, "tril") - - -def binary_from_string(string): - return _from_string(string, binary, _str_to_binary, "+[int]") - - -def monoid_from_string(string): - return _from_string(string, monoid, _str_to_monoid, "+[int]") - - -def semiring_from_string(string): - split = string.split(".") - if len(split) == 1: - try: - return _from_string(string, semiring, {}, "min.+[int]") - except Exception: - pass - if len(split) != 2: - raise ValueError( - f"Bad semiring string: {string!r}. " - 'The monoid and binaryop should be separated by exactly one period, ".". ' - "Example usage: min.+[int]" - ) - cur_monoid = monoid_from_string(split[0]) - cur_binary = binary_from_string(split[1]) - return get_semiring(cur_monoid, cur_binary) - - -def op_from_string(string): - for func in [ - # Note: order matters here - unary_from_string, - binary_from_string, - monoid_from_string, - semiring_from_string, - indexunary_from_string, - select_from_string, - aggregator_from_string, - ]: - try: - return func(string) - except Exception: - pass - raise ValueError(f"Unknown op string: {string!r}. Example usage: 'abs[int]'") - - -unary.from_string = unary_from_string -indexunary.from_string = indexunary_from_string -select.from_string = select_from_string -binary.from_string = binary_from_string -monoid.from_string = monoid_from_string -semiring.from_string = semiring_from_string -op.from_string = op_from_string - -_str_to_agg = { - "+": "sum", - "*": "prod", - "&": "all", - "|": "any", -} - - -def aggregator_from_string(string): - return _from_string(string, agg, _str_to_agg, "sum[int]") - - -from .. import agg # noqa: E402 isort:skip - -agg.from_string = aggregator_from_string diff --git a/graphblas/core/operator/__init__.py b/graphblas/core/operator/__init__.py new file mode 100644 index 000000000..509e84a04 --- /dev/null +++ b/graphblas/core/operator/__init__.py @@ -0,0 +1,21 @@ +from .base import UNKNOWN_OPCLASS, OpBase, OpPath, ParameterizedUdf, TypedOpBase, find_opclass +from .binary import BinaryOp, ParameterizedBinaryOp +from .indexunary import IndexUnaryOp, ParameterizedIndexUnaryOp +from .monoid import Monoid, ParameterizedMonoid +from .select import ParameterizedSelectOp, SelectOp +from .semiring import ParameterizedSemiring, Semiring +from .unary import ParameterizedUnaryOp, UnaryOp +from .utils import ( + aggregator_from_string, + binary_from_string, + get_semiring, + get_typed_op, + indexunary_from_string, + monoid_from_string, + op_from_string, + select_from_string, + semiring_from_string, + unary_from_string, +) + +from .agg import Aggregator # isort:skip diff --git a/graphblas/core/operator/agg.py b/graphblas/core/operator/agg.py new file mode 100644 index 000000000..036149b1f --- /dev/null +++ b/graphblas/core/operator/agg.py @@ -0,0 +1,680 @@ +from functools import partial +from operator import getitem + +import numpy as np + +from ... import agg, backend, binary, monoid, semiring, unary +from ...dtypes import INT64, lookup_dtype +from ..utils import output_type + + +def _get_types(ops, initdtype): + """Determine the input and output types of an aggregator based on a list of ops.""" + if initdtype is None: + prev = dict(ops[0].types) + else: + op = ops[0] + prev = {key: get_typed_op(op, key, initdtype).return_type for key in op.types} + for op in ops[1:]: + cur = {} + types = op.types + for in_type, out_type in prev.items(): + if out_type not in types: # pragma: no cover (safety) + continue + cur[in_type] = types[out_type] + prev = cur + return prev + + +class Aggregator: + opclass = "Aggregator" + + def __init__( + self, + name, + *, + initval=None, + monoid=None, + semiring=None, + switch=False, + semiring2=None, + finalize=None, + composite=None, + custom=None, + types=None, + any_dtype=None, + ): + self.name = name + self._initval_orig = initval + self._initval = False if initval is None else initval + self._initdtype = lookup_dtype(type(self._initval), self._initval) + self._monoid = monoid + self._semiring = semiring + self._semiring2 = semiring2 + self._switch = switch + self._finalize = finalize + self._composite = composite + self._custom = custom + if types is None: + if monoid is not None: + types = [monoid] + elif semiring is not None: + types = [semiring, semiring2] + if finalize is not None: + types.append(finalize) + initval = self._initval + else: # pragma: no cover (sanity) + raise TypeError("types must be provided for composite and custom aggregators") + self._types_orig = types + self._types = None + self._typed_ops = {} + self._any_dtype = any_dtype + + @property + def types(self): + if self._types is None: + if type(self._semiring) is str: + self._semiring = semiring.from_string(self._semiring) + if type(self._types_orig[0]) is str: # pragma: no branch + self._types_orig[0] = semiring.from_string(self._types_orig[0]) + self._types = _get_types( + self._types_orig, None if self._initval_orig is None else self._initdtype + ) + return self._types + + def __getitem__(self, dtype): + dtype = lookup_dtype(dtype) + if not self._any_dtype and dtype not in self.types: + raise KeyError(f"{self.name} does not work with {dtype}") + if dtype not in self._typed_ops: + self._typed_ops[dtype] = TypedAggregator(self, dtype) + return self._typed_ops[dtype] + + def __contains__(self, dtype): + dtype = lookup_dtype(dtype) + return self._any_dtype or dtype in self.types + + def __repr__(self): + if self.name in agg._deprecated: + return f"agg.ss.{self.name}" + return f"agg.{self.name}" + + def __reduce__(self): + if self.name in agg._deprecated: + return f"agg.ss.{self.name}" + return f"agg.{self.name}" + + def __call__(self, val, *, rowwise=False, columnwise=False): + # Should we expose `allow_empty=` keyword when reducing to Scalar? + from ..matrix import Matrix, TransposedMatrix + from ..vector import Vector + + typ = output_type(val) + if typ is Vector: + if rowwise or columnwise: + raise ValueError( + "rowwise and columnwise arguments should not be used with Vector input" + ) + return val.reduce(self) + if typ in {Matrix, TransposedMatrix}: + if rowwise: + if columnwise: + raise ValueError("rowwise and columnwise arguments cannot both be True") + return val.reduce_rowwise(self) + if columnwise: + return val.reduce_columnwise(self) + return val.reduce_scalar(self) + raise TypeError( + f"Bad type when calling {self!r}.\n" + " - Expected type: Vector, Matrix, TransposedMatrix.\n" + f" - Got: {type(val)}.\n" + "Calling an Aggregator is syntactic sugar for calling reduce methods. " + f"For example, `A.reduce_scalar({self!r})` is the same as `{self!r}(A)`." + ) + + +class TypedAggregator: + opclass = "Aggregator" + + def __init__(self, agg, dtype): + self.name = agg.name + self.parent = agg + self.type = dtype + if dtype in agg.types: + self.return_type = agg.types[dtype] + elif agg._any_dtype is True: + self.return_type = dtype + else: + self.return_type = agg._any_dtype + + def __repr__(self): + return f"agg.{self.name}[{self.type}]" + + def _new(self, updater, expr, *, in_composite=False): + agg = self.parent + if agg._monoid is not None: + x = expr.args[0] + method = getattr(x, expr.method_name) + if expr.output_type.__name__ == "Scalar": + expr = method(agg._monoid[self.type], allow_empty=not expr._is_cscalar) + else: + expr = method(agg._monoid[self.type]) + updater << expr + if in_composite: + parent = updater.parent + if not parent._is_scalar: + return parent + return parent._as_vector() + return + + opts = updater.opts + if agg._composite is not None: + # Masks are applied throughout the aggregation, including composite aggregations. + # Aggregations done while `in_composite is True` should return the updater parent + # if the result is not a Scalar. If the result is a Scalar, then there can be no + # output mask, and a Vector of size 1 should be returned instead. + results = [] + mask = updater.kwargs.get("mask") + for cur_agg in agg._composite: + cur_agg = cur_agg[self.type] # Hopefully works well enough + arg = expr.construct_output(cur_agg.return_type) + results.append(cur_agg._new(arg(mask=mask, **opts), expr, in_composite=True)) + final_expr = agg._finalize(*results, opts) + if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": + updater << final_expr + elif expr.cfunc_name.startswith("GrB_Vector_reduce") or expr.cfunc_name.startswith( + "GrB_Matrix_reduce" + ): + final = final_expr.new(**opts) + updater << final[0] + else: + raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") + if in_composite: + parent = updater.parent + if not parent._is_scalar: + return parent + return parent._as_vector() + return + + if agg._custom is not None: + return agg._custom(self, updater, expr, opts, in_composite=in_composite) + + semiring = get_typed_op(agg._semiring, self.type, agg._initdtype) + if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": + # Matrix -> Vector + A = expr.args[0] + orig_updater = updater + if agg._finalize is not None: + step1 = expr.construct_output(semiring.return_type) + updater = step1(mask=updater.kwargs.get("mask"), **opts) + if expr.method_name == "reduce_columnwise": + A = A.T + size = A._ncols + init = expr._new_vector(agg._initdtype, size=size) + init(**opts)[...] = agg._initval # O(1) dense vector in SuiteSparse 5 + if agg._switch: + updater << semiring(init @ A.T) + else: + updater << semiring(A @ init) + if agg._finalize is not None: + orig_updater << agg._finalize[semiring.return_type](step1) + if in_composite: + return orig_updater.parent + elif expr.cfunc_name.startswith("GrB_Vector_reduce"): + # Vector -> Scalar + v = expr.args[0] + step1 = expr._new_vector(semiring.return_type, size=1) + init = expr._new_matrix(agg._initdtype, nrows=v._size, ncols=1) + init(**opts)[...] = agg._initval # O(1) dense column vector in SuiteSparse 5 + if agg._switch: + step1(**opts) << semiring(init.T @ v) + else: + step1(**opts) << semiring(v @ init) + if agg._finalize is not None: + finalize = agg._finalize[semiring.return_type] + if step1.dtype == finalize.return_type: + step1(**opts) << finalize(step1) + else: + step1 = finalize(step1).new(finalize.return_type, **opts) + if in_composite: + return step1 + updater << step1[0] + elif expr.cfunc_name.startswith("GrB_Matrix_reduce"): + # Matrix -> Scalar + A = expr.args[0] + # We need to compute in two steps: Matrix -> Vector -> Scalar. + # This has not been benchmarked or optimized. + # We may be able to intelligently choose the faster path. + init1 = expr._new_vector(agg._initdtype, size=A._ncols) + init1(**opts)[...] = agg._initval # O(1) dense vector in SuiteSparse 5 + step1 = expr._new_vector(semiring.return_type, size=A._nrows) + if agg._switch: + step1(**opts) << semiring(init1 @ A.T) + else: + step1(**opts) << semiring(A @ init1) + init2 = expr._new_matrix(agg._initdtype, nrows=A._nrows, ncols=1) + init2(**opts)[...] = agg._initval # O(1) dense vector in SuiteSparse 5 + semiring2 = agg._semiring2[semiring.return_type] + step2 = expr._new_vector(semiring2.return_type, size=1) + step2(**opts) << semiring2(step1 @ init2) + if agg._finalize is not None: + finalize = agg._finalize[semiring2.return_type] + if step2.dtype == finalize.return_type: + step2 << finalize(step2) + else: + step2 = finalize(step2).new(finalize.return_type, **opts) + if in_composite: + return step2 + updater << step2[0] + else: + raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") + + def __reduce__(self): + return (getitem, (self.parent, self.type)) + + __call__ = Aggregator.__call__ + + +# Monoid-only +agg.sum = Aggregator("sum", monoid=monoid.plus) +agg.prod = Aggregator("prod", monoid=monoid.times) +agg.all = Aggregator("all", monoid=monoid.land) +agg.any = Aggregator("any", monoid=monoid.lor) +agg.min = Aggregator("min", monoid=monoid.min) +agg.max = Aggregator("max", monoid=monoid.max) +agg.any_value = Aggregator("any_value", monoid=monoid.any, any_dtype=True) +agg.bitwise_all = Aggregator("bitwise_all", monoid=monoid.band) +agg.bitwise_any = Aggregator("bitwise_any", monoid=monoid.bor) +# Other monoids: bxnor bxor eq lxnor lxor + +# Semiring-only +agg.count = Aggregator( + "count", semiring=semiring.plus_pair, semiring2=semiring.plus_first, any_dtype=INT64 +) +agg.count_nonzero = Aggregator( + "count_nonzero", semiring=semiring.plus_isne, semiring2=semiring.plus_first +) +agg.count_zero = Aggregator( + "count_zero", semiring=semiring.plus_iseq, semiring2=semiring.plus_first +) +agg.sum_of_squares = Aggregator( + "sum_of_squares", initval=2, semiring=semiring.plus_pow, semiring2=semiring.plus_first +) +agg.sum_of_inverses = Aggregator( + "sum_of_inverses", + initval=-1.0, + semiring=semiring.plus_pow, + semiring2=semiring.plus_first, +) +agg.exists = Aggregator( + "exists", semiring=semiring.any_pair, semiring2=semiring.any_pair, any_dtype=INT64 +) + +# Semiring and finalize +agg.hypot = Aggregator( + "hypot", + initval=2, + semiring=semiring.plus_pow, + semiring2=semiring.plus_first, + finalize=unary.sqrt, +) +agg.logaddexp = Aggregator( + "logaddexp", + initval=np.e, + semiring=semiring.plus_pow, + switch=True, + semiring2=semiring.plus_first, + finalize=unary.log, +) +agg.logaddexp2 = Aggregator( + "logaddexp2", + initval=2, + semiring=semiring.plus_pow, + switch=True, + semiring2=semiring.plus_first, + finalize=unary.log2, +) +# Alternatives +# logaddexp = Aggregator('logaddexp', monoid=semiring.numpy.logaddexp) +# logaddexp2 = Aggregator('logaddexp2', monoid=semiring.numpy.logaddexp2) +# hypot as monoid doesn't work if single negative element! +# hypot = Aggregator('hypot', monoid=semiring.numpy.hypot) + +agg.L0norm = agg.count_nonzero +agg.L1norm = Aggregator("L1norm", semiring="plus_absfirst", semiring2=semiring.plus_first) +agg.L2norm = agg.hypot +agg.Linfnorm = Aggregator("Linfnorm", semiring="max_absfirst", semiring2=semiring.max_first) + + +# Composite +def _mean_finalize(c, x, opts): + return binary.truediv(x & c) + + +def _ptp_finalize(max, min, opts): + return binary.minus(max & min) + + +def _varp_finalize(c, x, x2, opts): + # / n - ( / n)**2 + left = binary.truediv(x2 & c).new(**opts) + right = binary.truediv(x & c).new(**opts) + right(**opts) << binary.pow(right, 2) + return binary.minus(left & right) + + +def _vars_finalize(c, x, x2, opts): + # / (n-1) - **2 / (n * (n-1)) + x(**opts) << binary.pow(x, 2) + right = binary.truediv(x & c).new(**opts) + c(**opts) << binary.minus(c, 1) + right(**opts) << binary.truediv(right & c) + left = binary.truediv(x2 & c).new(**opts) + return binary.minus(left & right) + + +def _stdp_finalize(c, x, x2, opts): + val = _varp_finalize(c, x, x2, opts).new(**opts) + return unary.sqrt(val) + + +def _stds_finalize(c, x, x2, opts): + val = _vars_finalize(c, x, x2, opts).new(**opts) + return unary.sqrt(val) + + +def _geometric_mean_finalize(c, x, opts): + right = unary.minv["FP64"](c).new(**opts) + return binary.pow(x & right) + + +def _harmonic_mean_finalize(c, x, opts): + return binary.truediv(c & x) + + +def _root_mean_square_finalize(c, x2, opts): + val = binary.truediv(x2 & c).new(**opts) + return unary.sqrt(val) + + +agg.mean = Aggregator( + "mean", + composite=[agg.count, agg.sum], + finalize=_mean_finalize, + types=[binary.truediv], +) +agg.peak_to_peak = Aggregator( + "peak_to_peak", + composite=[agg.max, agg.min], + finalize=_ptp_finalize, + types=[monoid.min], +) +agg.varp = Aggregator( + "varp", + composite=[agg.count, agg.sum, agg.sum_of_squares], + finalize=_varp_finalize, + types=[binary.truediv], +) +agg.vars = Aggregator( + "vars", + composite=[agg.count, agg.sum, agg.sum_of_squares], + finalize=_vars_finalize, + types=[binary.truediv], +) +agg.stdp = Aggregator( + "stdp", + composite=[agg.count, agg.sum, agg.sum_of_squares], + finalize=_stdp_finalize, + types=[binary.truediv, unary.sqrt], +) +agg.stds = Aggregator( + "stds", + composite=[agg.count, agg.sum, agg.sum_of_squares], + finalize=_stds_finalize, + types=[binary.truediv, unary.sqrt], +) +agg.geometric_mean = Aggregator( + "geometric_mean", + composite=[agg.count, agg.prod], + finalize=_geometric_mean_finalize, + types=[binary.truediv], +) +agg.harmonic_mean = Aggregator( + "harmonic_mean", + composite=[agg.count, agg.sum_of_inverses], + finalize=_harmonic_mean_finalize, + types=[agg.sum_of_inverses, binary.truediv], +) +agg.root_mean_square = Aggregator( + "root_mean_square", + composite=[agg.count, agg.sum_of_squares], + finalize=_root_mean_square_finalize, + types=[binary.truediv, unary.sqrt], +) + + +# Special recipes +def _argminmaxij( + agg, + updater, + expr, + opts, + *, + in_composite, + monoid, + col_semiring, + row_semiring, +): + if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": + A = expr.args[0] + if expr.method_name == "reduce_rowwise": + step1 = A.reduce_rowwise(monoid).new(**opts) + + D = step1.diag() + + masked = semiring.any_eq(D @ A).new(**opts) + masked(mask=masked.V, replace=True, **opts) << masked # Could use select + init = expr._new_vector(bool, size=A._ncols) + init(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 + updater << row_semiring(masked @ init) + if in_composite: + return updater.parent + else: + step1 = A.reduce_columnwise(monoid).new(**opts) + + D = step1.diag() + + masked = semiring.any_eq(A @ D).new(**opts) + masked(mask=masked.V, replace=True, **opts) << masked # Could use select + init = expr._new_vector(bool, size=A._nrows) + init(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 + updater << col_semiring(init @ masked) + if in_composite: + return updater.parent + elif expr.cfunc_name.startswith("GrB_Vector_reduce"): + v = expr.args[0] + step1 = v.reduce(monoid, allow_empty=False).new(**opts) + masked = binary.eq(v, step1).new(**opts) + masked(mask=masked.V, replace=True, **opts) << masked # Could use select + init = expr._new_matrix(bool, nrows=v._size, ncols=1) + init(**opts)[...] = False # O(1) dense column vector in SuiteSparse 5 + step2 = col_semiring(masked @ init).new(**opts) + if in_composite: + return step2 + updater << step2[0] + else: + raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") + + +def _argminmax(agg, updater, expr, opts, *, in_composite, monoid): + if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": + if expr.method_name == "reduce_rowwise": + return _argminmaxij( + agg, + updater, + expr, + opts, + in_composite=in_composite, + monoid=monoid, + row_semiring=semiring._deprecated["min_firstj"], + col_semiring=semiring._deprecated["min_secondj"], + ) + return _argminmaxij( + agg, + updater, + expr, + opts, + in_composite=in_composite, + monoid=monoid, + row_semiring=semiring._deprecated["min_firsti"], + col_semiring=semiring._deprecated["min_secondi"], + ) + if expr.cfunc_name.startswith("GrB_Vector_reduce"): + return _argminmaxij( + agg, + updater, + expr, + opts, + in_composite=in_composite, + monoid=monoid, + row_semiring=semiring._deprecated["min_firsti"], + col_semiring=semiring._deprecated["min_secondi"], + ) + if expr.cfunc_name.startswith("GrB_Matrix_reduce"): + raise ValueError(f"Aggregator {agg.name} may not be used with Matrix.reduce_scalar.") + raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") + + +# These "do the right thing", but don't work with `reduce_scalar` +_argmin = Aggregator( + "argmin", + custom=partial(_argminmax, monoid=monoid.min), + types=[semiring._deprecated["min_firsti"]], +) +_argmax = Aggregator( + "argmax", + custom=partial(_argminmax, monoid=monoid.max), + types=[semiring._deprecated["min_firsti"]], +) + + +def _first_last(agg, updater, expr, opts, *, in_composite, semiring_): + if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": + A = expr.args[0] + if expr.method_name == "reduce_columnwise": + A = A.T + init = expr._new_vector(bool, size=A._ncols) + init(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 + step1 = semiring_(A @ init).new(**opts) + Is, Js = step1.to_coo() + + Matrix_ = type(expr._new_matrix(bool)) + P = Matrix_.from_coo(Js, Is, 1, nrows=A._ncols, ncols=A._nrows) + mask = step1.diag() + result = semiring.any_first(A @ P).new(mask=mask.S, **opts).diag(**opts) + + updater << result + if in_composite: + return updater.parent + elif expr.cfunc_name.startswith("GrB_Vector_reduce"): + v = expr.args[0] + init = expr._new_matrix(bool, nrows=v._size, ncols=1) + init(**opts)[...] = False # O(1) dense matrix in SuiteSparse 5 + step1 = semiring_(v @ init).new(**opts) + index = step1[0].new().value + # `==` instead of `is` automatically triggers index.compute() in dask-graphblas: + if index == None: # noqa: E711 + index = 0 + if in_composite: + return v[[index]].new(**opts) + updater << v[index] + else: # GrB_Matrix_reduce + A = expr.args[0] + init1 = expr._new_matrix(bool, nrows=A._ncols, ncols=1) + init1(**opts)[...] = False # O(1) dense matrix in SuiteSparse 5 + step1 = semiring_(A @ init1).new(**opts) + init2 = expr._new_vector(bool, size=A._nrows) + init2(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 + step2 = semiring_(step1.T @ init2).new(**opts) + i = step2[0].new().value + # `==` instead of `is` automatically triggers i.compute() in dask-graphblas: + if i == None: # noqa: E711 + i = j = 0 + else: + j = step1[i, 0].new().value + if in_composite: + return A[i, [j]].new(**opts) + updater << A[i, j] + + +_first = Aggregator( + "first", + custom=partial(_first_last, semiring_=semiring._deprecated["min_secondi"]), + types=[binary.first], + any_dtype=True, +) +_last = Aggregator( + "last", + custom=partial(_first_last, semiring_=semiring._deprecated["max_secondi"]), + types=[binary.second], + any_dtype=True, +) + + +def _first_last_index(agg, updater, expr, opts, *, in_composite, semiring): + if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": + A = expr.args[0] + if expr.method_name == "reduce_columnwise": + A = A.T + init = expr._new_vector(bool, size=A._ncols) + init(**opts)[...] = False # O(1) dense vector in SuiteSparse 5 + expr = semiring(A @ init) + updater << expr + if in_composite: + return updater.parent + elif expr.cfunc_name.startswith("GrB_Vector_reduce"): + v = expr.args[0] + init = expr._new_matrix(bool, nrows=v._size, ncols=1) + init(**opts)[...] = False # O(1) dense matrix in SuiteSparse 5 + step1 = semiring(v @ init).new(**opts) + if in_composite: + return step1 + updater << step1[0] + elif expr.cfunc_name.startswith("GrB_Matrix_reduce"): + raise ValueError(f"Aggregator {agg.name} may not be used with Matrix.reduce_scalar.") + else: + raise NotImplementedError(f"{agg.name} with {expr.cfunc_name}") + + +_first_index = Aggregator( + "first_index", + custom=partial(_first_last_index, semiring=semiring._deprecated["min_secondi"]), + types=[semiring._deprecated["min_secondi"]], + any_dtype=INT64, +) +_last_index = Aggregator( + "last_index", + custom=partial(_first_last_index, semiring=semiring._deprecated["max_secondi"]), + types=[semiring._deprecated["min_secondi"]], + any_dtype=INT64, +) +agg._deprecated = { + "argmin": _argmin, + "argmax": _argmax, + "first": _first, + "last": _last, + "first_index": _first_index, + "last_index": _last_index, +} +if backend == "suitesparse": + agg.ss.argmin = _argmin + agg.ss.argmax = _argmax + agg.ss.first = _first + agg.ss.last = _last + agg.ss.first_index = _first_index + agg.ss.last_index = _last_index + +agg.Aggregator = Aggregator +agg.TypedAggregator = TypedAggregator + +from .utils import get_typed_op # noqa: E402 isort:skip diff --git a/graphblas/core/operator/base.py b/graphblas/core/operator/base.py new file mode 100644 index 000000000..ef92b41a4 --- /dev/null +++ b/graphblas/core/operator/base.py @@ -0,0 +1,532 @@ +from functools import lru_cache, reduce +from operator import getitem, mul +from types import BuiltinFunctionType, ModuleType + +import numba +import numpy as np + +from ... import _STANDARD_OPERATOR_NAMES, backend, op +from ...dtypes import BOOL, INT8, UINT64, _supports_complex, lookup_dtype +from .. import lib +from ..expr import InfixExprBase +from ..utils import output_type + +UNKNOWN_OPCLASS = "UnknownOpClass" + +# These now live as e.g. `gb.unary.ss.positioni` +# Deprecations such as `gb.unary.positioni` will be removed in 2023.9.0 or later. +_SS_OPERATORS = { + # unary + "erf", # scipy.special.erf + "erfc", # scipy.special.erfc + "frexpe", # np.frexp[1] + "frexpx", # np.frexp[0] + "lgamma", # scipy.special.loggamma + "tgamma", # scipy.special.gamma + # Positional + # unary + "positioni", + "positioni1", + "positionj", + "positionj1", + # binary + "firsti", + "firsti1", + "firstj", + "firstj1", + "secondi", + "secondi1", + "secondj", + "secondj1", + # semiring + "any_firsti", + "any_firsti1", + "any_firstj", + "any_firstj1", + "any_secondi", + "any_secondi1", + "any_secondj", + "any_secondj1", + "max_firsti", + "max_firsti1", + "max_firstj", + "max_firstj1", + "max_secondi", + "max_secondi1", + "max_secondj", + "max_secondj1", + "min_firsti", + "min_firsti1", + "min_firstj", + "min_firstj1", + "min_secondi", + "min_secondi1", + "min_secondj", + "min_secondj1", + "plus_firsti", + "plus_firsti1", + "plus_firstj", + "plus_firstj1", + "plus_secondi", + "plus_secondi1", + "plus_secondj", + "plus_secondj1", + "times_firsti", + "times_firsti1", + "times_firstj", + "times_firstj1", + "times_secondi", + "times_secondi1", + "times_secondj", + "times_secondj1", +} + + +def _hasop(module, name): + return ( + name in module.__dict__ + or name in module._delayed + or name in getattr(module, "_deprecated", ()) + ) + + +class OpPath: + def __init__(self, parent, name): + self._parent = parent + self._name = name + self._delayed = {} + self._delayed_commutes_to = {} + + def __getattr__(self, key): + if key in self._delayed: + func, kwargs = self._delayed.pop(key) + return func(**kwargs) + self.__getattribute__(key) # raises + + +def _call_op(op, left, right=None, thunk=None, **kwargs): + if right is None and thunk is None: + if isinstance(left, InfixExprBase): + # op(A & B), op(A | B), op(A @ B) + return getattr(left.left, left.method_name)(left.right, op, **kwargs) + if find_opclass(op)[1] == "Semiring": + raise TypeError( + f"Bad type when calling {op!r}. Got type: {type(left)}.\n" + f"Expected an infix expression, such as: {op!r}(A @ B)" + ) + raise TypeError( + f"Bad type when calling {op!r}. Got type: {type(left)}.\n" + "Expected an infix expression or an apply with a Vector or Matrix and a scalar:\n" + f" - {op!r}(A & B)\n" + f" - {op!r}(A, 1)\n" + f" - {op!r}(1, A)" + ) + + # op(A, 1) -> apply (or select if thunk provided) + from ..matrix import Matrix, TransposedMatrix + from ..vector import Vector + + if (left_type := output_type(left)) in {Vector, Matrix, TransposedMatrix}: + if thunk is not None: + return left.select(op, thunk=thunk, **kwargs) + return left.apply(op, right=right, **kwargs) + if (right_type := output_type(right)) in {Vector, Matrix, TransposedMatrix}: + return right.apply(op, left=left, **kwargs) + + from ..scalar import Scalar, _as_scalar + + if left_type is Scalar: + if thunk is not None: + return left.select(op, thunk=thunk, **kwargs) + return left.apply(op, right=right, **kwargs) + if right_type is Scalar: + return right.apply(op, left=left, **kwargs) + try: + left_scalar = _as_scalar(left, is_cscalar=False) + except Exception: + pass + else: + if thunk is not None: + return left_scalar.select(op, thunk=thunk, **kwargs) + return left_scalar.apply(op, right=right, **kwargs) + raise TypeError( + f"Bad types when calling {op!r}. Got types: {type(left)}, {type(right)}.\n" + "Expected an infix expression or an apply with a Vector or Matrix and a scalar:\n" + f" - {op!r}(A & B)\n" + f" - {op!r}(A, 1)\n" + f" - {op!r}(1, A)" + ) + + +_udt_mask_cache = {} + + +def _udt_mask(dtype): + """Create mask to determine which bytes of UDTs to use for equality check.""" + if dtype in _udt_mask_cache: + return _udt_mask_cache[dtype] + if dtype.subdtype is not None: + mask = _udt_mask(dtype.subdtype[0]) + N = reduce(mul, dtype.subdtype[1]) + rv = np.concatenate([mask] * N) + elif dtype.names is not None: + prev_offset = mask = None + masks = [] + for name in dtype.names: + dtype2, offset = dtype.fields[name] + if mask is not None: + masks.append(np.pad(mask, (0, offset - prev_offset - mask.size))) + mask = _udt_mask(dtype2) + prev_offset = offset + masks.append(np.pad(mask, (0, dtype.itemsize - prev_offset - mask.size))) + rv = np.concatenate(masks) + else: + rv = np.ones(dtype.itemsize, dtype=bool) + # assert rv.size == dtype.itemsize + _udt_mask_cache[dtype] = rv + return rv + + +def _get_udt_wrapper(numba_func, return_type, dtype, dtype2=None, *, include_indexes=False): + ztype = INT8 if return_type == BOOL else return_type + xtype = INT8 if dtype == BOOL else dtype + nt = numba.types + wrapper_args = [nt.CPointer(ztype.numba_type), nt.CPointer(xtype.numba_type)] + if include_indexes: + wrapper_args.extend([UINT64.numba_type, UINT64.numba_type]) + if dtype2 is not None: + ytype = INT8 if dtype2 == BOOL else dtype2 + wrapper_args.append(nt.CPointer(ytype.numba_type)) + wrapper_sig = nt.void(*wrapper_args) + + zarray = xarray = yarray = BL = BR = yarg = yname = rcidx = "" + if return_type._is_udt: + if return_type.np_type.subdtype is None: + zarray = " z = numba.carray(z_ptr, 1)\n" + zname = "z[0]" + else: + zname = "z_ptr[0]" + BR = "[0]" + else: + zname = "z_ptr[0]" + if return_type == BOOL: + BL = "bool(" + BR = ")" + + if dtype._is_udt: + if dtype.np_type.subdtype is None: + xarray = " x = numba.carray(x_ptr, 1)\n" + xname = "x[0]" + else: + xname = "x_ptr" + elif dtype == BOOL: + xname = "bool(x_ptr[0])" + else: + xname = "x_ptr[0]" + + if dtype2 is not None: + yarg = ", y_ptr" + if dtype2._is_udt: + if dtype2.np_type.subdtype is None: + yarray = " y = numba.carray(y_ptr, 1)\n" + yname = ", y[0]" + else: + yname = ", y_ptr" + elif dtype2 == BOOL: + yname = ", bool(y_ptr[0])" + else: + yname = ", y_ptr[0]" + + if include_indexes: + rcidx = ", row, col" + + d = {"numba": numba, "numba_func": numba_func} + text = ( + f"def wrapper(z_ptr, x_ptr{rcidx}{yarg}):\n" + f"{zarray}{xarray}{yarray}" + f" {zname} = {BL}numba_func({xname}{rcidx}{yname}){BR}\n" + ) + exec(text, d) # pylint: disable=exec-used + return d["wrapper"], wrapper_sig + + +class TypedOpBase: + __slots__ = ( + "parent", + "name", + "type", + "return_type", + "gb_obj", + "gb_name", + "_type2", + "__weakref__", + ) + + def __init__(self, parent, name, type_, return_type, gb_obj, gb_name, dtype2=None): + self.parent = parent + self.name = name + self.type = type_ + self.return_type = return_type + self.gb_obj = gb_obj + self.gb_name = gb_name + self._type2 = dtype2 + + def __repr__(self): + classname = self.opclass.lower() + if classname.endswith("op"): + classname = classname[:-2] + dtype2 = "" if self._type2 is None else f", {self._type2.name}" + return f"{classname}.{self.name}[{self.type.name}{dtype2}]" + + @property + def _carg(self): + return self.gb_obj + + @property + def is_positional(self): + return self.parent.is_positional + + def __reduce__(self): + if self._type2 is None or self.type == self._type2: + return (getitem, (self.parent, self.type)) + return (getitem, (self.parent, (self.type, self._type2))) + + +def _deserialize_parameterized(parameterized_op, args, kwargs): + return parameterized_op(*args, **kwargs) + + +class ParameterizedUdf: + __slots__ = "name", "__call__", "_anonymous", "__weakref__" + is_positional = False + _custom_dtype = None + + def __init__(self, name, anonymous): + self.name = name + self._anonymous = anonymous + # lru_cache per instance + method = self._call.__get__(self, type(self)) + self.__call__ = lru_cache(maxsize=1024)(method) + + def _call(self, *args, **kwargs): + raise NotImplementedError + + +_VARNAMES = tuple(x for x in dir(lib) if x[0] != "_") + + +class OpBase: + __slots__ = ( + "name", + "_typed_ops", + "types", + "coercions", + "_anonymous", + "_udt_types", + "_udt_ops", + "__weakref__", + ) + _parse_config = None + _initialized = False + _module = None + _positional = None + + def __init__(self, name, *, anonymous=False): + self.name = name + self._typed_ops = {} + self.types = {} + self.coercions = {} + self._anonymous = anonymous + self._udt_types = None + self._udt_ops = None + + def __repr__(self): + return f"{self._modname}.{self.name}" + + def __getitem__(self, type_): + if type(type_) is tuple: + from .utils import get_typed_op + + dtype1, dtype2 = type_ + dtype1 = lookup_dtype(dtype1) + dtype2 = lookup_dtype(dtype2) + return get_typed_op(self, dtype1, dtype2) + if not self._is_udt: + type_ = lookup_dtype(type_) + if type_ not in self._typed_ops: + if self._udt_types is None: + if self.is_positional: + return self._typed_ops[UINT64] + raise KeyError(f"{self.name} does not work with {type_}") + else: + return self._typed_ops[type_] + # This is a UDT or is able to operate on UDTs such as `first` any `any` + dtype = lookup_dtype(type_) + return self._compile_udt(dtype, dtype) + + def _add(self, op): + self._typed_ops[op.type] = op + self.types[op.type] = op.return_type + + def __delitem__(self, type_): + type_ = lookup_dtype(type_) + del self._typed_ops[type_] + del self.types[type_] + + def __contains__(self, type_): + try: + self[type_] + except (TypeError, KeyError, numba.NumbaError): + return False + return True + + @classmethod + def _remove_nesting(cls, funcname, *, module=None, modname=None, strict=True): + if module is None: + module = cls._module + if modname is None: + modname = cls._modname + if "." not in funcname: + if strict and _hasop(module, funcname): + raise AttributeError(f"{modname}.{funcname} is already defined") + else: + path, funcname = funcname.rsplit(".", 1) + for folder in path.split("."): + if not _hasop(module, folder): + setattr(module, folder, OpPath(module, folder)) + module = getattr(module, folder) + modname = f"{modname}.{folder}" + if not isinstance(module, (OpPath, ModuleType)): + raise AttributeError( + f"{modname} is already defined. Cannot use as a nested path." + ) + if strict and _hasop(module, funcname): + raise AttributeError(f"{path}.{funcname} is already defined") + return module, funcname + + @classmethod + def _find(cls, funcname): + rv = cls._module + for attr in funcname.split("."): + if attr in getattr(rv, "_deprecated", ()): + rv = rv._deprecated[attr] + else: + rv = getattr(rv, attr, None) + if rv is None: + break + return rv + + @classmethod + def _initialize(cls, include_in_ops=True): + """ + include_in_ops determines whether the operators are included in the + `gb.ops` namespace in addition to the defined module. + """ + if cls._initialized: # pragma: no cover (safety) + return + # Read in the parse configs + trim_from_front = cls._parse_config.get("trim_from_front", 0) + delete_exact = cls._parse_config.get("delete_exact", None) + num_underscores = cls._parse_config["num_underscores"] + + for re_str, return_prefix in [ + ("re_exprs", None), + ("re_exprs_return_bool", "BOOL"), + ("re_exprs_return_float", "FP"), + ("re_exprs_return_complex", "FC"), + ]: + if re_str not in cls._parse_config: + continue + if "complex" in re_str and not _supports_complex: + continue + for r in reversed(cls._parse_config[re_str]): + for varname in _VARNAMES: + m = r.match(varname) + if m: + # Parse function into name and datatype + gb_name = m.string + splitname = gb_name[trim_from_front:].split("_") + if delete_exact and delete_exact in splitname: + splitname.remove(delete_exact) + if len(splitname) == num_underscores + 1: + *splitname, type_ = splitname + else: + type_ = None + name = "_".join(splitname).lower() + # Create object for name unless it already exists + if not _hasop(cls._module, name): + if backend == "suitesparse" and name in _SS_OPERATORS: + fullname = f"ss.{name}" + else: + fullname = name + if cls._positional is None: + obj = cls(fullname) + else: + obj = cls(fullname, is_positional=name in cls._positional) + if name in _SS_OPERATORS: + if backend == "suitesparse": + setattr(cls._module.ss, name, obj) + cls._module._deprecated[name] = obj + if include_in_ops and not _hasop(op, name): # pragma: no branch + op._deprecated[name] = obj + if backend == "suitesparse": + setattr(op.ss, name, obj) + else: + setattr(cls._module, name, obj) + if include_in_ops and not _hasop(op, name): + setattr(op, name, obj) + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{fullname}") + elif name in _SS_OPERATORS: + obj = cls._module._deprecated[name] + else: + obj = getattr(cls._module, name) + gb_obj = getattr(lib, varname) + # Determine return type + if return_prefix == "BOOL": + return_type = BOOL + if type_ is None: + type_ = BOOL + else: + if type_ is None: # pragma: no cover + raise TypeError(f"Unable to determine return type for {varname}") + if return_prefix is None: + return_type = type_ + else: + # Grab the number of bits from type_ + num_bits = type_[-2:] + if num_bits not in {"32", "64"}: # pragma: no cover (safety) + raise TypeError(f"Unexpected number of bits: {num_bits}") + return_type = f"{return_prefix}{num_bits}" + builtin_op = cls._typed_class( + obj, + name, + lookup_dtype(type_), + lookup_dtype(return_type), + gb_obj, + gb_name, + ) + obj._add(builtin_op) + + @classmethod + def _deserialize(cls, name, *args): + if (rv := cls._find(name)) is not None: + return rv # Should we verify this is what the user expects? + return cls.register_new(name, *args) + + +_builtin_to_op = {} # Populated in .utils + + +def find_opclass(gb_op): + if isinstance(gb_op, OpBase): + opclass = type(gb_op).__name__ + elif isinstance(gb_op, TypedOpBase): + opclass = gb_op.opclass + elif isinstance(gb_op, ParameterizedUdf): + gb_op = gb_op() # Use default parameters of parameterized UDFs + gb_op, opclass = find_opclass(gb_op) + elif isinstance(gb_op, BuiltinFunctionType) and gb_op in _builtin_to_op: + gb_op, opclass = find_opclass(_builtin_to_op[gb_op]) + else: + opclass = UNKNOWN_OPCLASS + return gb_op, opclass diff --git a/graphblas/core/operator/binary.py b/graphblas/core/operator/binary.py new file mode 100644 index 000000000..eeb72ea3b --- /dev/null +++ b/graphblas/core/operator/binary.py @@ -0,0 +1,864 @@ +import inspect +import re +from functools import lru_cache +from types import FunctionType + +import numba +import numpy as np + +from ... import _STANDARD_OPERATOR_NAMES, backend, binary, monoid, op +from ...dtypes import ( + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + _sample_values, + _supports_complex, + lookup_dtype, +) +from ...exceptions import UdfParseError, check_status_carg +from .. import ffi, lib +from ..expr import InfixExprBase +from .base import ( + _SS_OPERATORS, + OpBase, + ParameterizedUdf, + TypedOpBase, + _call_op, + _deserialize_parameterized, + _get_udt_wrapper, + _hasop, + _udt_mask, +) + +if _supports_complex: + from ...dtypes import FC32, FC64 + +ffi_new = ffi.new + + +class TypedBuiltinBinaryOp(TypedOpBase): + __slots__ = () + opclass = "BinaryOp" + + def __call__(self, left, right=None, *, left_default=None, right_default=None): + if left_default is not None or right_default is not None: + if ( + left_default is None + or right_default is None + or right is not None + or not isinstance(left, InfixExprBase) + or left.method_name != "ewise_add" + ): + raise TypeError( + "Specifying `left_default` or `right_default` keyword arguments implies " + "performing `ewise_union` operation with infix notation.\n" + "There is only one valid way to do this:\n\n" + f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " + "are Vectors or Matrices, and left_default and right_default are scalars." + ) + return left.left.ewise_union(left.right, self, left_default, right_default) + return _call_op(self, left, right) + + @property + def monoid(self): + rv = getattr(monoid, self.name, None) + if rv is not None and self.type in rv._typed_ops: + return rv[self.type] + + @property + def commutes_to(self): + commutes_to = self.parent.commutes_to + if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt): + return commutes_to[self.type] + + @property + def _semiring_commutes_to(self): + commutes_to = self.parent._semiring_commutes_to + if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt): + return commutes_to[self.type] + + @property + def is_commutative(self): + return self.commutes_to is self + + @property + def type2(self): + return self.type if self._type2 is None else self._type2 + + +class TypedUserBinaryOp(TypedOpBase): + __slots__ = "_monoid" + opclass = "BinaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) + self._monoid = None + + @property + def monoid(self): + if self._monoid is None: + monoid = self.parent.monoid + if monoid is not None and self.type in monoid: + self._monoid = monoid[self.type] + return self._monoid + + @property + def orig_func(self): + return self.parent.orig_func + + @property + def _numba_func(self): + return self.parent._numba_func + + commutes_to = TypedBuiltinBinaryOp.commutes_to + _semiring_commutes_to = TypedBuiltinBinaryOp._semiring_commutes_to + is_commutative = TypedBuiltinBinaryOp.is_commutative + type2 = TypedBuiltinBinaryOp.type2 + __call__ = TypedBuiltinBinaryOp.__call__ + + +class ParameterizedBinaryOp(ParameterizedUdf): + __slots__ = "func", "__signature__", "_monoid", "_cached_call", "_commutes_to", "_is_udt" + + def __init__(self, name, func, *, anonymous=False, is_udt=False): + self.func = func + self.__signature__ = inspect.signature(func) + self._monoid = None + self._is_udt = is_udt + if name is None: + name = getattr(func, "__name__", name) + super().__init__(name, anonymous) + method = self._call_to_cache.__get__(self, type(self)) + self._cached_call = lru_cache(maxsize=1024)(method) + self.__call__ = self._call + self._commutes_to = None + + def _call_to_cache(self, *args, **kwargs): + binary = self.func(*args, **kwargs) + binary._parameterized_info = (self, args, kwargs) + return BinaryOp.register_anonymous(binary, self.name, is_udt=self._is_udt) + + def _call(self, *args, **kwargs): + binop = self._cached_call(*args, **kwargs) + if self._monoid is not None and binop._monoid is None: + # This is all a bit funky. We try our best to associate a binaryop + # to a monoid. So, if we made a ParameterizedMonoid using this object, + # then try to create a monoid with the given arguments. + binop._monoid = binop # temporary! + try: + # If this call is successful, then it will set `binop._monoid` + self._monoid(*args, **kwargs) # pylint: disable=not-callable + except Exception: + binop._monoid = None + # assert binop._monoid is not binop + if self.is_commutative: + binop._commutes_to = binop + # Don't bother yet with creating `binop.commutes_to` (but we could!) + return binop + + @property + def monoid(self): + return self._monoid + + @property + def commutes_to(self): + if type(self._commutes_to) is str: + self._commutes_to = BinaryOp._find(self._commutes_to) + return self._commutes_to + + is_commutative = TypedBuiltinBinaryOp.is_commutative + + def __reduce__(self): + name = f"binary.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.func, self._anonymous)) + + @staticmethod + def _deserialize(name, func, anonymous): + if anonymous: + return BinaryOp.register_anonymous(func, name, parameterized=True) + if (rv := BinaryOp._find(name)) is not None: + return rv + return BinaryOp.register_new(name, func, parameterized=True) + + +def _floordiv(x, y): + return x // y # pragma: no cover (numba) + + +def _rfloordiv(x, y): + return y // x # pragma: no cover (numba) + + +def _absfirst(x, y): + return np.abs(x) # pragma: no cover (numba) + + +def _abssecond(x, y): + return np.abs(y) # pragma: no cover (numba) + + +def _rpow(x, y): + return y**x # pragma: no cover (numba) + + +def _isclose(rel_tol=1e-7, abs_tol=0.0): + def inner(x, y): # pragma: no cover (numba) + return x == y or abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol) + + return inner + + +_MAX_INT64 = np.iinfo(np.int64).max + + +def _binom(N, k): # pragma: no cover (numba) + # Returns 0 if overflow or out-of-bounds + if k > N or k < 0: + return 0 + val = np.int64(1) + for i in range(min(k, N - k)): + if val > _MAX_INT64 // (N - i): # Overflow + return 0 + val *= N - i + val //= i + 1 + return val + + +# Kinda complicated, but works for now +def _register_binom(): + # "Fake" UDT so we only compile once for INT64 + op = BinaryOp.register_new("binom", _binom, is_udt=True) + typed_op = op[INT64, INT64] + # Make this look like a normal operator + for dtype in [UINT8, UINT16, UINT32, UINT64, INT8, INT16, INT32, INT64]: + op.types[dtype] = INT64 + op._typed_ops[dtype] = typed_op + if dtype != INT64: + op.coercions[dtype] = typed_op + # And make it not look like it operates on UDTs + typed_op._type2 = None + op._is_udt = False + op._udt_types = None + op._udt_ops = None + return op + + +def _first(x, y): + return x # pragma: no cover (numba) + + +def _second(x, y): + return y # pragma: no cover (numba) + + +def _pair(x, y): + return 1 # pragma: no cover (numba) + + +def _first_dtype(op, dtype, dtype2): + if dtype._is_udt or dtype2._is_udt: + return op._compile_udt(dtype, dtype2) + + +def _second_dtype(op, dtype, dtype2): + if dtype._is_udt or dtype2._is_udt: + return op._compile_udt(dtype, dtype2) + + +def _pair_dtype(op, dtype, dtype2): + return op[INT64] + + +class BinaryOp(OpBase): + """Takes two inputs and returns one output, possibly of a different data type. + + Built-in and registered BinaryOps are located in the ``graphblas.binary`` namespace + as well as in the ``graphblas.ops`` combined namespace. + """ + + __slots__ = ( + "_monoid", + "_commutes_to", + "_semiring_commutes_to", + "orig_func", + "is_positional", + "_is_udt", + "_numba_func", + "_custom_dtype", + ) + _module = binary + _modname = "binary" + _typed_class = TypedBuiltinBinaryOp + _parse_config = { + "trim_from_front": 4, + "num_underscores": 1, + "re_exprs": [ + re.compile( + "^GrB_(FIRST|SECOND|PLUS|MINUS|TIMES|DIV|MIN|MAX)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" + ), + re.compile( + "GrB_(BOR|BAND|BXOR|BXNOR)_(INT8|INT16|INT32|INT64|UINT8|UINT16|UINT32|UINT64)$" + ), + re.compile( + "^GxB_(POW|RMINUS|RDIV|PAIR|ANY|ISEQ|ISNE|ISGT|ISLT|ISGE|ISLE|LOR|LAND|LXOR)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" + ), + re.compile("^GxB_(FIRST|SECOND|PLUS|MINUS|TIMES|DIV)_(FC32|FC64)$"), + re.compile("^GxB_(ATAN2|HYPOT|FMOD|REMAINDER|LDEXP|COPYSIGN)_(FP32|FP64)$"), + re.compile( + "GxB_(BGET|BSET|BCLR|BSHIFT|FIRSTI1|FIRSTI|FIRSTJ1|FIRSTJ" + "|SECONDI1|SECONDI|SECONDJ1|SECONDJ)" + "_(INT8|INT16|INT32|INT64|UINT8|UINT16|UINT32|UINT64)$" + ), + # These are coerced to 0 or 1, but don't return BOOL + re.compile( + "^GxB_(LOR|LAND|LXOR|LXNOR)_" + "(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + ], + "re_exprs_return_bool": [ + re.compile("^GrB_(LOR|LAND|LXOR|LXNOR)$"), + re.compile( + "^GrB_(EQ|NE|GT|LT|GE|LE)_" + "(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile("^GxB_(EQ|NE)_(FC32|FC64)$"), + ], + "re_exprs_return_complex": [re.compile("^GxB_(CMPLX)_(FP32|FP64)$")], + } + _commutes = { + # builtins + "cdiv": "rdiv", + "first": "second", + "ge": "le", + "gt": "lt", + "isge": "isle", + "isgt": "islt", + "minus": "rminus", + "pow": "rpow", + # special + "firsti": "secondi", + "firsti1": "secondi1", + "firstj": "secondj", + "firstj1": "secondj1", + # custom + # "absfirst": "abssecond", # handled in graphblas.binary + # "floordiv": "rfloordiv", + "truediv": "rtruediv", + } + _commutes_to_in_semiring = { + "firsti": "secondj", + "firsti1": "secondj1", + "firstj": "secondi", + "firstj1": "secondi1", + } + _commutative = { + # monoids + "any", + "band", + "bor", + "bxnor", + "bxor", + "eq", + "land", + "lor", + "lxnor", + "lxor", + "max", + "min", + "plus", + "times", + # other + "hypot", + "isclose", + "iseq", + "isne", + "ne", + "pair", + } + # Don't commute: atan2, bclr, bget, bset, bshift, cmplx, copysign, fmod, ldexp, remainder + _positional = { + "firsti", + "firsti1", + "firstj", + "firstj1", + "secondi", + "secondi1", + "secondj", + "secondj1", + } + + @classmethod + def _build(cls, name, func, *, is_udt=False, anonymous=False): + if not isinstance(func, FunctionType): + raise TypeError(f"UDF argument must be a function, not {type(func)}") + if name is None: + name = getattr(func, "__name__", "") + success = False + binary_udf = numba.njit(func) + new_type_obj = cls(name, func, anonymous=anonymous, is_udt=is_udt, numba_func=binary_udf) + return_types = {} + nt = numba.types + if not is_udt: + for type_ in _sample_values: + sig = (type_.numba_type, type_.numba_type) + try: + binary_udf.compile(sig) + except numba.TypingError: + continue + ret_type = lookup_dtype(binary_udf.overloads[sig].signature.return_type) + if ret_type != type_ and ( + ("INT" in ret_type.name and "INT" in type_.name) + or ("FP" in ret_type.name and "FP" in type_.name) + or ("FC" in ret_type.name and "FC" in type_.name) + or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) + ): + # Downcast `ret_type` to `type_`. + # This is what users want most of the time, but we can't make a perfect rule. + # There should be a way for users to be explicit. + ret_type = type_ + elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: + ret_type = INT8 + + # Numba is unable to handle BOOL correctly right now, but we have a workaround + # See: https://github.com/numba/numba/issues/5395 + # We're relying on coercion behaving correctly here + input_type = INT8 if type_ == BOOL else type_ + return_type = INT8 if ret_type == BOOL else ret_type + + # Build wrapper because GraphBLAS wants pointers and void return + wrapper_sig = nt.void( + nt.CPointer(return_type.numba_type), + nt.CPointer(input_type.numba_type), + nt.CPointer(input_type.numba_type), + ) + + if type_ == BOOL: + if ret_type == BOOL: + + def binary_wrapper(z, x, y): # pragma: no cover (numba) + z[0] = bool(binary_udf(bool(x[0]), bool(y[0]))) + + else: + + def binary_wrapper(z, x, y): # pragma: no cover (numba) + z[0] = binary_udf(bool(x[0]), bool(y[0])) + + elif ret_type == BOOL: + + def binary_wrapper(z, x, y): # pragma: no cover (numba) + z[0] = bool(binary_udf(x[0], y[0])) + + else: + + def binary_wrapper(z, x, y): # pragma: no cover (numba) + z[0] = binary_udf(x[0], y[0]) + + binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper) + new_binary = ffi_new("GrB_BinaryOp*") + check_status_carg( + lib.GrB_BinaryOp_new( + new_binary, + binary_wrapper.cffi, + ret_type.gb_obj, + type_.gb_obj, + type_.gb_obj, + ), + "BinaryOp", + new_binary, + ) + op = TypedUserBinaryOp(new_type_obj, name, type_, ret_type, new_binary[0]) + new_type_obj._add(op) + success = True + return_types[type_] = ret_type + if success or is_udt: + return new_type_obj + raise UdfParseError("Unable to parse function using Numba") + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: + dtype2 = dtype + dtypes = (dtype, dtype2) + if dtypes in self._udt_types: + return self._udt_ops[dtypes] + + nt = numba.types + if self.name == "eq" and not self._anonymous: + # assert dtype.np_type == dtype2.np_type + itemsize = dtype.np_type.itemsize + mask = _udt_mask(dtype.np_type) + ret_type = BOOL + wrapper_sig = nt.void( + nt.CPointer(INT8.numba_type), + nt.CPointer(UINT8.numba_type), + nt.CPointer(UINT8.numba_type), + ) + # PERF: we can probably make this faster + if mask.all(): + + def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) + x = numba.carray(x_ptr, itemsize) + y = numba.carray(y_ptr, itemsize) + # for i in range(itemsize): + # if x[i] != y[i]: + # z_ptr[0] = False + # break + # else: + # z_ptr[0] = True + z_ptr[0] = (x == y).all() + + else: + + def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) + x = numba.carray(x_ptr, itemsize) + y = numba.carray(y_ptr, itemsize) + # for i in range(itemsize): + # if mask[i] and x[i] != y[i]: + # z_ptr[0] = False + # break + # else: + # z_ptr[0] = True + z_ptr[0] = (x[mask] == y[mask]).all() + + elif self.name == "ne" and not self._anonymous: + # assert dtype.np_type == dtype2.np_type + itemsize = dtype.np_type.itemsize + mask = _udt_mask(dtype.np_type) + ret_type = BOOL + wrapper_sig = nt.void( + nt.CPointer(INT8.numba_type), + nt.CPointer(UINT8.numba_type), + nt.CPointer(UINT8.numba_type), + ) + if mask.all(): + + def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) + x = numba.carray(x_ptr, itemsize) + y = numba.carray(y_ptr, itemsize) + # for i in range(itemsize): + # if x[i] != y[i]: + # z_ptr[0] = True + # break + # else: + # z_ptr[0] = False + z_ptr[0] = (x != y).any() + + else: + + def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) + x = numba.carray(x_ptr, itemsize) + y = numba.carray(y_ptr, itemsize) + # for i in range(itemsize): + # if mask[i] and x[i] != y[i]: + # z_ptr[0] = True + # break + # else: + # z_ptr[0] = False + z_ptr[0] = (x[mask] != y[mask]).any() + + else: + numba_func = self._numba_func + sig = (dtype.numba_type, dtype2.numba_type) + numba_func.compile(sig) # Should we catch and give additional error message? + ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) + binary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype, dtype2) + + binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper) + new_binary = ffi_new("GrB_BinaryOp*") + check_status_carg( + lib.GrB_BinaryOp_new( + new_binary, binary_wrapper.cffi, ret_type._carg, dtype._carg, dtype2._carg + ), + "BinaryOp", + new_binary, + ) + op = TypedUserBinaryOp( + self, + self.name, + dtype, + ret_type, + new_binary[0], + dtype2=dtype2, + ) + self._udt_types[dtypes] = ret_type + self._udt_ops[dtypes] = op + return op + + @classmethod + def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): + """Register a BinaryOp without registering it in the ``graphblas.binary`` namespace. + + Because it is not registered in the namespace, the name is optional. + """ + if parameterized: + return ParameterizedBinaryOp(name, func, anonymous=True, is_udt=is_udt) + return cls._build(name, func, anonymous=True, is_udt=is_udt) + + @classmethod + def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): + """Register a BinaryOp. The name will be used to identify the BinaryOp in the + ``graphblas.binary`` namespace. + + >>> def max_zero(x, y): + r = 0 + if x > r: + r = x + if y > r: + r = y + return r + >>> gb.core.operator.BinaryOp.register_new("max_zero", max_zero) + >>> dir(gb.binary) + [..., 'max_zero', ...] + """ + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "func": func, "parameterized": parameterized}, + ) + elif parameterized: + binary_op = ParameterizedBinaryOp(name, func, is_udt=is_udt) + setattr(module, funcname, binary_op) + else: + binary_op = cls._build(name, func, is_udt=is_udt) + setattr(module, funcname, binary_op) + # Also save it to `graphblas.op` if not yet defined + opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) + if not _hasop(opmodule, funcname): + if lazy: + opmodule._delayed[funcname] = module + else: + setattr(opmodule, funcname, binary_op) + if not cls._initialized: + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return binary_op + + @classmethod + def _initialize(cls): + if cls._initialized: # pragma: no cover (safety) + return + super()._initialize() + # Rename div to cdiv + cdiv = binary.cdiv = op.cdiv = BinaryOp("cdiv") + for dtype, ret_type in binary.div.types.items(): + orig_op = binary.div[dtype] + cur_op = TypedBuiltinBinaryOp( + cdiv, "cdiv", dtype, ret_type, orig_op.gb_obj, orig_op.gb_name + ) + cdiv._add(cur_op) + del binary.div + del op.div + # Add truediv which always points to floating point cdiv + # We are effectively hacking cdiv to always return floating point values + # If the inputs are FP32, we use DIV_FP32; use DIV_FP64 for all other input dtypes + truediv = binary.truediv = op.truediv = BinaryOp("truediv") + rtruediv = binary.rtruediv = op.rtruediv = BinaryOp("rtruediv") + for new_op, builtin_op in [(truediv, binary.cdiv), (rtruediv, binary.rdiv)]: + for dtype in builtin_op.types: + if dtype.name in {"FP32", "FC32", "FC64"}: + orig_dtype = dtype + else: + orig_dtype = FP64 + orig_op = builtin_op[orig_dtype] + cur_op = TypedBuiltinBinaryOp( + new_op, + new_op.name, + dtype, + builtin_op.types[orig_dtype], + orig_op.gb_obj, + orig_op.gb_name, + ) + new_op._add(cur_op) + # Add floordiv + # cdiv truncates towards 0, while floordiv truncates towards -inf + BinaryOp.register_new("floordiv", _floordiv, lazy=True) # cast to integer + BinaryOp.register_new("rfloordiv", _rfloordiv, lazy=True) # cast to integer + + # For aggregators + BinaryOp.register_new("absfirst", _absfirst, lazy=True) + BinaryOp.register_new("abssecond", _abssecond, lazy=True) + BinaryOp.register_new("rpow", _rpow, lazy=True) + + # For algorithms + binary._delayed["binom"] = (_register_binom, {}) # Lazy with custom creation + op._delayed["binom"] = binary + + BinaryOp.register_new("isclose", _isclose, parameterized=True) + + # Update type information with sane coercion + position_dtypes = [ + BOOL, + FP32, + FP64, + INT8, + INT16, + UINT8, + UINT16, + UINT32, + UINT64, + ] + if _supports_complex: + position_dtypes.extend([FC32, FC64]) + name_types = [ + # fmt: off + ( + ("atan2", "copysign", "fmod", "hypot", "ldexp", "remainder"), + ((BOOL, INT8, INT16, UINT8, UINT16), FP32), + ((INT32, INT64, UINT32, UINT64), FP64), + ), + ( + ( + "firsti", "firsti1", "firstj", "firstj1", "secondi", "secondi1", + "secondj", "secondj1"), + ( + position_dtypes, + INT64, + ), + ), + ( + ["lxnor"], + ( + ( + FP32, FP64, INT8, INT16, INT32, INT64, + UINT8, UINT16, UINT32, UINT64, + ), + BOOL, + ), + ), + # fmt: on + ] + if _supports_complex: + name_types.append( + ( + ["cmplx"], + ((BOOL, INT8, INT16, UINT8, UINT16), FP32), + ((INT32, INT64, UINT32, UINT64), FP64), + ) + ) + for names, *types in name_types: + for name in names: + if name in _SS_OPERATORS: + cur_op = binary._deprecated[name] + else: + cur_op = getattr(binary, name) + for input_types, target_type in types: + typed_op = cur_op._typed_ops[target_type] + output_type = cur_op.types[target_type] + for dtype in input_types: + if dtype not in cur_op.types: # pragma: no branch (safety) + cur_op.types[dtype] = output_type + cur_op._typed_ops[dtype] = typed_op + cur_op.coercions[dtype] = target_type + # Not valid input dtypes + del binary.ldexp[FP32] + del binary.ldexp[FP64] + # Fill in commutes info + for left_name, right_name in cls._commutes.items(): + if left_name in _SS_OPERATORS: + left = binary._deprecated[left_name] + else: + left = getattr(binary, left_name) + if backend == "suitesparse" and right_name in _SS_OPERATORS: + left._commutes_to = f"ss.{right_name}" + else: + left._commutes_to = right_name + if right_name not in binary._delayed: + if right_name in _SS_OPERATORS: + right = binary._deprecated[right_name] + else: + right = getattr(binary, right_name) + if backend == "suitesparse" and left_name in _SS_OPERATORS: + right._commutes_to = f"ss.{left_name}" + else: + right._commutes_to = left_name + for name in cls._commutative: + cur_op = getattr(binary, name) + cur_op._commutes_to = name + for left_name, right_name in cls._commutes_to_in_semiring.items(): + if left_name in _SS_OPERATORS: + left = binary._deprecated[left_name] + else: # pragma: no cover (safety) + left = getattr(binary, left_name) + if right_name in _SS_OPERATORS: + right = binary._deprecated[right_name] + else: # pragma: no cover (safety) + right = getattr(binary, right_name) + left._semiring_commutes_to = right + right._semiring_commutes_to = left + # Allow some functions to work on UDTs + for binop, func in [ + (binary.first, _first), + (binary.second, _second), + (binary.pair, _pair), + (binary.any, _first), + ]: + binop.orig_func = func + binop._numba_func = numba.njit(func) + binop._udt_types = {} + binop._udt_ops = {} + binary.any._numba_func = binary.first._numba_func + binary.eq._udt_types = {} + binary.eq._udt_ops = {} + binary.ne._udt_types = {} + binary.ne._udt_ops = {} + # Set custom dtype handling + binary.first._custom_dtype = _first_dtype + binary.second._custom_dtype = _second_dtype + binary.pair._custom_dtype = _pair_dtype + cls._initialized = True + + def __init__( + self, + name, + func=None, + *, + anonymous=False, + is_positional=False, + is_udt=False, + numba_func=None, + ): + super().__init__(name, anonymous=anonymous) + self._monoid = None + self._commutes_to = None + self._semiring_commutes_to = None + self.orig_func = func + self._numba_func = numba_func + self._is_udt = is_udt + self.is_positional = is_positional + self._custom_dtype = None + if is_udt: + self._udt_types = {} # {(dtype, dtype): DataType} + self._udt_ops = {} # {(dtype, dtype): TypedUserBinaryOp} + + def __reduce__(self): + if self._anonymous: + if hasattr(self.orig_func, "_parameterized_info"): + return (_deserialize_parameterized, self.orig_func._parameterized_info) + return (self.register_anonymous, (self.orig_func, self.name)) + if (name := f"binary.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.orig_func)) + + __call__ = TypedBuiltinBinaryOp.__call__ + is_commutative = TypedBuiltinBinaryOp.is_commutative + commutes_to = ParameterizedBinaryOp.commutes_to + + @property + def monoid(self): + if self._monoid is None and not self._anonymous: + from .monoid import Monoid + + self._monoid = Monoid._find(self.name) + return self._monoid diff --git a/graphblas/core/operator/indexunary.py b/graphblas/core/operator/indexunary.py new file mode 100644 index 000000000..5fdafb62a --- /dev/null +++ b/graphblas/core/operator/indexunary.py @@ -0,0 +1,357 @@ +import inspect +import re +from types import FunctionType + +import numba + +from ... import _STANDARD_OPERATOR_NAMES, indexunary, select +from ...dtypes import BOOL, FP64, INT8, INT64, UINT64, _sample_values, lookup_dtype +from ...exceptions import UdfParseError, check_status_carg +from .. import ffi, lib +from .base import ( + OpBase, + ParameterizedUdf, + TypedOpBase, + _call_op, + _deserialize_parameterized, + _get_udt_wrapper, +) + +ffi_new = ffi.new + + +class TypedBuiltinIndexUnaryOp(TypedOpBase): + __slots__ = () + opclass = "IndexUnaryOp" + + def __call__(self, val, thunk=None): + if thunk is None: + thunk = False # most basic form of 0 when unifying dtypes + return _call_op(self, val, right=thunk) + + +class TypedUserIndexUnaryOp(TypedOpBase): + __slots__ = () + opclass = "IndexUnaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) + + @property + def orig_func(self): + return self.parent.orig_func + + @property + def _numba_func(self): + return self.parent._numba_func + + __call__ = TypedBuiltinIndexUnaryOp.__call__ + + +class ParameterizedIndexUnaryOp(ParameterizedUdf): + __slots__ = "func", "__signature__", "_is_udt" + + def __init__(self, name, func, *, anonymous=False, is_udt=False): + self.func = func + self.__signature__ = inspect.signature(func) + self._is_udt = is_udt + if name is None: + name = getattr(func, "__name__", name) + super().__init__(name, anonymous) + + def _call(self, *args, **kwargs): + indexunary = self.func(*args, **kwargs) + indexunary._parameterized_info = (self, args, kwargs) + return IndexUnaryOp.register_anonymous(indexunary, self.name, is_udt=self._is_udt) + + def __reduce__(self): + name = f"indexunary.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.func, self._anonymous)) + + @staticmethod + def _deserialize(name, func, anonymous): + if anonymous: + return IndexUnaryOp.register_anonymous(func, name, parameterized=True) + if (rv := IndexUnaryOp._find(name)) is not None: + return rv + return IndexUnaryOp.register_new(name, func, parameterized=True) + + +class IndexUnaryOp(OpBase): + """Takes one input and a thunk and returns one output, possibly of a different data type. + Along with the input value, the index(es) of the element are given to the function. + + This is an advanced form of a unary operation that allows, for example, converting + elements of a Vector to their index position to build a ramp structure. Another use + case is returning a boolean value indicating whether the element is part of the upper + triangular structure of a Matrix. + + Built-in and registered IndexUnaryOps are located in the ``graphblas.indexunary`` namespace. + """ + + __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" + _module = indexunary + _modname = "indexunary" + _custom_dtype = None + _typed_class = TypedBuiltinIndexUnaryOp + _typed_user_class = TypedUserIndexUnaryOp + _parse_config = { + "trim_from_front": 4, + "num_underscores": 1, + "re_exprs": [ + re.compile("^GrB_(ROWINDEX|COLINDEX|DIAGINDEX)_(INT32|INT64)$"), + ], + "re_exprs_return_bool": [ + re.compile("^GrB_(TRIL|TRIU|DIAG|OFFDIAG|COLLE|COLGT|ROWLE|ROWGT)$"), + re.compile( + "^GrB_(VALUEEQ|VALUENE|VALUEGT|VALUEGE|VALUELT|VALUELE)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile("^GxB_(VALUEEQ|VALUENE)_(FC32|FC64)$"), + ], + } + _positional = {"tril", "triu", "diag", "offdiag", "colle", "colgt", "rowle", "rowgt", + "rowindex", "colindex"} # fmt: skip + + @classmethod + def _build(cls, name, func, *, is_udt=False, anonymous=False): + if not isinstance(func, FunctionType): + raise TypeError(f"UDF argument must be a function, not {type(func)}") + if name is None: + name = getattr(func, "__name__", "") + success = False + indexunary_udf = numba.njit(func) + new_type_obj = cls( + name, func, anonymous=anonymous, is_udt=is_udt, numba_func=indexunary_udf + ) + return_types = {} + nt = numba.types + if not is_udt: + for type_ in _sample_values: + sig = (type_.numba_type, UINT64.numba_type, UINT64.numba_type, type_.numba_type) + try: + indexunary_udf.compile(sig) + except numba.TypingError: + continue + ret_type = lookup_dtype(indexunary_udf.overloads[sig].signature.return_type) + if ret_type != type_ and ( + ("INT" in ret_type.name and "INT" in type_.name) + or ("FP" in ret_type.name and "FP" in type_.name) + or ("FC" in ret_type.name and "FC" in type_.name) + or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) + ): + # Downcast `ret_type` to `type_`. + # This is what users want most of the time, but we can't make a perfect rule. + # There should be a way for users to be explicit. + ret_type = type_ + elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: + ret_type = INT8 + + # Numba is unable to handle BOOL correctly right now, but we have a workaround + # See: https://github.com/numba/numba/issues/5395 + # We're relying on coercion behaving correctly here + input_type = INT8 if type_ == BOOL else type_ + return_type = INT8 if ret_type == BOOL else ret_type + + # Build wrapper because GraphBLAS wants pointers and void return + wrapper_sig = nt.void( + nt.CPointer(return_type.numba_type), + nt.CPointer(input_type.numba_type), + UINT64.numba_type, + UINT64.numba_type, + nt.CPointer(input_type.numba_type), + ) + + if type_ == BOOL: + if ret_type == BOOL: + + def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) + z[0] = bool(indexunary_udf(bool(x[0]), row, col, bool(y[0]))) + + else: + + def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) + z[0] = indexunary_udf(bool(x[0]), row, col, bool(y[0])) + + elif ret_type == BOOL: + + def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) + z[0] = bool(indexunary_udf(x[0], row, col, y[0])) + + else: + + def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) + z[0] = indexunary_udf(x[0], row, col, y[0]) + + indexunary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(indexunary_wrapper) + new_indexunary = ffi_new("GrB_IndexUnaryOp*") + check_status_carg( + lib.GrB_IndexUnaryOp_new( + new_indexunary, + indexunary_wrapper.cffi, + ret_type.gb_obj, + type_.gb_obj, + type_.gb_obj, + ), + "IndexUnaryOp", + new_indexunary, + ) + op = cls._typed_user_class(new_type_obj, name, type_, ret_type, new_indexunary[0]) + new_type_obj._add(op) + success = True + return_types[type_] = ret_type + if success or is_udt: + return new_type_obj + raise UdfParseError("Unable to parse function using Numba") + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: # pragma: no cover + dtype2 = dtype + dtypes = (dtype, dtype2) + if dtypes in self._udt_types: + return self._udt_ops[dtypes] + + numba_func = self._numba_func + sig = (dtype.numba_type, UINT64.numba_type, UINT64.numba_type, dtype2.numba_type) + numba_func.compile(sig) # Should we catch and give additional error message? + ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) + indexunary_wrapper, wrapper_sig = _get_udt_wrapper( + numba_func, ret_type, dtype, dtype2, include_indexes=True + ) + + indexunary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(indexunary_wrapper) + new_indexunary = ffi_new("GrB_IndexUnaryOp*") + check_status_carg( + lib.GrB_IndexUnaryOp_new( + new_indexunary, indexunary_wrapper.cffi, ret_type._carg, dtype._carg, dtype2._carg + ), + "IndexUnaryOp", + new_indexunary, + ) + op = TypedUserIndexUnaryOp( + self, + self.name, + dtype, + ret_type, + new_indexunary[0], + dtype2=dtype2, + ) + self._udt_types[dtypes] = ret_type + self._udt_ops[dtypes] = op + return op + + @classmethod + def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): + """Register an IndexUnaryOp without registering it in the + ``graphblas.indexunary`` namespace. + + Because it is not registered in the namespace, the name is optional. + """ + if parameterized: + return ParameterizedIndexUnaryOp(name, func, anonymous=True, is_udt=is_udt) + return cls._build(name, func, anonymous=True, is_udt=is_udt) + + @classmethod + def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): + """Register an IndexUnaryOp. The name will be used to identify the IndexUnaryOp in the + ``graphblas.indexunary`` namespace. + + If the return type is Boolean, the function will also be registered as a SelectOp + with the same name. + + >>> gb.indexunary.register_new("row_mod", lambda x, i, j, thunk: i % max(thunk, 2)) + >>> dir(gb.indexunary) + [..., 'row_mod', ...] + """ + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "func": func, "parameterized": parameterized}, + ) + elif parameterized: + indexunary_op = ParameterizedIndexUnaryOp(name, func, is_udt=is_udt) + setattr(module, funcname, indexunary_op) + else: + indexunary_op = cls._build(name, func, is_udt=is_udt) + setattr(module, funcname, indexunary_op) + # If return type is BOOL, register additionally as a SelectOp + if all(x == BOOL for x in indexunary_op.types.values()): + from .select import SelectOp + + setattr(select, funcname, SelectOp._from_indexunary(indexunary_op)) + + if not cls._initialized: + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return indexunary_op + + @classmethod + def _initialize(cls): + if cls._initialized: + return + from .select import SelectOp + + super()._initialize(include_in_ops=False) + # Update type information to include UINT64 for positional ops + for name in ["tril", "triu", "diag", "offdiag", "colle", "colgt", "rowle", "rowgt"]: + op = getattr(indexunary, name) + typed_op = op._typed_ops[BOOL] + output_type = op.types[BOOL] + if UINT64 not in op.types: # pragma: no branch (safety) + op.types[UINT64] = output_type + op._typed_ops[UINT64] = typed_op + op.coercions[UINT64] = BOOL + for name in ["rowindex", "colindex"]: + op = getattr(indexunary, name) + typed_op = op._typed_ops[INT64] + output_type = op.types[INT64] + if UINT64 not in op.types: # pragma: no branch (safety) + op.types[UINT64] = output_type + op._typed_ops[UINT64] = typed_op + op.coercions[UINT64] = INT64 + # Add index->row alias to make it more intuitive which to use for vectors + indexunary.indexle = indexunary.rowle + indexunary.indexgt = indexunary.rowgt + indexunary.index = indexunary.rowindex + # fmt: off + # Add SelectOp when it makes sense + for name in ["tril", "triu", "diag", "offdiag", + "colle", "colgt", "rowle", "rowgt", "indexle", "indexgt", + "valueeq", "valuene", "valuegt", "valuege", "valuelt", "valuele"]: + iop = getattr(indexunary, name) + setattr(select, name, SelectOp._from_indexunary(iop)) + # fmt: on + cls._initialized = True + + def __init__( + self, + name, + func=None, + *, + anonymous=False, + is_positional=False, + is_udt=False, + numba_func=None, + ): + super().__init__(name, anonymous=anonymous) + self.orig_func = func + self._numba_func = numba_func + self.is_positional = is_positional + self._is_udt = is_udt + if is_udt: + self._udt_types = {} # {dtype: DataType} + self._udt_ops = {} # {dtype: TypedUserIndexUnaryOp} + + def __reduce__(self): + if self._anonymous: + if hasattr(self.orig_func, "_parameterized_info"): + return (_deserialize_parameterized, self.orig_func._parameterized_info) + return (self.register_anonymous, (self.orig_func, self.name)) + if (name := f"indexunary.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.orig_func)) + + __call__ = TypedBuiltinIndexUnaryOp.__call__ diff --git a/graphblas/core/operator/monoid.py b/graphblas/core/operator/monoid.py new file mode 100644 index 000000000..387652b63 --- /dev/null +++ b/graphblas/core/operator/monoid.py @@ -0,0 +1,417 @@ +import inspect +import re +from collections.abc import Mapping + +from ... import _STANDARD_OPERATOR_NAMES, binary, monoid, op +from ...dtypes import ( + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + lookup_dtype, +) +from ...exceptions import check_status_carg +from .. import ffi, lib +from ..expr import InfixExprBase +from ..utils import libget +from .base import OpBase, ParameterizedUdf, TypedOpBase, _call_op, _hasop +from .binary import BinaryOp, ParameterizedBinaryOp + +ffi_new = ffi.new + + +class TypedBuiltinMonoid(TypedOpBase): + __slots__ = "_identity" + opclass = "Monoid" + is_commutative = True + + def __init__(self, parent, name, type_, return_type, gb_obj, gb_name): + super().__init__(parent, name, type_, return_type, gb_obj, gb_name) + self._identity = None + + def __call__(self, left, right=None, *, left_default=None, right_default=None): + if left_default is not None or right_default is not None: + if ( + left_default is None + or right_default is None + or right is not None + or not isinstance(left, InfixExprBase) + or left.method_name != "ewise_add" + ): + raise TypeError( + "Specifying `left_default` or `right_default` keyword arguments implies " + "performing `ewise_union` operation with infix notation.\n" + "There is only one valid way to do this:\n\n" + f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " + "are Vectors or Matrices, and left_default and right_default are scalars." + ) + return left.left.ewise_union(left.right, self, left_default, right_default) + return _call_op(self, left, right) + + @property + def identity(self): + if self._identity is None: + from ..recorder import skip_record + from ..vector import Vector + + with skip_record: + self._identity = ( + Vector(self.type, size=1, name="").reduce(self, allow_empty=False).new().value + ) + return self._identity + + @property + def binaryop(self): + return getattr(binary, self.name)[self.type] + + @property + def commutes_to(self): + return self + + @property + def type2(self): + return self.type + + @property + def is_idempotent(self): + """True if ``monoid(x, x) == x`` for any x.""" + return self.parent.is_idempotent + + +class TypedUserMonoid(TypedOpBase): + __slots__ = "binaryop", "identity" + opclass = "Monoid" + is_commutative = True + + def __init__(self, parent, name, type_, return_type, gb_obj, binaryop, identity): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") + self.binaryop = binaryop + self.identity = identity + binaryop._monoid = self + + commutes_to = TypedBuiltinMonoid.commutes_to + type2 = TypedBuiltinMonoid.type2 + is_idempotent = TypedBuiltinMonoid.is_idempotent + __call__ = TypedBuiltinMonoid.__call__ + + +class ParameterizedMonoid(ParameterizedUdf): + __slots__ = "binaryop", "identity", "_is_idempotent", "__signature__" + is_commutative = True + + def __init__(self, name, binaryop, identity, *, is_idempotent=False, anonymous=False): + if type(binaryop) is not ParameterizedBinaryOp: + raise TypeError("binaryop must be parameterized") + self.binaryop = binaryop + self.__signature__ = binaryop.__signature__ + if callable(identity): + # assume it must be parameterized as well, so signature must match + sig = inspect.signature(identity) + if sig != self.__signature__: + raise ValueError( + "Signatures of binaryop and identity passed to " + f"{type(self).__name__} must be the same. Got:\n" + f" binaryop{self.__signature__}\n" + " !=\n" + f" identity{sig}" + ) + self.identity = identity + self._is_idempotent = is_idempotent + if name is None: + name = binaryop.name + super().__init__(name, anonymous) + binaryop._monoid = self + # clear binaryop cache so it can be associated with this monoid + binaryop._cached_call.cache_clear() + + def _call(self, *args, **kwargs): + binary = self.binaryop(*args, **kwargs) + identity = self.identity + if callable(identity): + identity = identity(*args, **kwargs) + return Monoid.register_anonymous( + binary, identity, self.name, is_idempotent=self._is_idempotent + ) + + commutes_to = TypedBuiltinMonoid.commutes_to + + @property + def is_idempotent(self): + """True if ``monoid(x, x) == x`` for any x.""" + return self._is_idempotent + + def __reduce__(self): + name = f"monoid.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover + return name + return (self._deserialize, (self.name, self.binaryop, self.identity, self._anonymous)) + + @staticmethod + def _deserialize(name, binaryop, identity, anonymous): + if anonymous: + return Monoid.register_anonymous(binaryop, identity, name) + if (rv := Monoid._find(name)) is not None: + return rv + return Monoid.register_new(name, binaryop, identity) + + +class Monoid(OpBase): + """Takes two inputs and returns one output, all of the same data type. + + Built-in and registered Monoids are located in the ``graphblas.monoid`` namespace + as well as in the ``graphblas.ops`` combined namespace. + """ + + __slots__ = "_binaryop", "_identity", "_is_idempotent" + is_commutative = True + is_positional = False + _custom_dtype = None + _module = monoid + _modname = "monoid" + _typed_class = TypedBuiltinMonoid + _parse_config = { + "trim_from_front": 4, + "delete_exact": "MONOID", + "num_underscores": 1, + "re_exprs": [ + re.compile( + "^GrB_(MIN|MAX|PLUS|TIMES|LOR|LAND|LXOR|LXNOR)_MONOID" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(ANY)_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)_MONOID$" + ), + re.compile("^GxB_(PLUS|TIMES|ANY)_(FC32|FC64)_MONOID$"), + re.compile("^GxB_(EQ|ANY)_BOOL_MONOID$"), + re.compile("^GxB_(BOR|BAND|BXOR|BXNOR)_(UINT8|UINT16|UINT32|UINT64)_MONOID$"), + ], + } + + @classmethod + def _build(cls, name, binaryop, identity, *, is_idempotent=False, anonymous=False): + if type(binaryop) is not BinaryOp: + raise TypeError(f"binaryop must be a BinaryOp, not {type(binaryop)}") + if name is None: + name = binaryop.name + new_type_obj = cls( + name, binaryop, identity, is_idempotent=is_idempotent, anonymous=anonymous + ) + if not binaryop._is_udt: + if not isinstance(identity, Mapping): + identities = dict.fromkeys(binaryop.types, identity) + explicit_identities = False + else: + identities = {lookup_dtype(key): val for key, val in identity.items()} + explicit_identities = True + for type_, ident in identities.items(): + ret_type = binaryop[type_].return_type + # If there is a domain mismatch, then DomainMismatch will be raised + # below if identities were explicitly given. + if type_ != ret_type and not explicit_identities: + continue + new_monoid = ffi_new("GrB_Monoid*") + func = libget(f"GrB_Monoid_new_{type_.name}") + zcast = ffi.cast(type_.c_type, ident) + check_status_carg( + func(new_monoid, binaryop[type_].gb_obj, zcast), "Monoid", new_monoid[0] + ) + op = TypedUserMonoid( + new_type_obj, + name, + type_, + ret_type, + new_monoid[0], + binaryop[type_], + ident, + ) + new_type_obj._add(op) + return new_type_obj + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: + dtype2 = dtype + elif dtype != dtype2: + raise TypeError( + "Monoid inputs must be the same dtype (got {dtype} and {dtype2}); " + "unable to coerce when using UDTs." + ) + if dtype in self._udt_types: + return self._udt_ops[dtype] + binaryop = self.binaryop._compile_udt(dtype, dtype2) + from ..scalar import Scalar + + ret_type = binaryop.return_type + identity = Scalar.from_value(self._identity, dtype=ret_type, is_cscalar=True) + new_monoid = ffi_new("GrB_Monoid*") + status = lib.GrB_Monoid_new_UDT(new_monoid, binaryop.gb_obj, identity.gb_obj) + check_status_carg(status, "Monoid", new_monoid[0]) + op = TypedUserMonoid( + new_monoid, + self.name, + dtype, + ret_type, + new_monoid[0], + binaryop, + identity, + ) + self._udt_types[dtype] = dtype + self._udt_ops[dtype] = op + return op + + @classmethod + def register_anonymous(cls, binaryop, identity, name=None, *, is_idempotent=False): + """Register a Monoid without registering it in the ``graphblas.monoid`` namespace. + + Because it is not registered in the namespace, the name is optional. + + Parameters + ---------- + binaryop : BinaryOp + Builtin or registered binary operator + identity : + Identity value of the monoid + name : str, optional + Name associated with the monoid + is_idempotent : bool, default False + Does ``op(x, x) == x`` for any x? + + Returns + ------- + Function handle + """ + if type(binaryop) is ParameterizedBinaryOp: + return ParameterizedMonoid( + name, binaryop, identity, is_idempotent=is_idempotent, anonymous=True + ) + return cls._build(name, binaryop, identity, is_idempotent=is_idempotent, anonymous=True) + + @classmethod + def register_new(cls, name, binaryop, identity, *, is_idempotent=False, lazy=False): + """Register a Monoid. The name will be used to identify the Monoid in the + ``graphblas.monoid`` namespace. + + >>> gb.core.operator.Monoid.register_new("max_zero", gb.binary.max_zero, 0) + >>> dir(gb.monoid) + [..., 'max_zero', ...] + """ + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "binaryop": binaryop, "identity": identity}, + ) + elif type(binaryop) is ParameterizedBinaryOp: + monoid = ParameterizedMonoid(name, binaryop, identity, is_idempotent=is_idempotent) + setattr(module, funcname, monoid) + else: + monoid = cls._build(name, binaryop, identity, is_idempotent=is_idempotent) + setattr(module, funcname, monoid) + # Also save it to `graphblas.op` if not yet defined + opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) + if not _hasop(opmodule, funcname): + if lazy: + opmodule._delayed[funcname] = module + else: + setattr(opmodule, funcname, monoid) + if not cls._initialized: # pragma: no cover + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return monoid + + def __init__(self, name, binaryop=None, identity=None, *, is_idempotent=False, anonymous=False): + super().__init__(name, anonymous=anonymous) + self._binaryop = binaryop + self._identity = identity + self._is_idempotent = is_idempotent + if binaryop is not None: + binaryop._monoid = self + if binaryop._is_udt: + self._udt_types = {} # {dtype: DataType} + self._udt_ops = {} # {dtype: TypedUserMonoid} + + def __reduce__(self): + if self._anonymous: + return (self.register_anonymous, (self._binaryop, self._identity, self.name)) + if (name := f"monoid.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self._binaryop, self._identity)) + + @property + def binaryop(self): + """The :class:`BinaryOp` associated with the Monoid.""" + if self._binaryop is not None: + return self._binaryop + # Must be builtin + return getattr(binary, self.name) + + @property + def identities(self): + """The per-dtype identity values for the Monoid.""" + return {dtype: val.identity for dtype, val in self._typed_ops.items()} + + @property + def is_idempotent(self): + """True if ``monoid(x, x) == x`` for any x.""" + return self._is_idempotent + + @property + def _is_udt(self): + return self._binaryop is not None and self._binaryop._is_udt + + @classmethod + def _initialize(cls): + if cls._initialized: # pragma: no cover (safety) + return + super()._initialize() + lor = monoid.lor._typed_ops[BOOL] + land = monoid.land._typed_ops[BOOL] + for cur_op, typed_op in [ + (monoid.max, lor), + (monoid.min, land), + # (monoid.plus, lor), # two choices: lor, or plus[int] + (monoid.times, land), + ]: + if BOOL not in cur_op.types: # pragma: no branch (safety) + cur_op.types[BOOL] = BOOL + cur_op.coercions[BOOL] = BOOL + cur_op._typed_ops[BOOL] = typed_op + + for cur_op in [monoid.lor, monoid.land, monoid.lxnor, monoid.lxor]: + bool_op = cur_op._typed_ops[BOOL] + for dtype in [ + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + ]: + if dtype in cur_op.types: # pragma: no cover (safety) + continue + cur_op.types[dtype] = BOOL + cur_op.coercions[dtype] = BOOL + cur_op._typed_ops[dtype] = bool_op + + # Builtin monoids that are idempotent; i.e., `op(x, x) == x` for any x + for name in ["any", "band", "bor", "land", "lor", "max", "min"]: + getattr(monoid, name)._is_idempotent = True + # Allow some functions to work on UDTs + any_ = monoid.any + any_._identity = 0 + any_._udt_types = {} + any_._udt_ops = {} + cls._initialized = True + + commutes_to = TypedBuiltinMonoid.commutes_to + __call__ = TypedBuiltinMonoid.__call__ diff --git a/graphblas/core/operator/select.py b/graphblas/core/operator/select.py new file mode 100644 index 000000000..844565f3a --- /dev/null +++ b/graphblas/core/operator/select.py @@ -0,0 +1,187 @@ +import inspect + +from ... import _STANDARD_OPERATOR_NAMES, select +from ...dtypes import BOOL +from .base import OpBase, ParameterizedUdf, TypedOpBase, _call_op, _deserialize_parameterized +from .indexunary import IndexUnaryOp + + +class TypedBuiltinSelectOp(TypedOpBase): + __slots__ = () + opclass = "SelectOp" + + def __call__(self, val, thunk=None): + if thunk is None: + thunk = False # most basic form of 0 when unifying dtypes + return _call_op(self, val, thunk=thunk) + + +class TypedUserSelectOp(TypedOpBase): + __slots__ = () + opclass = "SelectOp" + + def __init__(self, parent, name, type_, return_type, gb_obj): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") + + @property + def orig_func(self): + return self.parent.orig_func + + @property + def _numba_func(self): + return self.parent._numba_func + + __call__ = TypedBuiltinSelectOp.__call__ + + +class ParameterizedSelectOp(ParameterizedUdf): + __slots__ = "func", "__signature__", "_is_udt" + + def __init__(self, name, func, *, anonymous=False, is_udt=False): + self.func = func + self.__signature__ = inspect.signature(func) + self._is_udt = is_udt + if name is None: + name = getattr(func, "__name__", name) + super().__init__(name, anonymous) + + def _call(self, *args, **kwargs): + sel = self.func(*args, **kwargs) + sel._parameterized_info = (self, args, kwargs) + return SelectOp.register_anonymous(sel, self.name, is_udt=self._is_udt) + + def __reduce__(self): + name = f"select.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.func, self._anonymous)) + + @staticmethod + def _deserialize(name, func, anonymous): + if anonymous: + return SelectOp.register_anonymous(func, name, parameterized=True) + if (rv := SelectOp._find(name)) is not None: + return rv + return SelectOp.register_new(name, func, parameterized=True) + + +class SelectOp(OpBase): + """Identical to an :class:`IndexUnaryOp `, + but must have a Boolean return type. + + A SelectOp is used exclusively to select a subset of values from a collection where + the function returns True. + + Built-in and registered SelectOps are located in the ``graphblas.select`` namespace. + """ + + __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" + _module = select + _modname = "select" + _custom_dtype = None + _typed_class = TypedBuiltinSelectOp + _typed_user_class = TypedUserSelectOp + + @classmethod + def _from_indexunary(cls, iop): + obj = cls( + iop.name, + iop.orig_func, + anonymous=iop._anonymous, + is_positional=iop.is_positional, + is_udt=iop._is_udt, + numba_func=iop._numba_func, + ) + if not all(x == BOOL for x in iop.types.values()): + raise ValueError("SelectOp must have BOOL return type") + for type_, t in iop._typed_ops.items(): + if iop.orig_func is not None: + op = cls._typed_user_class( + obj, + iop.name, + t.type, + t.return_type, + t.gb_obj, + ) + else: + op = cls._typed_class( + obj, + iop.name, + t.type, + t.return_type, + t.gb_obj, + t.gb_name, + ) + # type is not always equal to t.type, so can't use op._add + # but otherwise perform the same logic + obj._typed_ops[type_] = op + obj.types[type_] = op.return_type + return obj + + @classmethod + def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): + """Register a SelectOp without registering it in the ``graphblas.select`` namespace. + + Because it is not registered in the namespace, the name is optional. + """ + if parameterized: + return ParameterizedSelectOp(name, func, anonymous=True, is_udt=is_udt) + iop = IndexUnaryOp._build(name, func, anonymous=True, is_udt=is_udt) + return SelectOp._from_indexunary(iop) + + @classmethod + def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): + """Register a SelectOp. The name will be used to identify the SelectOp in the + ``graphblas.select`` namespace. + + The function will also be registered as a IndexUnaryOp with the same name. + + >>> gb.select.register_new("upper_left_triangle", lambda x, i, j, thunk: i + j <= thunk) + >>> dir(gb.select) + [..., 'upper_left_triangle', ...] + """ + iop = IndexUnaryOp.register_new( + name, func, parameterized=parameterized, is_udt=is_udt, lazy=lazy + ) + if not all(x == BOOL for x in iop.types.values()): + raise ValueError("SelectOp must have BOOL return type") + if lazy: + return getattr(select, iop.name) + + @classmethod + def _initialize(cls): + if cls._initialized: # pragma: no cover (safety) + return + # IndexUnaryOp adds it boolean-returning objects to SelectOp + IndexUnaryOp._initialize() + cls._initialized = True + + def __init__( + self, + name, + func=None, + *, + anonymous=False, + is_positional=False, + is_udt=False, + numba_func=None, + ): + super().__init__(name, anonymous=anonymous) + self.orig_func = func + self._numba_func = numba_func + self.is_positional = is_positional + self._is_udt = is_udt + if is_udt: + self._udt_types = {} # {dtype: DataType} + self._udt_ops = {} # {dtype: TypedUserIndexUnaryOp} + + def __reduce__(self): + if self._anonymous: + if hasattr(self.orig_func, "_parameterized_info"): + return (_deserialize_parameterized, self.orig_func._parameterized_info) + return (self.register_anonymous, (self.orig_func, self.name)) + if (name := f"select.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.orig_func)) + + __call__ = TypedBuiltinSelectOp.__call__ diff --git a/graphblas/core/operator/semiring.py b/graphblas/core/operator/semiring.py new file mode 100644 index 000000000..06450e007 --- /dev/null +++ b/graphblas/core/operator/semiring.py @@ -0,0 +1,545 @@ +import itertools +import re + +from ... import _STANDARD_OPERATOR_NAMES, binary, monoid, op, semiring +from ...dtypes import ( + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + _supports_complex, +) +from ...exceptions import check_status_carg +from .. import ffi, lib +from .base import _SS_OPERATORS, OpBase, ParameterizedUdf, TypedOpBase, _call_op, _hasop +from .binary import BinaryOp, ParameterizedBinaryOp +from .monoid import Monoid, ParameterizedMonoid + +if _supports_complex: + from ...dtypes import FC32, FC64 + +ffi_new = ffi.new + + +class TypedBuiltinSemiring(TypedOpBase): + __slots__ = () + opclass = "Semiring" + + def __call__(self, left, right=None): + if right is not None: + raise TypeError( + f"Bad types when calling {self!r}. Got types: {type(left)}, {type(right)}.\n" + f"Expected an infix expression, such as: {self!r}(A @ B)" + ) + return _call_op(self, left) + + @property + def binaryop(self): + name = self.name.split("_", 1)[1] + if name in _SS_OPERATORS: + binop = binary._deprecated[name] + else: + binop = getattr(binary, name) + return binop[self.type] + + @property + def monoid(self): + monoid_name, binary_name = self.name.split("_", 1) + if binary_name in _SS_OPERATORS: + binop = binary._deprecated[binary_name] + else: + binop = getattr(binary, binary_name) + binop = binop[self.type] + val = getattr(monoid, monoid_name) + return val[binop.return_type] + + @property + def commutes_to(self): + binop = self.binaryop + commutes_to = binop._semiring_commutes_to or binop.commutes_to + if commutes_to is None: + return + if commutes_to is binop: + return self + from .utils import get_semiring + + return get_semiring(self.monoid, commutes_to) + + @property + def is_commutative(self): + return self.binaryop.is_commutative + + @property + def type2(self): + return self.type if self._type2 is None else self._type2 + + +class TypedUserSemiring(TypedOpBase): + __slots__ = "monoid", "binaryop" + opclass = "Semiring" + + def __init__(self, parent, name, type_, return_type, gb_obj, monoid, binaryop, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) + self.monoid = monoid + self.binaryop = binaryop + + commutes_to = TypedBuiltinSemiring.commutes_to + is_commutative = TypedBuiltinSemiring.is_commutative + type2 = TypedBuiltinSemiring.type2 + __call__ = TypedBuiltinSemiring.__call__ + + +class ParameterizedSemiring(ParameterizedUdf): + __slots__ = "monoid", "binaryop", "__signature__" + + def __init__(self, name, monoid, binaryop, *, anonymous=False): + if type(monoid) not in {ParameterizedMonoid, Monoid}: + raise TypeError("monoid must be of type Monoid or ParameterizedMonoid") + if type(binaryop) is ParameterizedBinaryOp: + self.__signature__ = binaryop.__signature__ + if type(monoid) is ParameterizedMonoid and monoid.__signature__ != self.__signature__: + raise ValueError( + "Signatures of monoid and binaryop passed to " + f"{type(self).__name__} must be the same. Got:\n" + f" monoid{monoid.__signature__}\n" + " !=\n" + f" binaryop{self.__signature__}\n\n" + "Perhaps call monoid or binaryop with parameters before creating the semiring." + ) + elif type(binaryop) is BinaryOp: + if type(monoid) is Monoid: + raise TypeError("At least one of monoid or binaryop must be parameterized") + self.__signature__ = monoid.__signature__ + else: + raise TypeError("binaryop must be of type BinaryOp or ParameterizedBinaryOp") + self.monoid = monoid + self.binaryop = binaryop + if name is None: + name = f"{monoid.name}_{binaryop.name}" + super().__init__(name, anonymous) + + def _call(self, *args, **kwargs): + monoid = self.monoid + if type(monoid) is ParameterizedMonoid: + monoid = monoid(*args, **kwargs) + binary = self.binaryop + if type(binary) is ParameterizedBinaryOp: + binary = binary(*args, **kwargs) + return Semiring.register_anonymous(monoid, binary, self.name) + + commutes_to = TypedBuiltinSemiring.commutes_to + is_commutative = TypedBuiltinSemiring.is_commutative + + def __reduce__(self): + name = f"semiring.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover + return name + return (self._deserialize, (self.name, self.monoid, self.binaryop, self._anonymous)) + + @staticmethod + def _deserialize(name, monoid, binaryop, anonymous): + if anonymous: + return Semiring.register_anonymous(monoid, binaryop, name) + if (rv := Semiring._find(name)) is not None: + return rv + return Semiring.register_new(name, monoid, binaryop) + + +class Semiring(OpBase): + """Combination of a :class:`Monoid` and a :class:`BinaryOp`. + + Semirings are most commonly used for performing matrix multiplication, + with the BinaryOp taking the place of the standard multiplication operator + and the Monoid taking the place of the standard addition operator. + + Built-in and registered Semirings are located in the ``graphblas.semiring`` namespace + as well as in the ``graphblas.ops`` combined namespace. + """ + + __slots__ = "_monoid", "_binaryop" + _module = semiring + _modname = "semiring" + _typed_class = TypedBuiltinSemiring + _parse_config = { + "trim_from_front": 4, + "delete_exact": "SEMIRING", + "num_underscores": 2, + "re_exprs": [ + re.compile( + "^GrB_(PLUS|MIN|MAX)_(PLUS|TIMES|FIRST|SECOND|MIN|MAX)_SEMIRING" + "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(MIN|MAX|PLUS|TIMES|ANY)" + "_(FIRST|SECOND|PAIR|MIN|MAX|PLUS|MINUS|RMINUS|TIMES" + "|DIV|RDIV|ISEQ|ISNE|ISGT|ISLT|ISGE|ISLE|LOR|LAND|LXOR" + "|FIRSTI1|FIRSTI|FIRSTJ1|FIRSTJ|SECONDI1|SECONDI|SECONDJ1|SECONDJ)" + "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(PLUS|TIMES|ANY)_(FIRST|SECOND|PAIR|PLUS|MINUS|TIMES|DIV|RDIV|RMINUS)" + "_(FC32|FC64)$" + ), + re.compile( + "^GxB_(BOR|BAND|BXOR|BXNOR)_(BOR|BAND|BXOR|BXNOR)_(UINT8|UINT16|UINT32|UINT64)$" + ), + ], + "re_exprs_return_bool": [ + re.compile("^GrB_(LOR|LAND|LXOR|LXNOR)_(LOR|LAND)_SEMIRING_BOOL$"), + re.compile( + "^GxB_(LOR|LAND|LXOR|EQ|ANY)_(EQ|NE|GT|LT|GE|LE)" + "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(LOR|LAND|LXOR|EQ|ANY)_(FIRST|SECOND|PAIR|LOR|LAND|LXOR|EQ|GT|LT|GE|LE)_BOOL$" + ), + ], + } + + @classmethod + def _build(cls, name, monoid, binaryop, *, anonymous=False): + if type(monoid) is not Monoid: + raise TypeError(f"monoid must be a Monoid, not {type(monoid)}") + if type(binaryop) is not BinaryOp: + raise TypeError(f"binaryop must be a BinaryOp, not {type(binaryop)}") + if name is None: + name = f"{monoid.name}_{binaryop.name}".replace(".", "_") + new_type_obj = cls(name, monoid, binaryop, anonymous=anonymous) + if binaryop._is_udt: + return new_type_obj + for binary_in, binary_func in binaryop._typed_ops.items(): + binary_out = binary_func.return_type + # Unfortunately, we can't have user-defined monoids over bools yet + # because numba can't compile correctly. + if ( + binary_out not in monoid.types + # Are all coercions bad, or just to bool? + or monoid.coercions.get(binary_out, binary_out) != binary_out + ): + continue + new_semiring = ffi_new("GrB_Semiring*") + check_status_carg( + lib.GrB_Semiring_new(new_semiring, monoid[binary_out].gb_obj, binary_func.gb_obj), + "Semiring", + new_semiring, + ) + ret_type = monoid[binary_out].return_type + op = TypedUserSemiring( + new_type_obj, + name, + binary_in, + ret_type, + new_semiring[0], + monoid[binary_out], + binary_func, + ) + new_type_obj._add(op) + return new_type_obj + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: + dtype2 = dtype + dtypes = (dtype, dtype2) + if dtypes in self._udt_types: + return self._udt_ops[dtypes] + binaryop = self.binaryop._compile_udt(dtype, dtype2) + monoid = self.monoid[binaryop.return_type] + ret_type = monoid.return_type + new_semiring = ffi_new("GrB_Semiring*") + status = lib.GrB_Semiring_new(new_semiring, monoid.gb_obj, binaryop.gb_obj) + check_status_carg(status, "Semiring", new_semiring) + op = TypedUserSemiring( + new_semiring, + self.name, + dtype, + ret_type, + new_semiring[0], + monoid, + binaryop, + dtype2=dtype2, + ) + self._udt_types[dtypes] = dtype + self._udt_ops[dtypes] = op + return op + + @classmethod + def register_anonymous(cls, monoid, binaryop, name=None): + """Register a Semiring without registering it in the ``graphblas.semiring`` namespace. + + Because it is not registered in the namespace, the name is optional. + + Parameters + ---------- + monoid : Monoid + Builtin or registered monoid + binaryop : BinaryOp + Builtin or registered binary operator + name : str, optional + Name associated with the semiring + + Returns + ------- + Function handle + """ + if type(monoid) is ParameterizedMonoid or type(binaryop) is ParameterizedBinaryOp: + return ParameterizedSemiring(name, monoid, binaryop, anonymous=True) + return cls._build(name, monoid, binaryop, anonymous=True) + + @classmethod + def register_new(cls, name, monoid, binaryop, *, lazy=False): + """Register a Semiring. The name will be used to identify the Semiring in the + ``graphblas.semiring`` namespace. + + >>> gb.core.operator.Semiring.register_new("max_max", gb.monoid.max, gb.binary.max) + >>> dir(gb.semiring) + [..., 'max_max', ...] + """ + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "monoid": monoid, "binaryop": binaryop}, + ) + elif type(monoid) is ParameterizedMonoid or type(binaryop) is ParameterizedBinaryOp: + semiring = ParameterizedSemiring(name, monoid, binaryop) + setattr(module, funcname, semiring) + else: + semiring = cls._build(name, monoid, binaryop) + setattr(module, funcname, semiring) + # Also save it to `graphblas.op` if not yet defined + opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) + if not _hasop(opmodule, funcname): + if lazy: + opmodule._delayed[funcname] = module + else: + setattr(opmodule, funcname, semiring) + if not cls._initialized: + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return semiring + + @classmethod + def _initialize(cls): + if cls._initialized: # pragma: no cover (safety) + return + super()._initialize() + # Rename div to cdiv (truncate towards 0) + div_semirings = { + attr: val + for attr, val in vars(semiring).items() + if type(val) is Semiring and attr.endswith("_div") + } + for orig_name, orig in div_semirings.items(): + name = f"{orig_name[:-3]}cdiv" + cdiv_semiring = Semiring(name) + setattr(semiring, name, cdiv_semiring) + setattr(op, name, cdiv_semiring) + delattr(semiring, orig_name) + delattr(op, orig_name) + for dtype, ret_type in orig.types.items(): + orig_semiring = orig[dtype] + new_semiring = TypedBuiltinSemiring( + cdiv_semiring, + name, + dtype, + ret_type, + orig_semiring.gb_obj, + orig_semiring.gb_name, + ) + cdiv_semiring._add(new_semiring) + # Also add truediv (always floating point) and floordiv (truncate towards -inf) + for orig_name, orig in div_semirings.items(): + cls.register_new(f"{orig_name[:-3]}truediv", orig.monoid, binary.truediv, lazy=True) + cls.register_new(f"{orig_name[:-3]}rtruediv", orig.monoid, "rtruediv", lazy=True) + cls.register_new(f"{orig_name[:-3]}floordiv", orig.monoid, "floordiv", lazy=True) + cls.register_new(f"{orig_name[:-3]}rfloordiv", orig.monoid, "rfloordiv", lazy=True) + # For aggregators + cls.register_new("plus_pow", monoid.plus, binary.pow) + cls.register_new("plus_rpow", monoid.plus, "rpow", lazy=True) + cls.register_new("plus_absfirst", monoid.plus, "absfirst", lazy=True) + cls.register_new("max_absfirst", monoid.max, "absfirst", lazy=True) + cls.register_new("plus_abssecond", monoid.plus, "abssecond", lazy=True) + cls.register_new("max_abssecond", monoid.max, "abssecond", lazy=True) + + # Update type information with sane coercion + for lname in ["any", "eq", "land", "lor", "lxnor", "lxor"]: + target_name = f"{lname}_ne" + source_name = f"{lname}_lxor" + if not _hasop(semiring, target_name): + continue + target_op = getattr(semiring, target_name) + if BOOL not in target_op.types: # pragma: no branch (safety) + source_op = getattr(semiring, source_name) + typed_op = source_op._typed_ops[BOOL] + target_op.types[BOOL] = BOOL + target_op._typed_ops[BOOL] = typed_op + target_op.coercions[dtype] = BOOL + + position_dtypes = [ + BOOL, + FP32, + FP64, + INT8, + INT16, + UINT8, + UINT16, + UINT32, + UINT64, + ] + notbool_dtypes = [ + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + ] + if _supports_complex: + position_dtypes.extend([FC32, FC64]) + notbool_dtypes.extend([FC32, FC64]) + for lnames, rnames, *types in [ + # fmt: off + ( + ("any", "max", "min", "plus", "times"), + ( + "firsti", "firsti1", "firstj", "firstj1", + "secondi", "secondi1", "secondj", "secondj1", + ), + ( + position_dtypes, + INT64, + ), + ), + ( + ("eq", "land", "lor", "lxnor", "lxor"), + ("first", "pair", "second"), + # TODO: check if FC coercion works here + ( + notbool_dtypes, + BOOL, + ), + ), + ( + ("band", "bor", "bxnor", "bxor"), + ("band", "bor", "bxnor", "bxor"), + ([INT8], UINT16), + ([INT16], UINT32), + ([INT32], UINT64), + ([INT64], UINT64), + ), + ( + ("any", "eq", "land", "lor", "lxnor", "lxor"), + ("eq", "land", "lor", "lxnor", "lxor", "ne"), + ( + ( + FP32, FP64, INT8, INT16, INT32, INT64, + UINT8, UINT16, UINT32, UINT64, + ), + BOOL, + ), + ), + # fmt: on + ]: + for left, right in itertools.product(lnames, rnames): + name = f"{left}_{right}" + if not _hasop(semiring, name): + continue + if name in _SS_OPERATORS: + cur_op = semiring._deprecated[name] + else: + cur_op = getattr(semiring, name) + for input_types, target_type in types: + typed_op = cur_op._typed_ops[target_type] + output_type = cur_op.types[target_type] + for dtype in input_types: + if dtype not in cur_op.types: + cur_op.types[dtype] = output_type + cur_op._typed_ops[dtype] = typed_op + cur_op.coercions[dtype] = target_type + + # Handle a few boolean cases + for opname, targetname in [ + ("max_first", "lor_first"), + ("max_second", "lor_second"), + ("max_land", "lor_land"), + ("max_lor", "lor_lor"), + ("max_lxor", "lor_lxor"), + ("min_first", "land_first"), + ("min_second", "land_second"), + ("min_land", "land_land"), + ("min_lor", "land_lor"), + ("min_lxor", "land_lxor"), + ]: + cur_op = getattr(semiring, opname) + target = getattr(semiring, targetname) + if BOOL in cur_op.types or BOOL not in target.types: # pragma: no cover (safety) + continue + cur_op.types[BOOL] = target.types[BOOL] + cur_op._typed_ops[BOOL] = target._typed_ops[BOOL] + cur_op.coercions[BOOL] = BOOL + cls._initialized = True + + def __init__(self, name, monoid=None, binaryop=None, *, anonymous=False): + super().__init__(name, anonymous=anonymous) + self._monoid = monoid + self._binaryop = binaryop + try: + if self.binaryop._udt_types is not None: + self._udt_types = {} # {(dtype, dtype): DataType} + self._udt_ops = {} # {(dtype, dtype): TypedUserSemiring} + except AttributeError: + # `*_div` semirings raise here, but don't need `_udt_types` + pass + + def __reduce__(self): + if self._anonymous: + return (self.register_anonymous, (self._monoid, self._binaryop, self.name)) + if (name := f"semiring.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self._monoid, self._binaryop)) + + @property + def binaryop(self): + """The :class:`BinaryOp` associated with the Semiring.""" + if self._binaryop is not None: + return self._binaryop + # Must be builtin + name = self.name.split("_")[1] + if name in _SS_OPERATORS: + return binary._deprecated[name] + return getattr(binary, name) + + @property + def monoid(self): + """The :class:`Monoid` associated with the Semiring.""" + if self._monoid is not None: + return self._monoid + # Must be builtin + return getattr(monoid, self.name.split("_")[0].split(".")[-1]) + + @property + def is_positional(self): + return self.binaryop.is_positional + + @property + def _is_udt(self): + return self._binaryop is not None and self._binaryop._is_udt + + @property + def _custom_dtype(self): + return self.binaryop._custom_dtype + + commutes_to = TypedBuiltinSemiring.commutes_to + is_commutative = TypedBuiltinSemiring.is_commutative + __call__ = TypedBuiltinSemiring.__call__ diff --git a/graphblas/core/operator/unary.py b/graphblas/core/operator/unary.py new file mode 100644 index 000000000..6b1319057 --- /dev/null +++ b/graphblas/core/operator/unary.py @@ -0,0 +1,408 @@ +import inspect +import re +from types import FunctionType + +import numba + +from ... import _STANDARD_OPERATOR_NAMES, op, unary +from ...dtypes import ( + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + _sample_values, + _supports_complex, + lookup_dtype, +) +from ...exceptions import UdfParseError, check_status_carg +from .. import ffi, lib +from ..utils import output_type +from .base import ( + _SS_OPERATORS, + OpBase, + ParameterizedUdf, + TypedOpBase, + _deserialize_parameterized, + _get_udt_wrapper, + _hasop, +) + +if _supports_complex: + from ...dtypes import FC32, FC64 + +ffi_new = ffi.new + + +class TypedBuiltinUnaryOp(TypedOpBase): + __slots__ = () + opclass = "UnaryOp" + + def __call__(self, val): + from ..matrix import Matrix, TransposedMatrix + from ..vector import Vector + + if (typ := output_type(val)) in {Vector, Matrix, TransposedMatrix}: + return val.apply(self) + from ..scalar import Scalar, _as_scalar + + if typ is Scalar: + return val.apply(self) + try: + scalar = _as_scalar(val, is_cscalar=False) + except Exception: + pass + else: + return scalar.apply(self) + raise TypeError( + f"Bad type when calling {self!r}.\n" + " - Expected type: Scalar, Vector, Matrix, TransposedMatrix.\n" + f" - Got: {type(val)}.\n" + "Calling a UnaryOp is syntactic sugar for calling apply. " + f"For example, `A.apply({self!r})` is the same as `{self!r}(A)`." + ) + + +class TypedUserUnaryOp(TypedOpBase): + __slots__ = () + opclass = "UnaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") + + @property + def orig_func(self): + return self.parent.orig_func + + @property + def _numba_func(self): + return self.parent._numba_func + + __call__ = TypedBuiltinUnaryOp.__call__ + + +class ParameterizedUnaryOp(ParameterizedUdf): + __slots__ = "func", "__signature__", "_is_udt" + + def __init__(self, name, func, *, anonymous=False, is_udt=False): + self.func = func + self.__signature__ = inspect.signature(func) + self._is_udt = is_udt + if name is None: + name = getattr(func, "__name__", name) + super().__init__(name, anonymous) + + def _call(self, *args, **kwargs): + unary = self.func(*args, **kwargs) + unary._parameterized_info = (self, args, kwargs) + return UnaryOp.register_anonymous(unary, self.name, is_udt=self._is_udt) + + def __reduce__(self): + name = f"unary.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover + return name + return (self._deserialize, (self.name, self.func, self._anonymous)) + + @staticmethod + def _deserialize(name, func, anonymous): + if anonymous: + return UnaryOp.register_anonymous(func, name, parameterized=True) + if (rv := UnaryOp._find(name)) is not None: + return rv + return UnaryOp.register_new(name, func, parameterized=True) + + +def _identity(x): + return x # pragma: no cover (numba) + + +def _one(x): + return 1 # pragma: no cover (numba) + + +class UnaryOp(OpBase): + """Takes one input and returns one output, possibly of a different data type. + + Built-in and registered UnaryOps are located in the ``graphblas.unary`` namespace + as well as in the ``graphblas.ops`` combined namespace. + """ + + __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" + _custom_dtype = None + _module = unary + _modname = "unary" + _typed_class = TypedBuiltinUnaryOp + _parse_config = { + "trim_from_front": 4, + "num_underscores": 1, + "re_exprs": [ + re.compile( + "^GrB_(IDENTITY|AINV|MINV|ABS|BNOT)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" + ), + re.compile( + "^GxB_(LNOT|ONE|POSITIONI1|POSITIONI|POSITIONJ1|POSITIONJ)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(SQRT|LOG|EXP|LOG2|SIN|COS|TAN|ACOS|ASIN|ATAN|SINH|COSH|TANH|ACOSH" + "|ASINH|ATANH|SIGNUM|CEIL|FLOOR|ROUND|TRUNC|EXP2|EXPM1|LOG10|LOG1P)" + "_(FP32|FP64|FC32|FC64)$" + ), + re.compile("^GxB_(LGAMMA|TGAMMA|ERF|ERFC|FREXPX|FREXPE|CBRT)_(FP32|FP64)$"), + re.compile("^GxB_(IDENTITY|AINV|MINV|ONE|CONJ)_(FC32|FC64)$"), + ], + "re_exprs_return_bool": [ + re.compile("^GrB_LNOT$"), + re.compile("^GxB_(ISINF|ISNAN|ISFINITE)_(FP32|FP64|FC32|FC64)$"), + ], + "re_exprs_return_float": [re.compile("^GxB_(CREAL|CIMAG|CARG|ABS)_(FC32|FC64)$")], + } + _positional = {"positioni", "positioni1", "positionj", "positionj1"} + + @classmethod + def _build(cls, name, func, *, anonymous=False, is_udt=False): + if type(func) is not FunctionType: + raise TypeError(f"UDF argument must be a function, not {type(func)}") + if name is None: + name = getattr(func, "__name__", "") + success = False + unary_udf = numba.njit(func) + new_type_obj = cls(name, func, anonymous=anonymous, is_udt=is_udt, numba_func=unary_udf) + return_types = {} + nt = numba.types + if not is_udt: + for type_ in _sample_values: + sig = (type_.numba_type,) + try: + unary_udf.compile(sig) + except numba.TypingError: + continue + ret_type = lookup_dtype(unary_udf.overloads[sig].signature.return_type) + if ret_type != type_ and ( + ("INT" in ret_type.name and "INT" in type_.name) + or ("FP" in ret_type.name and "FP" in type_.name) + or ("FC" in ret_type.name and "FC" in type_.name) + or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) + ): + # Downcast `ret_type` to `type_`. + # This is what users want most of the time, but we can't make a perfect rule. + # There should be a way for users to be explicit. + ret_type = type_ + elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: + ret_type = INT8 + + # Numba is unable to handle BOOL correctly right now, but we have a workaround + # See: https://github.com/numba/numba/issues/5395 + # We're relying on coercion behaving correctly here + input_type = INT8 if type_ == BOOL else type_ + return_type = INT8 if ret_type == BOOL else ret_type + + # Build wrapper because GraphBLAS wants pointers and void return + wrapper_sig = nt.void( + nt.CPointer(return_type.numba_type), + nt.CPointer(input_type.numba_type), + ) + + if type_ == BOOL: + if ret_type == BOOL: + + def unary_wrapper(z, x): + z[0] = bool(unary_udf(bool(x[0]))) # pragma: no cover (numba) + + else: + + def unary_wrapper(z, x): + z[0] = unary_udf(bool(x[0])) # pragma: no cover (numba) + + elif ret_type == BOOL: + + def unary_wrapper(z, x): + z[0] = bool(unary_udf(x[0])) # pragma: no cover (numba) + + else: + + def unary_wrapper(z, x): + z[0] = unary_udf(x[0]) # pragma: no cover (numba) + + unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper) + new_unary = ffi_new("GrB_UnaryOp*") + check_status_carg( + lib.GrB_UnaryOp_new( + new_unary, unary_wrapper.cffi, ret_type.gb_obj, type_.gb_obj + ), + "UnaryOp", + new_unary, + ) + op = TypedUserUnaryOp(new_type_obj, name, type_, ret_type, new_unary[0]) + new_type_obj._add(op) + success = True + return_types[type_] = ret_type + if success or is_udt: + return new_type_obj + raise UdfParseError("Unable to parse function using Numba") + + def _compile_udt(self, dtype, dtype2): + if dtype in self._udt_types: + return self._udt_ops[dtype] + + numba_func = self._numba_func + sig = (dtype.numba_type,) + numba_func.compile(sig) # Should we catch and give additional error message? + ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) + + unary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype) + unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper) + new_unary = ffi_new("GrB_UnaryOp*") + check_status_carg( + lib.GrB_UnaryOp_new(new_unary, unary_wrapper.cffi, ret_type._carg, dtype._carg), + "UnaryOp", + new_unary, + ) + op = TypedUserUnaryOp(self, self.name, dtype, ret_type, new_unary[0]) + self._udt_types[dtype] = ret_type + self._udt_ops[dtype] = op + return op + + @classmethod + def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): + """Register a UnaryOp without registering it in the ``graphblas.unary`` namespace. + + Because it is not registered in the namespace, the name is optional. + """ + if parameterized: + return ParameterizedUnaryOp(name, func, anonymous=True, is_udt=is_udt) + return cls._build(name, func, anonymous=True, is_udt=is_udt) + + @classmethod + def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): + """Register a UnaryOp. The name will be used to identify the UnaryOp in the + ``graphblas.unary`` namespace. + + >>> gb.core.operator.UnaryOp.register_new("plus_one", lambda x: x + 1) + >>> dir(gb.unary) + [..., 'plus_one', ...] + """ + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "func": func, "parameterized": parameterized}, + ) + elif parameterized: + unary_op = ParameterizedUnaryOp(name, func, is_udt=is_udt) + setattr(module, funcname, unary_op) + else: + unary_op = cls._build(name, func, is_udt=is_udt) + setattr(module, funcname, unary_op) + # Also save it to `graphblas.op` if not yet defined + opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) + if not _hasop(opmodule, funcname): + if lazy: + opmodule._delayed[funcname] = module + else: + setattr(opmodule, funcname, unary_op) + if not cls._initialized: # pragma: no cover + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return unary_op + + @classmethod + def _initialize(cls): + if cls._initialized: + return + super()._initialize() + # Update type information with sane coercion + position_dtypes = [ + BOOL, + FP32, + FP64, + INT8, + INT16, + UINT8, + UINT16, + UINT32, + UINT64, + ] + if _supports_complex: + position_dtypes.extend([FC32, FC64]) + for names, *types in [ + # fmt: off + ( + ( + "erf", "erfc", "lgamma", "tgamma", "acos", "acosh", "asin", "asinh", + "atan", "atanh", "ceil", "cos", "cosh", "exp", "exp2", "expm1", "floor", + "log", "log10", "log1p", "log2", "round", "signum", "sin", "sinh", "sqrt", + "tan", "tanh", "trunc", "cbrt", + ), + ((BOOL, INT8, INT16, UINT8, UINT16), FP32), + ((INT32, INT64, UINT32, UINT64), FP64), + ), + ( + ("positioni", "positioni1", "positionj", "positionj1"), + ( + position_dtypes, + INT64, + ), + ), + # fmt: on + ]: + for name in names: + if name in _SS_OPERATORS: + op = unary._deprecated[name] + else: + op = getattr(unary, name) + for input_types, target_type in types: + typed_op = op._typed_ops[target_type] + output_type = op.types[target_type] + for dtype in input_types: + if dtype not in op.types: # pragma: no branch (safety) + op.types[dtype] = output_type + op._typed_ops[dtype] = typed_op + op.coercions[dtype] = target_type + # Allow some functions to work on UDTs + for unop, func in [ + (unary.identity, _identity), + (unary.one, _one), + ]: + unop.orig_func = func + unop._numba_func = numba.njit(func) + unop._udt_types = {} + unop._udt_ops = {} + cls._initialized = True + + def __init__( + self, + name, + func=None, + *, + anonymous=False, + is_positional=False, + is_udt=False, + numba_func=None, + ): + super().__init__(name, anonymous=anonymous) + self.orig_func = func + self._numba_func = numba_func + self.is_positional = is_positional + self._is_udt = is_udt + if is_udt: + self._udt_types = {} # {dtype: DataType} + self._udt_ops = {} # {dtype: TypedUserUnaryOp} + + def __reduce__(self): + if self._anonymous: + if hasattr(self.orig_func, "_parameterized_info"): + return (_deserialize_parameterized, self.orig_func._parameterized_info) + return (self.register_anonymous, (self.orig_func, self.name)) + if (name := f"unary.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.orig_func)) + + __call__ = TypedBuiltinUnaryOp.__call__ diff --git a/graphblas/core/operator/utils.py b/graphblas/core/operator/utils.py new file mode 100644 index 000000000..00bc86cea --- /dev/null +++ b/graphblas/core/operator/utils.py @@ -0,0 +1,447 @@ +from types import BuiltinFunctionType, FunctionType, ModuleType + +from ... import backend, binary, config, indexunary, monoid, op, select, semiring, unary +from ...dtypes import UINT64, lookup_dtype, unify +from .base import ( + _SS_OPERATORS, + OpBase, + OpPath, + ParameterizedUdf, + TypedOpBase, + _builtin_to_op, + _hasop, + find_opclass, +) +from .binary import BinaryOp +from .indexunary import IndexUnaryOp +from .monoid import Monoid +from .select import SelectOp +from .semiring import Semiring +from .unary import UnaryOp + +# Now initialize all the things! +try: + UnaryOp._initialize() + IndexUnaryOp._initialize() + SelectOp._initialize() + BinaryOp._initialize() + Monoid._initialize() + Semiring._initialize() +except Exception: # pragma: no cover (debug) + # Exceptions here can often get ignored by Python + import traceback + + traceback.print_exc() + raise + + +def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scalar=False, kind=None): + if isinstance(op, OpBase): + # UDTs always get compiled + if op._is_udt: + return op._compile_udt(dtype, dtype2) + # Single dtype is simple lookup + if dtype2 is None: + return op[dtype] + # Handle special cases such as first and second (may have UDTs) + if op._custom_dtype is not None and (rv := op._custom_dtype(op, dtype, dtype2)) is not None: + return rv + # Generic case: try to unify the two dtypes + try: + return op[ + unify(dtype, dtype2, is_left_scalar=is_left_scalar, is_right_scalar=is_right_scalar) + ] + except (TypeError, AttributeError): + # Failure to unify implies a dtype is UDT; some builtin operators can handle UDTs + if op.is_positional: + return op[UINT64] + if op._udt_types is None: + raise + return op._compile_udt(dtype, dtype2) + if isinstance(op, ParameterizedUdf): + op = op() # Use default parameters of parameterized UDFs + return get_typed_op( + op, + dtype, + dtype2, + is_left_scalar=is_left_scalar, + is_right_scalar=is_right_scalar, + kind=kind, + ) + if isinstance(op, TypedOpBase): + return op + + from .agg import Aggregator, TypedAggregator + + if isinstance(op, Aggregator): + return op[dtype] + if isinstance(op, TypedAggregator): + return op + if isinstance(op, str): + if kind == "unary": + op = unary_from_string(op) + elif kind == "select": + op = select_from_string(op) + elif kind == "binary": + op = binary_from_string(op) + elif kind == "monoid": + op = monoid_from_string(op) + elif kind == "semiring": + op = semiring_from_string(op) + elif kind == "binary|aggregator": + try: + op = binary_from_string(op) + except ValueError: + try: + op = aggregator_from_string(op) + except ValueError: + raise ValueError( + f"Unknown binary or aggregator string: {op!r}. Example usage: '+[int]'" + ) from None + + else: + raise ValueError( + f"Unable to get op from string {op!r}. `kind=` argument must be provided as " + '"unary", "binary", "monoid", "semiring", "indexunary", "select", ' + 'or "binary|aggregator".' + ) + return get_typed_op( + op, + dtype, + dtype2, + is_left_scalar=is_left_scalar, + is_right_scalar=is_right_scalar, + kind=kind, + ) + if isinstance(op, FunctionType): + if kind == "unary": + op = UnaryOp.register_anonymous(op, is_udt=True) + return op._compile_udt(dtype, dtype2) + if kind.startswith("binary"): + op = BinaryOp.register_anonymous(op, is_udt=True) + return op._compile_udt(dtype, dtype2) + if isinstance(op, BuiltinFunctionType) and op in _builtin_to_op: + return get_typed_op( + _builtin_to_op[op], + dtype, + dtype2, + is_left_scalar=is_left_scalar, + is_right_scalar=is_right_scalar, + kind=kind, + ) + raise TypeError(f"Unable to get typed operator from object with type {type(op)}") + + +def get_semiring(monoid, binaryop, name=None): + """Get or create a Semiring object from a monoid and binaryop. + + If either are typed, then the returned semiring will also be typed. + + See Also + -------- + semiring.register_anonymous + semiring.register_new + semiring.from_string + """ + monoid, opclass = find_opclass(monoid) + switched = False + if opclass == "BinaryOp" and monoid.monoid is not None: + switched = True + monoid = monoid.monoid + elif opclass != "Monoid": + raise TypeError(f"Expected a Monoid for the monoid argument. Got type: {type(monoid)}") + binaryop, opclass = find_opclass(binaryop) + if opclass == "Monoid": + if switched: + raise TypeError( + "Got a BinaryOp for the monoid argument and a Monoid for the binaryop argument. " + "Are the arguments switched? Hint: you can do `mymonoid.binaryop` to get the " + "binaryop from a monoid." + ) + binaryop = binaryop.binaryop + elif opclass != "BinaryOp": + raise TypeError( + f"Expected a BinaryOp for the binaryop argument. Got type: {type(binaryop)}" + ) + if isinstance(monoid, Monoid): + monoid_type = None + else: + monoid_type = monoid.type + monoid = monoid.parent + if isinstance(binaryop, BinaryOp): + binary_type = None + else: + binary_type = binaryop.type + binaryop = binaryop.parent + if monoid._anonymous or binaryop._anonymous: + rv = Semiring.register_anonymous(monoid, binaryop, name=name) + else: + *monoid_prefix, monoid_name = monoid.name.rsplit(".", 1) + *binary_prefix, binary_name = binaryop.name.rsplit(".", 1) + if ( + monoid_prefix + and binary_prefix + and monoid_prefix == binary_prefix + or config.get("mapnumpy") + and ( + monoid_prefix == ["numpy"] + and not binary_prefix + or binary_prefix == ["numpy"] + and not monoid_prefix + ) + or backend == "suitesparse" + and binary_name in _SS_OPERATORS + ): + canonical_name = ( + ".".join(monoid_prefix or binary_prefix) + f".{monoid_name}_{binary_name}" + ) + else: + canonical_name = f"{monoid.name}_{binaryop.name}".replace(".", "_") + if name is None: + name = canonical_name + + module, funcname = Semiring._remove_nesting(canonical_name, strict=False) + rv = ( + getattr(module, funcname) + if funcname in module.__dict__ or funcname in module._delayed + else getattr(module, "_deprecated", {}).get(funcname) + ) + if rv is None and name != canonical_name: + module, funcname = Semiring._remove_nesting(name, strict=False) + rv = ( + getattr(module, funcname) + if funcname in module.__dict__ or funcname in module._delayed + else getattr(module, "_deprecated", {}).get(funcname) + ) + if rv is None: + rv = Semiring.register_new(canonical_name, monoid, binaryop) + elif rv.monoid is not monoid or rv.binaryop is not binaryop: # pragma: no cover + # It's not the object we expect (can this happen?) + rv = Semiring.register_anonymous(monoid, binaryop, name=name) + if name != canonical_name: + module, funcname = Semiring._remove_nesting(name, strict=False) + if not _hasop(module, funcname): # pragma: no branch (safety) + setattr(module, funcname, rv) + + if binary_type is not None: + return rv[binary_type] + if monoid_type is not None: + return rv[monoid_type] + return rv + + +unary.register_new = UnaryOp.register_new +unary.register_anonymous = UnaryOp.register_anonymous +indexunary.register_new = IndexUnaryOp.register_new +indexunary.register_anonymous = IndexUnaryOp.register_anonymous +select.register_new = SelectOp.register_new +select.register_anonymous = SelectOp.register_anonymous +binary.register_new = BinaryOp.register_new +binary.register_anonymous = BinaryOp.register_anonymous +monoid.register_new = Monoid.register_new +monoid.register_anonymous = Monoid.register_anonymous +semiring.register_new = Semiring.register_new +semiring.register_anonymous = Semiring.register_anonymous +semiring.get_semiring = get_semiring + +select._binary_to_select.update( + { + binary.eq: select.valueeq, + binary.ne: select.valuene, + binary.le: select.valuele, + binary.lt: select.valuelt, + binary.ge: select.valuege, + binary.gt: select.valuegt, + binary.iseq: select.valueeq, + binary.isne: select.valuene, + binary.isle: select.valuele, + binary.islt: select.valuelt, + binary.isge: select.valuege, + binary.isgt: select.valuegt, + } +) + +_builtin_to_op.update( + { + abs: unary.abs, + max: binary.max, + min: binary.min, + # Maybe someday: all, any, pow, sum + } +) + +_str_to_unary = { + "-": unary.ainv, + "~": unary.lnot, +} +_str_to_select = { + "<": select.valuelt, + ">": select.valuegt, + "<=": select.valuele, + ">=": select.valuege, + "!=": select.valuene, + "==": select.valueeq, + "col<=": select.colle, + "col>": select.colgt, + "row<=": select.rowle, + "row>": select.rowgt, + "index<=": select.indexle, + "index>": select.indexgt, +} +_str_to_binary = { + "<": binary.lt, + ">": binary.gt, + "<=": binary.le, + ">=": binary.ge, + "!=": binary.ne, + "==": binary.eq, + "+": binary.plus, + "-": binary.minus, + "*": binary.times, + "/": binary.truediv, + "//": "floordiv", + "%": "numpy.mod", + "**": binary.pow, + "&": binary.land, + "|": binary.lor, + "^": binary.lxor, +} +_str_to_monoid = { + "==": monoid.eq, + "+": monoid.plus, + "*": monoid.times, + "&": monoid.land, + "|": monoid.lor, + "^": monoid.lxor, +} + + +def _from_string(string, module, mapping, example): + s = string.lower().strip() + base, *dtype = s.split("[") + if len(dtype) > 1: + name = module.__name__.split(".")[-1] + raise ValueError( + f'Bad {name} string: {string!r}. Contains too many "[". Example usage: {example!r}' + ) + if dtype: + dtype = dtype[0] + if not dtype.endswith("]"): + name = module.__name__.split(".")[-1] + raise ValueError( + f'Bad {name} string: {string!r}. Datatype specification does not end with "]". ' + f"Example usage: {example!r}" + ) + dtype = lookup_dtype(dtype[:-1]) + if "]" in base: + name = module.__name__.split(".")[-1] + raise ValueError( + f'Bad {name} string: {string!r}. "]" not matched by "[". Example usage: {example!r}' + ) + if base in mapping: + op = mapping[base] + if type(op) is str: + op = mapping[base] = module.from_string(op) + elif hasattr(module, base): + op = getattr(module, base) + elif hasattr(module, "numpy") and hasattr(module.numpy, base): + op = getattr(module.numpy, base) + else: + *paths, attr = base.split(".") + op = None + cur = module + for path in paths: + cur = getattr(cur, path, None) + if not isinstance(cur, (OpPath, ModuleType)): + cur = None + break + op = getattr(cur, attr, None) + if op is None: + name = module.__name__.split(".")[-1] + raise ValueError(f"Unknown {name} string: {string!r}. Example usage: {example!r}") + if dtype: + op = op[dtype] + return op + + +def unary_from_string(string): + return _from_string(string, unary, _str_to_unary, "abs[int]") + + +def indexunary_from_string(string): + # "select" is a variant of IndexUnary, so the string abbreviations in + # _str_to_select are appropriate to reuse here + return _from_string(string, indexunary, _str_to_select, "row_index") + + +def select_from_string(string): + return _from_string(string, select, _str_to_select, "tril") + + +def binary_from_string(string): + return _from_string(string, binary, _str_to_binary, "+[int]") + + +def monoid_from_string(string): + return _from_string(string, monoid, _str_to_monoid, "+[int]") + + +def semiring_from_string(string): + split = string.split(".") + if len(split) == 1: + try: + return _from_string(string, semiring, {}, "min.+[int]") + except Exception: + pass + if len(split) != 2: + raise ValueError( + f"Bad semiring string: {string!r}. " + 'The monoid and binaryop should be separated by exactly one period, ".". ' + "Example usage: min.+[int]" + ) + cur_monoid = monoid_from_string(split[0]) + cur_binary = binary_from_string(split[1]) + return get_semiring(cur_monoid, cur_binary) + + +def op_from_string(string): + for func in [ + # Note: order matters here + unary_from_string, + binary_from_string, + monoid_from_string, + semiring_from_string, + indexunary_from_string, + select_from_string, + aggregator_from_string, + ]: + try: + return func(string) + except Exception: + pass + raise ValueError(f"Unknown op string: {string!r}. Example usage: 'abs[int]'") + + +unary.from_string = unary_from_string +indexunary.from_string = indexunary_from_string +select.from_string = select_from_string +binary.from_string = binary_from_string +monoid.from_string = monoid_from_string +semiring.from_string = semiring_from_string +op.from_string = op_from_string + +_str_to_agg = { + "+": "sum", + "*": "prod", + "&": "all", + "|": "any", +} + + +def aggregator_from_string(string): + return _from_string(string, agg, _str_to_agg, "sum[int]") + + +from ... import agg # noqa: E402 isort:skip + +agg.from_string = aggregator_from_string diff --git a/graphblas/monoid/numpy.py b/graphblas/monoid/numpy.py index 2d8d70c20..1d687443f 100644 --- a/graphblas/monoid/numpy.py +++ b/graphblas/monoid/numpy.py @@ -173,10 +173,8 @@ def __getattr__(name): if _config.get("mapnumpy") and name in _numpy_to_graphblas: globals()[name] = getattr(_monoid, _numpy_to_graphblas[name]) else: - from ..core import operator - func = getattr(_binary.numpy, name) - operator.Monoid.register_new( + _monoid.register_new( f"numpy.{name}", func, _monoid_identities[name], is_idempotent=name in _idempotent ) return globals()[name] diff --git a/graphblas/semiring/numpy.py b/graphblas/semiring/numpy.py index 64169168a..e47ac0336 100644 --- a/graphblas/semiring/numpy.py +++ b/graphblas/semiring/numpy.py @@ -136,7 +136,7 @@ def __dir__(): def __getattr__(name): - from ..core import operator + from ..core.operator import get_semiring if name in _delayed: func, kwargs = _delayed.pop(name) @@ -161,7 +161,7 @@ def __getattr__(name): binary_name = "_".join(words[i:]) if hasattr(_binary.numpy, binary_name): # pragma: no branch break - operator.get_semiring( + get_semiring( getattr(_monoid.numpy, monoid_name), getattr(_binary.numpy, binary_name), name=f"numpy.{name}", diff --git a/graphblas/tests/pickle1-vanilla.pkl b/graphblas/tests/pickle1-vanilla.pkl index 36ea20760..a494e405a 100644 Binary files a/graphblas/tests/pickle1-vanilla.pkl and b/graphblas/tests/pickle1-vanilla.pkl differ diff --git a/graphblas/tests/pickle1.pkl b/graphblas/tests/pickle1.pkl index 98a1fdf05..273b49901 100644 Binary files a/graphblas/tests/pickle1.pkl and b/graphblas/tests/pickle1.pkl differ diff --git a/graphblas/tests/pickle2-vanilla.pkl b/graphblas/tests/pickle2-vanilla.pkl index 3c6e18ba4..dd091c823 100644 Binary files a/graphblas/tests/pickle2-vanilla.pkl and b/graphblas/tests/pickle2-vanilla.pkl differ diff --git a/graphblas/tests/pickle2.pkl b/graphblas/tests/pickle2.pkl index 3c6e18ba4..dd091c823 100644 Binary files a/graphblas/tests/pickle2.pkl and b/graphblas/tests/pickle2.pkl differ diff --git a/graphblas/tests/pickle3-vanilla.pkl b/graphblas/tests/pickle3-vanilla.pkl index 29e79d7db..7f8408c95 100644 Binary files a/graphblas/tests/pickle3-vanilla.pkl and b/graphblas/tests/pickle3-vanilla.pkl differ diff --git a/graphblas/tests/pickle3.pkl b/graphblas/tests/pickle3.pkl index d04a53cb9..28b308452 100644 Binary files a/graphblas/tests/pickle3.pkl and b/graphblas/tests/pickle3.pkl differ diff --git a/graphblas/tests/test_core.py b/graphblas/tests/test_core.py index c08ca416f..71d0bd8a3 100644 --- a/graphblas/tests/test_core.py +++ b/graphblas/tests/test_core.py @@ -1,7 +1,18 @@ +import pathlib + import pytest import graphblas as gb +try: + import setuptools +except ImportError: # pragma: no cover (import) + setuptools = None +try: + import tomli +except ImportError: # pragma: no cover (import) + tomli = None + def test_import_special_attrs(): not_hidden = {x for x in dir(gb) if not x.startswith("__")} @@ -57,3 +68,22 @@ def test_version(): from packaging.version import parse assert parse(gb.__version__) > parse("2022.11.0") + + +@pytest.mark.skipif("not setuptools or not tomli or not gb.__file__") +def test_packages(): + """Ensure all packages are declared in pyproject.toml.""" + # Currently assume s`pyproject.toml` is at the same level as `graphblas` folder. + # This probably isn't always True, and we can probably do a better job of finding it. + path = pathlib.Path(gb.__file__).parent + pkgs = [f"graphblas.{x}" for x in setuptools.find_packages(str(path))] + pkgs.append("graphblas") + pkgs.sort() + pyproject = path.parent / "pyproject.toml" + if not pyproject.exists(): + pytest.skip("Did not find pyproject.toml") + with pyproject.open("rb") as f: + pkgs2 = sorted(tomli.load(f)["tool"]["setuptools"]["packages"]) + assert ( + pkgs == pkgs2 + ), "If there are extra items on the left, add them to pyproject.toml:tool.setuptools.packages" diff --git a/graphblas/tests/test_op.py b/graphblas/tests/test_op.py index e32606290..3a80dbe52 100644 --- a/graphblas/tests/test_op.py +++ b/graphblas/tests/test_op.py @@ -1332,6 +1332,8 @@ def test_deprecated(): gb.op.secondj with pytest.warns(DeprecationWarning, match="please use"): gb.agg.argmin + with pytest.warns(DeprecationWarning, match="please use"): + import graphblas.core.agg # noqa: F401 def test_is_idempotent(): diff --git a/graphblas/unary/numpy.py b/graphblas/unary/numpy.py index 06086569d..836da2024 100644 --- a/graphblas/unary/numpy.py +++ b/graphblas/unary/numpy.py @@ -133,19 +133,17 @@ def __getattr__(name): if _config.get("mapnumpy") and name in _numpy_to_graphblas: globals()[name] = getattr(_unary, _numpy_to_graphblas[name]) else: - from ..core import operator - numpy_func = getattr(_np, name) def func(x): # pragma: no cover (numba) return numpy_func(x) - operator.UnaryOp.register_new(f"numpy.{name}", func) + _unary.register_new(f"numpy.{name}", func) if name == "reciprocal": # numba doesn't match numpy here def reciprocal(x): # pragma: no cover (numba) return 1 if x else 0 - op = operator.UnaryOp.register_anonymous(reciprocal) + op = _unary.register_anonymous(reciprocal) globals()[name]._add(op["BOOL"]) return globals()[name] diff --git a/pyproject.toml b/pyproject.toml index c035b53db..47cf1e67f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ test = [ "packaging", "pandas >=1.2", "scipy >=1.8", + "tomli", ] complete = [ "pandas >=1.2", @@ -100,6 +101,7 @@ complete = [ "matplotlib >=3.5", "pytest", "packaging", + "tomli", ] [tool.setuptools] @@ -112,6 +114,7 @@ packages = [ "graphblas.agg", "graphblas.binary", "graphblas.core", + "graphblas.core.operator", "graphblas.core.ss", "graphblas.indexunary", "graphblas.monoid", @@ -314,12 +317,13 @@ ignore = [ ] [tool.ruff.per-file-ignores] -"graphblas/core/operator.py" = ["S102"] # exec is used for UDF +"graphblas/core/agg.py" = ["F401", "F403"] # Deprecated +"graphblas/core/operator/base.py" = ["S102"] # exec is used for UDF "graphblas/core/ss/matrix.py" = ["NPY002"] # numba doesn't support rng generator yet "graphblas/core/ss/vector.py" = ["NPY002"] # numba doesn't support rng generator yet "graphblas/ss/_core.py" = ["N999"] # We want _core.py to be underscopre -# Allow assert, pickle, RNG, print, no docstring, and yoda in tests -"graphblas/tests/*py" = ["S101", "S301", "S311", "T201", "D103", "D100", "SIM300"] +# Allow useless expressions, assert, pickle, RNG, print, no docstring, and yoda in tests +"graphblas/tests/*py" = ["B018", "S101", "S301", "S311", "T201", "D103", "D100", "SIM300"] "graphblas/tests/test_formatting.py" = ["E501"] # Allow long lines "graphblas/**/__init__.py" = ["F401"] # Allow unused imports (w/o defining `__all__`) "scripts/*.py" = ["INP001"] # Not a package diff --git a/scripts/check_versions.sh b/scripts/check_versions.sh index d08ad6476..cdd4adf16 100755 --- a/scripts/check_versions.sh +++ b/scripts/check_versions.sh @@ -12,5 +12,5 @@ conda search 'sparse[channel=conda-forge]>=0.14.0' conda search 'fast_matrix_market[channel=conda-forge]>=1.4.5' conda search 'numba[channel=conda-forge]>=0.56.4' conda search 'pyyaml[channel=conda-forge]>=6.0' -conda search 'flake8-bugbear[channel=conda-forge]>=23.3.12' +conda search 'flake8-bugbear[channel=conda-forge]>=23.3.23' conda search 'flake8-simplify[channel=conda-forge]>=0.19.3' diff --git a/scripts/create_pickle.py b/scripts/create_pickle.py index 9ee672c41..10fe58630 100755 --- a/scripts/create_pickle.py +++ b/scripts/create_pickle.py @@ -6,7 +6,7 @@ """ import argparse import pickle -from pathlib import PurePath +from pathlib import Path import graphblas as gb from graphblas.tests.test_pickle import * @@ -158,7 +158,7 @@ def pickle3(filepath): extra = "-vanilla" else: extra = "" - path = PurePath(gb.tests.__file__).parent + path = Path(gb.tests.__file__).parent pickle1(path / f"pickle1{extra}.pkl") pickle2(path / f"pickle2{extra}.pkl") pickle3(path / f"pickle3{extra}.pkl") diff --git a/scripts/test_imports.sh b/scripts/test_imports.sh index c38e41d3e..cc989ef06 100755 --- a/scripts/test_imports.sh +++ b/scripts/test_imports.sh @@ -3,7 +3,7 @@ # Make sure imports work. Also, this is a good way to measure import performance. if ! python -c "from graphblas import * ; Matrix" ; then exit 1 ; fi if ! python -c "from graphblas import agg" ; then exit 1 ; fi -if ! python -c "from graphblas.core import agg" ; then exit 1 ; fi +if ! python -c "from graphblas.core.operator import agg" ; then exit 1 ; fi if ! python -c "from graphblas.agg import count" ; then exit 1 ; fi if ! python -c "from graphblas.binary import plus" ; then exit 1 ; fi if ! python -c "from graphblas.indexunary import tril" ; then exit 1 ; fi @@ -20,7 +20,7 @@ if ! (for attr in Matrix Scalar Vector Recorder agg binary dtypes exceptions \ fi done ) ; then exit 1 ; fi -if ! (for attr in agg base descriptor expr formatting ffi infix lib mask \ +if ! (for attr in base descriptor expr formatting ffi infix lib mask \ matrix operator scalar vector recorder automethods infixmethods slice ss do echo python -c \"from graphblas.core import $attr\" if ! python -c "from graphblas.core import $attr" @@ -44,7 +44,7 @@ if ! (for attr in agg binary binary.numpy dtypes exceptions io monoid monoid.num fi done ) ; then exit 1 ; fi -if ! (for attr in agg base descriptor expr formatting infix mask matrix \ +if ! (for attr in base descriptor expr formatting infix mask matrix \ operator scalar vector recorder automethods infixmethods slice ss do echo python -c \"import graphblas.core.$attr\" if ! python -c "import graphblas.core.$attr" @@ -60,3 +60,10 @@ if ! python -c "from graphblas import op ; op.plus" ; then exit 1 ; fi if ! python -c "from graphblas import select ; select.tril" ; then exit 1 ; fi if ! python -c "from graphblas import semiring ; semiring.plus_times" ; then exit 1 ; fi if ! python -c "from graphblas import unary ; unary.exp" ; then exit 1 ; fi +if ! (for attr in agg unary binary monoid semiring select indexunary base utils + do echo python -c \"import graphblas.core.operator.$attr\" + if ! python -c "import graphblas.core.operator.$attr" + then exit 1 + fi + done +) ; then exit 1 ; fi