Skip to content

GalacticDynamics/unxt

unxt

Unitful Quantities in JAX

PyPI: unxt PyPI versions: unxt ReadTheDocs unxt license

ruff ruff pre-commit CodSpeed Badge

CI status ReadTheDocs codecov

DOI badge


Unxt is unitful quantities and calculations in JAX, built on Equinox and Quax.

Unxt supports JAX's compelling features:

  • JIT compilation (jit)
  • vectorization (vmap, etc.)
  • auto-differentiation (grad, jacobian, hessian)
  • GPU/TPU/multi-host acceleration

And best of all, unxt doesn't force you to use special unit-compatible re-exports of JAX libraries. You can use unxt with existing JAX code, and with quax's simple decorator, JAX will work with unxt.Quantity.

Installation

PyPI version PyPI platforms

pip install unxt
using uv
uv add unxt
from source, using pip
pip install git+https://https://github.com/GalacticDynamics/unxt.git
building from source
cd /path/to/parent
git clone https://https://github.com/GalacticDynamics/unxt.git
cd unxt
pip install -e .  # editable mode

Read The Docs

For full documentation, including installation instructions, tutorials, and API reference, please see the unxt docs. This README provides a brief overview and some quick examples.

Dimensions

Dimensions represent the physical type of a quantity, such as length, time, or mass.

import unxt as u

Create dimensions from strings:

length_dim = u.dimension("length")
print(length_dim)
# PhysicalType('length')

Dimensions support mathematical expressions:

speed_dim = u.dimension("length / time")
print(speed_dim)
# PhysicalType({'speed', 'velocity'})

Multi-word dimension names require parentheses in expressions:

activity_dim = u.dimension("(amount of substance) / (time)")
print(activity_dim)
# PhysicalType('catalytic activity')

Units

Units specify the scale and dimension of measurements.

meter = u.unit("m")
print(meter)
# Unit("m")

Units can be combined:

velocity_unit = u.unit("km/h")  # in the expression
print(velocity_unit)
# Unit("km / h")

velocity_unit2 = u.unit("km") / u.unit("h")  # via arithmetic
print(velocity_unit2)
# Unit("km / h")

Get the dimension of a unit:

print(u.dimension_of(meter))
# PhysicalType('length')

Unit Systems

Unit systems define consistent sets of base units for specific domains. unxt provides built-in unit systems and tools for creating custom ones.

Built-in Unit Systems

# SI (International System of Units)
si = u.unitsystem("si")
print(si)
# unitsystem(m, kg, s, mol, A, K, cd, rad)

# CGS (centimeter-gram-second)
cgs = u.unitsystem("cgs")
print(cgs)
# unitsystem(cm, g, s, dyn, erg, Ba, P, St, rad)

# Galactic (astrophysics)
galactic = u.unitsystem("galactic")
print(galactic)
# unitsystem(kpc, Myr, solMass, rad)

Composing Units from a Unit System

Once you have a unit system, you can get units for any physical dimension by indexing the system:

usys = u.unitsystem("si")

# Get specific units
print(usys["length"])
# Unit("m")

Custom Unit Systems

Create custom unit systems by specifying base units:

import unxt as u

# Define a custom unit system
custom_usys = u.unitsystem("km", "h", "tonne", "degree")
print(custom_usys)
# unitsystem(km, h, t, deg)

# Access derived units
print(custom_usys["velocity"])
# Unit("km / h")

Dynamical Unit Systems

For domains like gravitational dynamics, use dynamical unit systems where $G = 1$:

from unxt.unitsystems import DynamicalSimUSysFlag

# Create a dynamical system where G=1
# Only specify 2 of (length, time, mass)
usys = u.unitsystem(DynamicalSimUSysFlag, "kpc", "Myr")
print(usys)
# unitsystem(kpc, Myr, ...)

# The third dimension (mass) is computed to make G=1
print(usys["mass"])
# Unit("10^11 solMass")  # computed value

Quantities

Quantities combine values with units, providing type-safe unitful arithmetic.

Basic Quantities

import jax.numpy as jnp

x = u.Quantity(jnp.arange(1, 5, dtype=float), "km")
print(x)
# Quantity['length']([1., 2., 3., 4.], unit='km')

The constituent value and unit are accessible as attributes:

repr(x.value)
# Array([1., 2., 3., 4.], dtype=float64)

repr(x.unit)
# Unit("km")

Quantity objects obey the rules of unitful arithmetic.

