diff --git a/Cargo.lock b/Cargo.lock index 9d95d017003..ffbe6709469 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2450,11 +2450,16 @@ dependencies = [ [[package]] name = "pymath" -version = "0.0.2" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b66ab66a8610ce209d8b36cd0fecc3a15c494f715e0cb26f0586057f293abc9" +checksum = "bbfb6723b732fc7f0b29a0ee7150c7f70f947bf467b8c3e82530b13589a78b4c" dependencies = [ "libc", + "libm", + "malachite-bigint", + "num-complex", + "num-integer", + "num-traits", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a40744841c5..96e821e977c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -116,6 +116,7 @@ members = [ ".", "crates/*", ] +exclude = ["pymath"] [workspace.package] version = "0.4.0" @@ -184,7 +185,7 @@ once_cell = "1.20.3" parking_lot = "0.12.3" paste = "1.0.15" proc-macro2 = "1.0.105" -pymath = "0.0.2" +pymath = { version = "0.1.5", features = ["mul_add", "malachite-bigint", "complex"] } quote = "1.0.43" radium = "1.1.1" rand = "0.9" diff --git a/Lib/test/mathdata/cmath_testcases.txt b/Lib/test/mathdata/cmath_testcases.txt index 0165e17634f..7b98b5a2998 100644 --- a/Lib/test/mathdata/cmath_testcases.txt +++ b/Lib/test/mathdata/cmath_testcases.txt @@ -371,9 +371,9 @@ acosh1002 acosh 0.0 inf -> inf 1.5707963267948966 acosh1003 acosh 2.3 inf -> inf 1.5707963267948966 acosh1004 acosh -0.0 inf -> inf 1.5707963267948966 acosh1005 acosh -2.3 inf -> inf 1.5707963267948966 -acosh1006 acosh 0.0 nan -> nan nan +acosh1006 acosh 0.0 nan -> nan 1.5707963267948966 ignore-imag-sign acosh1007 acosh 2.3 nan -> nan nan -acosh1008 acosh -0.0 nan -> nan nan +acosh1008 acosh -0.0 nan -> nan 1.5707963267948966 ignore-imag-sign acosh1009 acosh -2.3 nan -> nan nan acosh1010 acosh -inf 0.0 -> inf 3.1415926535897931 acosh1011 acosh -inf 2.3 -> inf 3.1415926535897931 @@ -1992,9 +1992,9 @@ tanh0065 tanh 1.797e+308 0.0 -> 1.0 0.0 --special values tanh1000 tanh 0.0 0.0 -> 0.0 0.0 -tanh1001 tanh 0.0 inf -> nan nan invalid +tanh1001 tanh 0.0 inf -> 0.0 nan invalid tanh1002 tanh 2.3 inf -> nan nan invalid -tanh1003 tanh 0.0 nan -> nan nan +tanh1003 tanh 0.0 nan -> 0.0 nan tanh1004 tanh 2.3 nan -> nan nan tanh1005 tanh inf 0.0 -> 1.0 0.0 tanh1006 tanh inf 0.7 -> 1.0 0.0 @@ -2009,7 +2009,7 @@ tanh1014 tanh nan 2.3 -> nan nan tanh1015 tanh nan inf -> nan nan tanh1016 tanh nan nan -> nan nan tanh1017 tanh 0.0 -0.0 -> 0.0 -0.0 -tanh1018 tanh 0.0 -inf -> nan nan invalid +tanh1018 tanh 0.0 -inf -> 0.0 nan invalid tanh1019 tanh 2.3 -inf -> nan nan invalid tanh1020 tanh inf -0.0 -> 1.0 -0.0 tanh1021 tanh inf -0.7 -> 1.0 -0.0 @@ -2022,9 +2022,9 @@ tanh1027 tanh nan -0.0 -> nan -0.0 tanh1028 tanh nan -2.3 -> nan nan tanh1029 tanh nan -inf -> nan nan tanh1030 tanh -0.0 -0.0 -> -0.0 -0.0 -tanh1031 tanh -0.0 -inf -> nan nan invalid +tanh1031 tanh -0.0 -inf -> -0.0 nan invalid tanh1032 tanh -2.3 -inf -> nan nan invalid -tanh1033 tanh -0.0 nan -> nan nan +tanh1033 tanh -0.0 nan -> -0.0 nan tanh1034 tanh -2.3 nan -> nan nan tanh1035 tanh -inf -0.0 -> -1.0 -0.0 tanh1036 tanh -inf -0.7 -> -1.0 -0.0 @@ -2035,7 +2035,7 @@ tanh1040 tanh -inf -3.5 -> -1.0 -0.0 tanh1041 tanh -inf -inf -> -1.0 0.0 ignore-imag-sign tanh1042 tanh -inf nan -> -1.0 0.0 ignore-imag-sign tanh1043 tanh -0.0 0.0 -> -0.0 0.0 -tanh1044 tanh -0.0 inf -> nan nan invalid +tanh1044 tanh -0.0 inf -> -0.0 nan invalid tanh1045 tanh -2.3 inf -> nan nan invalid tanh1046 tanh -inf 0.0 -> -1.0 0.0 tanh1047 tanh -inf 0.7 -> -1.0 0.0 @@ -2307,9 +2307,9 @@ tan0066 tan -8.79645943005142 0.0 -> 0.7265425280053614098 0.0 -- special values tan1000 tan -0.0 0.0 -> -0.0 0.0 -tan1001 tan -inf 0.0 -> nan nan invalid +tan1001 tan -inf 0.0 -> nan 0.0 invalid tan1002 tan -inf 2.2999999999999998 -> nan nan invalid -tan1003 tan nan 0.0 -> nan nan +tan1003 tan nan 0.0 -> nan 0.0 tan1004 tan nan 2.2999999999999998 -> nan nan tan1005 tan -0.0 inf -> -0.0 1.0 tan1006 tan -0.69999999999999996 inf -> -0.0 1.0 @@ -2324,7 +2324,7 @@ tan1014 tan -2.2999999999999998 nan -> nan nan tan1015 tan -inf nan -> nan nan tan1016 tan nan nan -> nan nan tan1017 tan 0.0 0.0 -> 0.0 0.0 -tan1018 tan inf 0.0 -> nan nan invalid +tan1018 tan inf 0.0 -> nan 0.0 invalid tan1019 tan inf 2.2999999999999998 -> nan nan invalid tan1020 tan 0.0 inf -> 0.0 1.0 tan1021 tan 0.69999999999999996 inf -> 0.0 1.0 @@ -2337,9 +2337,9 @@ tan1027 tan 0.0 nan -> 0.0 nan tan1028 tan 2.2999999999999998 nan -> nan nan tan1029 tan inf nan -> nan nan tan1030 tan 0.0 -0.0 -> 0.0 -0.0 -tan1031 tan inf -0.0 -> nan nan invalid +tan1031 tan inf -0.0 -> nan -0.0 invalid tan1032 tan inf -2.2999999999999998 -> nan nan invalid -tan1033 tan nan -0.0 -> nan nan +tan1033 tan nan -0.0 -> nan -0.0 tan1034 tan nan -2.2999999999999998 -> nan nan tan1035 tan 0.0 -inf -> 0.0 -1.0 tan1036 tan 0.69999999999999996 -inf -> 0.0 -1.0 @@ -2350,7 +2350,7 @@ tan1040 tan 3.5 -inf -> 0.0 -1.0 tan1041 tan inf -inf -> -0.0 -1.0 ignore-real-sign tan1042 tan nan -inf -> -0.0 -1.0 ignore-real-sign tan1043 tan -0.0 -0.0 -> -0.0 -0.0 -tan1044 tan -inf -0.0 -> nan nan invalid +tan1044 tan -inf -0.0 -> nan -0.0 invalid tan1045 tan -inf -2.2999999999999998 -> nan nan invalid tan1046 tan -0.0 -inf -> -0.0 -1.0 tan1047 tan -0.69999999999999996 -inf -> -0.0 -1.0 diff --git a/Lib/test/mathdata/ieee754.txt b/Lib/test/mathdata/ieee754.txt index 3e986cdb102..9be667826a6 100644 --- a/Lib/test/mathdata/ieee754.txt +++ b/Lib/test/mathdata/ieee754.txt @@ -116,7 +116,7 @@ inf >>> 0 ** -1 Traceback (most recent call last): ... -ZeroDivisionError: 0.0 cannot be raised to a negative power +ZeroDivisionError: zero to a negative power >>> pow(0, NAN) nan @@ -127,31 +127,31 @@ Trigonometric Functions >>> sin(INF) Traceback (most recent call last): ... -ValueError: math domain error +ValueError: expected a finite input, got inf >>> sin(NINF) Traceback (most recent call last): ... -ValueError: math domain error +ValueError: expected a finite input, got -inf >>> sin(NAN) nan >>> cos(INF) Traceback (most recent call last): ... -ValueError: math domain error +ValueError: expected a finite input, got inf >>> cos(NINF) Traceback (most recent call last): ... -ValueError: math domain error +ValueError: expected a finite input, got -inf >>> cos(NAN) nan >>> tan(INF) Traceback (most recent call last): ... -ValueError: math domain error +ValueError: expected a finite input, got inf >>> tan(NINF) Traceback (most recent call last): ... -ValueError: math domain error +ValueError: expected a finite input, got -inf >>> tan(NAN) nan @@ -169,11 +169,11 @@ True >>> asin(INF), asin(NINF) Traceback (most recent call last): ... -ValueError: math domain error +ValueError: expected a number in range from -1 up to 1, got inf >>> acos(INF), acos(NINF) Traceback (most recent call last): ... -ValueError: math domain error +ValueError: expected a number in range from -1 up to 1, got inf >>> equal(atan(INF), PI/2), equal(atan(NINF), -PI/2) (True, True) diff --git a/Lib/test/test_cmath.py b/Lib/test/test_cmath.py index 44f1b2da638..a96a5780b31 100644 --- a/Lib/test/test_cmath.py +++ b/Lib/test/test_cmath.py @@ -276,7 +276,6 @@ def test_cmath_matches_math(self): self.rAssertAlmostEqual(math.log(v, base), z.real) self.assertEqual(0., z.imag) - @unittest.expectedFailure # TODO: RUSTPYTHON @requires_IEEE_754 def test_specific_values(self): # Some tests need to be skipped on ancient OS X versions. @@ -530,13 +529,11 @@ def testTanhSign(self): # log1p function; If that system function doesn't respect the sign # of zero, then atan and atanh will also have difficulties with # the sign of complex zeros. - @unittest.expectedFailure # TODO: RUSTPYTHON @requires_IEEE_754 def testAtanSign(self): for z in complex_zeros: self.assertComplexesAreIdentical(cmath.atan(z), z) - @unittest.expectedFailure # TODO: RUSTPYTHON @requires_IEEE_754 def testAtanhSign(self): for z in complex_zeros: @@ -583,7 +580,6 @@ def test_complex_near_zero(self): self.assertIsClose(0.001-0.001j, 0.001+0.001j, abs_tol=2e-03) self.assertIsNotClose(0.001-0.001j, 0.001+0.001j, abs_tol=1e-03) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_complex_special(self): self.assertIsNotClose(INF, INF*1j) self.assertIsNotClose(INF*1j, INF) diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index 1a4d257586b..d14336f8bac 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -573,6 +573,8 @@ def testFloor(self): #self.assertEqual(math.ceil(NINF), NINF) #self.assertTrue(math.isnan(math.floor(NAN))) + class TestFloorIsNone(float): + __floor__ = None class TestFloor: def __floor__(self): return 42 @@ -588,6 +590,7 @@ class TestBadFloor: self.assertEqual(math.floor(FloatLike(41.9)), 41) self.assertRaises(TypeError, math.floor, TestNoFloor()) self.assertRaises(ValueError, math.floor, TestBadFloor()) + self.assertRaises(TypeError, math.floor, TestFloorIsNone(3.5)) t = TestNoFloor() t.__floor__ = lambda *args: args @@ -1125,6 +1128,15 @@ def __index__(self): with self.assertRaises(TypeError): math.isqrt(value) + @support.bigmemtest(2**32, memuse=0.85) + def test_isqrt_huge(self, size): + if size & 1: + size += 1 + v = 1 << size + w = math.isqrt(v) + self.assertEqual(w.bit_length(), size // 2 + 1) + self.assertEqual(w.bit_count(), 1) + def test_lcm(self): lcm = math.lcm self.assertEqual(lcm(0, 0), 0) @@ -1272,6 +1284,13 @@ def testLog10(self): self.assertEqual(math.log(INF), INF) self.assertTrue(math.isnan(math.log10(NAN))) + @support.bigmemtest(2**32, memuse=0.2) + def test_log_huge_integer(self, size): + v = 1 << size + self.assertAlmostEqual(math.log2(v), size) + self.assertAlmostEqual(math.log(v), size * 0.6931471805599453) + self.assertAlmostEqual(math.log10(v), size * 0.3010299956639812) + def testSumProd(self): sumprod = math.sumprod Decimal = decimal.Decimal @@ -1380,7 +1399,6 @@ def test_sumprod_accuracy(self): self.assertEqual(sumprod([True, False] * 10, [0.1] * 20), 1.0) self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0) - @unittest.skip("TODO: RUSTPYTHON, Taking a few minutes.") @support.requires_resource('cpu') def test_sumprod_stress(self): sumprod = math.sumprod @@ -2020,7 +2038,6 @@ def test_exceptions(self): else: self.fail("sqrt(-1) didn't raise ValueError") - @unittest.expectedFailure # TODO: RUSTPYTHON @requires_IEEE_754 def test_testfile(self): # Some tests need to be skipped on ancient OS X versions. @@ -2495,6 +2512,46 @@ def test_input_exceptions(self): self.assertRaises(TypeError, math.atan2, 1.0) self.assertRaises(TypeError, math.atan2, 1.0, 2.0, 3.0) + def test_exception_messages(self): + x = -1.1 + with self.assertRaisesRegex(ValueError, + f"expected a nonnegative input, got {x}"): + math.sqrt(x) + with self.assertRaisesRegex(ValueError, + f"expected a positive input, got {x}"): + math.log(x) + with self.assertRaisesRegex(ValueError, + f"expected a positive input, got {x}"): + math.log(123, x) + with self.assertRaisesRegex(ValueError, + f"expected a positive input, got {x}"): + math.log(x, 123) + with self.assertRaisesRegex(ValueError, + f"expected a positive input, got {x}"): + math.log2(x) + with self.assertRaisesRegex(ValueError, + f"expected a positive input, got {x}"): + math.log10(x) + x = decimal.Decimal('-1.1') + with self.assertRaisesRegex(ValueError, + f"expected a positive input, got {x}"): + math.log(x) + x = fractions.Fraction(1, 10**400) + with self.assertRaisesRegex(ValueError, + f"expected a positive input, got {float(x)}"): + math.log(x) + x = -123 + with self.assertRaisesRegex(ValueError, + "expected a positive input$"): + math.log(x) + with self.assertRaisesRegex(ValueError, + f"expected a noninteger or positive integer, got {x}"): + math.gamma(x) + x = 1.0 + with self.assertRaisesRegex(ValueError, + f"expected a number between -1 and 1, got {x}"): + math.atanh(x) + # Custom assertions. def assertIsNaN(self, value): @@ -2724,6 +2781,9 @@ def test_fma_infinities(self): or (sys.platform == "android" and platform.machine() == "x86_64") or support.linked_to_musl(), # gh-131032 f"this platform doesn't implement IEE 754-2008 properly") + # gh-131032: musl is fixed but the fix is not yet released; when the fixed + # version is known change this to: + # or support.linked_to_musl() < (1, ,

) def test_fma_zero_result(self): nonnegative_finites = [0.0, 1e-300, 2.3, 1e300] diff --git a/crates/stdlib/src/cmath.rs b/crates/stdlib/src/cmath.rs index 6ce471195cc..e7abb3bcad3 100644 --- a/crates/stdlib/src/cmath.rs +++ b/crates/stdlib/src/cmath.rs @@ -1,6 +1,5 @@ -// TODO: Keep track of rust-num/num-complex/issues/2. A common trait could help with duplication -// that exists between cmath and math. pub(crate) use cmath::make_module; + #[pymodule] mod cmath { use crate::vm::{ @@ -9,137 +8,141 @@ mod cmath { }; use num_complex::Complex64; + use crate::math::pymath_exception; + // Constants - #[pyattr] - use core::f64::consts::{E as e, PI as pi, TAU as tau}; + #[pyattr(name = "e")] + const E: f64 = pymath::cmath::E; + #[pyattr(name = "pi")] + const PI: f64 = pymath::cmath::PI; + #[pyattr(name = "tau")] + const TAU: f64 = pymath::cmath::TAU; #[pyattr(name = "inf")] - const INF: f64 = f64::INFINITY; + const INF: f64 = pymath::cmath::INF; #[pyattr(name = "nan")] - const NAN: f64 = f64::NAN; + const NAN: f64 = pymath::cmath::NAN; #[pyattr(name = "infj")] - const INFJ: Complex64 = Complex64::new(0., f64::INFINITY); + const INFJ: Complex64 = pymath::cmath::INFJ; #[pyattr(name = "nanj")] - const NANJ: Complex64 = Complex64::new(0., f64::NAN); + const NANJ: Complex64 = pymath::cmath::NANJ; #[pyfunction] - fn phase(z: ArgIntoComplex) -> f64 { - z.into_complex().arg() + fn phase(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::phase(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn polar(x: ArgIntoComplex) -> (f64, f64) { - x.into_complex().to_polar() + fn polar(x: ArgIntoComplex, vm: &VirtualMachine) -> PyResult<(f64, f64)> { + pymath::cmath::polar(x.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn rect(r: ArgIntoFloat, phi: ArgIntoFloat) -> Complex64 { - Complex64::from_polar(r.into_float(), phi.into_float()) + fn rect(r: ArgIntoFloat, phi: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + pymath::cmath::rect(r.into_float(), phi.into_float()) + .map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn isinf(z: ArgIntoComplex) -> bool { - let Complex64 { re, im } = z.into_complex(); - re.is_infinite() || im.is_infinite() + pymath::cmath::isinf(z.into_complex()) } #[pyfunction] fn isfinite(z: ArgIntoComplex) -> bool { - z.into_complex().is_finite() + pymath::cmath::isfinite(z.into_complex()) } #[pyfunction] fn isnan(z: ArgIntoComplex) -> bool { - z.into_complex().is_nan() + pymath::cmath::isnan(z.into_complex()) } #[pyfunction] fn exp(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { - let z = z.into_complex(); - result_or_overflow(z, z.exp(), vm) + pymath::cmath::exp(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn sqrt(z: ArgIntoComplex) -> Complex64 { - z.into_complex().sqrt() + fn sqrt(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::sqrt(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn sin(z: ArgIntoComplex) -> Complex64 { - z.into_complex().sin() + fn sin(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::sin(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn asin(z: ArgIntoComplex) -> Complex64 { - z.into_complex().asin() + fn asin(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::asin(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn cos(z: ArgIntoComplex) -> Complex64 { - z.into_complex().cos() + fn cos(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::cos(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn acos(z: ArgIntoComplex) -> Complex64 { - z.into_complex().acos() + fn acos(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::acos(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn log(z: ArgIntoComplex, base: OptionalArg) -> Complex64 { - // TODO: Complex64.log with a negative base yields wrong results. - // Issue is with num_complex::Complex64 implementation of log - // which returns NaN when base is negative. - // log10(z) / log10(base) yields correct results but division - // doesn't handle pos/neg zero nicely. (i.e log(1, 0.5)) - z.into_complex().log( - base.into_option() - .map(|base| base.into_complex().re) - .unwrap_or(core::f64::consts::E), + fn log( + z: ArgIntoComplex, + base: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + pymath::cmath::log( + z.into_complex(), + base.into_option().map(|b| b.into_complex()), ) + .map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn log10(z: ArgIntoComplex) -> Complex64 { - z.into_complex().log(10.0) + fn log10(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::log10(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn acosh(z: ArgIntoComplex) -> Complex64 { - z.into_complex().acosh() + fn acosh(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::acosh(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn atan(z: ArgIntoComplex) -> Complex64 { - z.into_complex().atan() + fn atan(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::atan(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn atanh(z: ArgIntoComplex) -> Complex64 { - z.into_complex().atanh() + fn atanh(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::atanh(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn tan(z: ArgIntoComplex) -> Complex64 { - z.into_complex().tan() + fn tan(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::tan(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn tanh(z: ArgIntoComplex) -> Complex64 { - z.into_complex().tanh() + fn tanh(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::tanh(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn sinh(z: ArgIntoComplex) -> Complex64 { - z.into_complex().sinh() + fn sinh(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::sinh(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn cosh(z: ArgIntoComplex) -> Complex64 { - z.into_complex().cosh() + fn cosh(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::cosh(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn asinh(z: ArgIntoComplex) -> Complex64 { - z.into_complex().asinh() + fn asinh(z: ArgIntoComplex, vm: &VirtualMachine) -> PyResult { + pymath::cmath::asinh(z.into_complex()).map_err(|err| pymath_exception(err, vm)) } #[derive(FromArgs)] @@ -158,52 +161,10 @@ mod cmath { fn isclose(args: IsCloseArgs, vm: &VirtualMachine) -> PyResult { let a = args.a.into_complex(); let b = args.b.into_complex(); - let rel_tol = args.rel_tol.map_or(1e-09, |v| v.into_float()); - let abs_tol = args.abs_tol.map_or(0.0, |v| v.into_float()); - - if rel_tol < 0.0 || abs_tol < 0.0 { - return Err(vm.new_value_error("tolerances must be non-negative")); - } - - if a == b { - /* short circuit exact equality -- needed to catch two infinities of - the same sign. And perhaps speeds things up a bit sometimes. - */ - return Ok(true); - } - - /* This catches the case of two infinities of opposite sign, or - one infinity and one finite number. Two infinities of opposite - sign would otherwise have an infinite relative tolerance. - Two infinities of the same sign are caught by the equality check - above. - */ - if a.is_infinite() || b.is_infinite() { - return Ok(false); - } + let rel_tol = args.rel_tol.into_option().map(|v| v.into_float()); + let abs_tol = args.abs_tol.into_option().map(|v| v.into_float()); - let diff = c_abs(b - a); - - Ok(diff <= (rel_tol * c_abs(b)) || (diff <= (rel_tol * c_abs(a))) || diff <= abs_tol) - } - - #[inline] - fn c_abs(Complex64 { re, im }: Complex64) -> f64 { - re.hypot(im) - } - - #[inline] - fn result_or_overflow( - value: Complex64, - result: Complex64, - vm: &VirtualMachine, - ) -> PyResult { - if !result.is_finite() && value.is_finite() { - // CPython doesn't return `inf` when called with finite - // values, it raises OverflowError instead. - Err(vm.new_overflow_error("math range error")) - } else { - Ok(result) - } + pymath::cmath::isclose(a, b, rel_tol, abs_tol) + .map_err(|_| vm.new_value_error("tolerances must be non-negative")) } } diff --git a/crates/stdlib/src/math.rs b/crates/stdlib/src/math.rs index fb8945c74f7..1e014e49e24 100644 --- a/crates/stdlib/src/math.rs +++ b/crates/stdlib/src/math.rs @@ -1,70 +1,48 @@ pub(crate) use math::make_module; -use crate::{builtins::PyBaseExceptionRef, vm::VirtualMachine}; +use crate::vm::{VirtualMachine, builtins::PyBaseExceptionRef}; #[pymodule] mod math { use crate::vm::{ - PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine, + AsObject, PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine, builtins::{PyFloat, PyInt, PyIntRef, PyStrInterned, try_bigint_to_f64, try_f64_to_bigint}, function::{ArgIndex, ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs}, identifier, }; - use core::cmp::Ordering; - use itertools::Itertools; use malachite_bigint::BigInt; - use num_traits::{One, Signed, ToPrimitive, Zero}; - use rustpython_common::{float_ops, int::true_div}; + use num_traits::{Signed, ToPrimitive}; + + use super::{float_repr, pymath_exception}; // Constants #[pyattr] use core::f64::consts::{E as e, PI as pi, TAU as tau}; - use super::pymath_error_to_exception; #[pyattr(name = "inf")] const INF: f64 = f64::INFINITY; #[pyattr(name = "nan")] const NAN: f64 = f64::NAN; - // Helper macro: - macro_rules! call_math_func { - ( $fun:ident, $name:ident, $vm:ident ) => {{ - let value = $name.into_float(); - let result = value.$fun(); - result_or_overflow(value, result, $vm) - }}; - } - - #[inline] - fn result_or_overflow(value: f64, result: f64, vm: &VirtualMachine) -> PyResult { - if !result.is_finite() && value.is_finite() { - // CPython doesn't return `inf` when called with finite - // values, it raises OverflowError instead. - Err(vm.new_overflow_error("math range error")) - } else { - Ok(result) - } - } - // Number theory functions: #[pyfunction] fn fabs(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(abs, x, vm) + pymath::math::fabs(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn isfinite(x: ArgIntoFloat) -> bool { - x.into_float().is_finite() + pymath::math::isfinite(x.into_float()) } #[pyfunction] fn isinf(x: ArgIntoFloat) -> bool { - x.into_float().is_infinite() + pymath::math::isinf(x.into_float()) } #[pyfunction] fn isnan(x: ArgIntoFloat) -> bool { - x.into_float().is_nan() + pymath::math::isnan(x.into_float()) } #[derive(FromArgs)] @@ -83,420 +61,286 @@ mod math { fn isclose(args: IsCloseArgs, vm: &VirtualMachine) -> PyResult { let a = args.a.into_float(); let b = args.b.into_float(); - let rel_tol = args.rel_tol.map_or(1e-09, |v| v.into_float()); - let abs_tol = args.abs_tol.map_or(0.0, |v| v.into_float()); - - if rel_tol < 0.0 || abs_tol < 0.0 { - return Err(vm.new_value_error("tolerances must be non-negative")); - } - - if a == b { - /* short circuit exact equality -- needed to catch two infinities of - the same sign. And perhaps speeds things up a bit sometimes. - */ - return Ok(true); - } + let rel_tol = args.rel_tol.into_option().map(|v| v.into_float()); + let abs_tol = args.abs_tol.into_option().map(|v| v.into_float()); - /* This catches the case of two infinities of opposite sign, or - one infinity and one finite number. Two infinities of opposite - sign would otherwise have an infinite relative tolerance. - Two infinities of the same sign are caught by the equality check - above. - */ - - if a.is_infinite() || b.is_infinite() { - return Ok(false); - } - - let diff = (b - a).abs(); - - Ok((diff <= (rel_tol * b).abs()) || (diff <= (rel_tol * a).abs()) || (diff <= abs_tol)) + pymath::math::isclose(a, b, rel_tol, abs_tol) + .map_err(|_| vm.new_value_error("tolerances must be non-negative")) } #[pyfunction] - fn copysign(x: ArgIntoFloat, y: ArgIntoFloat) -> f64 { - x.into_float().copysign(y.into_float()) + fn copysign(x: ArgIntoFloat, y: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + pymath::math::copysign(x.into_float(), y.into_float()) + .map_err(|err| pymath_exception(err, vm)) } // Power and logarithmic functions: #[pyfunction] fn exp(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(exp, x, vm) + pymath::math::exp(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn exp2(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(exp2, x, vm) + pymath::math::exp2(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn expm1(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(exp_m1, x, vm) + pymath::math::expm1(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn log(x: PyObjectRef, base: OptionalArg, vm: &VirtualMachine) -> PyResult { - let base: f64 = base.map(Into::into).unwrap_or(core::f64::consts::E); - if base.is_sign_negative() { - return Err(vm.new_value_error("math domain error")); + let base = base.into_option().map(|v| v.into_float()); + // Check base first for proper error messages + if let Some(b) = base { + if b <= 0.0 { + return Err(vm.new_value_error(format!( + "expected a positive input, got {}", + super::float_repr(b) + ))); + } + if b == 1.0 { + return Err(vm.new_value_error("math domain error".to_owned())); + } + } + // Handle BigInt specially for large values (only for actual int type, not float) + if let Some(i) = x.downcast_ref::() { + return pymath::math::log_bigint(i.as_bigint(), base).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error("expected a positive input".to_owned()), + _ => pymath_exception(err, vm), + }); } - log2(x, vm).map(|log_x| log_x / base.log2()) + let val = x.try_float(vm)?.to_f64(); + pymath::math::log(val, base).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error(format!( + "expected a positive input, got {}", + super::float_repr(val) + )), + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn log1p(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - if x.is_nan() || x > -1.0_f64 { - Ok(x.ln_1p()) - } else { - Err(vm.new_value_error("math domain error")) - } - } - - /// Generates the base-2 logarithm of a BigInt `x` - fn int_log2(x: &BigInt) -> f64 { - // log2(x) = log2(2^n * 2^-n * x) = n + log2(x/2^n) - // If we set 2^n to be the greatest power of 2 below x, then x/2^n is in [1, 2), and can - // thus be converted into a float. - let n = x.bits() as u32 - 1; - let frac = true_div(x, &BigInt::from(2).pow(n)); - f64::from(n) + frac.log2() + pymath::math::log1p(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn log2(x: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match x.try_float(vm) { - Ok(x) => { - let x = x.to_f64(); - if x.is_nan() || x > 0.0_f64 { - Ok(x.log2()) - } else { - Err(vm.new_value_error("math domain error")) - } - } - Err(float_err) => { - if let Ok(x) = x.try_int(vm) { - let x = x.as_bigint(); - if x.is_positive() { - Ok(int_log2(x)) - } else { - Err(vm.new_value_error("math domain error")) - } - } else { - // Return the float error, as it will be more intuitive to users - Err(float_err) - } - } + // Handle BigInt specially for large values (only for actual int type, not float) + if let Some(i) = x.downcast_ref::() { + return pymath::math::log2_bigint(i.as_bigint()).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error("expected a positive input".to_owned()), + _ => pymath_exception(err, vm), + }); } + let val = x.try_float(vm)?.to_f64(); + pymath::math::log2(val).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error(format!( + "expected a positive input, got {}", + super::float_repr(val) + )), + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn log10(x: PyObjectRef, vm: &VirtualMachine) -> PyResult { - log2(x, vm).map(|log_x| log_x / 10f64.log2()) - } - - #[pyfunction] - fn pow(x: ArgIntoFloat, y: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - let y = y.into_float(); - - if x < 0.0 && x.is_finite() && y.fract() != 0.0 && y.is_finite() - || x == 0.0 && y < 0.0 && y != f64::NEG_INFINITY - { - return Err(vm.new_value_error("math domain error")); - } - - let value = x.powf(y); - - if x.is_finite() && y.is_finite() && value.is_infinite() { - return Err(vm.new_overflow_error("math range error")); + // Handle BigInt specially for large values (only for actual int type, not float) + if let Some(i) = x.downcast_ref::() { + return pymath::math::log10_bigint(i.as_bigint()).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error("expected a positive input".to_owned()), + _ => pymath_exception(err, vm), + }); } - - Ok(value) + let val = x.try_float(vm)?.to_f64(); + pymath::math::log10(val).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error(format!( + "expected a positive input, got {}", + super::float_repr(val) + )), + _ => pymath_exception(err, vm), + }) } #[pyfunction] - fn sqrt(value: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let value = value.into_float(); - if value.is_nan() { - return Ok(value); - } - if value.is_sign_negative() { - if value.is_zero() { - return Ok(-0.0f64); - } - return Err(vm.new_value_error("math domain error")); - } - Ok(value.sqrt()) + fn pow(x: ArgIntoFloat, y: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + pymath::math::pow(x.into_float(), y.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn isqrt(x: ArgIndex, vm: &VirtualMachine) -> PyResult { - let x = x.into_int_ref(); - let value = x.as_bigint(); - - if value.is_negative() { - return Err(vm.new_value_error("isqrt() argument must be nonnegative")); - } - Ok(value.sqrt()) + fn sqrt(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + let val = x.into_float(); + pymath::math::sqrt(val).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error(format!( + "expected a nonnegative input, got {}", + super::float_repr(val) + )), + _ => pymath_exception(err, vm), + }) } // Trigonometric functions: #[pyfunction] fn acos(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - if x.is_nan() || (-1.0_f64..=1.0_f64).contains(&x) { - Ok(x.acos()) - } else { - Err(vm.new_value_error("math domain error")) - } + let val = x.into_float(); + pymath::math::acos(val).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error(format!( + "expected a number in range from -1 up to 1, got {}", + float_repr(val) + )), + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn asin(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - if x.is_nan() || (-1.0_f64..=1.0_f64).contains(&x) { - Ok(x.asin()) - } else { - Err(vm.new_value_error("math domain error")) - } + let val = x.into_float(); + pymath::math::asin(val).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error(format!( + "expected a number in range from -1 up to 1, got {}", + float_repr(val) + )), + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn atan(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(atan, x, vm) + pymath::math::atan(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn atan2(y: ArgIntoFloat, x: ArgIntoFloat) -> f64 { - y.into_float().atan2(x.into()) + fn atan2(y: ArgIntoFloat, x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + pymath::math::atan2(y.into_float(), x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn cos(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - if x.is_infinite() { - return Err(vm.new_value_error("math domain error")); - } - result_or_overflow(x, x.cos(), vm) + let val = x.into_float(); + pymath::math::cos(val).map_err(|err| match err { + pymath::Error::EDOM => { + vm.new_value_error(format!("expected a finite input, got {}", float_repr(val))) + } + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn hypot(coordinates: PosArgs) -> f64 { - let mut coordinates = ArgIntoFloat::vec_into_f64(coordinates.into_vec()); - let mut max = 0.0; - let mut has_nan = false; - for f in &mut coordinates { - *f = f.abs(); - if f.is_nan() { - has_nan = true; - } else if *f > max { - max = *f - } - } - // inf takes precedence over nan - if max.is_infinite() { - return max; - } - if has_nan { - return f64::NAN; - } - coordinates.sort_unstable_by(|x, y| x.total_cmp(y).reverse()); - vector_norm(&coordinates) - } - - /// Implementation of accurate hypotenuse algorithm from Borges 2019. - /// See https://arxiv.org/abs/1904.09481. - /// This assumes that its arguments are positive finite and have been scaled to avoid overflow - /// and underflow. - fn accurate_hypot(max: f64, min: f64) -> f64 { - if min <= max * (f64::EPSILON / 2.0).sqrt() { - return max; - } - let hypot = max.mul_add(max, min * min).sqrt(); - let hypot_sq = hypot * hypot; - let max_sq = max * max; - let correction = (-min).mul_add(min, hypot_sq - max_sq) + hypot.mul_add(hypot, -hypot_sq) - - max.mul_add(max, -max_sq); - hypot - correction / (2.0 * hypot) - } - - /// Calculates the norm of the vector given by `v`. - /// `v` is assumed to be a list of non-negative finite floats, sorted in descending order. - fn vector_norm(v: &[f64]) -> f64 { - // Drop zeros from the vector. - let zero_count = v.iter().rev().cloned().take_while(|x| *x == 0.0).count(); - let v = &v[..v.len() - zero_count]; - if v.is_empty() { - return 0.0; - } - if v.len() == 1 { - return v[0]; - } - // Calculate scaling to avoid overflow / underflow. - let max = *v.first().unwrap(); - let min = *v.last().unwrap(); - let scale = if max > (f64::MAX / v.len() as f64).sqrt() { - max - } else if min < f64::MIN_POSITIVE.sqrt() { - // ^ This can be an `else if`, because if the max is near f64::MAX and the min is near - // f64::MIN_POSITIVE, then the min is relatively unimportant and will be effectively - // ignored. - min - } else { - 1.0 - }; - let mut norm = v - .iter() - .copied() - .map(|x| x / scale) - .reduce(accurate_hypot) - .unwrap_or_default(); - if v.len() > 2 { - // For larger lists of numbers, we can accumulate a rounding error, so a correction is - // needed, similar to that in `accurate_hypot()`. - // First, we estimate [sum of squares - norm^2], then we add the first-order - // approximation of the square root of that to `norm`. - let correction = v - .iter() - .copied() - .map(|x| (x / scale).powi(2)) - .chain(core::iter::once(-norm * norm)) - // Pairwise summation of floats gives less rounding error than a naive sum. - .tree_reduce(core::ops::Add::add) - .expect("expected at least 1 element"); - norm = norm + correction / (2.0 * norm); - } - norm * scale + let coords = ArgIntoFloat::vec_into_f64(coordinates.into_vec()); + pymath::math::hypot(&coords) } #[pyfunction] fn dist(p: Vec, q: Vec, vm: &VirtualMachine) -> PyResult { - let mut max = 0.0; - let mut has_nan = false; - let p = ArgIntoFloat::vec_into_f64(p); let q = ArgIntoFloat::vec_into_f64(q); - let mut diffs = vec![]; - if p.len() != q.len() { return Err(vm.new_value_error("both points must have the same number of dimensions")); } - - for i in 0..p.len() { - let px = p[i]; - let qx = q[i]; - - let x = (px - qx).abs(); - if x.is_nan() { - has_nan = true; - } - - diffs.push(x); - if x > max { - max = x; - } - } - - if max.is_infinite() { - return Ok(max); - } - if has_nan { - return Ok(f64::NAN); - } - diffs.sort_unstable_by(|x, y| x.total_cmp(y).reverse()); - Ok(vector_norm(&diffs)) + Ok(pymath::math::dist(&p, &q)) } #[pyfunction] fn sin(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - if x.is_infinite() { - return Err(vm.new_value_error("math domain error")); - } - result_or_overflow(x, x.sin(), vm) + let val = x.into_float(); + pymath::math::sin(val).map_err(|err| match err { + pymath::Error::EDOM => { + vm.new_value_error(format!("expected a finite input, got {}", float_repr(val))) + } + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn tan(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - if x.is_infinite() { - return Err(vm.new_value_error("math domain error")); - } - result_or_overflow(x, x.tan(), vm) + let val = x.into_float(); + pymath::math::tan(val).map_err(|err| match err { + pymath::Error::EDOM => { + vm.new_value_error(format!("expected a finite input, got {}", float_repr(val))) + } + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn degrees(x: ArgIntoFloat) -> f64 { - x.into_float() * (180.0 / core::f64::consts::PI) + pymath::math::degrees(x.into_float()) } #[pyfunction] fn radians(x: ArgIntoFloat) -> f64 { - x.into_float() * (core::f64::consts::PI / 180.0) + pymath::math::radians(x.into_float()) } // Hyperbolic functions: #[pyfunction] fn acosh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - if x.is_sign_negative() || x.is_zero() { - Err(vm.new_value_error("math domain error")) - } else { - Ok(x.acosh()) - } + pymath::math::acosh(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn asinh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(asinh, x, vm) + pymath::math::asinh(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn atanh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - if x >= 1.0_f64 || x <= -1.0_f64 { - Err(vm.new_value_error("math domain error")) - } else { - Ok(x.atanh()) - } + let val = x.into_float(); + pymath::math::atanh(val).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error(format!( + "expected a number between -1 and 1, got {}", + super::float_repr(val) + )), + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn cosh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(cosh, x, vm) + pymath::math::cosh(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn sinh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(sinh, x, vm) + pymath::math::sinh(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn tanh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - call_math_func!(tanh, x, vm) + pymath::math::tanh(x.into_float()).map_err(|err| pymath_exception(err, vm)) } // Special functions: #[pyfunction] - fn erf(x: ArgIntoFloat) -> f64 { - pymath::erf(x.into()) + fn erf(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + pymath::math::erf(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn erfc(x: ArgIntoFloat) -> f64 { - pymath::erfc(x.into()) + fn erfc(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + pymath::math::erfc(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn gamma(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - pymath::gamma(x.into()).map_err(|err| pymath_error_to_exception(err, vm)) + let val = x.into_float(); + pymath::math::gamma(val).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error(format!( + "expected a noninteger or positive integer, got {}", + super::float_repr(val) + )), + _ => pymath_exception(err, vm), + }) } #[pyfunction] fn lgamma(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - pymath::lgamma(x.into()).map_err(|err| pymath_error_to_exception(err, vm)) + pymath::math::lgamma(x.into_float()).map_err(|err| pymath_exception(err, vm)) } fn try_magic_method( @@ -521,37 +365,43 @@ mod math { #[pyfunction] fn ceil(x: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let result_or_err = try_magic_method(identifier!(vm, __ceil__), vm, &x); - if result_or_err.is_err() - && let Some(v) = x.try_float_opt(vm) - { + // Only call __ceil__ if the class defines it - if it exists but is not callable, + // the error should be propagated (not fall back to float conversion) + if x.class().has_attr(identifier!(vm, __ceil__)) { + return try_magic_method(identifier!(vm, __ceil__), vm, &x); + } + // __ceil__ not defined - fall back to float conversion + if let Some(v) = x.try_float_opt(vm) { let v = try_f64_to_bigint(v?.to_f64().ceil(), vm)?; return Ok(vm.ctx.new_int(v).into()); } - result_or_err + Err(vm.new_type_error(format!( + "type '{}' doesn't define '__ceil__' method", + x.class().name(), + ))) } #[pyfunction] fn floor(x: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let result_or_err = try_magic_method(identifier!(vm, __floor__), vm, &x); - if result_or_err.is_err() - && let Some(v) = x.try_float_opt(vm) - { + // Only call __floor__ if the class defines it - if it exists but is not callable, + // the error should be propagated (not fall back to float conversion) + if x.class().has_attr(identifier!(vm, __floor__)) { + return try_magic_method(identifier!(vm, __floor__), vm, &x); + } + // __floor__ not defined - fall back to float conversion + if let Some(v) = x.try_float_opt(vm) { let v = try_f64_to_bigint(v?.to_f64().floor(), vm)?; return Ok(vm.ctx.new_int(v).into()); } - result_or_err + Err(vm.new_type_error(format!( + "type '{}' doesn't define '__floor__' method", + x.class().name(), + ))) } #[pyfunction] fn frexp(x: ArgIntoFloat) -> (f64, i32) { - let value: f64 = x.into(); - if value.is_finite() { - let (m, exp) = float_ops::decompose_float(value); - (m * value.signum(), exp) - } else { - (value, 0) - } + pymath::math::frexp(x.into_float()) } #[pyfunction] @@ -564,315 +414,24 @@ mod math { Either::A(f) => f.to_f64(), Either::B(z) => try_bigint_to_f64(z.as_bigint(), vm)?, }; - - if value == 0_f64 || !value.is_finite() { - // NaNs, zeros and infinities are returned unchanged - return Ok(value); - } - - // Using IEEE 754 bit manipulation to handle large exponents correctly. - // Direct multiplication would overflow for large i values, especially when computing - // the largest finite float (i=1024, x<1.0). By directly modifying the exponent bits, - // we avoid intermediate overflow to infinity. - - // Scale subnormals to normal range first, then adjust exponent. - let (mant, exp0) = if value.abs() < f64::MIN_POSITIVE { - let scaled = value * (1u64 << 54) as f64; // multiply by 2^54 - let (mant_scaled, exp_scaled) = float_ops::decompose_float(scaled); - (mant_scaled, exp_scaled - 54) // adjust exponent back - } else { - float_ops::decompose_float(value) - }; - - let i_big = i.as_bigint(); - let overflow_bound = BigInt::from(1024_i32 - exp0); // i > 1024 - exp0 => overflow - if i_big > &overflow_bound { - return Err(vm.new_overflow_error("math range error")); - } - if i_big == &overflow_bound && mant == 1.0 { - return Err(vm.new_overflow_error("math range error")); - } - let underflow_bound = BigInt::from(-1074_i32 - exp0); // i < -1074 - exp0 => 0.0 with sign - if i_big < &underflow_bound { - return Ok(0.0f64.copysign(value)); - } - - let i_small: i32 = i_big - .to_i32() - .expect("exponent within [-1074-exp0, 1024-exp0] must fit in i32"); - let exp = exp0 + i_small; - - const SIGN_MASK: u64 = 0x8000_0000_0000_0000; - const FRAC_MASK: u64 = 0x000F_FFFF_FFFF_FFFF; - let sign_bit: u64 = if value.is_sign_negative() { - SIGN_MASK - } else { - 0 - }; - let mant_bits = mant.to_bits() & FRAC_MASK; - if exp >= -1021 { - let e_bits = (1022_i32 + exp) as u64; - let result_bits = sign_bit | (e_bits << 52) | mant_bits; - return Ok(f64::from_bits(result_bits)); - } - - let full_mant: u64 = (1u64 << 52) | mant_bits; - let shift: u32 = (-exp - 1021) as u32; - let frac_shifted = full_mant >> shift; - let lost_bits = full_mant & ((1u64 << shift) - 1); - - let half = 1u64 << (shift - 1); - let frac = if (lost_bits > half) || (lost_bits == half && (frac_shifted & 1) == 1) { - frac_shifted + 1 - } else { - frac_shifted - }; - - let result_bits = if frac >= (1u64 << 52) { - sign_bit | (1u64 << 52) - } else { - sign_bit | frac - }; - Ok(f64::from_bits(result_bits)) - } - - fn math_perf_arb_len_int_op(args: PosArgs, op: F, default: BigInt) -> BigInt - where - F: Fn(&BigInt, &PyInt) -> BigInt, - { - let arg_vec = args.into_vec(); - - if arg_vec.is_empty() { - return default; - } else if arg_vec.len() == 1 { - return op(arg_vec[0].as_ref().as_bigint(), arg_vec[0].as_ref()); - } - - let mut res = arg_vec[0].as_ref().as_bigint().clone(); - for num in &arg_vec[1..] { - res = op(&res, num.as_ref()) - } - res + pymath::math::ldexp_bigint(value, i.as_bigint()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] - fn gcd(args: PosArgs) -> BigInt { - use num_integer::Integer; - math_perf_arb_len_int_op(args, |x, y| x.gcd(y.as_bigint()), BigInt::zero()) - } - - #[pyfunction] - fn lcm(args: PosArgs) -> BigInt { - use num_integer::Integer; - math_perf_arb_len_int_op(args, |x, y| x.lcm(y.as_bigint()), BigInt::one()) - } - - #[pyfunction] - fn cbrt(x: ArgIntoFloat) -> f64 { - x.into_float().cbrt() + fn cbrt(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { + pymath::math::cbrt(x.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn fsum(seq: ArgIterable, vm: &VirtualMachine) -> PyResult { - let mut partials = Vec::with_capacity(32); - let mut special_sum = 0.0; - let mut inf_sum = 0.0; - - for obj in seq.iter(vm)? { - let mut x = obj?.into_float(); - - let xsave = x; - let mut i = 0; - // This inner loop applies `hi`/`lo` summation to each - // partial so that the list of partial sums remains exact. - for j in 0..partials.len() { - let mut y: f64 = partials[j]; - if x.abs() < y.abs() { - core::mem::swap(&mut x, &mut y); - } - // Rounded `x+y` is stored in `hi` with round-off stored in - // `lo`. Together `hi+lo` are exactly equal to `x+y`. - let hi = x + y; - let lo = y - (hi - x); - if lo != 0.0 { - partials[i] = lo; - i += 1; - } - x = hi; - } - - partials.truncate(i); - if x != 0.0 { - if !x.is_finite() { - // a non-finite x could arise either as - // a result of intermediate overflow, or - // as a result of a nan or inf in the - // summands - if xsave.is_finite() { - return Err(vm.new_overflow_error("intermediate overflow in fsum")); - } - if xsave.is_infinite() { - inf_sum += xsave; - } - special_sum += xsave; - // reset partials - partials.clear(); - } else { - partials.push(x); - } - } - } - if special_sum != 0.0 { - return if inf_sum.is_nan() { - Err(vm.new_value_error("-inf + inf in fsum")) - } else { - Ok(special_sum) - }; - } - - let mut n = partials.len(); - if n > 0 { - n -= 1; - let mut hi = partials[n]; - - let mut lo = 0.0; - while n > 0 { - let x = hi; - - n -= 1; - let y = partials[n]; - - hi = x + y; - lo = y - (hi - x); - if lo != 0.0 { - break; - } - } - if n > 0 && ((lo < 0.0 && partials[n - 1] < 0.0) || (lo > 0.0 && partials[n - 1] > 0.0)) - { - let y = lo + lo; - let x = hi + y; - - // Make half-even rounding work across multiple partials. - // Needed so that sum([1e-16, 1, 1e16]) will round-up the last - // digit to two instead of down to zero (the 1e-16 makes the 1 - // slightly closer to two). With a potential 1 ULP rounding - // error fixed-up, math.fsum() can guarantee commutativity. - if y == x - hi { - hi = x; - } - } - - Ok(hi) - } else { - Ok(0.0) - } - } - - #[pyfunction] - fn factorial(x: PyIntRef, vm: &VirtualMachine) -> PyResult { - let value = x.as_bigint(); - let one = BigInt::one(); - if value.is_negative() { - return Err(vm.new_value_error("factorial() not defined for negative values")); - } else if *value <= one { - return Ok(one); - } - // start from 2, since we know that value > 1 and 1*2=2 - let mut current = one + 1; - let mut product = BigInt::from(2u8); - while current < *value { - current += 1; - product *= ¤t; - } - Ok(product) - } - - #[pyfunction] - fn perm( - n: ArgIndex, - k: OptionalArg>, - vm: &VirtualMachine, - ) -> PyResult { - let n = n.into_int_ref(); - let n = n.as_bigint(); - let k_ref; - let v = match k.flatten() { - Some(k) => { - k_ref = k.into_int_ref(); - k_ref.as_bigint() - } - None => n, - }; - - if n.is_negative() || v.is_negative() { - return Err(vm.new_value_error("perm() not defined for negative values")); - } - if v > n { - return Ok(BigInt::zero()); - } - let mut result = BigInt::one(); - let mut current = n.clone(); - let tmp = n - v; - while current > tmp { - result *= ¤t; - current -= 1; - } - Ok(result) - } - - #[pyfunction] - fn comb(n: ArgIndex, k: ArgIndex, vm: &VirtualMachine) -> PyResult { - let k = k.into_int_ref(); - let mut k = k.as_bigint(); - let n = n.into_int_ref(); - let n = n.as_bigint(); - let one = BigInt::one(); - let zero = BigInt::zero(); - - if n.is_negative() || k.is_negative() { - return Err(vm.new_value_error("comb() not defined for negative values")); - } - - let temp = n - k; - if temp.is_negative() { - return Ok(zero); - } - - if temp < *k { - k = &temp - } - - if k.is_zero() { - return Ok(one); - } - - let mut result = n.clone(); - let mut factor = n.clone(); - let mut current = one; - while current < *k { - factor -= 1; - current += 1; - - result *= &factor; - result /= ¤t; - } - - Ok(result) + let values: Result, _> = + seq.iter(vm)?.map(|r| r.map(|v| v.into_float())).collect(); + pymath::math::fsum(values?).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn modf(x: ArgIntoFloat) -> (f64, f64) { - let x = x.into_float(); - if !x.is_finite() { - if x.is_infinite() { - return (0.0_f64.copysign(x), x); - } else if x.is_nan() { - return (x, x); - } - } - - (x.fract(), x.trunc()) + pymath::math::modf(x.into_float()) } #[derive(FromArgs)] @@ -887,85 +446,36 @@ mod math { #[pyfunction] fn nextafter(arg: NextAfterArgs, vm: &VirtualMachine) -> PyResult { - let steps: Option = arg - .steps - .map(|v| v.into_int_ref().try_to_primitive(vm)) - .transpose()? - .into_option(); - let x: f64 = arg.x.into(); - let y: f64 = arg.y.into(); - match steps { + let x = arg.x.into_float(); + let y = arg.y.into_float(); + + let steps = match arg.steps.into_option() { Some(steps) => { + let steps: i64 = steps.into_int_ref().try_to_primitive(vm)?; if steps < 0 { return Err(vm.new_value_error("steps must be a non-negative integer")); } - Ok(float_ops::nextafter_with_steps(x, y, steps as u64)) + Some(steps as u64) } - None => Ok(float_ops::nextafter(x, y)), - } + None => None, + }; + Ok(pymath::math::nextafter(x, y, steps)) } #[pyfunction] fn ulp(x: ArgIntoFloat) -> f64 { - float_ops::ulp(x.into()) - } - - fn fmod(x: f64, y: f64) -> f64 { - if y.is_infinite() && x.is_finite() { - return x; - } - - x % y + pymath::math::ulp(x.into_float()) } #[pyfunction(name = "fmod")] fn py_fmod(x: ArgIntoFloat, y: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - let y = y.into_float(); - - let r = fmod(x, y); - - if r.is_nan() && !x.is_nan() && !y.is_nan() { - return Err(vm.new_value_error("math domain error")); - } - - Ok(r) + pymath::math::fmod(x.into_float(), y.into_float()).map_err(|err| pymath_exception(err, vm)) } #[pyfunction] fn remainder(x: ArgIntoFloat, y: ArgIntoFloat, vm: &VirtualMachine) -> PyResult { - let x = x.into_float(); - let y = y.into_float(); - - if x.is_finite() && y.is_finite() { - if y == 0.0 { - return Err(vm.new_value_error("math domain error")); - } - - let abs_x = x.abs(); - let abs_y = y.abs(); - let modulus = abs_x % abs_y; - - let c = abs_y - modulus; - let r = match modulus.partial_cmp(&c) { - Some(Ordering::Less) => modulus, - Some(Ordering::Greater) => -c, - _ => modulus - 2.0 * fmod(0.5 * (abs_x - modulus), abs_y), - }; - - return Ok(1.0_f64.copysign(x) * r); - } - if x.is_infinite() && !y.is_nan() { - return Err(vm.new_value_error("math domain error")); - } - if x.is_nan() || y.is_nan() { - return Ok(f64::NAN); - } - if y.is_infinite() { - Ok(x) - } else { - Err(vm.new_value_error("math domain error")) - } + pymath::math::remainder(x.into_float(), y.into_float()) + .map_err(|err| pymath_exception(err, vm)) } #[derive(FromArgs)] @@ -978,15 +488,118 @@ mod math { #[pyfunction] fn prod(args: ProdArgs, vm: &VirtualMachine) -> PyResult { + use crate::vm::builtins::PyInt; + let iter = args.iterable; + let start = args.start; + + // Check if start is provided and what type it is (exact types only, not subclasses) + let (mut obj_result, start_is_int, start_is_float) = match &start { + OptionalArg::Present(s) => { + let is_int = s.class().is(vm.ctx.types.int_type); + let is_float = s.class().is(vm.ctx.types.float_type); + (Some(s.clone()), is_int, is_float) + } + OptionalArg::Missing => (None, true, false), // Default is int 1 + }; + + let mut item_iter = iter.iter(vm)?; + + // Integer fast path + if start_is_int && !start_is_float { + let mut int_result: i64 = match &start { + OptionalArg::Present(s) => { + if let Some(i) = s.downcast_ref::() { + match i.as_bigint().try_into() { + Ok(v) => v, + Err(_) => { + // Start overflows i64, fall through to generic path + obj_result = Some(s.clone()); + i64::MAX // Will be ignored + } + } + } else { + 1 + } + } + OptionalArg::Missing => 1, + }; - let mut result = args.start.unwrap_or_else(|| vm.new_pyobj(1)); + if obj_result.is_none() { + loop { + let item = match item_iter.next() { + Some(r) => r?, + None => return Ok(vm.ctx.new_int(int_result).into()), + }; + + // Only use fast path for exact int type (not subclasses) + if item.class().is(vm.ctx.types.int_type) + && let Some(int_item) = item.downcast_ref::() + && let Ok(b) = int_item.as_bigint().try_into() as Result + && let Some(product) = int_result.checked_mul(b) + { + int_result = product; + continue; + } - // TODO: CPython has optimized implementation for this - // refer: https://github.com/python/cpython/blob/main/Modules/mathmodule.c#L3093-L3193 - for obj in iter.iter(vm)? { - let obj = obj?; - result = vm._mul(&result, &obj)?; + // Overflow or non-int: restore to PyObject and continue + obj_result = Some(vm.ctx.new_int(int_result).into()); + let temp = vm._mul(obj_result.as_ref().unwrap(), &item)?; + obj_result = Some(temp); + break; + } + } + } + + // Float fast path + let obj_float = obj_result + .as_ref() + .and_then(|obj| obj.clone().downcast::().ok()); + if obj_float.is_some() || start_is_float { + let mut flt_result: f64 = if let Some(ref f) = obj_float { + f.to_f64() + } else if start_is_float && let OptionalArg::Present(s) = &start { + s.downcast_ref::() + .map(|f| f.to_f64()) + .unwrap_or(1.0) + } else { + 1.0 + }; + + loop { + let item = match item_iter.next() { + Some(r) => r?, + None => return Ok(vm.ctx.new_float(flt_result).into()), + }; + + // Only use fast path for exact float/int types (not subclasses) + if item.class().is(vm.ctx.types.float_type) + && let Some(f) = item.downcast_ref::() + { + flt_result *= f.to_f64(); + continue; + } + if item.class().is(vm.ctx.types.int_type) + && let Some(i) = item.downcast_ref::() + && let Ok(v) = i.as_bigint().try_into() as Result + { + flt_result *= v as f64; + continue; + } + + // Non-exact-float/int: restore and continue with generic path + obj_result = Some(vm.ctx.new_float(flt_result).into()); + let temp = vm._mul(obj_result.as_ref().unwrap(), &item)?; + obj_result = Some(temp); + break; + } + } + + // Generic path for remaining items + let mut result = obj_result.unwrap_or_else(|| vm.ctx.new_int(1).into()); + for item in item_iter { + let item = item?; + result = vm._mul(&result, &item)?; } Ok(result) @@ -998,29 +611,145 @@ mod math { q: ArgIterable, vm: &VirtualMachine, ) -> PyResult { + use crate::vm::builtins::PyInt; + let mut p_iter = p.iter(vm)?; let mut q_iter = q.iter(vm)?; - // We cannot just create a float because the iterator may contain - // anything as long as it supports __add__ and __mul__. - let mut result = vm.new_pyobj(0); + + // Fast path state + let mut int_path_enabled = true; + let mut int_total: i64 = 0; + let mut int_total_in_use = false; + let mut flt_p_values: Vec = Vec::new(); + let mut flt_q_values: Vec = Vec::new(); + + // Fallback accumulator for generic Python path + let mut obj_total: Option = None; + loop { let m_p = p_iter.next(); let m_q = q_iter.next(); - match (m_p, m_q) { - (Some(r_p), Some(r_q)) => { - let p = r_p?; - let q = r_q?; - let tmp = vm._mul(&p, &q)?; - result = vm._add(&result, &tmp)?; + + let (p_i, q_i, finished) = match (m_p, m_q) { + (Some(r_p), Some(r_q)) => (Some(r_p?), Some(r_q?), false), + (None, None) => (None, None, true), + _ => return Err(vm.new_value_error("Inputs are not the same length")), + }; + + // Integer fast path (only for exact int types, not subclasses) + if int_path_enabled { + if !finished { + let (p_i, q_i) = (p_i.as_ref().unwrap(), q_i.as_ref().unwrap()); + if p_i.class().is(vm.ctx.types.int_type) + && q_i.class().is(vm.ctx.types.int_type) + && let (Some(p_int), Some(q_int)) = + (p_i.downcast_ref::(), q_i.downcast_ref::()) + && let (Ok(p_val), Ok(q_val)) = ( + p_int.as_bigint().try_into() as Result, + q_int.as_bigint().try_into() as Result, + ) + && let Some(prod) = p_val.checked_mul(q_val) + && let Some(new_total) = int_total.checked_add(prod) + { + int_total = new_total; + int_total_in_use = true; + continue; + } } - (None, None) => break, - _ => { - return Err(vm.new_value_error("Inputs are not the same length")); + // Finalize int path + int_path_enabled = false; + if int_total_in_use { + let int_obj: PyObjectRef = vm.ctx.new_int(int_total).into(); + obj_total = Some(match obj_total { + Some(total) => vm._add(&total, &int_obj)?, + None => int_obj, + }); + int_total = 0; + int_total_in_use = false; } } + + // Float fast path - only when at least one value is exact float type + // (not subclasses, to preserve custom __mul__/__add__ behavior) + { + if !finished { + let (p_i, q_i) = (p_i.as_ref().unwrap(), q_i.as_ref().unwrap()); + + let p_is_exact_float = p_i.class().is(vm.ctx.types.float_type); + let q_is_exact_float = q_i.class().is(vm.ctx.types.float_type); + let p_is_exact_int = p_i.class().is(vm.ctx.types.int_type); + let q_is_exact_int = q_i.class().is(vm.ctx.types.int_type); + let p_is_exact_numeric = p_is_exact_float || p_is_exact_int; + let q_is_exact_numeric = q_is_exact_float || q_is_exact_int; + let has_exact_float = p_is_exact_float || q_is_exact_float; + + // Only use float path if at least one is exact float and both are exact int/float + if has_exact_float && p_is_exact_numeric && q_is_exact_numeric { + let p_flt = if let Some(f) = p_i.downcast_ref::() { + Some(f.to_f64()) + } else if let Some(i) = p_i.downcast_ref::() { + // PyLong_AsDouble fails for integers too large for f64 + try_bigint_to_f64(i.as_bigint(), vm).ok() + } else { + None + }; + + let q_flt = if let Some(f) = q_i.downcast_ref::() { + Some(f.to_f64()) + } else if let Some(i) = q_i.downcast_ref::() { + // PyLong_AsDouble fails for integers too large for f64 + try_bigint_to_f64(i.as_bigint(), vm).ok() + } else { + None + }; + + if let (Some(p_val), Some(q_val)) = (p_flt, q_flt) { + flt_p_values.push(p_val); + flt_q_values.push(q_val); + continue; + } + } + } + // Finalize float path + if !flt_p_values.is_empty() { + let flt_result = pymath::math::sumprod(&flt_p_values, &flt_q_values); + let flt_obj: PyObjectRef = vm.ctx.new_float(flt_result).into(); + obj_total = Some(match obj_total { + Some(total) => vm._add(&total, &flt_obj)?, + None => flt_obj, + }); + flt_p_values.clear(); + flt_q_values.clear(); + } + } + + if finished { + break; + } + + // Generic Python path + let (p_i, q_i) = (p_i.unwrap(), q_i.unwrap()); + + // Collect current + remaining elements + let p_remaining: Result, _> = + std::iter::once(Ok(p_i)).chain(p_iter).collect(); + let q_remaining: Result, _> = + std::iter::once(Ok(q_i)).chain(q_iter).collect(); + let (p_vec, q_vec) = (p_remaining?, q_remaining?); + + if p_vec.len() != q_vec.len() { + return Err(vm.new_value_error("Inputs are not the same length")); + } + + let mut total = obj_total.unwrap_or_else(|| vm.ctx.new_int(0).into()); + for (p_item, q_item) in p_vec.into_iter().zip(q_vec) { + let prod = vm._mul(&p_item, &q_item)?; + total = vm._add(&total, &prod)?; + } + return Ok(total); } - Ok(result) + Ok(obj_total.unwrap_or_else(|| vm.ctx.new_int(0).into())) } #[pyfunction] @@ -1030,30 +759,202 @@ mod math { z: ArgIntoFloat, vm: &VirtualMachine, ) -> PyResult { - let x = x.into_float(); - let y = y.into_float(); - let z = z.into_float(); - let result = x.mul_add(y, z); + pymath::math::fma(x.into_float(), y.into_float(), z.into_float()).map_err(|err| match err { + pymath::Error::EDOM => vm.new_value_error("invalid operation in fma"), + pymath::Error::ERANGE => vm.new_overflow_error("overflow in fma"), + }) + } + + // Integer functions: + + #[pyfunction] + fn isqrt(x: ArgIndex, vm: &VirtualMachine) -> PyResult { + let value = x.into_int_ref(); + pymath::math::integer::isqrt(value.as_bigint()) + .map_err(|_| vm.new_value_error("isqrt() argument must be nonnegative")) + } - if result.is_finite() { - return Ok(result); + #[pyfunction] + fn gcd(args: PosArgs) -> BigInt { + let ints: Vec<_> = args + .into_vec() + .into_iter() + .map(|x| x.into_int_ref()) + .collect(); + let refs: Vec<_> = ints.iter().map(|x| x.as_bigint()).collect(); + pymath::math::integer::gcd(&refs) + } + + #[pyfunction] + fn lcm(args: PosArgs) -> BigInt { + let ints: Vec<_> = args + .into_vec() + .into_iter() + .map(|x| x.into_int_ref()) + .collect(); + let refs: Vec<_> = ints.iter().map(|x| x.as_bigint()).collect(); + pymath::math::integer::lcm(&refs) + } + + #[pyfunction] + fn factorial(x: PyIntRef, vm: &VirtualMachine) -> PyResult { + // Check for negative before overflow - negative values are always invalid + if x.as_bigint().is_negative() { + return Err(vm.new_value_error("factorial() not defined for negative values")); } + let n: i64 = x.try_to_primitive(vm).map_err(|_| { + vm.new_overflow_error("factorial() argument should not exceed 9223372036854775807") + })?; + pymath::math::integer::factorial(n) + .map(|r| r.into()) + .map_err(|_| vm.new_value_error("factorial() not defined for negative values")) + } + + #[pyfunction] + fn perm( + n: ArgIndex, + k: OptionalArg>, + vm: &VirtualMachine, + ) -> PyResult { + let n_int = n.into_int_ref(); + let n_big = n_int.as_bigint(); - if result.is_nan() { - if !x.is_nan() && !y.is_nan() && !z.is_nan() { - return Err(vm.new_value_error("invalid operation in fma")); + if n_big.is_negative() { + return Err(vm.new_value_error("n must be a non-negative integer")); + } + + // k = None means k = n (factorial) + let k_int = k.flatten().map(|k| k.into_int_ref()); + let k_big: Option<&BigInt> = k_int.as_ref().map(|k| k.as_bigint()); + + if let Some(k_val) = k_big { + if k_val.is_negative() { + return Err(vm.new_value_error("k must be a non-negative integer")); + } + if k_val > n_big { + return Ok(BigInt::from(0u8)); } - } else if x.is_finite() && y.is_finite() && z.is_finite() { - return Err(vm.new_overflow_error("overflow in fma")); } - Ok(result) + // Convert k to u64 (required by pymath) + let ki: u64 = match k_big { + None => match n_big.to_u64() { + Some(n) => n, + None => { + return Err(vm.new_overflow_error(format!("n must not exceed {}", u64::MAX))); + } + }, + Some(k_val) => match k_val.to_u64() { + Some(k) => k, + None => { + return Err(vm.new_overflow_error(format!("k must not exceed {}", u64::MAX))); + } + }, + }; + + // Fast path: n fits in i64 + if let Some(ni) = n_big.to_i64() + && ni >= 0 + && ki > 1 + { + let result = pymath::math::integer::perm(ni, Some(ki as i64)) + .map_err(|_| vm.new_value_error("perm() error"))?; + return Ok(result.into()); + } + + // BigInt path: use perm_bigint + let result = pymath::math::perm_bigint(n_big, ki); + Ok(result.into()) + } + + #[pyfunction] + fn comb(n: ArgIndex, k: ArgIndex, vm: &VirtualMachine) -> PyResult { + let n_int = n.into_int_ref(); + let n_big = n_int.as_bigint(); + let k_int = k.into_int_ref(); + let k_big = k_int.as_bigint(); + + if n_big.is_negative() { + return Err(vm.new_value_error("n must be a non-negative integer")); + } + if k_big.is_negative() { + return Err(vm.new_value_error("k must be a non-negative integer")); + } + + // Fast path: n fits in i64 + if let Some(ni) = n_big.to_i64() + && ni >= 0 + { + // k overflow or k > n means result is 0 + let ki = match k_big.to_i64() { + Some(k) if k >= 0 && k <= ni => k, + _ => return Ok(BigInt::from(0u8)), + }; + // Apply symmetry: use min(k, n-k) + let ki = ki.min(ni - ki); + if ki > 1 { + let result = pymath::math::integer::comb(ni, ki) + .map_err(|_| vm.new_value_error("comb() error"))?; + return Ok(result.into()); + } + // ki <= 1 cases + if ki == 0 { + return Ok(BigInt::from(1u8)); + } + return Ok(n_big.clone()); // ki == 1 + } + + // BigInt path: n doesn't fit in i64 + // Apply symmetry: k = min(k, n - k) + let n_minus_k = n_big - k_big; + if n_minus_k.is_negative() { + return Ok(BigInt::from(0u8)); + } + let effective_k = if &n_minus_k < k_big { + &n_minus_k + } else { + k_big + }; + + // k must fit in u64 + let ki: u64 = match effective_k.to_u64() { + Some(k) => k, + None => { + return Err( + vm.new_overflow_error(format!("min(n - k, k) must not exceed {}", u64::MAX)) + ); + } + }; + + let result = pymath::math::comb_bigint(n_big, ki); + Ok(result.into()) } } -fn pymath_error_to_exception(err: pymath::Error, vm: &VirtualMachine) -> PyBaseExceptionRef { +pub(crate) fn pymath_exception(err: pymath::Error, vm: &VirtualMachine) -> PyBaseExceptionRef { match err { pymath::Error::EDOM => vm.new_value_error("math domain error"), pymath::Error::ERANGE => vm.new_overflow_error("math range error"), } } + +/// Format a float in Python style (ensures trailing .0 for integers). +fn float_repr(value: f64) -> String { + if value.is_nan() { + "nan".to_owned() + } else if value.is_infinite() { + if value.is_sign_positive() { + "inf".to_owned() + } else { + "-inf".to_owned() + } + } else { + let s = format!("{}", value); + // If no decimal point and not in scientific notation, add .0 + if !s.contains('.') && !s.contains('e') && !s.contains('E') { + format!("{}.0", s) + } else { + s + } + } +} diff --git a/crates/vm/src/builtins/float.rs b/crates/vm/src/builtins/float.rs index f101a3aa8e3..9941aea9ba2 100644 --- a/crates/vm/src/builtins/float.rs +++ b/crates/vm/src/builtins/float.rs @@ -116,7 +116,7 @@ fn inner_divmod(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult<(f64, f64)> { pub fn float_pow(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult { if v1.is_zero() && v2.is_sign_negative() { - let msg = "0.0 cannot be raised to a negative power"; + let msg = "zero to a negative power"; Err(vm.new_zero_division_error(msg.to_owned())) } else if v1.is_sign_negative() && (v2.floor() - v2).abs() > f64::EPSILON { let v1 = Complex64::new(v1, 0.);