diff --git a/.github/workflows/test_and_build.yml b/.github/workflows/test_and_build.yml index 209060521..d93b4c25c 100644 --- a/.github/workflows/test_and_build.yml +++ b/.github/workflows/test_and_build.yml @@ -131,9 +131,9 @@ jobs: source upstream weights: | - 1000000 - 1000000 - 1000000 + 1 + 1 + 1 1 - name: Setup mamba uses: conda-incubator/setup-miniconda@v2 @@ -175,22 +175,22 @@ jobs: npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", ""]))') spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", ""]))') pdver=$(python -c 'import random ; print(random.choice(["=1.2", "=1.3", "=1.4", "=1.5", "=2.0", ""]))') - akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", ""]))') + akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", "=2.3", ""]))') elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.9') }} == true ]]; then npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", "=1.25", ""]))') spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", "=1.11", ""]))') pdver=$(python -c 'import random ; print(random.choice(["=1.2", "=1.3", "=1.4", "=1.5", "=2.0", ""]))') - akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", ""]))') + akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", "=2.3", ""]))') elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.10') }} == true ]]; then npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", "=1.25", ""]))') spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", "=1.11", ""]))') pdver=$(python -c 'import random ; print(random.choice(["=1.3", "=1.4", "=1.5", "=2.0", ""]))') - akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", ""]))') + akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", "=2.3", ""]))') else # Python 3.11 npver=$(python -c 'import random ; print(random.choice(["=1.23", "=1.24", "=1.25", ""]))') spver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=1.11", ""]))') pdver=$(python -c 'import random ; print(random.choice(["=1.5", "=2.0", ""]))') - akver=$(python -c 'import random ; print(random.choice(["=1.10", "=2.0", "=2.1", "=2.2", ""]))') + akver=$(python -c 'import random ; print(random.choice(["=1.10", "=2.0", "=2.1", "=2.2", "=2.3", ""]))') fi if [[ ${{ steps.sourcetype.outputs.selected }} == "source" || ${{ steps.sourcetype.outputs.selected }} == "upstream" ]]; then # TODO: there are currently issues with some numpy versions when @@ -204,13 +204,13 @@ jobs: # But, it's still useful for us to test with different versions! psg="" if [[ ${{ steps.sourcetype.outputs.selected}} == "conda-forge" ]] ; then - psgver=$(python -c 'import random ; print(random.choice(["=7.4.0", "=7.4.1", "=7.4.2", "=7.4.3.0", "=7.4.3.1", "=7.4.3.2"]))') + psgver=$(python -c 'import random ; print(random.choice(["=7.4.0", "=7.4.1", "=7.4.2", "=7.4.3.0", "=7.4.3.1", "=7.4.3.2", "=8.0.2.1", ""]))') psg=python-suitesparse-graphblas${psgver} elif [[ ${{ steps.sourcetype.outputs.selected}} == "wheel" ]] ; then - psgver=$(python -c 'import random ; print(random.choice(["==7.4.3.2"]))') + psgver=$(python -c 'import random ; print(random.choice(["==7.4.3.2", "==8.0.2.1", ""]))') elif [[ ${{ steps.sourcetype.outputs.selected}} == "source" ]] ; then # These should be exact versions - psgver=$(python -c 'import random ; print(random.choice(["==7.4.0.0", "==7.4.1.0", "==7.4.2.0", "==7.4.3.0", "==7.4.3.1", "==7.4.3.2"]))') + psgver=$(python -c 'import random ; print(random.choice(["==7.4.0.0", "==7.4.1.0", "==7.4.2.0", "==7.4.3.0", "==7.4.3.1", "==7.4.3.2", "==8.0.2.1", ""]))') else psgver="" fi @@ -260,17 +260,18 @@ jobs: numba=numba${numbaver} sparse=sparse${sparsever} fi - echo "versions: np${npver} sp${spver} pd${pdver} ak${akver} nx${nxver} numba${numbaver} yaml${yamlver} sparse${sparsever} psgver${psgver}" + echo "versions: np${npver} sp${spver} pd${pdver} ak${akver} nx${nxver} numba${numbaver} yaml${yamlver} sparse${sparsever} psg${psgver}" set -x # echo on - $(command -v mamba || command -v conda) install packaging pytest coverage coveralls=3.3.1 pytest-randomly cffi donfig tomli \ + $(command -v mamba || command -v conda) install packaging pytest coverage coveralls=3.3.1 pytest-randomly cffi donfig tomli c-compiler make \ pyyaml${yamlver} ${sparse} pandas${pdver} scipy${spver} numpy${npver} ${awkward} \ networkx${nxver} ${numba} ${fmm} ${psg} \ ${{ matrix.slowtask == 'pytest_bizarro' && 'black' || '' }} \ ${{ matrix.slowtask == 'notebooks' && 'matplotlib nbconvert jupyter "ipython>=7"' || '' }} \ ${{ steps.sourcetype.outputs.selected == 'upstream' && 'cython' || '' }} \ - ${{ steps.sourcetype.outputs.selected != 'wheel' && '"graphblas=7.4"' || '' }} \ - ${{ contains(steps.pyver.outputs.selected, 'pypy') && 'pypy' || '' }} + ${{ steps.sourcetype.outputs.selected != 'wheel' && '"graphblas>=7.4"' || '' }} \ + ${{ contains(steps.pyver.outputs.selected, 'pypy') && 'pypy' || '' }} \ + ${{ matrix.os == 'windows-latest' && 'cmake' || 'm4' }} - name: Build extension module run: | if [[ ${{ steps.sourcetype.outputs.selected }} == "wheel" ]]; then @@ -291,6 +292,12 @@ jobs: pip install --no-deps git+https://github.com/GraphBLAS/python-suitesparse-graphblas.git@main#egg=suitesparse-graphblas fi pip install --no-deps -e . + - name: python-suitesparse-graphblas tests + run: | + # Don't use our conftest.py ; allow `test_print_jit_config` to fail if it doesn't exist + (cd .. + pytest --pyargs suitesparse_graphblas -s -k test_print_jit_config || true + pytest -v --pyargs suitesparse_graphblas) - name: Unit tests run: | A=${{ needs.rngs.outputs.mapnumpy == 'A' || '' }} ; B=${{ needs.rngs.outputs.mapnumpy == 'B' || '' }} @@ -318,7 +325,6 @@ jobs: if [[ $H && $normal ]] ; then if [[ $macos ]] ; then echo " $vanilla" ; elif [[ $windows ]] ; then echo " $suitesparse" ; fi ; fi)$( \ if [[ $H && $bizarro ]] ; then if [[ $macos ]] ; then echo " $suitesparse" ; elif [[ $windows ]] ; then echo " $vanilla" ; fi ; fi) echo ${args} - (cd .. && pytest -v --pyargs suitesparse_graphblas) # Don't use our conftest.py set -x # echo on coverage run -m pytest --color=yes --randomly -v ${args} \ ${{ matrix.slowtask == 'pytest_normal' && '--runslow' || '' }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f0ca307e8..726538e16 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: - id: isort # Let's keep `pyupgrade` even though `ruff --fix` probably does most of it - repo: https://github.com/asottile/pyupgrade - rev: v3.7.0 + rev: v3.8.0 hooks: - id: pyupgrade args: [--py38-plus] @@ -66,7 +66,7 @@ repos: - id: black - id: black-jupyter - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.275 + rev: v0.0.277 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -94,7 +94,7 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.275 + rev: v0.0.277 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint diff --git a/docs/env.yml b/docs/env.yml index 3636cfa2d..c0c4c8999 100644 --- a/docs/env.yml +++ b/docs/env.yml @@ -8,7 +8,7 @@ dependencies: # python-graphblas dependencies - donfig - numba - - python-suitesparse-graphblas>=7.4.0.0,<8 + - python-suitesparse-graphblas>=7.4.0.0 - pyyaml # extra dependencies - matplotlib diff --git a/graphblas/binary/ss.py b/graphblas/binary/ss.py index 97852fc12..0c294e322 100644 --- a/graphblas/binary/ss.py +++ b/graphblas/binary/ss.py @@ -1,4 +1,5 @@ from ..core import operator +from ..core.ss.binary import register_new # noqa: F401 _delayed = {} diff --git a/graphblas/core/dtypes.py b/graphblas/core/dtypes.py index 345c1be81..d7a83c99b 100644 --- a/graphblas/core/dtypes.py +++ b/graphblas/core/dtypes.py @@ -22,7 +22,7 @@ def __init__(self, name, gb_obj, gb_name, c_type, numba_type, np_type): self.gb_name = gb_name self.c_type = c_type self.numba_type = numba_type - self.np_type = np.dtype(np_type) + self.np_type = np.dtype(np_type) if np_type is not None else None def __repr__(self): return self.name diff --git a/graphblas/core/ss/__init__.py b/graphblas/core/ss/__init__.py index e69de29bb..c2e83ddcc 100644 --- a/graphblas/core/ss/__init__.py +++ b/graphblas/core/ss/__init__.py @@ -0,0 +1,3 @@ +import suitesparse_graphblas as _ssgb + +_IS_SSGB7 = _ssgb.__version__.split(".", 1)[0] == "7" diff --git a/graphblas/core/ss/binary.py b/graphblas/core/ss/binary.py new file mode 100644 index 000000000..898257fac --- /dev/null +++ b/graphblas/core/ss/binary.py @@ -0,0 +1,72 @@ +from ... import backend +from ...dtypes import lookup_dtype +from ...exceptions import check_status_carg +from .. import NULL, ffi, lib +from ..operator.base import TypedOpBase +from ..operator.binary import BinaryOp, TypedUserBinaryOp +from . import _IS_SSGB7 + +ffi_new = ffi.new + + +class TypedJitBinaryOp(TypedOpBase): + __slots__ = "_monoid", "_jit_c_definition" + opclass = "BinaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, name, dtype2=dtype2) + self._monoid = None + self._jit_c_definition = jit_c_definition + + @property + def jit_c_definition(self): + return self._jit_c_definition + + monoid = TypedUserBinaryOp.monoid + commutes_to = TypedUserBinaryOp.commutes_to + _semiring_commutes_to = TypedUserBinaryOp._semiring_commutes_to + is_commutative = TypedUserBinaryOp.is_commutative + type2 = TypedUserBinaryOp.type2 + __call__ = TypedUserBinaryOp.__call__ + + +def register_new(name, jit_c_definition, left_type, right_type, ret_type): + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.binary.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + left_type = lookup_dtype(left_type) + right_type = lookup_dtype(right_type) + ret_type = lookup_dtype(ret_type) + name = name if name.startswith("ss.") else f"ss.{name}" + module, funcname = BinaryOp._remove_nesting(name) + + rv = BinaryOp(name) + gb_obj = ffi_new("GrB_BinaryOp*") + check_status_carg( + lib.GxB_BinaryOp_new( + gb_obj, + NULL, + ret_type._carg, + left_type._carg, + right_type._carg, + ffi_new("char[]", funcname.encode()), + ffi_new("char[]", jit_c_definition.encode()), + ), + "BinaryOp", + gb_obj[0], + ) + op = TypedJitBinaryOp( + rv, funcname, left_type, ret_type, gb_obj[0], jit_c_definition, dtype2=right_type + ) + rv._add(op) + setattr(module, funcname, rv) + return rv diff --git a/graphblas/core/ss/config.py b/graphblas/core/ss/config.py index 89536479d..433716bb3 100644 --- a/graphblas/core/ss/config.py +++ b/graphblas/core/ss/config.py @@ -65,7 +65,7 @@ def __getitem__(self, key): raise KeyError(key) key_obj, ctype = self._options[key] is_bool = ctype == "bool" - if is_context := (key in self._context_keys): # pragma: no cover (suitesparse 8) + if is_context := (key in self._context_keys): get_function_base = self._context_get_function else: get_function_base = self._get_function @@ -76,14 +76,14 @@ def __getitem__(self, key): get_function_name = f"{get_function_base}_INT64" elif ctype.startswith("double"): get_function_name = f"{get_function_base}_FP64" - elif ctype.startswith("char"): # pragma: no cover (suitesparse 8) + elif ctype.startswith("char"): get_function_name = f"{get_function_base}_CHAR" else: # pragma: no cover (sanity) raise ValueError(ctype) get_function = getattr(lib, get_function_name) is_array = "[" in ctype val_ptr = ffi.new(ctype if is_array else f"{ctype}*") - if is_context: # pragma: no cover (suitesparse 8) + if is_context: info = get_function(self._context._carg, key_obj, val_ptr) elif self._parent is None: info = get_function(key_obj, val_ptr) @@ -105,7 +105,7 @@ def __getitem__(self, key): return rv if is_bool: return bool(val_ptr[0]) - if ctype.startswith("char"): # pragma: no cover (suitesparse 8) + if ctype.startswith("char"): return ffi.string(val_ptr[0]).decode() return val_ptr[0] raise _error_code_lookup[info](f"Failed to get info for {key!r}") # pragma: no cover @@ -117,7 +117,7 @@ def __setitem__(self, key, val): if key in self._read_only: raise ValueError(f"Config option {key!r} is read-only") key_obj, ctype = self._options[key] - if is_context := (key in self._context_keys): # pragma: no cover (suitesparse 8) + if is_context := (key in self._context_keys): set_function_base = self._context_set_function else: set_function_base = self._set_function @@ -130,7 +130,7 @@ def __setitem__(self, key, val): set_function_name = f"{set_function_base}_INT64_ARRAY" elif ctype.startswith("double["): set_function_name = f"{set_function_base}_FP64_ARRAY" - elif ctype.startswith("char"): # pragma: no cover (suitesparse 8) + elif ctype.startswith("char"): set_function_name = f"{set_function_base}_CHAR" else: # pragma: no cover (sanity) raise ValueError(ctype) @@ -174,11 +174,11 @@ def __setitem__(self, key, val): f"expected {size}, got {vals.size}: {val}" ) val_obj = ffi.from_buffer(ctype, vals) - elif ctype.startswith("char"): # pragma: no cover (suitesparse 8) + elif ctype.startswith("char"): val_obj = ffi.new("char[]", val.encode()) else: val_obj = ffi.cast(ctype, val) - if is_context: # pragma: no cover (suitesparse 8) + if is_context: if self._context is None: from .context import Context diff --git a/graphblas/core/ss/context.py b/graphblas/core/ss/context.py new file mode 100644 index 000000000..9b48bcaa4 --- /dev/null +++ b/graphblas/core/ss/context.py @@ -0,0 +1,146 @@ +import threading + +from ...exceptions import InvalidValue, check_status, check_status_carg +from .. import ffi, lib +from . import _IS_SSGB7 +from .config import BaseConfig + +ffi_new = ffi.new +if _IS_SSGB7: + # Context was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise ImportError( + "Context was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + + +class Context(BaseConfig): + _context_keys = {"chunk", "gpu_id", "nthreads"} + _options = { + "chunk": (lib.GxB_CONTEXT_CHUNK, "double"), + "gpu_id": (lib.GxB_CONTEXT_GPU_ID, "int"), + "nthreads": (lib.GxB_CONTEXT_NTHREADS, "int"), + } + _defaults = { + "nthreads": 0, + "chunk": 0, + "gpu_id": -1, # -1 means no GPU + } + + def __init__(self, engage=True, *, stack=True, nthreads=None, chunk=None, gpu_id=None): + super().__init__() + self.gb_obj = ffi_new("GxB_Context*") + check_status_carg(lib.GxB_Context_new(self.gb_obj), "Context", self.gb_obj[0]) + if stack: + context = threadlocal.context + self["nthreads"] = context["nthreads"] if nthreads is None else nthreads + self["chunk"] = context["chunk"] if chunk is None else chunk + self["gpu_id"] = context["gpu_id"] if gpu_id is None else gpu_id + else: + if nthreads is not None: + self["nthreads"] = nthreads + if chunk is not None: + self["chunk"] = chunk + if gpu_id is not None: + self["gpu_id"] = gpu_id + self._prev_context = None + if engage: + self.engage() + + @classmethod + def _from_obj(cls, gb_obj=None): + self = object.__new__(cls) + self.gb_obj = gb_obj + self._prev_context = None + super().__init__(self) + return self + + @property + def _carg(self): + return self.gb_obj[0] + + def dup(self, engage=True, *, nthreads=None, chunk=None, gpu_id=None): + if nthreads is None: + nthreads = self["nthreads"] + if chunk is None: + chunk = self["chunk"] + if gpu_id is None: + gpu_id = self["gpu_id"] + return type(self)(engage, stack=False, nthreads=nthreads, chunk=chunk, gpu_id=gpu_id) + + def __del__(self): + gb_obj = getattr(self, "gb_obj", None) + if gb_obj is not None and lib is not None: # pragma: no branch (safety) + try: + self.disengage() + except InvalidValue: + pass + lib.GxB_Context_free(gb_obj) + + def engage(self): + if self._prev_context is None and (context := threadlocal.context) is not self: + self._prev_context = context + check_status(lib.GxB_Context_engage(self._carg), self) + threadlocal.context = self + + def _engage(self): + """Like engage, but don't set to threadlocal.context. + + This is useful if you want to disengage when the object is deleted by going out of scope. + """ + if self._prev_context is None and (context := threadlocal.context) is not self: + self._prev_context = context + check_status(lib.GxB_Context_engage(self._carg), self) + + def disengage(self): + prev_context = self._prev_context + self._prev_context = None + if threadlocal.context is self: + if prev_context is not None: + threadlocal.context = prev_context + prev_context.engage() + else: + threadlocal.context = global_context + check_status(lib.GxB_Context_disengage(self._carg), self) + elif prev_context is not None and threadlocal.context is prev_context: + prev_context.engage() + else: + check_status(lib.GxB_Context_disengage(self._carg), self) + + def __enter__(self): + self.engage() + + def __exit__(self, exc_type, exc, exc_tb): + self.disengage() + + @property + def _context(self): + return self + + @_context.setter + def _context(self, val): + if val is not None and val is not self: + raise AttributeError("'_context' attribute is read-only") + + +class GlobalContext(Context): + @property + def _carg(self): + return self.gb_obj + + def __del__(self): # pragma: no cover (safety) + pass + + +global_context = GlobalContext._from_obj(lib.GxB_CONTEXT_WORLD) + + +class ThreadLocal(threading.local): + """Hold the active context for the current thread.""" + + context = global_context + + +threadlocal = ThreadLocal() diff --git a/graphblas/core/ss/descriptor.py b/graphblas/core/ss/descriptor.py index 43553f5ea..52c43b95d 100644 --- a/graphblas/core/ss/descriptor.py +++ b/graphblas/core/ss/descriptor.py @@ -1,6 +1,7 @@ from ...exceptions import check_status, check_status_carg from .. import ffi, lib from ..descriptor import Descriptor +from . import _IS_SSGB7 from .config import BaseConfig ffi_new = ffi.new @@ -18,6 +19,8 @@ class _DescriptorConfig(BaseConfig): _get_function = "GxB_Desc_get" _set_function = "GxB_Desc_set" + if not _IS_SSGB7: + _context_keys = {"chunk", "gpu_id", "nthreads"} _options = { # GrB "output_replace": (lib.GrB_OUTP, "GrB_Desc_Value"), @@ -26,13 +29,25 @@ class _DescriptorConfig(BaseConfig): "transpose_first": (lib.GrB_INP0, "GrB_Desc_Value"), "transpose_second": (lib.GrB_INP1, "GrB_Desc_Value"), # GxB - "nthreads": (lib.GxB_DESCRIPTOR_NTHREADS, "int"), - "chunk": (lib.GxB_DESCRIPTOR_CHUNK, "double"), "axb_method": (lib.GxB_AxB_METHOD, "GrB_Desc_Value"), "sort": (lib.GxB_SORT, "int"), "secure_import": (lib.GxB_IMPORT, "int"), - # "gpu_control": (GxB_DESCRIPTOR_GPU_CONTROL, "GrB_Desc_Value"), # Coming soon... } + if _IS_SSGB7: + _options.update( + { + "nthreads": (lib.GxB_DESCRIPTOR_NTHREADS, "int"), + "chunk": (lib.GxB_DESCRIPTOR_CHUNK, "double"), + } + ) + else: + _options.update( + { + "chunk": (lib.GxB_CONTEXT_CHUNK, "double"), + "gpu_id": (lib.GxB_CONTEXT_GPU_ID, "int"), + "nthreads": (lib.GxB_CONTEXT_NTHREADS, "int"), + } + ) _enumerations = { # GrB "output_replace": { @@ -71,10 +86,6 @@ class _DescriptorConfig(BaseConfig): False: False, True: lib.GxB_SORT, }, - # "gpu_control": { # Coming soon... - # "always": lib.GxB_GPU_ALWAYS, - # "never": lib.GxB_GPU_NEVER, - # }, } _defaults = { # GrB @@ -90,6 +101,8 @@ class _DescriptorConfig(BaseConfig): "sort": False, "secure_import": False, } + if not _IS_SSGB7: + _defaults["gpu_id"] = -1 def __init__(self): gb_obj = ffi_new("GrB_Descriptor*") diff --git a/graphblas/core/ss/dtypes.py b/graphblas/core/ss/dtypes.py new file mode 100644 index 000000000..d2eb5b416 --- /dev/null +++ b/graphblas/core/ss/dtypes.py @@ -0,0 +1,88 @@ +import numpy as np + +from ... import backend, core, dtypes +from ...exceptions import check_status_carg +from .. import _has_numba, ffi, lib +from . import _IS_SSGB7 + +ffi_new = ffi.new +if _has_numba: + import numba + from cffi import FFI + from numba.core.typing import cffi_utils + + jit_ffi = FFI() + + +def register_new(name, jit_c_definition, *, np_type=None): + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.dtypes.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + if not name.isidentifier(): + raise ValueError(f"`name` argument must be a valid Python identifier; got: {name!r}") + if name in core.dtypes._registry or hasattr(dtypes.ss, name): + raise ValueError(f"{name!r} name for dtype is unavailable") + if len(name) > lib.GxB_MAX_NAME_LEN: + raise ValueError( + f"`name` argument is too large. Max size is {lib.GxB_MAX_NAME_LEN}; got {len(name)}" + ) + if name not in jit_c_definition: + raise ValueError("`name` argument must be same name as the typedef in `jit_c_definition`") + if "struct" not in jit_c_definition: + raise ValueError("Only struct typedefs are currently allowed for JIT dtypes") + + gb_obj = ffi.new("GrB_Type*") + status = lib.GxB_Type_new( + gb_obj, 0, ffi_new("char[]", name.encode()), ffi_new("char[]", jit_c_definition.encode()) + ) + check_status_carg(status, "Type", gb_obj[0]) + + # Let SuiteSparse:GraphBLAS determine the size (we gave 0 as size above) + size_ptr = ffi_new("size_t*") + check_status_carg(lib.GxB_Type_size(size_ptr, gb_obj[0]), "Type", gb_obj[0]) + size = size_ptr[0] + + save_np_type = True + if np_type is None and _has_numba and numba.__version__[:5] > "0.56.": + jit_ffi.cdef(jit_c_definition) + numba_type = cffi_utils.map_type(jit_ffi.typeof(name), use_record_dtype=True) + np_type = numba_type.dtype + if np_type.itemsize != size: # pragma: no cover + raise RuntimeError( + "Size of compiled user-defined type does not match size of inferred numpy type: " + f"{size} != {np_type.itemsize} != {size}.\n\n" + f"UDT C definition: {jit_c_definition}\n" + f"numpy dtype: {np_type}\n\n" + "To get around this, you may pass `np_type=` keyword argument." + ) + else: + if np_type is not None: + np_type = np.dtype(np_type) + else: + # Not an ideal numpy type, but minimally useful + np_type = np.dtype((np.uint8, size)) + save_np_type = False + if _has_numba: + numba_type = numba.typeof(np_type).dtype + else: + numba_type = None + + # For now, let's use "opaque" unsigned bytes for the c type. + rv = core.dtypes.DataType(name, gb_obj, None, f"uint8_t[{size}]", numba_type, np_type) + core.dtypes._registry[gb_obj] = rv + if save_np_type or np_type not in core.dtypes._registry: + core.dtypes._registry[np_type] = rv + if numba_type is not None and (save_np_type or numba_type not in core.dtypes._registry): + core.dtypes._registry[numba_type] = rv + core.dtypes._registry[numba_type.name] = rv + setattr(dtypes.ss, name, rv) + return rv diff --git a/graphblas/core/ss/indexunary.py b/graphblas/core/ss/indexunary.py new file mode 100644 index 000000000..c0f185737 --- /dev/null +++ b/graphblas/core/ss/indexunary.py @@ -0,0 +1,77 @@ +from ... import backend +from ...dtypes import BOOL, lookup_dtype +from ...exceptions import check_status_carg +from .. import NULL, ffi, lib +from ..operator.base import TypedOpBase +from ..operator.indexunary import IndexUnaryOp, TypedUserIndexUnaryOp +from . import _IS_SSGB7 + +ffi_new = ffi.new + + +class TypedJitIndexUnaryOp(TypedOpBase): + __slots__ = "_jit_c_definition" + opclass = "IndexUnaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, name, dtype2=dtype2) + self._jit_c_definition = jit_c_definition + + @property + def jit_c_definition(self): + return self._jit_c_definition + + __call__ = TypedUserIndexUnaryOp.__call__ + + +def register_new(name, jit_c_definition, input_type, thunk_type, ret_type): + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.indexunary.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + input_type = lookup_dtype(input_type) + thunk_type = lookup_dtype(thunk_type) + ret_type = lookup_dtype(ret_type) + name = name if name.startswith("ss.") else f"ss.{name}" + module, funcname = IndexUnaryOp._remove_nesting(name) + + rv = IndexUnaryOp(name) + gb_obj = ffi_new("GrB_IndexUnaryOp*") + check_status_carg( + lib.GxB_IndexUnaryOp_new( + gb_obj, + NULL, + ret_type._carg, + input_type._carg, + thunk_type._carg, + ffi_new("char[]", funcname.encode()), + ffi_new("char[]", jit_c_definition.encode()), + ), + "IndexUnaryOp", + gb_obj[0], + ) + op = TypedJitIndexUnaryOp( + rv, funcname, input_type, ret_type, gb_obj[0], jit_c_definition, dtype2=thunk_type + ) + rv._add(op) + if ret_type == BOOL: + from ..operator.select import SelectOp + from .select import TypedJitSelectOp + + select_module, funcname = SelectOp._remove_nesting(name, strict=False) + selectop = SelectOp(name) + op2 = TypedJitSelectOp( + rv, funcname, input_type, ret_type, gb_obj[0], jit_c_definition, dtype2=thunk_type + ) + selectop._add(op2) + setattr(select_module, funcname, selectop) + setattr(module, funcname, rv) + return rv diff --git a/graphblas/core/ss/select.py b/graphblas/core/ss/select.py new file mode 100644 index 000000000..37c352b67 --- /dev/null +++ b/graphblas/core/ss/select.py @@ -0,0 +1,45 @@ +from ... import backend, indexunary +from ...dtypes import BOOL, lookup_dtype +from .. import ffi +from ..operator.base import TypedOpBase +from ..operator.select import SelectOp, TypedUserSelectOp +from . import _IS_SSGB7 + +ffi_new = ffi.new + + +class TypedJitSelectOp(TypedOpBase): + __slots__ = "_jit_c_definition" + opclass = "SelectOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, name, dtype2=dtype2) + self._jit_c_definition = jit_c_definition + + @property + def jit_c_definition(self): + return self._jit_c_definition + + __call__ = TypedUserSelectOp.__call__ + + +def register_new(name, jit_c_definition, input_type, thunk_type): + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.select.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + input_type = lookup_dtype(input_type) + thunk_type = lookup_dtype(thunk_type) + name = name if name.startswith("ss.") else f"ss.{name}" + # Register to both `gb.indexunary.ss` and `gb.select.ss.` + indexunary.ss.register_new(name, jit_c_definition, input_type, thunk_type, BOOL) + module, funcname = SelectOp._remove_nesting(name, strict=False) + return getattr(module, funcname) diff --git a/graphblas/core/ss/unary.py b/graphblas/core/ss/unary.py new file mode 100644 index 000000000..97c4614c0 --- /dev/null +++ b/graphblas/core/ss/unary.py @@ -0,0 +1,62 @@ +from ... import backend +from ...dtypes import lookup_dtype +from ...exceptions import check_status_carg +from .. import NULL, ffi, lib +from ..operator.base import TypedOpBase +from ..operator.unary import TypedUserUnaryOp, UnaryOp +from . import _IS_SSGB7 + +ffi_new = ffi.new + + +class TypedJitUnaryOp(TypedOpBase): + __slots__ = "_jit_c_definition" + opclass = "UnaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition): + super().__init__(parent, name, type_, return_type, gb_obj, name) + self._jit_c_definition = jit_c_definition + + @property + def jit_c_definition(self): + return self._jit_c_definition + + __call__ = TypedUserUnaryOp.__call__ + + +def register_new(name, jit_c_definition, input_type, ret_type): + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.unary.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + input_type = lookup_dtype(input_type) + ret_type = lookup_dtype(ret_type) + name = name if name.startswith("ss.") else f"ss.{name}" + module, funcname = UnaryOp._remove_nesting(name) + + rv = UnaryOp(name) + gb_obj = ffi_new("GrB_UnaryOp*") + check_status_carg( + lib.GxB_UnaryOp_new( + gb_obj, + NULL, + ret_type._carg, + input_type._carg, + ffi_new("char[]", funcname.encode()), + ffi_new("char[]", jit_c_definition.encode()), + ), + "UnaryOp", + gb_obj[0], + ) + op = TypedJitUnaryOp(rv, funcname, input_type, ret_type, gb_obj[0], jit_c_definition) + rv._add(op) + setattr(module, funcname, rv) + return rv diff --git a/graphblas/dtypes/ss.py b/graphblas/dtypes/ss.py index e69de29bb..9f6083e01 100644 --- a/graphblas/dtypes/ss.py +++ b/graphblas/dtypes/ss.py @@ -0,0 +1 @@ +from ..core.ss.dtypes import register_new # noqa: F401 diff --git a/graphblas/indexunary/ss.py b/graphblas/indexunary/ss.py index 97852fc12..58218df6f 100644 --- a/graphblas/indexunary/ss.py +++ b/graphblas/indexunary/ss.py @@ -1,4 +1,5 @@ from ..core import operator +from ..core.ss.indexunary import register_new # noqa: F401 _delayed = {} diff --git a/graphblas/select/ss.py b/graphblas/select/ss.py index 97852fc12..173067382 100644 --- a/graphblas/select/ss.py +++ b/graphblas/select/ss.py @@ -1,4 +1,5 @@ from ..core import operator +from ..core.ss.select import register_new # noqa: F401 _delayed = {} diff --git a/graphblas/ss/__init__.py b/graphblas/ss/__init__.py index b36bc1bdc..b723d9cb8 100644 --- a/graphblas/ss/__init__.py +++ b/graphblas/ss/__init__.py @@ -1 +1,5 @@ -from ._core import about, concat, config, diag +from ._core import _IS_SSGB7, about, concat, config, diag + +if not _IS_SSGB7: + # Context was introduced in SuiteSparse:GraphBLAS 8.0 + from ..core.ss.context import Context, global_context diff --git a/graphblas/ss/_core.py b/graphblas/ss/_core.py index 53287f1a5..2639a7709 100644 --- a/graphblas/ss/_core.py +++ b/graphblas/ss/_core.py @@ -5,6 +5,7 @@ from ..core.descriptor import lookup as descriptor_lookup from ..core.matrix import Matrix, TransposedMatrix from ..core.scalar import _as_scalar +from ..core.ss import _IS_SSGB7 from ..core.ss.config import BaseConfig from ..core.ss.matrix import _concat_mn from ..core.vector import Vector @@ -126,13 +127,23 @@ class GlobalConfig(BaseConfig): Enable diagnostic printing from SuiteSparse:GraphBLAS print_1based : bool gpu_control : str, {"always", "never"} + Only available for SuiteSparse:GraphBLAS 7 + **GPU support is a work in progress--not recommended to use** gpu_chunk : double + Only available for SuiteSparse:GraphBLAS 7 + **GPU support is a work in progress--not recommended to use** + gpu_id : int + Which GPU to use; default is -1, which means do not run on the GPU. + Only available for SuiteSparse:GraphBLAS 8 + **GPU support is a work in progress--not recommended to use** Setting values to None restores the default value for most configurations. """ _get_function = "GxB_Global_Option_get" _set_function = "GxB_Global_Option_set" + if not _IS_SSGB7: + _context_keys = {"chunk", "gpu_id", "nthreads"} _null_valid = {"bitmap_switch"} _options = { # Matrix/Vector format @@ -147,10 +158,32 @@ class GlobalConfig(BaseConfig): # Diagnostics (skipping "printf" and "flush" for now) "burble": (lib.GxB_BURBLE, "bool"), "print_1based": (lib.GxB_PRINT_1BASED, "bool"), - # CUDA GPU control - "gpu_control": (lib.GxB_GLOBAL_GPU_CONTROL, "GrB_Desc_Value"), - "gpu_chunk": (lib.GxB_GLOBAL_GPU_CHUNK, "double"), } + if _IS_SSGB7: + _options.update( + { + "gpu_control": (lib.GxB_GLOBAL_GPU_CONTROL, "GrB_Desc_Value"), + "gpu_chunk": (lib.GxB_GLOBAL_GPU_CHUNK, "double"), + } + ) + else: + _options.update( + { + # JIT control + "jit_c_control": (lib.GxB_JIT_C_CONTROL, "int"), + "jit_use_cmake": (lib.GxB_JIT_USE_CMAKE, "bool"), + "jit_c_compiler_name": (lib.GxB_JIT_C_COMPILER_NAME, "char*"), + "jit_c_compiler_flags": (lib.GxB_JIT_C_COMPILER_FLAGS, "char*"), + "jit_c_linker_flags": (lib.GxB_JIT_C_LINKER_FLAGS, "char*"), + "jit_c_libraries": (lib.GxB_JIT_C_LIBRARIES, "char*"), + "jit_c_cmake_libs": (lib.GxB_JIT_C_CMAKE_LIBS, "char*"), + "jit_c_preface": (lib.GxB_JIT_C_PREFACE, "char*"), + "jit_error_log": (lib.GxB_JIT_ERROR_LOG, "char*"), + "jit_cache_path": (lib.GxB_JIT_CACHE_PATH, "char*"), + # CUDA GPU control + "gpu_id": (lib.GxB_GLOBAL_GPU_ID, "int"), + } + ) # Values to restore defaults _defaults = { "hyper_switch": lib.GxB_HYPER_DEFAULT, @@ -161,17 +194,28 @@ class GlobalConfig(BaseConfig): "burble": 0, "print_1based": 0, } + if not _IS_SSGB7: + _defaults["gpu_id"] = -1 # -1 means no GPU _enumerations = { "format": { "by_row": lib.GxB_BY_ROW, "by_col": lib.GxB_BY_COL, # "no_format": lib.GxB_NO_FORMAT, # Used by iterators; not valid here }, - "gpu_control": { + } + if _IS_SSGB7: + _enumerations["gpu_control"] = { "always": lib.GxB_GPU_ALWAYS, "never": lib.GxB_GPU_NEVER, - }, - } + } + else: + _enumerations["jit_c_control"] = { + "off": lib.GxB_JIT_OFF, + "pause": lib.GxB_JIT_PAUSE, + "run": lib.GxB_JIT_RUN, + "load": lib.GxB_JIT_LOAD, + "on": lib.GxB_JIT_ON, + } class About(Mapping): @@ -258,4 +302,10 @@ def __len__(self): about = About() -config = GlobalConfig() +if _IS_SSGB7: + config = GlobalConfig() +else: + # Context was introduced in SuiteSparse:GraphBLAS 8.0 + from ..core.ss.context import global_context + + config = GlobalConfig(context=global_context) diff --git a/graphblas/tests/test_ssjit.py b/graphblas/tests/test_ssjit.py new file mode 100644 index 000000000..57cb2bbba --- /dev/null +++ b/graphblas/tests/test_ssjit.py @@ -0,0 +1,269 @@ +import os +import sys + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import graphblas as gb +from graphblas import backend, binary, dtypes, indexunary, select, unary +from graphblas.core import _supports_udfs as supports_udfs +from graphblas.core.ss import _IS_SSGB7 + +from .conftest import autocompute, burble + +from graphblas import Vector # isort:skip (for dask-graphblas) + +try: + import numba +except ImportError: + numba = None + +if backend != "suitesparse": + pytest.skip("not suitesparse backend", allow_module_level=True) + + +@pytest.fixture(scope="module", autouse=True) +def _setup_jit(): + # Configuration values below were obtained from the output of the JIT config + # in CI, but with paths changed to use `{conda_prefix}` where appropriate. + if "CONDA_PREFIX" not in os.environ or _IS_SSGB7: + return + conda_prefix = os.environ["CONDA_PREFIX"] + gb.ss.config["jit_c_control"] = "on" + if sys.platform == "linux": + gb.ss.config["jit_c_compiler_name"] = f"{conda_prefix}/bin/x86_64-conda-linux-gnu-cc" + gb.ss.config["jit_c_compiler_flags"] = ( + "-march=nocona -mtune=haswell -ftree-vectorize -fPIC -fstack-protector-strong " + f"-fno-plt -O2 -ffunction-sections -pipe -isystem {conda_prefix}/include -Wundef " + "-std=c11 -lm -Wno-pragmas -fexcess-precision=fast -fcx-limited-range " + "-fno-math-errno -fwrapv -O3 -DNDEBUG -fopenmp -fPIC" + ) + gb.ss.config["jit_c_linker_flags"] = ( + "-Wl,-O2 -Wl,--sort-common -Wl,--as-needed -Wl,-z,relro -Wl,-z,now " + "-Wl,--disable-new-dtags -Wl,--gc-sections -Wl,--allow-shlib-undefined " + f"-Wl,-rpath,{conda_prefix}/lib -Wl,-rpath-link,{conda_prefix}/lib " + f"-L{conda_prefix}/lib -shared" + ) + gb.ss.config["jit_c_libraries"] = ( + f"-lm -ldl {conda_prefix}/lib/libgomp.so " + f"{conda_prefix}/x86_64-conda-linux-gnu/sysroot/usr/lib/libpthread.so" + ) + gb.ss.config["jit_c_cmake_libs"] = ( + f"m;dl;{conda_prefix}/lib/libgomp.so;" + f"{conda_prefix}/x86_64-conda-linux-gnu/sysroot/usr/lib/libpthread.so" + ) + elif sys.platform == "darwin": + gb.ss.config["jit_c_compiler_name"] = f"{conda_prefix}/bin/clang" + gb.ss.config["jit_c_compiler_flags"] = ( + "-march=core2 -mtune=haswell -mssse3 -ftree-vectorize -fPIC -fPIE " + f"-fstack-protector-strong -O2 -pipe -isystem {conda_prefix}/include -DGBNCPUFEAT " + "-Wno-pointer-sign -O3 -DNDEBUG -fopenmp=libomp -fPIC -arch x86_64" + ) + gb.ss.config["jit_c_linker_flags"] = ( + "-Wl,-pie -Wl,-headerpad_max_install_names -Wl,-dead_strip_dylibs " + f"-Wl,-rpath,{conda_prefix}/lib -L{conda_prefix}/lib -dynamiclib" + ) + gb.ss.config["jit_c_libraries"] = f"-lm -ldl {conda_prefix}/lib/libomp.dylib" + gb.ss.config["jit_c_cmake_libs"] = f"m;dl;{conda_prefix}/lib/libomp.dylib" + elif sys.platform == "win32": # pragma: no branch (sanity) + if "mingw" in gb.ss.config["jit_c_libraries"]: + # This probably means we're testing a `python-suitesparse-graphblas` wheel + # in a conda environment. This is not yet working. + gb.ss.config["jit_c_control"] = "off" + return + + gb.ss.config["jit_c_compiler_name"] = f"{conda_prefix}/bin/cc" + gb.ss.config["jit_c_compiler_flags"] = ( + '/DWIN32 /D_WINDOWS -DGBNCPUFEAT /O2 -wd"4244" -wd"4146" -wd"4018" ' + '-wd"4996" -wd"4047" -wd"4554" /O2 /Ob2 /DNDEBUG -openmp' + ) + gb.ss.config["jit_c_linker_flags"] = "/machine:x64" + gb.ss.config["jit_c_libraries"] = "" + gb.ss.config["jit_c_cmake_libs"] = "" + + +@pytest.fixture +def v(): + return Vector.from_coo([1, 3, 4, 6], [1, 1, 2, 0]) + + +@autocompute +def test_jit_udt(): + if _IS_SSGB7: + with pytest.raises(RuntimeError, match="JIT was added"): + dtypes.ss.register_new( + "myquaternion", "typedef struct { float x [4][4] ; int color ; } myquaternion ;" + ) + return + if gb.ss.config["jit_c_control"] == "off": + return + with burble(): + dtype = dtypes.ss.register_new( + "myquaternion", "typedef struct { float x [4][4] ; int color ; } myquaternion ;" + ) + assert not hasattr(dtypes, "myquaternion") + assert dtypes.ss.myquaternion is dtype + assert dtype.name == "myquaternion" + assert str(dtype) == "myquaternion" + assert dtype.gb_name is None + v = Vector(dtype, 2) + np_type = np.dtype([("x", "=1.25.0' -conda search 'pandas[channel=conda-forge]>=2.0.2' -conda search 'scipy[channel=conda-forge]>=1.11.0' +conda search 'pandas[channel=conda-forge]>=2.0.3' +conda search 'scipy[channel=conda-forge]>=1.11.1' conda search 'networkx[channel=conda-forge]>=3.1' -conda search 'awkward[channel=conda-forge]>=2.2.4' +conda search 'awkward[channel=conda-forge]>=2.3.0' conda search 'sparse[channel=conda-forge]>=0.14.0' conda search 'fast_matrix_market[channel=conda-forge]>=1.7.2' conda search 'numba[channel=conda-forge]>=0.57.1'