# Addition / Subtraction
print(x + x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")

# Multiplication / Division
print(2 * x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")

y = u.Quantity(jnp.arange(4, 8, dtype=float), "yr")

print(x / y)
# Quantity['speed']([0.25, 0.4 , 0.5 , 0.57142857], unit='km / yr')

# Exponentiation
print(x**2)
# Quantity['area']([ 1.,  4.,  9., 16.], unit='km2')

# Unit checking on operations
try:
    x + y
except Exception as e:
    print(e)
# 'yr' (time) and 'km' (length) are not convertible

Quantities can be converted to different units:

print(u.uconvert("m", x))  # via function
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')

print(x.uconvert("m"))  # via method
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')

Since Quantity is parametric, it can do runtime dimension checking!

LengthQuantity = u.Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](2, unit='km')

try:
    LengthQuantity(2, "s")
except ValueError as e:
    print(e)
# Physical type mismatch.

BareQuantity

For performance-critical code where you don't need dimension checking, use BareQuantity:

import unxt as u
import jax.numpy as jnp

# BareQuantity skips dimension checks for better performance
bq = u.quantity.BareQuantity(jnp.array([1.0, 2.0, 3.0]), "m")
print(bq)
# BareQuantity([1., 2., 3.], unit='m')

# Works just like Quantity but without dimension validation
print(bq * 2)
# BareQuantity([2., 4., 6.], unit='m')

Angle

Angle is a specialized quantity with wrapping support for angular values:

import unxt as u
import jax.numpy as jnp

# Angles can wrap to a specified range
theta = u.Angle(jnp.array([0, 90, 180, 270, 360]), "deg")
print(theta)
# Angle([0., 90., 180., 270., 360.], unit='deg')

# Optional wrapping to a specified range
angle = u.Angle(jnp.array([370, -10]), "deg")
wrapped = angle.wrap_to(u.Q(0, "deg"), u.Q(360, "deg"))
print(wrapped)
# Angle([10., 350.], unit='deg')

StaticQuantity

For static configuration values (e.g., JAX static arguments), use StaticQuantity, which stores NumPy values and rejects JAX arrays:

import numpy as np
from functools import partial
import jax
import jax.numpy as jnp
import unxt as u

cfg = u.StaticQuantity(np.array([1.0, 2.0]), "m")


@partial(jax.jit, static_argnames=("q",))
def add(x, q):
    return x + jnp.asarray(q.value)


print(add(1.0, cfg))

StaticValue

If you want a Quantity that keeps a static value but still participates in regular arithmetic, wrap the value with StaticValue. Arithmetic behaves like the wrapped array, and StaticValue + StaticValue returns a StaticValue. Comparison operators (==, !=, <, <=, >, >=) return NumPy boolean arrays for element-wise comparison:

import numpy as np
import jax.numpy as jnp
import unxt as u

sv = u.quantity.StaticValue(np.array([1.0, 2.0]))
q_static = u.Q(sv, "m")
q = u.Q(jnp.array([3.0, 4.0]), "m")

print(q_static + q)

# Comparisons return NumPy boolean arrays (element-wise)
sv2 = u.quantity.StaticValue(np.array([2.0, 1.0]))
print(sv < sv2)  # array([ True, False])
print(sv == np.array([1.0, 2.0]))  # array([ True,  True])

JAX Integration

unxt is built on quax, which enables custom array-ish objects in JAX. For convenience we use the quaxed library, which is just a quax.quaxify wrapper around jax to avoid boilerplate code.

Note

Using quaxed is optional. You can directly use quaxify, and even apply it to the top-level function instead of individual functions.

from quaxed import grad, vmap
import quaxed.numpy as jnp

# Using the x quantity from earlier examples
print(jnp.square(x))
# Quantity['area']([ 1.,  4.,  9., 16.], unit='km2')

print(jnp.power(x, 3))
# Quantity['volume']([ 1.,  8., 27., 64.], unit='km3')

print(vmap(grad(lambda x: x**3))(x))
# Quantity['area']([ 3., 12., 27., 48.], unit='km2')

See the documentation for more examples and details of JIT and AD

Citation

JOSS DOI

If you found this library to be useful and want to support the development and maintenance of lower-level code libraries for the scientific community, please consider citing this work.

Contributing and Development

Actions Status Documentation Status codecov SPEC 0 — Minimum Supported Dependencies pre-commit ruff CodSpeed Badge

We welcome contributions! Contributions are how open source projects improve and grow.

To contribute to unxt, please fork the repository, make a development branch, develop on that branch, then open a pull request from the branch in your fork to main.

To report bugs, request features, or suggest other ideas, please open an issue.

For more information, see CONTRIBUTING.md.