From aba6999f97c452563d7cdfecf01a931154dd4e86 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Fri, 14 Jul 2023 11:38:49 -0500 Subject: [PATCH] Allow `__index__` only for integral dtypes on Scalars --- .pre-commit-config.yaml | 14 ++++++------- graphblas/core/scalar.py | 8 ++++++-- graphblas/core/ss/config.py | 7 +++---- graphblas/core/utils.py | 36 ++++++++++++++++++++++------------ graphblas/dtypes/__init__.py | 3 +++ graphblas/tests/test_scalar.py | 2 ++ scripts/check_versions.sh | 4 ++-- 7 files changed, 46 insertions(+), 28 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 726538e16..b8d767f05 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ ci: # See: https://pre-commit.ci/#configuration autofix_prs: false - autoupdate_schedule: monthly + autoupdate_schedule: quarterly autoupdate_commit_msg: "chore: update pre-commit hooks" autofix_commit_msg: "style: pre-commit fixes" skip: [pylint, no-commit-to-branch] @@ -51,7 +51,7 @@ repos: - id: isort # Let's keep `pyupgrade` even though `ruff --fix` probably does most of it - repo: https://github.com/asottile/pyupgrade - rev: v3.8.0 + rev: v3.9.0 hooks: - id: pyupgrade args: [--py38-plus] @@ -61,12 +61,12 @@ repos: - id: auto-walrus args: [--line-length, "100"] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.7.0 hooks: - id: black - id: black-jupyter - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.277 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.278 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -93,8 +93,8 @@ repos: types_or: [python, rst, markdown] additional_dependencies: [tomli] files: ^(graphblas|docs)/ - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.277 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.278 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint diff --git a/graphblas/core/scalar.py b/graphblas/core/scalar.py index b55d601af..8a95e1d71 100644 --- a/graphblas/core/scalar.py +++ b/graphblas/core/scalar.py @@ -3,7 +3,7 @@ import numpy as np from .. import backend, binary, config, monoid -from ..dtypes import _INDEX, FP64, lookup_dtype, unify +from ..dtypes import _INDEX, FP64, _index_dtypes, lookup_dtype, unify from ..exceptions import EmptyObject, check_status from . import _has_numba, _supports_udfs, automethods, ffi, lib, utils from .base import BaseExpression, BaseType, call @@ -158,7 +158,11 @@ def __int__(self): def __complex__(self): return complex(self.value) - __index__ = __int__ + @property + def __index__(self): + if self.dtype in _index_dtypes: + return self.__int__ + raise AttributeError("Scalar object only has `__index__` for integral dtypes") def __array__(self, dtype=None): if dtype is None: diff --git a/graphblas/core/ss/config.py b/graphblas/core/ss/config.py index 433716bb3..20cf318e8 100644 --- a/graphblas/core/ss/config.py +++ b/graphblas/core/ss/config.py @@ -1,10 +1,9 @@ from collections.abc import MutableMapping -from numbers import Integral from ...dtypes import lookup_dtype from ...exceptions import _error_code_lookup, check_status from .. import NULL, ffi, lib -from ..utils import values_to_numpy_buffer +from ..utils import maybe_integral, values_to_numpy_buffer class BaseConfig(MutableMapping): @@ -147,8 +146,8 @@ def __setitem__(self, key, val): bitwise = self._bitwise[key] if isinstance(val, str): val = bitwise[val.lower()] - elif isinstance(val, Integral): - val = bitwise.get(val, val) + elif (x := maybe_integral(val)) is not None: + val = bitwise.get(x, x) else: bits = 0 for x in val: diff --git a/graphblas/core/utils.py b/graphblas/core/utils.py index 74e03f2f9..7bb1a1fb0 100644 --- a/graphblas/core/utils.py +++ b/graphblas/core/utils.py @@ -1,4 +1,4 @@ -from numbers import Integral, Number +from operator import index import numpy as np @@ -158,6 +158,17 @@ def get_order(order): ) +def maybe_integral(val): + """Ensure ``val`` is an integer or return None if it's not.""" + try: + return index(val) + except TypeError: + pass + if isinstance(val, float) and val.is_integer(): + return int(val) + return None + + def normalize_chunks(chunks, shape): """Normalize chunks argument for use by ``Matrix.ss.split``. @@ -175,8 +186,8 @@ def normalize_chunks(chunks, shape): """ if isinstance(chunks, (list, tuple)): pass - elif isinstance(chunks, Number): - chunks = (chunks,) * len(shape) + elif (chunk := maybe_integral(chunks)) is not None: + chunks = (chunk,) * len(shape) elif isinstance(chunks, np.ndarray): chunks = chunks.tolist() else: @@ -192,22 +203,21 @@ def normalize_chunks(chunks, shape): for size, chunk in zip(shape, chunks): if chunk is None: cur_chunks = [size] - elif isinstance(chunk, Integral) or isinstance(chunk, float) and chunk.is_integer(): - chunk = int(chunk) - if chunk < 0: - raise ValueError(f"Chunksize must be greater than 0; got: {chunk}") - div, mod = divmod(size, chunk) - cur_chunks = [chunk] * div + elif (c := maybe_integral(chunk)) is not None: + if c < 0: + raise ValueError(f"Chunksize must be greater than 0; got: {c}") + div, mod = divmod(size, c) + cur_chunks = [c] * div if mod: cur_chunks.append(mod) elif isinstance(chunk, (list, tuple)): cur_chunks = [] none_index = None for c in chunk: - if isinstance(c, Integral) or isinstance(c, float) and c.is_integer(): - c = int(c) - if c < 0: - raise ValueError(f"Chunksize must be greater than 0; got: {c}") + if (val := maybe_integral(c)) is not None: + if val < 0: + raise ValueError(f"Chunksize must be greater than 0; got: {val}") + c = val elif c is None: if none_index is not None: raise TypeError( diff --git a/graphblas/dtypes/__init__.py b/graphblas/dtypes/__init__.py index 49e46d787..f9c144f13 100644 --- a/graphblas/dtypes/__init__.py +++ b/graphblas/dtypes/__init__.py @@ -41,3 +41,6 @@ def __getattr__(key): globals()["ss"] = ss return ss raise AttributeError(f"module {__name__!r} has no attribute {key!r}") + + +_index_dtypes = {BOOL, INT8, UINT8, INT16, UINT16, INT32, UINT32, INT64, UINT64, _INDEX} diff --git a/graphblas/tests/test_scalar.py b/graphblas/tests/test_scalar.py index 7b7c77177..cf4c6fd41 100644 --- a/graphblas/tests/test_scalar.py +++ b/graphblas/tests/test_scalar.py @@ -132,6 +132,8 @@ def test_casting(s): assert float(s) == 5.0 assert type(float(s)) is float assert range(s) == range(5) + with pytest.raises(AttributeError, match="Scalar .* only .*__index__.*integral"): + range(s.dup(float)) assert complex(s) == complex(5) assert type(complex(s)) is complex diff --git a/scripts/check_versions.sh b/scripts/check_versions.sh index ef1a76135..263b1d8f7 100755 --- a/scripts/check_versions.sh +++ b/scripts/check_versions.sh @@ -3,11 +3,11 @@ # Use, adjust, copy/paste, etc. as necessary to answer your questions. # This may be helpful when updating dependency versions in CI. # Tip: add `--json` for more information. -conda search 'numpy[channel=conda-forge]>=1.25.0' +conda search 'numpy[channel=conda-forge]>=1.25.1' conda search 'pandas[channel=conda-forge]>=2.0.3' conda search 'scipy[channel=conda-forge]>=1.11.1' conda search 'networkx[channel=conda-forge]>=3.1' -conda search 'awkward[channel=conda-forge]>=2.3.0' +conda search 'awkward[channel=conda-forge]>=2.3.1' conda search 'sparse[channel=conda-forge]>=0.14.0' conda search 'fast_matrix_market[channel=conda-forge]>=1.7.2' conda search 'numba[channel=conda-forge]>=0.57.1'