From 65bdfc3d4edfa6a0e3e86edb78dc8f7cca172311 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Tue, 16 Dec 2025 23:11:35 +0900 Subject: [PATCH 001/418] Integrate OSError creations into OSErrorBuilder (#6443) --- crates/vm/src/exceptions.rs | 114 ++++++++++++++++++++++++++++++++++ crates/vm/src/ospath.rs | 67 +++----------------- crates/vm/src/stdlib/io.rs | 46 +++++++------- crates/vm/src/stdlib/os.rs | 50 ++++++++------- crates/vm/src/stdlib/posix.rs | 23 +++---- crates/vm/src/vm/vm_new.rs | 24 +------ 6 files changed, 189 insertions(+), 135 deletions(-) diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 04dd78fb448..2b725085bec 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -1193,6 +1193,8 @@ pub(crate) fn errno_to_exc_type(_errno: i32, _vm: &VirtualMachine) -> Option<&'s None } +pub(crate) use types::{OSErrorBuilder, ToOSErrorBuilder}; + pub(super) mod types { use crate::common::lock::PyRwLock; use crate::object::{MaybeTraverse, Traverse, TraverseFn}; @@ -1204,6 +1206,7 @@ pub(super) mod types { PyInt, PyStrRef, PyTupleRef, PyType, PyTypeRef, traceback::PyTracebackRef, tuple::IntoPyTuple, }, + convert::ToPyObject, convert::ToPyResult, function::{ArgBytesLike, FuncArgs}, types::{Constructor, Initializer}, @@ -1212,6 +1215,117 @@ pub(super) mod types { use itertools::Itertools; use rustpython_common::str::UnicodeEscapeCodepoint; + pub(crate) trait ToOSErrorBuilder { + fn to_os_error_builder(&self, vm: &VirtualMachine) -> OSErrorBuilder; + } + + pub struct OSErrorBuilder { + exc_type: PyTypeRef, + errno: Option, + strerror: Option, + filename: Option, + #[cfg(windows)] + winerror: Option, + filename2: Option, + } + + impl OSErrorBuilder { + #[must_use] + pub fn with_subtype( + exc_type: PyTypeRef, + errno: Option, + strerror: impl ToPyObject, + vm: &VirtualMachine, + ) -> Self { + let strerror = strerror.to_pyobject(vm); + Self { + exc_type, + errno, + strerror: Some(strerror), + filename: None, + #[cfg(windows)] + winerror: None, + filename2: None, + } + } + #[must_use] + pub fn with_errno(errno: i32, strerror: impl ToPyObject, vm: &VirtualMachine) -> Self { + let exc_type = crate::exceptions::errno_to_exc_type(errno, vm) + .unwrap_or(vm.ctx.exceptions.os_error) + .to_owned(); + Self::with_subtype(exc_type, Some(errno), strerror, vm) + } + + // #[must_use] + // pub(crate) fn errno(mut self, errno: i32) -> Self { + // self.errno.replace(errno); + // self + // } + + #[must_use] + pub(crate) fn filename(mut self, filename: PyObjectRef) -> Self { + self.filename.replace(filename); + self + } + + #[must_use] + pub(crate) fn filename2(mut self, filename: PyObjectRef) -> Self { + self.filename2.replace(filename); + self + } + + #[must_use] + #[cfg(windows)] + pub(crate) fn winerror(mut self, winerror: PyObjectRef) -> Self { + self.winerror.replace(winerror); + self + } + + pub fn build(self, vm: &VirtualMachine) -> PyRef { + let OSErrorBuilder { + exc_type, + errno, + strerror, + filename, + #[cfg(windows)] + winerror, + filename2, + } = self; + + let args = if let Some(errno) = errno { + #[cfg(windows)] + let winerror = winerror.to_pyobject(vm); + #[cfg(not(windows))] + let winerror = vm.ctx.none(); + + vec![ + errno.to_pyobject(vm), + strerror.to_pyobject(vm), + filename.to_pyobject(vm), + winerror, + filename2.to_pyobject(vm), + ] + } else { + vec![strerror.to_pyobject(vm)] + }; + + let payload = PyOSError::py_new(&exc_type, args.clone().into(), vm) + .expect("new_os_error usage error"); + let os_error = payload + .into_ref_with_type(vm, exc_type) + .expect("new_os_error usage error"); + PyOSError::slot_init(os_error.as_object().to_owned(), args.into(), vm) + .expect("new_os_error usage error"); + os_error + } + } + + impl crate::convert::IntoPyException for OSErrorBuilder { + fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { + self.build(vm).upcast() + } + } + // Re-export exception group types from dedicated module pub use crate::exception_group::types::PyBaseExceptionGroup; diff --git a/crates/vm/src/ospath.rs b/crates/vm/src/ospath.rs index add40f9b20c..9fca53d869c 100644 --- a/crates/vm/src/ospath.rs +++ b/crates/vm/src/ospath.rs @@ -2,10 +2,8 @@ use rustpython_common::crt_fd; use crate::{ PyObjectRef, PyResult, VirtualMachine, - builtins::PyBaseExceptionRef, convert::{IntoPyException, ToPyException, ToPyObject, TryFromObject}, function::FsPath, - object::AsObject, }; use std::path::{Path, PathBuf}; @@ -144,62 +142,17 @@ impl OsPathOrFd<'_> { } } -// TODO: preserve the input `PyObjectRef` of filename and filename2 (Failing check `self.assertIs(err.filename, name, str(func)`) -pub struct IOErrorBuilder<'a> { - error: &'a std::io::Error, - filename: Option>, - filename2: Option>, -} - -impl<'a> IOErrorBuilder<'a> { - pub const fn new(error: &'a std::io::Error) -> Self { - Self { - error, - filename: None, - filename2: None, - } - } - - pub(crate) fn filename(mut self, filename: impl Into>) -> Self { - let filename = filename.into(); - self.filename.replace(filename); - self - } - - pub(crate) fn filename2(mut self, filename: impl Into>) -> Self { - let filename = filename.into(); - self.filename2.replace(filename); - self - } - - pub(crate) fn with_filename( - error: &'a std::io::Error, +impl crate::exceptions::OSErrorBuilder { + #[must_use] + pub(crate) fn with_filename<'a>( + error: &std::io::Error, filename: impl Into>, vm: &VirtualMachine, - ) -> PyBaseExceptionRef { - let zelf = IOErrorBuilder { - error, - filename: Some(filename.into()), - filename2: None, - }; - zelf.to_pyexception(vm) - } -} - -impl ToPyException for IOErrorBuilder<'_> { - fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { - let exc = self.error.to_pyexception(vm); - - if let Some(filename) = &self.filename { - exc.as_object() - .set_attr("filename", filename.filename(vm), vm) - .unwrap(); - } - if let Some(filename2) = &self.filename2 { - exc.as_object() - .set_attr("filename2", filename2.filename(vm), vm) - .unwrap(); - } - exc + ) -> crate::builtins::PyBaseExceptionRef { + // TODO: return type to PyRef + use crate::exceptions::ToOSErrorBuilder; + let builder = error.to_os_error_builder(vm); + let builder = builder.filename(filename.into().filename(vm)); + builder.build(vm).upcast() } } diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs index 3d67591d567..56bcbefddeb 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/io.rs @@ -14,11 +14,12 @@ use crate::{ builtins::{PyBaseExceptionRef, PyModule}, common::os::ErrorExt, convert::{IntoPyException, ToPyException}, + exceptions::{OSErrorBuilder, ToOSErrorBuilder}, }; pub use _io::{OpenArgs, io_open as open}; -impl ToPyException for std::io::Error { - fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { +impl ToOSErrorBuilder for std::io::Error { + fn to_os_error_builder(&self, vm: &VirtualMachine) -> OSErrorBuilder { let errno = self.posix_errno(); #[cfg(windows)] let msg = 'msg: { @@ -53,23 +54,23 @@ impl ToPyException for std::io::Error { #[cfg(not(any(windows, unix)))] let msg = self.to_string(); - #[allow(clippy::let_and_return)] - let exc = vm.new_errno_error(errno, msg); - #[cfg(windows)] - { - use crate::object::AsObject; - let winerror = if let Some(winerror) = self.raw_os_error() { - vm.new_pyobj(winerror) - } else { - vm.ctx.none() - }; + #[allow(unused_mut)] + let mut builder = OSErrorBuilder::with_errno(errno, msg, vm); - // FIXME: manual setup winerror due to lack of OSError.__init__ support - exc.as_object() - .set_attr("winerror", vm.new_pyobj(winerror), vm) - .unwrap(); + #[cfg(windows)] + if let Some(winerror) = self.raw_os_error() { + use crate::convert::ToPyObject; + builder = builder.winerror(winerror.to_pyobject(vm)); } - exc.upcast() + + builder + } +} + +impl ToPyException for std::io::Error { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + let builder = self.to_os_error_builder(vm); + builder.into_pyexception(vm) } } @@ -4328,8 +4329,9 @@ mod fileio { builtins::{PyBaseExceptionRef, PyUtf8Str, PyUtf8StrRef}, common::crt_fd, convert::{IntoPyException, ToPyException}, + exceptions::OSErrorBuilder, function::{ArgBytesLike, ArgMemoryBuffer, OptionalArg, OptionalOption}, - ospath::{IOErrorBuilder, OsPath, OsPathOrFd}, + ospath::{OsPath, OsPathOrFd}, stdlib::os, types::{Constructor, DefaultConstructor, Destructor, Initializer, Representable}, }; @@ -4526,7 +4528,7 @@ mod fileio { let filename = OsPathOrFd::Path(path); match fd { Ok(fd) => (fd.into_raw(), Some(filename)), - Err(e) => return Err(IOErrorBuilder::with_filename(&e, filename, vm)), + Err(e) => return Err(OSErrorBuilder::with_filename(&e, filename, vm)), } } }; @@ -4541,7 +4543,7 @@ mod fileio { #[cfg(windows)] { if let Err(err) = fd_fstat { - return Err(IOErrorBuilder::with_filename(&err, filename, vm)); + return Err(OSErrorBuilder::with_filename(&err, filename, vm)); } } #[cfg(any(unix, target_os = "wasi"))] @@ -4550,12 +4552,12 @@ mod fileio { Ok(status) => { if (status.st_mode & libc::S_IFMT) == libc::S_IFDIR { let err = std::io::Error::from_raw_os_error(libc::EISDIR); - return Err(IOErrorBuilder::with_filename(&err, filename, vm)); + return Err(OSErrorBuilder::with_filename(&err, filename, vm)); } } Err(err) => { if err.raw_os_error() == Some(libc::EBADF) { - return Err(IOErrorBuilder::with_filename(&err, filename, vm)); + return Err(OSErrorBuilder::with_filename(&err, filename, vm)); } } } diff --git a/crates/vm/src/stdlib/os.rs b/crates/vm/src/stdlib/os.rs index f6ffd66759d..a698cae059c 100644 --- a/crates/vm/src/stdlib/os.rs +++ b/crates/vm/src/stdlib/os.rs @@ -153,6 +153,7 @@ pub(super) mod _os { AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, builtins::{ PyBytesRef, PyGenericAlias, PyIntRef, PyStrRef, PyTuple, PyTupleRef, PyTypeRef, + ToOSErrorBuilder, }, common::{ crt_fd, @@ -161,8 +162,9 @@ pub(super) mod _os { suppress_iph, }, convert::{IntoPyException, ToPyObject}, + exceptions::OSErrorBuilder, function::{ArgBytesLike, FsPath, FuncArgs, OptionalArg}, - ospath::{IOErrorBuilder, OsPath, OsPathOrFd, OutputMode}, + ospath::{OsPath, OsPathOrFd, OutputMode}, protocol::PyIterReturn, recursion::ReprGuard, types::{IterNext, Iterable, PyStructSequence, Representable, SelfIter, Unconstructible}, @@ -263,7 +265,7 @@ pub(super) mod _os { crt_fd::open(&name, flags, mode) } }; - fd.map_err(|err| IOErrorBuilder::with_filename(&err, name, vm)) + fd.map_err(|err| OSErrorBuilder::with_filename(&err, name, vm)) } #[pyfunction] @@ -316,7 +318,7 @@ pub(super) mod _os { } else { fs::remove_file(&path) }; - res.map_err(|err| IOErrorBuilder::with_filename(&err, path, vm)) + res.map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } #[cfg(not(windows))] @@ -334,7 +336,7 @@ pub(super) mod _os { let res = unsafe { libc::mkdirat(fd, c_path.as_ptr(), mode as _) }; return if res < 0 { let err = crate::common::os::errno_io_error(); - Err(IOErrorBuilder::with_filename(&err, path, vm)) + Err(OSErrorBuilder::with_filename(&err, path, vm)) } else { Ok(()) }; @@ -344,7 +346,7 @@ pub(super) mod _os { let res = unsafe { libc::mkdir(c_path.as_ptr(), mode as _) }; if res < 0 { let err = crate::common::os::errno_io_error(); - return Err(IOErrorBuilder::with_filename(&err, path, vm)); + return Err(OSErrorBuilder::with_filename(&err, path, vm)); } Ok(()) } @@ -357,7 +359,7 @@ pub(super) mod _os { #[pyfunction] fn rmdir(path: OsPath, dir_fd: DirFd<'_, 0>, vm: &VirtualMachine) -> PyResult<()> { let [] = dir_fd.0; - fs::remove_dir(&path).map_err(|err| IOErrorBuilder::with_filename(&err, path, vm)) + fs::remove_dir(&path).map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } const LISTDIR_FD: bool = cfg!(all(unix, not(target_os = "redox"))); @@ -373,13 +375,13 @@ pub(super) mod _os { let dir_iter = match fs::read_dir(&path) { Ok(iter) => iter, Err(err) => { - return Err(IOErrorBuilder::with_filename(&err, path, vm)); + return Err(OSErrorBuilder::with_filename(&err, path, vm)); } }; dir_iter .map(|entry| match entry { Ok(entry_path) => Ok(path.mode.process_path(entry_path.file_name(), vm)), - Err(err) => Err(IOErrorBuilder::with_filename(&err, path.clone(), vm)), + Err(err) => Err(OSErrorBuilder::with_filename(&err, path.clone(), vm)), }) .collect::>()? } @@ -546,7 +548,7 @@ pub(super) mod _os { let mode = path.mode; let [] = dir_fd.0; let path = - fs::read_link(&path).map_err(|err| IOErrorBuilder::with_filename(&err, path, vm))?; + fs::read_link(&path).map_err(|err| OSErrorBuilder::with_filename(&err, path, vm))?; Ok(mode.process_path(path, vm)) } @@ -859,7 +861,7 @@ pub(super) mod _os { fn scandir(path: OptionalArg, vm: &VirtualMachine) -> PyResult { let path = path.unwrap_or_else(|| OsPath::new_str(".")); let entries = fs::read_dir(&path.path) - .map_err(|err| IOErrorBuilder::with_filename(&err, path.clone(), vm))?; + .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; Ok(ScandirIterator { entries: PyRwLock::new(Some(entries)), mode: path.mode, @@ -1084,7 +1086,7 @@ pub(super) mod _os { vm: &VirtualMachine, ) -> PyResult { let stat = stat_inner(file.clone(), dir_fd, follow_symlinks) - .map_err(|err| IOErrorBuilder::with_filename(&err, file, vm))? + .map_err(|err| OSErrorBuilder::with_filename(&err, file, vm))? .ok_or_else(|| crate::exceptions::cstring_error(vm))?; Ok(StatResultData::from_stat(&stat, vm).to_pyobject(vm)) } @@ -1115,7 +1117,7 @@ pub(super) mod _os { #[pyfunction] fn chdir(path: OsPath, vm: &VirtualMachine) -> PyResult<()> { env::set_current_dir(&path.path) - .map_err(|err| IOErrorBuilder::with_filename(&err, path, vm)) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } #[pyfunction] @@ -1127,10 +1129,10 @@ pub(super) mod _os { #[pyfunction(name = "replace")] fn rename(src: OsPath, dst: OsPath, vm: &VirtualMachine) -> PyResult<()> { fs::rename(&src.path, &dst.path).map_err(|err| { - IOErrorBuilder::new(&err) - .filename(src) - .filename2(dst) - .into_pyexception(vm) + let builder = err.to_os_error_builder(vm); + let builder = builder.filename(src.filename(vm)); + let builder = builder.filename2(dst.filename(vm)); + builder.build(vm).upcast() }) } @@ -1219,10 +1221,10 @@ pub(super) mod _os { #[pyfunction] fn link(src: OsPath, dst: OsPath, vm: &VirtualMachine) -> PyResult<()> { fs::hard_link(&src.path, &dst.path).map_err(|err| { - IOErrorBuilder::new(&err) - .filename(src) - .filename2(dst) - .into_pyexception(vm) + let builder = err.to_os_error_builder(vm); + let builder = builder.filename(src.filename(vm)); + let builder = builder.filename2(dst.filename(vm)); + builder.build(vm).upcast() }) } @@ -1334,7 +1336,7 @@ pub(super) mod _os { ) }; if ret < 0 { - Err(IOErrorBuilder::with_filename( + Err(OSErrorBuilder::with_filename( &io::Error::last_os_error(), path_for_err, vm, @@ -1385,14 +1387,14 @@ pub(super) mod _os { .write(true) .custom_flags(windows_sys::Win32::Storage::FileSystem::FILE_FLAG_BACKUP_SEMANTICS) .open(&path) - .map_err(|err| IOErrorBuilder::with_filename(&err, path.clone(), vm))?; + .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; let ret = unsafe { FileSystem::SetFileTime(f.as_raw_handle() as _, std::ptr::null(), &acc, &modif) }; if ret == 0 { - Err(IOErrorBuilder::with_filename( + Err(OSErrorBuilder::with_filename( &io::Error::last_os_error(), path, vm, @@ -1565,7 +1567,7 @@ pub(super) mod _os { error: std::io::Error, path: OsPath, ) -> crate::builtins::PyBaseExceptionRef { - IOErrorBuilder::with_filename(&error, path, vm) + OSErrorBuilder::with_filename(&error, path, vm) } let path = OsPath::try_from_object(vm, path)?; diff --git a/crates/vm/src/stdlib/posix.rs b/crates/vm/src/stdlib/posix.rs index 680e9914a03..cfe605733d3 100644 --- a/crates/vm/src/stdlib/posix.rs +++ b/crates/vm/src/stdlib/posix.rs @@ -26,8 +26,9 @@ pub mod module { AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyDictRef, PyInt, PyListRef, PyStrRef, PyTupleRef, PyType, PyUtf8StrRef}, convert::{IntoPyException, ToPyObject, TryFromObject}, + exceptions::OSErrorBuilder, function::{Either, KwArgs, OptionalArg}, - ospath::{IOErrorBuilder, OsPath, OsPathOrFd}, + ospath::{OsPath, OsPathOrFd}, stdlib::os::{_os, DirFd, FollowSymlinks, SupportFunc, TargetIsDirectory, fs_metadata}, types::{Constructor, Representable}, utils::ToCString, @@ -412,7 +413,7 @@ pub mod module { } let metadata = - metadata.map_err(|err| IOErrorBuilder::with_filename(&err, path.clone(), vm))?; + metadata.map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; let user_id = metadata.uid(); let group_id = metadata.gid(); @@ -482,12 +483,12 @@ pub mod module { #[cfg(not(target_os = "redox"))] #[pyfunction] fn chroot(path: OsPath, vm: &VirtualMachine) -> PyResult<()> { - use crate::ospath::IOErrorBuilder; + use crate::exceptions::OSErrorBuilder; nix::unistd::chroot(&*path.path).map_err(|err| { // Use `From for io::Error` when it is available - let err = io::Error::from_raw_os_error(err as i32); - IOErrorBuilder::with_filename(&err, path, vm) + let io_err: io::Error = err.into(); + OSErrorBuilder::with_filename(&io_err, path, vm) }) } @@ -533,7 +534,7 @@ pub mod module { .map_err(|err| { // Use `From for io::Error` when it is available let err = io::Error::from_raw_os_error(err as i32); - IOErrorBuilder::with_filename(&err, path, vm) + OSErrorBuilder::with_filename(&err, path, vm) }) } @@ -1031,7 +1032,7 @@ pub mod module { permissions.set_mode(mode); fs::set_permissions(&path, permissions) }; - body().map_err(|err| IOErrorBuilder::with_filename(&err, err_path, vm)) + body().map_err(|err| OSErrorBuilder::with_filename(&err, err_path, vm)) } #[cfg(not(target_os = "redox"))] @@ -1093,7 +1094,7 @@ pub mod module { Ok(()) } else { let err = std::io::Error::last_os_error(); - Err(IOErrorBuilder::with_filename(&err, path, vm)) + Err(OSErrorBuilder::with_filename(&err, path, vm)) } } @@ -1554,7 +1555,7 @@ pub mod module { }; if let Err(err) = ret { let err = err.into(); - return Err(IOErrorBuilder::with_filename(&err, self.path, vm)); + return Err(OSErrorBuilder::with_filename(&err, self.path, vm)); } } } @@ -1653,7 +1654,7 @@ pub mod module { nix::spawn::posix_spawn(&*path, &file_actions, &attrp, &args, &env) }; ret.map(Into::into) - .map_err(|err| IOErrorBuilder::with_filename(&err.into(), self.path, vm)) + .map_err(|err| OSErrorBuilder::with_filename(&err.into(), self.path, vm)) } } @@ -2126,7 +2127,7 @@ pub mod module { if Errno::last_raw() == 0 { Ok(None) } else { - Err(IOErrorBuilder::with_filename( + Err(OSErrorBuilder::with_filename( &io::Error::from(Errno::last()), path, vm, diff --git a/crates/vm/src/vm/vm_new.rs b/crates/vm/src/vm/vm_new.rs index 36481e5dbe3..6d0e983c844 100644 --- a/crates/vm/src/vm/vm_new.rs +++ b/crates/vm/src/vm/vm_new.rs @@ -1,5 +1,5 @@ use crate::{ - AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, + AsObject, Py, PyObject, PyObjectRef, PyRef, builtins::{ PyBaseException, PyBaseExceptionRef, PyBytesRef, PyDictRef, PyModule, PyOSError, PyStrRef, PyType, PyTypeRef, @@ -8,9 +8,9 @@ use crate::{ tuple::{IntoPyTuple, PyTupleRef}, }, convert::{ToPyException, ToPyObject}, + exceptions::OSErrorBuilder, function::{IntoPyNativeFn, PyMethodFlags}, scope::Scope, - types::Constructor, vm::VirtualMachine, }; use rustpython_compiler_core::SourceLocation; @@ -119,26 +119,8 @@ impl VirtualMachine { msg: impl ToPyObject, ) -> PyRef { debug_assert_eq!(exc_type.slots.basicsize, std::mem::size_of::()); - let msg = msg.to_pyobject(self); - - fn new_os_subtype_error_impl( - vm: &VirtualMachine, - exc_type: PyTypeRef, - errno: Option, - msg: PyObjectRef, - ) -> PyRef { - let args = match errno { - Some(e) => vec![vm.new_pyobj(e), msg], - None => vec![msg], - }; - let payload = - PyOSError::py_new(&exc_type, args.into(), vm).expect("new_os_error usage error"); - payload - .into_ref_with_type(vm, exc_type) - .expect("new_os_error usage error") - } - new_os_subtype_error_impl(self, exc_type, errno, msg) + OSErrorBuilder::with_subtype(exc_type, errno, msg, self).build(self) } /// Instantiate an exception with no arguments. From aef4de4ab883a8a442e956da5ce0cccf66466a8f Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Tue, 16 Dec 2025 23:27:21 +0900 Subject: [PATCH 002/418] fix buffer (#6447) --- .cspell.dict/cpython.txt | 3 +++ crates/vm/src/buffer.rs | 19 ++++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index 26921e04080..8acb1468f66 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -15,6 +15,7 @@ cellvar cellvars cmpop denom +DICTFLAG dictoffset elts excepthandler @@ -28,6 +29,7 @@ heaptype HIGHRES IMMUTABLETYPE Itertool +keeped kwonlyarg kwonlyargs lasti @@ -48,6 +50,7 @@ PYTHREAD_NAME SA_ONSTACK SOABI stackdepth +stginfo stringlib structseq subparams diff --git a/crates/vm/src/buffer.rs b/crates/vm/src/buffer.rs index 13cebfc6a28..cf49d6815c0 100644 --- a/crates/vm/src/buffer.rs +++ b/crates/vm/src/buffer.rs @@ -238,10 +238,23 @@ impl FormatCode { _ => 1, }; + // Skip whitespace (Python ignores whitespace in format strings) + while let Some(b' ' | b'\t' | b'\n' | b'\r') = chars.peek() { + chars.next(); + } + // determine format char: - let c = chars - .next() - .ok_or_else(|| "repeat count given without format specifier".to_owned())?; + let c = match chars.next() { + Some(c) => c, + None => { + // If we have a repeat count but only whitespace follows, error + if repeat != 1 { + return Err("repeat count given without format specifier".to_owned()); + } + // Otherwise, we're done parsing + break; + } + }; // Check for embedded null character if c == 0 { From 246fab63f7cc89fa1a0ba423b7ada2bbce6f31f1 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 17 Dec 2025 00:12:03 +0900 Subject: [PATCH 003/418] flag: DISALLOW_INSTANTIATION (#6445) * Set tp_new slot when build heap/static type * Improve type tp_call impl to check tp_new existence and error if not exist * Set DISALLOW_INSTANTIATION flag on several types according to cpython impl * Allow #[pyslot] for function pointer * Fix DISALLOW_INSTANTIATION --------- Signed-off-by: snowapril Co-authored-by: snowapril --- Lib/test/test_descr.py | 1 - crates/derive-impl/src/pyclass.rs | 18 ++---- crates/stdlib/src/array.rs | 2 +- crates/stdlib/src/csv.rs | 4 +- crates/stdlib/src/pystruct.rs | 6 +- crates/stdlib/src/sqlite.rs | 10 +-- crates/stdlib/src/unicodedata.rs | 2 +- crates/stdlib/src/zlib.rs | 4 +- crates/vm/src/builtins/asyncgenerator.rs | 6 +- crates/vm/src/builtins/builtin_func.rs | 12 ++-- crates/vm/src/builtins/bytearray.rs | 6 +- crates/vm/src/builtins/bytes.rs | 5 +- crates/vm/src/builtins/coroutine.rs | 6 +- crates/vm/src/builtins/descriptor.rs | 15 ++--- crates/vm/src/builtins/dict.rs | 62 +++++++++--------- crates/vm/src/builtins/frame.rs | 6 +- crates/vm/src/builtins/generator.rs | 6 +- crates/vm/src/builtins/getset.rs | 5 +- crates/vm/src/builtins/list.rs | 8 +-- crates/vm/src/builtins/memory.rs | 5 +- crates/vm/src/builtins/range.rs | 8 +-- crates/vm/src/builtins/set.rs | 5 +- crates/vm/src/builtins/str.rs | 6 +- crates/vm/src/builtins/tuple.rs | 5 +- crates/vm/src/builtins/type.rs | 81 +++++++++++++++++++++--- crates/vm/src/class.rs | 24 +++++-- crates/vm/src/protocol/buffer.rs | 5 +- crates/vm/src/stdlib/ctypes/field.rs | 6 +- crates/vm/src/stdlib/os.rs | 9 +-- crates/vm/src/types/slot.rs | 9 --- 30 files changed, 184 insertions(+), 163 deletions(-) diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index f592a88fc2c..8c711207fae 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -1788,7 +1788,6 @@ class D(C): self.assertEqual(b.foo, 3) self.assertEqual(b.__class__, D) - @unittest.expectedFailure def test_bad_new(self): self.assertRaises(TypeError, object.__new__) self.assertRaises(TypeError, object.__new__, '') diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index 5060dced2b0..f784a2e2a76 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -954,7 +954,10 @@ where } else if let Ok(f) = args.item.function_or_method() { (&f.sig().ident, f.span()) } else { - return Err(self.new_syn_error(args.item.span(), "can only be on a method")); + return Err(self.new_syn_error( + args.item.span(), + "can only be on a method or const function pointer", + )); }; let item_attr = args.attrs.remove(self.index()); @@ -1496,7 +1499,9 @@ impl SlotItemMeta { } } else { let ident_str = self.inner().item_name(); - let name = if let Some(stripped) = ident_str.strip_prefix("slot_") { + // Convert to lowercase to handle both SLOT_NEW and slot_new + let ident_lower = ident_str.to_lowercase(); + let name = if let Some(stripped) = ident_lower.strip_prefix("slot_") { proc_macro2::Ident::new(stripped, inner.item_ident.span()) } else { inner.item_ident.clone() @@ -1609,7 +1614,6 @@ fn extract_impl_attrs(attr: PunctuatedNestedMeta, item: &Ident) -> Result { @@ -1634,9 +1638,6 @@ fn extract_impl_attrs(attr: PunctuatedNestedMeta, item: &Ident) -> Result::__extend_py_class), quote!(::__OWN_METHOD_DEFS), @@ -1689,11 +1690,6 @@ fn extract_impl_attrs(attr: PunctuatedNestedMeta, item: &Ident) -> Result bail_span!(attr, "Unknown pyimpl attribute"), } } - // TODO: DISALLOW_INSTANTIATION check is required - let _ = has_constructor; - // if !withs.is_empty() && !has_constructor { - // bail_span!(item, "#[pyclass(with(...))] does not have a Constructor. Either #[pyclass(with(Constructor, ...))] or #[pyclass(with(Unconstructible, ...))] is mandatory. Consider to add `impl DefaultConstructor for T {{}}` or `impl Unconstructible for T {{}}`.") - // } Ok(ExtractedImplAttrs { payload, diff --git a/crates/stdlib/src/array.rs b/crates/stdlib/src/array.rs index 4fcba1f8725..49ca4a89037 100644 --- a/crates/stdlib/src/array.rs +++ b/crates/stdlib/src/array.rs @@ -1399,7 +1399,7 @@ mod array { internal: PyMutex>, } - #[pyclass(with(IterNext, Iterable), flags(HAS_DICT))] + #[pyclass(with(IterNext, Iterable), flags(HAS_DICT, DISALLOW_INSTANTIATION))] impl PyArrayIter { #[pymethod] fn __setstate__(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { diff --git a/crates/stdlib/src/csv.rs b/crates/stdlib/src/csv.rs index 3c7cc2ff807..792402e0580 100644 --- a/crates/stdlib/src/csv.rs +++ b/crates/stdlib/src/csv.rs @@ -908,7 +908,7 @@ mod _csv { } } - #[pyclass(with(IterNext, Iterable))] + #[pyclass(with(IterNext, Iterable), flags(DISALLOW_INSTANTIATION))] impl Reader { #[pygetset] fn line_num(&self) -> u64 { @@ -1059,7 +1059,7 @@ mod _csv { } } - #[pyclass] + #[pyclass(flags(DISALLOW_INSTANTIATION))] impl Writer { #[pygetset(name = "dialect")] const fn get_dialect(&self, _vm: &VirtualMachine) -> PyDialect { diff --git a/crates/stdlib/src/pystruct.rs b/crates/stdlib/src/pystruct.rs index 798e5f5de80..0a006f5a0f2 100644 --- a/crates/stdlib/src/pystruct.rs +++ b/crates/stdlib/src/pystruct.rs @@ -16,7 +16,7 @@ pub(crate) mod _struct { function::{ArgBytesLike, ArgMemoryBuffer, PosArgs}, match_class, protocol::PyIterReturn, - types::{Constructor, IterNext, Iterable, Representable, SelfIter, Unconstructible}, + types::{Constructor, IterNext, Iterable, Representable, SelfIter}, }; use crossbeam_utils::atomic::AtomicCell; @@ -189,7 +189,7 @@ pub(crate) mod _struct { } } - #[pyclass(with(Unconstructible, IterNext, Iterable))] + #[pyclass(with(IterNext, Iterable), flags(DISALLOW_INSTANTIATION))] impl UnpackIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -197,7 +197,7 @@ pub(crate) mod _struct { } } impl SelfIter for UnpackIterator {} - impl Unconstructible for UnpackIterator {} + impl IterNext for UnpackIterator { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let size = zelf.format_spec.size; diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index deff3c3a66a..bc84cffbf80 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -75,7 +75,7 @@ mod _sqlite { sliceable::{SaturatedSliceIter, SliceableSequenceOp}, types::{ AsMapping, AsNumber, AsSequence, Callable, Comparable, Constructor, Hashable, - Initializer, IterNext, Iterable, PyComparisonOp, SelfIter, Unconstructible, + Initializer, IterNext, Iterable, PyComparisonOp, SelfIter, }, utils::ToCString, }; @@ -2197,8 +2197,6 @@ mod _sqlite { inner: PyMutex>, } - impl Unconstructible for Blob {} - #[derive(Debug)] struct BlobInner { blob: SqliteBlob, @@ -2211,7 +2209,7 @@ mod _sqlite { } } - #[pyclass(with(AsMapping, Unconstructible, AsNumber, AsSequence))] + #[pyclass(flags(DISALLOW_INSTANTIATION), with(AsMapping, AsNumber, AsSequence))] impl Blob { #[pymethod] fn close(&self) { @@ -2592,9 +2590,7 @@ mod _sqlite { } } - impl Unconstructible for Statement {} - - #[pyclass(with(Unconstructible))] + #[pyclass(flags(DISALLOW_INSTANTIATION))] impl Statement { fn new( connection: &Connection, diff --git a/crates/stdlib/src/unicodedata.rs b/crates/stdlib/src/unicodedata.rs index 46e18357260..68d9a17e575 100644 --- a/crates/stdlib/src/unicodedata.rs +++ b/crates/stdlib/src/unicodedata.rs @@ -105,7 +105,7 @@ mod unicodedata { } } - #[pyclass] + #[pyclass(flags(DISALLOW_INSTANTIATION))] impl Ucd { #[pymethod] fn category(&self, character: PyStrRef, vm: &VirtualMachine) -> PyResult { diff --git a/crates/stdlib/src/zlib.rs b/crates/stdlib/src/zlib.rs index 328452ae9d5..9ca94939f78 100644 --- a/crates/stdlib/src/zlib.rs +++ b/crates/stdlib/src/zlib.rs @@ -225,7 +225,7 @@ mod zlib { inner: PyMutex, } - #[pyclass] + #[pyclass(flags(DISALLOW_INSTANTIATION))] impl PyDecompress { #[pygetset] fn eof(&self) -> bool { @@ -383,7 +383,7 @@ mod zlib { inner: PyMutex>, } - #[pyclass] + #[pyclass(flags(DISALLOW_INSTANTIATION))] impl PyCompress { #[pymethod] fn compress(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult> { diff --git a/crates/vm/src/builtins/asyncgenerator.rs b/crates/vm/src/builtins/asyncgenerator.rs index 073513184ff..483ba6a7f96 100644 --- a/crates/vm/src/builtins/asyncgenerator.rs +++ b/crates/vm/src/builtins/asyncgenerator.rs @@ -8,7 +8,7 @@ use crate::{ frame::FrameRef, function::OptionalArg, protocol::PyIterReturn, - types::{IterNext, Iterable, Representable, SelfIter, Unconstructible}, + types::{IterNext, Iterable, Representable, SelfIter}, }; use crossbeam_utils::atomic::AtomicCell; @@ -32,7 +32,7 @@ impl PyPayload for PyAsyncGen { } } -#[pyclass(with(PyRef, Unconstructible, Representable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(PyRef, Representable))] impl PyAsyncGen { pub const fn as_coro(&self) -> &Coro { &self.inner @@ -201,8 +201,6 @@ impl Representable for PyAsyncGen { } } -impl Unconstructible for PyAsyncGen {} - #[pyclass(module = false, name = "async_generator_wrapped_value")] #[derive(Debug)] pub(crate) struct PyAsyncGenWrappedValue(pub PyObjectRef); diff --git a/crates/vm/src/builtins/builtin_func.rs b/crates/vm/src/builtins/builtin_func.rs index d1ce107e374..d25188affd2 100644 --- a/crates/vm/src/builtins/builtin_func.rs +++ b/crates/vm/src/builtins/builtin_func.rs @@ -5,7 +5,7 @@ use crate::{ common::wtf8::Wtf8, convert::TryFromObject, function::{FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags, PyNativeFn}, - types::{Callable, Comparable, PyComparisonOp, Representable, Unconstructible}, + types::{Callable, Comparable, PyComparisonOp, Representable}, }; use std::fmt; @@ -74,7 +74,7 @@ impl Callable for PyNativeFunction { } } -#[pyclass(with(Callable, Unconstructible), flags(HAS_DICT))] +#[pyclass(with(Callable), flags(HAS_DICT, DISALLOW_INSTANTIATION))] impl PyNativeFunction { #[pygetset] fn __module__(zelf: NativeFunctionOrMethod) -> Option<&'static PyStrInterned> { @@ -145,8 +145,6 @@ impl Representable for PyNativeFunction { } } -impl Unconstructible for PyNativeFunction {} - // `PyCMethodObject` in CPython #[pyclass(name = "builtin_method", module = false, base = PyNativeFunction, ctx = "builtin_method_type")] pub struct PyNativeMethod { @@ -155,8 +153,8 @@ pub struct PyNativeMethod { } #[pyclass( - with(Unconstructible, Callable, Comparable, Representable), - flags(HAS_DICT) + with(Callable, Comparable, Representable), + flags(HAS_DICT, DISALLOW_INSTANTIATION) )] impl PyNativeMethod { #[pygetset] @@ -246,8 +244,6 @@ impl Representable for PyNativeMethod { } } -impl Unconstructible for PyNativeMethod {} - pub fn init(context: &Context) { PyNativeFunction::extend_class(context, context.types.builtin_function_or_method_type); PyNativeMethod::extend_class(context, context.types.builtin_method_type); diff --git a/crates/vm/src/builtins/bytearray.rs b/crates/vm/src/builtins/bytearray.rs index 32eaa2b3e27..c5861befb73 100644 --- a/crates/vm/src/builtins/bytearray.rs +++ b/crates/vm/src/builtins/bytearray.rs @@ -33,7 +33,7 @@ use crate::{ types::{ AsBuffer, AsMapping, AsNumber, AsSequence, Callable, Comparable, Constructor, DefaultConstructor, Initializer, IterNext, Iterable, PyComparisonOp, Representable, - SelfIter, Unconstructible, + SelfIter, }, }; use bstr::ByteSlice; @@ -865,7 +865,7 @@ impl PyPayload for PyByteArrayIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyByteArrayIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -886,8 +886,6 @@ impl PyByteArrayIterator { } } -impl Unconstructible for PyByteArrayIterator {} - impl SelfIter for PyByteArrayIterator {} impl IterNext for PyByteArrayIterator { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/builtins/bytes.rs b/crates/vm/src/builtins/bytes.rs index 70a33401271..f782c035f86 100644 --- a/crates/vm/src/builtins/bytes.rs +++ b/crates/vm/src/builtins/bytes.rs @@ -25,7 +25,7 @@ use crate::{ sliceable::{SequenceIndex, SliceableSequenceOp}, types::{ AsBuffer, AsMapping, AsNumber, AsSequence, Callable, Comparable, Constructor, Hashable, - IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, + IterNext, Iterable, PyComparisonOp, Representable, SelfIter, }, }; use bstr::ByteSlice; @@ -749,7 +749,7 @@ impl PyPayload for PyBytesIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyBytesIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -770,7 +770,6 @@ impl PyBytesIterator { .set_state(state, |obj, pos| pos.min(obj.len()), vm) } } -impl Unconstructible for PyBytesIterator {} impl SelfIter for PyBytesIterator {} impl IterNext for PyBytesIterator { diff --git a/crates/vm/src/builtins/coroutine.rs b/crates/vm/src/builtins/coroutine.rs index 0909cdfb444..21405448693 100644 --- a/crates/vm/src/builtins/coroutine.rs +++ b/crates/vm/src/builtins/coroutine.rs @@ -6,7 +6,7 @@ use crate::{ frame::FrameRef, function::OptionalArg, protocol::PyIterReturn, - types::{IterNext, Iterable, Representable, SelfIter, Unconstructible}, + types::{IterNext, Iterable, Representable, SelfIter}, }; use crossbeam_utils::atomic::AtomicCell; @@ -24,7 +24,7 @@ impl PyPayload for PyCoroutine { } } -#[pyclass(with(Py, Unconstructible, IterNext, Representable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(Py, IterNext, Representable))] impl PyCoroutine { pub const fn as_coro(&self) -> &Coro { &self.inner @@ -123,8 +123,6 @@ impl Py { } } -impl Unconstructible for PyCoroutine {} - impl Representable for PyCoroutine { #[inline] fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index bc3aded3253..cdcc456edfc 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -4,7 +4,7 @@ use crate::{ builtins::{PyTypeRef, builtin_func::PyNativeMethod, type_}, class::PyClassImpl, function::{FuncArgs, PyMethodDef, PyMethodFlags, PySetterValue}, - types::{Callable, GetDescriptor, Representable, Unconstructible}, + types::{Callable, GetDescriptor, Representable}, }; use rustpython_common::lock::PyRwLock; @@ -105,8 +105,8 @@ impl PyMethodDescriptor { } #[pyclass( - with(GetDescriptor, Callable, Unconstructible, Representable), - flags(METHOD_DESCRIPTOR) + with(GetDescriptor, Callable, Representable), + flags(METHOD_DESCRIPTOR, DISALLOW_INSTANTIATION) )] impl PyMethodDescriptor { #[pygetset] @@ -159,8 +159,6 @@ impl Representable for PyMethodDescriptor { } } -impl Unconstructible for PyMethodDescriptor {} - #[derive(Debug)] pub enum MemberKind { Bool = 14, @@ -246,7 +244,10 @@ fn calculate_qualname(descr: &PyDescriptorOwned, vm: &VirtualMachine) -> PyResul } } -#[pyclass(with(GetDescriptor, Unconstructible, Representable), flags(BASETYPE))] +#[pyclass( + with(GetDescriptor, Representable), + flags(BASETYPE, DISALLOW_INSTANTIATION) +)] impl PyMemberDescriptor { #[pygetset] fn __doc__(&self) -> Option { @@ -339,8 +340,6 @@ fn set_slot_at_object( Ok(()) } -impl Unconstructible for PyMemberDescriptor {} - impl Representable for PyMemberDescriptor { #[inline] fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index 77126d4ee62..b34299d4170 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -19,7 +19,7 @@ use crate::{ recursion::ReprGuard, types::{ AsMapping, AsNumber, AsSequence, Callable, Comparable, Constructor, DefaultConstructor, - Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, + Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, }, vm::VirtualMachine, }; @@ -848,7 +848,7 @@ macro_rules! dict_view { } } - #[pyclass(with(Unconstructible, IterNext, Iterable))] + #[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl $iter_name { fn new(dict: PyDictRef) -> Self { $iter_name { @@ -878,8 +878,6 @@ macro_rules! dict_view { } } - impl Unconstructible for $iter_name {} - impl SelfIter for $iter_name {} impl IterNext for $iter_name { #[allow(clippy::redundant_closure_call)] @@ -923,7 +921,7 @@ macro_rules! dict_view { } } - #[pyclass(with(Unconstructible, IterNext, Iterable))] + #[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl $reverse_iter_name { fn new(dict: PyDictRef) -> Self { let size = dict.size(); @@ -957,8 +955,6 @@ macro_rules! dict_view { .rev_length_hint(|_| self.size.entries_size) } } - impl Unconstructible for $reverse_iter_name {} - impl SelfIter for $reverse_iter_name {} impl IterNext for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] @@ -1126,16 +1122,18 @@ trait ViewSetOps: DictView { } impl ViewSetOps for PyDictKeys {} -#[pyclass(with( - DictView, - Unconstructible, - Comparable, - Iterable, - ViewSetOps, - AsSequence, - AsNumber, - Representable -))] +#[pyclass( + flags(DISALLOW_INSTANTIATION), + with( + DictView, + Comparable, + Iterable, + ViewSetOps, + AsSequence, + AsNumber, + Representable + ) +)] impl PyDictKeys { #[pymethod] fn __contains__(zelf: PyObjectRef, key: PyObjectRef, vm: &VirtualMachine) -> PyResult { @@ -1147,7 +1145,6 @@ impl PyDictKeys { PyMappingProxy::from(zelf.dict().clone()) } } -impl Unconstructible for PyDictKeys {} impl Comparable for PyDictKeys { fn cmp( @@ -1190,16 +1187,18 @@ impl AsNumber for PyDictKeys { } impl ViewSetOps for PyDictItems {} -#[pyclass(with( - DictView, - Unconstructible, - Comparable, - Iterable, - ViewSetOps, - AsSequence, - AsNumber, - Representable -))] +#[pyclass( + flags(DISALLOW_INSTANTIATION), + with( + DictView, + Comparable, + Iterable, + ViewSetOps, + AsSequence, + AsNumber, + Representable + ) +)] impl PyDictItems { #[pymethod] fn __contains__(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { @@ -1210,7 +1209,6 @@ impl PyDictItems { PyMappingProxy::from(zelf.dict().clone()) } } -impl Unconstructible for PyDictItems {} impl Comparable for PyDictItems { fn cmp( @@ -1264,14 +1262,16 @@ impl AsNumber for PyDictItems { } } -#[pyclass(with(DictView, Unconstructible, Iterable, AsSequence, Representable))] +#[pyclass( + flags(DISALLOW_INSTANTIATION), + with(DictView, Iterable, AsSequence, Representable) +)] impl PyDictValues { #[pygetset] fn mapping(zelf: PyRef) -> PyMappingProxy { PyMappingProxy::from(zelf.dict().clone()) } } -impl Unconstructible for PyDictValues {} impl AsSequence for PyDictValues { fn as_sequence() -> &'static PySequenceMethods { diff --git a/crates/vm/src/builtins/frame.rs b/crates/vm/src/builtins/frame.rs index 17dc88ac042..6ccc594d338 100644 --- a/crates/vm/src/builtins/frame.rs +++ b/crates/vm/src/builtins/frame.rs @@ -8,7 +8,7 @@ use crate::{ class::PyClassImpl, frame::{Frame, FrameRef}, function::PySetterValue, - types::{Representable, Unconstructible}, + types::Representable, }; use num_traits::Zero; @@ -16,8 +16,6 @@ pub fn init(context: &Context) { Frame::extend_class(context, context.types.frame_type); } -impl Unconstructible for Frame {} - impl Representable for Frame { #[inline] fn repr(_zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -31,7 +29,7 @@ impl Representable for Frame { } } -#[pyclass(with(Unconstructible, Py))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(Py))] impl Frame { #[pymethod] const fn clear(&self) { diff --git a/crates/vm/src/builtins/generator.rs b/crates/vm/src/builtins/generator.rs index da981b5a6c2..04cd7dd3456 100644 --- a/crates/vm/src/builtins/generator.rs +++ b/crates/vm/src/builtins/generator.rs @@ -10,7 +10,7 @@ use crate::{ frame::FrameRef, function::OptionalArg, protocol::PyIterReturn, - types::{IterNext, Iterable, Representable, SelfIter, Unconstructible}, + types::{IterNext, Iterable, Representable, SelfIter}, }; #[pyclass(module = false, name = "generator")] @@ -26,7 +26,7 @@ impl PyPayload for PyGenerator { } } -#[pyclass(with(Py, Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(Py, IterNext, Iterable))] impl PyGenerator { pub const fn as_coro(&self) -> &Coro { &self.inner @@ -114,8 +114,6 @@ impl Py { } } -impl Unconstructible for PyGenerator {} - impl Representable for PyGenerator { #[inline] fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/builtins/getset.rs b/crates/vm/src/builtins/getset.rs index 4b966bbc31b..f56191f5f8b 100644 --- a/crates/vm/src/builtins/getset.rs +++ b/crates/vm/src/builtins/getset.rs @@ -7,7 +7,7 @@ use crate::{ builtins::type_::PointerSlot, class::PyClassImpl, function::{IntoPyGetterFunc, IntoPySetterFunc, PyGetterFunc, PySetterFunc, PySetterValue}, - types::{GetDescriptor, Unconstructible}, + types::GetDescriptor, }; #[pyclass(module = false, name = "getset_descriptor")] @@ -96,7 +96,7 @@ impl PyGetSet { } } -#[pyclass(with(GetDescriptor, Unconstructible))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(GetDescriptor))] impl PyGetSet { // Descriptor methods @@ -152,7 +152,6 @@ impl PyGetSet { Ok(unsafe { zelf.class.borrow_static() }.to_owned().into()) } } -impl Unconstructible for PyGetSet {} pub(crate) fn init(context: &Context) { PyGetSet::extend_class(context, context.types.getset_type); diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index b2927462cac..0683381f38a 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -15,7 +15,7 @@ use crate::{ sliceable::{SequenceIndex, SliceableSequenceMutOp, SliceableSequenceOp}, types::{ AsMapping, AsSequence, Comparable, Constructor, Initializer, IterNext, Iterable, - PyComparisonOp, Representable, SelfIter, Unconstructible, + PyComparisonOp, Representable, SelfIter, }, utils::collection_repr, vm::VirtualMachine, @@ -544,7 +544,7 @@ impl PyPayload for PyListIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyListIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -565,7 +565,6 @@ impl PyListIterator { .builtins_iter_reduce(|x| x.clone().into(), vm) } } -impl Unconstructible for PyListIterator {} impl SelfIter for PyListIterator {} impl IterNext for PyListIterator { @@ -590,7 +589,7 @@ impl PyPayload for PyListReverseIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyListReverseIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -611,7 +610,6 @@ impl PyListReverseIterator { .builtins_reversed_reduce(|x| x.clone().into(), vm) } } -impl Unconstructible for PyListReverseIterator {} impl SelfIter for PyListReverseIterator {} impl IterNext for PyListReverseIterator { diff --git a/crates/vm/src/builtins/memory.rs b/crates/vm/src/builtins/memory.rs index c1b1496e8c6..4e895f92b7e 100644 --- a/crates/vm/src/builtins/memory.rs +++ b/crates/vm/src/builtins/memory.rs @@ -23,7 +23,7 @@ use crate::{ sliceable::SequenceIndexOp, types::{ AsBuffer, AsMapping, AsSequence, Comparable, Constructor, Hashable, IterNext, Iterable, - PyComparisonOp, Representable, SelfIter, Unconstructible, + PyComparisonOp, Representable, SelfIter, }, }; use crossbeam_utils::atomic::AtomicCell; @@ -1132,7 +1132,7 @@ impl PyPayload for PyMemoryViewIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyMemoryViewIterator { #[pymethod] fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef { @@ -1141,7 +1141,6 @@ impl PyMemoryViewIterator { .builtins_iter_reduce(|x| x.clone().into(), vm) } } -impl Unconstructible for PyMemoryViewIterator {} impl SelfIter for PyMemoryViewIterator {} impl IterNext for PyMemoryViewIterator { diff --git a/crates/vm/src/builtins/range.rs b/crates/vm/src/builtins/range.rs index 3edd130ee28..9f79f8efb2d 100644 --- a/crates/vm/src/builtins/range.rs +++ b/crates/vm/src/builtins/range.rs @@ -11,7 +11,7 @@ use crate::{ protocol::{PyIterReturn, PyMappingMethods, PySequenceMethods}, types::{ AsMapping, AsSequence, Comparable, Hashable, IterNext, Iterable, PyComparisonOp, - Representable, SelfIter, Unconstructible, + Representable, SelfIter, }, }; use crossbeam_utils::atomic::AtomicCell; @@ -548,7 +548,7 @@ impl PyPayload for PyLongRangeIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyLongRangeIterator { #[pymethod] fn __length_hint__(&self) -> BigInt { @@ -577,7 +577,6 @@ impl PyLongRangeIterator { ) } } -impl Unconstructible for PyLongRangeIterator {} impl SelfIter for PyLongRangeIterator {} impl IterNext for PyLongRangeIterator { @@ -614,7 +613,7 @@ impl PyPayload for PyRangeIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyRangeIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -640,7 +639,6 @@ impl PyRangeIterator { ) } } -impl Unconstructible for PyRangeIterator {} impl SelfIter for PyRangeIterator {} impl IterNext for PyRangeIterator { diff --git a/crates/vm/src/builtins/set.rs b/crates/vm/src/builtins/set.rs index 7fde8d32781..5582ff3323c 100644 --- a/crates/vm/src/builtins/set.rs +++ b/crates/vm/src/builtins/set.rs @@ -18,7 +18,7 @@ use crate::{ types::AsNumber, types::{ AsSequence, Comparable, Constructor, DefaultConstructor, Hashable, Initializer, IterNext, - Iterable, PyComparisonOp, Representable, SelfIter, Unconstructible, + Iterable, PyComparisonOp, Representable, SelfIter, }, utils::collection_repr, vm::VirtualMachine, @@ -1304,7 +1304,7 @@ impl PyPayload for PySetIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PySetIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -1330,7 +1330,6 @@ impl PySetIterator { )) } } -impl Unconstructible for PySetIterator {} impl SelfIter for PySetIterator {} impl IterNext for PySetIterator { diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index 9b05e195722..279b84362a6 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -21,7 +21,7 @@ use crate::{ sliceable::{SequenceIndex, SliceableSequenceOp}, types::{ AsMapping, AsNumber, AsSequence, Comparable, Constructor, Hashable, IterNext, Iterable, - PyComparisonOp, Representable, SelfIter, Unconstructible, + PyComparisonOp, Representable, SelfIter, }, }; use ascii::{AsciiChar, AsciiStr, AsciiString}; @@ -282,7 +282,7 @@ impl PyPayload for PyStrIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyStrIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -307,8 +307,6 @@ impl PyStrIterator { } } -impl Unconstructible for PyStrIterator {} - impl SelfIter for PyStrIterator {} impl IterNext for PyStrIterator { diff --git a/crates/vm/src/builtins/tuple.rs b/crates/vm/src/builtins/tuple.rs index fada8840bb1..adc5b483de3 100644 --- a/crates/vm/src/builtins/tuple.rs +++ b/crates/vm/src/builtins/tuple.rs @@ -16,7 +16,7 @@ use crate::{ sliceable::{SequenceIndex, SliceableSequenceOp}, types::{ AsMapping, AsSequence, Comparable, Constructor, Hashable, IterNext, Iterable, - PyComparisonOp, Representable, SelfIter, Unconstructible, + PyComparisonOp, Representable, SelfIter, }, utils::collection_repr, vm::VirtualMachine, @@ -533,7 +533,7 @@ impl PyPayload for PyTupleIterator { } } -#[pyclass(with(Unconstructible, IterNext, Iterable))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl PyTupleIterator { #[pymethod] fn __length_hint__(&self) -> usize { @@ -554,7 +554,6 @@ impl PyTupleIterator { .builtins_iter_reduce(|x| x.clone().into(), vm) } } -impl Unconstructible for PyTupleIterator {} impl SelfIter for PyTupleIterator {} impl IterNext for PyTupleIterator { diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 0f619b1399a..d014cdf015e 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -191,13 +191,16 @@ impl PyType { name: &str, bases: Vec>, attrs: PyAttributes, - slots: PyTypeSlots, + mut slots: PyTypeSlots, metaclass: PyRef, ctx: &Context, ) -> Result, String> { // TODO: ensure clean slot name // assert_eq!(slots.name.borrow(), ""); + // Set HEAPTYPE flag for heap-allocated types + slots.flags |= PyTypeFlags::HEAPTYPE; + let name = ctx.new_str(name); let heaptype_ext = HeapTypeExt { name: PyRwLock::new(name.clone()), @@ -401,6 +404,8 @@ impl PyType { None, ); + Self::set_new(&new_type.slots, &new_type.base); + let weakref_type = super::PyWeak::static_type(); for base in new_type.bases.read().iter() { base.subclasses.write().push( @@ -420,9 +425,6 @@ impl PyType { for cls in self.mro.read().iter() { for &name in cls.attributes.read().keys() { - if name == identifier!(ctx, __new__) { - continue; - } if name.as_bytes().starts_with(b"__") && name.as_bytes().ends_with(b"__") { slot_name_set.insert(name); } @@ -436,6 +438,20 @@ impl PyType { for attr_name in slot_name_set { self.update_slot::(attr_name, ctx); } + + Self::set_new(&self.slots, &self.base); + } + + fn set_new(slots: &PyTypeSlots, base: &Option) { + if slots.flags.contains(PyTypeFlags::DISALLOW_INSTANTIATION) { + slots.new.store(None) + } else if slots.new.load().is_none() { + slots.new.store( + base.as_ref() + .map(|base| base.slots.new.load()) + .unwrap_or(None), + ) + } } // This is used for class initialization where the vm is not yet available. @@ -1563,15 +1579,28 @@ impl Callable for PyType { type Args = FuncArgs; fn call(zelf: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { vm_trace!("type_call: {:?}", zelf); - let obj = call_slot_new(zelf.to_owned(), zelf.to_owned(), args.clone(), vm)?; - if (zelf.is(vm.ctx.types.type_type) && args.kwargs.is_empty()) || !obj.fast_isinstance(zelf) - { + if zelf.is(vm.ctx.types.type_type) { + let num_args = args.args.len(); + if num_args == 1 && args.kwargs.is_empty() { + return Ok(args.args[0].obj_type()); + } + if num_args != 3 { + return Err(vm.new_type_error("type() takes 1 or 3 arguments".to_owned())); + } + } + + let obj = if let Some(slot_new) = zelf.slots.new.load() { + slot_new(zelf.to_owned(), args.clone(), vm)? + } else { + return Err(vm.new_type_error(format!("cannot create '{}' instances", zelf.slots.name))); + }; + + if !obj.class().fast_issubclass(zelf) { return Ok(obj); } - let init = obj.class().mro_find_map(|cls| cls.slots.init.load()); - if let Some(init_method) = init { + if let Some(init_method) = obj.class().slots.init.load() { init_method(obj.clone(), args, vm)?; } Ok(obj) @@ -1700,6 +1729,40 @@ pub(crate) fn call_slot_new( args: FuncArgs, vm: &VirtualMachine, ) -> PyResult { + // Check DISALLOW_INSTANTIATION flag on subtype (the type being instantiated) + if subtype + .slots + .flags + .has_feature(PyTypeFlags::DISALLOW_INSTANTIATION) + { + return Err(vm.new_type_error(format!("cannot create '{}' instances", subtype.slot_name()))); + } + + // "is not safe" check (tp_new_wrapper logic) + // Check that the user doesn't do something silly and unsafe like + // object.__new__(dict). To do this, we check that the most derived base + // that's not a heap type is this type. + let mut staticbase = subtype.clone(); + while staticbase.slots.flags.has_feature(PyTypeFlags::HEAPTYPE) { + if let Some(base) = staticbase.base.as_ref() { + staticbase = base.clone(); + } else { + break; + } + } + + // Check if staticbase's tp_new differs from typ's tp_new + let typ_new = typ.slots.new.load(); + let staticbase_new = staticbase.slots.new.load(); + if typ_new.map(|f| f as usize) != staticbase_new.map(|f| f as usize) { + return Err(vm.new_type_error(format!( + "{}.__new__({}) is not safe, use {}.__new__()", + typ.slot_name(), + subtype.slot_name(), + staticbase.slot_name() + ))); + } + let slot_new = typ .deref() .mro_find_map(|cls| cls.slots.new.load()) diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index 6b5a02dea73..6a366385702 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -116,13 +116,23 @@ pub trait PyClassImpl: PyClassDef { ); } - if class.slots.new.load().is_some() { - let bound_new = Context::genesis().slot_new_wrapper.build_bound_method( - ctx, - class.to_owned().into(), - class, - ); - class.set_attr(identifier!(ctx, __new__), bound_new.into()); + // Don't add __new__ attribute if slot_new is inherited from object + // (Python doesn't add __new__ to __dict__ for inherited slots) + // Exception: object itself should have __new__ in its dict + if let Some(slot_new) = class.slots.new.load() { + let object_new = ctx.types.object_type.slots.new.load(); + let is_object_itself = std::ptr::eq(class, ctx.types.object_type); + let is_inherited_from_object = !is_object_itself + && object_new.is_some_and(|obj_new| slot_new as usize == obj_new as usize); + + if !is_inherited_from_object { + let bound_new = Context::genesis().slot_new_wrapper.build_bound_method( + ctx, + class.to_owned().into(), + class, + ); + class.set_attr(identifier!(ctx, __new__), bound_new.into()); + } } if class.slots.hash.load().map_or(0, |h| h as usize) == hash_not_implemented as usize { diff --git a/crates/vm/src/protocol/buffer.rs b/crates/vm/src/protocol/buffer.rs index 1b1a4a14df5..1dafda203d9 100644 --- a/crates/vm/src/protocol/buffer.rs +++ b/crates/vm/src/protocol/buffer.rs @@ -9,7 +9,6 @@ use crate::{ }, object::PyObjectPayload, sliceable::SequenceIndexOp, - types::Unconstructible, }; use itertools::Itertools; use std::{borrow::Cow, fmt::Debug, ops::Range}; @@ -402,7 +401,7 @@ pub struct VecBuffer { data: PyMutex>, } -#[pyclass(flags(BASETYPE), with(Unconstructible))] +#[pyclass(flags(BASETYPE, DISALLOW_INSTANTIATION))] impl VecBuffer { pub fn take(&self) -> Vec { std::mem::take(&mut self.data.lock()) @@ -417,8 +416,6 @@ impl From> for VecBuffer { } } -impl Unconstructible for VecBuffer {} - impl PyRef { pub fn into_pybuffer(self, readonly: bool) -> PyBuffer { let len = self.data.lock().len(); diff --git a/crates/vm/src/stdlib/ctypes/field.rs b/crates/vm/src/stdlib/ctypes/field.rs index e760f07d035..659255f3329 100644 --- a/crates/vm/src/stdlib/ctypes/field.rs +++ b/crates/vm/src/stdlib/ctypes/field.rs @@ -1,6 +1,6 @@ use crate::builtins::PyType; use crate::function::PySetterValue; -use crate::types::{GetDescriptor, Representable, Unconstructible}; +use crate::types::{GetDescriptor, Representable}; use crate::{AsObject, Py, PyObjectRef, PyResult, VirtualMachine}; use num_traits::ToPrimitive; @@ -85,8 +85,6 @@ impl Representable for PyCField { } } -impl Unconstructible for PyCField {} - impl GetDescriptor for PyCField { fn descr_get( zelf: PyObjectRef, @@ -184,7 +182,7 @@ impl PyCField { #[pyclass( flags(DISALLOW_INSTANTIATION, IMMUTABLETYPE), - with(Unconstructible, Representable, GetDescriptor) + with(Representable, GetDescriptor) )] impl PyCField { #[pyslot] diff --git a/crates/vm/src/stdlib/os.rs b/crates/vm/src/stdlib/os.rs index a698cae059c..868fc727c65 100644 --- a/crates/vm/src/stdlib/os.rs +++ b/crates/vm/src/stdlib/os.rs @@ -167,7 +167,7 @@ pub(super) mod _os { ospath::{OsPath, OsPathOrFd, OutputMode}, protocol::PyIterReturn, recursion::ReprGuard, - types::{IterNext, Iterable, PyStructSequence, Representable, SelfIter, Unconstructible}, + types::{IterNext, Iterable, PyStructSequence, Representable, SelfIter}, utils::ToCString, vm::VirtualMachine, }; @@ -570,7 +570,7 @@ pub(super) mod _os { ino: AtomicCell>, } - #[pyclass(with(Representable, Unconstructible))] + #[pyclass(flags(DISALLOW_INSTANTIATION), with(Representable))] impl DirEntry { #[pygetset] fn name(&self, vm: &VirtualMachine) -> PyResult { @@ -760,8 +760,6 @@ pub(super) mod _os { } } } - impl Unconstructible for DirEntry {} - #[pyattr] #[pyclass(name = "ScandirIter")] #[derive(Debug, PyPayload)] @@ -770,7 +768,7 @@ pub(super) mod _os { mode: OutputMode, } - #[pyclass(with(IterNext, Iterable, Unconstructible))] + #[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))] impl ScandirIterator { #[pymethod] fn close(&self) { @@ -788,7 +786,6 @@ pub(super) mod _os { zelf.close() } } - impl Unconstructible for ScandirIterator {} impl SelfIter for ScandirIterator {} impl IterNext for ScandirIterator { fn next(zelf: &crate::Py, vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index f52e7296a7b..5deb593818e 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -922,15 +922,6 @@ pub trait DefaultConstructor: PyPayload + Default + std::fmt::Debug { } } -/// For types that cannot be instantiated through Python code. -#[pyclass] -pub trait Unconstructible: PyPayload { - #[pyslot] - fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error(format!("cannot create '{}' instances", cls.slot_name()))) - } -} - impl Constructor for T where T: DefaultConstructor, From 9aa1f189982ff561478f30315c19c47b9c9d9cde Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 17 Dec 2025 00:16:51 +0900 Subject: [PATCH 004/418] Unraiseable traceback (#6449) --- crates/vm/src/stdlib/sys.rs | 50 ++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 45b1d566058..52ecf927d5c 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -15,7 +15,7 @@ mod sys { }, convert::ToPyObject, frame::FrameRef, - function::{FuncArgs, OptionalArg, PosArgs}, + function::{FuncArgs, KwArgs, OptionalArg, PosArgs}, stdlib::{builtins, warnings::warn}, types::PyStructSequence, version, @@ -688,13 +688,21 @@ mod sys { writeln!(stderr, "{}:", unraisable.err_msg.str(vm)?); } - // TODO: print received unraisable.exc_traceback - let tb_module = vm.import("traceback", 0)?; - let print_stack = tb_module.get_attr("print_stack", vm)?; - print_stack.call((), vm)?; + // Print traceback (using actual exc_traceback, not current stack) + if !vm.is_none(&unraisable.exc_traceback) { + let tb_module = vm.import("traceback", 0)?; + let print_tb = tb_module.get_attr("print_tb", vm)?; + let stderr_obj = super::get_stderr(vm)?; + let kwargs: KwArgs = [("file".to_string(), stderr_obj)].into_iter().collect(); + let _ = print_tb.call( + FuncArgs::new(vec![unraisable.exc_traceback.clone()], kwargs), + vm, + ); + } + // Check exc_type if vm.is_none(unraisable.exc_type.as_object()) { - // TODO: early return, but with what error? + return Ok(()); } assert!( unraisable @@ -702,10 +710,28 @@ mod sys { .fast_issubclass(vm.ctx.exceptions.base_exception_type) ); - // TODO: print module name and qualname + // Print module name (if not builtins or __main__) + let module_name = unraisable.exc_type.__module__(vm); + if let Ok(module_str) = module_name.downcast::() { + let module = module_str.as_str(); + if module != "builtins" && module != "__main__" { + write!(stderr, "{}.", module); + } + } else { + write!(stderr, "."); + } + // Print qualname + let qualname = unraisable.exc_type.__qualname__(vm); + if let Ok(qualname_str) = qualname.downcast::() { + write!(stderr, "{}", qualname_str.as_str()); + } else { + write!(stderr, "{}", unraisable.exc_type.name()); + } + + // Print exception value if !vm.is_none(&unraisable.exc_value) { - write!(stderr, "{}: ", unraisable.exc_type); + write!(stderr, ": "); if let Ok(str) = unraisable.exc_value.str(vm) { write!(stderr, "{}", str.to_str().unwrap_or("")); } else { @@ -713,7 +739,13 @@ mod sys { } } writeln!(stderr); - // TODO: call file.flush() + + // Flush stderr + if let Ok(stderr_obj) = super::get_stderr(vm) + && let Ok(flush) = stderr_obj.get_attr("flush", vm) + { + let _ = flush.call((), vm); + } Ok(()) } From adc2b0dbbe043095be17c037e3daa39b5351c96c Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:28:06 +0100 Subject: [PATCH 005/418] Update `test_zipfile64.py` from 3.13.11 (#6433) * Update `test_zipfile64.py` from 3.13.11 * Mark flaky test --- Lib/test/test_import/__init__.py | 1 + Lib/test/test_zipfile64.py | 13 +++++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py index 44e7da1033d..85913cd7c58 100644 --- a/Lib/test/test_import/__init__.py +++ b/Lib/test/test_import/__init__.py @@ -457,6 +457,7 @@ def test_issue31492(self): with self.assertRaises(AttributeError): os.does_not_exist + @unittest.skipIf(sys.platform == 'win32', 'TODO: RUSTPYTHON; Flaky') @threading_helper.requires_working_threading() def test_concurrency(self): # bpo 38091: this is a hack to slow down the code that calls diff --git a/Lib/test/test_zipfile64.py b/Lib/test/test_zipfile64.py index 0947013afbc..2e1affe0252 100644 --- a/Lib/test/test_zipfile64.py +++ b/Lib/test/test_zipfile64.py @@ -11,7 +11,7 @@ 'test requires loads of disk-space bytes and a long time to run' ) -import zipfile, os, unittest +import zipfile, unittest import time import sys @@ -32,10 +32,6 @@ def setUp(self): line_gen = ("Test of zipfile line %d." % i for i in range(1000000)) self.data = '\n'.join(line_gen).encode('ascii') - # And write it to a file. - with open(TESTFN, "wb") as fp: - fp.write(self.data) - def zipTest(self, f, compression): # Create the ZIP archive. with zipfile.ZipFile(f, "w", compression) as zipfp: @@ -67,6 +63,9 @@ def zipTest(self, f, compression): (num, filecount)), file=sys.__stdout__) sys.__stdout__.flush() + # Check that testzip thinks the archive is valid + self.assertIsNone(zipfp.testzip()) + def testStored(self): # Try the temp file first. If we do TESTFN2 first, then it hogs # gigabytes of disk space for the duration of the test. @@ -85,9 +84,7 @@ def testDeflated(self): self.zipTest(TESTFN2, zipfile.ZIP_DEFLATED) def tearDown(self): - for fname in TESTFN, TESTFN2: - if os.path.exists(fname): - os.remove(fname) + os_helper.unlink(TESTFN2) class OtherTests(unittest.TestCase): From f3916950bf823f398b25cf554f0485787a2a84be Mon Sep 17 00:00:00 2001 From: ShaharNaveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Tue, 16 Dec 2025 21:35:55 +0100 Subject: [PATCH 006/418] Skip flakey test --- Lib/test/test_subprocess.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py index 6b20a5c00a5..e04f8b8fcc4 100644 --- a/Lib/test/test_subprocess.py +++ b/Lib/test/test_subprocess.py @@ -1770,8 +1770,7 @@ def test_run_with_pathlike_path_and_arguments(self): res = subprocess.run(args) self.assertEqual(res.returncode, 57) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.skipIf(mswindows, 'TODO: RUSTPYTHON; Flakey') @unittest.skipUnless(mswindows, "Maybe test trigger a leak on Ubuntu") def test_run_with_an_empty_env(self): # gh-105436: fix subprocess.run(..., env={}) broken on Windows From be559ae4530f85c9101f166fd702646353f7f064 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Wed, 17 Dec 2025 00:15:25 +0100 Subject: [PATCH 007/418] Update `support/import_helper.py` from 3.13.11 (#6451) * Update `support/import_helper.py` from 3.13.11 * Unmark passing test --- Lib/test/support/import_helper.py | 43 ++++++++++++++++++++++++++++--- Lib/test/test_super.py | 1 - 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 67f18e530ed..2b91bdcf9cd 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -8,7 +8,7 @@ import unittest import warnings -from .os_helper import unlink +from .os_helper import unlink, temp_dir @contextlib.contextmanager @@ -58,8 +58,8 @@ def make_legacy_pyc(source): :return: The file system path to the legacy pyc file. """ pyc_file = importlib.util.cache_from_source(source) - up_one = os.path.dirname(os.path.abspath(source)) - legacy_pyc = os.path.join(up_one, source + 'c') + assert source.endswith('.py') + legacy_pyc = source + 'c' shutil.move(pyc_file, legacy_pyc) return legacy_pyc @@ -114,7 +114,7 @@ def multi_interp_extensions_check(enabled=True): This only applies to modules that haven't been imported yet. It overrides the PyInterpreterConfig.check_multi_interp_extensions setting (see support.run_in_subinterp_with_config() and - _xxsubinterpreters.create()). + _interpreters.create()). Also see importlib.utils.allowing_all_extensions(). """ @@ -268,9 +268,44 @@ def modules_cleanup(oldmodules): sys.modules.update(oldmodules) +@contextlib.contextmanager +def isolated_modules(): + """ + Save modules on entry and cleanup on exit. + """ + (saved,) = modules_setup() + try: + yield + finally: + modules_cleanup(saved) + + def mock_register_at_fork(func): # bpo-30599: Mock os.register_at_fork() when importing the random module, # since this function doesn't allow to unregister callbacks and would leak # memory. from unittest import mock return mock.patch('os.register_at_fork', create=True)(func) + + +@contextlib.contextmanager +def ready_to_import(name=None, source=""): + from test.support import script_helper + + # 1. Sets up a temporary directory and removes it afterwards + # 2. Creates the module file + # 3. Temporarily clears the module from sys.modules (if any) + # 4. Reverts or removes the module when cleaning up + name = name or "spam" + with temp_dir() as tempdir: + path = script_helper.make_script(tempdir, name, source) + old_module = sys.modules.pop(name, None) + try: + sys.path.insert(0, tempdir) + yield name, path + sys.path.remove(tempdir) + finally: + if old_module is not None: + sys.modules[name] = old_module + else: + sys.modules.pop(name, None) diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py index 8967dab8bdd..76eda799da4 100644 --- a/Lib/test/test_super.py +++ b/Lib/test/test_super.py @@ -349,7 +349,6 @@ def test_super_argtype(self): with self.assertRaisesRegex(TypeError, "argument 1 must be a type"): super(1, int) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'test.support.import_helper' has no attribute 'ready_to_import' def test_shadowed_global(self): source = textwrap.dedent( """ From 9d2dd1745565547d3fcaf0b792612a48192c02de Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Wed, 17 Dec 2025 11:02:44 +0100 Subject: [PATCH 008/418] Add regression test (#6452) --- extra_tests/snippets/builtin_bytes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/extra_tests/snippets/builtin_bytes.py b/extra_tests/snippets/builtin_bytes.py index 9347fbc8fab..4cb743baa6f 100644 --- a/extra_tests/snippets/builtin_bytes.py +++ b/extra_tests/snippets/builtin_bytes.py @@ -1,3 +1,5 @@ +import sys + from testutils import assert_raises, skip_if_unsupported # new @@ -611,6 +613,9 @@ assert b"\xc2\xae\x75\x73\x74".decode() == "®ust" assert b"\xe4\xb8\xad\xe6\x96\x87\xe5\xad\x97".decode("utf-8") == "中文字" +# gh-2391 +assert b"-\xff".decode(sys.getfilesystemencoding(), "surrogateescape") == "-\udcff" + # mod assert b"rust%bpython%b" % (b" ", b"!") == b"rust python!" assert b"x=%i y=%f" % (1, 2.5) == b"x=1 y=2.500000" From 0412dfdb3b4b64aeacf82822352ae6c93580688c Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:57:34 +0900 Subject: [PATCH 009/418] Fix winerror handling (#6454) --- crates/vm/src/exceptions.rs | 122 +++++++++++++++++++++++++----------- crates/vm/src/stdlib/io.rs | 11 ++-- 2 files changed, 94 insertions(+), 39 deletions(-) diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 2b725085bec..2958d1e497e 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -1677,7 +1677,7 @@ pub(super) mod types { // SAFETY: slot_init is called during object initialization, // so fields are None and swap result can be safely ignored if len <= 5 { - // Only set errno/strerror when args len is 2-5 (CPython behavior) + // Only set errno/strerror when args len is 2-5 if 2 <= len { let _ = unsafe { exc.errno.swap(Some(new_args.args[0].clone())) }; let _ = unsafe { exc.strerror.swap(Some(new_args.args[1].clone())) }; @@ -1708,8 +1708,10 @@ pub(super) mod types { } } - // args are truncated to 2 for compatibility (only when 2-5 args) - if (3..=5).contains(&len) { + // args are truncated to 2 for compatibility (only when 2-5 args and filename is not None) + // truncation happens inside "if (filename && filename != Py_None)" block + let has_filename = exc.filename.to_owned().filter(|f| !vm.is_none(f)).is_some(); + if (3..=5).contains(&len) && has_filename { new_args.args.truncate(2); } PyBaseException::slot_init(zelf, new_args, vm) @@ -1717,44 +1719,94 @@ pub(super) mod types { #[pymethod] fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - let args = exc.args(); let obj = exc.as_object().to_owned(); - let str = if args.len() == 2 { - // SAFETY: len() == 2 is checked so get_arg 1 or 2 won't panic - let errno = exc.get_arg(0).unwrap().str(vm)?; - let msg = exc.get_arg(1).unwrap().str(vm)?; - - // On Windows, use [WinError X] format when winerror is set - #[cfg(windows)] - let (label, code) = match obj.get_attr("winerror", vm) { - Ok(winerror) if !vm.is_none(&winerror) => ("WinError", winerror.str(vm)?), - _ => ("Errno", errno.clone()), - }; - #[cfg(not(windows))] - let (label, code) = ("Errno", errno.clone()); + // Get OSError fields directly + let errno_field = obj.get_attr("errno", vm).ok().filter(|v| !vm.is_none(v)); + let strerror = obj.get_attr("strerror", vm).ok().filter(|v| !vm.is_none(v)); + let filename = obj.get_attr("filename", vm).ok().filter(|v| !vm.is_none(v)); + let filename2 = obj + .get_attr("filename2", vm) + .ok() + .filter(|v| !vm.is_none(v)); + #[cfg(windows)] + let winerror = obj.get_attr("winerror", vm).ok().filter(|v| !vm.is_none(v)); - let s = match obj.get_attr("filename", vm) { - Ok(filename) if !vm.is_none(&filename) => match obj.get_attr("filename2", vm) { - Ok(filename2) if !vm.is_none(&filename2) => format!( - "[{} {}] {}: '{}' -> '{}'", - label, + // Windows: winerror takes priority over errno + #[cfg(windows)] + if let Some(ref win_err) = winerror { + let code = win_err.str(vm)?; + if let Some(ref f) = filename { + let msg = strerror + .as_ref() + .map(|s| s.str(vm)) + .transpose()? + .map(|s| s.to_string()) + .unwrap_or_else(|| "None".to_owned()); + if let Some(ref f2) = filename2 { + return Ok(vm.ctx.new_str(format!( + "[WinError {}] {}: {} -> {}", code, msg, - filename.str(vm)?, - filename2.str(vm)? - ), - _ => format!("[{} {}] {}: '{}'", label, code, msg, filename.str(vm)?), - }, - _ => { - format!("[{label} {code}] {msg}") + f.repr(vm)?, + f2.repr(vm)? + ))); } - }; - vm.ctx.new_str(s) - } else { - exc.__str__(vm) - }; - Ok(str) + return Ok(vm.ctx.new_str(format!( + "[WinError {}] {}: {}", + code, + msg, + f.repr(vm)? + ))); + } + // winerror && strerror (no filename) + if let Some(ref s) = strerror { + return Ok(vm + .ctx + .new_str(format!("[WinError {}] {}", code, s.str(vm)?))); + } + } + + // Non-Windows or fallback: use errno + if let Some(ref f) = filename { + let errno_str = errno_field + .as_ref() + .map(|e| e.str(vm)) + .transpose()? + .map(|s| s.to_string()) + .unwrap_or_else(|| "None".to_owned()); + let msg = strerror + .as_ref() + .map(|s| s.str(vm)) + .transpose()? + .map(|s| s.to_string()) + .unwrap_or_else(|| "None".to_owned()); + if let Some(ref f2) = filename2 { + return Ok(vm.ctx.new_str(format!( + "[Errno {}] {}: {} -> {}", + errno_str, + msg, + f.repr(vm)?, + f2.repr(vm)? + ))); + } + return Ok(vm.ctx.new_str(format!( + "[Errno {}] {}: {}", + errno_str, + msg, + f.repr(vm)? + ))); + } + + // errno && strerror (no filename) + if let (Some(e), Some(s)) = (&errno_field, &strerror) { + return Ok(vm + .ctx + .new_str(format!("[Errno {}] {}", e.str(vm)?, s.str(vm)?))); + } + + // fallback to BaseException.__str__ + Ok(exc.__str__(vm)) } #[pymethod] diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs index 56bcbefddeb..547e482f13a 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/io.rs @@ -1084,10 +1084,13 @@ mod _io { self.write_end = buffer_size; // TODO: BlockingIOError(errno, msg, written) // written += self.buffer.len(); - return Err(vm.new_exception_msg( - vm.ctx.exceptions.blocking_io_error.to_owned(), - "write could not complete without blocking".to_owned(), - )); + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.blocking_io_error.to_owned(), + None, + "write could not complete without blocking".to_owned(), + ) + .upcast()); } else { break; } From a4c93dfbbfed868e3fe896463a39d401079a3160 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:58:00 +0900 Subject: [PATCH 010/418] Remove useless &PyRef patterns (#6455) * fix &PyTypeRef * Remove useless &PyObjectRef * Remove useless &PyRef --- crates/stdlib/src/csv.rs | 18 ++++---- crates/stdlib/src/openssl.rs | 5 ++- crates/stdlib/src/scproxy.rs | 6 +-- crates/stdlib/src/ssl.rs | 14 +++---- crates/stdlib/src/ssl/compat.rs | 8 ++-- crates/vm/src/builtins/builtin_func.rs | 4 +- crates/vm/src/builtins/dict.rs | 10 ++--- crates/vm/src/builtins/function/jit.rs | 6 ++- crates/vm/src/builtins/genericalias.rs | 6 +-- crates/vm/src/builtins/list.rs | 4 +- crates/vm/src/builtins/property.rs | 6 +-- crates/vm/src/builtins/type.rs | 4 +- crates/vm/src/builtins/union.rs | 2 +- crates/vm/src/cformat.rs | 16 +++++--- crates/vm/src/codecs.rs | 2 +- crates/vm/src/coroutine.rs | 5 ++- crates/vm/src/exception_group.rs | 14 +++---- crates/vm/src/exceptions.rs | 14 +++---- crates/vm/src/frame.rs | 4 +- crates/vm/src/import.rs | 7 ++-- crates/vm/src/object/payload.rs | 4 +- crates/vm/src/scope.rs | 2 +- crates/vm/src/sequence.rs | 4 +- crates/vm/src/stdlib/builtins.rs | 4 +- crates/vm/src/stdlib/collections.rs | 4 +- crates/vm/src/stdlib/ctypes.rs | 4 +- crates/vm/src/stdlib/ctypes/base.rs | 52 ++++++++++++------------ crates/vm/src/stdlib/ctypes/field.rs | 4 +- crates/vm/src/stdlib/ctypes/structure.rs | 6 +-- crates/vm/src/stdlib/ctypes/union.rs | 4 +- crates/vm/src/stdlib/typevar.rs | 8 ++-- crates/vm/src/suggestion.rs | 8 ++-- crates/vm/src/vm/context.rs | 6 +-- crates/vm/src/vm/mod.rs | 5 ++- crates/vm/src/warn.rs | 4 +- crates/wasm/src/convert.rs | 6 +-- crates/wasm/src/js_module.rs | 3 +- 37 files changed, 146 insertions(+), 137 deletions(-) diff --git a/crates/stdlib/src/csv.rs b/crates/stdlib/src/csv.rs index 792402e0580..a62594a9f1b 100644 --- a/crates/stdlib/src/csv.rs +++ b/crates/stdlib/src/csv.rs @@ -4,7 +4,8 @@ pub(crate) use _csv::make_module; mod _csv { use crate::common::lock::PyMutex; use crate::vm::{ - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + VirtualMachine, builtins::{PyBaseExceptionRef, PyInt, PyNone, PyStr, PyType, PyTypeRef}, function::{ArgIterable, ArgumentError, FromArgs, FuncArgs, OptionalArg}, protocol::{PyIter, PyIterReturn}, @@ -130,11 +131,11 @@ mod _csv { /// /// * If the 'delimiter' attribute is not a single-character string, a type error is returned. /// * If the 'obj' is not of string type and does not have a 'delimiter' attribute, a type error is returned. - fn parse_delimiter_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { + fn parse_delimiter_from_obj(vm: &VirtualMachine, obj: &PyObject) -> PyResult { if let Ok(attr) = obj.get_attr("delimiter", vm) { parse_delimiter_from_obj(vm, &attr) } else { - match_class!(match obj.clone() { + match_class!(match obj.to_owned() { s @ PyStr => { Ok(s.as_str().bytes().exactly_one().map_err(|_| { let msg = r#""delimiter" must be a 1-character string"#; @@ -148,7 +149,7 @@ mod _csv { }) } } - fn parse_quotechar_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult> { + fn parse_quotechar_from_obj(vm: &VirtualMachine, obj: &PyObject) -> PyResult> { match_class!(match obj.get_attr("quotechar", vm)? { s @ PyStr => { Ok(Some(s.as_str().bytes().exactly_one().map_err(|_| { @@ -169,7 +170,7 @@ mod _csv { } }) } - fn parse_escapechar_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult> { + fn parse_escapechar_from_obj(vm: &VirtualMachine, obj: &PyObject) -> PyResult> { match_class!(match obj.get_attr("escapechar", vm)? { s @ PyStr => { Ok(Some(s.as_str().bytes().exactly_one().map_err(|_| { @@ -191,10 +192,7 @@ mod _csv { } }) } - fn prase_lineterminator_from_obj( - vm: &VirtualMachine, - obj: &PyObjectRef, - ) -> PyResult { + fn prase_lineterminator_from_obj(vm: &VirtualMachine, obj: &PyObject) -> PyResult { match_class!(match obj.get_attr("lineterminator", vm)? { s @ PyStr => { Ok(if s.as_bytes().eq(b"\r\n") { @@ -217,7 +215,7 @@ mod _csv { } }) } - fn prase_quoting_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { + fn prase_quoting_from_obj(vm: &VirtualMachine, obj: &PyObject) -> PyResult { match_class!(match obj.get_attr("quoting", vm)? { i @ PyInt => { Ok(i.try_to_primitive::(vm)?.try_into().map_err(|_| { diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index ea67d605f76..bf03813b180 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -56,7 +56,8 @@ mod _ssl { vm::{ AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{ - PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, PyTypeRef, PyWeak, + PyBaseException, PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, + PyTypeRef, PyWeak, }, class_or_notimplemented, convert::ToPyException, @@ -3351,7 +3352,7 @@ mod _ssl { // Helper function to set verify_code and verify_message on SSLCertVerificationError fn set_verify_error_info( - exc: &PyBaseExceptionRef, + exc: &Py, ssl_ptr: *const sys::SSL, vm: &VirtualMachine, ) { diff --git a/crates/stdlib/src/scproxy.rs b/crates/stdlib/src/scproxy.rs index f49b6890a69..1974e7814ae 100644 --- a/crates/stdlib/src/scproxy.rs +++ b/crates/stdlib/src/scproxy.rs @@ -5,8 +5,8 @@ mod _scproxy { // straight-forward port of Modules/_scproxy.c use crate::vm::{ - PyResult, VirtualMachine, - builtins::{PyDictRef, PyStr}, + Py, PyResult, VirtualMachine, + builtins::{PyDict, PyDictRef, PyStr}, convert::ToPyObject, }; use system_configuration::core_foundation::{ @@ -74,7 +74,7 @@ mod _scproxy { let result = vm.ctx.new_dict(); - let set_proxy = |result: &PyDictRef, + let set_proxy = |result: &Py, proto: &str, enabled_key: CFStringRef, host_key: CFStringRef, diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index c23062d639d..e6f8c5fda7c 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -1423,7 +1423,7 @@ mod _ssl { /// Helper: Get path from Python's os.environ fn get_env_path( - environ: &PyObjectRef, + environ: &PyObject, var_name: &str, vm: &VirtualMachine, ) -> PyResult { @@ -2101,10 +2101,10 @@ mod _ssl { // Helper functions (private): /// Parse path argument (str or bytes) to string - fn parse_path_arg(arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult { - if let Ok(s) = PyStrRef::try_from_object(vm, arg.clone()) { + fn parse_path_arg(arg: &PyObject, vm: &VirtualMachine) -> PyResult { + if let Ok(s) = PyStrRef::try_from_object(vm, arg.to_owned()) { Ok(s.as_str().to_owned()) - } else if let Ok(b) = ArgBytesLike::try_from_object(vm, arg.clone()) { + } else if let Ok(b) = ArgBytesLike::try_from_object(vm, arg.to_owned()) { String::from_utf8(b.borrow_buf().to_vec()) .map_err(|_| vm.new_value_error("path contains invalid UTF-8".to_owned())) } else { @@ -2279,10 +2279,10 @@ mod _ssl { } /// Helper: Parse cadata argument (str or bytes) - fn parse_cadata_arg(&self, arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(s) = PyStrRef::try_from_object(vm, arg.clone()) { + fn parse_cadata_arg(&self, arg: &PyObject, vm: &VirtualMachine) -> PyResult> { + if let Ok(s) = PyStrRef::try_from_object(vm, arg.to_owned()) { Ok(s.as_str().as_bytes().to_vec()) - } else if let Ok(b) = ArgBytesLike::try_from_object(vm, arg.clone()) { + } else if let Ok(b) = ArgBytesLike::try_from_object(vm, arg.to_owned()) { Ok(b.borrow_buf().to_vec()) } else { Err(vm.new_type_error("cadata should be a str or bytes".to_owned())) diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 798542f210a..4ccc590360a 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -23,10 +23,10 @@ use rustls::server::ResolvesServerCert; use rustls::server::ServerConfig; use rustls::server::ServerConnection; use rustls::sign::CertifiedKey; -use rustpython_vm::builtins::PyBaseExceptionRef; +use rustpython_vm::builtins::{PyBaseException, PyBaseExceptionRef}; use rustpython_vm::convert::IntoPyException; use rustpython_vm::function::ArgBytesLike; -use rustpython_vm::{AsObject, PyObjectRef, PyPayload, PyResult, TryFromObject}; +use rustpython_vm::{AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject}; use std::io::Read; use std::sync::{Arc, Once}; @@ -984,7 +984,7 @@ pub(super) fn create_client_config(options: ClientConfigOptions) -> Result bool { +pub(super) fn is_blocking_io_error(err: &Py, vm: &VirtualMachine) -> bool { err.fast_isinstance(vm.ctx.exceptions.blocking_io_error) } @@ -1534,7 +1534,7 @@ fn ssl_read_tls_records( /// Check if an exception is a connection closed error /// In SSL context, these errors indicate unexpected connection termination without proper TLS shutdown -fn is_connection_closed_error(exc: &PyBaseExceptionRef, vm: &VirtualMachine) -> bool { +fn is_connection_closed_error(exc: &Py, vm: &VirtualMachine) -> bool { use rustpython_vm::stdlib::errno::errors; // Check for ConnectionAbortedError, ConnectionResetError (Python exception types) diff --git a/crates/vm/src/builtins/builtin_func.rs b/crates/vm/src/builtins/builtin_func.rs index d25188affd2..da5fd5e8075 100644 --- a/crates/vm/src/builtins/builtin_func.rs +++ b/crates/vm/src/builtins/builtin_func.rs @@ -51,11 +51,11 @@ impl PyNativeFunction { } // PyCFunction_GET_SELF - pub const fn get_self(&self) -> Option<&PyObjectRef> { + pub fn get_self(&self) -> Option<&PyObject> { if self.value.flags.contains(PyMethodFlags::STATIC) { return None; } - self.zelf.as_ref() + self.zelf.as_deref() } pub const fn as_func(&self) -> &'static dyn PyNativeFn { diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index b34299d4170..04915b035f8 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -753,7 +753,7 @@ impl ExactSizeIterator for DictIter<'_> { trait DictView: PyPayload + PyClassDef + Iterable + Representable { type ReverseIter: PyPayload + std::fmt::Debug; - fn dict(&self) -> &PyDictRef; + fn dict(&self) -> &Py; fn item(vm: &VirtualMachine, key: PyObjectRef, value: PyObjectRef) -> PyObjectRef; #[pymethod] @@ -785,7 +785,7 @@ macro_rules! dict_view { impl DictView for $name { type ReverseIter = $reverse_iter_name; - fn dict(&self) -> &PyDictRef { + fn dict(&self) -> &Py { &self.dict } @@ -1142,7 +1142,7 @@ impl PyDictKeys { #[pygetset] fn mapping(zelf: PyRef) -> PyMappingProxy { - PyMappingProxy::from(zelf.dict().clone()) + PyMappingProxy::from(zelf.dict().to_owned()) } } @@ -1206,7 +1206,7 @@ impl PyDictItems { } #[pygetset] fn mapping(zelf: PyRef) -> PyMappingProxy { - PyMappingProxy::from(zelf.dict().clone()) + PyMappingProxy::from(zelf.dict().to_owned()) } } @@ -1269,7 +1269,7 @@ impl AsNumber for PyDictItems { impl PyDictValues { #[pygetset] fn mapping(zelf: PyRef) -> PyMappingProxy { - PyMappingProxy::from(zelf.dict().clone()) + PyMappingProxy::from(zelf.dict().to_owned()) } } diff --git a/crates/vm/src/builtins/function/jit.rs b/crates/vm/src/builtins/function/jit.rs index 21d8c9c0abf..de0a528b734 100644 --- a/crates/vm/src/builtins/function/jit.rs +++ b/crates/vm/src/builtins/function/jit.rs @@ -1,6 +1,8 @@ use crate::{ AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, - builtins::{PyBaseExceptionRef, PyDictRef, PyFunction, PyStrInterned, bool_, float, int}, + builtins::{ + PyBaseExceptionRef, PyDict, PyDictRef, PyFunction, PyStrInterned, bool_, float, int, + }, bytecode::CodeFlags, convert::ToPyObject, function::FuncArgs, @@ -42,7 +44,7 @@ pub fn new_jit_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef { vm.new_exception_msg(jit_error, msg) } -fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResult { +fn get_jit_arg_type(dict: &Py, name: &str, vm: &VirtualMachine) -> PyResult { if let Some(value) = dict.get_item_opt(name, vm)? { if value.is(vm.ctx.types.int_type) { Ok(JitType::Int) diff --git a/crates/vm/src/builtins/genericalias.rs b/crates/vm/src/builtins/genericalias.rs index 494b580e563..8a7288980fa 100644 --- a/crates/vm/src/builtins/genericalias.rs +++ b/crates/vm/src/builtins/genericalias.rs @@ -294,11 +294,11 @@ pub(crate) fn make_parameters(args: &Py, vm: &VirtualMachine) -> PyTupl } #[inline] -fn tuple_index(vec: &[PyObjectRef], item: &PyObjectRef) -> Option { +fn tuple_index(vec: &[PyObjectRef], item: &PyObject) -> Option { vec.iter().position(|element| element.is(item)) } -fn is_unpacked_typevartuple(arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult { +fn is_unpacked_typevartuple(arg: &PyObject, vm: &VirtualMachine) -> PyResult { if arg.class().is(vm.ctx.types.type_type) { return Ok(false); } @@ -312,7 +312,7 @@ fn is_unpacked_typevartuple(arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult< fn subs_tvars( obj: PyObjectRef, - params: &PyTupleRef, + params: &Py, arg_items: &[PyObjectRef], vm: &VirtualMachine, ) -> PyResult { diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index 0683381f38a..13e8864cd1f 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -367,8 +367,8 @@ where impl MutObjectSequenceOp for PyList { type Inner = [PyObjectRef]; - fn do_get(index: usize, inner: &[PyObjectRef]) -> Option<&PyObjectRef> { - inner.get(index) + fn do_get(index: usize, inner: &[PyObjectRef]) -> Option<&PyObject> { + inner.get(index).map(|r| r.as_ref()) } fn do_lock(&self) -> impl std::ops::Deref { diff --git a/crates/vm/src/builtins/property.rs b/crates/vm/src/builtins/property.rs index 1a2e04ee8b0..41b05a60049 100644 --- a/crates/vm/src/builtins/property.rs +++ b/crates/vm/src/builtins/property.rs @@ -266,7 +266,7 @@ impl PyProperty { #[pygetset] fn __isabstractmethod__(&self, vm: &VirtualMachine) -> PyResult { // Helper to check if a method is abstract - let is_abstract = |method: &PyObjectRef| -> PyResult { + let is_abstract = |method: &PyObject| -> PyResult { match method.get_attr("__isabstractmethod__", vm) { Ok(isabstract) => isabstract.try_to_bool(vm), Err(_) => Ok(false), @@ -309,7 +309,7 @@ impl PyProperty { #[cold] fn format_property_error( &self, - obj: &PyObjectRef, + obj: &PyObject, error_type: &str, vm: &VirtualMachine, ) -> PyResult { @@ -356,7 +356,7 @@ impl Initializer for PyProperty { let mut getter_doc = false; // Helper to get doc from getter - let get_getter_doc = |fget: &PyObjectRef| -> Option { + let get_getter_doc = |fget: &PyObject| -> Option { fget.get_attr("__doc__", vm) .ok() .filter(|doc| !vm.is_none(doc)) diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index d014cdf015e..15743350397 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -175,12 +175,12 @@ fn is_subtype_with_mro(a_mro: &[PyTypeRef], a: &Py, b: &Py) -> b impl PyType { pub fn new_simple_heap( name: &str, - base: &PyTypeRef, + base: &Py, ctx: &Context, ) -> Result, String> { Self::new_heap( name, - vec![base.clone()], + vec![base.to_owned()], Default::default(), Default::default(), Self::static_type().to_owned(), diff --git a/crates/vm/src/builtins/union.rs b/crates/vm/src/builtins/union.rs index 16d2b7831cd..8310201a129 100644 --- a/crates/vm/src/builtins/union.rs +++ b/crates/vm/src/builtins/union.rs @@ -42,7 +42,7 @@ impl PyUnion { /// Direct access to args field, matching CPython's _Py_union_args #[inline] - pub const fn args(&self) -> &PyTupleRef { + pub fn args(&self) -> &Py { &self.args } diff --git a/crates/vm/src/cformat.rs b/crates/vm/src/cformat.rs index 507079e7deb..efb3cb2acc9 100644 --- a/crates/vm/src/cformat.rs +++ b/crates/vm/src/cformat.rs @@ -6,7 +6,8 @@ use crate::common::cformat::*; use crate::common::wtf8::{CodePoint, Wtf8, Wtf8Buf}; use crate::{ - AsObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject, VirtualMachine, + AsObject, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject, + VirtualMachine, builtins::{ PyBaseExceptionRef, PyByteArray, PyBytes, PyFloat, PyInt, PyStr, try_f64_to_bigint, tuple, }, @@ -207,7 +208,7 @@ fn spec_format_string( fn try_update_quantity_from_element( vm: &VirtualMachine, - element: Option<&PyObjectRef>, + element: Option<&PyObject>, ) -> PyResult { match element { Some(width_obj) => { @@ -224,7 +225,7 @@ fn try_update_quantity_from_element( fn try_conversion_flag_from_tuple( vm: &VirtualMachine, - element: Option<&PyObjectRef>, + element: Option<&PyObject>, ) -> PyResult { match element { Some(width_obj) => { @@ -254,8 +255,11 @@ fn try_update_quantity_from_tuple<'a, I: Iterator>( return Ok(()); }; let element = elements.next(); - f.insert(try_conversion_flag_from_tuple(vm, element)?); - let quantity = try_update_quantity_from_element(vm, element)?; + f.insert(try_conversion_flag_from_tuple( + vm, + element.map(|v| v.as_ref()), + )?); + let quantity = try_update_quantity_from_element(vm, element.map(|v| v.as_ref()))?; *q = Some(quantity); Ok(()) } @@ -268,7 +272,7 @@ fn try_update_precision_from_tuple<'a, I: Iterator>( let Some(CFormatPrecision::Quantity(CFormatQuantity::FromValuesTuple)) = p else { return Ok(()); }; - let quantity = try_update_quantity_from_element(vm, elements.next())?; + let quantity = try_update_quantity_from_element(vm, elements.next().map(|v| v.as_ref()))?; *p = Some(CFormatPrecision::Quantity(quantity)); Ok(()) } diff --git a/crates/vm/src/codecs.rs b/crates/vm/src/codecs.rs index dac637c396d..2edb67b497b 100644 --- a/crates/vm/src/codecs.rs +++ b/crates/vm/src/codecs.rs @@ -52,7 +52,7 @@ impl PyCodec { self.0 } #[inline] - pub const fn as_tuple(&self) -> &PyTupleRef { + pub fn as_tuple(&self) -> &Py { &self.0 } diff --git a/crates/vm/src/coroutine.rs b/crates/vm/src/coroutine.rs index 4e76490ed65..ebe3107cb36 100644 --- a/crates/vm/src/coroutine.rs +++ b/crates/vm/src/coroutine.rs @@ -1,7 +1,8 @@ use crate::{ - AsObject, PyObject, PyObjectRef, PyResult, VirtualMachine, + AsObject, Py, PyObject, PyObjectRef, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyStrRef}, common::lock::PyMutex, + exceptions::types::PyBaseException, frame::{ExecutionResult, FrameRef}, protocol::PyIterReturn, }; @@ -207,6 +208,6 @@ impl Coro { } } -pub fn is_gen_exit(exc: &PyBaseExceptionRef, vm: &VirtualMachine) -> bool { +pub fn is_gen_exit(exc: &Py, vm: &VirtualMachine) -> bool { exc.fast_isinstance(vm.ctx.exceptions.generator_exit) } diff --git a/crates/vm/src/exception_group.rs b/crates/vm/src/exception_group.rs index 8d033b26110..cd943ae1bd9 100644 --- a/crates/vm/src/exception_group.rs +++ b/crates/vm/src/exception_group.rs @@ -352,12 +352,12 @@ pub(super) mod types { } // Helper functions for ExceptionGroup - fn is_base_exception_group(obj: &PyObjectRef, vm: &VirtualMachine) -> bool { + fn is_base_exception_group(obj: &PyObject, vm: &VirtualMachine) -> bool { obj.fast_isinstance(vm.ctx.exceptions.base_exception_group) } fn get_exceptions_tuple( - exc: &PyRef, + exc: &Py, vm: &VirtualMachine, ) -> PyResult> { let obj = exc @@ -376,7 +376,7 @@ pub(super) mod types { } fn get_condition_matcher( - condition: &PyObjectRef, + condition: &PyObject, vm: &VirtualMachine, ) -> PyResult { // If it's a type and subclass of BaseException @@ -409,19 +409,19 @@ pub(super) mod types { // If it's callable (but not a type) if condition.is_callable() && condition.downcast_ref::().is_none() { - return Ok(ConditionMatcher::Callable(condition.clone())); + return Ok(ConditionMatcher::Callable(condition.to_owned())); } Err(vm.new_type_error("expected a function, exception type or tuple of exception types")) } impl ConditionMatcher { - fn check(&self, exc: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn check(&self, exc: &PyObject, vm: &VirtualMachine) -> PyResult { match self { ConditionMatcher::Type(typ) => Ok(exc.fast_isinstance(typ)), ConditionMatcher::Types(types) => Ok(types.iter().any(|t| exc.fast_isinstance(t))), ConditionMatcher::Callable(func) => { - let result = func.call((exc.clone(),), vm)?; + let result = func.call((exc.to_owned(),), vm)?; result.try_to_bool(vm) } } @@ -429,7 +429,7 @@ pub(super) mod types { } fn derive_and_copy_attributes( - orig: &PyRef, + orig: &Py, excs: Vec, vm: &VirtualMachine, ) -> PyResult { diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 2958d1e497e..036d914810d 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -79,7 +79,7 @@ impl VirtualMachine { pub fn write_exception( &self, output: &mut W, - exc: &PyBaseExceptionRef, + exc: &Py, ) -> Result<(), W::Error> { let seen = &mut HashSet::::new(); self.write_exception_recursive(output, exc, seen) @@ -88,7 +88,7 @@ impl VirtualMachine { fn write_exception_recursive( &self, output: &mut W, - exc: &PyBaseExceptionRef, + exc: &Py, seen: &mut HashSet, ) -> Result<(), W::Error> { // This function should not be called directly, @@ -132,7 +132,7 @@ impl VirtualMachine { pub fn write_exception_inner( &self, output: &mut W, - exc: &PyBaseExceptionRef, + exc: &Py, ) -> Result<(), W::Error> { let vm = self; if let Some(tb) = exc.traceback.read().clone() { @@ -177,7 +177,7 @@ impl VirtualMachine { fn write_syntaxerror( &self, output: &mut W, - exc: &PyBaseExceptionRef, + exc: &Py, exc_type: &Py, args_repr: &[PyRef], ) -> Result<(), W::Error> { @@ -369,7 +369,7 @@ fn print_source_line( /// Print exception occurrence location from traceback element fn write_traceback_entry( output: &mut W, - tb_entry: &PyTracebackRef, + tb_entry: &Py, ) -> Result<(), W::Error> { let filename = tb_entry.frame.code.source_path.as_str(); writeln!( @@ -1053,12 +1053,12 @@ fn system_exit_code(exc: PyBaseExceptionRef) -> Option { #[cfg(feature = "serde")] pub struct SerializeException<'vm, 's> { vm: &'vm VirtualMachine, - exc: &'s PyBaseExceptionRef, + exc: &'s Py, } #[cfg(feature = "serde")] impl<'vm, 's> SerializeException<'vm, 's> { - pub fn new(vm: &'vm VirtualMachine, exc: &'s PyBaseExceptionRef) -> Self { + pub fn new(vm: &'vm VirtualMachine, exc: &'s Py) -> Self { SerializeException { vm, exc } } } diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 4a460a95884..ad50f972aef 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -1866,14 +1866,14 @@ impl ExecutingFrame<'_> { /// This ensures proper order preservation for OrderedDict and other custom mappings. fn iterate_mapping_keys( vm: &VirtualMachine, - mapping: &PyObjectRef, + mapping: &PyObject, error_prefix: &str, mut key_handler: F, ) -> PyResult<()> where F: FnMut(PyObjectRef) -> PyResult<()>, { - let Some(keys_method) = vm.get_method(mapping.clone(), vm.ctx.intern_str("keys")) else { + let Some(keys_method) = vm.get_method(mapping.to_owned(), vm.ctx.intern_str("keys")) else { return Err(vm.new_type_error(format!("{error_prefix} must be a mapping"))); }; diff --git a/crates/vm/src/import.rs b/crates/vm/src/import.rs index c119405fe1d..3f4a437c599 100644 --- a/crates/vm/src/import.rs +++ b/crates/vm/src/import.rs @@ -1,8 +1,9 @@ //! Import mechanics use crate::{ - AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, - builtins::{PyBaseExceptionRef, PyCode, list, traceback::PyTraceback}, + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, + builtins::{PyCode, list, traceback::PyTraceback}, + exceptions::types::PyBaseException, scope::Scope, version::get_git_revision, vm::{VirtualMachine, thread}, @@ -204,7 +205,7 @@ fn remove_importlib_frames_inner( // TODO: This function should do nothing on verbose mode. // TODO: Fix this function after making PyTraceback.next mutable -pub fn remove_importlib_frames(vm: &VirtualMachine, exc: &PyBaseExceptionRef) { +pub fn remove_importlib_frames(vm: &VirtualMachine, exc: &Py) { if vm.state.settings.verbose != 0 { return; } diff --git a/crates/vm/src/object/payload.rs b/crates/vm/src/object/payload.rs index cf903871179..4b900b7caa1 100644 --- a/crates/vm/src/object/payload.rs +++ b/crates/vm/src/object/payload.rs @@ -106,7 +106,7 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline(never)] fn _into_ref_size_error( vm: &VirtualMachine, - cls: &PyTypeRef, + cls: &Py, exact_class: &Py, ) -> PyBaseExceptionRef { vm.new_type_error(format!( @@ -123,7 +123,7 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline(never)] fn _into_ref_with_type_error( vm: &VirtualMachine, - cls: &PyTypeRef, + cls: &Py, exact_class: &Py, ) -> PyBaseExceptionRef { vm.new_type_error(format!( diff --git a/crates/vm/src/scope.rs b/crates/vm/src/scope.rs index 9311fa5c2db..4f80e9999ec 100644 --- a/crates/vm/src/scope.rs +++ b/crates/vm/src/scope.rs @@ -34,7 +34,7 @@ impl Scope { Self::new(locals, globals) } - // pub fn get_locals(&self) -> &PyDictRef { + // pub fn get_locals(&self) -> &Py { // match self.locals.first() { // Some(dict) => dict, // None => &self.globals, diff --git a/crates/vm/src/sequence.rs b/crates/vm/src/sequence.rs index e75c0a6da55..6e03ad1697e 100644 --- a/crates/vm/src/sequence.rs +++ b/crates/vm/src/sequence.rs @@ -12,7 +12,7 @@ use std::ops::{Deref, Range}; pub trait MutObjectSequenceOp { type Inner: ?Sized; - fn do_get(index: usize, inner: &Self::Inner) -> Option<&PyObjectRef>; + fn do_get(index: usize, inner: &Self::Inner) -> Option<&PyObject>; fn do_lock(&self) -> impl Deref; fn mut_count(&self, vm: &VirtualMachine, needle: &PyObject) -> PyResult { @@ -76,7 +76,7 @@ pub trait MutObjectSequenceOp { } borrower = Some(guard); } else { - let elem = elem.clone(); + let elem = elem.to_owned(); drop(guard); if elem.rich_compare_bool(needle, PyComparisonOp::Eq, vm)? { diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index 542476d68c7..442bb79b94e 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -10,7 +10,7 @@ mod builtins { use std::io::IsTerminal; use crate::{ - AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, + AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{ PyByteArray, PyBytes, PyDictRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, enumerate::PyReverseSequenceIterator, @@ -261,7 +261,7 @@ mod builtins { func_name: &'static str, ) -> PyResult { fn validate_globals_dict( - globals: &PyObjectRef, + globals: &PyObject, vm: &VirtualMachine, func_name: &'static str, ) -> PyResult<()> { diff --git a/crates/vm/src/stdlib/collections.rs b/crates/vm/src/stdlib/collections.rs index 32596b65386..eae56968cba 100644 --- a/crates/vm/src/stdlib/collections.rs +++ b/crates/vm/src/stdlib/collections.rs @@ -422,8 +422,8 @@ mod _collections { impl MutObjectSequenceOp for PyDeque { type Inner = VecDeque; - fn do_get(index: usize, inner: &Self::Inner) -> Option<&PyObjectRef> { - inner.get(index) + fn do_get(index: usize, inner: &Self::Inner) -> Option<&PyObject> { + inner.get(index).map(|r| r.as_ref()) } fn do_lock(&self) -> impl std::ops::Deref { diff --git a/crates/vm/src/stdlib/ctypes.rs b/crates/vm/src/stdlib/ctypes.rs index 70aee7378d3..ebe2d16ffb2 100644 --- a/crates/vm/src/stdlib/ctypes.rs +++ b/crates/vm/src/stdlib/ctypes.rs @@ -53,7 +53,7 @@ pub(crate) mod _ctypes { use crate::convert::ToPyObject; use crate::function::{Either, FuncArgs, OptionalArg}; use crate::stdlib::ctypes::library; - use crate::{AsObject, PyObjectRef, PyPayload, PyResult, VirtualMachine}; + use crate::{AsObject, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine}; use crossbeam_utils::atomic::AtomicCell; use std::ffi::{ c_double, c_float, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, @@ -349,7 +349,7 @@ pub(crate) mod _ctypes { const SIMPLE_TYPE_CHARS: &str = "cbBhHiIlLdfguzZPqQ?O"; pub fn new_simple_type( - cls: Either<&PyObjectRef, &PyTypeRef>, + cls: Either<&PyObject, &PyTypeRef>, vm: &VirtualMachine, ) -> PyResult { let cls = match cls { diff --git a/crates/vm/src/stdlib/ctypes/base.rs b/crates/vm/src/stdlib/ctypes/base.rs index a4664ad3671..e45ff0b3b70 100644 --- a/crates/vm/src/stdlib/ctypes/base.rs +++ b/crates/vm/src/stdlib/ctypes/base.rs @@ -5,7 +5,9 @@ use crate::function::{ArgBytesLike, Either, FuncArgs, KwArgs, OptionalArg}; use crate::protocol::{BufferDescriptor, BufferMethods, PyBuffer, PyNumberMethods}; use crate::stdlib::ctypes::_ctypes::new_simple_type; use crate::types::{AsBuffer, AsNumber, Constructor}; -use crate::{AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine}; +use crate::{ + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, +}; use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; use rustpython_common::lock::PyRwLock; @@ -46,19 +48,19 @@ pub fn ffi_type_from_str(_type_: &str) -> Option { } #[allow(dead_code)] -fn set_primitive(_type_: &str, value: &PyObjectRef, vm: &VirtualMachine) -> PyResult { +fn set_primitive(_type_: &str, value: &PyObject, vm: &VirtualMachine) -> PyResult { match _type_ { "c" => { if value - .clone() + .to_owned() .downcast_exact::(vm) .is_ok_and(|v| v.len() == 1) || value - .clone() + .to_owned() .downcast_exact::(vm) .is_ok_and(|v| v.len() == 1) || value - .clone() + .to_owned() .downcast_exact::(vm) .map_or(Ok(false), |v| { let n = v.as_bigint().to_i64(); @@ -69,7 +71,7 @@ fn set_primitive(_type_: &str, value: &PyObjectRef, vm: &VirtualMachine) -> PyRe } })? { - Ok(value.clone()) + Ok(value.to_owned()) } else { Err(vm.new_type_error("one character bytes, bytearray or integer expected")) } @@ -77,7 +79,7 @@ fn set_primitive(_type_: &str, value: &PyObjectRef, vm: &VirtualMachine) -> PyRe "u" => { if let Ok(b) = value.str(vm).map(|v| v.to_string().chars().count() == 1) { if b { - Ok(value.clone()) + Ok(value.to_owned()) } else { Err(vm.new_type_error("one character unicode string expected")) } @@ -89,8 +91,8 @@ fn set_primitive(_type_: &str, value: &PyObjectRef, vm: &VirtualMachine) -> PyRe } } "b" | "h" | "H" | "i" | "I" | "l" | "q" | "L" | "Q" => { - if value.clone().downcast_exact::(vm).is_ok() { - Ok(value.clone()) + if value.to_owned().downcast_exact::(vm).is_ok() { + Ok(value.to_owned()) } else { Err(vm.new_type_error(format!( "an integer is required (got type {})", @@ -100,30 +102,30 @@ fn set_primitive(_type_: &str, value: &PyObjectRef, vm: &VirtualMachine) -> PyRe } "f" | "d" | "g" => { // float allows int - if value.clone().downcast_exact::(vm).is_ok() - || value.clone().downcast_exact::(vm).is_ok() + if value.to_owned().downcast_exact::(vm).is_ok() + || value.to_owned().downcast_exact::(vm).is_ok() { - Ok(value.clone()) + Ok(value.to_owned()) } else { Err(vm.new_type_error(format!("must be real number, not {}", value.class().name()))) } } "?" => Ok(PyObjectRef::from( - vm.ctx.new_bool(value.clone().try_to_bool(vm)?), + vm.ctx.new_bool(value.to_owned().try_to_bool(vm)?), )), "B" => { - if value.clone().downcast_exact::(vm).is_ok() { + if value.to_owned().downcast_exact::(vm).is_ok() { // Store as-is, conversion to unsigned happens in the getter - Ok(value.clone()) + Ok(value.to_owned()) } else { Err(vm.new_type_error(format!("int expected instead of {}", value.class().name()))) } } "z" => { - if value.clone().downcast_exact::(vm).is_ok() - || value.clone().downcast_exact::(vm).is_ok() + if value.to_owned().downcast_exact::(vm).is_ok() + || value.to_owned().downcast_exact::(vm).is_ok() { - Ok(value.clone()) + Ok(value.to_owned()) } else { Err(vm.new_type_error(format!( "bytes or integer address expected instead of {} instance", @@ -132,8 +134,8 @@ fn set_primitive(_type_: &str, value: &PyObjectRef, vm: &VirtualMachine) -> PyRe } } "Z" => { - if value.clone().downcast_exact::(vm).is_ok() { - Ok(value.clone()) + if value.to_owned().downcast_exact::(vm).is_ok() { + Ok(value.to_owned()) } else { Err(vm.new_type_error(format!( "unicode string or integer address expected instead of {} instance", @@ -143,10 +145,10 @@ fn set_primitive(_type_: &str, value: &PyObjectRef, vm: &VirtualMachine) -> PyRe } _ => { // "P" - if value.clone().downcast_exact::(vm).is_ok() - || value.clone().downcast_exact::(vm).is_ok() + if value.to_owned().downcast_exact::(vm).is_ok() + || value.to_owned().downcast_exact::(vm).is_ok() { - Ok(value.clone()) + Ok(value.to_owned()) } else { Err(vm.new_type_error("cannot be converted to pointer")) } @@ -437,7 +439,7 @@ impl Debug for PyCSimple { fn value_to_bytes_endian( _type_: &str, - value: &PyObjectRef, + value: &PyObject, swapped: bool, vm: &VirtualMachine, ) -> Vec { @@ -598,7 +600,7 @@ fn value_to_bytes_endian( } "?" => { // c_bool (1 byte) - if let Ok(b) = value.clone().try_to_bool(vm) { + if let Ok(b) = value.to_owned().try_to_bool(vm) { return vec![if b { 1 } else { 0 }]; } vec![0] diff --git a/crates/vm/src/stdlib/ctypes/field.rs b/crates/vm/src/stdlib/ctypes/field.rs index 659255f3329..ea57d68065a 100644 --- a/crates/vm/src/stdlib/ctypes/field.rs +++ b/crates/vm/src/stdlib/ctypes/field.rs @@ -1,7 +1,7 @@ use crate::builtins::PyType; use crate::function::PySetterValue; use crate::types::{GetDescriptor, Representable}; -use crate::{AsObject, Py, PyObjectRef, PyResult, VirtualMachine}; +use crate::{AsObject, Py, PyObject, PyObjectRef, PyResult, VirtualMachine}; use num_traits::ToPrimitive; use super::structure::PyCStructure; @@ -152,7 +152,7 @@ impl PyCField { } /// Convert a Python value to bytes - fn value_to_bytes(value: &PyObjectRef, size: usize, vm: &VirtualMachine) -> PyResult> { + fn value_to_bytes(value: &PyObject, size: usize, vm: &VirtualMachine) -> PyResult> { if let Ok(int_val) = value.try_int(vm) { let i = int_val.as_bigint(); match size { diff --git a/crates/vm/src/stdlib/ctypes/structure.rs b/crates/vm/src/stdlib/ctypes/structure.rs index f32d6865cb6..ca67a2fe7d6 100644 --- a/crates/vm/src/stdlib/ctypes/structure.rs +++ b/crates/vm/src/stdlib/ctypes/structure.rs @@ -7,7 +7,7 @@ use crate::function::FuncArgs; use crate::protocol::{BufferDescriptor, BufferMethods, PyBuffer, PyNumberMethods}; use crate::stdlib::ctypes::_ctypes::get_size; use crate::types::{AsBuffer, AsNumber, Constructor}; -use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; +use crate::{AsObject, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine}; use indexmap::IndexMap; use num_traits::ToPrimitive; use rustpython_common::lock::PyRwLock; @@ -109,7 +109,7 @@ impl PyCStructType { } /// Get the size of a ctypes type - fn get_field_size(field_type: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn get_field_size(field_type: &PyObject, vm: &VirtualMachine) -> PyResult { // Try to get _type_ attribute for simple types if let Some(size) = field_type .get_attr("_type_", vm) @@ -139,7 +139,7 @@ impl PyCStructType { } /// Get the alignment of a ctypes type - fn get_field_align(field_type: &PyObjectRef, vm: &VirtualMachine) -> usize { + fn get_field_align(field_type: &PyObject, vm: &VirtualMachine) -> usize { // Try to get _type_ attribute for simple types if let Some(align) = field_type .get_attr("_type_", vm) diff --git a/crates/vm/src/stdlib/ctypes/union.rs b/crates/vm/src/stdlib/ctypes/union.rs index e6873e87506..308a5e4e98f 100644 --- a/crates/vm/src/stdlib/ctypes/union.rs +++ b/crates/vm/src/stdlib/ctypes/union.rs @@ -7,7 +7,7 @@ use crate::function::FuncArgs; use crate::protocol::{BufferDescriptor, BufferMethods, PyBuffer as ProtocolPyBuffer}; use crate::stdlib::ctypes::_ctypes::get_size; use crate::types::{AsBuffer, Constructor}; -use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; +use crate::{AsObject, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine}; use num_traits::ToPrimitive; use rustpython_common::lock::PyRwLock; @@ -90,7 +90,7 @@ impl PyCUnionType { Ok(()) } - fn get_field_size(field_type: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn get_field_size(field_type: &PyObject, vm: &VirtualMachine) -> PyResult { if let Some(size) = field_type .get_attr("_type_", vm) .ok() diff --git a/crates/vm/src/stdlib/typevar.rs b/crates/vm/src/stdlib/typevar.rs index e8cc407da15..65249bfd075 100644 --- a/crates/vm/src/stdlib/typevar.rs +++ b/crates/vm/src/stdlib/typevar.rs @@ -44,7 +44,7 @@ fn caller(vm: &VirtualMachine) -> Option { /// Set __module__ attribute for an object based on the caller's module. /// This follows CPython's behavior for TypeVar and similar objects. -fn set_module_from_caller(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { +fn set_module_from_caller(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> { // Note: CPython gets module from frame->f_funcobj, but RustPython's Frame // architecture is different - we use globals['__name__'] instead if let Some(module_name) = caller(vm) { @@ -1006,15 +1006,15 @@ pub fn set_typeparam_default( ) -> PyResult { // Inner function to handle common pattern of setting evaluate_default fn try_set_default( - obj: &PyObjectRef, - evaluate_default: &PyObjectRef, + obj: &PyObject, + evaluate_default: &PyObject, get_field: impl FnOnce(&T) -> &PyMutex, ) -> bool where T: PyPayload, { if let Some(typed_obj) = obj.downcast_ref::() { - *get_field(typed_obj).lock() = evaluate_default.clone(); + *get_field(typed_obj).lock() = evaluate_default.to_owned(); true } else { false diff --git a/crates/vm/src/suggestion.rs b/crates/vm/src/suggestion.rs index 2cc935873c2..866deb668eb 100644 --- a/crates/vm/src/suggestion.rs +++ b/crates/vm/src/suggestion.rs @@ -2,9 +2,9 @@ //! This is used during tracebacks. use crate::{ - AsObject, Py, PyObjectRef, VirtualMachine, + AsObject, Py, PyObject, PyObjectRef, VirtualMachine, builtins::{PyStr, PyStrRef}, - exceptions::types::PyBaseExceptionRef, + exceptions::types::PyBaseException, sliceable::SliceableSequenceOp, }; use rustpython_common::str::levenshtein::{MOVE_COST, levenshtein_distance}; @@ -14,7 +14,7 @@ const MAX_CANDIDATE_ITEMS: usize = 750; pub fn calculate_suggestions<'a>( dir_iter: impl ExactSizeIterator, - name: &PyObjectRef, + name: &PyObject, ) -> Option { if dir_iter.len() >= MAX_CANDIDATE_ITEMS { return None; @@ -47,7 +47,7 @@ pub fn calculate_suggestions<'a>( suggestion.map(|r| r.to_owned()) } -pub fn offer_suggestions(exc: &PyBaseExceptionRef, vm: &VirtualMachine) -> Option { +pub fn offer_suggestions(exc: &Py, vm: &VirtualMachine) -> Option { if exc.class().is(vm.ctx.exceptions.attribute_error) { let name = exc.as_object().get_attr("name", vm).unwrap(); let obj = exc.as_object().get_attr("obj", vm).unwrap(); diff --git a/crates/vm/src/vm/context.rs b/crates/vm/src/vm/context.rs index fbda71dc1f6..486c1861fb1 100644 --- a/crates/vm/src/vm/context.rs +++ b/crates/vm/src/vm/context.rs @@ -62,9 +62,9 @@ macro_rules! declare_const_name { } impl ConstName { - unsafe fn new(pool: &StringPool, typ: &PyTypeRef) -> Self { + unsafe fn new(pool: &StringPool, typ: &Py) -> Self { Self { - $($name: unsafe { pool.intern(declare_const_name!(@string $name $($s)?), typ.clone()) },)* + $($name: unsafe { pool.intern(declare_const_name!(@string $name $($s)?), typ.to_owned()) },)* } } } @@ -317,7 +317,7 @@ impl Context { ); let string_pool = StringPool::default(); - let names = unsafe { ConstName::new(&string_pool, &types.str_type.to_owned()) }; + let names = unsafe { ConstName::new(&string_pool, types.str_type) }; let slot_new_wrapper = PyMethodDef::new_const( names.__new__.as_str(), diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index fd37b2494dd..4574b2de370 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -23,6 +23,7 @@ use crate::{ codecs::CodecsRegistry, common::{hash::HashSecret, lock::PyMutex, rc::PyRc}, convert::ToPyObject, + exceptions::types::PyBaseException, frame::{ExecutionResult, Frame, FrameRef}, frozen::FrozenModule, function::{ArgMapping, FuncArgs, PySetterValue}, @@ -728,7 +729,7 @@ impl VirtualMachine { pub fn set_attribute_error_context( &self, - exc: &PyBaseExceptionRef, + exc: &Py, obj: PyObjectRef, name: PyStrRef, ) { @@ -814,7 +815,7 @@ impl VirtualMachine { drop(prev); } - pub(crate) fn contextualize_exception(&self, exception: &PyBaseExceptionRef) { + pub(crate) fn contextualize_exception(&self, exception: &Py) { if let Some(context_exc) = self.topmost_exception() && !context_exc.is(exception) { diff --git a/crates/vm/src/warn.rs b/crates/vm/src/warn.rs index 6480f778433..b632495eb4a 100644 --- a/crates/vm/src/warn.rs +++ b/crates/vm/src/warn.rs @@ -1,5 +1,5 @@ use crate::{ - AsObject, Context, Py, PyObjectRef, PyResult, VirtualMachine, + AsObject, Context, Py, PyObject, PyObjectRef, PyResult, VirtualMachine, builtins::{ PyDictRef, PyListRef, PyStr, PyStrInterned, PyStrRef, PyTuple, PyTupleRef, PyTypeRef, }, @@ -38,7 +38,7 @@ impl WarningsState { } } -fn check_matched(obj: &PyObjectRef, arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult { +fn check_matched(obj: &PyObject, arg: &PyObject, vm: &VirtualMachine) -> PyResult { if obj.class().is(vm.ctx.types.none_type) { return Ok(true); } diff --git a/crates/wasm/src/convert.rs b/crates/wasm/src/convert.rs index d1821f2e733..f84b0d46239 100644 --- a/crates/wasm/src/convert.rs +++ b/crates/wasm/src/convert.rs @@ -4,8 +4,8 @@ use crate::js_module; use crate::vm_class::{WASMVirtualMachine, stored_vm_from_wasm}; use js_sys::{Array, ArrayBuffer, Object, Promise, Reflect, SyntaxError, Uint8Array}; use rustpython_vm::{ - AsObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, VirtualMachine, - builtins::PyBaseExceptionRef, + AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, VirtualMachine, + builtins::{PyBaseException, PyBaseExceptionRef}, compiler::{CompileError, ParseError, parser::LexicalErrorType, parser::ParseErrorType}, exceptions, function::{ArgBytesLike, FuncArgs}, @@ -32,7 +32,7 @@ extern "C" { fn new(info: JsValue) -> PyError; } -pub fn py_err_to_js_err(vm: &VirtualMachine, py_err: &PyBaseExceptionRef) -> JsValue { +pub fn py_err_to_js_err(vm: &VirtualMachine, py_err: &Py) -> JsValue { let js_err = vm.try_class("_js", "JSError").ok(); let js_arg = if js_err.is_some_and(|js_err| py_err.fast_isinstance(&js_err)) { py_err.get_arg(0) diff --git a/crates/wasm/src/js_module.rs b/crates/wasm/src/js_module.rs index 1d8ca0961ca..d4f623da9f4 100644 --- a/crates/wasm/src/js_module.rs +++ b/crates/wasm/src/js_module.rs @@ -612,8 +612,7 @@ mod _js { fn js_error(vm: &VirtualMachine) -> PyTypeRef { let ctx = &vm.ctx; let js_error = PyRef::leak( - PyType::new_simple_heap("JSError", &vm.ctx.exceptions.exception_type.to_owned(), ctx) - .unwrap(), + PyType::new_simple_heap("JSError", vm.ctx.exceptions.exception_type, ctx).unwrap(), ); extend_class!(ctx, js_error, { "value" => ctx.new_readonly_getset("value", js_error, |exc: PyBaseExceptionRef| exc.get_arg(0)), From 9e439667dac0729f3d0c488ab3f6fedaadc580f7 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 18 Dec 2025 10:44:56 +0900 Subject: [PATCH 011/418] Fix openssl build and shared ssl/error.rs (#6456) * Fix openssl * shared error --- crates/stdlib/src/openssl.rs | 502 +++++++++++++----------------- crates/stdlib/src/openssl/cert.rs | 3 +- crates/stdlib/src/ssl.rs | 114 +------ crates/stdlib/src/ssl/compat.rs | 11 +- crates/stdlib/src/ssl/error.rs | 117 +++++++ 5 files changed, 354 insertions(+), 393 deletions(-) create mode 100644 crates/stdlib/src/ssl/error.rs diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index bf03813b180..e07ad552f17 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -2,6 +2,10 @@ mod cert; +// SSL exception types (shared with rustls backend) +#[path = "ssl/error.rs"] +mod ssl_error; + // Conditional compilation for OpenSSL version-specific error codes cfg_if::cfg_if! { if #[cfg(ossl310)] { @@ -45,9 +49,16 @@ cfg_if::cfg_if! { } #[allow(non_upper_case_globals)] -#[pymodule(with(cert::ssl_cert, ossl101, ossl111, windows))] +#[pymodule(with(cert::ssl_cert, ssl_error::ssl_error, ossl101, ossl111, windows))] mod _ssl { use super::{bio, probe}; + + // Import error types used in this module (others are exposed via pymodule(with(...))) + use super::ssl_error::{ + PySSLCertVerificationError as PySslCertVerificationError, PySSLEOFError as PySslEOFError, + PySSLError as PySslError, PySSLWantReadError as PySslWantReadError, + PySSLWantWriteError as PySslWantWriteError, + }; use crate::{ common::lock::{ PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, @@ -56,8 +67,8 @@ mod _ssl { vm::{ AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{ - PyBaseException, PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, - PyTypeRef, PyWeak, + PyBaseException, PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, + PyWeak, }, class_or_notimplemented, convert::ToPyException, @@ -300,85 +311,6 @@ mod _ssl { parse_version_info(openssl_api_version) } - // SSL Exception Types - - /// An error occurred in the SSL implementation. - #[pyattr] - #[pyexception(name = "SSLError", base = PyOSError)] - #[derive(Debug)] - pub struct PySslError {} - - #[pyexception] - impl PySslError { - // Returns strerror attribute if available, otherwise str(args) - #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - // Try to get strerror attribute first (OSError compatibility) - if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) - && !vm.is_none(&strerror) - { - return strerror.str(vm); - } - - // Otherwise return str(args) - exc.args().as_object().str(vm) - } - } - - /// A certificate could not be verified. - #[pyattr] - #[pyexception(name = "SSLCertVerificationError", base = PySslError)] - #[derive(Debug)] - pub struct PySslCertVerificationError {} - - #[pyexception] - impl PySslCertVerificationError {} - - /// SSL/TLS session closed cleanly. - #[pyattr] - #[pyexception(name = "SSLZeroReturnError", base = PySslError)] - #[derive(Debug)] - pub struct PySslZeroReturnError {} - - #[pyexception] - impl PySslZeroReturnError {} - - /// Non-blocking SSL socket needs to read more data. - #[pyattr] - #[pyexception(name = "SSLWantReadError", base = PySslError)] - #[derive(Debug)] - pub struct PySslWantReadError {} - - #[pyexception] - impl PySslWantReadError {} - - /// Non-blocking SSL socket needs to write more data. - #[pyattr] - #[pyexception(name = "SSLWantWriteError", base = PySslError)] - #[derive(Debug)] - pub struct PySslWantWriteError {} - - #[pyexception] - impl PySslWantWriteError {} - - /// System error when attempting SSL operation. - #[pyattr] - #[pyexception(name = "SSLSyscallError", base = PySslError)] - #[derive(Debug)] - pub struct PySslSyscallError {} - - #[pyexception] - impl PySslSyscallError {} - - /// SSL/TLS connection terminated abruptly. - #[pyattr] - #[pyexception(name = "SSLEOFError", base = PySslError)] - #[derive(Debug)] - pub struct PySslEOFError {} - - #[pyexception] - impl PySslEOFError {} - type OpensslVersionInfo = (u8, u8, u8, u8, u8); const fn parse_version_info(mut n: i64) -> OpensslVersionInfo { let status = (n & 0xF) as u8; @@ -582,18 +514,53 @@ mod _ssl { Ok(buf) } - // Callback data stored in SSL context for SNI + // Callback data stored in SSL ex_data for SNI/msg callbacks struct SniCallbackData { ssl_context: PyRef, - vm_ptr: *const VirtualMachine, + // Use weak reference to avoid reference cycle: + // PySslSocket -> SslStream -> SSL -> ex_data -> SniCallbackData -> PySslSocket + ssl_socket_weak: PyRef, + } + + // Thread-local storage for VirtualMachine pointer during handshake + // SNI callback is only called during handshake which is synchronous + thread_local! { + static HANDSHAKE_VM: std::cell::Cell> = const { std::cell::Cell::new(None) }; + // SSL pointer during handshake - needed because connection lock is held during handshake + // and callbacks may need to access SSL without acquiring the lock + static HANDSHAKE_SSL_PTR: std::cell::Cell> = const { std::cell::Cell::new(None) }; } - impl Drop for SniCallbackData { + // RAII guard to set/clear thread-local handshake context + struct HandshakeVmGuard { + _ssl_ptr: *mut sys::SSL, + } + + impl HandshakeVmGuard { + fn new(vm: &VirtualMachine, ssl_ptr: *mut sys::SSL) -> Self { + HANDSHAKE_VM.with(|cell| cell.set(Some(vm as *const _))); + HANDSHAKE_SSL_PTR.with(|cell| cell.set(Some(ssl_ptr))); + HandshakeVmGuard { _ssl_ptr: ssl_ptr } + } + } + + impl Drop for HandshakeVmGuard { fn drop(&mut self) { - // PyRef will handle reference counting + HANDSHAKE_VM.with(|cell| cell.set(None)); + HANDSHAKE_SSL_PTR.with(|cell| cell.set(None)); } } + // Get SSL pointer - either from thread-local (during handshake) or from connection + fn get_ssl_ptr_for_context_change(connection: &PyRwLock) -> *mut sys::SSL { + // First check if we're in a handshake callback (lock already held) + if let Some(ptr) = HANDSHAKE_SSL_PTR.with(|cell| cell.get()) { + return ptr; + } + // Otherwise, acquire the lock normally + connection.read().ssl().as_ptr() + } + // Get or create an ex_data index for SNI callback data fn get_sni_ex_data_index() -> libc::c_int { use std::sync::LazyLock; @@ -610,17 +577,30 @@ mod _ssl { } // Free function for callback data + // NOTE: We don't free the data here because it's managed manually in do_handshake + // to avoid use-after-free when the SSL object is dropped after timeout unsafe extern "C" fn sni_callback_data_free( _parent: *mut libc::c_void, - ptr: *mut libc::c_void, + _ptr: *mut libc::c_void, _ad: *mut sys::CRYPTO_EX_DATA, _idx: libc::c_int, _argl: libc::c_long, _argp: *mut libc::c_void, ) { - if !ptr.is_null() { - unsafe { - let _ = Box::from_raw(ptr as *mut SniCallbackData); + // Intentionally empty - data is freed in cleanup_sni_ex_data() + } + + // Clean up SNI callback data from SSL ex_data + // Called after handshake to free the data and release references + unsafe fn cleanup_sni_ex_data(ssl_ptr: *mut sys::SSL) { + unsafe { + let idx = get_sni_ex_data_index(); + let data_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); + if !data_ptr.is_null() { + // Free the Box - this releases references to context and socket + let _ = Box::from_raw(data_ptr as *mut SniCallbackData); + // Clear the ex_data to prevent double-free + sys::SSL_set_ex_data(ssl_ptr, idx, std::ptr::null_mut()); } } } @@ -658,9 +638,13 @@ mod _ssl { let callback_data = &*(data_ptr as *const SniCallbackData); - // SAFETY: vm_ptr is stored during wrap_socket and is valid for the lifetime - // of the SSL connection. The handshake happens synchronously in the same thread. - let vm = &*callback_data.vm_ptr; + // Get VM from thread-local storage (set by HandshakeVmGuard in do_handshake) + let Some(vm_ptr) = HANDSHAKE_VM.with(|cell| cell.get()) else { + // VM not available - this shouldn't happen during handshake + *al = SSL_AD_INTERNAL_ERROR; + return SSL_TLSEXT_ERR_ALERT_FATAL; + }; + let vm = &*vm_ptr; // Get server name let servername = sys::SSL_get_servername(ssl_ptr, TLSEXT_NAMETYPE_host_name); @@ -674,20 +658,11 @@ mod _ssl { } }; - // Get SSL socket from SSL ex_data (stored as PySslSocket pointer) - let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); // Index 0 for SSL socket - let ssl_socket_obj = if !ssl_socket_ptr.is_null() { - let ssl_socket = &*(ssl_socket_ptr as *const PySslSocket); - // Try to get owner first - ssl_socket - .owner - .read() - .as_ref() - .and_then(|weak| weak.upgrade()) - .unwrap_or_else(|| vm.ctx.none()) - } else { - vm.ctx.none() - }; + // Get SSL socket from callback data via weak reference + let ssl_socket_obj = callback_data + .ssl_socket_weak + .upgrade() + .unwrap_or_else(|| vm.ctx.none()); // Call the Python callback match callback.call( @@ -735,81 +710,20 @@ mod _ssl { } // Message callback function called by OpenSSL - // Based on CPython's _PySSL_msg_callback in Modules/_ssl/debughelpers.c + // NOTE: This callback is intentionally a no-op to avoid deadlocks. + // The msg_callback can be called during various SSL operations (read, write, handshake), + // and invoking Python code from within these operations can cause deadlocks + // (see CPython bpo-43577). A proper implementation would require careful lock ordering. unsafe extern "C" fn _msg_callback( - write_p: libc::c_int, - version: libc::c_int, - content_type: libc::c_int, - buf: *const libc::c_void, - len: usize, - ssl_ptr: *mut sys::SSL, + _write_p: libc::c_int, + _version: libc::c_int, + _content_type: libc::c_int, + _buf: *const libc::c_void, + _len: usize, + _ssl_ptr: *mut sys::SSL, _arg: *mut libc::c_void, ) { - if ssl_ptr.is_null() { - return; - } - - unsafe { - // Get SSL socket from SSL_get_app_data (index 0) - let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); - if ssl_socket_ptr.is_null() { - return; - } - - let ssl_socket = &*(ssl_socket_ptr as *const PySslSocket); - - // Get the callback from the context - let callback_opt = ssl_socket.ctx.read().msg_callback.lock().clone(); - let Some(callback) = callback_opt else { - return; - }; - - // Get callback data from SSL ex_data (for VM) - let idx = get_sni_ex_data_index(); - let data_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); - if data_ptr.is_null() { - return; - } - - let callback_data = &*(data_ptr as *const SniCallbackData); - let vm = &*callback_data.vm_ptr; - - // Get SSL socket owner object - let ssl_socket_obj = ssl_socket - .owner - .read() - .as_ref() - .and_then(|weak| weak.upgrade()) - .unwrap_or_else(|| vm.ctx.none()); - - // Create the message bytes - let buf_slice = std::slice::from_raw_parts(buf as *const u8, len); - let msg_bytes = vm.ctx.new_bytes(buf_slice.to_vec()); - - // Determine direction string - let direction_str = if write_p != 0 { "write" } else { "read" }; - - // Call the Python callback - // Signature: callback(conn, direction, version, content_type, msg_type, data) - // For simplicity, we'll pass msg_type as 0 (would need more parsing to get the actual type) - match callback.call( - ( - ssl_socket_obj, - vm.ctx.new_str(direction_str), - vm.ctx.new_int(version), - vm.ctx.new_int(content_type), - vm.ctx.new_int(0), // msg_type - would need parsing - msg_bytes, - ), - vm, - ) { - Ok(_) => {} - Err(exc) => { - // Log the exception but don't propagate it - vm.run_unraisable(exc, None, vm.ctx.none()); - } - } - } + // Intentionally empty to avoid deadlocks } #[pyfunction(name = "RAND_pseudo_bytes")] @@ -850,7 +764,11 @@ mod _ssl { impl Constructor for PySslContext { type Args = i32; - fn py_new(cls: PyTypeRef, proto_version: Self::Args, vm: &VirtualMachine) -> PyResult { + fn py_new( + _cls: &Py, + proto_version: Self::Args, + vm: &VirtualMachine, + ) -> PyResult { let proto = SslVersion::try_from(proto_version) .map_err(|_| vm.new_value_error("invalid protocol version"))?; let method = match proto { @@ -932,16 +850,14 @@ mod _ssl { sys::X509_VERIFY_PARAM_set_flags(param, sys::X509_V_FLAG_TRUSTED_FIRST); } - PySslContext { + Ok(PySslContext { ctx: PyRwLock::new(builder), check_hostname: AtomicCell::new(check_hostname), protocol: proto, post_handshake_auth: PyMutex::new(false), sni_callback: PyMutex::new(None), msg_callback: PyMutex::new(None), - } - .into_ref_with_type(vm, cls) - .map(Into::into) + }) } } @@ -981,12 +897,9 @@ mod _ssl { if ciphers.contains('\0') { return Err(exceptions::cstring_error(vm)); } - self.builder().set_cipher_list(ciphers).map_err(|_| { - vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "No cipher can be selected.".to_owned(), - ) - }) + self.builder() + .set_cipher_list(ciphers) + .map_err(|_| new_ssl_error(vm, "No cipher can be selected.")) } #[pymethod] @@ -1126,16 +1039,10 @@ mod _ssl { let set = !flags & new_flags; if clear != 0 && sys::X509_VERIFY_PARAM_clear_flags(param, clear) == 0 { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Failed to clear verify flags".to_owned(), - )); + return Err(new_ssl_error(vm, "Failed to clear verify flags")); } if set != 0 && sys::X509_VERIFY_PARAM_set_flags(param, set) == 0 { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Failed to set verify flags".to_owned(), - )); + return Err(new_ssl_error(vm, "Failed to set verify flags")); } Ok(()) } @@ -1477,10 +1384,13 @@ mod _ssl { let fp = rustpython_common::fileutils::fopen(path.as_path(), "rb").map_err(|e| { match e.kind() { - std::io::ErrorKind::NotFound => vm.new_exception_msg( - vm.ctx.exceptions.file_not_found_error.to_owned(), - e.to_string(), - ), + std::io::ErrorKind::NotFound => vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + None, + e.to_string(), + ) + .upcast(), _ => vm.new_os_error(e.to_string()), } })?; @@ -1670,15 +1580,15 @@ mod _ssl { ) -> PyResult<(ssl::Ssl, SslServerOrClient, Option)> { // Validate socket type and context protocol if server_side && ctx_ref.protocol == SslVersion::TlsClient { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), + return Err(new_ssl_error( + vm, + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context", )); } if !server_side && ctx_ref.protocol == SslVersion::TlsServer { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), + return Err(new_ssl_error( + vm, + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context", )); } @@ -1791,21 +1701,22 @@ mod _ssl { let py_ref = py_ssl_socket.into_ref_with_type(vm, PySslSocket::class(&vm.ctx).to_owned())?; - // Set SNI callback data if callback is configured - if zelf.sni_callback.lock().is_some() { + // Check if SNI callback is configured (minimize lock time) + let has_sni_callback = zelf.sni_callback.lock().is_some(); + + // Set SNI callback data if needed (after releasing the lock) + if has_sni_callback { + let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?; unsafe { let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - // Store callback data in SSL ex_data + // Store callback data in SSL ex_data - use weak reference to avoid cycle let callback_data = Box::new(SniCallbackData { ssl_context: zelf.clone(), - vm_ptr: vm as *const _, + ssl_socket_weak, }); let idx = get_sni_ex_data_index(); sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); - - // Store PyRef pointer (heap-allocated) in ex_data index 0 - sys::SSL_set_ex_data(ssl_ptr, 0, &*py_ref as *const _ as *mut _); } } @@ -1851,21 +1762,22 @@ mod _ssl { let py_ref = py_ssl_socket.into_ref_with_type(vm, PySslSocket::class(&vm.ctx).to_owned())?; - // Set SNI callback data if callback is configured - if zelf.sni_callback.lock().is_some() { + // Check if SNI callback is configured (minimize lock time) + let has_sni_callback = zelf.sni_callback.lock().is_some(); + + // Set SNI callback data if needed (after releasing the lock) + if has_sni_callback { + let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?; unsafe { let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - // Store callback data in SSL ex_data + // Store callback data in SSL ex_data - use weak reference to avoid cycle let callback_data = Box::new(SniCallbackData { ssl_context: zelf.clone(), - vm_ptr: vm as *const _, + ssl_socket_weak, }); let idx = get_sni_ex_data_index(); sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); - - // Store PyRef pointer (heap-allocated) in ex_data index 0 - sys::SSL_set_ex_data(ssl_ptr, 0, &*py_ref as *const _ as *mut _); } } @@ -1992,10 +1904,7 @@ mod _ssl { } fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Underlying socket has been closed.".to_owned(), - ) + new_ssl_error(vm, "Underlying socket has been closed.") } // BIO stream wrapper to implement Read/Write traits for MemoryBIO @@ -2152,12 +2061,13 @@ mod _ssl { } #[pygetset(setter)] fn set_context(&self, value: PyRef, vm: &VirtualMachine) -> PyResult<()> { - // Update the SSL context in the underlying SSL object - let stream = self.connection.read(); + // Get SSL pointer - use thread-local during handshake to avoid deadlock + // (connection lock is already held during handshake) + let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); // Set the new SSL_CTX on the SSL object unsafe { - let result = SSL_set_SSL_CTX(stream.ssl().as_ptr(), value.ctx().as_ptr()); + let result = SSL_set_SSL_CTX(ssl_ptr, value.ctx().as_ptr()); if result.is_null() { return Err(vm.new_runtime_error("Failed to set SSL context".to_owned())); } @@ -2275,6 +2185,12 @@ mod _ssl { .map(cipher_to_tuple) } + #[pymethod] + fn pending(&self) -> i32 { + let stream = self.connection.read(); + unsafe { sys::SSL_pending(stream.ssl().as_ptr()) } + } + #[pymethod] fn shared_ciphers(&self, vm: &VirtualMachine) -> Option { #[cfg(ossl110)] @@ -2450,8 +2366,8 @@ mod _ssl { // Non-blocking would block - this is okay for shutdown // Return the underlying socket } else { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), + return Err(new_ssl_error( + vm, format!("SSL shutdown failed: error code {}", err), )); } @@ -2491,9 +2407,13 @@ mod _ssl { let mut stream = self.connection.write(); let ssl_ptr = stream.ssl().as_ptr(); + // Set up thread-local VM and SSL pointer for callbacks + // This allows callbacks to access SSL without acquiring the connection lock + let _vm_guard = HandshakeVmGuard::new(vm, ssl_ptr); + // BIO mode: no timeout/select logic, just do handshake if stream.is_bio() { - return stream.do_handshake().map_err(|e| { + let result = stream.do_handshake().map_err(|e| { let exc = convert_ssl_error(vm, e); // If it's a cert verification error, set verify info if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { @@ -2501,6 +2421,10 @@ mod _ssl { } exc }); + // Clean up SNI ex_data after handshake (success or failure) + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; + return result; } // Socket mode: handle timeout and blocking @@ -2510,7 +2434,12 @@ mod _ssl { .timeout_deadline(); loop { let err = match stream.do_handshake() { - Ok(()) => return Ok(()), + Ok(()) => { + // Clean up SNI ex_data after successful handshake + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; + return Ok(()); + } Err(e) => e, }; let (needs, state) = stream @@ -2519,12 +2448,20 @@ mod _ssl { .socket_needs(&err, &timeout); match state { SelectRet::TimedOut => { + // Clean up SNI ex_data before returning error + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; return Err(socket::timeout_error_msg( vm, "The handshake operation timed out".to_owned(), - )); + ) + .upcast()); + } + SelectRet::Closed => { + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; + return Err(socket_closed_error(vm)); } - SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} SelectRet::IsBlocking | SelectRet::Ok => { // For blocking sockets, select() has completed successfully @@ -2539,6 +2476,9 @@ mod _ssl { if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { set_verify_error_info(&exc, ssl_ptr, vm); } + // Clean up SNI ex_data before returning error + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; return Err(exc); } } @@ -2565,7 +2505,8 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The write operation timed out".to_owned(), - )); + ) + .upcast()); } SelectRet::Closed => return Err(socket_closed_error(vm)), _ => {} @@ -2584,7 +2525,8 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The write operation timed out".to_owned(), - )); + ) + .upcast()); } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} @@ -2727,7 +2669,8 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The read operation timed out".to_owned(), - )); + ) + .upcast()); } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} @@ -3070,7 +3013,7 @@ mod _ssl { impl Constructor for PySslMemoryBio { type Args = (); - fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { + fn py_new(_cls: &Py, _args: Self::Args, vm: &VirtualMachine) -> PyResult { unsafe { let bio = sys::BIO_new(sys::BIO_s_mem()); if bio.is_null() { @@ -3080,12 +3023,10 @@ mod _ssl { sys::BIO_set_retry_read(bio); BIO_set_mem_eof_return(bio, -1); - PySslMemoryBio { + Ok(PySslMemoryBio { bio, eof_written: AtomicCell::new(false), - } - .into_ref_with_type(vm, cls) - .map(Into::into) + }) } } } @@ -3143,10 +3084,7 @@ mod _ssl { #[pymethod] fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { if self.eof_written.load() { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "cannot write() after write_eof()".to_owned(), - )); + return Err(new_ssl_error(vm, "cannot write() after write_eof()")); } data.with_ref(|buf| unsafe { @@ -3235,6 +3173,12 @@ mod _ssl { } } + /// Helper function to create SSL error with proper OSError subtype handling + fn new_ssl_error(vm: &VirtualMachine, msg: impl ToString) -> PyBaseExceptionRef { + vm.new_os_subtype_error(PySslError::class(&vm.ctx).to_owned(), None, msg.to_string()) + .upcast() + } + #[track_caller] pub(crate) fn convert_openssl_error( vm: &VirtualMachine, @@ -3255,12 +3199,7 @@ mod _ssl { } else { vm.ctx.exceptions.os_error.to_owned() }; - let exc = vm.new_exception(exc_type, vec![vm.ctx.new_int(reason).into()]); - // Set errno attribute explicitly - let _ = exc - .as_object() - .set_attr("errno", vm.ctx.new_int(reason), vm); - return exc; + return vm.new_os_subtype_error(exc_type, Some(reason), "").upcast(); } let caller = std::panic::Location::caller(); @@ -3310,13 +3249,8 @@ mod _ssl { // Create exception instance let reason = sys::ERR_GET_REASON(e.code()); - let exc = vm.new_exception( - cls, - vec![vm.ctx.new_int(reason).into(), vm.ctx.new_str(msg).into()], - ); - - // Set attributes on instance, not class - let exc_obj: PyObjectRef = exc.into(); + let exc = vm.new_os_subtype_error(cls, Some(reason), msg); + let exc_obj: PyObjectRef = exc.upcast::().into(); // Set reason attribute (always set, even if just the error string) let reason_value = vm.ctx.new_str(errstr); @@ -3345,7 +3279,8 @@ mod _ssl { } None => { let cls = PySslError::class(&vm.ctx).to_owned(); - vm.new_exception_empty(cls) + vm.new_os_subtype_error(cls, None, "unknown SSL error") + .upcast() } } } @@ -3396,15 +3331,13 @@ mod _ssl { // this is an EOF in violation of protocol -> SSLEOFError // Need to set args[0] = SSL_ERROR_EOF for suppress_ragged_eofs check None => { - return vm.new_exception( - PySslEOFError::class(&vm.ctx).to_owned(), - vec![ - vm.ctx.new_int(SSL_ERROR_EOF).into(), - vm.ctx - .new_str("EOF occurred in violation of protocol") - .into(), - ], - ); + return vm + .new_os_subtype_error( + PySslEOFError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_EOF as i32), + "EOF occurred in violation of protocol", + ) + .upcast(); } }, ssl::ErrorCode::SSL => { @@ -3417,15 +3350,13 @@ mod _ssl { let reason = sys::ERR_GET_REASON(err_code); let lib = sys::ERR_GET_LIB(err_code); if lib == ERR_LIB_SSL && reason == SSL_R_UNEXPECTED_EOF_WHILE_READING { - return vm.new_exception( - PySslEOFError::class(&vm.ctx).to_owned(), - vec![ - vm.ctx.new_int(SSL_ERROR_EOF).into(), - vm.ctx - .new_str("EOF occurred in violation of protocol") - .into(), - ], - ); + return vm + .new_os_subtype_error( + PySslEOFError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_EOF as i32), + "EOF occurred in violation of protocol", + ) + .upcast(); } } return convert_openssl_error(vm, ssl_err.clone()); @@ -3440,7 +3371,7 @@ mod _ssl { "A failure in the SSL library occurred", ), }; - vm.new_exception_msg(cls, msg.to_owned()) + vm.new_os_subtype_error(cls, None, msg).upcast() } // SSL_FILETYPE_ASN1 part of _add_ca_certs in CPython @@ -3543,10 +3474,13 @@ mod _ssl { ) -> Result<(), PyBaseExceptionRef> { let root = Path::new(CERT_DIR); if !root.is_dir() { - return Err(vm.new_exception_msg( - vm.ctx.exceptions.file_not_found_error.to_owned(), - CERT_DIR.to_string(), - )); + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + None, + CERT_DIR.to_string(), + ) + .upcast()); } let mut combined_pem = String::new(); diff --git a/crates/stdlib/src/openssl/cert.rs b/crates/stdlib/src/openssl/cert.rs index 1139f0e26f0..1197bf4aa46 100644 --- a/crates/stdlib/src/openssl/cert.rs +++ b/crates/stdlib/src/openssl/cert.rs @@ -165,7 +165,8 @@ pub(crate) mod ssl_cert { format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]) } else if ip.len() == 16 { // IPv6 - format with all zeros visible (not compressed) - let ip_addr = std::net::Ipv6Addr::from(ip[0..16]); + let ip_addr = + std::net::Ipv6Addr::from(<[u8; 16]>::try_from(&ip[0..16]).unwrap()); let s = ip_addr.segments(); format!( "{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}", diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index e6f8c5fda7c..992b32e00ea 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -22,11 +22,14 @@ mod cert; // OpenSSL compatibility layer (abstracts rustls operations) mod compat; +// SSL exception types (shared with openssl backend) +mod error; + pub(crate) use _ssl::make_module; #[allow(non_snake_case)] #[allow(non_upper_case_globals)] -#[pymodule] +#[pymodule(with(error::ssl_error))] mod _ssl { use crate::{ common::{ @@ -37,15 +40,18 @@ mod _ssl { vm::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, - builtins::{ - PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, PyType, PyTypeRef, - }, + builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef}, convert::IntoPyException, function::{ArgBytesLike, ArgMemoryBuffer, FuncArgs, OptionalArg, PyComparisonValue}, stdlib::warnings, types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable}, }, }; + + // Import error types used in this module (others are exposed via pymodule(with(...))) + use super::error::{ + PySSLEOFError, PySSLError, create_ssl_want_read_error, create_ssl_want_write_error, + }; use std::{ collections::HashMap, sync::{ @@ -342,106 +348,6 @@ mod _ssl { #[pyattr] const ENCODING_PEM_AUX: i32 = 0x101; // PEM + 0x100 - #[pyattr] - #[pyexception(name = "SSLError", base = PyOSError)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLError(PyOSError); - - #[pyexception] - impl PySSLError { - // Returns strerror attribute if available, otherwise str(args) - #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - // Try to get strerror attribute first (OSError compatibility) - if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) - && !vm.is_none(&strerror) - { - return strerror.str(vm); - } - - // Otherwise return str(args) - let args = exc.args(); - if args.len() == 1 { - args.as_slice()[0].str(vm) - } else { - args.as_object().str(vm) - } - } - } - - #[pyattr] - #[pyexception(name = "SSLZeroReturnError", base = PySSLError)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLZeroReturnError(PySSLError); - - #[pyexception] - impl PySSLZeroReturnError {} - - #[pyattr] - #[pyexception(name = "SSLWantReadError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLWantReadError(PySSLError); - - #[pyattr] - #[pyexception(name = "SSLWantWriteError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLWantWriteError(PySSLError); - - #[pyattr] - #[pyexception(name = "SSLSyscallError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLSyscallError(PySSLError); - - #[pyattr] - #[pyexception(name = "SSLEOFError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLEOFError(PySSLError); - - #[pyattr] - #[pyexception(name = "SSLCertVerificationError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLCertVerificationError(PySSLError); - - // Helper functions to create SSL exceptions with proper errno attribute - pub(super) fn create_ssl_want_read_error(vm: &VirtualMachine) -> PyRef { - vm.new_os_subtype_error( - PySSLWantReadError::class(&vm.ctx).to_owned(), - Some(SSL_ERROR_WANT_READ), - "The operation did not complete (read)", - ) - } - - pub(super) fn create_ssl_want_write_error(vm: &VirtualMachine) -> PyRef { - vm.new_os_subtype_error( - PySSLWantWriteError::class(&vm.ctx).to_owned(), - Some(SSL_ERROR_WANT_WRITE), - "The operation did not complete (write)", - ) - } - - pub(crate) fn create_ssl_eof_error(vm: &VirtualMachine) -> PyRef { - vm.new_os_subtype_error( - PySSLEOFError::class(&vm.ctx).to_owned(), - None, - "EOF occurred in violation of protocol", - ) - } - - pub(crate) fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { - vm.new_os_subtype_error( - PySSLZeroReturnError::class(&vm.ctx).to_owned(), - None, - "TLS/SSL connection has been closed (EOF)", - ) - } - /// Validate server hostname for TLS SNI /// /// Checks that the hostname: diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 4ccc590360a..ab3c81b7a4e 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -30,10 +30,13 @@ use rustpython_vm::{AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObjec use std::io::Read; use std::sync::{Arc, Once}; -// Import PySSLSocket and helper functions from parent module -use super::_ssl::{ - PySSLCertVerificationError, PySSLError, PySSLSocket, create_ssl_eof_error, - create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error, +// Import PySSLSocket from parent module +use super::_ssl::PySSLSocket; + +// Import error types and helper functions from error module +use super::error::{ + PySSLCertVerificationError, PySSLError, create_ssl_eof_error, create_ssl_want_read_error, + create_ssl_want_write_error, create_ssl_zero_return_error, }; // SSL Verification Flags diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs new file mode 100644 index 00000000000..e31683ec72d --- /dev/null +++ b/crates/stdlib/src/ssl/error.rs @@ -0,0 +1,117 @@ +// SSL exception types shared between ssl (rustls) and openssl backends + +pub(crate) use ssl_error::*; + +#[pymodule(sub)] +pub(crate) mod ssl_error { + use crate::vm::{ + PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyBaseExceptionRef, PyOSError, PyStrRef}, + types::Constructor, + }; + + // Error type constants (needed for create_ssl_want_read_error etc.) + pub(crate) const SSL_ERROR_WANT_READ: i32 = 2; + pub(crate) const SSL_ERROR_WANT_WRITE: i32 = 3; + + #[pyattr] + #[pyexception(name = "SSLError", base = PyOSError)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLError(PyOSError); + + #[pyexception] + impl PySSLError { + // Returns strerror attribute if available, otherwise str(args) + #[pymethod] + fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { + use crate::vm::AsObject; + // Try to get strerror attribute first (OSError compatibility) + if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) + && !vm.is_none(&strerror) + { + return strerror.str(vm); + } + + // Otherwise return str(args) + let args = exc.args(); + if args.len() == 1 { + args.as_slice()[0].str(vm) + } else { + args.as_object().str(vm) + } + } + } + + #[pyattr] + #[pyexception(name = "SSLZeroReturnError", base = PySSLError)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLZeroReturnError(PySSLError); + + #[pyexception] + impl PySSLZeroReturnError {} + + #[pyattr] + #[pyexception(name = "SSLWantReadError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLWantReadError(PySSLError); + + #[pyattr] + #[pyexception(name = "SSLWantWriteError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLWantWriteError(PySSLError); + + #[pyattr] + #[pyexception(name = "SSLSyscallError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLSyscallError(PySSLError); + + #[pyattr] + #[pyexception(name = "SSLEOFError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLEOFError(PySSLError); + + #[pyattr] + #[pyexception(name = "SSLCertVerificationError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLCertVerificationError(PySSLError); + + // Helper functions to create SSL exceptions with proper errno attribute + pub fn create_ssl_want_read_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( + PySSLWantReadError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_WANT_READ), + "The operation did not complete (read)", + ) + } + + pub fn create_ssl_want_write_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( + PySSLWantWriteError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_WANT_WRITE), + "The operation did not complete (write)", + ) + } + + pub fn create_ssl_eof_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( + PySSLEOFError::class(&vm.ctx).to_owned(), + None, + "EOF occurred in violation of protocol", + ) + } + + pub fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( + PySSLZeroReturnError::class(&vm.ctx).to_owned(), + None, + "TLS/SSL connection has been closed (EOF)", + ) + } +} From 70b93898d4e4dfe5b9badc29618a0fbbc6542787 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Fri, 19 Dec 2025 14:10:55 +0900 Subject: [PATCH 012/418] ctypes overhaul (#6450) --- .cspell.dict/python-more.txt | 2 + .cspell.json | 5 +- crates/stdlib/src/pystruct.rs | 2 +- crates/vm/src/buffer.rs | 50 + crates/vm/src/builtins/builtin_func.rs | 4 +- crates/vm/src/builtins/function.rs | 5 +- crates/vm/src/builtins/object.rs | 4 +- crates/vm/src/builtins/str.rs | 6 +- crates/vm/src/builtins/type.rs | 4 +- crates/vm/src/exceptions.rs | 11 +- crates/vm/src/object/core.rs | 22 + crates/vm/src/protocol/buffer.rs | 18 +- crates/vm/src/stdlib/ast/python.rs | 4 +- crates/vm/src/stdlib/codecs.rs | 4 +- crates/vm/src/stdlib/ctypes.rs | 1351 ++++++---- crates/vm/src/stdlib/ctypes/array.rs | 1658 +++++++----- crates/vm/src/stdlib/ctypes/base.rs | 3068 +++++++++++++++------- crates/vm/src/stdlib/ctypes/function.rs | 1986 ++++++++++++-- crates/vm/src/stdlib/ctypes/library.rs | 34 +- crates/vm/src/stdlib/ctypes/pointer.rs | 792 ++++-- crates/vm/src/stdlib/ctypes/simple.rs | 1379 ++++++++++ crates/vm/src/stdlib/ctypes/structure.rs | 747 +++--- crates/vm/src/stdlib/ctypes/thunk.rs | 319 --- crates/vm/src/stdlib/ctypes/union.rs | 665 +++-- crates/vm/src/stdlib/ctypes/util.rs | 88 - crates/vm/src/stdlib/functools.rs | 4 +- crates/vm/src/stdlib/operator.rs | 2 +- crates/vm/src/types/structseq.rs | 3 +- 28 files changed, 8879 insertions(+), 3358 deletions(-) create mode 100644 crates/vm/src/stdlib/ctypes/simple.rs delete mode 100644 crates/vm/src/stdlib/ctypes/thunk.rs delete mode 100644 crates/vm/src/stdlib/ctypes/util.rs diff --git a/.cspell.dict/python-more.txt b/.cspell.dict/python-more.txt index 58a0e816087..e8534e9744a 100644 --- a/.cspell.dict/python-more.txt +++ b/.cspell.dict/python-more.txt @@ -148,6 +148,7 @@ nbytes ncallbacks ndigits ndim +needsfree nldecoder nlocals NOARGS @@ -168,6 +169,7 @@ pycache pycodecs pycs pyexpat +PYTHONAPI PYTHONBREAKPOINT PYTHONDEBUG PYTHONDONTWRITEBYTECODE diff --git a/.cspell.json b/.cspell.json index 9f88a74f96d..3bd06fc2032 100644 --- a/.cspell.json +++ b/.cspell.json @@ -75,9 +75,9 @@ "makeunicodedata", "miri", "notrace", + "oparg", "openat", "pyarg", - "pyarg", "pyargs", "pyast", "PyAttr", @@ -107,6 +107,7 @@ "pystruct", "pystructseq", "pytrace", + "pytype", "reducelib", "richcompare", "RustPython", @@ -116,7 +117,6 @@ "sysmodule", "tracebacks", "typealiases", - "unconstructible", "unhashable", "uninit", "unraisable", @@ -131,6 +131,7 @@ "getrusage", "nanosleep", "sigaction", + "sighandler", "WRLCK", // win32 "birthtime", diff --git a/crates/stdlib/src/pystruct.rs b/crates/stdlib/src/pystruct.rs index 0a006f5a0f2..34a4905ed9f 100644 --- a/crates/stdlib/src/pystruct.rs +++ b/crates/stdlib/src/pystruct.rs @@ -28,7 +28,7 @@ pub(crate) mod _struct { // CPython turns str to bytes but we do reversed way here // The only performance difference is this transition cost let fmt = match_class!(match obj { - s @ PyStr => s.is_ascii().then_some(s), + s @ PyStr => s.isascii().then_some(s), b @ PyBytes => ascii::AsciiStr::from_ascii(&b) .ok() .map(|s| vm.ctx.new_str(s)), diff --git a/crates/vm/src/buffer.rs b/crates/vm/src/buffer.rs index cf49d6815c0..eeb6a676542 100644 --- a/crates/vm/src/buffer.rs +++ b/crates/vm/src/buffer.rs @@ -261,6 +261,56 @@ impl FormatCode { return Err("embedded null character".to_owned()); } + // PEP3118: Handle extended format specifiers + // T{...} - struct, X{} - function pointer, (...) - array shape, :name: - field name + if c == b'T' || c == b'X' { + // Skip struct/function pointer: consume until matching '}' + if chars.peek() == Some(&b'{') { + chars.next(); // consume '{' + let mut depth = 1; + while depth > 0 { + match chars.next() { + Some(b'{') => depth += 1, + Some(b'}') => depth -= 1, + None => return Err("unmatched '{' in format".to_owned()), + _ => {} + } + } + continue; + } + } + + if c == b'(' { + // Skip array shape: consume until matching ')' + let mut depth = 1; + while depth > 0 { + match chars.next() { + Some(b'(') => depth += 1, + Some(b')') => depth -= 1, + None => return Err("unmatched '(' in format".to_owned()), + _ => {} + } + } + continue; + } + + if c == b':' { + // Skip field name: consume until next ':' + loop { + match chars.next() { + Some(b':') => break, + None => return Err("unmatched ':' in format".to_owned()), + _ => {} + } + } + continue; + } + + if c == b'{' || c == b'}' { + // Skip standalone braces (pointer targets, etc.) + continue; + } + let code = FormatType::try_from(c) .ok() .filter(|c| match c { diff --git a/crates/vm/src/builtins/builtin_func.rs b/crates/vm/src/builtins/builtin_func.rs index da5fd5e8075..2b569375b28 100644 --- a/crates/vm/src/builtins/builtin_func.rs +++ b/crates/vm/src/builtins/builtin_func.rs @@ -114,7 +114,7 @@ impl PyNativeFunction { zelf.0.value.doc } - #[pygetset(name = "__self__")] + #[pygetset] fn __self__(_zelf: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { vm.ctx.none() } @@ -181,7 +181,7 @@ impl PyNativeMethod { Ok((getattr, (target, name))) } - #[pygetset(name = "__self__")] + #[pygetset] fn __self__(zelf: PyRef, _vm: &VirtualMachine) -> Option { zelf.func.zelf.clone() } diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index 0459cecbdd2..c29e45ddcf6 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -629,12 +629,11 @@ impl GetDescriptor for PyFunction { vm: &VirtualMachine, ) -> PyResult { let (_zelf, obj) = Self::_unwrap(&zelf, obj, vm)?; - let obj = if vm.is_none(&obj) && !Self::_cls_is(&cls, obj.class()) { + Ok(if vm.is_none(&obj) && !Self::_cls_is(&cls, obj.class()) { zelf } else { PyBoundMethod::new(obj, zelf).into_ref(&vm.ctx).into() - }; - Ok(obj) + }) } } diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs index 6f917cd853c..cb95652f937 100644 --- a/crates/vm/src/builtins/object.rs +++ b/crates/vm/src/builtins/object.rs @@ -450,8 +450,8 @@ impl PyBaseObject { Ok(()) } - #[pygetset(name = "__class__")] - fn get_class(obj: PyObjectRef) -> PyTypeRef { + #[pygetset] + fn __class__(obj: PyObjectRef) -> PyTypeRef { obj.class().to_owned() } diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index 279b84362a6..8084c4d053e 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -625,9 +625,9 @@ impl PyStr { self.data.char_len() } - #[pymethod(name = "isascii")] + #[pymethod] #[inline(always)] - pub const fn is_ascii(&self) -> bool { + pub const fn isascii(&self) -> bool { matches!(self.kind(), StrKind::Ascii) } @@ -960,7 +960,7 @@ impl PyStr { format_map(&format_string, &mapping, vm) } - #[pymethod(name = "__format__")] + #[pymethod] fn __format__( zelf: PyRef, spec: PyStrRef, diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 15743350397..68de17f60b6 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -1445,8 +1445,8 @@ impl GetAttr for PyType { #[pyclass] impl Py { - #[pygetset(name = "__mro__")] - fn get_mro(&self) -> PyTuple { + #[pygetset] + fn __mro__(&self) -> PyTuple { let elements: Vec = self.mro_map_collect(|x| x.as_object().to_owned()); PyTuple::new_unchecked(elements.into_boxed_slice()) } diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 036d914810d..bb10ca02c2c 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -624,8 +624,8 @@ impl PyBaseException { *self.context.write() = context; } - #[pygetset(name = "__suppress_context__")] - pub(super) fn get_suppress_context(&self) -> bool { + #[pygetset] + pub(super) fn __suppress_context__(&self) -> bool { self.suppress_context.load() } @@ -1112,7 +1112,7 @@ impl serde::Serialize for SerializeException<'_, '_> { .__context__() .map(|exc| SerializeExceptionOwned { vm: self.vm, exc }), )?; - struc.serialize_field("suppress_context", &self.exc.get_suppress_context())?; + struc.serialize_field("suppress_context", &self.exc.__suppress_context__())?; let args = { struct Args<'vm>(&'vm VirtualMachine, PyTupleRef); @@ -1550,6 +1550,7 @@ pub(super) mod types { pub struct PyUnboundLocalError(PyNameError); #[pyexception(name, base = PyException, ctx = "os_error")] + #[repr(C)] pub struct PyOSError { base: PyException, errno: PyAtomicRef>, @@ -1857,8 +1858,8 @@ pub(super) mod types { self.errno.swap_to_temporary_refs(value, vm); } - #[pygetset(name = "strerror")] - fn get_strerror(&self) -> Option { + #[pygetset] + fn strerror(&self) -> Option { self.strerror.to_owned() } diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index e04b87de594..60b623ef3ed 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -1102,6 +1102,28 @@ where } } +impl Py { + /// Converts `&Py` to `&Py`. + #[inline] + pub fn to_base(&self) -> &Py { + debug_assert!(self.as_object().downcast_ref::().is_some()); + // SAFETY: T is #[repr(transparent)] over T::Base, + // so Py and Py have the same layout. + unsafe { &*(self as *const Py as *const Py) } + } + + /// Converts `&Py` to `&Py` where U is an ancestor type. + #[inline] + pub fn upcast_ref(&self) -> &Py + where + T: StaticType, + { + debug_assert!(T::static_type().is_subtype(U::static_type())); + // SAFETY: T is a subtype of U, so Py can be viewed as Py. + unsafe { &*(self as *const Py as *const Py) } + } +} + impl Borrow for PyRef where T: PyPayload, diff --git a/crates/vm/src/protocol/buffer.rs b/crates/vm/src/protocol/buffer.rs index 1dafda203d9..948ec763dc6 100644 --- a/crates/vm/src/protocol/buffer.rs +++ b/crates/vm/src/protocol/buffer.rs @@ -202,14 +202,18 @@ impl BufferDescriptor { #[cfg(debug_assertions)] pub fn validate(self) -> Self { assert!(self.itemsize != 0); - assert!(self.ndim() != 0); - let mut shape_product = 1; - for (shape, stride, suboffset) in self.dim_desc.iter().cloned() { - shape_product *= shape; - assert!(suboffset >= 0); - assert!(stride != 0); + // ndim=0 is valid for scalar types (e.g., ctypes Structure) + if self.ndim() == 0 { + assert!(self.itemsize == self.len); + } else { + let mut shape_product = 1; + for (shape, stride, suboffset) in self.dim_desc.iter().cloned() { + shape_product *= shape; + assert!(suboffset >= 0); + assert!(stride != 0); + } + assert!(shape_product * self.itemsize == self.len); } - assert!(shape_product * self.itemsize == self.len); self } diff --git a/crates/vm/src/stdlib/ast/python.rs b/crates/vm/src/stdlib/ast/python.rs index 042db4aa74e..aa21d8b034a 100644 --- a/crates/vm/src/stdlib/ast/python.rs +++ b/crates/vm/src/stdlib/ast/python.rs @@ -47,8 +47,8 @@ pub(crate) mod _ast { Ok(()) } - #[pyattr(name = "_fields")] - fn fields(ctx: &Context) -> PyTupleRef { + #[pyattr] + fn _fields(ctx: &Context) -> PyTupleRef { ctx.empty_tuple.clone() } } diff --git a/crates/vm/src/stdlib/codecs.rs b/crates/vm/src/stdlib/codecs.rs index 5f1b721dfb4..821b313090c 100644 --- a/crates/vm/src/stdlib/codecs.rs +++ b/crates/vm/src/stdlib/codecs.rs @@ -176,7 +176,7 @@ mod _codecs { #[pyfunction] fn latin_1_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { - if args.s.is_ascii() { + if args.s.isascii() { return Ok((args.s.as_bytes().to_vec(), args.s.byte_len())); } do_codec!(latin_1::encode, args, vm) @@ -189,7 +189,7 @@ mod _codecs { #[pyfunction] fn ascii_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { - if args.s.is_ascii() { + if args.s.isascii() { return Ok((args.s.as_bytes().to_vec(), args.s.byte_len())); } do_codec!(ascii::encode, args, vm) diff --git a/crates/vm/src/stdlib/ctypes.rs b/crates/vm/src/stdlib/ctypes.rs index ebe2d16ffb2..3fdb2df6104 100644 --- a/crates/vm/src/stdlib/ctypes.rs +++ b/crates/vm/src/stdlib/ctypes.rs @@ -1,77 +1,372 @@ // spell-checker:disable -pub(crate) mod array; -pub(crate) mod base; -pub(crate) mod field; -pub(crate) mod function; -pub(crate) mod library; -pub(crate) mod pointer; -pub(crate) mod structure; -pub(crate) mod thunk; -pub(crate) mod union; -pub(crate) mod util; - -use crate::builtins::PyModule; -use crate::class::PyClassImpl; -use crate::{Py, PyRef, VirtualMachine}; - -pub use crate::stdlib::ctypes::base::{CDataObject, PyCData, PyCSimple, PyCSimpleType}; - -pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) { +mod array; +mod base; +mod function; +mod library; +mod pointer; +mod simple; +mod structure; +mod union; + +use crate::{ + AsObject, Py, PyObjectRef, PyRef, PyResult, VirtualMachine, + builtins::{PyModule, PyStr, PyType}, + class::PyClassImpl, + types::TypeDataRef, +}; +use std::ffi::{ + c_double, c_float, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, + c_ulonglong, c_ushort, +}; +use std::mem; +use widestring::WideChar; + +pub use array::PyCArray; +pub use base::{FfiArgValue, PyCData, PyCField, StgInfo, StgInfoFlags}; +pub use pointer::PyCPointer; +pub use simple::{PyCSimple, PyCSimpleType}; +pub use structure::PyCStructure; +pub use union::PyCUnion; + +/// Extension for PyType to get StgInfo +/// PyStgInfo_FromType +impl Py { + /// Get StgInfo from a ctypes type object + /// + /// Returns a TypeDataRef to StgInfo if the type has one and is initialized, error otherwise. + /// Abstract classes (whose metaclass __init__ was not called) will have uninitialized StgInfo. + fn stg_info<'a>(&'a self, vm: &VirtualMachine) -> PyResult> { + self.stg_info_opt() + .ok_or_else(|| vm.new_type_error("abstract class")) + } + + /// Get StgInfo if initialized, None otherwise. + fn stg_info_opt(&self) -> Option> { + self.get_type_data::() + .filter(|info| info.initialized) + } + + /// Get _type_ attribute as String (type code like "i", "d", etc.) + fn type_code(&self, vm: &VirtualMachine) -> Option { + self.as_object() + .get_attr("_type_", vm) + .ok() + .and_then(|t: PyObjectRef| t.downcast_ref::().map(|s| s.to_string())) + } + + /// Mark all base classes as finalized + fn mark_bases_final(&self) { + for base in self.bases.read().iter() { + if let Some(mut stg) = base.get_type_data_mut::() { + stg.flags |= StgInfoFlags::DICTFLAG_FINAL; + } else { + let mut stg = StgInfo::default(); + stg.flags |= StgInfoFlags::DICTFLAG_FINAL; + let _ = base.init_type_data(stg); + } + } + } +} + +impl PyType { + /// Check if StgInfo is already initialized - prevent double initialization + pub(crate) fn check_not_initialized(&self, vm: &VirtualMachine) -> PyResult<()> { + if let Some(stg_info) = self.get_type_data::() + && stg_info.initialized + { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.system_error.to_owned(), + format!("StgInfo of '{}' is already initialized.", self.name()), + )); + } + Ok(()) + } +} + +// Dynamic type check helpers for PyCData +// These check if an object's type's metaclass is a subclass of a specific metaclass + +pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { + let module = _ctypes::make_module(vm); let ctx = &vm.ctx; PyCSimpleType::make_class(ctx); array::PyCArrayType::make_class(ctx); - field::PyCFieldType::make_class(ctx); pointer::PyCPointerType::make_class(ctx); structure::PyCStructType::make_class(ctx); union::PyCUnionType::make_class(ctx); - extend_module!(vm, module, { + extend_module!(vm, &module, { "_CData" => PyCData::make_class(ctx), "_SimpleCData" => PyCSimple::make_class(ctx), - "Array" => array::PyCArray::make_class(ctx), - "CField" => field::PyCField::make_class(ctx), + "Array" => PyCArray::make_class(ctx), + "CField" => PyCField::make_class(ctx), "CFuncPtr" => function::PyCFuncPtr::make_class(ctx), - "_Pointer" => pointer::PyCPointer::make_class(ctx), + "_Pointer" => PyCPointer::make_class(ctx), "_pointer_type_cache" => ctx.new_dict(), - "Structure" => structure::PyCStructure::make_class(ctx), - "CThunkObject" => thunk::PyCThunk::make_class(ctx), - "Union" => union::PyCUnion::make_class(ctx), - }) + "_array_type_cache" => ctx.new_dict(), + "Structure" => PyCStructure::make_class(ctx), + "CThunkObject" => function::PyCThunk::make_class(ctx), + "Union" => PyCUnion::make_class(ctx), + }); + module } -pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { - let module = _ctypes::make_module(vm); - extend_module_nodes(vm, &module); - module +/// Size of long double - platform dependent +/// x86_64 macOS/Linux: 16 bytes (80-bit extended + padding) +/// ARM64: 16 bytes (128-bit) +/// Windows: 8 bytes (same as double) +#[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") +))] +const LONG_DOUBLE_SIZE: usize = 16; + +#[cfg(target_os = "windows")] +const LONG_DOUBLE_SIZE: usize = mem::size_of::(); + +#[cfg(not(any( + all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ), + target_os = "windows" +)))] +const LONG_DOUBLE_SIZE: usize = mem::size_of::(); + +/// Type information for ctypes simple types +struct TypeInfo { + pub size: usize, + pub ffi_type_fn: fn() -> libffi::middle::Type, +} + +/// Get type information (size and ffi_type) for a ctypes type code +fn type_info(ty: &str) -> Option { + use libffi::middle::Type; + match ty { + "c" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::u8, + }), + "u" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: if mem::size_of::() == 2 { + Type::u16 + } else { + Type::u32 + }, + }), + "b" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::i8, + }), + "B" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::u8, + }), + "h" | "v" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::i16, + }), + "H" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::u16, + }), + "i" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::i32, + }), + "I" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::u32, + }), + "l" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: if mem::size_of::() == 8 { + Type::i64 + } else { + Type::i32 + }, + }), + "L" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: if mem::size_of::() == 8 { + Type::u64 + } else { + Type::u32 + }, + }), + "q" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::i64, + }), + "Q" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::u64, + }), + "f" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::f32, + }), + "d" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::f64, + }), + "g" => Some(TypeInfo { + // long double - platform dependent size + // x86_64 macOS/Linux: 16 bytes (80-bit extended + padding) + // ARM64: 16 bytes (128-bit) + // Windows: 8 bytes (same as double) + // Note: Use f64 as FFI type since Rust doesn't support long double natively + size: LONG_DOUBLE_SIZE, + ffi_type_fn: Type::f64, + }), + "?" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::u8, + }), + "z" | "Z" | "P" | "X" | "O" => Some(TypeInfo { + size: mem::size_of::(), + ffi_type_fn: Type::pointer, + }), + "void" => Some(TypeInfo { + size: 0, + ffi_type_fn: Type::void, + }), + _ => None, + } +} + +/// Get size for a ctypes type code +fn get_size(ty: &str) -> usize { + type_info(ty).map(|t| t.size).expect("invalid type code") +} + +/// Get alignment for simple type codes from type_info(). +/// For primitive C types (c_int, c_long, etc.), alignment equals size. +fn get_align(ty: &str) -> usize { + get_size(ty) } #[pymodule] pub(crate) mod _ctypes { - use super::base::{CDataObject, PyCData, PyCSimple}; - use crate::builtins::PyTypeRef; + use super::library; + use super::{PyCArray, PyCData, PyCPointer, PyCSimple, PyCStructure, PyCUnion}; + use crate::builtins::{PyType, PyTypeRef}; use crate::class::StaticType; use crate::convert::ToPyObject; - use crate::function::{Either, FuncArgs, OptionalArg}; - use crate::stdlib::ctypes::library; - use crate::{AsObject, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine}; - use crossbeam_utils::atomic::AtomicCell; - use std::ffi::{ - c_double, c_float, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, - c_ulonglong, c_ushort, - }; - use std::mem; - use widestring::WideChar; - - /// CArgObject - returned by byref() + use crate::function::{Either, OptionalArg}; + use crate::types::Representable; + use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; + use num_traits::ToPrimitive; + + /// CArgObject - returned by byref() and paramfunc + /// tagPyCArgObject #[pyclass(name = "CArgObject", module = "_ctypes", no_attr)] #[derive(Debug, PyPayload)] pub struct CArgObject { + /// Type tag ('P', 'V', 'i', 'd', etc.) + pub tag: u8, + /// The actual FFI value (mirrors union value) + pub value: super::FfiArgValue, + /// Reference to original object (for memory safety) pub obj: PyObjectRef, + /// Size for struct/union ('V' tag) #[allow(dead_code)] + pub size: usize, + /// Offset for byref() pub offset: isize, } - #[pyclass] + /// is_literal_char - check if character is printable literal (not \\ or ') + fn is_literal_char(c: u8) -> bool { + c < 128 && c.is_ascii_graphic() && c != b'\\' && c != b'\'' + } + + impl Representable for CArgObject { + // PyCArg_repr - use tag and value fields directly + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + use super::base::FfiArgValue; + + let tag_char = zelf.tag as char; + + // Format value based on tag + match zelf.tag { + b'b' | b'h' | b'i' | b'l' | b'q' => { + // Signed integers + let n = match zelf.value { + FfiArgValue::I8(v) => v as i64, + FfiArgValue::I16(v) => v as i64, + FfiArgValue::I32(v) => v as i64, + FfiArgValue::I64(v) => v, + _ => 0, + }; + Ok(format!("", tag_char, n)) + } + b'B' | b'H' | b'I' | b'L' | b'Q' => { + // Unsigned integers + let n = match zelf.value { + FfiArgValue::U8(v) => v as u64, + FfiArgValue::U16(v) => v as u64, + FfiArgValue::U32(v) => v as u64, + FfiArgValue::U64(v) => v, + _ => 0, + }; + Ok(format!("", tag_char, n)) + } + b'f' => { + let v = match zelf.value { + FfiArgValue::F32(v) => v as f64, + _ => 0.0, + }; + Ok(format!("", tag_char, v)) + } + b'd' | b'g' => { + let v = match zelf.value { + FfiArgValue::F64(v) => v, + FfiArgValue::F32(v) => v as f64, + _ => 0.0, + }; + Ok(format!("", tag_char, v)) + } + b'c' => { + // c_char - single byte + let byte = match zelf.value { + FfiArgValue::I8(v) => v as u8, + FfiArgValue::U8(v) => v, + _ => 0, + }; + if is_literal_char(byte) { + Ok(format!("", tag_char, byte as char)) + } else { + Ok(format!("", tag_char, byte)) + } + } + b'z' | b'Z' | b'P' | b'V' => { + // Pointer types + let ptr = match zelf.value { + FfiArgValue::Pointer(v) => v, + _ => 0, + }; + if ptr == 0 { + Ok(format!("", tag_char)) + } else { + Ok(format!("", tag_char, ptr)) + } + } + _ => { + // Default fallback + let addr = zelf.get_id(); + if is_literal_char(zelf.tag) { + Ok(format!("", tag_char, addr)) + } else { + Ok(format!("", zelf.tag, addr)) + } + } + } + } + } + + #[pyclass(with(Representable))] impl CArgObject { #[pygetset] fn _obj(&self) -> PyObjectRef { @@ -83,43 +378,43 @@ pub(crate) mod _ctypes { const __VERSION__: &str = "1.1.0"; // TODO: get properly - #[pyattr(name = "RTLD_LOCAL")] + #[pyattr] const RTLD_LOCAL: i32 = 0; // TODO: get properly - #[pyattr(name = "RTLD_GLOBAL")] + #[pyattr] const RTLD_GLOBAL: i32 = 0; #[cfg(target_os = "windows")] - #[pyattr(name = "SIZEOF_TIME_T")] - pub const SIZEOF_TIME_T: usize = 8; + #[pyattr] + const SIZEOF_TIME_T: usize = 8; #[cfg(not(target_os = "windows"))] - #[pyattr(name = "SIZEOF_TIME_T")] - pub const SIZEOF_TIME_T: usize = 4; + #[pyattr] + const SIZEOF_TIME_T: usize = 4; - #[pyattr(name = "CTYPES_MAX_ARGCOUNT")] - pub const CTYPES_MAX_ARGCOUNT: usize = 1024; + #[pyattr] + const CTYPES_MAX_ARGCOUNT: usize = 1024; #[pyattr] - pub const FUNCFLAG_STDCALL: u32 = 0x0; + const FUNCFLAG_STDCALL: u32 = 0x0; #[pyattr] - pub const FUNCFLAG_CDECL: u32 = 0x1; + const FUNCFLAG_CDECL: u32 = 0x1; #[pyattr] - pub const FUNCFLAG_HRESULT: u32 = 0x2; + const FUNCFLAG_HRESULT: u32 = 0x2; #[pyattr] - pub const FUNCFLAG_PYTHONAPI: u32 = 0x4; + const FUNCFLAG_PYTHONAPI: u32 = 0x4; #[pyattr] - pub const FUNCFLAG_USE_ERRNO: u32 = 0x8; + const FUNCFLAG_USE_ERRNO: u32 = 0x8; #[pyattr] - pub const FUNCFLAG_USE_LASTERROR: u32 = 0x10; + const FUNCFLAG_USE_LASTERROR: u32 = 0x10; #[pyattr] - pub const TYPEFLAG_ISPOINTER: u32 = 0x100; + const TYPEFLAG_ISPOINTER: u32 = 0x100; #[pyattr] - pub const TYPEFLAG_HASPOINTER: u32 = 0x200; + const TYPEFLAG_HASPOINTER: u32 = 0x200; #[pyattr] - pub const DICTFLAG_FINAL: u32 = 0x1000; + const DICTFLAG_FINAL: u32 = 0x1000; #[pyattr(name = "ArgumentError", once)] fn argument_error(vm: &VirtualMachine) -> PyTypeRef { @@ -130,369 +425,138 @@ pub(crate) mod _ctypes { ) } - #[pyattr(name = "FormatError", once)] - fn format_error(vm: &VirtualMachine) -> PyTypeRef { - vm.ctx.new_exception_type( - "_ctypes", - "FormatError", - Some(vec![vm.ctx.exceptions.exception_type.to_owned()]), - ) - } - - pub fn get_size(ty: &str) -> usize { - match ty { - "u" => mem::size_of::(), - "c" | "b" => mem::size_of::(), - "h" => mem::size_of::(), - "H" => mem::size_of::(), - "i" => mem::size_of::(), - "I" => mem::size_of::(), - "l" => mem::size_of::(), - "q" => mem::size_of::(), - "L" => mem::size_of::(), - "Q" => mem::size_of::(), - "f" => mem::size_of::(), - "d" | "g" => mem::size_of::(), - "?" | "B" => mem::size_of::(), - "P" | "z" | "Z" => mem::size_of::(), - "O" => mem::size_of::(), - _ => unreachable!(), - } - } - - /// Get alignment for a simple type - for C types, alignment equals size - pub fn get_align(ty: &str) -> usize { - get_size(ty) - } - - /// Get the size of a ctypes type from its type object - #[allow(dead_code)] - pub fn get_size_from_type(cls: &PyTypeRef, vm: &VirtualMachine) -> PyResult { - // Try to get _type_ attribute for simple types - if let Ok(type_attr) = cls.as_object().get_attr("_type_", vm) - && let Ok(s) = type_attr.str(vm) - { - let s = s.to_string(); - if s.len() == 1 && SIMPLE_TYPE_CHARS.contains(s.as_str()) { - return Ok(get_size(&s)); - } + #[cfg(target_os = "windows")] + #[pyattr(name = "COMError", once)] + fn com_error(vm: &VirtualMachine) -> PyTypeRef { + use crate::builtins::type_::PyAttributes; + use crate::function::FuncArgs; + use crate::types::{PyTypeFlags, PyTypeSlots}; + + // Sets hresult, text, details as instance attributes in __init__ + // This function has InitFunc signature for direct slots.init use + fn comerror_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + let (hresult, text, details): ( + Option, + Option, + Option, + ) = args.bind(vm)?; + let hresult = hresult.unwrap_or_else(|| vm.ctx.none()); + let text = text.unwrap_or_else(|| vm.ctx.none()); + let details = details.unwrap_or_else(|| vm.ctx.none()); + + // Set instance attributes + zelf.set_attr("hresult", hresult.clone(), vm)?; + zelf.set_attr("text", text.clone(), vm)?; + zelf.set_attr("details", details.clone(), vm)?; + + // self.args = args[1:] = (text, details) + // via: PyObject_SetAttrString(self, "args", PySequence_GetSlice(args, 1, size)) + let args_tuple: PyObjectRef = vm.ctx.new_tuple(vec![text, details]).into(); + zelf.set_attr("args", args_tuple, vm)?; + + Ok(()) } - // Fall back to sizeof - size_of(cls.clone().into(), vm) - } - /// Convert bytes to appropriate Python object based on ctypes type - pub fn bytes_to_pyobject( - cls: &PyTypeRef, - bytes: &[u8], - vm: &VirtualMachine, - ) -> PyResult { - // Try to get _type_ attribute - if let Ok(type_attr) = cls.as_object().get_attr("_type_", vm) - && let Ok(s) = type_attr.str(vm) - { - let ty = s.to_string(); - return match ty.as_str() { - "c" => { - // c_char - single byte - Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) - } - "b" => { - // c_byte - signed char - let val = if !bytes.is_empty() { bytes[0] as i8 } else { 0 }; - Ok(vm.ctx.new_int(val).into()) - } - "B" => { - // c_ubyte - unsigned char - let val = if !bytes.is_empty() { bytes[0] } else { 0 }; - Ok(vm.ctx.new_int(val).into()) - } - "h" => { - // c_short - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_short::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "H" => { - // c_ushort - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_ushort::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "i" => { - // c_int - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_int::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "I" => { - // c_uint - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_uint::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "l" => { - // c_long - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_long::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "L" => { - // c_ulong - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_ulong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "q" => { - // c_longlong - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_longlong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "Q" => { - // c_ulonglong - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_ulonglong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "f" => { - // c_float - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_float::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0.0 - }; - Ok(vm.ctx.new_float(val as f64).into()) - } - "d" | "g" => { - // c_double - const SIZE: usize = mem::size_of::(); - let val = if bytes.len() >= SIZE { - c_double::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) - } else { - 0.0 - }; - Ok(vm.ctx.new_float(val).into()) - } - "?" => { - // c_bool - let val = !bytes.is_empty() && bytes[0] != 0; - Ok(vm.ctx.new_bool(val).into()) - } - "P" | "z" | "Z" => { - // Pointer types - return as integer address - let val = if bytes.len() >= mem::size_of::() { - const UINTPTR_LEN: usize = mem::size_of::(); - let mut arr = [0u8; UINTPTR_LEN]; - arr[..bytes.len().min(UINTPTR_LEN)] - .copy_from_slice(&bytes[..bytes.len().min(UINTPTR_LEN)]); - usize::from_ne_bytes(arr) - } else { - 0 - }; - Ok(vm.ctx.new_int(val).into()) - } - "u" => { - // c_wchar - wide character - let val = if bytes.len() >= mem::size_of::() { - let wc = if mem::size_of::() == 2 { - u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 - } else { - u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) - }; - char::from_u32(wc).unwrap_or('\0') - } else { - '\0' - }; - Ok(vm.ctx.new_str(val.to_string()).into()) - } - _ => Ok(vm.ctx.none()), - }; - } - // Default: return bytes as-is - Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) - } + // Create exception type with IMMUTABLETYPE flag + let mut attrs = PyAttributes::default(); + attrs.insert( + vm.ctx.intern_str("__module__"), + vm.ctx.new_str("_ctypes").into(), + ); + attrs.insert( + vm.ctx.intern_str("__doc__"), + vm.ctx + .new_str("Raised when a COM method call failed.") + .into(), + ); + + // Create slots with IMMUTABLETYPE flag + let slots = PyTypeSlots { + name: "COMError", + flags: PyTypeFlags::heap_type_flags() + | PyTypeFlags::HAS_DICT + | PyTypeFlags::IMMUTABLETYPE, + ..PyTypeSlots::default() + }; - const SIMPLE_TYPE_CHARS: &str = "cbBhHiIlLdfguzZPqQ?O"; + let exc_type = PyType::new_heap( + "COMError", + vec![vm.ctx.exceptions.exception_type.to_owned()], + attrs, + slots, + vm.ctx.types.type_type.to_owned(), + &vm.ctx, + ) + .unwrap(); - pub fn new_simple_type( - cls: Either<&PyObject, &PyTypeRef>, - vm: &VirtualMachine, - ) -> PyResult { - let cls = match cls { - Either::A(obj) => obj, - Either::B(typ) => typ.as_object(), - }; + // Set our custom init after new_heap, which runs init_slots that would + // otherwise overwrite slots.init with init_wrapper (due to __init__ in MRO). + exc_type.slots.init.store(Some(comerror_init)); - if let Ok(_type_) = cls.get_attr("_type_", vm) { - if _type_.is_instance((&vm.ctx.types.str_type).as_ref(), vm)? { - let tp_str = _type_.str(vm)?.to_string(); - - if tp_str.len() != 1 { - Err(vm.new_value_error( - format!("class must define a '_type_' attribute which must be a string of length 1, str: {tp_str}"), - )) - } else if !SIMPLE_TYPE_CHARS.contains(tp_str.as_str()) { - Err(vm.new_attribute_error(format!("class must define a '_type_' attribute which must be\n a single character string containing one of {SIMPLE_TYPE_CHARS}, currently it is {tp_str}."))) - } else { - let size = get_size(&tp_str); - let cdata = CDataObject::from_bytes(vec![0u8; size], None); - Ok(PyCSimple { - _base: PyCData::new(cdata.clone()), - _type_: tp_str, - value: AtomicCell::new(vm.ctx.none()), - cdata: rustpython_common::lock::PyRwLock::new(cdata), - }) - } - } else { - Err(vm.new_type_error("class must define a '_type_' string attribute")) - } - } else { - Err(vm.new_attribute_error("class must define a '_type_' attribute")) - } + exc_type } /// Get the size of a ctypes type or instance - #[pyfunction(name = "sizeof")] - pub fn size_of(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - use super::pointer::PyCPointer; - use super::structure::{PyCStructType, PyCStructure}; + #[pyfunction] + pub fn sizeof(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + use super::structure::PyCStructType; use super::union::PyCUnionType; - use super::util::StgInfo; - use crate::builtins::PyType; - // 1. Check TypeDataSlot on class (for instances) - if let Some(stg_info) = obj.class().get_type_data::() { - return Ok(stg_info.size); - } - - // 2. Check TypeDataSlot on type itself (for type objects) - if let Some(type_obj) = obj.downcast_ref::() - && let Some(stg_info) = type_obj.get_type_data::() - { - return Ok(stg_info.size); - } - - // 3. Instances with cdata buffer - if let Some(structure) = obj.downcast_ref::() { - return Ok(structure.cdata.read().size()); - } - if let Some(simple) = obj.downcast_ref::() { - return Ok(simple.cdata.read().size()); - } - if obj.fast_isinstance(PyCPointer::static_type()) { - return Ok(std::mem::size_of::()); - } - - // 3. Type objects - if let Ok(type_ref) = obj.clone().downcast::() { - // Structure types - check if metaclass is or inherits from PyCStructType - if type_ref + // 1. Check if obj is a TYPE object (not instance) - PyStgInfo_FromType + if let Some(type_obj) = obj.downcast_ref::() { + // Type object - return StgInfo.size + if let Some(stg_info) = type_obj.stg_info_opt() { + return Ok(stg_info.size); + } + // Fallback for type objects without StgInfo + // Array types + if type_obj + .class() + .fast_issubclass(super::array::PyCArrayType::static_type()) + && let Ok(stg) = type_obj.stg_info(vm) + { + return Ok(stg.size); + } + // Structure types + if type_obj .class() .fast_issubclass(PyCStructType::static_type()) { - return calculate_struct_size(&type_ref, vm); + return super::structure::calculate_struct_size(type_obj, vm); } - // Union types - check if metaclass is or inherits from PyCUnionType - if type_ref + // Union types + if type_obj .class() .fast_issubclass(PyCUnionType::static_type()) { - return calculate_union_size(&type_ref, vm); + return super::union::calculate_union_size(type_obj, vm); } - // Simple types (c_int, c_char, etc.) - if type_ref.fast_issubclass(PyCSimple::static_type()) { - let instance = new_simple_type(Either::B(&type_ref), vm)?; - return Ok(get_size(&instance._type_)); + // Simple types + if type_obj.fast_issubclass(PyCSimple::static_type()) { + if let Ok(type_attr) = type_obj.as_object().get_attr("_type_", vm) + && let Ok(type_str) = type_attr.str(vm) + { + return Ok(super::get_size(type_str.as_ref())); + } + return Ok(std::mem::size_of::()); } // Pointer types - if type_ref.fast_issubclass(PyCPointer::static_type()) { + if type_obj.fast_issubclass(PyCPointer::static_type()) { return Ok(std::mem::size_of::()); } + return Err(vm.new_type_error("this type has no size")); } - Err(vm.new_type_error("this type has no size")) - } - - /// Calculate Structure type size from _fields_ (sum of field sizes) - fn calculate_struct_size( - cls: &crate::builtins::PyTypeRef, - vm: &VirtualMachine, - ) -> PyResult { - use crate::AsObject; - - if let Ok(fields_attr) = cls.as_object().get_attr("_fields_", vm) { - let fields: Vec = fields_attr.try_to_value(vm).unwrap_or_default(); - let mut total_size = 0usize; - - for field in fields.iter() { - if let Some(tuple) = field.downcast_ref::() - && let Some(field_type) = tuple.get(1) - { - // Recursively calculate field type size - total_size += size_of(field_type.clone(), vm)?; - } - } - return Ok(total_size); + // 2. Instance object - return actual buffer size (b_size) + // CDataObject_Check + return obj->b_size + if let Some(cdata) = obj.downcast_ref::() { + return Ok(cdata.size()); } - Ok(0) - } - - /// Calculate Union type size from _fields_ (max field size) - fn calculate_union_size( - cls: &crate::builtins::PyTypeRef, - vm: &VirtualMachine, - ) -> PyResult { - use crate::AsObject; - - if let Ok(fields_attr) = cls.as_object().get_attr("_fields_", vm) { - let fields: Vec = fields_attr.try_to_value(vm).unwrap_or_default(); - let mut max_size = 0usize; - - for field in fields.iter() { - if let Some(tuple) = field.downcast_ref::() - && let Some(field_type) = tuple.get(1) - { - let field_size = size_of(field_type.clone(), vm)?; - max_size = max_size.max(field_size); - } - } - return Ok(max_size); + if obj.fast_isinstance(PyCPointer::static_type()) { + return Ok(std::mem::size_of::()); } - Ok(0) + + Err(vm.new_type_error("this type has no size")) } #[cfg(windows)] @@ -513,7 +577,7 @@ pub(crate) mod _ctypes { #[cfg(not(windows))] #[pyfunction(name = "dlopen")] fn load_library_unix( - name: Option, + name: Option, _load_flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult { @@ -523,9 +587,12 @@ pub(crate) mod _ctypes { Some(name) => { let cache = library::libcache(); let mut cache_write = cache.write(); - let (id, _) = cache_write - .get_or_insert_lib(&name, vm) - .map_err(|e| vm.new_os_error(e.to_string()))?; + let os_str = name.as_os_str(vm)?; + let (id, _) = cache_write.get_or_insert_lib(&*os_str, vm).map_err(|e| { + // Include filename in error message for better diagnostics + let name_str = os_str.to_string_lossy(); + vm.new_os_error(format!("{}: {}", name_str, e)) + })?; Ok(id) } None => { @@ -548,7 +615,9 @@ pub(crate) mod _ctypes { } #[pyfunction(name = "POINTER")] - pub fn create_pointer_type(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn create_pointer_type(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult { + use crate::builtins::PyStr; + // Get the _pointer_type_cache let ctypes_module = vm.import("_ctypes", 0)?; let cache = ctypes_module.get_attr("_pointer_type_cache", vm)?; @@ -563,33 +632,60 @@ pub(crate) mod _ctypes { // Get the _Pointer base class let pointer_base = ctypes_module.get_attr("_Pointer", vm)?; + // Create a new type that inherits from _Pointer + let pointer_base_type = pointer_base + .clone() + .downcast::() + .map_err(|_| vm.new_type_error("_Pointer must be a type"))?; + let metaclass = pointer_base_type.class().to_owned(); + + let bases = vm.ctx.new_tuple(vec![pointer_base]); + let dict = vm.ctx.new_dict(); + + // PyUnicode_CheckExact(cls) - string creates incomplete pointer type + if let Some(s) = cls.downcast_ref::() { + // Incomplete pointer type: _type_ not set, cache key is id(result) + let name = format!("LP_{}", s.as_str()); + + let new_type = metaclass + .as_object() + .call((vm.ctx.new_str(name), bases, dict), vm)?; + + // Store with id(result) as key for incomplete pointer types + let id_key: PyObjectRef = vm.ctx.new_int(new_type.get_id() as i64).into(); + vm.call_method(&cache, "__setitem__", (id_key, new_type.clone()))?; + + return Ok(new_type); + } + + // PyType_Check(cls) - type creates complete pointer type + if !cls.class().fast_issubclass(vm.ctx.types.type_type.as_ref()) { + return Err(vm.new_type_error("must be a ctypes type")); + } + // Create the name for the pointer type let name = if let Ok(type_obj) = cls.get_attr("__name__", vm) { format!("LP_{}", type_obj.str(vm)?) - } else if let Ok(s) = cls.str(vm) { - format!("LP_{}", s) } else { "LP_unknown".to_string() }; - // Create a new type that inherits from _Pointer - let type_type = &vm.ctx.types.type_type; - let bases = vm.ctx.new_tuple(vec![pointer_base]); - let dict = vm.ctx.new_dict(); + // Complete pointer type: set _type_ attribute dict.set_item("_type_", cls.clone(), vm)?; - let new_type = type_type + // Call the metaclass (PyCPointerType) to create the new type + let new_type = metaclass .as_object() .call((vm.ctx.new_str(name), bases, dict), vm)?; - // Store in cache using __setitem__ + // Store in cache with cls as key vm.call_method(&cache, "__setitem__", (cls, new_type.clone()))?; Ok(new_type) } - #[pyfunction(name = "pointer")] - pub fn create_pointer_inst(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + #[pyfunction] + fn pointer(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { // Get the type of the object let obj_type = obj.class().to_owned(); @@ -607,7 +703,7 @@ pub(crate) mod _ctypes { #[cfg(target_os = "windows")] #[pyfunction(name = "_check_HRESULT")] - pub fn check_hresult(_self: PyObjectRef, hr: i32, _vm: &VirtualMachine) -> PyResult { + fn check_hresult(_self: PyObjectRef, hr: i32, _vm: &VirtualMachine) -> PyResult { // TODO: fixme if hr < 0 { // vm.ctx.new_windows_error(hr) @@ -619,18 +715,17 @@ pub(crate) mod _ctypes { #[pyfunction] fn addressof(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if obj.is_instance(PyCSimple::static_type().as_ref(), vm)? { - let simple = obj.downcast_ref::().unwrap(); - Ok(simple.value.as_ptr() as usize) + // All ctypes objects should return cdata buffer pointer + if let Some(cdata) = obj.downcast_ref::() { + Ok(cdata.buffer.read().as_ptr() as usize) } else { Err(vm.new_type_error("expected a ctypes instance")) } } #[pyfunction] - fn byref(obj: PyObjectRef, offset: OptionalArg, vm: &VirtualMachine) -> PyResult { - use super::base::PyCData; - use crate::class::StaticType; + pub fn byref(obj: PyObjectRef, offset: OptionalArg, vm: &VirtualMachine) -> PyResult { + use super::FfiArgValue; // Check if obj is a ctypes instance if !obj.fast_isinstance(PyCData::static_type()) @@ -644,9 +739,23 @@ pub(crate) mod _ctypes { let offset_val = offset.unwrap_or(0); + // Get buffer address: (char *)((CDataObject *)obj)->b_ptr + offset + let ptr_val = if let Some(simple) = obj.downcast_ref::() { + let buffer = simple.0.buffer.read(); + (buffer.as_ptr() as isize + offset_val) as usize + } else if let Some(cdata) = obj.downcast_ref::() { + let buffer = cdata.buffer.read(); + (buffer.as_ptr() as isize + offset_val) as usize + } else { + 0 + }; + // Create CArgObject to hold the reference Ok(CArgObject { + tag: b'P', + value: FfiArgValue::Pointer(ptr_val), obj, + size: 0, offset: offset_val, } .to_pyobject(vm)) @@ -654,11 +763,6 @@ pub(crate) mod _ctypes { #[pyfunction] fn alignment(tp: Either, vm: &VirtualMachine) -> PyResult { - use super::base::PyCSimpleType; - use super::pointer::PyCPointer; - use super::structure::PyCStructure; - use super::union::PyCUnion; - use super::util::StgInfo; use crate::builtins::PyType; let obj = match &tp { @@ -667,23 +771,27 @@ pub(crate) mod _ctypes { }; // 1. Check TypeDataSlot on class (for instances) - if let Some(stg_info) = obj.class().get_type_data::() { + if let Some(stg_info) = obj.class().stg_info_opt() { return Ok(stg_info.align); } // 2. Check TypeDataSlot on type itself (for type objects) if let Some(type_obj) = obj.downcast_ref::() - && let Some(stg_info) = type_obj.get_type_data::() + && let Some(stg_info) = type_obj.stg_info_opt() { return Ok(stg_info.align); } - // 3. Fallback for simple types without TypeDataSlot - if obj.fast_isinstance(PyCSimple::static_type()) { - // Get stg_info from the type by reading _type_ attribute - let cls = obj.class().to_owned(); - let stg_info = PyCSimpleType::get_stg_info(&cls, vm); - return Ok(stg_info.align); + // 3. Fallback for simple types + if obj.fast_isinstance(PyCSimple::static_type()) + && let Ok(stg) = obj.class().stg_info(vm) + { + return Ok(stg.align); + } + if obj.fast_isinstance(PyCArray::static_type()) + && let Ok(stg) = obj.class().stg_info(vm) + { + return Ok(stg.align); } if obj.fast_isinstance(PyCStructure::static_type()) { // Calculate alignment from _fields_ @@ -715,8 +823,8 @@ pub(crate) mod _ctypes { // Simple type: _type_ is a single character string if let Ok(s) = type_attr.str(vm) { let ty = s.to_string(); - if ty.len() == 1 && SIMPLE_TYPE_CHARS.contains(ty.as_str()) { - return Ok(get_align(&ty)); + if ty.len() == 1 && super::simple::SIMPLE_TYPE_CHARS.contains(ty.as_str()) { + return Ok(super::get_align(&ty)); } } } @@ -754,9 +862,45 @@ pub(crate) mod _ctypes { } #[pyfunction] - fn resize(_args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { - // TODO: RUSTPYTHON - Err(vm.new_value_error("not implemented")) + fn resize(obj: PyObjectRef, size: isize, vm: &VirtualMachine) -> PyResult<()> { + use std::borrow::Cow; + + // 1. Get StgInfo from object's class (validates ctypes instance) + let stg_info = obj + .class() + .stg_info_opt() + .ok_or_else(|| vm.new_type_error("expected ctypes instance"))?; + + // 2. Validate size + if size < 0 || (size as usize) < stg_info.size { + return Err(vm.new_value_error(format!("minimum size is {}", stg_info.size))); + } + + // 3. Get PyCData via upcast (works for all ctypes types due to repr(transparent)) + let cdata = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected ctypes instance"))?; + + // 4. Check if buffer is owned (not borrowed from external memory) + { + let buffer = cdata.buffer.read(); + if matches!(&*buffer, Cow::Borrowed(_)) { + return Err(vm.new_value_error( + "Memory cannot be resized because this object doesn't own it".to_owned(), + )); + } + } + + // 5. Resize the buffer + let new_size = size as usize; + let mut buffer = cdata.buffer.write(); + let old_data = buffer.to_vec(); + let mut new_data = vec![0u8; new_size]; + let copy_len = old_data.len().min(new_size); + new_data[..copy_len].copy_from_slice(&old_data[..copy_len]); + *buffer = Cow::Owned(new_data); + + Ok(()) } #[pyfunction] @@ -796,77 +940,306 @@ pub(crate) mod _ctypes { #[pyattr] fn _string_at_addr(_vm: &VirtualMachine) -> usize { - let f = libc::strnlen; - f as usize + super::function::INTERNAL_STRING_AT_ADDR } #[pyattr] fn _wstring_at_addr(_vm: &VirtualMachine) -> usize { - // Return address of wcsnlen or similar wide string function - #[cfg(not(target_os = "windows"))] - { - let f = libc::wcslen; - f as usize - } - #[cfg(target_os = "windows")] - { - // FIXME: On Windows, use wcslen from ucrt - 0 - } + super::function::INTERNAL_WSTRING_AT_ADDR } #[pyattr] fn _cast_addr(_vm: &VirtualMachine) -> usize { - // todo!("Implement _cast_addr") - 0 + super::function::INTERNAL_CAST_ADDR } - #[pyfunction(name = "_cast")] - pub fn pycfunction_cast( + #[pyfunction] + fn _cast( obj: PyObjectRef, - _obj2: PyObjectRef, + src: PyObjectRef, ctype: PyObjectRef, vm: &VirtualMachine, ) -> PyResult { - use super::array::PyCArray; - use super::base::PyCData; - use super::pointer::PyCPointer; - use crate::class::StaticType; + super::function::cast_impl(obj, src, ctype, vm) + } + + /// Python-level cast function (PYFUNCTYPE wrapper) + #[pyfunction] + fn cast(obj: PyObjectRef, typ: PyObjectRef, vm: &VirtualMachine) -> PyResult { + super::function::cast_impl(obj.clone(), obj, typ, vm) + } + + /// Return buffer interface information for a ctypes type or object. + /// Returns a tuple (format, ndim, shape) where: + /// - format: PEP 3118 format string + /// - ndim: number of dimensions + /// - shape: tuple of dimension sizes + #[pyfunction] + fn buffer_info(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // Determine if obj is a type or an instance + let is_type = obj.class().fast_issubclass(vm.ctx.types.type_type.as_ref()); + let cls = if is_type { + obj.clone() + } else { + obj.class().to_owned().into() + }; + + // Get format from type - try _type_ first (for simple types), then _stg_info_format_ + let format = if let Ok(type_attr) = cls.get_attr("_type_", vm) { + type_attr.str(vm)?.to_string() + } else if let Ok(format_attr) = cls.get_attr("_stg_info_format_", vm) { + format_attr.str(vm)?.to_string() + } else { + return Err(vm.new_type_error("not a ctypes type or object")); + }; + + // Non-array types have ndim=0 and empty shape + // TODO: Implement ndim/shape for arrays when StgInfo supports it + let ndim = 0; + let shape: Vec = vec![]; + + let shape_tuple = vm.ctx.new_tuple(shape); + Ok(vm + .ctx + .new_tuple(vec![ + vm.ctx.new_str(format).into(), + vm.ctx.new_int(ndim).into(), + shape_tuple.into(), + ]) + .into()) + } + + /// Unpickle a ctypes object. + #[pyfunction] + fn _unpickle(typ: PyObjectRef, state: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if !state.class().is(vm.ctx.types.tuple_type.as_ref()) { + return Err(vm.new_type_error("state must be a tuple")); + } + let obj = vm.call_method(&typ, "__new__", (typ.clone(),))?; + vm.call_method(&obj, "__setstate__", (state,))?; + Ok(obj) + } + + /// Call a function at the given address with the given arguments. + #[pyfunction] + fn call_function( + func_addr: usize, + args: crate::builtins::PyTupleRef, + vm: &VirtualMachine, + ) -> PyResult { + call_function_internal(func_addr, args, 0, vm) + } + + /// Call a cdecl function at the given address with the given arguments. + #[pyfunction] + fn call_cdeclfunction( + func_addr: usize, + args: crate::builtins::PyTupleRef, + vm: &VirtualMachine, + ) -> PyResult { + call_function_internal(func_addr, args, FUNCFLAG_CDECL, vm) + } - // Python signature: _cast(obj, obj, ctype) - // Python passes the same object twice (obj and _obj2 are the same) - // We ignore _obj2 as it's redundant + fn call_function_internal( + func_addr: usize, + args: crate::builtins::PyTupleRef, + _flags: u32, + vm: &VirtualMachine, + ) -> PyResult { + use libffi::middle::{Arg, Cif, CodePtr, Type}; - // Check if this is a pointer type (has _type_ attribute) - if ctype.get_attr("_type_", vm).is_err() { - return Err(vm.new_type_error("cast() argument 2 must be a pointer type".to_string())); + if func_addr == 0 { + return Err(vm.new_value_error("NULL function pointer")); } - // Create an instance of the target pointer type with no arguments - let result = ctype.call((), vm)?; + let mut ffi_args: Vec = Vec::with_capacity(args.len()); + let mut arg_values: Vec = Vec::with_capacity(args.len()); + let mut arg_types: Vec = Vec::with_capacity(args.len()); + + for arg in args.iter() { + if vm.is_none(arg) { + arg_values.push(0); + arg_types.push(Type::pointer()); + } else if let Ok(int_val) = arg.try_int(vm) { + let val = int_val.as_bigint().to_i64().unwrap_or(0) as isize; + arg_values.push(val); + arg_types.push(Type::isize()); + } else if let Some(bytes) = arg.downcast_ref::() { + let ptr = bytes.as_bytes().as_ptr() as isize; + arg_values.push(ptr); + arg_types.push(Type::pointer()); + } else if let Some(s) = arg.downcast_ref::() { + let ptr = s.as_str().as_ptr() as isize; + arg_values.push(ptr); + arg_types.push(Type::pointer()); + } else { + return Err(vm.new_type_error(format!( + "Don't know how to convert parameter of type '{}'", + arg.class().name() + ))); + } + } - // Get the pointer value from the source object - // If obj is a CData instance (including arrays), use the object itself - // If obj is an integer, use it directly as the pointer value - let ptr_value: PyObjectRef = if obj.fast_isinstance(PyCData::static_type()) - || obj.fast_isinstance(PyCArray::static_type()) - || obj.fast_isinstance(PyCPointer::static_type()) - { - // For CData objects (including arrays and pointers), store the object itself - obj.clone() - } else if let Ok(int_val) = obj.try_int(vm) { - // For integers, treat as pointer address - vm.ctx.new_int(int_val.as_bigint().clone()).into() - } else { - return Err(vm.new_type_error(format!( - "cast() argument 1 must be a ctypes instance or an integer, not {}", - obj.class().name() - ))); + for val in &arg_values { + ffi_args.push(Arg::new(val)); + } + + let cif = Cif::new(arg_types, Type::isize()); + let code_ptr = CodePtr::from_ptr(func_addr as *const _); + let result: isize = unsafe { cif.call(code_ptr, &ffi_args) }; + Ok(vm.ctx.new_int(result).into()) + } + + /// Convert a pointer (as integer) to a Python object. + #[pyfunction(name = "PyObj_FromPtr")] + fn py_obj_from_ptr(ptr: usize, vm: &VirtualMachine) -> PyResult { + if ptr == 0 { + return Err(vm.new_value_error("NULL pointer access")); + } + let raw_ptr = ptr as *mut crate::object::PyObject; + unsafe { + let obj = crate::PyObjectRef::from_raw(std::ptr::NonNull::new_unchecked(raw_ptr)); + let obj = std::mem::ManuallyDrop::new(obj); + Ok((*obj).clone()) + } + } + + #[pyfunction(name = "Py_INCREF")] + fn py_incref(obj: PyObjectRef, _vm: &VirtualMachine) -> PyObjectRef { + // TODO: + obj + } + + #[pyfunction(name = "Py_DECREF")] + fn py_decref(obj: PyObjectRef, _vm: &VirtualMachine) -> PyObjectRef { + // TODO: + obj + } + + #[cfg(target_os = "macos")] + #[pyfunction] + fn _dyld_shared_cache_contains_path( + path: Option, + vm: &VirtualMachine, + ) -> PyResult { + use std::ffi::CString; + + let path = match path { + Some(p) if !vm.is_none(&p) => p, + _ => return Ok(false), }; - // Set the contents of the pointer by setting the attribute - result.set_attr("contents", ptr_value, vm)?; + let path_str = path.str(vm)?.to_string(); + let c_path = + CString::new(path_str).map_err(|_| vm.new_value_error("path contains null byte"))?; + unsafe extern "C" { + fn _dyld_shared_cache_contains_path(path: *const libc::c_char) -> bool; + } + + let result = unsafe { _dyld_shared_cache_contains_path(c_path.as_ptr()) }; Ok(result) } + + #[cfg(windows)] + #[pyfunction(name = "FormatError")] + fn format_error_func(code: OptionalArg, _vm: &VirtualMachine) -> PyResult { + use windows_sys::Win32::Foundation::{GetLastError, LocalFree}; + use windows_sys::Win32::System::Diagnostics::Debug::{ + FORMAT_MESSAGE_ALLOCATE_BUFFER, FORMAT_MESSAGE_FROM_SYSTEM, + FORMAT_MESSAGE_IGNORE_INSERTS, FormatMessageW, + }; + + let error_code = code.unwrap_or_else(|| unsafe { GetLastError() }); + + let mut buffer: *mut u16 = std::ptr::null_mut(); + let len = unsafe { + FormatMessageW( + FORMAT_MESSAGE_ALLOCATE_BUFFER + | FORMAT_MESSAGE_FROM_SYSTEM + | FORMAT_MESSAGE_IGNORE_INSERTS, + std::ptr::null(), + error_code, + 0, + &mut buffer as *mut *mut u16 as *mut u16, + 0, + std::ptr::null(), + ) + }; + + if len == 0 || buffer.is_null() { + return Ok("".to_string()); + } + + let message = unsafe { + let slice = std::slice::from_raw_parts(buffer, len as usize); + let msg = String::from_utf16_lossy(slice).trim_end().to_string(); + LocalFree(buffer as *mut _); + msg + }; + + Ok(message) + } + + #[cfg(windows)] + #[pyfunction(name = "CopyComPointer")] + fn copy_com_pointer(src: PyObjectRef, dst: PyObjectRef, vm: &VirtualMachine) -> PyResult { + use windows_sys::Win32::Foundation::{E_POINTER, S_OK}; + + // 1. Extract pointer-to-pointer address from dst (byref() result) + let pdst: usize = if let Some(carg) = dst.downcast_ref::() { + // byref() result: object buffer address + offset + let base = if let Some(cdata) = carg.obj.downcast_ref::() { + cdata.buffer.read().as_ptr() as usize + } else { + return Ok(E_POINTER); + }; + (base as isize + carg.offset) as usize + } else { + return Ok(E_POINTER); + }; + + if pdst == 0 { + return Ok(E_POINTER); + } + + // 2. Extract COM pointer value from src + let src_ptr: usize = if vm.is_none(&src) { + 0 + } else if let Some(cdata) = src.downcast_ref::() { + // c_void_p etc: read pointer value from buffer + let buffer = cdata.buffer.read(); + if buffer.len() >= std::mem::size_of::() { + usize::from_ne_bytes( + buffer[..std::mem::size_of::()] + .try_into() + .unwrap_or([0; std::mem::size_of::()]), + ) + } else { + 0 + } + } else { + return Ok(E_POINTER); + }; + + // 3. Call IUnknown::AddRef if src is non-NULL + if src_ptr != 0 { + unsafe { + // IUnknown vtable: [QueryInterface, AddRef, Release, ...] + let iunknown = src_ptr as *mut *const usize; + let vtable = *iunknown; + debug_assert!(!vtable.is_null(), "IUnknown vtable is null"); + let addref_fn: extern "system" fn(*mut std::ffi::c_void) -> u32 = + std::mem::transmute(*vtable.add(1)); // AddRef is index 1 + addref_fn(src_ptr as *mut std::ffi::c_void); + } + } + + // 4. Copy pointer: *pdst = src + unsafe { + *(pdst as *mut usize) = src_ptr; + } + + Ok(S_OK) + } } diff --git a/crates/vm/src/stdlib/ctypes/array.rs b/crates/vm/src/stdlib/ctypes/array.rs index fe12a781d9f..60e6516bfe0 100644 --- a/crates/vm/src/stdlib/ctypes/array.rs +++ b/crates/vm/src/stdlib/ctypes/array.rs @@ -1,41 +1,109 @@ -use crate::atomic_func; -use crate::builtins::{PyBytes, PyInt}; -use crate::class::StaticType; -use crate::function::FuncArgs; -use crate::protocol::{ - BufferDescriptor, BufferMethods, PyBuffer, PyNumberMethods, PySequenceMethods, -}; -use crate::stdlib::ctypes::base::CDataObject; -use crate::stdlib::ctypes::util::StgInfo; -use crate::types::{AsBuffer, AsNumber, AsSequence}; -use crate::{AsObject, Py, PyObjectRef, PyPayload}; +use super::StgInfo; +use super::base::{CDATA_BUFFER_METHODS, PyCData}; use crate::{ - PyResult, VirtualMachine, - builtins::{PyType, PyTypeRef}, - types::Constructor, + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, + atomic_func, + builtins::{PyBytes, PyInt, PyList, PySlice, PyStr, PyType, PyTypeRef}, + class::StaticType, + function::{ArgBytesLike, FuncArgs, PySetterValue}, + protocol::{BufferDescriptor, PyBuffer, PyNumberMethods, PySequenceMethods}, + types::{AsBuffer, AsNumber, AsSequence, Constructor, Initializer}, }; -use crossbeam_utils::atomic::AtomicCell; -use num_traits::ToPrimitive; -use rustpython_common::lock::PyRwLock; -use rustpython_vm::stdlib::ctypes::_ctypes::get_size; -use rustpython_vm::stdlib::ctypes::base::PyCData; +use num_traits::{Signed, ToPrimitive}; + +/// Creates array type for (element_type, length) +/// Uses _array_type_cache to ensure identical calls return the same type object +pub(super) fn array_type_from_ctype( + itemtype: PyObjectRef, + length: usize, + vm: &VirtualMachine, +) -> PyResult { + // PyCArrayType_from_ctype + + // Get the _array_type_cache from _ctypes module + let ctypes_module = vm.import("_ctypes", 0)?; + let cache = ctypes_module.get_attr("_array_type_cache", vm)?; + + // Create cache key: (itemtype, length) tuple + let length_obj: PyObjectRef = vm.ctx.new_int(length).into(); + let cache_key = vm.ctx.new_tuple(vec![itemtype.clone(), length_obj]); + + // Check if already in cache + if let Ok(cached) = vm.call_method(&cache, "__getitem__", (cache_key.clone(),)) + && !vm.is_none(&cached) + { + return Ok(cached); + } -/// PyCArrayType - metatype for Array types -/// CPython stores array info (type, length) in StgInfo via type_data -#[pyclass(name = "PyCArrayType", base = PyType, module = "_ctypes")] -#[derive(Debug)] -#[repr(transparent)] -pub struct PyCArrayType(PyType); + // Cache miss - create new array type + let itemtype_ref = itemtype + .clone() + .downcast::() + .map_err(|_| vm.new_type_error("Expected a type object"))?; + + let item_stg = itemtype_ref + .stg_info_opt() + .ok_or_else(|| vm.new_type_error("_type_ must have storage info"))?; + + let element_size = item_stg.size; + let element_align = item_stg.align; + let item_format = item_stg.format.clone(); + let item_shape = item_stg.shape.clone(); + let item_flags = item_stg.flags; + + // Check overflow before multiplication + let total_size = element_size + .checked_mul(length) + .ok_or_else(|| vm.new_overflow_error("array too large"))?; + + // format name: "c_int_Array_5" + let type_name = format!("{}_Array_{}", itemtype_ref.name(), length); + + // Get item type code before moving itemtype + let item_type_code = itemtype_ref + .as_object() + .get_attr("_type_", vm) + .ok() + .and_then(|t| t.downcast_ref::().map(|s| s.to_string())); + + let stg_info = StgInfo::new_array( + total_size, + element_align, + length, + itemtype_ref.clone(), + element_size, + item_format.as_deref(), + &item_shape, + item_flags, + ); -/// Create a new Array type with StgInfo stored in type_data (CPython style) -pub fn create_array_type_with_stg_info(stg_info: StgInfo, vm: &VirtualMachine) -> PyResult { - // Get PyCArrayType as metaclass - let metaclass = PyCArrayType::static_type().to_owned(); + let new_type = create_array_type_with_name(stg_info, &type_name, vm)?; + + // Special case for character arrays - add value/raw attributes + let new_type_ref: PyTypeRef = new_type + .clone() + .downcast() + .map_err(|_| vm.new_type_error("expected type"))?; + + match item_type_code.as_deref() { + Some("c") => add_char_array_getsets(&new_type_ref, vm), + Some("u") => add_wchar_array_getsets(&new_type_ref, vm), + _ => {} + } + + // Store in cache + vm.call_method(&cache, "__setitem__", (cache_key, new_type.clone()))?; - // Create a unique name for the array type - let type_name = format!("Array_{}", stg_info.length); + Ok(new_type) +} - // Create args for type(): (name, bases, dict) +/// create_array_type_with_name - create array type with specified name +fn create_array_type_with_name( + stg_info: StgInfo, + type_name: &str, + vm: &VirtualMachine, +) -> PyResult { + let metaclass = PyCArrayType::static_type().to_owned(); let name = vm.ctx.new_str(type_name); let bases = vm .ctx @@ -47,170 +115,205 @@ pub fn create_array_type_with_stg_info(stg_info: StgInfo, vm: &VirtualMachine) - crate::function::KwArgs::default(), ); - // Create the new type using PyType::slot_new with PyCArrayType as metaclass let new_type = crate::builtins::type_::PyType::slot_new(metaclass, args, vm)?; - // Set StgInfo in type_data let type_ref: PyTypeRef = new_type .clone() .downcast() - .map_err(|_| vm.new_type_error("Failed to create array type".to_owned()))?; + .map_err(|_| vm.new_type_error("Failed to create array type"))?; - if type_ref.init_type_data(stg_info.clone()).is_err() { - // Type data already initialized - update it - if let Some(mut existing) = type_ref.get_type_data_mut::() { - *existing = stg_info; - } + // Set class attributes for _type_ and _length_ + if let Some(element_type) = stg_info.element_type.clone() { + new_type.set_attr("_type_", element_type, vm)?; } + new_type.set_attr("_length_", vm.ctx.new_int(stg_info.length), vm)?; + + super::base::set_or_init_stginfo(&type_ref, stg_info); Ok(new_type) } -impl Constructor for PyCArrayType { - type Args = PyObjectRef; +/// PyCArrayType - metatype for Array types +#[pyclass(name = "PyCArrayType", base = PyType, module = "_ctypes")] +#[derive(Debug)] +#[repr(transparent)] +pub(super) struct PyCArrayType(PyType); - fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { - unimplemented!("use slot_new") - } -} +// PyCArrayType implements Initializer for slots.init (PyCArrayType_init) +impl Initializer for PyCArrayType { + type Args = FuncArgs; -#[pyclass(flags(IMMUTABLETYPE), with(Constructor, AsNumber))] -impl PyCArrayType { - #[pygetset(name = "_type_")] - fn typ(zelf: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { - zelf.downcast_ref::() - .and_then(|t| t.get_type_data::()) - .and_then(|stg| stg.element_type.clone()) - .unwrap_or_else(|| vm.ctx.none()) - } + fn init(zelf: PyRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + // zelf is the newly created array type (e.g., T in "class T(Array)") + let new_type: &PyType = &zelf.0; + + new_type.check_not_initialized(vm)?; + + // 1. Get _length_ from class dict first + let direct_length = new_type + .attributes + .read() + .get(vm.ctx.intern_str("_length_")) + .cloned(); + + // 2. Get _type_ from class dict first + let direct_type = new_type + .attributes + .read() + .get(vm.ctx.intern_str("_type_")) + .cloned(); + + // 3. Find parent StgInfo from MRO (for inheritance) + // Note: PyType.mro does NOT include self, so no skip needed + let parent_stg_info = new_type + .mro + .read() + .iter() + .find_map(|base| base.stg_info_opt().map(|s| s.clone())); + + // 4. Resolve _length_ (direct or inherited) + let length = if let Some(length_attr) = direct_length { + // Direct _length_ defined - validate it (PyLong_Check) + let length_int = length_attr + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("The '_length_' attribute must be an integer"))?; + let bigint = length_int.as_bigint(); + // Check sign first - negative values are ValueError + if bigint.is_negative() { + return Err(vm.new_value_error("The '_length_' attribute must not be negative")); + } + // Positive values that don't fit in usize are OverflowError + bigint + .to_usize() + .ok_or_else(|| vm.new_overflow_error("The '_length_' attribute is too large"))? + } else if let Some(ref parent_info) = parent_stg_info { + // Inherit from parent + parent_info.length + } else { + return Err(vm.new_attribute_error("class must define a '_length_' attribute")); + }; - #[pygetset(name = "_length_")] - fn length(zelf: PyObjectRef) -> usize { - zelf.downcast_ref::() - .and_then(|t| t.get_type_data::()) - .map(|stg| stg.length) - .unwrap_or(0) - } + // 5. Resolve _type_ and get item_info (direct or inherited) + let (element_type, item_size, item_align, item_format, item_shape, item_flags) = + if let Some(type_attr) = direct_type { + // Direct _type_ defined - validate it (PyStgInfo_FromType) + let type_ref = type_attr + .clone() + .downcast::() + .map_err(|_| vm.new_type_error("_type_ must be a type"))?; + let (size, align, format, shape, flags) = { + let item_info = type_ref + .stg_info_opt() + .ok_or_else(|| vm.new_type_error("_type_ must have storage info"))?; + ( + item_info.size, + item_info.align, + item_info.format.clone(), + item_info.shape.clone(), + item_info.flags, + ) + }; + (type_ref, size, align, format, shape, flags) + } else if let Some(ref parent_info) = parent_stg_info { + // Inherit from parent + let parent_type = parent_info + .element_type + .clone() + .ok_or_else(|| vm.new_type_error("_type_ must have storage info"))?; + ( + parent_type, + parent_info.element_size, + parent_info.align, + parent_info.format.clone(), + parent_info.shape.clone(), + parent_info.flags, + ) + } else { + return Err(vm.new_attribute_error("class must define a '_type_' attribute")); + }; - #[pymethod] - fn __mul__(zelf: PyObjectRef, n: isize, vm: &VirtualMachine) -> PyResult { - if n < 0 { - return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}"))); + // 6. Check overflow (item_size != 0 && length > MAX / item_size) + if item_size != 0 && length > usize::MAX / item_size { + return Err(vm.new_overflow_error("array too large")); } - // Get inner array info from TypeDataSlot - let type_ref = zelf.downcast_ref::().unwrap(); - let (_inner_length, inner_size) = type_ref - .get_type_data::() - .map(|stg| (stg.length, stg.size)) - .unwrap_or((0, 0)); - - // The element type of the new array is the current array type itself - let current_array_type: PyObjectRef = zelf.clone(); - - // Element size is the total size of the inner array - let new_element_size = inner_size; - let total_size = new_element_size * (n as usize); - + // 7. Initialize StgInfo (PyStgInfo_Init + field assignment) let stg_info = StgInfo::new_array( - total_size, - new_element_size, - n as usize, - current_array_type, - new_element_size, + item_size * length, // size = item_size * length + item_align, // align = item_info->align + length, // length + element_type.clone(), + item_size, // element_size + item_format.as_deref(), + &item_shape, + item_flags, ); - create_array_type_with_stg_info(stg_info, vm) - } + // 8. Store StgInfo in type_data + super::base::set_or_init_stginfo(new_type, stg_info); - #[pyclassmethod] - fn in_dll( - zelf: PyObjectRef, - dll: PyObjectRef, - name: crate::builtins::PyStrRef, - vm: &VirtualMachine, - ) -> PyResult { - use libloading::Symbol; + // 9. Get type code before moving element_type + let item_type_code = element_type + .as_object() + .get_attr("_type_", vm) + .ok() + .and_then(|t| t.downcast_ref::().map(|s| s.to_string())); + + // 10. Set class attributes for _type_ and _length_ + zelf.as_object().set_attr("_type_", element_type, vm)?; + zelf.as_object() + .set_attr("_length_", vm.ctx.new_int(length), vm)?; + + // 11. Special case for character arrays - add value/raw attributes + // if (iteminfo->getfunc == _ctypes_get_fielddesc("c")->getfunc) + // add_getset((PyTypeObject*)self, CharArray_getsets); + // else if (iteminfo->getfunc == _ctypes_get_fielddesc("u")->getfunc) + // add_getset((PyTypeObject*)self, WCharArray_getsets); + + // Get type ref for add_getset + let type_ref: PyTypeRef = zelf.as_object().to_owned().downcast().unwrap(); + match item_type_code.as_deref() { + Some("c") => add_char_array_getsets(&type_ref, vm), + Some("u") => add_wchar_array_getsets(&type_ref, vm), + _ => {} + } - // Get the library handle from dll object - let handle = if let Ok(int_handle) = dll.try_int(vm) { - // dll is an integer handle - int_handle - .as_bigint() - .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid library handle".to_owned()))? - } else { - // dll is a CDLL/PyDLL/WinDLL object with _handle attribute - dll.get_attr("_handle", vm)? - .try_int(vm)? - .as_bigint() - .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid library handle".to_owned()))? - }; + Ok(()) + } +} - // Get the library from cache - let library_cache = crate::stdlib::ctypes::library::libcache().read(); - let library = library_cache - .get_lib(handle) - .ok_or_else(|| vm.new_attribute_error("Library not found".to_owned()))?; - - // Get symbol address from library - let symbol_name = format!("{}\0", name.as_str()); - let inner_lib = library.lib.lock(); - - let symbol_address = if let Some(lib) = &*inner_lib { - unsafe { - // Try to get the symbol from the library - let symbol: Symbol<'_, *mut u8> = lib.get(symbol_name.as_bytes()).map_err(|e| { - vm.new_attribute_error(format!("{}: symbol '{}' not found", e, name.as_str())) - })?; - *symbol as usize - } - } else { - return Err(vm.new_attribute_error("Library is closed".to_owned())); - }; +#[pyclass(flags(IMMUTABLETYPE), with(Initializer, AsNumber))] +impl PyCArrayType { + #[pymethod] + fn from_param(zelf: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // zelf is the array type class that from_param was called on + let cls = zelf + .downcast::() + .map_err(|_| vm.new_type_error("from_param: expected a type"))?; + + // 1. If already an instance of the requested type, return it + if value.is_instance(cls.as_object(), vm)? { + return Ok(value); + } - // Get size from the array type via TypeDataSlot - let type_ref = zelf.downcast_ref::().unwrap(); - let (element_type, length, element_size) = type_ref - .get_type_data::() - .map(|stg| { - ( - stg.element_type.clone().unwrap_or_else(|| vm.ctx.none()), - stg.length, - stg.element_size, - ) - }) - .unwrap_or_else(|| (vm.ctx.none(), 0, 0)); - let total_size = element_size * length; - - // Read data from symbol address - let data = if symbol_address != 0 && total_size > 0 { - unsafe { - let ptr = symbol_address as *const u8; - std::slice::from_raw_parts(ptr, total_size).to_vec() + // 2. Check for CArgObject (PyCArg_CheckExact) + if let Some(carg) = value.downcast_ref::() { + // Check if the wrapped object is an instance of the requested type + if carg.obj.is_instance(cls.as_object(), vm)? { + return Ok(value); // Return the CArgObject as-is } - } else { - vec![0; total_size] - }; - - // Create instance - let cdata = CDataObject::from_bytes(data, None); - let instance = PyCArray { - _base: PyCData::new(cdata.clone()), - typ: PyRwLock::new(element_type), - length: AtomicCell::new(length), - element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(cdata), } - .into_pyobject(vm); - // Store base reference to keep dll alive - if let Ok(array_ref) = instance.clone().downcast::() { - array_ref.cdata.write().base = Some(dll); + // 3. Check for _as_parameter_ attribute + if let Ok(as_parameter) = value.get_attr("_as_parameter_", vm) { + return PyCArrayType::from_param(cls.as_object().to_owned(), as_parameter, vm); } - Ok(instance) + Err(vm.new_type_error(format!( + "expected {} instance instead of {}", + cls.name(), + value.class().name() + ))) } } @@ -223,8 +326,28 @@ impl AsNumber for PyCArrayType { .try_index(vm)? .as_bigint() .to_isize() - .ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?; - PyCArrayType::__mul__(a.to_owned(), n, vm) + .ok_or_else(|| vm.new_overflow_error("array size too large"))?; + + if n < 0 { + return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}"))); + } + + // Check for overflow before creating the new array type + let zelf_type = a + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("Expected type"))?; + + if let Some(stg_info) = zelf_type.stg_info_opt() { + let current_size = stg_info.size; + // Check if current_size * n would overflow + if current_size != 0 && (n as usize) > isize::MAX as usize / current_size { + return Err(vm.new_overflow_error("array too large")); + } + } + + // Use cached array type creation + // The element type of the new array is the current array type itself + array_type_from_ctype(a.to_owned(), n as usize, vm) }), ..PyNumberMethods::NOT_IMPLEMENTED }; @@ -232,27 +355,28 @@ impl AsNumber for PyCArrayType { } } +/// PyCArray - Array instance +/// All array metadata (element_type, length, element_size) is stored in the type's StgInfo #[pyclass( name = "Array", base = PyCData, metaclass = "PyCArrayType", module = "_ctypes" )] -pub struct PyCArray { - _base: PyCData, - /// Element type - can be a simple type (c_int) or an array type (c_int * 5) - pub(super) typ: PyRwLock, - pub(super) length: AtomicCell, - pub(super) element_size: AtomicCell, - pub(super) cdata: PyRwLock, -} +#[derive(Debug)] +#[repr(transparent)] +pub struct PyCArray(pub PyCData); -impl std::fmt::Debug for PyCArray { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PyCArray") - .field("typ", &self.typ) - .field("length", &self.length) - .finish() +impl PyCArray { + /// Get the type code of array element type (e.g., "c" for c_char, "u" for c_wchar) + fn get_element_type_code(zelf: &Py, vm: &VirtualMachine) -> Option { + zelf.class() + .stg_info_opt() + .and_then(|info| info.element_type.clone())? + .as_object() + .get_attr("_type_", vm) + .ok() + .and_then(|t| t.downcast_ref::().map(|s| s.to_string())) } } @@ -260,60 +384,29 @@ impl Constructor for PyCArray { type Args = FuncArgs; fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - // Get _type_ and _length_ from the class - let type_attr = cls.as_object().get_attr("_type_", vm).ok(); - let length_attr = cls.as_object().get_attr("_length_", vm).ok(); - - let element_type = type_attr.unwrap_or_else(|| vm.ctx.types.object_type.to_owned().into()); - let length = if let Some(len_obj) = length_attr { - len_obj.try_int(vm)?.as_bigint().to_usize().unwrap_or(0) - } else { - 0 + // Check for abstract class - StgInfo must exist and be initialized + // Extract values in a block to drop the borrow before using cls + let (length, total_size) = { + let stg = cls.stg_info(vm)?; + (stg.length, stg.size) }; - // Get element size from _type_ - let element_size = if let Ok(type_code) = element_type.get_attr("_type_", vm) { - if let Ok(s) = type_code.str(vm) { - let s = s.to_string(); - if s.len() == 1 { - get_size(&s) - } else { - std::mem::size_of::() - } - } else { - std::mem::size_of::() - } - } else { - std::mem::size_of::() - }; + // Check for too many initializers + if args.args.len() > length { + return Err(vm.new_index_error("too many initializers")); + } - let total_size = element_size * length; - let mut buffer = vec![0u8; total_size]; + // Create array with zero-initialized buffer + let buffer = vec![0u8; total_size]; + let instance = PyCArray(PyCData::from_bytes_with_length(buffer, None, length)) + .into_ref_with_type(vm, cls)?; - // Initialize from positional arguments + // Initialize elements using setitem_by_index (Array_init pattern) for (i, value) in args.args.iter().enumerate() { - if i >= length { - break; - } - let offset = i * element_size; - if let Ok(int_val) = value.try_int(vm) { - let bytes = PyCArray::int_to_bytes(int_val.as_bigint(), element_size); - if offset + element_size <= buffer.len() { - buffer[offset..offset + element_size].copy_from_slice(&bytes); - } - } + PyCArray::setitem_by_index(&instance, i as isize, value.clone(), vm)?; } - let cdata = CDataObject::from_bytes(buffer, None); - PyCArray { - _base: PyCData::new(cdata.clone()), - typ: PyRwLock::new(element_type), - length: AtomicCell::new(length), - element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(cdata), - } - .into_ref_with_type(vm, cls) - .map(Into::into) + Ok(instance.into()) } fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { @@ -325,15 +418,19 @@ impl AsSequence for PyCArray { fn as_sequence() -> &'static PySequenceMethods { use std::sync::LazyLock; static AS_SEQUENCE: LazyLock = LazyLock::new(|| PySequenceMethods { - length: atomic_func!(|seq, _vm| Ok(PyCArray::sequence_downcast(seq).length.load())), + length: atomic_func!(|seq, _vm| { + let zelf = PyCArray::sequence_downcast(seq); + Ok(zelf.class().stg_info_opt().map_or(0, |i| i.length)) + }), item: atomic_func!(|seq, i, vm| { - PyCArray::getitem_by_index(PyCArray::sequence_downcast(seq), i, vm) + let zelf = PyCArray::sequence_downcast(seq); + PyCArray::getitem_by_index(zelf, i, vm) }), ass_item: atomic_func!(|seq, i, value, vm| { let zelf = PyCArray::sequence_downcast(seq); match value { Some(v) => PyCArray::setitem_by_index(zelf, i, v, vm), - None => Err(vm.new_type_error("cannot delete array elements".to_owned())), + None => Err(vm.new_type_error("cannot delete array elements")), } }), ..PySequenceMethods::NOT_IMPLEMENTED @@ -347,468 +444,839 @@ impl AsSequence for PyCArray { with(Constructor, AsSequence, AsBuffer) )] impl PyCArray { - #[pygetset] - fn _objects(&self) -> Option { - self.cdata.read().objects.clone() - } - fn int_to_bytes(i: &malachite_bigint::BigInt, size: usize) -> Vec { + // Try unsigned first (handles values like 0xFFFFFFFF that overflow signed) + // then fall back to signed (handles negative values) match size { - 1 => vec![i.to_i8().unwrap_or(0) as u8], - 2 => i.to_i16().unwrap_or(0).to_ne_bytes().to_vec(), - 4 => i.to_i32().unwrap_or(0).to_ne_bytes().to_vec(), - 8 => i.to_i64().unwrap_or(0).to_ne_bytes().to_vec(), + 1 => { + if let Some(v) = i.to_u8() { + vec![v] + } else { + vec![i.to_i8().unwrap_or(0) as u8] + } + } + 2 => { + if let Some(v) = i.to_u16() { + v.to_ne_bytes().to_vec() + } else { + i.to_i16().unwrap_or(0).to_ne_bytes().to_vec() + } + } + 4 => { + if let Some(v) = i.to_u32() { + v.to_ne_bytes().to_vec() + } else { + i.to_i32().unwrap_or(0).to_ne_bytes().to_vec() + } + } + 8 => { + if let Some(v) = i.to_u64() { + v.to_ne_bytes().to_vec() + } else { + i.to_i64().unwrap_or(0).to_ne_bytes().to_vec() + } + } _ => vec![0u8; size], } } - fn bytes_to_int(bytes: &[u8], size: usize, vm: &VirtualMachine) -> PyObjectRef { - match size { - 1 => vm.ctx.new_int(bytes[0] as i8).into(), - 2 => { + fn bytes_to_int( + bytes: &[u8], + size: usize, + type_code: Option<&str>, + vm: &VirtualMachine, + ) -> PyObjectRef { + // Unsigned type codes: B (uchar), H (ushort), I (uint), L (ulong), Q (ulonglong) + let is_unsigned = matches!( + type_code, + Some("B") | Some("H") | Some("I") | Some("L") | Some("Q") + ); + + match (size, is_unsigned) { + (1, false) => vm.ctx.new_int(bytes[0] as i8).into(), + (1, true) => vm.ctx.new_int(bytes[0]).into(), + (2, false) => { let val = i16::from_ne_bytes([bytes[0], bytes[1]]); vm.ctx.new_int(val).into() } - 4 => { + (2, true) => { + let val = u16::from_ne_bytes([bytes[0], bytes[1]]); + vm.ctx.new_int(val).into() + } + (4, false) => { let val = i32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); vm.ctx.new_int(val).into() } - 8 => { + (4, true) => { + let val = u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); + vm.ctx.new_int(val).into() + } + (8, false) => { let val = i64::from_ne_bytes([ bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], ]); vm.ctx.new_int(val).into() } + (8, true) => { + let val = u64::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]); + vm.ctx.new_int(val).into() + } _ => vm.ctx.new_int(0).into(), } } - fn getitem_by_index(zelf: &PyCArray, i: isize, vm: &VirtualMachine) -> PyResult { - let length = zelf.length.load() as isize; + fn getitem_by_index(zelf: &Py, i: isize, vm: &VirtualMachine) -> PyResult { + let stg = zelf.class().stg_info_opt(); + let length = stg.as_ref().map_or(0, |i| i.length) as isize; let index = if i < 0 { length + i } else { i }; if index < 0 || index >= length { - return Err(vm.new_index_error("array index out of range".to_owned())); + return Err(vm.new_index_error("invalid index")); } let index = index as usize; - let element_size = zelf.element_size.load(); + let element_size = stg.as_ref().map_or(0, |i| i.element_size); let offset = index * element_size; - let buffer = zelf.cdata.read().buffer.clone(); - if offset + element_size <= buffer.len() { - let bytes = &buffer[offset..offset + element_size]; - Ok(Self::bytes_to_int(bytes, element_size, vm)) + let type_code = Self::get_element_type_code(zelf, vm); + + // Get target buffer and offset (base's buffer if available, otherwise own) + let base_obj = zelf.0.base.read().clone(); + let (buffer_lock, final_offset) = if let Some(cdata) = base_obj + .as_ref() + .and_then(|b| b.downcast_ref::()) + { + (&cdata.buffer, zelf.0.base_offset.load() + offset) } else { - Ok(vm.ctx.new_int(0).into()) + (&zelf.0.buffer, offset) + }; + + let buffer = buffer_lock.read(); + Self::read_element_from_buffer( + &buffer, + final_offset, + element_size, + type_code.as_deref(), + vm, + ) + } + + /// Helper to read an element value from a buffer at given offset + fn read_element_from_buffer( + buffer: &[u8], + offset: usize, + element_size: usize, + type_code: Option<&str>, + vm: &VirtualMachine, + ) -> PyResult { + match type_code { + Some("c") => { + // Return single byte as bytes + if offset < buffer.len() { + Ok(vm.ctx.new_bytes(vec![buffer[offset]]).into()) + } else { + Ok(vm.ctx.new_bytes(vec![0]).into()) + } + } + Some("u") => { + // Return single wchar as str + if let Some(code) = wchar_from_bytes(&buffer[offset..]) { + let s = char::from_u32(code) + .map(|c| c.to_string()) + .unwrap_or_default(); + Ok(vm.ctx.new_str(s).into()) + } else { + Ok(vm.ctx.new_str("").into()) + } + } + Some("z") => { + // c_char_p: pointer to bytes - dereference to get string + if offset + element_size > buffer.len() { + return Ok(vm.ctx.none()); + } + let ptr_bytes = &buffer[offset..offset + element_size]; + let ptr_val = usize::from_ne_bytes( + ptr_bytes + .try_into() + .unwrap_or([0; std::mem::size_of::()]), + ); + if ptr_val == 0 { + return Ok(vm.ctx.none()); + } + // Read null-terminated string from pointer address + unsafe { + let ptr = ptr_val as *const u8; + let mut len = 0; + while *ptr.add(len) != 0 { + len += 1; + } + let bytes = std::slice::from_raw_parts(ptr, len); + Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) + } + } + Some("Z") => { + // c_wchar_p: pointer to wchar_t - dereference to get string + if offset + element_size > buffer.len() { + return Ok(vm.ctx.none()); + } + let ptr_bytes = &buffer[offset..offset + element_size]; + let ptr_val = usize::from_ne_bytes( + ptr_bytes + .try_into() + .unwrap_or([0; std::mem::size_of::()]), + ); + if ptr_val == 0 { + return Ok(vm.ctx.none()); + } + // Read null-terminated wide string using WCHAR_SIZE + unsafe { + let ptr = ptr_val as *const u8; + let mut chars = Vec::new(); + let mut pos = 0usize; + loop { + let code = if WCHAR_SIZE == 2 { + let bytes = std::slice::from_raw_parts(ptr.add(pos), 2); + u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 + } else { + let bytes = std::slice::from_raw_parts(ptr.add(pos), 4); + u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) + }; + if code == 0 { + break; + } + if let Some(ch) = char::from_u32(code) { + chars.push(ch); + } + pos += WCHAR_SIZE; + } + let s: String = chars.into_iter().collect(); + Ok(vm.ctx.new_str(s).into()) + } + } + Some("f") => { + // c_float + if offset + 4 <= buffer.len() { + let bytes: [u8; 4] = buffer[offset..offset + 4].try_into().unwrap(); + let val = f32::from_ne_bytes(bytes); + Ok(vm.ctx.new_float(val as f64).into()) + } else { + Ok(vm.ctx.new_float(0.0).into()) + } + } + Some("d") | Some("g") => { + // c_double / c_longdouble - read f64 from first 8 bytes + if offset + 8 <= buffer.len() { + let bytes: [u8; 8] = buffer[offset..offset + 8].try_into().unwrap(); + let val = f64::from_ne_bytes(bytes); + Ok(vm.ctx.new_float(val).into()) + } else { + Ok(vm.ctx.new_float(0.0).into()) + } + } + _ => { + if offset + element_size <= buffer.len() { + let bytes = &buffer[offset..offset + element_size]; + Ok(Self::bytes_to_int(bytes, element_size, type_code, vm)) + } else { + Ok(vm.ctx.new_int(0).into()) + } + } } } + /// Helper to write an element value to a buffer at given offset + /// This is extracted to share code between direct write and base-buffer write + #[allow(clippy::too_many_arguments)] + fn write_element_to_buffer( + buffer: &mut [u8], + offset: usize, + element_size: usize, + type_code: Option<&str>, + value: &PyObject, + zelf: &Py, + index: usize, + vm: &VirtualMachine, + ) -> PyResult<()> { + match type_code { + Some("c") => { + if let Some(b) = value.downcast_ref::() { + if offset < buffer.len() { + buffer[offset] = b.as_bytes().first().copied().unwrap_or(0); + } + } else if let Ok(int_val) = value.try_int(vm) { + if offset < buffer.len() { + buffer[offset] = int_val.as_bigint().to_u8().unwrap_or(0); + } + } else { + return Err(vm.new_type_error("an integer or bytes of length 1 is required")); + } + } + Some("u") => { + if let Some(s) = value.downcast_ref::() { + let code = s.as_str().chars().next().map(|c| c as u32).unwrap_or(0); + if offset + WCHAR_SIZE <= buffer.len() { + wchar_to_bytes(code, &mut buffer[offset..]); + } + } else { + return Err(vm.new_type_error("unicode string expected")); + } + } + Some("z") => { + let (ptr_val, converted) = if value.is(&vm.ctx.none) { + (0usize, None) + } else if let Some(bytes) = value.downcast_ref::() { + let (c, ptr) = super::base::ensure_z_null_terminated(bytes, vm); + (ptr, Some(c)) + } else if let Ok(int_val) = value.try_index(vm) { + (int_val.as_bigint().to_usize().unwrap_or(0), None) + } else { + return Err(vm.new_type_error( + "bytes or integer address expected instead of {}".to_owned(), + )); + }; + if offset + element_size <= buffer.len() { + buffer[offset..offset + element_size].copy_from_slice(&ptr_val.to_ne_bytes()); + } + if let Some(c) = converted { + return zelf.0.keep_ref(index, c, vm); + } + } + Some("Z") => { + let (ptr_val, converted) = if value.is(&vm.ctx.none) { + (0usize, None) + } else if let Some(s) = value.downcast_ref::() { + let (holder, ptr) = super::base::str_to_wchar_bytes(s.as_str(), vm); + (ptr, Some(holder)) + } else if let Ok(int_val) = value.try_index(vm) { + (int_val.as_bigint().to_usize().unwrap_or(0), None) + } else { + return Err(vm.new_type_error("unicode string or integer address expected")); + }; + if offset + element_size <= buffer.len() { + buffer[offset..offset + element_size].copy_from_slice(&ptr_val.to_ne_bytes()); + } + if let Some(c) = converted { + return zelf.0.keep_ref(index, c, vm); + } + } + Some("f") => { + // c_float: convert int/float to f32 bytes + let f32_val = if let Ok(float_val) = value.try_float(vm) { + float_val.to_f64() as f32 + } else if let Ok(int_val) = value.try_int(vm) { + int_val.as_bigint().to_f64().unwrap_or(0.0) as f32 + } else { + return Err(vm.new_type_error("a float is required")); + }; + if offset + 4 <= buffer.len() { + buffer[offset..offset + 4].copy_from_slice(&f32_val.to_ne_bytes()); + } + } + Some("d") | Some("g") => { + // c_double / c_longdouble: convert int/float to f64 bytes + let f64_val = if let Ok(float_val) = value.try_float(vm) { + float_val.to_f64() + } else if let Ok(int_val) = value.try_int(vm) { + int_val.as_bigint().to_f64().unwrap_or(0.0) + } else { + return Err(vm.new_type_error("a float is required")); + }; + if offset + 8 <= buffer.len() { + buffer[offset..offset + 8].copy_from_slice(&f64_val.to_ne_bytes()); + } + // For "g" type, remaining bytes stay zero + } + _ => { + // Handle ctypes instances (copy their buffer) + if let Some(cdata) = value.downcast_ref::() { + let src_buffer = cdata.buffer.read(); + let copy_len = src_buffer.len().min(element_size); + if offset + copy_len <= buffer.len() { + buffer[offset..offset + copy_len].copy_from_slice(&src_buffer[..copy_len]); + } + // Other types: use int_to_bytes + } else if let Ok(int_val) = value.try_int(vm) { + let bytes = Self::int_to_bytes(int_val.as_bigint(), element_size); + if offset + element_size <= buffer.len() { + buffer[offset..offset + element_size].copy_from_slice(&bytes); + } + } else { + return Err(vm.new_type_error(format!( + "expected {} instance, not {}", + type_code.unwrap_or("value"), + value.class().name() + ))); + } + } + } + + // KeepRef + if super::base::PyCData::should_keep_ref(value) { + let to_keep = super::base::PyCData::get_kept_objects(value, vm); + zelf.0.keep_ref(index, to_keep, vm)?; + } + + Ok(()) + } + fn setitem_by_index( - zelf: &PyCArray, + zelf: &Py, i: isize, value: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { - let length = zelf.length.load() as isize; + let stg = zelf.class().stg_info_opt(); + let length = stg.as_ref().map_or(0, |i| i.length) as isize; let index = if i < 0 { length + i } else { i }; if index < 0 || index >= length { - return Err(vm.new_index_error("array index out of range".to_owned())); + return Err(vm.new_index_error("invalid index")); } let index = index as usize; - let element_size = zelf.element_size.load(); + let element_size = stg.as_ref().map_or(0, |i| i.element_size); let offset = index * element_size; + let type_code = Self::get_element_type_code(zelf, vm); + + // Get target buffer and offset (base's buffer if available, otherwise own) + let base_obj = zelf.0.base.read().clone(); + let (buffer_lock, final_offset) = if let Some(cdata) = base_obj + .as_ref() + .and_then(|b| b.downcast_ref::()) + { + (&cdata.buffer, zelf.0.base_offset.load() + offset) + } else { + (&zelf.0.buffer, offset) + }; - let int_val = value.try_int(vm)?; - let bytes = Self::int_to_bytes(int_val.as_bigint(), element_size); - - let mut cdata = zelf.cdata.write(); - if offset + element_size <= cdata.buffer.len() { - cdata.buffer[offset..offset + element_size].copy_from_slice(&bytes); - } - Ok(()) + let mut buffer = buffer_lock.write(); + Self::write_element_to_buffer( + buffer.to_mut(), + final_offset, + element_size, + type_code.as_deref(), + &value, + zelf, + index, + vm, + ) } + // Array_subscript #[pymethod] - fn __getitem__(&self, index: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if let Some(i) = index.downcast_ref::() { + fn __getitem__(zelf: &Py, item: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // PyIndex_Check + if let Some(i) = item.downcast_ref::() { let i = i.as_bigint().to_isize().ok_or_else(|| { - vm.new_index_error("cannot fit index into an index-sized integer".to_owned()) + vm.new_index_error("cannot fit index into an index-sized integer") })?; - Self::getitem_by_index(self, i, vm) + // getitem_by_index handles negative index normalization + Self::getitem_by_index(zelf, i, vm) + } + // PySlice_Check + else if let Some(slice) = item.downcast_ref::() { + Self::getitem_by_slice(zelf, slice, vm) } else { - Err(vm.new_type_error("array indices must be integers".to_owned())) + Err(vm.new_type_error("indices must be integers")) + } + } + + // Array_subscript slice handling + fn getitem_by_slice(zelf: &Py, slice: &PySlice, vm: &VirtualMachine) -> PyResult { + use crate::sliceable::SaturatedSliceIter; + + let stg = zelf.class().stg_info_opt(); + let length = stg.as_ref().map_or(0, |i| i.length); + + // PySlice_Unpack + PySlice_AdjustIndices + let sat_slice = slice.to_saturated(vm)?; + let (range, step, slice_len) = sat_slice.adjust_indices(length); + + let type_code = Self::get_element_type_code(zelf, vm); + let element_size = stg.as_ref().map_or(0, |i| i.element_size); + let start = range.start; + + match type_code.as_deref() { + // c_char → bytes (item_info->getfunc == "c") + Some("c") => { + if slice_len == 0 { + return Ok(vm.ctx.new_bytes(vec![]).into()); + } + let buffer = zelf.0.buffer.read(); + // step == 1 optimization: direct memcpy + if step == 1 { + let start_offset = start * element_size; + let end_offset = start_offset + slice_len; + if end_offset <= buffer.len() { + return Ok(vm + .ctx + .new_bytes(buffer[start_offset..end_offset].to_vec()) + .into()); + } + } + // Non-contiguous: iterate + let iter = SaturatedSliceIter::from_adjust_indices(range, step, slice_len); + let mut result = Vec::with_capacity(slice_len); + for idx in iter { + let offset = idx * element_size; + if offset < buffer.len() { + result.push(buffer[offset]); + } + } + Ok(vm.ctx.new_bytes(result).into()) + } + // c_wchar → str (item_info->getfunc == "u") + Some("u") => { + if slice_len == 0 { + return Ok(vm.ctx.new_str("").into()); + } + let buffer = zelf.0.buffer.read(); + // step == 1 optimization: direct conversion + if step == 1 { + let start_offset = start * WCHAR_SIZE; + let end_offset = start_offset + slice_len * WCHAR_SIZE; + if end_offset <= buffer.len() { + let wchar_bytes = &buffer[start_offset..end_offset]; + let result: String = wchar_bytes + .chunks(WCHAR_SIZE) + .filter_map(|chunk| wchar_from_bytes(chunk).and_then(char::from_u32)) + .collect(); + return Ok(vm.ctx.new_str(result).into()); + } + } + // Non-contiguous: iterate + let iter = SaturatedSliceIter::from_adjust_indices(range, step, slice_len); + let mut result = String::with_capacity(slice_len); + for idx in iter { + let offset = idx * WCHAR_SIZE; + if let Some(code_point) = wchar_from_bytes(&buffer[offset..]) + && let Some(c) = char::from_u32(code_point) + { + result.push(c); + } + } + Ok(vm.ctx.new_str(result).into()) + } + // Other types → list (PyList_New + Array_item for each) + _ => { + let iter = SaturatedSliceIter::from_adjust_indices(range, step, slice_len); + let mut result = Vec::with_capacity(slice_len); + for idx in iter { + result.push(Self::getitem_by_index(zelf, idx as isize, vm)?); + } + Ok(PyList::from(result).into_ref(&vm.ctx).into()) + } } } + // Array_ass_subscript #[pymethod] fn __setitem__( - &self, - index: PyObjectRef, + zelf: &Py, + item: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { - if let Some(i) = index.downcast_ref::() { + // Array does not support item deletion + // (handled implicitly - value is always provided in __setitem__) + + // PyIndex_Check + if let Some(i) = item.downcast_ref::() { let i = i.as_bigint().to_isize().ok_or_else(|| { - vm.new_index_error("cannot fit index into an index-sized integer".to_owned()) + vm.new_index_error("cannot fit index into an index-sized integer") })?; - Self::setitem_by_index(self, i, value, vm) + // setitem_by_index handles negative index normalization + Self::setitem_by_index(zelf, i, value, vm) + } + // PySlice_Check + else if let Some(slice) = item.downcast_ref::() { + Self::setitem_by_slice(zelf, slice, value, vm) } else { - Err(vm.new_type_error("array indices must be integers".to_owned())) + Err(vm.new_type_error("indices must be integer")) } } + // Array does not support item deletion #[pymethod] - fn __len__(&self) -> usize { - self.length.load() + fn __delitem__(&self, _item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + Err(vm.new_type_error("Array does not support item deletion")) } - #[pygetset(name = "_type_")] - fn typ(&self) -> PyObjectRef { - self.typ.read().clone() - } + // Array_ass_subscript slice handling + fn setitem_by_slice( + zelf: &Py, + slice: &PySlice, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + use crate::sliceable::SaturatedSliceIter; - #[pygetset(name = "_length_")] - fn length_getter(&self) -> usize { - self.length.load() - } + let length = zelf.class().stg_info_opt().map_or(0, |i| i.length); - #[pygetset] - fn value(&self, vm: &VirtualMachine) -> PyObjectRef { - // Return bytes representation of the buffer - let buffer = self.cdata.read().buffer.clone(); - vm.ctx.new_bytes(buffer.clone()).into() - } + // PySlice_Unpack + PySlice_AdjustIndices + let sat_slice = slice.to_saturated(vm)?; + let (range, step, slice_len) = sat_slice.adjust_indices(length); - #[pygetset(setter)] - fn set_value(&self, value: PyObjectRef, _vm: &VirtualMachine) -> PyResult<()> { - if let Some(bytes) = value.downcast_ref::() { - let mut cdata = self.cdata.write(); - let src = bytes.as_bytes(); - let len = std::cmp::min(src.len(), cdata.buffer.len()); - cdata.buffer[..len].copy_from_slice(&src[..len]); + // other_len = PySequence_Length(value); + let items: Vec = vm.extract_elements_with(&value, Ok)?; + let other_len = items.len(); + + if other_len != slice_len { + return Err(vm.new_value_error("Can only assign sequence of same size")); } - Ok(()) - } - #[pygetset] - fn raw(&self, vm: &VirtualMachine) -> PyObjectRef { - let cdata = self.cdata.read(); - vm.ctx.new_bytes(cdata.buffer.clone()).into() - } + // Use SaturatedSliceIter for correct index iteration (handles negative step) + let iter = SaturatedSliceIter::from_adjust_indices(range, step, slice_len); - #[pygetset(setter)] - fn set_raw(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if let Some(bytes) = value.downcast_ref::() { - let mut cdata = self.cdata.write(); - let src = bytes.as_bytes(); - let len = std::cmp::min(src.len(), cdata.buffer.len()); - cdata.buffer[..len].copy_from_slice(&src[..len]); - Ok(()) - } else { - Err(vm.new_type_error("expected bytes".to_owned())) + for (idx, item) in iter.zip(items) { + Self::setitem_by_index(zelf, idx as isize, item, vm)?; } + Ok(()) } - #[pyclassmethod] - fn from_address(cls: PyTypeRef, address: isize, vm: &VirtualMachine) -> PyResult { - use crate::stdlib::ctypes::_ctypes::size_of; + #[pymethod] + fn __len__(zelf: &Py, _vm: &VirtualMachine) -> usize { + zelf.class().stg_info_opt().map_or(0, |i| i.length) + } +} - // Get size from cls - let size = size_of(cls.clone().into(), vm)?; +impl PyCArray { + #[allow(unused)] + pub fn to_arg(&self, _vm: &VirtualMachine) -> PyResult { + let buffer = self.0.buffer.read(); + Ok(libffi::middle::Arg::new(&*buffer)) + } +} - // Create instance with data from address - if address == 0 || size == 0 { - return Err(vm.new_value_error("NULL pointer access".to_owned())); - } - unsafe { - let ptr = address as *const u8; - let bytes = std::slice::from_raw_parts(ptr, size); - // Get element type and length from cls - let element_type = cls.as_object().get_attr("_type_", vm)?; - let element_type: PyTypeRef = element_type - .downcast() - .map_err(|_| vm.new_type_error("_type_ must be a type".to_owned()))?; - let length = cls - .as_object() - .get_attr("_length_", vm)? - .try_int(vm)? - .as_bigint() - .to_usize() - .unwrap_or(0); - let element_size = if length > 0 { size / length } else { 0 }; - - let cdata = CDataObject::from_bytes(bytes.to_vec(), None); - Ok(PyCArray { - _base: PyCData::new(cdata.clone()), - typ: PyRwLock::new(element_type.into()), - length: AtomicCell::new(length), - element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(cdata), +impl AsBuffer for PyCArray { + fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + let buffer_len = zelf.0.buffer.read().len(); + + // Get format and shape from type's StgInfo + let stg_info = zelf + .class() + .stg_info_opt() + .expect("PyCArray type must have StgInfo"); + let format = stg_info.format.clone(); + let shape = stg_info.shape.clone(); + let element_size = stg_info.element_size; + + let desc = if let Some(fmt) = format + && !shape.is_empty() + { + // Build dim_desc from shape (C-contiguous: row-major order) + // stride[i] = product(shape[i+1:]) * itemsize + let mut dim_desc = Vec::with_capacity(shape.len()); + let mut stride = element_size as isize; + + // Calculate strides from innermost to outermost dimension + for &dim_size in shape.iter().rev() { + dim_desc.push((dim_size, stride, 0)); + stride *= dim_size as isize; } - .into_pyobject(vm)) - } - } + dim_desc.reverse(); + + BufferDescriptor { + len: buffer_len, + readonly: false, + itemsize: element_size, + format: std::borrow::Cow::Owned(fmt), + dim_desc, + } + } else { + // Fallback to simple buffer if no format/shape info + BufferDescriptor::simple(buffer_len, false) + }; - #[pyclassmethod] - fn from_buffer( - cls: PyTypeRef, - source: PyObjectRef, - offset: crate::function::OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - use crate::TryFromObject; - use crate::protocol::PyBuffer; - use crate::stdlib::ctypes::_ctypes::size_of; + let buf = PyBuffer::new(zelf.to_owned().into(), desc, &CDATA_BUFFER_METHODS); + Ok(buf) + } +} - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); - } - let offset = offset as usize; +// CharArray and WCharArray getsets - added dynamically via add_getset - // Get buffer from source - let buffer = PyBuffer::try_from_object(vm, source.clone())?; +// CharArray_get_value +fn char_array_get_value(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let zelf = obj.downcast_ref::().unwrap(); + let buffer = zelf.0.buffer.read(); + let len = buffer.iter().position(|&b| b == 0).unwrap_or(buffer.len()); + Ok(vm.ctx.new_bytes(buffer[..len].to_vec()).into()) +} - // Check if buffer is writable - if buffer.desc.readonly { - return Err(vm.new_type_error("underlying buffer is not writable".to_owned())); - } +// CharArray_set_value +fn char_array_set_value(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let zelf = obj.downcast_ref::().unwrap(); + let bytes = value + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("bytes expected"))?; + let mut buffer = zelf.0.buffer.write(); + let src = bytes.as_bytes(); + + if src.len() > buffer.len() { + return Err(vm.new_value_error("byte string too long")); + } - // Get size from cls - let size = size_of(cls.clone().into(), vm)?; - - // Check if buffer is large enough - let buffer_len = buffer.desc.len; - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); - } + buffer.to_mut()[..src.len()].copy_from_slice(src); + if src.len() < buffer.len() { + buffer.to_mut()[src.len()] = 0; + } + Ok(()) +} - // Read bytes from buffer at offset - let bytes = buffer.obj_bytes(); - let data = &bytes[offset..offset + size]; +// CharArray_get_raw +fn char_array_get_raw(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let zelf = obj.downcast_ref::().unwrap(); + let buffer = zelf.0.buffer.read(); + Ok(vm.ctx.new_bytes(buffer.to_vec()).into()) +} - // Get element type and length from cls - let element_type = cls.as_object().get_attr("_type_", vm)?; - let element_type: PyTypeRef = element_type - .downcast() - .map_err(|_| vm.new_type_error("_type_ must be a type".to_owned()))?; - let length = cls - .as_object() - .get_attr("_length_", vm)? - .try_int(vm)? - .as_bigint() - .to_usize() - .unwrap_or(0); - let element_size = if length > 0 { size / length } else { 0 }; - - let cdata = CDataObject::from_bytes(data.to_vec(), Some(buffer.obj.clone())); - Ok(PyCArray { - _base: PyCData::new(cdata.clone()), - typ: PyRwLock::new(element_type.into()), - length: AtomicCell::new(length), - element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(cdata), - } - .into_pyobject(vm)) +// CharArray_set_raw +fn char_array_set_raw( + obj: PyObjectRef, + value: PySetterValue, + vm: &VirtualMachine, +) -> PyResult<()> { + let value = value.ok_or_else(|| vm.new_attribute_error("cannot delete attribute"))?; + let zelf = obj.downcast_ref::().unwrap(); + let bytes_like = ArgBytesLike::try_from_object(vm, value)?; + let mut buffer = zelf.0.buffer.write(); + let src = bytes_like.borrow_buf(); + if src.len() > buffer.len() { + return Err(vm.new_value_error("byte string too long")); } + buffer.to_mut()[..src.len()].copy_from_slice(&src); + Ok(()) +} - #[pyclassmethod] - fn from_buffer_copy( - cls: PyTypeRef, - source: crate::function::ArgBytesLike, - offset: crate::function::OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - use crate::stdlib::ctypes::_ctypes::size_of; - - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); - } - let offset = offset as usize; - - // Get size from cls - let size = size_of(cls.clone().into(), vm)?; - - // Borrow bytes from source - let source_bytes = source.borrow_buf(); - let buffer_len = source_bytes.len(); - - // Check if buffer is large enough - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); - } +// WCharArray_get_value +fn wchar_array_get_value(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let zelf = obj.downcast_ref::().unwrap(); + let buffer = zelf.0.buffer.read(); + Ok(vm.ctx.new_str(wstring_from_bytes(&buffer)).into()) +} - // Copy bytes from buffer at offset - let data = &source_bytes[offset..offset + size]; +// WCharArray_set_value +fn wchar_array_set_value( + obj: PyObjectRef, + value: PyObjectRef, + vm: &VirtualMachine, +) -> PyResult<()> { + let zelf = obj.downcast_ref::().unwrap(); + let s = value + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("unicode string expected"))?; + let mut buffer = zelf.0.buffer.write(); + let wchar_count = buffer.len() / WCHAR_SIZE; + let char_count = s.as_str().chars().count(); + + if char_count > wchar_count { + return Err(vm.new_value_error("string too long")); + } - // Get element type and length from cls - let element_type = cls.as_object().get_attr("_type_", vm)?; - let element_type: PyTypeRef = element_type - .downcast() - .map_err(|_| vm.new_type_error("_type_ must be a type".to_owned()))?; - let length = cls - .as_object() - .get_attr("_length_", vm)? - .try_int(vm)? - .as_bigint() - .to_usize() - .unwrap_or(0); - let element_size = if length > 0 { size / length } else { 0 }; - - let cdata = CDataObject::from_bytes(data.to_vec(), None); - Ok(PyCArray { - _base: PyCData::new(cdata.clone()), - typ: PyRwLock::new(element_type.into()), - length: AtomicCell::new(length), - element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(cdata), - } - .into_pyobject(vm)) + for (i, ch) in s.as_str().chars().enumerate() { + let offset = i * WCHAR_SIZE; + wchar_to_bytes(ch as u32, &mut buffer.to_mut()[offset..]); } - #[pyclassmethod] - fn in_dll( - cls: PyTypeRef, - dll: PyObjectRef, - name: crate::builtins::PyStrRef, - vm: &VirtualMachine, - ) -> PyResult { - use crate::stdlib::ctypes::_ctypes::size_of; - use libloading::Symbol; - - // Get the library handle from dll object - let handle = if let Ok(int_handle) = dll.try_int(vm) { - // dll is an integer handle - int_handle - .as_bigint() - .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid library handle".to_owned()))? - } else { - // dll is a CDLL/PyDLL/WinDLL object with _handle attribute - dll.get_attr("_handle", vm)? - .try_int(vm)? - .as_bigint() - .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid library handle".to_owned()))? - }; + let terminator_offset = char_count * WCHAR_SIZE; + if terminator_offset + WCHAR_SIZE <= buffer.len() { + wchar_to_bytes(0, &mut buffer.to_mut()[terminator_offset..]); + } + Ok(()) +} - // Get the library from cache - let library_cache = crate::stdlib::ctypes::library::libcache().read(); - let library = library_cache - .get_lib(handle) - .ok_or_else(|| vm.new_attribute_error("Library not found".to_owned()))?; - - // Get symbol address from library - let symbol_name = format!("{}\0", name.as_str()); - let inner_lib = library.lib.lock(); - - let symbol_address = if let Some(lib) = &*inner_lib { - unsafe { - // Try to get the symbol from the library - let symbol: Symbol<'_, *mut u8> = lib.get(symbol_name.as_bytes()).map_err(|e| { - vm.new_attribute_error(format!("{}: symbol '{}' not found", e, name.as_str())) - })?; - *symbol as usize - } - } else { - return Err(vm.new_attribute_error("Library is closed".to_owned())); - }; +/// add_getset for c_char arrays - adds 'value' and 'raw' attributes +/// add_getset((PyTypeObject*)self, CharArray_getsets) +fn add_char_array_getsets(array_type: &Py, vm: &VirtualMachine) { + // SAFETY: getset is owned by array_type which outlives the getset + let value_getset = unsafe { + vm.ctx.new_getset( + "value", + array_type, + char_array_get_value, + char_array_set_value, + ) + }; + let raw_getset = unsafe { + vm.ctx + .new_getset("raw", array_type, char_array_get_raw, char_array_set_raw) + }; + + array_type + .attributes + .write() + .insert(vm.ctx.intern_str("value"), value_getset.into()); + array_type + .attributes + .write() + .insert(vm.ctx.intern_str("raw"), raw_getset.into()); +} - // Get size from cls - let size = size_of(cls.clone().into(), vm)?; +/// add_getset for c_wchar arrays - adds only 'value' attribute (no 'raw') +fn add_wchar_array_getsets(array_type: &Py, vm: &VirtualMachine) { + // SAFETY: getset is owned by array_type which outlives the getset + let value_getset = unsafe { + vm.ctx.new_getset( + "value", + array_type, + wchar_array_get_value, + wchar_array_set_value, + ) + }; - // Read data from symbol address - let data = if symbol_address != 0 && size > 0 { - unsafe { - let ptr = symbol_address as *const u8; - std::slice::from_raw_parts(ptr, size).to_vec() - } - } else { - vec![0; size] - }; + array_type + .attributes + .write() + .insert(vm.ctx.intern_str("value"), value_getset.into()); +} - // Get element type and length from cls - let element_type = cls.as_object().get_attr("_type_", vm)?; - let element_type: PyTypeRef = element_type - .downcast() - .map_err(|_| vm.new_type_error("_type_ must be a type".to_owned()))?; - let length = cls - .as_object() - .get_attr("_length_", vm)? - .try_int(vm)? - .as_bigint() - .to_usize() - .unwrap_or(0); - let element_size = if length > 0 { size / length } else { 0 }; - - // Create instance - let cdata = CDataObject::from_bytes(data, None); - let instance = PyCArray { - _base: PyCData::new(cdata.clone()), - typ: PyRwLock::new(element_type.into()), - length: AtomicCell::new(length), - element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(cdata), - } - .into_pyobject(vm); +// wchar_t helpers - Platform-independent wide character handling +// Windows: sizeof(wchar_t) == 2 (UTF-16) +// Linux/macOS: sizeof(wchar_t) == 4 (UTF-32) - // Store base reference to keep dll alive - if let Ok(array_ref) = instance.clone().downcast::() { - array_ref.cdata.write().base = Some(dll); - } +/// Size of wchar_t on this platform +pub(super) const WCHAR_SIZE: usize = std::mem::size_of::(); - Ok(instance) +/// Read a single wchar_t from bytes (platform-endian) +#[inline] +pub(super) fn wchar_from_bytes(bytes: &[u8]) -> Option { + if bytes.len() < WCHAR_SIZE { + return None; } + Some(if WCHAR_SIZE == 2 { + u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 + } else { + u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) + }) } -impl PyCArray { - #[allow(unused)] - pub fn to_arg(&self, _vm: &VirtualMachine) -> PyResult { - let cdata = self.cdata.read(); - Ok(libffi::middle::Arg::new(&cdata.buffer)) +/// Write a single wchar_t to bytes (platform-endian) +#[inline] +pub(super) fn wchar_to_bytes(ch: u32, buffer: &mut [u8]) { + if WCHAR_SIZE == 2 { + if buffer.len() >= 2 { + buffer[..2].copy_from_slice(&(ch as u16).to_ne_bytes()); + } + } else if buffer.len() >= 4 { + buffer[..4].copy_from_slice(&ch.to_ne_bytes()); } } -static ARRAY_BUFFER_METHODS: BufferMethods = BufferMethods { - obj_bytes: |buffer| { - rustpython_common::lock::PyMappedRwLockReadGuard::map( - rustpython_common::lock::PyRwLockReadGuard::map( - buffer.obj_as::().cdata.read(), - |x: &CDataObject| x, - ), - |x: &CDataObject| x.buffer.as_slice(), - ) - .into() - }, - obj_bytes_mut: |buffer| { - rustpython_common::lock::PyMappedRwLockWriteGuard::map( - rustpython_common::lock::PyRwLockWriteGuard::map( - buffer.obj_as::().cdata.write(), - |x: &mut CDataObject| x, - ), - |x: &mut CDataObject| x.buffer.as_mut_slice(), - ) - .into() - }, - release: |_| {}, - retain: |_| {}, -}; - -impl AsBuffer for PyCArray { - fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { - let buffer_len = zelf.cdata.read().buffer.len(); - let buf = PyBuffer::new( - zelf.to_owned().into(), - BufferDescriptor::simple(buffer_len, false), // readonly=false for ctypes - &ARRAY_BUFFER_METHODS, - ); - Ok(buf) +/// Read a null-terminated wchar_t string from bytes, returns String +fn wstring_from_bytes(buffer: &[u8]) -> String { + let mut chars = Vec::new(); + for chunk in buffer.chunks(WCHAR_SIZE) { + if chunk.len() < WCHAR_SIZE { + break; + } + let code = if WCHAR_SIZE == 2 { + u16::from_ne_bytes([chunk[0], chunk[1]]) as u32 + } else { + u32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) + }; + if code == 0 { + break; // null terminator + } + if let Some(ch) = char::from_u32(code) { + chars.push(ch); + } } + chars.into_iter().collect() } diff --git a/crates/vm/src/stdlib/ctypes/base.rs b/crates/vm/src/stdlib/ctypes/base.rs index e45ff0b3b70..38c371346e0 100644 --- a/crates/vm/src/stdlib/ctypes/base.rs +++ b/crates/vm/src/stdlib/ctypes/base.rs @@ -1,871 +1,1113 @@ -use super::_ctypes::bytes_to_pyobject; -use super::util::StgInfo; -use crate::builtins::{PyBytes, PyFloat, PyInt, PyNone, PyStr, PyStrRef, PyType, PyTypeRef}; -use crate::function::{ArgBytesLike, Either, FuncArgs, KwArgs, OptionalArg}; -use crate::protocol::{BufferDescriptor, BufferMethods, PyBuffer, PyNumberMethods}; -use crate::stdlib::ctypes::_ctypes::new_simple_type; -use crate::types::{AsBuffer, AsNumber, Constructor}; +use super::array::{WCHAR_SIZE, wchar_from_bytes, wchar_to_bytes}; +use crate::builtins::{PyBytes, PyDict, PyMemoryView, PyStr, PyType, PyTypeRef}; +use crate::class::StaticType; +use crate::function::{ArgBytesLike, OptionalArg, PySetterValue}; +use crate::protocol::{BufferMethods, PyBuffer}; +use crate::types::{GetDescriptor, Representable}; use crate::{ - AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; -use num_traits::ToPrimitive; +use num_traits::{Signed, ToPrimitive}; use rustpython_common::lock::PyRwLock; -use std::ffi::{c_uint, c_ulong, c_ulonglong, c_ushort}; +use std::borrow::Cow; +use std::ffi::{ + c_double, c_float, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort, +}; use std::fmt::Debug; +use std::mem; +use widestring::WideChar; + +// StgInfo - Storage information for ctypes types +// Stored in TypeDataSlot of heap types (PyType::init_type_data/get_type_data) + +// Flag constants +bitflags::bitflags! { + #[derive(Default, Copy, Clone, Debug, PartialEq, Eq)] + pub struct StgInfoFlags: u32 { + // Function calling convention flags + /// Standard call convention (Windows) + const FUNCFLAG_STDCALL = 0x0; + /// C calling convention + const FUNCFLAG_CDECL = 0x1; + /// Function returns HRESULT + const FUNCFLAG_HRESULT = 0x2; + /// Use Python API calling convention + const FUNCFLAG_PYTHONAPI = 0x4; + /// Capture errno after call + const FUNCFLAG_USE_ERRNO = 0x8; + /// Capture last error after call (Windows) + const FUNCFLAG_USE_LASTERROR = 0x10; + + // Type flags + /// Type is a pointer type + const TYPEFLAG_ISPOINTER = 0x100; + /// Type contains pointer fields + const TYPEFLAG_HASPOINTER = 0x200; + /// Type is or contains a union + const TYPEFLAG_HASUNION = 0x400; + /// Type contains bitfield members + const TYPEFLAG_HASBITFIELD = 0x800; + + // Dict flags + /// Type is finalized (_fields_ has been set) + const DICTFLAG_FINAL = 0x1000; + } +} -/// Get the type code string from a ctypes type (e.g., "i" for c_int) -pub fn get_type_code(cls: &PyTypeRef, vm: &VirtualMachine) -> Option { - cls.as_object() - .get_attr("_type_", vm) - .ok() - .and_then(|t| t.downcast_ref::().map(|s| s.to_string())) +/// ParamFunc - determines how a type is passed to foreign functions +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub(super) enum ParamFunc { + #[default] + None, + /// Array types are passed as pointers (tag = 'P') + Array, + /// Simple types use their specific conversion (tag = type code) + Simple, + /// Pointer types (tag = 'P') + Pointer, + /// Structure types (tag = 'V' for value) + Structure, + /// Union types (tag = 'V' for value) + Union, +} + +#[derive(Clone)] +pub struct StgInfo { + pub initialized: bool, + pub size: usize, // number of bytes + pub align: usize, // alignment requirements + pub length: usize, // number of fields (for arrays/structures) + pub proto: Option, // Only for Pointer/ArrayObject + pub flags: StgInfoFlags, // type flags (TYPEFLAG_*, DICTFLAG_*) + + // Array-specific fields + pub element_type: Option, // _type_ for arrays + pub element_size: usize, // size of each element + + // PEP 3118 buffer protocol fields + pub format: Option, // struct format string (e.g., "i", "(5)i") + pub shape: Vec, // shape for multi-dimensional arrays + + // Function parameter conversion + pub(super) paramfunc: ParamFunc, // how to pass to foreign functions + + // Byte order (for _swappedbytes_) + pub big_endian: bool, // true if big endian, false if little endian + + // FFI field types for structure/union passing (inherited from base class) + pub ffi_field_types: Vec, } -pub fn ffi_type_from_str(_type_: &str) -> Option { - match _type_ { - "c" => Some(libffi::middle::Type::u8()), - "u" => Some(libffi::middle::Type::u32()), - "b" => Some(libffi::middle::Type::i8()), - "B" => Some(libffi::middle::Type::u8()), - "h" => Some(libffi::middle::Type::i16()), - "H" => Some(libffi::middle::Type::u16()), - "i" => Some(libffi::middle::Type::i32()), - "I" => Some(libffi::middle::Type::u32()), - "l" => Some(libffi::middle::Type::i32()), - "L" => Some(libffi::middle::Type::u32()), - "q" => Some(libffi::middle::Type::i64()), - "Q" => Some(libffi::middle::Type::u64()), - "f" => Some(libffi::middle::Type::f32()), - "d" => Some(libffi::middle::Type::f64()), - "g" => Some(libffi::middle::Type::f64()), - "?" => Some(libffi::middle::Type::u8()), - "z" => Some(libffi::middle::Type::u64()), - "Z" => Some(libffi::middle::Type::u64()), - "P" => Some(libffi::middle::Type::u64()), - _ => None, +// StgInfo is stored in type_data which requires Send + Sync. +// The PyTypeRef in proto/element_type fields is protected by the type system's locking mechanism. +// ctypes objects are not thread-safe by design; users must synchronize access. +unsafe impl Send for StgInfo {} +unsafe impl Sync for StgInfo {} + +impl std::fmt::Debug for StgInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StgInfo") + .field("initialized", &self.initialized) + .field("size", &self.size) + .field("align", &self.align) + .field("length", &self.length) + .field("proto", &self.proto) + .field("flags", &self.flags) + .field("element_type", &self.element_type) + .field("element_size", &self.element_size) + .field("format", &self.format) + .field("shape", &self.shape) + .field("paramfunc", &self.paramfunc) + .field("big_endian", &self.big_endian) + .field("ffi_field_types", &self.ffi_field_types.len()) + .finish() } } -#[allow(dead_code)] -fn set_primitive(_type_: &str, value: &PyObject, vm: &VirtualMachine) -> PyResult { - match _type_ { - "c" => { - if value - .to_owned() - .downcast_exact::(vm) - .is_ok_and(|v| v.len() == 1) - || value - .to_owned() - .downcast_exact::(vm) - .is_ok_and(|v| v.len() == 1) - || value - .to_owned() - .downcast_exact::(vm) - .map_or(Ok(false), |v| { - let n = v.as_bigint().to_i64(); - if let Some(n) = n { - Ok((0..=255).contains(&n)) - } else { - Ok(false) - } - })? - { - Ok(value.to_owned()) - } else { - Err(vm.new_type_error("one character bytes, bytearray or integer expected")) - } +impl Default for StgInfo { + fn default() -> Self { + StgInfo { + initialized: false, + size: 0, + align: 1, + length: 0, + proto: None, + flags: StgInfoFlags::empty(), + element_type: None, + element_size: 0, + format: None, + shape: Vec::new(), + paramfunc: ParamFunc::None, + big_endian: cfg!(target_endian = "big"), // native endian by default + ffi_field_types: Vec::new(), } - "u" => { - if let Ok(b) = value.str(vm).map(|v| v.to_string().chars().count() == 1) { - if b { - Ok(value.to_owned()) + } +} + +impl StgInfo { + pub fn new(size: usize, align: usize) -> Self { + StgInfo { + initialized: true, + size, + align, + length: 0, + proto: None, + flags: StgInfoFlags::empty(), + element_type: None, + element_size: 0, + format: None, + shape: Vec::new(), + paramfunc: ParamFunc::None, + big_endian: cfg!(target_endian = "big"), // native endian by default + ffi_field_types: Vec::new(), + } + } + + /// Create StgInfo for an array type + /// item_format: the innermost element's format string (kept as-is, e.g., ", + item_shape: &[usize], + item_flags: StgInfoFlags, + ) -> Self { + // Format is kept from innermost element (e.g., "flags & (TYPEFLAG_ISPOINTER | TYPEFLAG_HASPOINTER)) + // stginfo->flags |= TYPEFLAG_HASPOINTER; + let flags = if item_flags + .intersects(StgInfoFlags::TYPEFLAG_ISPOINTER | StgInfoFlags::TYPEFLAG_HASPOINTER) + { + StgInfoFlags::TYPEFLAG_HASPOINTER + } else { + StgInfoFlags::empty() + }; + + StgInfo { + initialized: true, + size, + align, + length, + proto: None, + flags, + element_type: Some(element_type), + element_size, + format, + shape, + paramfunc: ParamFunc::Array, + big_endian: cfg!(target_endian = "big"), // native endian by default + ffi_field_types: Vec::new(), + } + } + + /// Get libffi type for this StgInfo + /// Note: For very large types, returns pointer type to avoid overflow + pub fn to_ffi_type(&self) -> libffi::middle::Type { + // Limit to avoid overflow in libffi (MAX_STRUCT_SIZE is platform-dependent) + const MAX_FFI_STRUCT_SIZE: usize = 1024 * 1024; // 1MB limit for safety + + match self.paramfunc { + ParamFunc::Structure | ParamFunc::Union => { + if !self.ffi_field_types.is_empty() { + libffi::middle::Type::structure(self.ffi_field_types.iter().cloned()) + } else if self.size <= MAX_FFI_STRUCT_SIZE { + // Small struct without field types: use bytes array + libffi::middle::Type::structure(std::iter::repeat_n( + libffi::middle::Type::u8(), + self.size, + )) } else { - Err(vm.new_type_error("one character unicode string expected")) + // Large struct: treat as pointer (passed by reference) + libffi::middle::Type::pointer() } - } else { - Err(vm.new_type_error(format!( - "unicode string expected instead of {} instance", - value.class().name() - ))) - } - } - "b" | "h" | "H" | "i" | "I" | "l" | "q" | "L" | "Q" => { - if value.to_owned().downcast_exact::(vm).is_ok() { - Ok(value.to_owned()) - } else { - Err(vm.new_type_error(format!( - "an integer is required (got type {})", - value.class().name() - ))) } - } - "f" | "d" | "g" => { - // float allows int - if value.to_owned().downcast_exact::(vm).is_ok() - || value.to_owned().downcast_exact::(vm).is_ok() - { - Ok(value.to_owned()) - } else { - Err(vm.new_type_error(format!("must be real number, not {}", value.class().name()))) + ParamFunc::Array => { + if self.size > MAX_FFI_STRUCT_SIZE || self.length > MAX_FFI_STRUCT_SIZE { + // Large array: treat as pointer + libffi::middle::Type::pointer() + } else if let Some(ref fmt) = self.format { + let elem_type = Self::format_to_ffi_type(fmt); + libffi::middle::Type::structure(std::iter::repeat_n(elem_type, self.length)) + } else { + libffi::middle::Type::structure(std::iter::repeat_n( + libffi::middle::Type::u8(), + self.size, + )) + } } - } - "?" => Ok(PyObjectRef::from( - vm.ctx.new_bool(value.to_owned().try_to_bool(vm)?), - )), - "B" => { - if value.to_owned().downcast_exact::(vm).is_ok() { - // Store as-is, conversion to unsigned happens in the getter - Ok(value.to_owned()) - } else { - Err(vm.new_type_error(format!("int expected instead of {}", value.class().name()))) + ParamFunc::Pointer => libffi::middle::Type::pointer(), + _ => { + // Simple type: derive from format + if let Some(ref fmt) = self.format { + Self::format_to_ffi_type(fmt) + } else { + libffi::middle::Type::u8() + } } } - "z" => { - if value.to_owned().downcast_exact::(vm).is_ok() - || value.to_owned().downcast_exact::(vm).is_ok() - { - Ok(value.to_owned()) - } else { - Err(vm.new_type_error(format!( - "bytes or integer address expected instead of {} instance", - value.class().name() - ))) - } + } + + /// Convert format string to libffi type + fn format_to_ffi_type(fmt: &str) -> libffi::middle::Type { + // Strip endian prefix if present + let code = fmt.trim_start_matches(['<', '>', '!', '@', '=']); + match code { + "b" => libffi::middle::Type::i8(), + "B" => libffi::middle::Type::u8(), + "h" => libffi::middle::Type::i16(), + "H" => libffi::middle::Type::u16(), + "i" | "l" => libffi::middle::Type::i32(), + "I" | "L" => libffi::middle::Type::u32(), + "q" => libffi::middle::Type::i64(), + "Q" => libffi::middle::Type::u64(), + "f" => libffi::middle::Type::f32(), + "d" => libffi::middle::Type::f64(), + "P" | "z" | "Z" | "O" => libffi::middle::Type::pointer(), + _ => libffi::middle::Type::u8(), // default } - "Z" => { - if value.to_owned().downcast_exact::(vm).is_ok() { - Ok(value.to_owned()) - } else { - Err(vm.new_type_error(format!( - "unicode string or integer address expected instead of {} instance", - value.class().name() - ))) - } + } + + /// Check if this type is finalized (cannot set _fields_ again) + pub fn is_final(&self) -> bool { + self.flags.contains(StgInfoFlags::DICTFLAG_FINAL) + } + + /// Get proto type reference (for Pointer/Array types) + pub fn proto(&self) -> &Py { + self.proto.as_deref().expect("type has proto") + } +} + +/// Get PEP3118 format string for a field type +/// Returns the format string considering byte order +pub(super) fn get_field_format( + field_type: &PyObject, + big_endian: bool, + vm: &VirtualMachine, +) -> String { + // 1. Check StgInfo for format + if let Some(type_obj) = field_type.downcast_ref::() + && let Some(stg_info) = type_obj.stg_info_opt() + && let Some(fmt) = &stg_info.format + { + // Handle endian prefix for simple types + if fmt.len() == 1 { + let endian_prefix = if big_endian { ">" } else { "<" }; + return format!("{}{}", endian_prefix, fmt); } - _ => { - // "P" - if value.to_owned().downcast_exact::(vm).is_ok() - || value.to_owned().downcast_exact::(vm).is_ok() - { - Ok(value.to_owned()) - } else { - Err(vm.new_type_error("cannot be converted to pointer")) - } + return fmt.clone(); + } + + // 2. Try to get _type_ attribute for simple types + if let Ok(type_attr) = field_type.get_attr("_type_", vm) + && let Some(type_str) = type_attr.downcast_ref::() + { + let s = type_str.as_str(); + if s.len() == 1 { + let endian_prefix = if big_endian { ">" } else { "<" }; + return format!("{}{}", endian_prefix, s); } + return s.to_string(); } + + // Default: single byte + "B".to_string() } -/// Common data object for all ctypes types -#[derive(Debug, Clone)] -pub struct CDataObject { - /// pointer to memory block (b_ptr + b_size) - pub buffer: Vec, +/// Compute byte order based on swapped flag +#[inline] +pub(super) fn is_big_endian(is_swapped: bool) -> bool { + if is_swapped { + !cfg!(target_endian = "big") + } else { + cfg!(target_endian = "big") + } +} + +/// Shared BufferMethods for all ctypes types (PyCArray, PyCSimple, PyCStructure, PyCUnion) +/// All these types are #[repr(transparent)] wrappers around PyCData +pub(super) static CDATA_BUFFER_METHODS: BufferMethods = BufferMethods { + obj_bytes: |buffer| { + rustpython_common::lock::PyRwLockReadGuard::map( + buffer.obj_as::().buffer.read(), + |x| &**x, + ) + .into() + }, + obj_bytes_mut: |buffer| { + rustpython_common::lock::PyRwLockWriteGuard::map( + buffer.obj_as::().buffer.write(), + |x| x.to_mut().as_mut_slice(), + ) + .into() + }, + release: |_| {}, + retain: |_| {}, +}; + +/// Convert Vec to Vec by reinterpreting the memory (same allocation). +fn vec_to_bytes(vec: Vec) -> Vec { + let len = vec.len() * std::mem::size_of::(); + let cap = vec.capacity() * std::mem::size_of::(); + let ptr = vec.as_ptr() as *mut u8; + std::mem::forget(vec); + unsafe { Vec::from_raw_parts(ptr, len, cap) } +} + +/// Ensure PyBytes is null-terminated. Returns (PyBytes to keep, pointer). +/// If already contains null, returns original. Otherwise creates new with null appended. +pub(super) fn ensure_z_null_terminated( + bytes: &PyBytes, + vm: &VirtualMachine, +) -> (PyObjectRef, usize) { + let data = bytes.as_bytes(); + if data.contains(&0) { + // Already has null, use original + let original: PyObjectRef = vm.ctx.new_bytes(data.to_vec()).into(); + (original, data.as_ptr() as usize) + } else { + // Create new with null appended + let mut buffer = data.to_vec(); + buffer.push(0); + let ptr = buffer.as_ptr() as usize; + let new_bytes: PyObjectRef = vm.ctx.new_bytes(buffer).into(); + (new_bytes, ptr) + } +} + +/// Convert str to null-terminated wchar_t buffer. Returns (PyBytes holder, pointer). +pub(super) fn str_to_wchar_bytes(s: &str, vm: &VirtualMachine) -> (PyObjectRef, usize) { + let wchars: Vec = s + .chars() + .map(|c| c as libc::wchar_t) + .chain(std::iter::once(0)) + .collect(); + let ptr = wchars.as_ptr() as usize; + let bytes = vec_to_bytes(wchars); + let holder: PyObjectRef = vm.ctx.new_bytes(bytes).into(); + (holder, ptr) +} + +/// PyCData - base type for all ctypes data types +#[pyclass(name = "_CData", module = "_ctypes")] +#[derive(Debug, PyPayload)] +pub struct PyCData { + /// Memory buffer - Owned (self-owned) or Borrowed (external reference) + /// + /// SAFETY: Borrowed variant's 'static lifetime is not actually static. + /// When created via from_address or from_base_obj, only valid for the lifetime of the source memory. + /// Same behavior as CPython's b_ptr (user responsibility, kept alive via b_base). + pub buffer: PyRwLock>, /// pointer to base object or None (b_base) - #[allow(dead_code)] - pub base: Option, + pub base: PyRwLock>, + /// byte offset within base's buffer (for field access) + pub base_offset: AtomicCell, /// index into base's b_objects list (b_index) - #[allow(dead_code)] - pub index: usize, + pub index: AtomicCell, /// dictionary of references we need to keep (b_objects) - pub objects: Option, + pub objects: PyRwLock>, + /// number of references we need (b_length) + pub length: AtomicCell, } -impl CDataObject { +impl PyCData { /// Create from StgInfo (PyCData_MallocBuffer pattern) pub fn from_stg_info(stg_info: &StgInfo) -> Self { - CDataObject { - buffer: vec![0u8; stg_info.size], - base: None, - index: 0, - objects: None, + PyCData { + buffer: PyRwLock::new(Cow::Owned(vec![0u8; stg_info.size])), + base: PyRwLock::new(None), + base_offset: AtomicCell::new(0), + index: AtomicCell::new(0), + objects: PyRwLock::new(None), + length: AtomicCell::new(stg_info.length), } } /// Create from existing bytes (copies data) pub fn from_bytes(data: Vec, objects: Option) -> Self { - CDataObject { - buffer: data, - base: None, - index: 0, - objects, + PyCData { + buffer: PyRwLock::new(Cow::Owned(data)), + base: PyRwLock::new(None), + base_offset: AtomicCell::new(0), + index: AtomicCell::new(0), + objects: PyRwLock::new(objects), + length: AtomicCell::new(0), } } - /// Create from base object (copies data from base's buffer at offset) - #[allow(dead_code)] - pub fn from_base( - base: PyObjectRef, - _offset: usize, - size: usize, - index: usize, + /// Create from bytes with specified length (for arrays) + pub fn from_bytes_with_length( + data: Vec, objects: Option, + length: usize, ) -> Self { - CDataObject { - buffer: vec![0u8; size], - base: Some(base), - index, - objects, + PyCData { + buffer: PyRwLock::new(Cow::Owned(data)), + base: PyRwLock::new(None), + base_offset: AtomicCell::new(0), + index: AtomicCell::new(0), + objects: PyRwLock::new(objects), + length: AtomicCell::new(length), } } - #[inline] - pub fn size(&self) -> usize { - self.buffer.len() + /// Create from external memory address + /// + /// # Safety + /// The returned slice's 'static lifetime is a lie. + /// Actually only valid for the lifetime of the memory pointed to by ptr. + /// PyCData_AtAddress + pub unsafe fn at_address(ptr: *const u8, size: usize) -> Self { + // = PyCData_AtAddress + // SAFETY: Caller must ensure ptr is valid for the lifetime of returned PyCData + let slice: &'static [u8] = unsafe { std::slice::from_raw_parts(ptr, size) }; + PyCData { + buffer: PyRwLock::new(Cow::Borrowed(slice)), + base: PyRwLock::new(None), + base_offset: AtomicCell::new(0), + index: AtomicCell::new(0), + objects: PyRwLock::new(None), + length: AtomicCell::new(0), + } } -} - -#[pyclass(name = "_CData", module = "_ctypes")] -#[derive(Debug, PyPayload)] -pub struct PyCData { - pub cdata: PyRwLock, -} -impl PyCData { - pub fn new(cdata: CDataObject) -> Self { - Self { - cdata: PyRwLock::new(cdata), + /// Create from base object with offset and data copy + /// + /// Similar to from_base_with_offset, but also stores a copy of the data. + /// This is used for arrays where we need our own buffer for the buffer protocol, + /// but still maintain the base reference for KeepRef and tracking. + pub fn from_base_with_data( + base_obj: PyObjectRef, + offset: usize, + idx: usize, + length: usize, + data: Vec, + ) -> Self { + PyCData { + buffer: PyRwLock::new(Cow::Owned(data)), // Has its own buffer copy + base: PyRwLock::new(Some(base_obj)), // But still tracks base + base_offset: AtomicCell::new(offset), // And offset for writes + index: AtomicCell::new(idx), + objects: PyRwLock::new(None), + length: AtomicCell::new(length), } } -} -#[pyclass(flags(BASETYPE))] -impl PyCData { - #[pygetset] - fn _objects(&self) -> Option { - self.cdata.read().objects.clone() + /// Create from base object's buffer + /// + /// This creates a borrowed view into the base's buffer at the given address. + /// The base object is stored in b_base to keep the memory alive. + /// + /// # Safety + /// ptr must point into base_obj's buffer and remain valid as long as base_obj is alive. + pub unsafe fn from_base_obj( + ptr: *mut u8, + size: usize, + base_obj: PyObjectRef, + idx: usize, + ) -> Self { + // = PyCData_FromBaseObj + // SAFETY: ptr points into base_obj's buffer, kept alive via base reference + let slice: &'static [u8] = unsafe { std::slice::from_raw_parts(ptr, size) }; + PyCData { + buffer: PyRwLock::new(Cow::Borrowed(slice)), + base: PyRwLock::new(Some(base_obj)), + base_offset: AtomicCell::new(0), + index: AtomicCell::new(idx), + objects: PyRwLock::new(None), + length: AtomicCell::new(0), + } } -} -#[pyclass(module = "_ctypes", name = "PyCSimpleType", base = PyType)] -#[derive(Debug)] -#[repr(transparent)] -pub struct PyCSimpleType(PyType); - -#[pyclass(flags(BASETYPE), with(AsNumber))] -impl PyCSimpleType { - /// Get stg_info for a simple type by reading _type_ attribute - pub fn get_stg_info(cls: &PyTypeRef, vm: &VirtualMachine) -> StgInfo { - if let Ok(type_attr) = cls.as_object().get_attr("_type_", vm) - && let Ok(type_str) = type_attr.str(vm) - { - let tp_str = type_str.to_string(); - if tp_str.len() == 1 { - let size = super::_ctypes::get_size(&tp_str); - let align = super::_ctypes::get_align(&tp_str); - return StgInfo::new(size, align); - } + /// Create from buffer protocol object (for from_buffer method) + /// + /// Unlike from_bytes, this shares memory with the source buffer. + /// The source object is stored in objects dict to keep the buffer alive. + /// Python stores with key -1 via KeepRef(result, -1, mv). + /// + /// # Safety + /// ptr must point to valid memory that remains valid as long as source is alive. + pub unsafe fn from_buffer_shared( + ptr: *const u8, + size: usize, + length: usize, + source: PyObjectRef, + vm: &VirtualMachine, + ) -> Self { + // SAFETY: Caller must ensure ptr is valid for the lifetime of source + let slice: &'static [u8] = unsafe { std::slice::from_raw_parts(ptr, size) }; + + // Python stores the reference in a dict with key "-1" (unique_key pattern) + let objects_dict = vm.ctx.new_dict(); + objects_dict + .set_item("-1", source, vm) + .expect("Failed to store buffer reference"); + + PyCData { + buffer: PyRwLock::new(Cow::Borrowed(slice)), + base: PyRwLock::new(None), + base_offset: AtomicCell::new(0), + index: AtomicCell::new(0), + objects: PyRwLock::new(Some(objects_dict.into())), + length: AtomicCell::new(length), } - StgInfo::default() - } - #[allow(clippy::new_ret_no_self)] - #[pymethod] - fn new(cls: PyTypeRef, _: OptionalArg, vm: &VirtualMachine) -> PyResult { - Ok(PyObjectRef::from( - new_simple_type(Either::B(&cls), vm)? - .into_ref_with_type(vm, cls)? - .clone(), - )) } - #[pyclassmethod] - fn from_param(cls: PyTypeRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { - // 1. If the value is already an instance of the requested type, return it - if value.fast_isinstance(&cls) { - return Ok(value); + /// Common implementation for from_buffer class method. + /// Validates buffer, creates memoryview, and returns PyCData sharing memory with source. + /// + /// CDataType_from_buffer_impl + pub fn from_buffer_impl( + cls: &Py, + source: PyObjectRef, + offset: isize, + vm: &VirtualMachine, + ) -> PyResult { + let (size, length) = { + let stg_info = cls + .stg_info_opt() + .ok_or_else(|| vm.new_type_error("not a ctypes type"))?; + (stg_info.size, stg_info.length) + }; + + if offset < 0 { + return Err(vm.new_value_error("offset cannot be negative")); } + let offset = offset as usize; - // 2. Get the type code to determine conversion rules - let type_code = get_type_code(&cls, vm); + // Get buffer from source (this exports the buffer) + let buffer = PyBuffer::try_from_object(vm, source)?; - // 3. Handle None for pointer types (c_char_p, c_wchar_p, c_void_p) - if vm.is_none(&value) && matches!(type_code.as_deref(), Some("z") | Some("Z") | Some("P")) { - return Ok(value); + // Check if buffer is writable + if buffer.desc.readonly { + return Err(vm.new_type_error("underlying buffer is not writable")); } - // 4. Try to convert value based on type code - match type_code.as_deref() { - // Integer types: accept integers - Some("b" | "B" | "h" | "H" | "i" | "I" | "l" | "L" | "q" | "Q") => { - if value.try_int(vm).is_ok() { - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(value.clone()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - } - // Float types: accept numbers - Some("f" | "d" | "g") => { - if value.try_float(vm).is_ok() || value.try_int(vm).is_ok() { - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(value.clone()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - } - // c_char: 1 byte character - Some("c") => { - if let Some(bytes) = value.downcast_ref::() - && bytes.len() == 1 - { - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(value.clone()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - if let Ok(int_val) = value.try_int(vm) - && int_val.as_bigint().to_u8().is_some() - { - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(value.clone()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - return Err(vm.new_type_error( - "one character bytes, bytearray or integer expected".to_string(), - )); - } - // c_wchar: 1 unicode character - Some("u") => { - if let Some(s) = value.downcast_ref::() - && s.as_str().chars().count() == 1 - { - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(value.clone()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - return Err(vm.new_type_error("one character unicode string expected".to_string())); - } - // c_char_p: bytes pointer - Some("z") => { - if value.downcast_ref::().is_some() { - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(value.clone()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - } - // c_wchar_p: unicode pointer - Some("Z") => { - if value.downcast_ref::().is_some() { - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(value.clone()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - } - // c_void_p: most flexible - accepts int, bytes, str - Some("P") => { - if value.try_int(vm).is_ok() - || value.downcast_ref::().is_some() - || value.downcast_ref::().is_some() - { - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(value.clone()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - } - // c_bool - Some("?") => { - let bool_val = value.is_true(vm)?; - let simple = new_simple_type(Either::B(&cls), vm)?; - simple.value.store(vm.ctx.new_bool(bool_val).into()); - return simple.into_ref_with_type(vm, cls.clone()).map(Into::into); - } - _ => {} + // Check if buffer is C contiguous + if !buffer.desc.is_contiguous() { + return Err(vm.new_type_error("underlying buffer is not C contiguous")); } - // 5. Check for _as_parameter_ attribute - if let Ok(as_parameter) = value.get_attr("_as_parameter_", vm) { - return PyCSimpleType::from_param(cls, as_parameter, vm); + // Check if buffer is large enough + let buffer_len = buffer.desc.len; + if offset + size > buffer_len { + return Err(vm.new_value_error(format!( + "Buffer size too small ({} instead of at least {} bytes)", + buffer_len, + offset + size + ))); } - // 6. Type-specific error messages - match type_code.as_deref() { - Some("z") => Err(vm.new_type_error(format!( - "'{}' object cannot be interpreted as ctypes.c_char_p", - value.class().name() - ))), - Some("Z") => Err(vm.new_type_error(format!( - "'{}' object cannot be interpreted as ctypes.c_wchar_p", - value.class().name() - ))), - _ => Err(vm.new_type_error("wrong type".to_string())), - } - } + // Get buffer pointer - the memory is owned by source + let ptr = { + let bytes = buffer.obj_bytes(); + bytes.as_ptr().wrapping_add(offset) + }; - #[pymethod] - fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { - PyCSimple::repeat(cls, n, vm) + // Create memoryview to keep buffer exported (prevents source from being modified) + // mv = PyMemoryView_FromObject(obj); KeepRef(result, -1, mv); + let memoryview = PyMemoryView::from_buffer(buffer, vm)?; + let mv_obj = memoryview.into_pyobject(vm); + + // Create CData that shares memory with the buffer + Ok(unsafe { Self::from_buffer_shared(ptr, size, length, mv_obj, vm) }) } -} -impl AsNumber for PyCSimpleType { - fn as_number() -> &'static PyNumberMethods { - static AS_NUMBER: PyNumberMethods = PyNumberMethods { - multiply: Some(|a, b, vm| { - // a is a PyCSimpleType instance (type object like c_char) - // b is int (array size) - let cls = a - .downcast_ref::() - .ok_or_else(|| vm.new_type_error("expected type".to_owned()))?; - let n = b - .try_index(vm)? - .as_bigint() - .to_isize() - .ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?; - PyCSimple::repeat(cls.to_owned(), n, vm) - }), - ..PyNumberMethods::NOT_IMPLEMENTED + /// Common implementation for from_buffer_copy class method. + /// Copies data from buffer and creates new independent instance. + /// + /// CDataType_from_buffer_copy_impl + pub fn from_buffer_copy_impl( + cls: &Py, + source: &[u8], + offset: isize, + vm: &VirtualMachine, + ) -> PyResult { + let (size, length) = { + let stg_info = cls + .stg_info_opt() + .ok_or_else(|| vm.new_type_error("not a ctypes type"))?; + (stg_info.size, stg_info.length) }; - &AS_NUMBER - } -} -#[pyclass( - module = "_ctypes", - name = "_SimpleCData", - base = PyCData, - metaclass = "PyCSimpleType" -)] -pub struct PyCSimple { - pub _base: PyCData, - pub _type_: String, - pub value: AtomicCell, - pub cdata: PyRwLock, -} + if offset < 0 { + return Err(vm.new_value_error("offset cannot be negative")); + } + let offset = offset as usize; -impl Debug for PyCSimple { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PyCSimple") - .field("_type_", &self._type_) - .finish() + // Check if buffer is large enough + if offset + size > source.len() { + return Err(vm.new_value_error(format!( + "Buffer size too small ({} instead of at least {} bytes)", + source.len(), + offset + size + ))); + } + + // Copy bytes from buffer at offset + let data = source[offset..offset + size].to_vec(); + + Ok(Self::from_bytes_with_length(data, None, length)) } -} -fn value_to_bytes_endian( - _type_: &str, - value: &PyObject, - swapped: bool, - vm: &VirtualMachine, -) -> Vec { - // Helper macro for endian conversion - macro_rules! to_bytes { - ($val:expr) => { - if swapped { - // Use opposite endianness - #[cfg(target_endian = "little")] - { - $val.to_be_bytes().to_vec() - } - #[cfg(target_endian = "big")] - { - $val.to_le_bytes().to_vec() - } - } else { - $val.to_ne_bytes().to_vec() - } - }; + #[inline] + pub fn size(&self) -> usize { + self.buffer.read().len() } - match _type_ { - "c" => { - // c_char - single byte - if let Some(bytes) = value.downcast_ref::() - && !bytes.is_empty() - { - return vec![bytes.as_bytes()[0]]; - } - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_u8() - { - return vec![v]; - } - vec![0] - } - "u" => { - // c_wchar - 4 bytes (wchar_t on most platforms) - if let Ok(s) = value.str(vm) - && let Some(c) = s.as_str().chars().next() - { - return to_bytes!(c as u32); - } - vec![0; 4] + /// Check if this buffer is borrowed (external memory reference) + #[inline] + pub fn is_borrowed(&self) -> bool { + matches!(&*self.buffer.read(), Cow::Borrowed(_)) + } + + /// Write bytes at offset - handles both borrowed and owned buffers + /// + /// For borrowed buffers (from from_address), writes directly to external memory. + /// For owned buffers, writes through to_mut() as normal. + /// + /// # Safety + /// For borrowed buffers, caller must ensure the memory is writable. + pub fn write_bytes_at_offset(&self, offset: usize, bytes: &[u8]) { + let buffer = self.buffer.read(); + if offset + bytes.len() > buffer.len() { + return; // Out of bounds } - "b" => { - // c_byte - signed char (1 byte) - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_i8() - { - return vec![v as u8]; + + match &*buffer { + Cow::Borrowed(slice) => { + // For borrowed memory, write directly + // SAFETY: We assume the caller knows this memory is writable + // (e.g., from from_address pointing to a ctypes buffer) + unsafe { + let ptr = slice.as_ptr() as *mut u8; + std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr.add(offset), bytes.len()); + } } - vec![0] - } - "B" => { - // c_ubyte - unsigned char (1 byte) - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_u8() - { - return vec![v]; + Cow::Owned(_) => { + // For owned memory, use to_mut() through write lock + drop(buffer); + let mut buffer = self.buffer.write(); + buffer.to_mut()[offset..offset + bytes.len()].copy_from_slice(bytes); } - vec![0] } - "h" => { - // c_short (2 bytes) - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_i16() - { - return to_bytes!(v); - } - vec![0; 2] + } + + /// Generate unique key for nested references (unique_key) + /// Creates a hierarchical key by walking up the b_base chain. + /// Format: "index:parent_index:grandparent_index:..." + pub fn unique_key(&self, index: usize) -> String { + let mut key = format!("{index:x}"); + // Walk up the base chain to build hierarchical key + if self.base.read().is_some() { + let parent_index = self.index.load(); + key.push_str(&format!(":{parent_index:x}")); } - "H" => { - // c_ushort (2 bytes) - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_u16() - { - return to_bytes!(v); - } - vec![0; 2] + key + } + + /// Keep a reference in the objects dictionary (KeepRef) + /// + /// Stores 'keep' in this object's b_objects dict at key 'index'. + /// If keep is None, does nothing (optimization). + /// This function stores the value directly - caller should use get_kept_objects() + /// first if they want to store the _objects of a CData instead of the object itself. + /// + /// If this object has a base (is embedded in another structure/union/array), + /// the reference is stored in the root object's b_objects with a hierarchical key. + pub fn keep_ref(&self, index: usize, keep: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // Optimization: no need to store None + if vm.is_none(&keep) { + return Ok(()); } - "i" => { - // c_int (4 bytes) - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_i32() - { - return to_bytes!(v); - } - vec![0; 4] + + // Build hierarchical key + let key = self.unique_key(index); + + // If we have a base object, find root and store there + if let Some(base_obj) = self.base.read().clone() { + // Find root by walking up the base chain + let root_obj = Self::find_root_object(&base_obj); + Self::store_in_object(&root_obj, &key, keep, vm)?; + return Ok(()); } - "I" => { - // c_uint (4 bytes) - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_u32() - { - return to_bytes!(v); + + // No base - store in own objects dict + let mut objects = self.objects.write(); + + // Initialize b_objects if needed + if objects.is_none() { + if self.length.load() > 0 { + // Need to store multiple references - create a dict + *objects = Some(vm.ctx.new_dict().into()); + } else { + // Only one reference needed - store directly + *objects = Some(keep); + return Ok(()); } - vec![0; 4] } - "l" => { - // c_long (platform dependent) - if let Ok(int_val) = value.try_to_value::(vm) { - return to_bytes!(int_val); - } - const SIZE: usize = std::mem::size_of::(); - vec![0; SIZE] + + // If b_objects is not a dict, convert it to a dict first + // This preserves the existing reference (e.g., from cast) when adding new references + if let Some(obj) = objects.as_ref() + && obj.downcast_ref::().is_none() + { + // Convert existing single reference to a dict + let dict = vm.ctx.new_dict(); + // Store the original object with a special key (id-based) + let id_key: PyObjectRef = vm.ctx.new_int(obj.get_id() as i64).into(); + dict.set_item(&*id_key, obj.clone(), vm)?; + *objects = Some(dict.into()); } - "L" => { - // c_ulong (platform dependent) - if let Ok(int_val) = value.try_to_value::(vm) { - return to_bytes!(int_val); - } - const SIZE: usize = std::mem::size_of::(); - vec![0; SIZE] + + // Store in dict with unique key + if let Some(dict_obj) = objects.as_ref() + && let Some(dict) = dict_obj.downcast_ref::() + { + let key_obj: PyObjectRef = vm.ctx.new_str(key).into(); + dict.set_item(&*key_obj, keep, vm)?; } - "q" => { - // c_longlong (8 bytes) - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_i64() - { - return to_bytes!(v); - } - vec![0; 8] + + Ok(()) + } + + /// Find the root object (one without a base) by walking up the base chain + fn find_root_object(obj: &PyObject) -> PyObjectRef { + // Try to get base from different ctypes types + let base = if let Some(cdata) = obj.downcast_ref::() { + cdata.base.read().clone() + } else { + None + }; + + // Recurse if there's a base, otherwise this is the root + if let Some(base_obj) = base { + Self::find_root_object(&base_obj) + } else { + obj.to_owned() } - "Q" => { - // c_ulonglong (8 bytes) - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_u64() - { - return to_bytes!(v); - } - vec![0; 8] - } - "f" => { - // c_float (4 bytes) - intë„ í—ˆìš© - if let Ok(float_val) = value.try_float(vm) { - return to_bytes!(float_val.to_f64() as f32); - } - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_f64() - { - return to_bytes!(v as f32); - } - vec![0; 4] - } - "d" | "g" => { - // c_double (8 bytes) - intë„ í—ˆìš© - if let Ok(float_val) = value.try_float(vm) { - return to_bytes!(float_val.to_f64()); - } - if let Ok(int_val) = value.try_int(vm) - && let Some(v) = int_val.as_bigint().to_f64() - { - return to_bytes!(v); - } - vec![0; 8] - } - "?" => { - // c_bool (1 byte) - if let Ok(b) = value.to_owned().try_to_bool(vm) { - return vec![if b { 1 } else { 0 }]; - } - vec![0] - } - "P" | "z" | "Z" => { - // Pointer types (platform pointer size) - vec![0; std::mem::size_of::()] - } - _ => vec![0], } -} -impl Constructor for PyCSimple { - type Args = (OptionalArg,); - - fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - let args: Self::Args = args.bind(vm)?; - let attributes = cls.get_attributes(); - let _type_ = attributes - .iter() - .find(|(k, _)| { - k.to_object() - .str(vm) - .map(|s| s.to_string() == "_type_") - .unwrap_or(false) - }) - .ok_or_else(|| { - vm.new_type_error(format!( - "cannot create '{}' instances: no _type_ attribute", - cls.name() - )) - })? - .1 - .str(vm)? - .to_string(); - let value = if let Some(ref v) = args.0.into_option() { - set_primitive(_type_.as_str(), v, vm)? + /// Store a value in an object's _objects dict with the given key + fn store_in_object( + obj: &PyObject, + key: &str, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Get the objects dict from the object + let objects_lock = if let Some(cdata) = obj.downcast_ref::() { + &cdata.objects } else { - match _type_.as_str() { - "c" | "u" => PyObjectRef::from(vm.ctx.new_bytes(vec![0])), - "b" | "B" | "h" | "H" | "i" | "I" | "l" | "q" | "L" | "Q" => { - PyObjectRef::from(vm.ctx.new_int(0)) - } - "f" | "d" | "g" => PyObjectRef::from(vm.ctx.new_float(0.0)), - "?" => PyObjectRef::from(vm.ctx.new_bool(false)), - _ => vm.ctx.none(), // "z" | "Z" | "P" - } + return Ok(()); // Unknown type, skip }; - // Check if this is a swapped endian type - let swapped = cls - .as_object() - .get_attr("_swappedbytes_", vm) - .map(|v| v.is_true(vm).unwrap_or(false)) - .unwrap_or(false); + let mut objects = objects_lock.write(); - let buffer = value_to_bytes_endian(&_type_, &value, swapped, vm); - let cdata = CDataObject::from_bytes(buffer, None); - PyCSimple { - _base: PyCData::new(cdata.clone()), - _type_, - value: AtomicCell::new(value), - cdata: PyRwLock::new(cdata), + // Initialize if needed + if objects.is_none() { + *objects = Some(vm.ctx.new_dict().into()); } - .into_ref_with_type(vm, cls) - .map(Into::into) - } - fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { - unimplemented!("use slot_new") - } -} + // If not a dict, convert to dict + if let Some(obj) = objects.as_ref() + && obj.downcast_ref::().is_none() + { + let dict = vm.ctx.new_dict(); + let id_key: PyObjectRef = vm.ctx.new_int(obj.get_id() as i64).into(); + dict.set_item(&*id_key, obj.clone(), vm)?; + *objects = Some(dict.into()); + } -#[pyclass(flags(BASETYPE), with(Constructor, AsBuffer))] -impl PyCSimple { - #[pygetset] - fn _objects(&self) -> Option { - self.cdata.read().objects.clone() - } - - #[pygetset(name = "value")] - pub fn value(instance: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let zelf: &Py = instance - .downcast_ref() - .ok_or_else(|| vm.new_type_error("cannot get value of instance"))?; - let raw_value = unsafe { (*zelf.value.as_ptr()).clone() }; - - // Convert to unsigned if needed for unsigned types - match zelf._type_.as_str() { - "B" | "H" | "I" | "L" | "Q" => { - if let Ok(int_val) = raw_value.try_int(vm) { - let n = int_val.as_bigint(); - // Use platform-specific C types for correct unsigned conversion - match zelf._type_.as_str() { - "B" => { - if let Some(v) = n.to_i64() { - return Ok(vm.ctx.new_int((v as u8) as u64).into()); - } - } - "H" => { - if let Some(v) = n.to_i64() { - return Ok(vm.ctx.new_int((v as c_ushort) as u64).into()); - } - } - "I" => { - if let Some(v) = n.to_i64() { - return Ok(vm.ctx.new_int((v as c_uint) as u64).into()); - } - } - "L" => { - if let Some(v) = n.to_i128() { - return Ok(vm.ctx.new_int(v as c_ulong).into()); - } - } - "Q" => { - if let Some(v) = n.to_i128() { - return Ok(vm.ctx.new_int(v as c_ulonglong).into()); - } - } - _ => {} - }; - } - Ok(raw_value) - } - _ => Ok(raw_value), + // Store in dict + if let Some(dict_obj) = objects.as_ref() + && let Some(dict) = dict_obj.downcast_ref::() + { + let key_obj: PyObjectRef = vm.ctx.new_str(key).into(); + dict.set_item(&*key_obj, value, vm)?; } + + Ok(()) } - #[pygetset(name = "value", setter)] - fn set_value(instance: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let zelf: PyRef = instance - .clone() - .downcast() - .map_err(|_| vm.new_type_error("cannot set value of instance"))?; - let content = set_primitive(zelf._type_.as_str(), &value, vm)?; + /// Get kept objects from a CData instance + /// Returns the _objects of the CData, or an empty dict if None. + pub fn get_kept_objects(value: &PyObject, vm: &VirtualMachine) -> PyObjectRef { + value + .downcast_ref::() + .and_then(|cdata| cdata.objects.read().clone()) + .unwrap_or_else(|| vm.ctx.new_dict().into()) + } - // Check if this is a swapped endian type - let swapped = instance - .class() - .as_object() - .get_attr("_swappedbytes_", vm) - .map(|v| v.is_true(vm).unwrap_or(false)) - .unwrap_or(false); - - // Update buffer when value changes - let buffer_bytes = value_to_bytes_endian(&zelf._type_, &content, swapped, vm); - zelf.cdata.write().buffer = buffer_bytes; - zelf.value.store(content); - Ok(()) + /// Check if a value should be stored in _objects + /// Returns true for ctypes objects and bytes (for c_char_p) + pub fn should_keep_ref(value: &PyObject) -> bool { + value.downcast_ref::().is_some() || value.downcast_ref::().is_some() } - #[pyclassmethod] - fn repeat(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { - use super::_ctypes::get_size; - use super::array::create_array_type_with_stg_info; - - if n < 0 { - return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}"))); - } - // Get element size from cls - let element_size = if let Ok(type_attr) = cls.as_object().get_attr("_type_", vm) { - if let Ok(s) = type_attr.str(vm) { - let s = s.to_string(); - if s.len() == 1 { - get_size(&s) - } else { - std::mem::size_of::() - } + /// PyCData_set + /// Sets a field value at the given offset, handling type conversion and KeepRef + #[allow(clippy::too_many_arguments)] + pub fn set_field( + &self, + proto: &PyObject, + value: PyObjectRef, + index: usize, + size: usize, + offset: usize, + needs_swap: bool, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Check if this is a c_char or c_wchar array field + let is_char_array = PyCField::is_char_array(proto, vm); + let is_wchar_array = PyCField::is_wchar_array(proto, vm); + + // For c_char arrays with bytes input, copy only up to first null + if is_char_array { + if let Some(bytes_val) = value.downcast_ref::() { + let src = bytes_val.as_bytes(); + let to_copy = PyCField::bytes_for_char_array(src); + let copy_len = std::cmp::min(to_copy.len(), size); + self.write_bytes_at_offset(offset, &to_copy[..copy_len]); + self.keep_ref(index, value, vm)?; + return Ok(()); } else { - std::mem::size_of::() + return Err(vm.new_type_error("bytes expected instead of str instance")); } - } else { - std::mem::size_of::() - }; - let total_size = element_size * (n as usize); - let stg_info = super::util::StgInfo::new_array( - total_size, - element_size, - n as usize, - cls.clone().into(), - element_size, - ); - create_array_type_with_stg_info(stg_info, vm) - } + } - #[pyclassmethod] - fn from_address(cls: PyTypeRef, address: isize, vm: &VirtualMachine) -> PyResult { - use super::_ctypes::get_size; - // Get _type_ attribute directly - let type_attr = cls - .as_object() - .get_attr("_type_", vm) - .map_err(|_| vm.new_type_error(format!("'{}' has no _type_ attribute", cls.name())))?; - let type_str = type_attr.str(vm)?.to_string(); - let size = get_size(&type_str); - - // Create instance with value read from address - let value = if address != 0 && size > 0 { - // Safety: This is inherently unsafe - reading from arbitrary memory address - unsafe { - let ptr = address as *const u8; - let bytes = std::slice::from_raw_parts(ptr, size); - // Convert bytes to appropriate Python value based on type - bytes_to_pyobject(&cls, bytes, vm)? + // For c_wchar arrays with str input, convert to wchar_t + if is_wchar_array { + if let Some(str_val) = value.downcast_ref::() { + // Convert str to wchar_t bytes (platform-dependent size) + let mut wchar_bytes = Vec::with_capacity(size); + for ch in str_val.as_str().chars().take(size / WCHAR_SIZE) { + let mut bytes = [0u8; 4]; + wchar_to_bytes(ch as u32, &mut bytes); + wchar_bytes.extend_from_slice(&bytes[..WCHAR_SIZE]); + } + // Pad with nulls to fill the array + while wchar_bytes.len() < size { + wchar_bytes.push(0); + } + self.write_bytes_at_offset(offset, &wchar_bytes); + self.keep_ref(index, value, vm)?; + return Ok(()); + } else if value.downcast_ref::().is_some() { + return Err(vm.new_type_error("str expected instead of bytes instance")); } + } + + // Special handling for Pointer fields with Array values + if let Some(proto_type) = proto.downcast_ref::() + && proto_type + .class() + .fast_issubclass(super::pointer::PyCPointerType::static_type()) + && let Some(array) = value.downcast_ref::() + { + let buffer_addr = { + let array_buffer = array.0.buffer.read(); + array_buffer.as_ptr() as usize + }; + let addr_bytes = buffer_addr.to_ne_bytes(); + let len = std::cmp::min(addr_bytes.len(), size); + self.write_bytes_at_offset(offset, &addr_bytes[..len]); + self.keep_ref(index, value, vm)?; + return Ok(()); + } + + // Get field type code for special handling + let field_type_code = proto + .get_attr("_type_", vm) + .ok() + .and_then(|attr| attr.downcast_ref::().map(|s| s.to_string())); + + let (mut bytes, converted_value) = if let Some(type_code) = &field_type_code { + PyCField::value_to_bytes_for_type(type_code, &value, size, vm)? } else { - vm.ctx.none() + (PyCField::value_to_bytes(&value, size, vm)?, None) }; - // Create instance using the type's constructor - let args = FuncArgs::new(vec![value], KwArgs::default()); - PyCSimple::slot_new(cls.clone(), args, vm) + // Swap bytes for opposite endianness + if needs_swap { + bytes.reverse(); + } + + self.write_bytes_at_offset(offset, &bytes); + + // KeepRef: for z/Z types use converted value, otherwise use original + if let Some(converted) = converted_value { + self.keep_ref(index, converted, vm)?; + } else if Self::should_keep_ref(&value) { + let to_keep = Self::get_kept_objects(&value, vm); + self.keep_ref(index, to_keep, vm)?; + } + + Ok(()) } - #[pyclassmethod] - fn from_buffer( - cls: PyTypeRef, - source: PyObjectRef, - offset: OptionalArg, + /// PyCData_get + /// Gets a field value at the given offset + pub fn get_field( + &self, + proto: &PyObject, + index: usize, + size: usize, + offset: usize, + base_obj: PyObjectRef, vm: &VirtualMachine, ) -> PyResult { - use super::_ctypes::get_size; - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); + // Get buffer data at offset + let buffer = self.buffer.read(); + if offset + size > buffer.len() { + return Ok(vm.ctx.new_int(0).into()); } - let offset = offset as usize; - // Get buffer from source - let buffer = PyBuffer::try_from_object(vm, source.clone())?; + // Check if field type is an array type + if let Some(type_ref) = proto.downcast_ref::() + && let Some(stg) = type_ref.stg_info_opt() + && stg.element_type.is_some() + { + // c_char array → return bytes + if PyCField::is_char_array(proto, vm) { + let data = &buffer[offset..offset + size]; + // Find first null terminator (or use full length) + let end = data.iter().position(|&b| b == 0).unwrap_or(data.len()); + return Ok(vm.ctx.new_bytes(data[..end].to_vec()).into()); + } - // Check if buffer is writable - if buffer.desc.readonly { - return Err(vm.new_type_error("underlying buffer is not writable".to_owned())); + // c_wchar array → return str + if PyCField::is_wchar_array(proto, vm) { + let data = &buffer[offset..offset + size]; + // wchar_t → char conversion, skip null + let chars: String = data + .chunks(WCHAR_SIZE) + .filter_map(|chunk| { + wchar_from_bytes(chunk) + .filter(|&wchar| wchar != 0) + .and_then(char::from_u32) + }) + .collect(); + return Ok(vm.ctx.new_str(chars).into()); + } + + // Other array types - create array with a copy of data from the base's buffer + // The array also keeps a reference to the base for keeping it alive and for writes + let array_data = buffer[offset..offset + size].to_vec(); + drop(buffer); + + let cdata_obj = + Self::from_base_with_data(base_obj, offset, index, stg.length, array_data); + let array_type: PyTypeRef = proto + .to_owned() + .downcast() + .map_err(|_| vm.new_type_error("expected array type"))?; + + return super::array::PyCArray(cdata_obj) + .into_ref_with_type(vm, array_type) + .map(Into::into); } - // Get _type_ attribute directly - let type_attr = cls - .as_object() - .get_attr("_type_", vm) - .map_err(|_| vm.new_type_error(format!("'{}' has no _type_ attribute", cls.name())))?; - let type_str = type_attr.str(vm)?.to_string(); - let size = get_size(&type_str); + let buffer_data = buffer[offset..offset + size].to_vec(); + drop(buffer); - // Check if buffer is large enough - let buffer_len = buffer.desc.len; - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); + // Get proto as type + let proto_type: PyTypeRef = proto + .to_owned() + .downcast() + .map_err(|_| vm.new_type_error("field proto is not a type"))?; + + let proto_metaclass = proto_type.class(); + + // Simple types: return primitive value + if proto_metaclass.fast_issubclass(super::simple::PyCSimpleType::static_type()) { + // Check for byte swapping + let needs_swap = base_obj + .class() + .as_object() + .get_attr("_swappedbytes_", vm) + .is_ok() + || proto_type + .as_object() + .get_attr("_swappedbytes_", vm) + .is_ok(); + + let data = if needs_swap && size > 1 { + let mut swapped = buffer_data.clone(); + swapped.reverse(); + swapped + } else { + buffer_data + }; + + return bytes_to_pyobject(&proto_type, &data, vm); } - // Read bytes from buffer at offset - let bytes = buffer.obj_bytes(); - let data = &bytes[offset..offset + size]; - let value = bytes_to_pyobject(&cls, data, vm)?; + // Complex types: create ctypes instance via PyCData_FromBaseObj + let ptr = self.buffer.read().as_ptr().wrapping_add(offset) as *mut u8; + let cdata_obj = unsafe { Self::from_base_obj(ptr, size, base_obj.clone(), index) }; - // Create instance - let args = FuncArgs::new(vec![value], KwArgs::default()); - let instance = PyCSimple::slot_new(cls.clone(), args, vm)?; + if proto_metaclass.fast_issubclass(super::structure::PyCStructType::static_type()) + || proto_metaclass.fast_issubclass(super::union::PyCUnionType::static_type()) + || proto_metaclass.fast_issubclass(super::pointer::PyCPointerType::static_type()) + { + cdata_obj.into_ref_with_type(vm, proto_type).map(Into::into) + } else { + // Fallback + Ok(vm.ctx.new_int(0).into()) + } + } +} - // TODO: Store reference to source in _objects to keep buffer alive - Ok(instance) +#[pyclass(flags(BASETYPE))] +impl PyCData { + #[pygetset] + fn _objects(&self) -> Option { + self.objects.read().clone() + } + + #[pygetset] + fn _b_base_(&self) -> Option { + self.base.read().clone() + } + + #[pygetset] + fn _b_needsfree_(&self) -> i32 { + // Borrowed (from_address) or has base object → 0 (don't free) + // Owned and no base → 1 (need to free) + if self.is_borrowed() || self.base.read().is_some() { + 0 + } else { + 1 + } + } + + // CDataType_methods - shared across all ctypes types + + #[pyclassmethod] + fn from_buffer( + cls: PyTypeRef, + source: PyObjectRef, + offset: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let cdata = Self::from_buffer_impl(&cls, source, offset.unwrap_or(0), vm)?; + cdata.into_ref_with_type(vm, cls).map(Into::into) } #[pyclassmethod] @@ -875,191 +1117,1237 @@ impl PyCSimple { offset: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use super::_ctypes::get_size; - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); - } - let offset = offset as usize; - - // Get _type_ attribute directly for simple types - let type_attr = cls - .as_object() - .get_attr("_type_", vm) - .map_err(|_| vm.new_type_error(format!("'{}' has no _type_ attribute", cls.name())))?; - let type_str = type_attr.str(vm)?.to_string(); - let size = get_size(&type_str); + let cdata = + Self::from_buffer_copy_impl(&cls, &source.borrow_buf(), offset.unwrap_or(0), vm)?; + cdata.into_ref_with_type(vm, cls).map(Into::into) + } - // Borrow bytes from source - let source_bytes = source.borrow_buf(); - let buffer_len = source_bytes.len(); + #[pyclassmethod] + fn from_address(cls: PyTypeRef, address: isize, vm: &VirtualMachine) -> PyResult { + let size = { + let stg_info = cls.stg_info(vm)?; + stg_info.size + }; - // Check if buffer is large enough - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); + if size == 0 { + return Err(vm.new_type_error("abstract class")); } - // Copy bytes from buffer at offset - let data = &source_bytes[offset..offset + size]; - let value = bytes_to_pyobject(&cls, data, vm)?; - - // Create instance (independent copy, no reference tracking) - let args = FuncArgs::new(vec![value], KwArgs::default()); - PyCSimple::slot_new(cls.clone(), args, vm) + // PyCData_AtAddress + let cdata = unsafe { Self::at_address(address as *const u8, size) }; + cdata.into_ref_with_type(vm, cls).map(Into::into) } #[pyclassmethod] - fn in_dll(cls: PyTypeRef, dll: PyObjectRef, name: PyStrRef, vm: &VirtualMachine) -> PyResult { - use super::_ctypes::get_size; - use libloading::Symbol; + fn in_dll( + cls: PyTypeRef, + dll: PyObjectRef, + name: crate::builtins::PyStrRef, + vm: &VirtualMachine, + ) -> PyResult { + let size = { + let stg_info = cls.stg_info(vm)?; + stg_info.size + }; + + if size == 0 { + return Err(vm.new_type_error("abstract class")); + } // Get the library handle from dll object let handle = if let Ok(int_handle) = dll.try_int(vm) { - // dll is an integer handle int_handle .as_bigint() .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid library handle".to_owned()))? + .ok_or_else(|| vm.new_value_error("Invalid library handle"))? } else { - // dll is a CDLL/PyDLL/WinDLL object with _handle attribute dll.get_attr("_handle", vm)? .try_int(vm)? .as_bigint() .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid library handle".to_owned()))? + .ok_or_else(|| vm.new_value_error("Invalid library handle"))? }; - // Get the library from cache - let library_cache = crate::stdlib::ctypes::library::libcache().read(); - let library = library_cache - .get_lib(handle) - .ok_or_else(|| vm.new_attribute_error("Library not found".to_owned()))?; - - // Get symbol address from library - let symbol_name = format!("{}\0", name.as_str()); - let inner_lib = library.lib.lock(); - - let symbol_address = if let Some(lib) = &*inner_lib { - unsafe { - // Try to get the symbol from the library - let symbol: Symbol<'_, *mut u8> = lib.get(symbol_name.as_bytes()).map_err(|e| { - vm.new_attribute_error(format!("{}: symbol '{}' not found", e, name.as_str())) - })?; - *symbol as usize + // Get symbol address using platform-specific API + let symbol_name = std::ffi::CString::new(name.as_str()) + .map_err(|_| vm.new_value_error("Invalid symbol name"))?; + + #[cfg(windows)] + let ptr: *const u8 = unsafe { + match windows_sys::Win32::System::LibraryLoader::GetProcAddress( + handle as windows_sys::Win32::Foundation::HMODULE, + symbol_name.as_ptr() as *const u8, + ) { + Some(p) => p as *const u8, + None => std::ptr::null(), } + }; + + #[cfg(not(windows))] + let ptr: *const u8 = + unsafe { libc::dlsym(handle as *mut libc::c_void, symbol_name.as_ptr()) as *const u8 }; + + if ptr.is_null() { + return Err( + vm.new_value_error(format!("symbol '{}' not found in library", name.as_str())) + ); + } + + // PyCData_AtAddress + let cdata = unsafe { Self::at_address(ptr, size) }; + cdata.into_ref_with_type(vm, cls).map(Into::into) + } +} + +// PyCField - Field descriptor for Structure/Union types + +/// CField descriptor for Structure/Union field access +#[pyclass(name = "CField", module = "_ctypes")] +#[derive(Debug, PyPayload)] +pub struct PyCField { + /// Byte offset of the field within the structure/union + pub(crate) offset: isize, + /// Encoded size: for bitfields (bit_size << 16) | bit_offset, otherwise byte size + pub(crate) size: isize, + /// Index into PyCData's object array + pub(crate) index: usize, + /// The ctypes type for this field + pub(crate) proto: PyTypeRef, + /// Flag indicating if the field is anonymous (MakeAnonFields sets this) + pub(crate) anonymous: bool, +} + +#[inline(always)] +const fn num_bits(size: isize) -> isize { + size >> 16 +} + +#[inline(always)] +const fn field_size(size: isize) -> isize { + size & 0xFFFF +} + +#[inline(always)] +const fn is_bitfield(size: isize) -> bool { + (size >> 16) != 0 +} + +impl PyCField { + /// Create a new CField descriptor (non-bitfield) + pub fn new(proto: PyTypeRef, offset: isize, size: isize, index: usize) -> Self { + Self { + offset, + size, + index, + proto, + anonymous: false, + } + } + + /// Create a new CField descriptor for a bitfield + #[allow(dead_code)] + pub fn new_bitfield( + proto: PyTypeRef, + offset: isize, + bit_size: u16, + bit_offset: u16, + index: usize, + ) -> Self { + let encoded_size = ((bit_size as isize) << 16) | (bit_offset as isize); + Self { + offset, + size: encoded_size, + index, + proto, + anonymous: false, + } + } + + /// Get the actual byte size (for non-bitfields) or bit storage size (for bitfields) + pub fn byte_size(&self) -> usize { + field_size(self.size) as usize + } + + /// Create a new CField from an existing field with adjusted offset and index + /// Used by MakeFields to promote anonymous fields + pub fn new_from_field(fdescr: &PyCField, index_offset: usize, offset_delta: isize) -> Self { + Self { + offset: fdescr.offset + offset_delta, + size: fdescr.size, + index: fdescr.index + index_offset, + proto: fdescr.proto.clone(), + anonymous: false, // promoted fields are not anonymous themselves + } + } + + /// Set anonymous flag + pub fn set_anonymous(&mut self, anonymous: bool) { + self.anonymous = anonymous; + } +} + +impl Representable for PyCField { + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + // Get type name from proto (which is always PyTypeRef) + let tp_name = zelf.proto.name().to_string(); + + // Bitfield: + // Regular: + if is_bitfield(zelf.size) { + let bit_offset = field_size(zelf.size); + let bits = num_bits(zelf.size); + Ok(format!( + "", + tp_name, zelf.offset, bit_offset, bits + )) } else { - return Err(vm.new_attribute_error("Library is closed".to_owned())); + Ok(format!( + "", + tp_name, zelf.offset, zelf.size + )) + } + } +} + +/// PyCField_get +impl GetDescriptor for PyCField { + fn descr_get( + zelf: PyObjectRef, + obj: Option, + _cls: Option, + vm: &VirtualMachine, + ) -> PyResult { + let zelf = zelf + .downcast::() + .map_err(|_| vm.new_type_error("expected CField"))?; + + // If obj is None, return the descriptor itself (class attribute access) + let obj = match obj { + Some(obj) if !vm.is_none(&obj) => obj, + _ => return Ok(zelf.into()), }; - // Get _type_ attribute and size - let type_attr = cls - .as_object() - .get_attr("_type_", vm) - .map_err(|_| vm.new_type_error(format!("'{}' has no _type_ attribute", cls.name())))?; - let type_str = type_attr.str(vm)?.to_string(); - let size = get_size(&type_str); - - // Read value from symbol address - let value = if symbol_address != 0 && size > 0 { - // Safety: Reading from a symbol address provided by dlsym - unsafe { - let ptr = symbol_address as *const u8; - let bytes = std::slice::from_raw_parts(ptr, size); - bytes_to_pyobject(&cls, bytes, vm)? + let offset = zelf.offset as usize; + let size = zelf.byte_size(); + + // Get PyCData from obj (works for both Structure and Union) + let cdata = PyCField::get_cdata_from_obj(&obj, vm)?; + + // PyCData_get + cdata.get_field( + zelf.proto.as_object(), + zelf.index, + size, + offset, + obj.clone(), + vm, + ) + } +} + +impl PyCField { + /// Convert a Python value to bytes + fn value_to_bytes(value: &PyObject, size: usize, vm: &VirtualMachine) -> PyResult> { + // 1. Handle bytes objects + if let Some(bytes) = value.downcast_ref::() { + let src = bytes.as_bytes(); + let mut result = vec![0u8; size]; + let len = std::cmp::min(src.len(), size); + result[..len].copy_from_slice(&src[..len]); + Ok(result) + } + // 2. Handle ctypes array instances (copy their buffer) + else if let Some(cdata) = value.downcast_ref::() { + let buffer = cdata.buffer.read(); + let mut result = vec![0u8; size]; + let len = std::cmp::min(buffer.len(), size); + result[..len].copy_from_slice(&buffer[..len]); + Ok(result) + } + // 4. Handle float values (check before int, since float.try_int would truncate) + else if let Some(float_val) = value.downcast_ref::() { + let f = float_val.to_f64(); + match size { + 4 => { + let val = f as f32; + Ok(val.to_ne_bytes().to_vec()) + } + 8 => Ok(f.to_ne_bytes().to_vec()), + _ => unreachable!("wrong payload size"), + } + } + // 4. Handle integer values + else if let Ok(int_val) = value.try_int(vm) { + let i = int_val.as_bigint(); + match size { + 1 => { + let val = i.to_i8().unwrap_or(0); + Ok(val.to_ne_bytes().to_vec()) + } + 2 => { + let val = i.to_i16().unwrap_or(0); + Ok(val.to_ne_bytes().to_vec()) + } + 4 => { + let val = i.to_i32().unwrap_or(0); + Ok(val.to_ne_bytes().to_vec()) + } + 8 => { + let val = i.to_i64().unwrap_or(0); + Ok(val.to_ne_bytes().to_vec()) + } + _ => Ok(vec![0u8; size]), } } else { - vm.ctx.none() - }; + Ok(vec![0u8; size]) + } + } - // Create instance - let args = FuncArgs::new(vec![value], KwArgs::default()); - let instance = PyCSimple::slot_new(cls.clone(), args, vm)?; + /// Convert a Python value to bytes with type-specific handling for pointer types. + /// Returns (bytes, optional holder for wchar buffer). + fn value_to_bytes_for_type( + type_code: &str, + value: &PyObject, + size: usize, + vm: &VirtualMachine, + ) -> PyResult<(Vec, Option)> { + match type_code { + // c_float: always convert to float first (f_set) + "f" => { + let f = if let Some(float_val) = value.downcast_ref::() { + float_val.to_f64() + } else if let Ok(int_val) = value.try_int(vm) { + int_val.as_bigint().to_i64().unwrap_or(0) as f64 + } else { + return Err(vm.new_type_error(format!( + "float expected instead of {}", + value.class().name() + ))); + }; + let val = f as f32; + Ok((val.to_ne_bytes().to_vec(), None)) + } + // c_double: always convert to float first (d_set) + "d" => { + let f = if let Some(float_val) = value.downcast_ref::() { + float_val.to_f64() + } else if let Ok(int_val) = value.try_int(vm) { + int_val.as_bigint().to_i64().unwrap_or(0) as f64 + } else { + return Err(vm.new_type_error(format!( + "float expected instead of {}", + value.class().name() + ))); + }; + Ok((f.to_ne_bytes().to_vec(), None)) + } + // c_longdouble: convert to float (treated as f64 in RustPython) + "g" => { + let f = if let Some(float_val) = value.downcast_ref::() { + float_val.to_f64() + } else if let Ok(int_val) = value.try_int(vm) { + int_val.as_bigint().to_i64().unwrap_or(0) as f64 + } else { + return Err(vm.new_type_error(format!( + "float expected instead of {}", + value.class().name() + ))); + }; + Ok((f.to_ne_bytes().to_vec(), None)) + } + "z" => { + // c_char_p: store pointer to null-terminated bytes + if let Some(bytes) = value.downcast_ref::() { + let (converted, ptr) = ensure_z_null_terminated(bytes, vm); + let mut result = vec![0u8; size]; + let addr_bytes = ptr.to_ne_bytes(); + let len = std::cmp::min(addr_bytes.len(), size); + result[..len].copy_from_slice(&addr_bytes[..len]); + return Ok((result, Some(converted))); + } + // Integer address + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_usize().unwrap_or(0); + let mut result = vec![0u8; size]; + let bytes = v.to_ne_bytes(); + let len = std::cmp::min(bytes.len(), size); + result[..len].copy_from_slice(&bytes[..len]); + return Ok((result, None)); + } + // None -> NULL pointer + if vm.is_none(value) { + return Ok((vec![0u8; size], None)); + } + Ok((PyCField::value_to_bytes(value, size, vm)?, None)) + } + "Z" => { + // c_wchar_p: store pointer to null-terminated wchar_t buffer + if let Some(s) = value.downcast_ref::() { + let (holder, ptr) = str_to_wchar_bytes(s.as_str(), vm); + let mut result = vec![0u8; size]; + let addr_bytes = ptr.to_ne_bytes(); + let len = std::cmp::min(addr_bytes.len(), size); + result[..len].copy_from_slice(&addr_bytes[..len]); + return Ok((result, Some(holder))); + } + // Integer address + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_usize().unwrap_or(0); + let mut result = vec![0u8; size]; + let bytes = v.to_ne_bytes(); + let len = std::cmp::min(bytes.len(), size); + result[..len].copy_from_slice(&bytes[..len]); + return Ok((result, None)); + } + // None -> NULL pointer + if vm.is_none(value) { + return Ok((vec![0u8; size], None)); + } + Ok((PyCField::value_to_bytes(value, size, vm)?, None)) + } + "P" => { + // c_void_p: store integer as pointer + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_usize().unwrap_or(0); + let mut result = vec![0u8; size]; + let bytes = v.to_ne_bytes(); + let len = std::cmp::min(bytes.len(), size); + result[..len].copy_from_slice(&bytes[..len]); + return Ok((result, None)); + } + // None -> NULL pointer + if vm.is_none(value) { + return Ok((vec![0u8; size], None)); + } + Ok((PyCField::value_to_bytes(value, size, vm)?, None)) + } + _ => Ok((PyCField::value_to_bytes(value, size, vm)?, None)), + } + } + + /// Check if the field type is a c_char array (element type has _type_ == 'c') + fn is_char_array(proto: &PyObject, vm: &VirtualMachine) -> bool { + // Get element_type from StgInfo (for array types) + if let Some(proto_type) = proto.downcast_ref::() + && let Some(stg) = proto_type.stg_info_opt() + && let Some(element_type) = &stg.element_type + { + // Check if element type has _type_ == "c" + if let Ok(type_code) = element_type.as_object().get_attr("_type_", vm) + && let Some(s) = type_code.downcast_ref::() + { + return s.as_str() == "c"; + } + } + false + } - // Store base reference to keep dll alive - if let Ok(simple_ref) = instance.clone().downcast::() { - simple_ref.cdata.write().base = Some(dll); + /// Check if the field type is a c_wchar array (element type has _type_ == 'u') + fn is_wchar_array(proto: &PyObject, vm: &VirtualMachine) -> bool { + // Get element_type from StgInfo (for array types) + if let Some(proto_type) = proto.downcast_ref::() + && let Some(stg) = proto_type.stg_info_opt() + && let Some(element_type) = &stg.element_type + { + // Check if element type has _type_ == "u" + if let Ok(type_code) = element_type.as_object().get_attr("_type_", vm) + && let Some(s) = type_code.downcast_ref::() + { + return s.as_str() == "u"; + } } + false + } - Ok(instance) + /// Convert bytes for c_char array assignment (stops at first null terminator) + /// Returns (bytes_to_copy, copy_len) + fn bytes_for_char_array(src: &[u8]) -> &[u8] { + // Find first null terminator and include it + if let Some(null_pos) = src.iter().position(|&b| b == 0) { + &src[..=null_pos] + } else { + src + } } } -impl PyCSimple { - pub fn to_arg( - &self, - ty: libffi::middle::Type, +#[pyclass( + flags(DISALLOW_INSTANTIATION, IMMUTABLETYPE), + with(Representable, GetDescriptor) +)] +impl PyCField { + /// Get PyCData from object (works for both Structure and Union) + fn get_cdata_from_obj<'a>(obj: &'a PyObjectRef, vm: &VirtualMachine) -> PyResult<&'a PyCData> { + if let Some(s) = obj.downcast_ref::() { + Ok(&s.0) + } else if let Some(u) = obj.downcast_ref::() { + Ok(&u.0) + } else { + Err(vm.new_type_error(format!( + "descriptor works only on Structure or Union instances, got {}", + obj.class().name() + ))) + } + } + + /// PyCField_set + #[pyslot] + fn descr_set( + zelf: &crate::PyObject, + obj: PyObjectRef, + value: PySetterValue, + vm: &VirtualMachine, + ) -> PyResult<()> { + let zelf = zelf + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected CField"))?; + + let offset = zelf.offset as usize; + let size = zelf.byte_size(); + + // Get PyCData from obj (works for both Structure and Union) + let cdata = Self::get_cdata_from_obj(&obj, vm)?; + + match value { + PySetterValue::Assign(value) => { + // Check if needs byte swapping + let needs_swap = (obj + .class() + .as_object() + .get_attr("_swappedbytes_", vm) + .is_ok() + || zelf + .proto + .as_object() + .get_attr("_swappedbytes_", vm) + .is_ok()) + && size > 1; + + // PyCData_set + cdata.set_field( + zelf.proto.as_object(), + value, + zelf.index, + size, + offset, + needs_swap, + vm, + ) + } + PySetterValue::Delete => Err(vm.new_type_error("cannot delete field")), + } + } + + #[pymethod] + fn __set__( + zelf: PyObjectRef, + obj: PyObjectRef, + value: PyObjectRef, vm: &VirtualMachine, - ) -> Option { - let value = unsafe { (*self.value.as_ptr()).clone() }; - if let Ok(i) = value.try_int(vm) { - let i = i.as_bigint(); - return if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u8().as_raw_ptr()) { - i.to_u8().map(|r: u8| libffi::middle::Arg::new(&r)) - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i8().as_raw_ptr()) { - i.to_i8().map(|r: i8| libffi::middle::Arg::new(&r)) - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u16().as_raw_ptr()) { - i.to_u16().map(|r: u16| libffi::middle::Arg::new(&r)) - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i16().as_raw_ptr()) { - i.to_i16().map(|r: i16| libffi::middle::Arg::new(&r)) - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u32().as_raw_ptr()) { - i.to_u32().map(|r: u32| libffi::middle::Arg::new(&r)) - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i32().as_raw_ptr()) { - i.to_i32().map(|r: i32| libffi::middle::Arg::new(&r)) - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u64().as_raw_ptr()) { - i.to_u64().map(|r: u64| libffi::middle::Arg::new(&r)) - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i64().as_raw_ptr()) { - i.to_i64().map(|r: i64| libffi::middle::Arg::new(&r)) + ) -> PyResult<()> { + Self::descr_set(&zelf, obj, PySetterValue::Assign(value), vm) + } + + #[pymethod] + fn __delete__(zelf: PyObjectRef, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + Self::descr_set(&zelf, obj, PySetterValue::Delete, vm) + } + + #[pygetset] + fn offset(&self) -> isize { + self.offset + } + + #[pygetset] + fn size(&self) -> isize { + self.size + } +} + +// ParamFunc implementations (PyCArgObject creation) + +use super::_ctypes::CArgObject; + +/// Call the appropriate paramfunc based on StgInfo.paramfunc +/// info->paramfunc(st, obj) +pub(super) fn call_paramfunc(obj: &PyObject, vm: &VirtualMachine) -> PyResult { + let cls = obj.class(); + let stg_info = cls + .stg_info_opt() + .ok_or_else(|| vm.new_type_error("not a ctypes type"))?; + + match stg_info.paramfunc { + ParamFunc::Simple => simple_paramfunc(obj, vm), + ParamFunc::Array => array_paramfunc(obj, vm), + ParamFunc::Pointer => pointer_paramfunc(obj, vm), + ParamFunc::Structure | ParamFunc::Union => struct_union_paramfunc(obj, &stg_info, vm), + ParamFunc::None => Err(vm.new_type_error("no paramfunc")), + } +} + +/// PyCSimpleType_paramfunc +fn simple_paramfunc(obj: &PyObject, vm: &VirtualMachine) -> PyResult { + use super::simple::PyCSimple; + + let simple = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected simple type"))?; + + // Get type code from _type_ attribute + let cls = obj.class().to_owned(); + let type_code = cls + .type_code(vm) + .ok_or_else(|| vm.new_type_error("no _type_ attribute"))?; + let tag = type_code.as_bytes().first().copied().unwrap_or(b'?'); + + // Read value from buffer: memcpy(&parg->value, self->b_ptr, self->b_size) + let buffer = simple.0.buffer.read(); + let ffi_value = buffer_to_ffi_value(&type_code, &buffer); + + Ok(CArgObject { + tag, + value: ffi_value, + obj: obj.to_owned(), + size: 0, + offset: 0, + }) +} + +/// PyCArrayType_paramfunc +fn array_paramfunc(obj: &PyObject, vm: &VirtualMachine) -> PyResult { + use super::array::PyCArray; + + let array = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected array"))?; + + // p->value.p = (char *)self->b_ptr + let buffer = array.0.buffer.read(); + let ptr_val = buffer.as_ptr() as usize; + + Ok(CArgObject { + tag: b'P', + value: FfiArgValue::Pointer(ptr_val), + obj: obj.to_owned(), + size: 0, + offset: 0, + }) +} + +/// PyCPointerType_paramfunc +fn pointer_paramfunc(obj: &PyObject, vm: &VirtualMachine) -> PyResult { + use super::pointer::PyCPointer; + + let ptr = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected pointer"))?; + + // parg->value.p = *(void **)self->b_ptr + let ptr_val = ptr.get_ptr_value(); + + Ok(CArgObject { + tag: b'P', + value: FfiArgValue::Pointer(ptr_val), + obj: obj.to_owned(), + size: 0, + offset: 0, + }) +} + +/// StructUnionType_paramfunc (for both Structure and Union) +fn struct_union_paramfunc( + obj: &PyObject, + stg_info: &StgInfo, + _vm: &VirtualMachine, +) -> PyResult { + // Get buffer pointer + // For large structs (> sizeof(void*)), we'd need to allocate and copy. + // For now, just point to buffer directly and keep obj reference for memory safety. + let buffer = if let Some(cdata) = obj.downcast_ref::() { + cdata.buffer.read() + } else { + return Ok(CArgObject { + tag: b'V', + value: FfiArgValue::Pointer(0), + obj: obj.to_owned(), + size: stg_info.size, + offset: 0, + }); + }; + + let ptr_val = buffer.as_ptr() as usize; + let size = buffer.len(); + + Ok(CArgObject { + tag: b'V', + value: FfiArgValue::Pointer(ptr_val), + obj: obj.to_owned(), + size, + offset: 0, + }) +} + +// FfiArgValue - Owned FFI argument value + +/// Owned FFI argument value. Keeps the value alive for the duration of the FFI call. +#[derive(Debug, Clone)] +pub enum FfiArgValue { + U8(u8), + I8(i8), + U16(u16), + I16(i16), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + F32(f32), + F64(f64), + Pointer(usize), + /// Pointer with owned data. The PyObjectRef keeps the pointed data alive. + OwnedPointer(usize, #[allow(dead_code)] crate::PyObjectRef), +} + +impl FfiArgValue { + /// Create an Arg reference to this owned value + pub fn as_arg(&self) -> libffi::middle::Arg { + match self { + FfiArgValue::U8(v) => libffi::middle::Arg::new(v), + FfiArgValue::I8(v) => libffi::middle::Arg::new(v), + FfiArgValue::U16(v) => libffi::middle::Arg::new(v), + FfiArgValue::I16(v) => libffi::middle::Arg::new(v), + FfiArgValue::U32(v) => libffi::middle::Arg::new(v), + FfiArgValue::I32(v) => libffi::middle::Arg::new(v), + FfiArgValue::U64(v) => libffi::middle::Arg::new(v), + FfiArgValue::I64(v) => libffi::middle::Arg::new(v), + FfiArgValue::F32(v) => libffi::middle::Arg::new(v), + FfiArgValue::F64(v) => libffi::middle::Arg::new(v), + FfiArgValue::Pointer(v) => libffi::middle::Arg::new(v), + FfiArgValue::OwnedPointer(v, _) => libffi::middle::Arg::new(v), + } + } +} + +/// Convert buffer bytes to FfiArgValue based on type code +pub(super) fn buffer_to_ffi_value(type_code: &str, buffer: &[u8]) -> FfiArgValue { + match type_code { + "c" | "b" => { + let v = buffer.first().map(|&b| b as i8).unwrap_or(0); + FfiArgValue::I8(v) + } + "B" => { + let v = buffer.first().copied().unwrap_or(0); + FfiArgValue::U8(v) + } + "h" => { + let v = if buffer.len() >= 2 { + i16::from_ne_bytes(buffer[..2].try_into().unwrap()) + } else { + 0 + }; + FfiArgValue::I16(v) + } + "H" => { + let v = if buffer.len() >= 2 { + u16::from_ne_bytes(buffer[..2].try_into().unwrap()) + } else { + 0 + }; + FfiArgValue::U16(v) + } + "i" => { + let v = if buffer.len() >= 4 { + i32::from_ne_bytes(buffer[..4].try_into().unwrap()) + } else { + 0 + }; + FfiArgValue::I32(v) + } + "I" => { + let v = if buffer.len() >= 4 { + u32::from_ne_bytes(buffer[..4].try_into().unwrap()) + } else { + 0 + }; + FfiArgValue::U32(v) + } + "l" | "q" => { + let v = if buffer.len() >= 8 { + i64::from_ne_bytes(buffer[..8].try_into().unwrap()) + } else if buffer.len() >= 4 { + i32::from_ne_bytes(buffer[..4].try_into().unwrap()) as i64 + } else { + 0 + }; + FfiArgValue::I64(v) + } + "L" | "Q" => { + let v = if buffer.len() >= 8 { + u64::from_ne_bytes(buffer[..8].try_into().unwrap()) + } else if buffer.len() >= 4 { + u32::from_ne_bytes(buffer[..4].try_into().unwrap()) as u64 + } else { + 0 + }; + FfiArgValue::U64(v) + } + "f" => { + let v = if buffer.len() >= 4 { + f32::from_ne_bytes(buffer[..4].try_into().unwrap()) + } else { + 0.0 + }; + FfiArgValue::F32(v) + } + "d" | "g" => { + let v = if buffer.len() >= 8 { + f64::from_ne_bytes(buffer[..8].try_into().unwrap()) } else { - None + 0.0 }; + FfiArgValue::F64(v) } - if let Ok(_f) = value.try_float(vm) { - todo!(); + "z" | "Z" | "P" | "O" => FfiArgValue::Pointer(read_ptr_from_buffer(buffer)), + "?" => { + let v = buffer.first().map(|&b| b != 0).unwrap_or(false); + FfiArgValue::U8(if v { 1 } else { 0 }) } - if let Ok(_b) = value.try_to_bool(vm) { - todo!(); + "u" => { + // wchar_t - 4 bytes on most platforms + let v = if buffer.len() >= 4 { + u32::from_ne_bytes(buffer[..4].try_into().unwrap()) + } else { + 0 + }; + FfiArgValue::U32(v) } - None + _ => FfiArgValue::Pointer(0), } } -static SIMPLE_BUFFER_METHODS: BufferMethods = BufferMethods { - obj_bytes: |buffer| { - rustpython_common::lock::PyMappedRwLockReadGuard::map( - rustpython_common::lock::PyRwLockReadGuard::map( - buffer.obj_as::().cdata.read(), - |x: &CDataObject| x, - ), - |x: &CDataObject| x.buffer.as_slice(), - ) - .into() - }, - obj_bytes_mut: |buffer| { - rustpython_common::lock::PyMappedRwLockWriteGuard::map( - rustpython_common::lock::PyRwLockWriteGuard::map( - buffer.obj_as::().cdata.write(), - |x: &mut CDataObject| x, - ), - |x: &mut CDataObject| x.buffer.as_mut_slice(), - ) - .into() - }, - release: |_| {}, - retain: |_| {}, -}; +/// Convert bytes to appropriate Python object based on ctypes type +pub(super) fn bytes_to_pyobject( + cls: &Py, + bytes: &[u8], + vm: &VirtualMachine, +) -> PyResult { + // Try to get _type_ attribute + if let Ok(type_attr) = cls.as_object().get_attr("_type_", vm) + && let Ok(s) = type_attr.str(vm) + { + let ty = s.to_string(); + return match ty.as_str() { + "c" => Ok(vm.ctx.new_bytes(bytes.to_vec()).into()), + "b" => { + let val = if !bytes.is_empty() { bytes[0] as i8 } else { 0 }; + Ok(vm.ctx.new_int(val).into()) + } + "B" => { + let val = if !bytes.is_empty() { bytes[0] } else { 0 }; + Ok(vm.ctx.new_int(val).into()) + } + "h" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_short::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_int(val).into()) + } + "H" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_ushort::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_int(val).into()) + } + "i" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_int::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_int(val).into()) + } + "I" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_uint::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_int(val).into()) + } + "l" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_long::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_int(val).into()) + } + "L" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_ulong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_int(val).into()) + } + "q" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_longlong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_int(val).into()) + } + "Q" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_ulonglong::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_int(val).into()) + } + "f" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_float::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0.0 + }; + Ok(vm.ctx.new_float(val as f64).into()) + } + "d" => { + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_double::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0.0 + }; + Ok(vm.ctx.new_float(val).into()) + } + "g" => { + // long double - read as f64 for now since Rust doesn't have native long double + // This may lose precision on platforms where long double > 64 bits + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_double::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0.0 + }; + Ok(vm.ctx.new_float(val).into()) + } + "?" => { + let val = !bytes.is_empty() && bytes[0] != 0; + Ok(vm.ctx.new_bool(val).into()) + } + "v" => { + // VARIANT_BOOL: non-zero = True, zero = False + const SIZE: usize = mem::size_of::(); + let val = if bytes.len() >= SIZE { + c_short::from_ne_bytes(bytes[..SIZE].try_into().expect("size checked")) + } else { + 0 + }; + Ok(vm.ctx.new_bool(val != 0).into()) + } + "z" => { + // c_char_p: read NULL-terminated string from pointer + let ptr = read_ptr_from_buffer(bytes); + if ptr == 0 { + return Ok(vm.ctx.none()); + } + let c_str = unsafe { std::ffi::CStr::from_ptr(ptr as _) }; + Ok(vm.ctx.new_bytes(c_str.to_bytes().to_vec()).into()) + } + "Z" => { + // c_wchar_p: read NULL-terminated wide string from pointer + let ptr = read_ptr_from_buffer(bytes); + if ptr == 0 { + return Ok(vm.ctx.none()); + } + let len = unsafe { libc::wcslen(ptr as *const libc::wchar_t) }; + let wchars = + unsafe { std::slice::from_raw_parts(ptr as *const libc::wchar_t, len) }; + let s: String = wchars + .iter() + .filter_map(|&c| char::from_u32(c as u32)) + .collect(); + Ok(vm.ctx.new_str(s).into()) + } + "P" => { + // c_void_p: return pointer value as integer + let val = read_ptr_from_buffer(bytes); + if val == 0 { + return Ok(vm.ctx.none()); + } + Ok(vm.ctx.new_int(val).into()) + } + "u" => { + let val = if bytes.len() >= mem::size_of::() { + let wc = if mem::size_of::() == 2 { + u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 + } else { + u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) + }; + char::from_u32(wc).unwrap_or('\0') + } else { + '\0' + }; + Ok(vm.ctx.new_str(val).into()) + } + _ => Ok(vm.ctx.none()), + }; + } + // Default: return bytes as-is + Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) +} -impl AsBuffer for PyCSimple { - fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { - let buffer_len = zelf.cdata.read().buffer.len(); - let buf = PyBuffer::new( - zelf.to_owned().into(), - BufferDescriptor::simple(buffer_len, false), // readonly=false for ctypes - &SIMPLE_BUFFER_METHODS, - ); - Ok(buf) +// Shared functions for Structure and Union types + +/// Parse a non-negative integer attribute, returning default if not present +pub(super) fn get_usize_attr( + obj: &PyObject, + attr: &str, + default: usize, + vm: &VirtualMachine, +) -> PyResult { + let Ok(attr_val) = obj.get_attr(vm.ctx.intern_str(attr), vm) else { + return Ok(default); + }; + let n = attr_val + .try_int(vm) + .map_err(|_| vm.new_value_error(format!("{attr} must be a non-negative integer")))?; + let val = n.as_bigint(); + if val.is_negative() { + return Err(vm.new_value_error(format!("{attr} must be a non-negative integer"))); + } + Ok(val.to_usize().unwrap_or(default)) +} + +/// Read a pointer value from buffer +#[inline] +pub(super) fn read_ptr_from_buffer(buffer: &[u8]) -> usize { + const PTR_SIZE: usize = std::mem::size_of::(); + if buffer.len() >= PTR_SIZE { + usize::from_ne_bytes(buffer[..PTR_SIZE].try_into().unwrap()) + } else { + 0 + } +} + +/// Set or initialize StgInfo on a type +pub(super) fn set_or_init_stginfo(type_ref: &PyType, stg_info: StgInfo) { + if type_ref.init_type_data(stg_info.clone()).is_err() + && let Some(mut existing) = type_ref.get_type_data_mut::() + { + *existing = stg_info; + } +} + +/// Check if a field type supports byte order swapping +pub(super) fn check_other_endian_support( + field_type: &PyObject, + vm: &VirtualMachine, +) -> PyResult<()> { + let other_endian_attr = if cfg!(target_endian = "little") { + "__ctype_be__" + } else { + "__ctype_le__" + }; + + if field_type.get_attr(other_endian_attr, vm).is_ok() { + return Ok(()); + } + + // Array type: recursively check element type + if let Ok(elem_type) = field_type.get_attr("_type_", vm) + && field_type.get_attr("_length_", vm).is_ok() + { + return check_other_endian_support(&elem_type, vm); } + + // Structure/Union: has StgInfo but no _type_ attribute + if let Some(type_obj) = field_type.downcast_ref::() + && type_obj.stg_info_opt().is_some() + && field_type.get_attr("_type_", vm).is_err() + { + return Ok(()); + } + + Err(vm.new_type_error(format!( + "This type does not support other endian: {}", + field_type.class().name() + ))) +} + +/// Get the size of a ctypes field type +pub(super) fn get_field_size(field_type: &PyObject, vm: &VirtualMachine) -> PyResult { + if let Some(type_obj) = field_type.downcast_ref::() + && let Some(stg_info) = type_obj.stg_info_opt() + { + return Ok(stg_info.size); + } + + if let Some(size) = field_type + .get_attr("_type_", vm) + .ok() + .and_then(|type_attr| type_attr.str(vm).ok()) + .and_then(|type_str| { + let s = type_str.to_string(); + (s.len() == 1).then(|| super::get_size(&s)) + }) + { + return Ok(size); + } + + if let Some(s) = field_type + .get_attr("size_of_instances", vm) + .ok() + .and_then(|size_method| size_method.call((), vm).ok()) + .and_then(|size| size.try_int(vm).ok()) + .and_then(|n| n.as_bigint().to_usize()) + { + return Ok(s); + } + + Ok(std::mem::size_of::()) +} + +/// Get the alignment of a ctypes field type +pub(super) fn get_field_align(field_type: &PyObject, vm: &VirtualMachine) -> usize { + if let Some(type_obj) = field_type.downcast_ref::() + && let Some(stg_info) = type_obj.stg_info_opt() + && stg_info.align > 0 + { + return stg_info.align; + } + + if let Some(align) = field_type + .get_attr("_type_", vm) + .ok() + .and_then(|type_attr| type_attr.str(vm).ok()) + .and_then(|type_str| { + let s = type_str.to_string(); + (s.len() == 1).then(|| super::get_size(&s)) + }) + { + return align; + } + + 1 +} + +/// Promote fields from anonymous struct/union to parent type +fn make_fields( + cls: &Py, + descr: &super::PyCField, + index: usize, + offset: isize, + vm: &VirtualMachine, +) -> PyResult<()> { + use crate::builtins::{PyList, PyTuple}; + use crate::convert::ToPyObject; + + let fields = descr.proto.as_object().get_attr("_fields_", vm)?; + let fieldlist: Vec = if let Some(list) = fields.downcast_ref::() { + list.borrow_vec().to_vec() + } else if let Some(tuple) = fields.downcast_ref::() { + tuple.to_vec() + } else { + return Err(vm.new_type_error("_fields_ must be a sequence")); + }; + + for pair in fieldlist.iter() { + let field_tuple = pair + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("_fields_ must contain tuples"))?; + + if field_tuple.len() < 2 { + continue; + } + + let fname = field_tuple + .first() + .expect("len checked") + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("field name must be a string"))?; + + let fdescr_obj = descr + .proto + .as_object() + .get_attr(vm.ctx.intern_str(fname.as_str()), vm)?; + let fdescr = fdescr_obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("unexpected type"))?; + + if fdescr.anonymous { + make_fields( + cls, + fdescr, + index + fdescr.index, + offset + fdescr.offset, + vm, + )?; + continue; + } + + let new_descr = super::PyCField::new_from_field(fdescr, index, offset); + cls.set_attr(vm.ctx.intern_str(fname.as_str()), new_descr.to_pyobject(vm)); + } + + Ok(()) +} + +/// Process _anonymous_ attribute for struct/union +pub(super) fn make_anon_fields(cls: &Py, vm: &VirtualMachine) -> PyResult<()> { + use crate::builtins::{PyList, PyTuple}; + use crate::convert::ToPyObject; + + let anon = match cls.as_object().get_attr("_anonymous_", vm) { + Ok(anon) => anon, + Err(_) => return Ok(()), + }; + + let anon_names: Vec = if let Some(list) = anon.downcast_ref::() { + list.borrow_vec().to_vec() + } else if let Some(tuple) = anon.downcast_ref::() { + tuple.to_vec() + } else { + return Err(vm.new_type_error("_anonymous_ must be a sequence")); + }; + + for fname_obj in anon_names.iter() { + let fname = fname_obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("_anonymous_ items must be strings"))?; + + let descr_obj = cls + .as_object() + .get_attr(vm.ctx.intern_str(fname.as_str()), vm)?; + + let descr = descr_obj.downcast_ref::().ok_or_else(|| { + vm.new_attribute_error(format!( + "'{}' is specified in _anonymous_ but not in _fields_", + fname.as_str() + )) + })?; + + let mut new_descr = super::PyCField::new_from_field(descr, 0, 0); + new_descr.set_anonymous(true); + cls.set_attr(vm.ctx.intern_str(fname.as_str()), new_descr.to_pyobject(vm)); + + make_fields(cls, descr, descr.index, descr.offset, vm)?; + } + + Ok(()) } diff --git a/crates/vm/src/stdlib/ctypes/function.rs b/crates/vm/src/stdlib/ctypes/function.rs index b4e600f77ba..9bddb0ef0e8 100644 --- a/crates/vm/src/stdlib/ctypes/function.rs +++ b/crates/vm/src/stdlib/ctypes/function.rs @@ -1,67 +1,310 @@ // spell-checker:disable -use crate::builtins::{PyNone, PyStr, PyTuple, PyTupleRef, PyType, PyTypeRef}; -use crate::convert::ToPyObject; -use crate::function::FuncArgs; -use crate::stdlib::ctypes::PyCData; -use crate::stdlib::ctypes::base::{CDataObject, PyCSimple, ffi_type_from_str}; -use crate::stdlib::ctypes::thunk::PyCThunk; -use crate::types::Representable; -use crate::types::{Callable, Constructor}; -use crate::{AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; -use crossbeam_utils::atomic::AtomicCell; -use libffi::middle::{Arg, Cif, CodePtr, Type}; +use super::{ + _ctypes::CArgObject, PyCArray, PyCData, PyCPointer, PyCStructure, base::FfiArgValue, + simple::PyCSimple, type_info, +}; +use crate::{ + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyBytes, PyDict, PyNone, PyStr, PyTuple, PyType, PyTypeRef}, + class::StaticType, + convert::ToPyObject, + function::FuncArgs, + types::{Callable, Constructor, Representable}, + vm::thread::with_current_vm, +}; +use libffi::{ + low, + middle::{Arg, Cif, Closure, CodePtr, Type}, +}; use libloading::Symbol; use num_traits::ToPrimitive; use rustpython_common::lock::PyRwLock; use std::ffi::{self, c_void}; use std::fmt::Debug; -// See also: https://github.com/python/cpython/blob/4f8bb3947cfbc20f970ff9d9531e1132a9e95396/Modules/_ctypes/callproc.c#L15 +// Internal function addresses for special ctypes functions +pub(super) const INTERNAL_CAST_ADDR: usize = 1; +pub(super) const INTERNAL_STRING_AT_ADDR: usize = 2; +pub(super) const INTERNAL_WSTRING_AT_ADDR: usize = 3; type FP = unsafe extern "C" fn(); -pub trait ArgumentType { +/// Get FFI type for a ctypes type code +fn get_ffi_type(ty: &str) -> Option { + type_info(ty).map(|t| (t.ffi_type_fn)()) +} + +// PyCFuncPtr - Function pointer implementation + +/// Get FFI type from CArgObject tag character +fn ffi_type_from_tag(tag: u8) -> Type { + match tag { + b'c' | b'b' => Type::i8(), + b'B' => Type::u8(), + b'h' => Type::i16(), + b'H' => Type::u16(), + b'i' => Type::i32(), + b'I' => Type::u32(), + b'l' => { + if std::mem::size_of::() == 8 { + Type::i64() + } else { + Type::i32() + } + } + b'L' => { + if std::mem::size_of::() == 8 { + Type::u64() + } else { + Type::u32() + } + } + b'q' => Type::i64(), + b'Q' => Type::u64(), + b'f' => Type::f32(), + b'd' | b'g' => Type::f64(), + b'?' => Type::u8(), + b'u' => { + if std::mem::size_of::() == 2 { + Type::u16() + } else { + Type::u32() + } + } + _ => Type::pointer(), // 'P', 'V', 'z', 'Z', 'O', etc. + } +} + +/// Convert any object to a pointer value for c_void_p arguments +/// Follows ConvParam logic for pointer types +fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult { + // 0. CArgObject (from byref()) -> buffer address + offset + if let Some(carg) = value.downcast_ref::() { + // Get buffer address from the underlying object + let base_addr = if let Some(cdata) = carg.obj.downcast_ref::() { + cdata.buffer.read().as_ptr() as usize + } else { + return Err(vm.new_type_error(format!( + "byref() argument must be a ctypes instance, not '{}'", + carg.obj.class().name() + ))); + }; + let addr = (base_addr as isize + carg.offset) as usize; + return Ok(FfiArgValue::Pointer(addr)); + } + + // 1. None -> NULL + if value.is(&vm.ctx.none) { + return Ok(FfiArgValue::Pointer(0)); + } + + // 2. PyCArray -> buffer address (PyCArrayType_paramfunc) + if let Some(array) = value.downcast_ref::() { + let addr = array.0.buffer.read().as_ptr() as usize; + return Ok(FfiArgValue::Pointer(addr)); + } + + // 3. PyCPointer -> stored pointer value + if let Some(ptr) = value.downcast_ref::() { + return Ok(FfiArgValue::Pointer(ptr.get_ptr_value())); + } + + // 4. PyCStructure -> buffer address + if let Some(struct_obj) = value.downcast_ref::() { + let addr = struct_obj.0.buffer.read().as_ptr() as usize; + return Ok(FfiArgValue::Pointer(addr)); + } + + // 5. PyCSimple (c_void_p, c_char_p, etc.) -> value from buffer + if let Some(simple) = value.downcast_ref::() { + let buffer = simple.0.buffer.read(); + if buffer.len() >= std::mem::size_of::() { + let addr = super::base::read_ptr_from_buffer(&buffer); + return Ok(FfiArgValue::Pointer(addr)); + } + } + + // 6. bytes -> buffer address (PyBytes_AsString) + if let Some(bytes) = value.downcast_ref::() { + let addr = bytes.as_bytes().as_ptr() as usize; + return Ok(FfiArgValue::Pointer(addr)); + } + + // 7. Integer -> direct value + if let Ok(int_val) = value.try_int(vm) { + return Ok(FfiArgValue::Pointer( + int_val.as_bigint().to_usize().unwrap_or(0), + )); + } + + // 8. Check _as_parameter_ attribute ( recursive ConvParam) + if let Ok(as_param) = value.get_attr("_as_parameter_", vm) { + return convert_to_pointer(&as_param, vm); + } + + Err(vm.new_type_error(format!( + "cannot convert '{}' to c_void_p", + value.class().name() + ))) +} + +/// ConvParam-like conversion for when argtypes is None +/// Returns both the FFI type and the converted value +fn conv_param(value: &PyObject, vm: &VirtualMachine) -> PyResult<(Type, FfiArgValue)> { + // 1. CArgObject (from byref() or paramfunc) -> use stored type and value + if let Some(carg) = value.downcast_ref::() { + let ffi_type = ffi_type_from_tag(carg.tag); + return Ok((ffi_type, carg.value.clone())); + } + + // 2. None -> NULL pointer + if value.is(&vm.ctx.none) { + return Ok((Type::pointer(), FfiArgValue::Pointer(0))); + } + + // 3. ctypes objects -> use paramfunc + if let Ok(carg) = super::base::call_paramfunc(value, vm) { + let ffi_type = ffi_type_from_tag(carg.tag); + return Ok((ffi_type, carg.value.clone())); + } + + // 4. Python str -> pointer (use internal UTF-8 buffer) + if let Some(s) = value.downcast_ref::() { + let addr = s.as_str().as_ptr() as usize; + return Ok((Type::pointer(), FfiArgValue::Pointer(addr))); + } + + // 9. Python bytes -> pointer to buffer + if let Some(bytes) = value.downcast_ref::() { + let addr = bytes.as_bytes().as_ptr() as usize; + return Ok((Type::pointer(), FfiArgValue::Pointer(addr))); + } + + // 10. Python int -> i32 (default integer type) + if let Ok(int_val) = value.try_int(vm) { + let val = int_val.as_bigint().to_i32().unwrap_or(0); + return Ok((Type::i32(), FfiArgValue::I32(val))); + } + + // 11. Python float -> f64 + if let Ok(float_val) = value.try_float(vm) { + return Ok((Type::f64(), FfiArgValue::F64(float_val.to_f64()))); + } + + // 12. Check _as_parameter_ attribute + if let Ok(as_param) = value.get_attr("_as_parameter_", vm) { + return conv_param(&as_param, vm); + } + + Err(vm.new_type_error(format!( + "Don't know how to convert parameter {}", + value.class().name() + ))) +} + +trait ArgumentType { fn to_ffi_type(&self, vm: &VirtualMachine) -> PyResult; - fn convert_object(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult; + fn convert_object(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult; } impl ArgumentType for PyTypeRef { fn to_ffi_type(&self, vm: &VirtualMachine) -> PyResult { + use super::pointer::PyCPointer; + use super::structure::PyCStructure; + + // CArgObject (from byref()) should be treated as pointer + if self.fast_issubclass(CArgObject::static_type()) { + return Ok(Type::pointer()); + } + + // Pointer types (POINTER(T)) are always pointer FFI type + // Check if type is a subclass of _Pointer (PyCPointer) + if self.fast_issubclass(PyCPointer::static_type()) { + return Ok(Type::pointer()); + } + + // Structure types are passed as pointers + if self.fast_issubclass(PyCStructure::static_type()) { + return Ok(Type::pointer()); + } + + // Use get_attr to traverse MRO (for subclasses like MyInt(c_int)) let typ = self - .get_class_attr(vm.ctx.intern_str("_type_")) - .ok_or(vm.new_type_error("Unsupported argument type".to_string()))?; + .as_object() + .get_attr(vm.ctx.intern_str("_type_"), vm) + .ok() + .ok_or(vm.new_type_error("Unsupported argument type"))?; let typ = typ .downcast_ref::() - .ok_or(vm.new_type_error("Unsupported argument type".to_string()))?; + .ok_or(vm.new_type_error("Unsupported argument type"))?; let typ = typ.to_string(); let typ = typ.as_str(); - let converted_typ = ffi_type_from_str(typ); - if let Some(typ) = converted_typ { - Ok(typ) - } else { - Err(vm.new_type_error(format!("Unsupported argument type: {}", typ))) - } + get_ffi_type(typ) + .ok_or_else(|| vm.new_type_error(format!("Unsupported argument type: {}", typ))) } - fn convert_object(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { - // if self.fast_isinstance::(vm) { - // let array = value.downcast::()?; - // return Ok(Arg::from(array.as_ptr())); - // } - if let Ok(simple) = value.downcast::() { + fn convert_object(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // Call from_param first to convert the value (like CPython's callproc.c:1235) + // converter = PyTuple_GET_ITEM(argtypes, i); + // v = PyObject_CallOneArg(converter, arg); + let from_param = self + .as_object() + .get_attr(vm.ctx.intern_str("from_param"), vm)?; + let converted = from_param.call((value.clone(),), vm)?; + + // Then pass the converted value to ConvParam logic + // CArgObject (from from_param) -> use stored value directly + if let Some(carg) = converted.downcast_ref::() { + return Ok(carg.value.clone()); + } + + // None -> NULL pointer + if vm.is_none(&converted) { + return Ok(FfiArgValue::Pointer(0)); + } + + // For pointer types (POINTER(T)), we need to pass the ADDRESS of the value's buffer + if self.fast_issubclass(PyCPointer::static_type()) { + if let Some(cdata) = converted.downcast_ref::() { + let addr = cdata.buffer.read().as_ptr() as usize; + return Ok(FfiArgValue::Pointer(addr)); + } + return convert_to_pointer(&converted, vm); + } + + // For structure types, convert to pointer to structure + if self.fast_issubclass(PyCStructure::static_type()) { + return convert_to_pointer(&converted, vm); + } + + // Get the type code for this argument type + let type_code = self + .as_object() + .get_attr(vm.ctx.intern_str("_type_"), vm) + .ok() + .and_then(|t| t.downcast_ref::().map(|s| s.to_string())); + + // For pointer types (c_void_p, c_char_p, c_wchar_p), handle as pointer + if matches!(type_code.as_deref(), Some("P") | Some("z") | Some("Z")) { + return convert_to_pointer(&converted, vm); + } + + // PyCSimple (already a ctypes instance from from_param) + if let Ok(simple) = converted.clone().downcast::() { let typ = ArgumentType::to_ffi_type(self, vm)?; - let arg = simple - .to_arg(typ, vm) - .ok_or(vm.new_type_error("Unsupported argument type".to_string()))?; - return Ok(arg); + let ffi_value = simple + .to_ffi_value(typ, vm) + .ok_or(vm.new_type_error("Unsupported argument type"))?; + return Ok(ffi_value); } - Err(vm.new_type_error("Unsupported argument type".to_string())) + + Err(vm.new_type_error("Unsupported argument type")) } } -pub trait ReturnType { - fn to_ffi_type(&self) -> Option; +trait ReturnType { + fn to_ffi_type(&self, vm: &VirtualMachine) -> Option; #[allow(clippy::wrong_self_convention)] fn from_ffi_type( &self, @@ -71,8 +314,34 @@ pub trait ReturnType { } impl ReturnType for PyTypeRef { - fn to_ffi_type(&self) -> Option { - ffi_type_from_str(self.name().to_string().as_str()) + fn to_ffi_type(&self, vm: &VirtualMachine) -> Option { + // Try to get _type_ attribute first (for ctypes types like c_void_p) + if let Ok(type_attr) = self.as_object().get_attr(vm.ctx.intern_str("_type_"), vm) + && let Some(s) = type_attr.downcast_ref::() + && let Some(ffi_type) = get_ffi_type(s.as_str()) + { + return Some(ffi_type); + } + + // Check for Structure/Array types (have StgInfo but no _type_) + // _ctypes_get_ffi_type: returns appropriately sized type for struct returns + if let Some(stg_info) = self.stg_info_opt() { + let size = stg_info.size; + // Small structs can be returned in registers + // Match can_return_struct_as_int/can_return_struct_as_sint64 + return Some(if size <= 4 { + Type::i32() + } else if size <= 8 { + Type::i64() + } else { + // Large structs: use pointer-sized return + // (ABI typically returns via hidden pointer parameter) + Type::pointer() + }); + } + + // Fallback to class name + get_ffi_type(self.name().to_string().as_str()) } fn from_ffi_type( @@ -80,9 +349,11 @@ impl ReturnType for PyTypeRef { value: *mut ffi::c_void, vm: &VirtualMachine, ) -> PyResult> { - // Get the type code from _type_ attribute + // Get the type code from _type_ attribute (use get_attr to traverse MRO) let type_code = self - .get_class_attr(vm.ctx.intern_str("_type_")) + .as_object() + .get_attr(vm.ctx.intern_str("_type_"), vm) + .ok() .and_then(|t| t.downcast_ref::().map(|s| s.to_string())); let result = match type_code.as_deref() { @@ -129,27 +400,59 @@ impl ReturnType for PyTypeRef { .new_float(unsafe { *(value as *const f32) } as f64) .into(), Some("d") => vm.ctx.new_float(unsafe { *(value as *const f64) }).into(), - Some("P") | Some("z") | Some("Z") => vm.ctx.new_int(value as usize).into(), + Some("P") | Some("z") | Some("Z") => { + vm.ctx.new_int(unsafe { *(value as *const usize) }).into() + } Some("?") => vm .ctx .new_bool(unsafe { *(value as *const u8) } != 0) .into(), None => { - // No _type_ attribute, try to create an instance of the type - // This handles cases like Structure or Array return types - return Ok(Some( - vm.ctx.new_int(unsafe { *(value as *const i32) }).into(), - )); + // No _type_ attribute - check for Structure/Array types + // GetResult: PyCData_FromBaseObj creates instance from memory + if let Some(stg_info) = self.stg_info_opt() { + let size = stg_info.size; + // Create instance of the ctypes type + let instance = self.as_object().call((), vm)?; + + // Copy return value memory into instance buffer + // Use a block to properly scope the borrow + { + let src = unsafe { std::slice::from_raw_parts(value as *const u8, size) }; + if let Some(cdata) = instance.downcast_ref::() { + let mut buffer = cdata.buffer.write(); + if buffer.len() >= size { + buffer.to_mut()[..size].copy_from_slice(src); + } + } else if let Some(structure) = instance.downcast_ref::() { + let mut buffer = structure.0.buffer.write(); + if buffer.len() >= size { + buffer.to_mut()[..size].copy_from_slice(src); + } + } else if let Some(array) = instance.downcast_ref::() { + let mut buffer = array.0.buffer.write(); + if buffer.len() >= size { + buffer.to_mut()[..size].copy_from_slice(src); + } + } + } + return Ok(Some(instance)); + } + // Not a ctypes type - call type with int result + return self + .as_object() + .call((unsafe { *(value as *const i32) },), vm) + .map(Some); } - _ => return Err(vm.new_type_error("Unsupported return type".to_string())), + _ => return Err(vm.new_type_error("Unsupported return type")), }; Ok(Some(result)) } } impl ReturnType for PyNone { - fn to_ffi_type(&self) -> Option { - ffi_type_from_str("void") + fn to_ffi_type(&self, _vm: &VirtualMachine) -> Option { + get_ffi_type("void") } fn from_ffi_type( @@ -161,47 +464,319 @@ impl ReturnType for PyNone { } } +/// PyCFuncPtr - Function pointer instance +/// Saved in _base.buffer #[pyclass(module = "_ctypes", name = "CFuncPtr", base = PyCData)] -pub struct PyCFuncPtr { - _base: PyCData, - pub name: PyRwLock>, - pub ptr: PyRwLock>, - #[allow(dead_code)] - pub needs_free: AtomicCell, - pub arg_types: PyRwLock>>, - pub res_type: PyRwLock>, - pub _flags_: AtomicCell, - #[allow(dead_code)] - pub handler: PyObjectRef, +#[repr(C)] +pub(super) struct PyCFuncPtr { + pub _base: PyCData, + /// Thunk for callbacks (keeps thunk alive) + pub thunk: PyRwLock>>, + /// Original Python callable (for callbacks) + pub callable: PyRwLock>, + /// Converters cache + pub converters: PyRwLock>, + /// Instance-level argtypes override + pub argtypes: PyRwLock>, + /// Instance-level restype override + pub restype: PyRwLock>, + /// Checker function + pub checker: PyRwLock>, + /// Error checking function + pub errcheck: PyRwLock>, + /// COM method vtable index + /// When set, the function reads the function pointer from the vtable at call time + #[cfg(windows)] + pub index: PyRwLock>, + /// COM method IID (interface ID) for error handling + #[cfg(windows)] + pub iid: PyRwLock>, + /// Parameter flags for COM methods (direction: IN=1, OUT=2, IN|OUT=4) + /// Each element is (direction, name, default) tuple + pub paramflags: PyRwLock>, } impl Debug for PyCFuncPtr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PyCFuncPtr") - .field("flags", &self._flags_) + .field("func_ptr", &self.get_func_ptr()) .finish() } } +/// Extract pointer value from a ctypes argument (c_void_p conversion) +fn extract_ptr_from_arg(arg: &PyObject, vm: &VirtualMachine) -> PyResult { + // Try to get pointer value from various ctypes types + if let Some(ptr) = arg.downcast_ref::() { + return Ok(ptr.get_ptr_value()); + } + if let Some(simple) = arg.downcast_ref::() { + let buffer = simple.0.buffer.read(); + if buffer.len() >= std::mem::size_of::() { + return Ok(usize::from_ne_bytes( + buffer[..std::mem::size_of::()].try_into().unwrap(), + )); + } + } + if let Some(cdata) = arg.downcast_ref::() { + // For arrays/structures, return address of buffer + return Ok(cdata.buffer.read().as_ptr() as usize); + } + // PyStr: return internal buffer address + if let Some(s) = arg.downcast_ref::() { + return Ok(s.as_str().as_ptr() as usize); + } + // PyBytes: return internal buffer address + if let Some(bytes) = arg.downcast_ref::() { + return Ok(bytes.as_bytes().as_ptr() as usize); + } + // Try as integer + if let Ok(int_val) = arg.try_int(vm) { + return Ok(int_val.as_bigint().to_usize().unwrap_or(0)); + } + Err(vm.new_type_error(format!( + "cannot convert '{}' to pointer", + arg.class().name() + ))) +} + +/// string_at implementation - read bytes from memory at ptr +fn string_at_impl(ptr: usize, size: isize, vm: &VirtualMachine) -> PyResult { + if ptr == 0 { + return Err(vm.new_value_error("NULL pointer access")); + } + let ptr = ptr as *const u8; + let len = if size < 0 { + // size == -1 means use strlen + unsafe { libc::strlen(ptr as _) } + } else { + // Overflow check for huge size values + let size_usize = size as usize; + if size_usize > isize::MAX as usize / 2 { + return Err(vm.new_overflow_error("string too long")); + } + size_usize + }; + let bytes = unsafe { std::slice::from_raw_parts(ptr, len) }; + Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) +} + +/// wstring_at implementation - read wide string from memory at ptr +fn wstring_at_impl(ptr: usize, size: isize, vm: &VirtualMachine) -> PyResult { + if ptr == 0 { + return Err(vm.new_value_error("NULL pointer access")); + } + let w_ptr = ptr as *const libc::wchar_t; + let len = if size < 0 { + unsafe { libc::wcslen(w_ptr) } + } else { + // Overflow check for huge size values + let size_usize = size as usize; + if size_usize > isize::MAX as usize / std::mem::size_of::() { + return Err(vm.new_overflow_error("string too long")); + } + size_usize + }; + let wchars = unsafe { std::slice::from_raw_parts(w_ptr, len) }; + + // Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide + // macOS/Linux: wchar_t = i32 (UTF-32) -> convert via char::from_u32 + #[cfg(windows)] + { + use rustpython_common::wtf8::Wtf8Buf; + let wide: Vec = wchars.to_vec(); + let wtf8 = Wtf8Buf::from_wide(&wide); + Ok(vm.ctx.new_str(wtf8).into()) + } + #[cfg(not(windows))] + { + let s: String = wchars + .iter() + .filter_map(|&c| char::from_u32(c as u32)) + .collect(); + Ok(vm.ctx.new_str(s).into()) + } +} + +// cast_check_pointertype +fn cast_check_pointertype(ctype: &PyObject, vm: &VirtualMachine) -> bool { + use super::pointer::PyCPointerType; + + // PyCPointerTypeObject_Check + if ctype.class().fast_issubclass(PyCPointerType::static_type()) { + return true; + } + + // PyCFuncPtrTypeObject_Check - TODO + + // simple pointer types via StgInfo.proto (c_void_p, c_char_p, etc.) + if let Ok(type_attr) = ctype.get_attr("_type_", vm) + && let Some(s) = type_attr.downcast_ref::() + { + let c = s.as_str(); + if c.len() == 1 && "sPzUZXO".contains(c) { + return true; + } + } + + false +} + +/// cast implementation +/// _ctypes.c cast() +pub(super) fn cast_impl( + obj: PyObjectRef, + src: PyObjectRef, + ctype: PyObjectRef, + vm: &VirtualMachine, +) -> PyResult { + // 1. cast_check_pointertype + if !cast_check_pointertype(&ctype, vm) { + return Err(vm.new_type_error(format!( + "cast() argument 2 must be a pointer type, not {}", + ctype.class().name() + ))); + } + + // 2. Extract pointer value - matches c_void_p_from_param_impl order + let ptr_value: usize = if vm.is_none(&obj) { + // None → NULL pointer + 0 + } else if let Ok(int_val) = obj.try_int(vm) { + // int/long → direct pointer value + int_val.as_bigint().to_usize().unwrap_or(0) + } else if let Some(bytes) = obj.downcast_ref::() { + // bytes → buffer address (c_void_p_from_param: PyBytes_Check) + bytes.as_bytes().as_ptr() as usize + } else if let Some(s) = obj.downcast_ref::() { + // unicode/str → buffer address (c_void_p_from_param: PyUnicode_Check) + s.as_str().as_ptr() as usize + } else if let Some(ptr) = obj.downcast_ref::() { + // Pointer instance → contained pointer value + ptr.get_ptr_value() + } else if let Some(simple) = obj.downcast_ref::() { + // Simple type (c_void_p, c_char_p, etc.) → value from buffer + let buffer = simple.0.buffer.read(); + super::base::read_ptr_from_buffer(&buffer) + } else if let Some(cdata) = obj.downcast_ref::() { + // Array, Structure, Union → buffer address (b_ptr) + cdata.buffer.read().as_ptr() as usize + } else { + return Err(vm.new_type_error(format!( + "cast() argument 1 must be a ctypes instance, not {}", + obj.class().name() + ))); + }; + + // 3. Create result instance + let result = ctype.call((), vm)?; + + // 4. _objects reference tracking + // Share _objects dict between source and result, add id(src): src + if src.class().fast_issubclass(PyCData::static_type()) { + // Get the source's _objects, create dict if needed + let shared_objects: PyObjectRef = if let Some(src_cdata) = src.downcast_ref::() { + let mut src_objects = src_cdata.objects.write(); + if src_objects.is_none() { + // Create new dict + let dict = vm.ctx.new_dict(); + *src_objects = Some(dict.clone().into()); + dict.into() + } else if let Some(obj) = src_objects.as_ref() { + if obj.downcast_ref::().is_none() { + // Convert to dict (keep existing reference) + let dict = vm.ctx.new_dict(); + let id_key: PyObjectRef = vm.ctx.new_int(obj.get_id() as i64).into(); + let _ = dict.set_item(&*id_key, obj.clone(), vm); + *src_objects = Some(dict.clone().into()); + dict.into() + } else { + obj.clone() + } + } else { + vm.ctx.new_dict().into() + } + } else { + vm.ctx.new_dict().into() + }; + + // Add id(src): src to the shared dict + if let Some(dict) = shared_objects.downcast_ref::() { + let id_key: PyObjectRef = vm.ctx.new_int(src.get_id() as i64).into(); + let _ = dict.set_item(&*id_key, src.clone(), vm); + } + + // Set result's _objects to the shared dict + if let Some(result_cdata) = result.downcast_ref::() { + *result_cdata.objects.write() = Some(shared_objects); + } + } + + // 5. Store pointer value + if let Some(ptr) = result.downcast_ref::() { + ptr.set_ptr_value(ptr_value); + } else if let Some(cdata) = result.downcast_ref::() { + let bytes = ptr_value.to_ne_bytes(); + let mut buffer = cdata.buffer.write(); + let buf = buffer.to_mut(); + if buf.len() >= bytes.len() { + buf[..bytes.len()].copy_from_slice(&bytes); + } + } + + Ok(result) +} + +impl PyCFuncPtr { + /// Get function pointer address from buffer + fn get_func_ptr(&self) -> usize { + let buffer = self._base.buffer.read(); + super::base::read_ptr_from_buffer(&buffer) + } + + /// Get CodePtr from buffer for FFI calls + fn get_code_ptr(&self) -> Option { + let addr = self.get_func_ptr(); + if addr != 0 { + Some(CodePtr(addr as *mut _)) + } else { + None + } + } + + /// Create buffer with function pointer address + fn make_ptr_buffer(addr: usize) -> Vec { + addr.to_ne_bytes().to_vec() + } +} + impl Constructor for PyCFuncPtr { type Args = FuncArgs; fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - // Handle different argument forms like CPython: - // 1. Empty args: create uninitialized + // Handle different argument forms: + // 1. Empty args: create uninitialized (NULL pointer) // 2. One integer argument: function address // 3. Tuple argument: (name, dll) form + // 4. Callable: callback creation + + let ptr_size = std::mem::size_of::(); if args.args.is_empty() { return PyCFuncPtr { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - ptr: PyRwLock::new(None), - needs_free: AtomicCell::new(false), - arg_types: PyRwLock::new(None), - _flags_: AtomicCell::new(0), - res_type: PyRwLock::new(None), - name: PyRwLock::new(None), - handler: vm.ctx.none(), + _base: PyCData::from_bytes(vec![0u8; ptr_size], None), + thunk: PyRwLock::new(None), + callable: PyRwLock::new(None), + converters: PyRwLock::new(None), + argtypes: PyRwLock::new(None), + restype: PyRwLock::new(None), + checker: PyRwLock::new(None), + errcheck: PyRwLock::new(None), + #[cfg(windows)] + index: PyRwLock::new(None), + #[cfg(windows)] + iid: PyRwLock::new(None), + paramflags: PyRwLock::new(None), } .into_ref_with_type(vm, cls) .map(Into::into); @@ -209,18 +784,68 @@ impl Constructor for PyCFuncPtr { let first_arg = &args.args[0]; + // Check for COM method form: (index, name, [paramflags], [iid]) + // First arg is integer (vtable index), second arg is string (method name) + if args.args.len() >= 2 + && first_arg.try_int(vm).is_ok() + && args.args[1].downcast_ref::().is_some() + { + #[cfg(windows)] + let index = first_arg.try_int(vm)?.as_bigint().to_usize().unwrap_or(0); + + // args[3] is iid (GUID struct, optional) + // Also check if args[2] is a GUID (has Data1 attribute) when args[3] is not present + #[cfg(windows)] + let iid = args.args.get(3).cloned().or_else(|| { + args.args.get(2).and_then(|arg| { + // If it's a GUID struct (has Data1 attribute), use it as IID + if arg.get_attr("Data1", vm).is_ok() { + Some(arg.clone()) + } else { + None + } + }) + }); + + // args[2] is paramflags (tuple or None) + let paramflags = args.args.get(2).filter(|arg| !vm.is_none(arg)).cloned(); + + return PyCFuncPtr { + _base: PyCData::from_bytes(vec![0u8; ptr_size], None), + thunk: PyRwLock::new(None), + callable: PyRwLock::new(None), + converters: PyRwLock::new(None), + argtypes: PyRwLock::new(None), + restype: PyRwLock::new(None), + checker: PyRwLock::new(None), + errcheck: PyRwLock::new(None), + #[cfg(windows)] + index: PyRwLock::new(Some(index)), + #[cfg(windows)] + iid: PyRwLock::new(iid), + paramflags: PyRwLock::new(paramflags), + } + .into_ref_with_type(vm, cls) + .map(Into::into); + } + // Check if first argument is an integer (function address) if let Ok(addr) = first_arg.try_int(vm) { let ptr_val = addr.as_bigint().to_usize().unwrap_or(0); return PyCFuncPtr { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - ptr: PyRwLock::new(Some(CodePtr(ptr_val as *mut _))), - needs_free: AtomicCell::new(false), - arg_types: PyRwLock::new(None), - _flags_: AtomicCell::new(0), - res_type: PyRwLock::new(None), - name: PyRwLock::new(Some(format!("CFuncPtr@{:#x}", ptr_val))), - handler: vm.ctx.new_int(ptr_val).into(), + _base: PyCData::from_bytes(Self::make_ptr_buffer(ptr_val), None), + thunk: PyRwLock::new(None), + callable: PyRwLock::new(None), + converters: PyRwLock::new(None), + argtypes: PyRwLock::new(None), + restype: PyRwLock::new(None), + checker: PyRwLock::new(None), + errcheck: PyRwLock::new(None), + #[cfg(windows)] + index: PyRwLock::new(None), + #[cfg(windows)] + iid: PyRwLock::new(None), + paramflags: PyRwLock::new(None), } .into_ref_with_type(vm, cls) .map(Into::into); @@ -234,53 +859,58 @@ impl Constructor for PyCFuncPtr { .downcast_ref::() .ok_or(vm.new_type_error("Expected a string"))? .to_string(); - let handler = tuple + let dll = tuple .iter() .nth(1) .ok_or(vm.new_type_error("Expected a tuple with at least 2 elements"))? .clone(); // Get library handle and load function - let handle = handler.try_int(vm); + let handle = dll.try_int(vm); let handle = match handle { Ok(handle) => handle.as_bigint().clone(), - Err(_) => handler + Err(_) => dll .get_attr("_handle", vm)? .try_int(vm)? .as_bigint() .clone(), }; - let library_cache = crate::stdlib::ctypes::library::libcache().read(); + let library_cache = super::library::libcache().read(); let library = library_cache .get_lib( handle .to_usize() - .ok_or(vm.new_value_error("Invalid handle".to_string()))?, + .ok_or(vm.new_value_error("Invalid handle"))?, ) - .ok_or_else(|| vm.new_value_error("Library not found".to_string()))?; + .ok_or_else(|| vm.new_value_error("Library not found"))?; let inner_lib = library.lib.lock(); let terminated = format!("{}\0", &name); - let code_ptr = if let Some(lib) = &*inner_lib { + let ptr_val = if let Some(lib) = &*inner_lib { let pointer: Symbol<'_, FP> = unsafe { lib.get(terminated.as_bytes()) .map_err(|err| err.to_string()) .map_err(|err| vm.new_attribute_error(err))? }; - Some(CodePtr(*pointer as *mut _)) + *pointer as usize } else { - None + 0 }; return PyCFuncPtr { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - ptr: PyRwLock::new(code_ptr), - needs_free: AtomicCell::new(false), - arg_types: PyRwLock::new(None), - _flags_: AtomicCell::new(0), - res_type: PyRwLock::new(None), - name: PyRwLock::new(Some(name)), - handler, + _base: PyCData::from_bytes(Self::make_ptr_buffer(ptr_val), None), + thunk: PyRwLock::new(None), + callable: PyRwLock::new(None), + converters: PyRwLock::new(None), + argtypes: PyRwLock::new(None), + restype: PyRwLock::new(None), + checker: PyRwLock::new(None), + errcheck: PyRwLock::new(None), + #[cfg(windows)] + index: PyRwLock::new(None), + #[cfg(windows)] + iid: PyRwLock::new(None), + paramflags: PyRwLock::new(None), } .into_ref_with_type(vm, cls) .map(Into::into); @@ -289,42 +919,36 @@ impl Constructor for PyCFuncPtr { // Check if first argument is a Python callable (callback creation) if first_arg.is_callable() { // Get argument types and result type from the class - let argtypes = cls.get_attr(vm.ctx.intern_str("_argtypes_")); - let restype = cls.get_attr(vm.ctx.intern_str("_restype_")); + let class_argtypes = cls.get_attr(vm.ctx.intern_str("_argtypes_")); + let class_restype = cls.get_attr(vm.ctx.intern_str("_restype_")); // Create the thunk (C-callable wrapper for the Python function) - let thunk = PyCThunk::new(first_arg.clone(), argtypes.clone(), restype.clone(), vm)?; + let thunk = PyCThunk::new( + first_arg.clone(), + class_argtypes.clone(), + class_restype.clone(), + vm, + )?; let code_ptr = thunk.code_ptr(); - - // Parse argument types for storage - let arg_type_vec: Option> = if let Some(ref args) = argtypes { - if vm.is_none(args) { - None - } else { - let mut types = Vec::new(); - for item in args.try_to_value::>(vm)? { - types.push(item.downcast::().map_err(|_| { - vm.new_type_error("_argtypes_ must be a sequence of types".to_string()) - })?); - } - Some(types) - } - } else { - None - }; + let ptr_val = code_ptr.0 as usize; // Store the thunk as a Python object to keep it alive let thunk_ref: PyRef = thunk.into_ref(&vm.ctx); return PyCFuncPtr { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - ptr: PyRwLock::new(Some(code_ptr)), - needs_free: AtomicCell::new(true), - arg_types: PyRwLock::new(arg_type_vec), - _flags_: AtomicCell::new(0), - res_type: PyRwLock::new(restype), - name: PyRwLock::new(Some("".to_string())), - handler: thunk_ref.into(), + _base: PyCData::from_bytes(Self::make_ptr_buffer(ptr_val), None), + thunk: PyRwLock::new(Some(thunk_ref)), + callable: PyRwLock::new(Some(first_arg.clone())), + converters: PyRwLock::new(None), + argtypes: PyRwLock::new(class_argtypes), + restype: PyRwLock::new(class_restype), + checker: PyRwLock::new(None), + errcheck: PyRwLock::new(None), + #[cfg(windows)] + index: PyRwLock::new(None), + #[cfg(windows)] + iid: PyRwLock::new(None), + paramflags: PyRwLock::new(None), } .into_ref_with_type(vm, cls) .map(Into::into); @@ -338,142 +962,1054 @@ impl Constructor for PyCFuncPtr { } } -impl Callable for PyCFuncPtr { - type Args = FuncArgs; - fn call(zelf: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { - // This is completely seperate from the C python implementation - - // Cif init - let arg_types: Vec<_> = match zelf.arg_types.read().clone() { - Some(tys) => tys, - None => args - .args - .clone() - .into_iter() - .map(|a| a.class().as_object().to_pyobject(vm).downcast().unwrap()) - .collect(), +// PyCFuncPtr call helpers (similar to callproc.c flow) + +/// Handle internal function addresses (PYFUNCTYPE special cases) +/// Returns Some(result) if handled, None if should continue with normal call +fn handle_internal_func(addr: usize, args: &FuncArgs, vm: &VirtualMachine) -> Option { + if addr == INTERNAL_CAST_ADDR { + let result: PyResult<(PyObjectRef, PyObjectRef, PyObjectRef)> = args.clone().bind(vm); + return Some(result.and_then(|(obj, src, ctype)| cast_impl(obj, src, ctype, vm))); + } + + if addr == INTERNAL_STRING_AT_ADDR { + let result: PyResult<(PyObjectRef, Option)> = args.clone().bind(vm); + return Some(result.and_then(|(ptr_arg, size_arg)| { + let ptr = extract_ptr_from_arg(&ptr_arg, vm)?; + let size = size_arg + .and_then(|s| s.try_int(vm).ok()) + .and_then(|i| i.as_bigint().to_isize()) + .unwrap_or(-1); + string_at_impl(ptr, size, vm) + })); + } + + if addr == INTERNAL_WSTRING_AT_ADDR { + let result: PyResult<(PyObjectRef, Option)> = args.clone().bind(vm); + return Some(result.and_then(|(ptr_arg, size_arg)| { + let ptr = extract_ptr_from_arg(&ptr_arg, vm)?; + let size = size_arg + .and_then(|s| s.try_int(vm).ok()) + .and_then(|i| i.as_bigint().to_isize()) + .unwrap_or(-1); + wstring_at_impl(ptr, size, vm) + })); + } + + None +} + +/// Call information extracted from PyCFuncPtr (argtypes, restype, etc.) +struct CallInfo { + explicit_arg_types: Option>, + restype_obj: Option, + restype_is_none: bool, + ffi_return_type: Type, + is_pointer_return: bool, +} + +/// Extract call information (argtypes, restype) from PyCFuncPtr +fn extract_call_info(zelf: &Py, vm: &VirtualMachine) -> PyResult { + // Get argtypes - first from instance, then from type's _argtypes_ + let explicit_arg_types: Option> = + if let Some(argtypes_obj) = zelf.argtypes.read().as_ref() { + if !vm.is_none(argtypes_obj) { + Some( + argtypes_obj + .try_to_value::>(vm)? + .into_iter() + .filter_map(|obj| obj.downcast::().ok()) + .collect(), + ) + } else { + None // argtypes is None -> use ConvParam + } + } else if let Some(class_argtypes) = zelf + .as_object() + .class() + .get_attr(vm.ctx.intern_str("_argtypes_")) + && !vm.is_none(&class_argtypes) + { + Some( + class_argtypes + .try_to_value::>(vm)? + .into_iter() + .filter_map(|obj| obj.downcast::().ok()) + .collect(), + ) + } else { + None // No argtypes -> use ConvParam }; - let ffi_arg_types = arg_types - .clone() - .iter() - .map(|t| ArgumentType::to_ffi_type(t, vm)) - .collect::>>()?; - let return_type = zelf.res_type.read(); - let ffi_return_type = return_type + + // Get restype - first from instance, then from class's _restype_ + let restype_obj = zelf.restype.read().clone().or_else(|| { + zelf.as_object() + .class() + .get_attr(vm.ctx.intern_str("_restype_")) + }); + + // Check if restype is explicitly None (return void) + let restype_is_none = restype_obj.as_ref().is_some_and(|t| vm.is_none(t)); + let ffi_return_type = if restype_is_none { + Type::void() + } else { + restype_obj .as_ref() .and_then(|t| t.clone().downcast::().ok()) - .and_then(|t| ReturnType::to_ffi_type(&t)) - .unwrap_or_else(Type::i32); - let cif = Cif::new(ffi_arg_types, ffi_return_type); - - // Call the function - let ffi_args = args - .args - .into_iter() - .enumerate() - .map(|(n, arg)| { - let arg_type = arg_types - .get(n) - .ok_or_else(|| vm.new_type_error("argument amount mismatch".to_string()))?; - arg_type.convert_object(arg, vm) - }) - .collect::, _>>()?; - let pointer = zelf.ptr.read(); - let code_ptr = pointer - .as_ref() - .ok_or_else(|| vm.new_type_error("Function pointer not set".to_string()))?; - let mut output: c_void = unsafe { cif.call(*code_ptr, &ffi_args) }; - let return_type = return_type + .and_then(|t| ReturnType::to_ffi_type(&t, vm)) + .unwrap_or_else(Type::i32) + }; + + // Check if return type is a pointer type (P, z, Z) - need special handling on 64-bit + let is_pointer_return = restype_obj + .as_ref() + .and_then(|t| t.clone().downcast::().ok()) + .and_then(|t| t.as_object().get_attr(vm.ctx.intern_str("_type_"), vm).ok()) + .and_then(|t| t.downcast_ref::().map(|s| s.to_string())) + .is_some_and(|tc| matches!(tc.as_str(), "P" | "z" | "Z")); + + Ok(CallInfo { + explicit_arg_types, + restype_obj, + restype_is_none, + ffi_return_type, + is_pointer_return, + }) +} + +/// Parsed paramflags: (direction, name, default) tuples +/// direction: 1=IN, 2=OUT, 4=IN|OUT (or 1|2=3) +type ParsedParamFlags = Vec<(u32, Option, Option)>; + +/// Parse paramflags from PyCFuncPtr +fn parse_paramflags( + zelf: &Py, + vm: &VirtualMachine, +) -> PyResult> { + let Some(pf) = zelf.paramflags.read().as_ref().cloned() else { + return Ok(None); + }; + + let pf_vec = pf.try_to_value::>(vm)?; + let parsed = pf_vec + .into_iter() + .map(|item| { + let Some(tuple) = item.downcast_ref::() else { + // Single value means just the direction + let direction = item + .try_int(vm) + .ok() + .and_then(|i| i.as_bigint().to_u32()) + .unwrap_or(1); + return (direction, None, None); + }; + let direction = tuple + .first() + .and_then(|d| d.try_int(vm).ok()) + .and_then(|i| i.as_bigint().to_u32()) + .unwrap_or(1); + let name = tuple + .get(1) + .and_then(|n| n.downcast_ref::().map(|s| s.to_string())); + let default = tuple.get(2).cloned(); + (direction, name, default) + }) + .collect(); + Ok(Some(parsed)) +} + +/// Resolve COM method pointer from vtable (Windows only) +/// Returns (Some(CodePtr), true) if this is a COM method call, (None, false) otherwise +#[cfg(windows)] +fn resolve_com_method( + zelf: &Py, + args: &FuncArgs, + vm: &VirtualMachine, +) -> PyResult<(Option, bool)> { + let com_index = zelf.index.read(); + let Some(idx) = *com_index else { + return Ok((None, false)); + }; + + // First arg must be the COM object pointer + if args.args.is_empty() { + return Err( + vm.new_type_error("COM method requires at least one argument (self)".to_string()) + ); + } + + // Extract COM pointer value from first argument + let self_arg = &args.args[0]; + let com_ptr = if let Some(simple) = self_arg.downcast_ref::() { + let buffer = simple.0.buffer.read(); + if buffer.len() >= std::mem::size_of::() { + super::base::read_ptr_from_buffer(&buffer) + } else { + 0 + } + } else if let Ok(int_val) = self_arg.try_int(vm) { + int_val.as_bigint().to_usize().unwrap_or(0) + } else { + return Err( + vm.new_type_error("COM method first argument must be a COM pointer".to_string()) + ); + }; + + if com_ptr == 0 { + return Err(vm.new_value_error("NULL COM pointer access")); + } + + // Read vtable pointer from COM object: vtable = *(void**)com_ptr + let vtable_ptr = unsafe { *(com_ptr as *const usize) }; + if vtable_ptr == 0 { + return Err(vm.new_value_error("NULL vtable pointer")); + } + + // Read function pointer from vtable: func = vtable[index] + let fptr = unsafe { + let vtable = vtable_ptr as *const usize; + *vtable.add(idx) + }; + + if fptr == 0 { + return Err(vm.new_value_error("NULL function pointer in vtable")); + } + + Ok((Some(CodePtr(fptr as *mut _)), true)) +} + +/// Prepared arguments for FFI call +struct PreparedArgs { + ffi_arg_types: Vec, + ffi_values: Vec, + out_buffers: Vec<(usize, PyObjectRef)>, +} + +/// Get buffer address from a ctypes object +fn get_buffer_addr(obj: &PyObjectRef) -> Option { + obj.downcast_ref::() + .map(|s| s.0.buffer.read().as_ptr() as usize) + .or_else(|| { + obj.downcast_ref::() + .map(|s| s.0.buffer.read().as_ptr() as usize) + }) + .or_else(|| { + obj.downcast_ref::() + .map(|s| s.0.buffer.read().as_ptr() as usize) + }) +} + +/// Create OUT buffer for a parameter type +fn create_out_buffer(arg_type: &PyTypeRef, vm: &VirtualMachine) -> PyResult { + // For POINTER(T) types, create T instance (the pointed-to type) + if arg_type.fast_issubclass(PyCPointer::static_type()) + && let Some(stg_info) = arg_type.stg_info_opt() + && let Some(ref proto) = stg_info.proto + { + return proto.as_object().call((), vm); + } + // Not a pointer type or no proto, create instance directly + arg_type.as_object().call((), vm) +} + +/// Build callargs when no argtypes specified (use ConvParam) +fn build_callargs_no_argtypes(args: &FuncArgs, vm: &VirtualMachine) -> PyResult { + let results: Vec<(Type, FfiArgValue)> = args + .args + .iter() + .map(|arg| conv_param(arg, vm)) + .collect::>>()?; + let (ffi_arg_types, ffi_values) = results.into_iter().unzip(); + Ok(PreparedArgs { + ffi_arg_types, + ffi_values, + out_buffers: Vec::new(), + }) +} + +/// Build callargs for regular function with argtypes (no paramflags) +fn build_callargs_simple( + args: &FuncArgs, + arg_types: &[PyTypeRef], + vm: &VirtualMachine, +) -> PyResult { + let ffi_arg_types = arg_types + .iter() + .map(|t| ArgumentType::to_ffi_type(t, vm)) + .collect::>>()?; + let ffi_values = args + .args + .iter() + .enumerate() + .map(|(n, arg)| { + let arg_type = arg_types + .get(n) + .ok_or_else(|| vm.new_type_error("argument amount mismatch"))?; + arg_type.convert_object(arg.clone(), vm) + }) + .collect::, _>>()?; + Ok(PreparedArgs { + ffi_arg_types, + ffi_values, + out_buffers: Vec::new(), + }) +} + +/// Build callargs with paramflags (handles IN/OUT parameters) +fn build_callargs_with_paramflags( + args: &FuncArgs, + arg_types: &[PyTypeRef], + paramflags: &ParsedParamFlags, + skip_first_arg: bool, // true for COM methods + vm: &VirtualMachine, +) -> PyResult { + let mut ffi_arg_types = Vec::new(); + let mut ffi_values = Vec::new(); + let mut out_buffers = Vec::new(); + + // For COM methods, first arg is self (pointer) + let mut caller_arg_idx = if skip_first_arg { + ffi_arg_types.push(Type::pointer()); + if !args.args.is_empty() { + ffi_values.push(conv_param(&args.args[0], vm)?.1); + } + 1usize + } else { + 0usize + }; + + // Add FFI types for all argtypes + for arg_type in arg_types { + ffi_arg_types.push(ArgumentType::to_ffi_type(arg_type, vm)?); + } + + // Process parameters based on paramflags + for (param_idx, (direction, _name, default)) in paramflags.iter().enumerate() { + let arg_type = arg_types + .get(param_idx) + .ok_or_else(|| vm.new_type_error("paramflags/argtypes mismatch"))?; + + let is_out = (*direction & 2) != 0; // OUT flag + let is_in = (*direction & 1) != 0 || *direction == 0; // IN flag or default + + if is_out && !is_in { + // Pure OUT parameter: create buffer, don't consume caller arg + let buffer = create_out_buffer(arg_type, vm)?; + let addr = get_buffer_addr(&buffer).ok_or_else(|| { + vm.new_type_error("Cannot create OUT buffer for this type".to_string()) + })?; + ffi_values.push(FfiArgValue::Pointer(addr)); + out_buffers.push((param_idx, buffer)); + } else { + // IN or IN|OUT: get from caller args or default + let arg = if caller_arg_idx < args.args.len() { + caller_arg_idx += 1; + args.args[caller_arg_idx - 1].clone() + } else if let Some(def) = default { + def.clone() + } else { + return Err(vm.new_type_error(format!("required argument {} missing", param_idx))); + }; + + if is_out { + // IN|OUT: track for return + out_buffers.push((param_idx, arg.clone())); + } + ffi_values.push(arg_type.convert_object(arg, vm)?); + } + } + + Ok(PreparedArgs { + ffi_arg_types, + ffi_values, + out_buffers, + }) +} + +/// Build call arguments (main dispatcher) +fn build_callargs( + args: &FuncArgs, + call_info: &CallInfo, + paramflags: Option<&ParsedParamFlags>, + is_com_method: bool, + vm: &VirtualMachine, +) -> PyResult { + let Some(ref arg_types) = call_info.explicit_arg_types else { + // No argtypes: use ConvParam + return build_callargs_no_argtypes(args, vm); + }; + + if let Some(pflags) = paramflags { + // Has paramflags: handle IN/OUT + build_callargs_with_paramflags(args, arg_types, pflags, is_com_method, vm) + } else if is_com_method { + // COM method without paramflags + let mut ffi_types = vec![Type::pointer()]; + ffi_types.extend( + arg_types + .iter() + .map(|t| ArgumentType::to_ffi_type(t, vm)) + .collect::>>()?, + ); + let mut ffi_vals = Vec::new(); + if !args.args.is_empty() { + ffi_vals.push(conv_param(&args.args[0], vm)?.1); + } + for (n, arg) in args.args.iter().skip(1).enumerate() { + let arg_type = arg_types + .get(n) + .ok_or_else(|| vm.new_type_error("argument amount mismatch"))?; + ffi_vals.push(arg_type.convert_object(arg.clone(), vm)?); + } + Ok(PreparedArgs { + ffi_arg_types: ffi_types, + ffi_values: ffi_vals, + out_buffers: Vec::new(), + }) + } else { + // Regular function + build_callargs_simple(args, arg_types, vm) + } +} + +/// Raw result from FFI call +enum RawResult { + Void, + Pointer(usize), + Value(libffi::low::ffi_arg), +} + +/// Execute FFI call +fn ctypes_callproc(code_ptr: CodePtr, prepared: &PreparedArgs, call_info: &CallInfo) -> RawResult { + let cif = Cif::new( + prepared.ffi_arg_types.clone(), + call_info.ffi_return_type.clone(), + ); + let ffi_args: Vec = prepared.ffi_values.iter().map(|v| v.as_arg()).collect(); + + if call_info.restype_is_none { + unsafe { cif.call::<()>(code_ptr, &ffi_args) }; + RawResult::Void + } else if call_info.is_pointer_return { + let result = unsafe { cif.call::(code_ptr, &ffi_args) }; + RawResult::Pointer(result) + } else { + let result = unsafe { cif.call::(code_ptr, &ffi_args) }; + RawResult::Value(result) + } +} + +/// Check and handle HRESULT errors (Windows) +#[cfg(windows)] +fn check_hresult(hresult: i32, zelf: &Py, vm: &VirtualMachine) -> PyResult<()> { + if hresult >= 0 { + return Ok(()); + } + + if zelf.iid.read().is_some() { + // Raise COMError + let ctypes_module = vm.import("_ctypes", 0)?; + let com_error_type = ctypes_module.get_attr("COMError", vm)?; + let com_error_type = com_error_type + .downcast::() + .map_err(|_| vm.new_type_error("COMError is not a type"))?; + let hresult_obj: PyObjectRef = vm.ctx.new_int(hresult).into(); + let text: PyObjectRef = vm + .ctx + .new_str(format!("HRESULT: 0x{:08X}", hresult as u32)) + .into(); + let details: PyObjectRef = vm.ctx.none(); + let exc = vm.invoke_exception( + com_error_type.to_owned(), + vec![text.clone(), details.clone()], + )?; + let _ = exc.as_object().set_attr("hresult", hresult_obj, vm); + let _ = exc.as_object().set_attr("text", text, vm); + let _ = exc.as_object().set_attr("details", details, vm); + Err(exc) + } else { + // Raise OSError + let exc = vm.new_os_error(format!("HRESULT: 0x{:08X}", hresult as u32)); + let _ = exc + .as_object() + .set_attr("winerror", vm.ctx.new_int(hresult), vm); + Err(exc) + } +} + +/// Convert raw FFI result to Python object +fn convert_raw_result( + raw_result: &mut RawResult, + call_info: &CallInfo, + vm: &VirtualMachine, +) -> Option { + match raw_result { + RawResult::Void => None, + RawResult::Pointer(ptr) => { + // Get type code from restype to determine conversion method + let type_code = call_info + .restype_obj + .as_ref() + .and_then(|t| t.clone().downcast::().ok()) + .and_then(|t| t.as_object().get_attr(vm.ctx.intern_str("_type_"), vm).ok()) + .and_then(|t| t.downcast_ref::().map(|s| s.to_string())); + + match type_code.as_deref() { + Some("z") => { + // c_char_p: NULL -> None, otherwise read C string -> bytes + if *ptr == 0 { + Some(vm.ctx.none()) + } else { + let cstr = unsafe { std::ffi::CStr::from_ptr(*ptr as _) }; + Some(vm.ctx.new_bytes(cstr.to_bytes().to_vec()).into()) + } + } + Some("Z") => { + // c_wchar_p: NULL -> None, otherwise read wide string -> str + if *ptr == 0 { + Some(vm.ctx.none()) + } else { + let wstr_ptr = *ptr as *const libc::wchar_t; + let mut len = 0; + unsafe { + while *wstr_ptr.add(len) != 0 { + len += 1; + } + } + let slice = unsafe { std::slice::from_raw_parts(wstr_ptr, len) }; + let s: String = slice + .iter() + .filter_map(|&c| char::from_u32(c as u32)) + .collect(); + Some(vm.ctx.new_str(s).into()) + } + } + _ => { + // c_void_p ("P") and other pointer types: NULL -> None, otherwise int + if *ptr == 0 { + Some(vm.ctx.none()) + } else { + Some(vm.ctx.new_int(*ptr).into()) + } + } + } + } + RawResult::Value(val) => call_info + .restype_obj .as_ref() .and_then(|f| f.clone().downcast::().ok()) - .map(|f| f.from_ffi_type(&mut output, vm).ok().flatten()) - .unwrap_or_else(|| Some(vm.ctx.new_int(output as i32).as_object().to_pyobject(vm))); - if let Some(return_type) = return_type { - Ok(return_type) - } else { - Ok(vm.ctx.none()) + .map(|f| { + f.from_ffi_type(val as *mut _ as *mut c_void, vm) + .ok() + .flatten() + }) + .unwrap_or_else(|| Some(vm.ctx.new_int(*val as usize).as_object().to_pyobject(vm))), + } +} + +/// Extract values from OUT buffers +fn extract_out_values( + out_buffers: Vec<(usize, PyObjectRef)>, + vm: &VirtualMachine, +) -> Vec { + out_buffers + .into_iter() + .map(|(_, buffer)| buffer.get_attr("value", vm).unwrap_or(buffer)) + .collect() +} + +/// Build final result (main function) +fn build_result( + mut raw_result: RawResult, + call_info: &CallInfo, + prepared: PreparedArgs, + zelf: &Py, + args: &FuncArgs, + vm: &VirtualMachine, +) -> PyResult { + // Check HRESULT on Windows + #[cfg(windows)] + if let RawResult::Value(val) = raw_result { + let is_hresult = call_info + .restype_obj + .as_ref() + .and_then(|t| t.clone().downcast::().ok()) + .is_some_and(|t| t.name().to_string() == "HRESULT"); + if is_hresult { + check_hresult(val as i32, zelf, vm)?; } } + + // Convert raw result to Python object + let mut result = convert_raw_result(&mut raw_result, call_info, vm); + + // Apply errcheck if set + if let Some(errcheck) = zelf.errcheck.read().as_ref() { + let args_tuple = PyTuple::new_ref(args.args.clone(), &vm.ctx); + let func_obj = zelf.as_object().to_owned(); + let result_obj = result.clone().unwrap_or_else(|| vm.ctx.none()); + result = Some(errcheck.call((result_obj, func_obj, args_tuple), vm)?); + } + + // Handle OUT parameter return values + if prepared.out_buffers.is_empty() { + return result.map(Ok).unwrap_or_else(|| Ok(vm.ctx.none())); + } + + let out_values = extract_out_values(prepared.out_buffers, vm); + Ok(match <[PyObjectRef; 1]>::try_from(out_values) { + Ok([single]) => single, + Err(v) => PyTuple::new_ref(v, &vm.ctx).into(), + }) } +impl Callable for PyCFuncPtr { + type Args = FuncArgs; + fn call(zelf: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { + // 1. Check for internal PYFUNCTYPE addresses + if let Some(result) = handle_internal_func(zelf.get_func_ptr(), &args, vm) { + return result; + } + + // 2. Resolve function pointer (COM or direct) + #[cfg(windows)] + let (func_ptr, is_com_method) = resolve_com_method(zelf, &args, vm)?; + #[cfg(not(windows))] + let (func_ptr, is_com_method) = (None::, false); + + // 3. Extract call info (argtypes, restype) + let call_info = extract_call_info(zelf, vm)?; + + // 4. Parse paramflags + let paramflags = parse_paramflags(zelf, vm)?; + + // 5. Build call arguments + let prepared = build_callargs(&args, &call_info, paramflags.as_ref(), is_com_method, vm)?; + + // 6. Get code pointer + let code_ptr = match func_ptr.or_else(|| zelf.get_code_ptr()) { + Some(cp) => cp, + None => { + debug_assert!(false, "NULL function pointer"); + // In release mode, this will crash like CPython + CodePtr(std::ptr::null_mut()) + } + }; + + // 7. Call the function + let raw_result = ctypes_callproc(code_ptr, &prepared, &call_info); + + // 8. Build result + build_result(raw_result, &call_info, prepared, zelf, &args, vm) + } +} + +// PyCFuncPtr_repr impl Representable for PyCFuncPtr { fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { - let index = zelf.ptr.read(); - let index = index.map(|ptr| ptr.0 as usize).unwrap_or(0); let type_name = zelf.class().name(); - if cfg!(windows) { - let index = index - 0x1000; - Ok(format!("")) - } else { - Ok(format!("<{type_name} object at {index:#x}>")) - } + // Use object id, not function pointer address + let addr = zelf.get_id(); + Ok(format!("<{} object at {:#x}>", type_name, addr)) } } -// TODO: fix -unsafe impl Send for PyCFuncPtr {} -unsafe impl Sync for PyCFuncPtr {} - #[pyclass(flags(BASETYPE), with(Callable, Constructor, Representable))] impl PyCFuncPtr { - #[pygetset(name = "_restype_")] + // restype getter/setter + #[pygetset] fn restype(&self) -> Option { - self.res_type.read().as_ref().cloned() + self.restype.read().clone() } - #[pygetset(name = "_restype_", setter)] - fn set_restype(&self, restype: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // has to be type, callable, or none - // TODO: Callable support - if vm.is_none(&restype) || restype.downcast_ref::().is_some() { - *self.res_type.write() = Some(restype); - Ok(()) + #[pygetset(setter)] + fn set_restype(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // Must be type, callable, or None + if vm.is_none(&value) { + *self.restype.write() = None; + } else if value.downcast_ref::().is_some() || value.is_callable() { + *self.restype.write() = Some(value); } else { - Err(vm.new_type_error("restype must be a type, a callable, or None".to_string())) + return Err(vm.new_type_error("restype must be a type, a callable, or None")); } + Ok(()) } - #[pygetset(name = "argtypes")] - fn argtypes(&self, vm: &VirtualMachine) -> PyTupleRef { - PyTuple::new_ref( - self.arg_types - .read() - .clone() - .unwrap_or_default() - .into_iter() - .map(|t| t.to_pyobject(vm)) - .collect(), - &vm.ctx, - ) + // argtypes getter/setter + #[pygetset] + fn argtypes(&self, vm: &VirtualMachine) -> PyObjectRef { + self.argtypes + .read() + .clone() + .unwrap_or_else(|| vm.ctx.empty_tuple.clone().into()) } #[pygetset(name = "argtypes", setter)] - fn set_argtypes(&self, argtypes: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let none = vm.is_none(&argtypes); - if none { - *self.arg_types.write() = None; - Ok(()) + fn set_argtypes(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if vm.is_none(&value) { + *self.argtypes.write() = None; } else { - let tuple = argtypes.downcast::().unwrap(); - *self.arg_types.write() = Some( - tuple - .iter() - .map(|obj| obj.clone().downcast::().unwrap()) - .collect::>(), - ); - Ok(()) + // Store the argtypes object directly as it is + *self.argtypes.write() = Some(value); } + Ok(()) } + // errcheck getter/setter #[pygetset] - fn __name__(&self) -> Option { - self.name.read().clone() + fn errcheck(&self) -> Option { + self.errcheck.read().clone() } #[pygetset(setter)] - fn set___name__(&self, name: String) -> PyResult<()> { - *self.name.write() = Some(name); - // TODO: update handle and stuff + fn set_errcheck(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if vm.is_none(&value) { + *self.errcheck.write() = None; + } else if value.is_callable() { + *self.errcheck.write() = Some(value); + } else { + return Err(vm.new_type_error("errcheck must be a callable or None")); + } Ok(()) } + + // _flags_ getter (read-only, from type's class attribute or StgInfo) + #[pygetset] + fn _flags_(zelf: &Py, vm: &VirtualMachine) -> u32 { + // First try to get _flags_ from type's class attribute (for dynamically created types) + // This is how CDLL sets use_errno: class _FuncPtr(_CFuncPtr): _flags_ = flags + if let Ok(flags_attr) = zelf.class().as_object().get_attr("_flags_", vm) + && let Ok(flags_int) = flags_attr.try_to_value::(vm) + { + return flags_int; + } + + // Fallback to StgInfo for native types + use super::base::StgInfoFlags; + zelf.class() + .stg_info_opt() + .map(|stg| stg.flags.bits()) + .unwrap_or(StgInfoFlags::empty().bits()) + } + + // bool conversion - check if function pointer is set + #[pymethod] + fn __bool__(&self) -> bool { + self.get_func_ptr() != 0 + } +} + +// CThunkObject - FFI callback (thunk) implementation + +/// Userdata passed to the libffi callback. +struct ThunkUserData { + /// The Python callable to invoke + callable: PyObjectRef, + /// Argument types for conversion + arg_types: Vec, + /// Result type for conversion (None means void) + res_type: Option, +} + +/// Check if ty is a subclass of a simple type (like MyInt(c_int)). +fn is_simple_subclass(ty: &Py, vm: &VirtualMachine) -> bool { + let Ok(base) = ty.as_object().get_attr(vm.ctx.intern_str("__base__"), vm) else { + return false; + }; + base.get_attr(vm.ctx.intern_str("_type_"), vm).is_ok() +} + +/// Convert a C value to a Python object based on the type code. +fn ffi_to_python(ty: &Py, ptr: *const c_void, vm: &VirtualMachine) -> PyObjectRef { + let type_code = ty.type_code(vm); + let raw_value: PyObjectRef = unsafe { + match type_code.as_deref() { + Some("b") => vm.ctx.new_int(*(ptr as *const i8) as i32).into(), + Some("B") => vm.ctx.new_int(*(ptr as *const u8) as i32).into(), + Some("c") => vm.ctx.new_bytes(vec![*(ptr as *const u8)]).into(), + Some("h") => vm.ctx.new_int(*(ptr as *const i16) as i32).into(), + Some("H") => vm.ctx.new_int(*(ptr as *const u16) as i32).into(), + Some("i") => vm.ctx.new_int(*(ptr as *const i32)).into(), + Some("I") => vm.ctx.new_int(*(ptr as *const u32)).into(), + Some("l") => vm.ctx.new_int(*(ptr as *const libc::c_long)).into(), + Some("L") => vm.ctx.new_int(*(ptr as *const libc::c_ulong)).into(), + Some("q") => vm.ctx.new_int(*(ptr as *const libc::c_longlong)).into(), + Some("Q") => vm.ctx.new_int(*(ptr as *const libc::c_ulonglong)).into(), + Some("f") => vm.ctx.new_float(*(ptr as *const f32) as f64).into(), + Some("d") => vm.ctx.new_float(*(ptr as *const f64)).into(), + Some("z") => { + // c_char_p: C string pointer → Python bytes + let cstr_ptr = *(ptr as *const *const libc::c_char); + if cstr_ptr.is_null() { + vm.ctx.none() + } else { + let cstr = std::ffi::CStr::from_ptr(cstr_ptr); + vm.ctx.new_bytes(cstr.to_bytes().to_vec()).into() + } + } + Some("Z") => { + // c_wchar_p: wchar_t* → Python str + let wstr_ptr = *(ptr as *const *const libc::wchar_t); + if wstr_ptr.is_null() { + vm.ctx.none() + } else { + let mut len = 0; + while *wstr_ptr.add(len) != 0 { + len += 1; + } + let slice = std::slice::from_raw_parts(wstr_ptr, len); + let s: String = slice + .iter() + .filter_map(|&c| char::from_u32(c as u32)) + .collect(); + vm.ctx.new_str(s).into() + } + } + Some("P") => vm.ctx.new_int(*(ptr as *const usize)).into(), + Some("?") => vm.ctx.new_bool(*(ptr as *const u8) != 0).into(), + _ => return vm.ctx.none(), + } + }; + + if !is_simple_subclass(ty, vm) { + return raw_value; + } + ty.as_object() + .call((raw_value.clone(),), vm) + .unwrap_or(raw_value) +} + +/// Convert a Python object to a C value and store it at the result pointer +fn python_to_ffi(obj: PyResult, ty: &Py, result: *mut c_void, vm: &VirtualMachine) { + let Ok(obj) = obj else { return }; + + let type_code = ty.type_code(vm); + unsafe { + match type_code.as_deref() { + Some("b") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut i8) = i.as_bigint().to_i8().unwrap_or(0); + } + } + Some("B") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut u8) = i.as_bigint().to_u8().unwrap_or(0); + } + } + Some("c") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut u8) = i.as_bigint().to_u8().unwrap_or(0); + } + } + Some("h") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut i16) = i.as_bigint().to_i16().unwrap_or(0); + } + } + Some("H") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut u16) = i.as_bigint().to_u16().unwrap_or(0); + } + } + Some("i") => { + if let Ok(i) = obj.try_int(vm) { + let val = i.as_bigint().to_i32().unwrap_or(0); + *(result as *mut libffi::low::ffi_arg) = val as libffi::low::ffi_arg; + } + } + Some("I") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut u32) = i.as_bigint().to_u32().unwrap_or(0); + } + } + Some("l") | Some("q") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut i64) = i.as_bigint().to_i64().unwrap_or(0); + } + } + Some("L") | Some("Q") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut u64) = i.as_bigint().to_u64().unwrap_or(0); + } + } + Some("f") => { + if let Ok(f) = obj.try_float(vm) { + *(result as *mut f32) = f.to_f64() as f32; + } + } + Some("d") => { + if let Ok(f) = obj.try_float(vm) { + *(result as *mut f64) = f.to_f64(); + } + } + Some("P") | Some("z") | Some("Z") => { + if let Ok(i) = obj.try_int(vm) { + *(result as *mut usize) = i.as_bigint().to_usize().unwrap_or(0); + } + } + Some("?") => { + if let Ok(b) = obj.is_true(vm) { + *(result as *mut u8) = u8::from(b); + } + } + _ => {} + } + } +} + +/// The callback function that libffi calls when the closure is invoked. +unsafe extern "C" fn thunk_callback( + _cif: &low::ffi_cif, + result: &mut c_void, + args: *const *const c_void, + userdata: &ThunkUserData, +) { + with_current_vm(|vm| { + let py_args: Vec = userdata + .arg_types + .iter() + .enumerate() + .map(|(i, ty)| { + let arg_ptr = unsafe { *args.add(i) }; + ffi_to_python(ty, arg_ptr, vm) + }) + .collect(); + + let py_result = userdata.callable.call(py_args, vm); + + // Call unraisable hook if exception occurred + if let Err(exc) = &py_result { + let repr = userdata + .callable + .repr(vm) + .map(|s| s.to_string()) + .unwrap_or_else(|_| "".to_string()); + let msg = format!( + "Exception ignored on calling ctypes callback function {}", + repr + ); + vm.run_unraisable(exc.clone(), Some(msg), vm.ctx.none()); + } + + if let Some(ref res_type) = userdata.res_type { + python_to_ffi(py_result, res_type, result as *mut c_void, vm); + } + }); +} + +/// Holds the closure and userdata together to ensure proper lifetime. +struct ThunkData { + #[allow(dead_code)] + closure: Closure<'static>, + userdata_ptr: *mut ThunkUserData, +} + +impl Drop for ThunkData { + fn drop(&mut self) { + unsafe { + drop(Box::from_raw(self.userdata_ptr)); + } + } +} + +/// CThunkObject wraps a Python callable to make it callable from C code. +#[pyclass(name = "CThunkObject", module = "_ctypes")] +#[derive(PyPayload)] +pub(super) struct PyCThunk { + callable: PyObjectRef, + #[allow(dead_code)] + thunk_data: PyRwLock>, + code_ptr: CodePtr, +} + +impl Debug for PyCThunk { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyCThunk") + .field("callable", &self.callable) + .finish() + } +} + +impl PyCThunk { + pub fn new( + callable: PyObjectRef, + arg_types: Option, + res_type: Option, + vm: &VirtualMachine, + ) -> PyResult { + let arg_type_vec: Vec = match arg_types { + Some(args) if !vm.is_none(&args) => args + .try_to_value::>(vm)? + .into_iter() + .map(|item| { + item.downcast::() + .map_err(|_| vm.new_type_error("_argtypes_ must be a sequence of types")) + }) + .collect::>>()?, + _ => Vec::new(), + }; + + let res_type_ref: Option = match res_type { + Some(ref rt) if !vm.is_none(rt) => Some( + rt.clone() + .downcast::() + .map_err(|_| vm.new_type_error("restype must be a ctypes type"))?, + ), + _ => None, + }; + + let ffi_arg_types: Vec = arg_type_vec + .iter() + .map(|ty| { + ty.type_code(vm) + .and_then(|code| get_ffi_type(&code)) + .unwrap_or(Type::pointer()) + }) + .collect(); + + let ffi_res_type = res_type_ref + .as_ref() + .and_then(|ty| ty.type_code(vm)) + .and_then(|code| get_ffi_type(&code)) + .unwrap_or(Type::void()); + + let cif = Cif::new(ffi_arg_types, ffi_res_type); + + let userdata = Box::new(ThunkUserData { + callable: callable.clone(), + arg_types: arg_type_vec, + res_type: res_type_ref, + }); + let userdata_ptr = Box::into_raw(userdata); + let userdata_ref: &'static ThunkUserData = unsafe { &*userdata_ptr }; + + let closure = Closure::new(cif, thunk_callback, userdata_ref); + let code_ptr = CodePtr(*closure.code_ptr() as *mut _); + + let thunk_data = ThunkData { + closure, + userdata_ptr, + }; + + Ok(Self { + callable, + thunk_data: PyRwLock::new(Some(thunk_data)), + code_ptr, + }) + } + + pub fn code_ptr(&self) -> CodePtr { + self.code_ptr + } +} + +unsafe impl Send for PyCThunk {} +unsafe impl Sync for PyCThunk {} + +#[pyclass] +impl PyCThunk { + #[pygetset] + fn callable(&self) -> PyObjectRef { + self.callable.clone() + } } diff --git a/crates/vm/src/stdlib/ctypes/library.rs b/crates/vm/src/stdlib/ctypes/library.rs index e918470b6c8..ec8ca91af0d 100644 --- a/crates/vm/src/stdlib/ctypes/library.rs +++ b/crates/vm/src/stdlib/ctypes/library.rs @@ -2,12 +2,12 @@ use crate::VirtualMachine; use libloading::Library; use rustpython_common::lock::{PyMutex, PyRwLock}; use std::collections::HashMap; -use std::ffi::c_void; +use std::ffi::{OsStr, c_void}; use std::fmt; use std::ptr::null; -pub struct SharedLibrary { - pub(crate) lib: PyMutex>, +pub(super) struct SharedLibrary { + pub(super) lib: PyMutex>, } impl fmt::Debug for SharedLibrary { @@ -17,13 +17,13 @@ impl fmt::Debug for SharedLibrary { } impl SharedLibrary { - pub fn new(name: &str) -> Result { + fn new(name: impl AsRef) -> Result { Ok(SharedLibrary { - lib: PyMutex::new(unsafe { Some(Library::new(name)?) }), + lib: PyMutex::new(unsafe { Some(Library::new(name.as_ref())?) }), }) } - pub fn get_pointer(&self) -> usize { + fn get_pointer(&self) -> usize { let lib_lock = self.lib.lock(); if let Some(l) = &*lib_lock { l as *const Library as usize @@ -32,12 +32,12 @@ impl SharedLibrary { } } - pub fn is_closed(&self) -> bool { + fn is_closed(&self) -> bool { let lib_lock = self.lib.lock(); lib_lock.is_none() } - pub fn close(&self) { + fn close(&self) { *self.lib.lock() = None; } } @@ -48,25 +48,24 @@ impl Drop for SharedLibrary { } } -pub struct ExternalLibs { +pub(super) struct ExternalLibs { libraries: HashMap, } impl ExternalLibs { - pub fn new() -> Self { + fn new() -> Self { Self { libraries: HashMap::new(), } } - #[allow(dead_code)] pub fn get_lib(&self, key: usize) -> Option<&SharedLibrary> { self.libraries.get(&key) } pub fn get_or_insert_lib( &mut self, - library_path: &str, + library_path: impl AsRef, _vm: &VirtualMachine, ) -> Result<(usize, &SharedLibrary), libloading::Error> { let new_lib = SharedLibrary::new(library_path)?; @@ -83,7 +82,7 @@ impl ExternalLibs { } }; - Ok((key, self.libraries.get(&key).unwrap())) + Ok((key, self.libraries.get(&key).expect("just inserted"))) } pub fn drop_lib(&mut self, key: usize) { @@ -91,10 +90,9 @@ impl ExternalLibs { } } -rustpython_common::static_cell! { - static LIBCACHE: PyRwLock; -} - -pub fn libcache() -> &'static PyRwLock { +pub(super) fn libcache() -> &'static PyRwLock { + rustpython_common::static_cell! { + static LIBCACHE: PyRwLock; + } LIBCACHE.get_or_init(|| PyRwLock::new(ExternalLibs::new())) } diff --git a/crates/vm/src/stdlib/ctypes/pointer.rs b/crates/vm/src/stdlib/ctypes/pointer.rs index 735034e7936..3ee39af3a7a 100644 --- a/crates/vm/src/stdlib/ctypes/pointer.rs +++ b/crates/vm/src/stdlib/ctypes/pointer.rs @@ -1,37 +1,164 @@ -use num_traits::ToPrimitive; -use rustpython_common::lock::PyRwLock; - -use crate::builtins::{PyType, PyTypeRef}; -use crate::function::FuncArgs; +use super::{PyCArray, PyCData, PyCSimple, PyCStructure, StgInfo, StgInfoFlags}; use crate::protocol::PyNumberMethods; -use crate::stdlib::ctypes::{CDataObject, PyCData}; -use crate::types::{AsNumber, Constructor}; -use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; +use crate::types::{AsNumber, Constructor, Initializer}; +use crate::{ + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyBytes, PyInt, PyList, PySlice, PyStr, PyType, PyTypeRef}, + class::StaticType, + function::{FuncArgs, OptionalArg}, +}; +use num_traits::ToPrimitive; #[pyclass(name = "PyCPointerType", base = PyType, module = "_ctypes")] #[derive(Debug)] #[repr(transparent)] -pub struct PyCPointerType(PyType); +pub(super) struct PyCPointerType(PyType); + +impl Initializer for PyCPointerType { + type Args = FuncArgs; + + fn init(zelf: crate::PyRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + // Get the type as PyTypeRef + let obj: PyObjectRef = zelf.clone().into(); + let new_type: PyTypeRef = obj + .downcast() + .map_err(|_| vm.new_type_error("expected type"))?; + + new_type.check_not_initialized(vm)?; + + // Get the _type_ attribute (element type) + // PyCPointerType_init gets the element type from _type_ attribute + let proto = new_type + .as_object() + .get_attr("_type_", vm) + .ok() + .and_then(|obj| obj.downcast::().ok()); + + // Initialize StgInfo for pointer type + let pointer_size = std::mem::size_of::(); + let mut stg_info = StgInfo::new(pointer_size, pointer_size); + stg_info.proto = proto; + stg_info.paramfunc = super::base::ParamFunc::Pointer; + stg_info.length = 1; + stg_info.flags |= StgInfoFlags::TYPEFLAG_ISPOINTER; + + // Set format string: "&" + if let Some(ref proto) = stg_info.proto { + let item_info = proto.stg_info_opt().expect("proto has StgInfo"); + let current_format = item_info.format.as_deref().unwrap_or("B"); + stg_info.format = Some(format!("&{}", current_format)); + } + + let _ = new_type.init_type_data(stg_info); -#[pyclass(flags(IMMUTABLETYPE), with(AsNumber))] + Ok(()) + } +} + +#[pyclass(flags(IMMUTABLETYPE), with(AsNumber, Initializer))] impl PyCPointerType { + #[pymethod] + fn from_param(zelf: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // zelf is the pointer type class that from_param was called on + let cls = zelf + .downcast::() + .map_err(|_| vm.new_type_error("from_param: expected a type"))?; + + // 1. None is allowed for pointer types + if vm.is_none(&value) { + return Ok(value); + } + + // 2. If already an instance of the requested type, return it + if value.is_instance(cls.as_object(), vm)? { + return Ok(value); + } + + // 3. If value is an instance of _type_ (the pointed-to type), wrap with byref + if let Ok(type_attr) = cls.as_object().get_attr("_type_", vm) + && let Ok(type_ref) = type_attr.downcast::() + && value.is_instance(type_ref.as_object(), vm)? + { + // Return byref(value) + return super::_ctypes::byref(value, crate::function::OptionalArg::Missing, vm); + } + + // 4. Array/Pointer instances with compatible proto + // "Array instances are also pointers when the item types are the same." + let is_pointer_or_array = value.downcast_ref::().is_some() + || value.downcast_ref::().is_some(); + + if is_pointer_or_array { + let is_compatible = { + if let Some(value_stginfo) = value.class().stg_info_opt() + && let Some(ref value_proto) = value_stginfo.proto + && let Some(cls_stginfo) = cls.stg_info_opt() + && let Some(ref cls_proto) = cls_stginfo.proto + { + // Check if value's proto is a subclass of target's proto + value_proto.fast_issubclass(cls_proto) + } else { + false + } + }; + if is_compatible { + return Ok(value); + } + } + + // 5. Check for _as_parameter_ attribute + if let Ok(as_parameter) = value.get_attr("_as_parameter_", vm) { + return PyCPointerType::from_param(cls.as_object().to_owned(), as_parameter, vm); + } + + Err(vm.new_type_error(format!( + "expected {} instance instead of {}", + cls.name(), + value.class().name() + ))) + } + #[pymethod] fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { - use super::array::create_array_type_with_stg_info; + use super::array::array_type_from_ctype; + if n < 0 { return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}"))); } - // Pointer size - let element_size = std::mem::size_of::(); - let total_size = element_size * (n as usize); - let stg_info = super::util::StgInfo::new_array( - total_size, - element_size, - n as usize, - cls.as_object().to_owned(), - element_size, - ); - create_array_type_with_stg_info(stg_info, vm) + // Use cached array type creation + array_type_from_ctype(cls.into(), n as usize, vm) + } + + // PyCPointerType_set_type: Complete an incomplete pointer type + #[pymethod] + fn set_type(zelf: PyTypeRef, typ: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + use crate::AsObject; + + // 1. Validate that typ is a type + let typ_type = typ + .clone() + .downcast::() + .map_err(|_| vm.new_type_error("_type_ must be a type"))?; + + // 2. Validate that typ has storage info + if typ_type.stg_info_opt().is_none() { + return Err(vm.new_type_error("_type_ must have storage info")); + } + + // 3. Update StgInfo.proto and format using mutable access + if let Some(mut stg_info) = zelf.get_type_data_mut::() { + stg_info.proto = Some(typ_type.clone()); + + // Update format string: "&" + let item_info = typ_type.stg_info_opt().expect("proto has StgInfo"); + let current_format = item_info.format.as_deref().unwrap_or("B"); + stg_info.format = Some(format!("&{}", current_format)); + } + + // 4. Set _type_ attribute on the pointer type + zelf.as_object().set_attr("_type_", typ_type, vm)?; + + Ok(()) } } @@ -41,12 +168,12 @@ impl AsNumber for PyCPointerType { multiply: Some(|a, b, vm| { let cls = a .downcast_ref::() - .ok_or_else(|| vm.new_type_error("expected type".to_owned()))?; + .ok_or_else(|| vm.new_type_error("expected type"))?; let n = b .try_index(vm)? .as_bigint() .to_isize() - .ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?; + .ok_or_else(|| vm.new_overflow_error("array size too large"))?; PyCPointerType::__mul__(cls.to_owned(), n, vm) }), ..PyNumberMethods::NOT_IMPLEMENTED @@ -55,6 +182,8 @@ impl AsNumber for PyCPointerType { } } +/// PyCPointer - Pointer instance +/// `contents` is a computed property, not a stored field. #[pyclass( name = "_Pointer", base = PyCData, @@ -62,26 +191,27 @@ impl AsNumber for PyCPointerType { module = "_ctypes" )] #[derive(Debug)] -pub struct PyCPointer { - _base: PyCData, - contents: PyRwLock, -} +#[repr(transparent)] +pub struct PyCPointer(pub PyCData); impl Constructor for PyCPointer { - type Args = (crate::function::OptionalArg,); - - fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - let args: Self::Args = args.bind(vm)?; - // Get the initial contents value if provided - let initial_contents = args.0.into_option().unwrap_or_else(|| vm.ctx.none()); + type Args = FuncArgs; - // Create a new PyCPointer instance with the provided value - PyCPointer { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - contents: PyRwLock::new(initial_contents), + fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { + // Pointer_new: Check if _type_ is defined + let has_type = cls.stg_info_opt().is_some_and(|info| info.proto.is_some()); + if !has_type { + return Err(vm.new_type_error("Cannot create instance: has no _type_")); } - .into_ref_with_type(vm, cls) - .map(Into::into) + + // Create a new PyCPointer instance with NULL pointer (all zeros) + // Initial contents is set via __init__ if provided + let cdata = PyCData::from_bytes(vec![0u8; std::mem::size_of::()], None); + // pointer instance has b_length set to 2 (for index 0 and 1) + cdata.length.store(2); + PyCPointer(cdata) + .into_ref_with_type(vm, cls) + .map(Into::into) } fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { @@ -89,186 +219,496 @@ impl Constructor for PyCPointer { } } -#[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor))] +impl Initializer for PyCPointer { + type Args = (OptionalArg,); + + fn init(zelf: PyRef, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + let (value,) = args; + if let OptionalArg::Present(val) = value + && !vm.is_none(&val) + { + Self::set_contents(&zelf, val, vm)?; + } + Ok(()) + } +} + +#[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor, Initializer))] impl PyCPointer { - // TODO: not correct + /// Get the pointer value stored in buffer as usize + pub fn get_ptr_value(&self) -> usize { + let buffer = self.0.buffer.read(); + super::base::read_ptr_from_buffer(&buffer) + } + + /// Set the pointer value in buffer + pub fn set_ptr_value(&self, value: usize) { + let mut buffer = self.0.buffer.write(); + let bytes = value.to_ne_bytes(); + if buffer.len() >= bytes.len() { + buffer.to_mut()[..bytes.len()].copy_from_slice(&bytes); + } + } + + /// Pointer_bool: returns True if pointer is not NULL + #[pymethod] + fn __bool__(&self) -> bool { + self.get_ptr_value() != 0 + } + + /// contents getter - reads address from b_ptr and creates an instance of the pointed-to type #[pygetset] - fn contents(&self) -> PyResult { - let contents = self.contents.read().clone(); - Ok(contents) + fn contents(zelf: &Py, vm: &VirtualMachine) -> PyResult { + // Pointer_get_contents + let ptr_val = zelf.get_ptr_value(); + if ptr_val == 0 { + return Err(vm.new_value_error("NULL pointer access")); + } + + // Get element type from StgInfo.proto + let stg_info = zelf.class().stg_info(vm)?; + let proto_type = stg_info.proto(); + let element_size = proto_type + .stg_info_opt() + .map_or(std::mem::size_of::(), |info| info.size); + + // Create instance that references the memory directly + // PyCData.into_ref_with_type works for all ctypes (simple, structure, union, array, pointer) + let cdata = unsafe { super::base::PyCData::at_address(ptr_val as *const u8, element_size) }; + cdata + .into_ref_with_type(vm, proto_type.to_owned()) + .map(Into::into) } + + /// contents setter - stores address in b_ptr and keeps reference + /// Pointer_set_contents #[pygetset(setter)] - fn set_contents(&self, contents: PyObjectRef, _vm: &VirtualMachine) -> PyResult<()> { - // Validate that the contents is a CData instance if we have a _type_ - // For now, just store it - *self.contents.write() = contents; + fn set_contents(zelf: &Py, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // Get stginfo and proto for type validation + let stg_info = zelf.class().stg_info(vm)?; + let proto = stg_info.proto(); + + // Check if value is CData, or isinstance(value, proto) + let cdata = if let Some(c) = value.downcast_ref::() { + c + } else if value.is_instance(proto.as_object(), vm)? { + value + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected ctypes instance"))? + } else { + return Err(vm.new_type_error(format!( + "expected {} instead of {}", + proto.name(), + value.class().name() + ))); + }; + + // Set pointer value + { + let buffer = cdata.buffer.read(); + let addr = buffer.as_ptr() as usize; + drop(buffer); + zelf.set_ptr_value(addr); + } + + // KeepRef: store the object directly with index 1 + zelf.0.keep_ref(1, value.clone(), vm)?; + + // KeepRef: store GetKeepedObjects(dst) at index 0 + if let Some(kept) = cdata.objects.read().clone() { + zelf.0.keep_ref(0, kept, vm)?; + } + Ok(()) } + // Pointer_subscript #[pymethod] - fn __init__( - &self, - value: crate::function::OptionalArg, - _vm: &VirtualMachine, - ) -> PyResult<()> { - // Pointer can be initialized with 0 or 1 argument - // If 1 argument is provided, it should be a CData instance - if let crate::function::OptionalArg::Present(val) = value { - *self.contents.write() = val; + fn __getitem__(zelf: &Py, item: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // PyIndex_Check + if let Some(i) = item.downcast_ref::() { + let i = i.as_bigint().to_isize().ok_or_else(|| { + vm.new_index_error("cannot fit index into an index-sized integer") + })?; + // Note: Pointer does NOT adjust negative indices (no length) + Self::getitem_by_index(zelf, i, vm) + } + // PySlice_Check + else if let Some(slice) = item.downcast_ref::() { + Self::getitem_by_slice(zelf, slice, vm) + } else { + Err(vm.new_type_error("Pointer indices must be integer")) } - - Ok(()) } - #[pyclassmethod] - fn from_address(cls: PyTypeRef, address: isize, vm: &VirtualMachine) -> PyResult { - if address == 0 { - return Err(vm.new_value_error("NULL pointer access".to_owned())); + // Pointer_item + fn getitem_by_index(zelf: &Py, index: isize, vm: &VirtualMachine) -> PyResult { + // if (*(void **)self->b_ptr == NULL) { PyErr_SetString(...); } + let ptr_value = zelf.get_ptr_value(); + if ptr_value == 0 { + return Err(vm.new_value_error("NULL pointer access")); } - // Pointer just stores the address value - Ok(PyCPointer { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - contents: PyRwLock::new(vm.ctx.new_int(address).into()), + + // Get element type and size from StgInfo.proto + let stg_info = zelf.class().stg_info(vm)?; + let proto_type = stg_info.proto(); + let element_size = proto_type + .stg_info_opt() + .map_or(std::mem::size_of::(), |info| info.size); + + // offset = index * iteminfo->size + let offset = index * element_size as isize; + let addr = (ptr_value as isize + offset) as usize; + + // Check if it's a simple type (has _type_ attribute) + if let Ok(type_attr) = proto_type.as_object().get_attr("_type_", vm) + && let Ok(type_str) = type_attr.str(vm) + { + let type_code = type_str.to_string(); + return Self::read_value_at_address(addr, element_size, Some(&type_code), vm); } - .into_ref_with_type(vm, cls)? - .into()) + + // Complex type: create instance that references the memory directly (not a copy) + // This allows p[i].val = x to modify the original memory + // PyCData.into_ref_with_type works for all ctypes (array, structure, union, pointer) + let cdata = unsafe { super::base::PyCData::at_address(addr as *const u8, element_size) }; + cdata + .into_ref_with_type(vm, proto_type.to_owned()) + .map(Into::into) } - #[pyclassmethod] - fn from_buffer( - cls: PyTypeRef, - source: PyObjectRef, - offset: crate::function::OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - use crate::TryFromObject; - use crate::protocol::PyBuffer; + // Pointer_subscript slice handling (manual parsing, not PySlice_Unpack) + fn getitem_by_slice(zelf: &Py, slice: &PySlice, vm: &VirtualMachine) -> PyResult { + // Since pointers have no length, we have to dissect the slice ourselves + + // step: defaults to 1, step == 0 is error + let step: isize = if let Some(ref step_obj) = slice.step + && !vm.is_none(step_obj) + { + let s = step_obj + .try_int(vm)? + .as_bigint() + .to_isize() + .ok_or_else(|| vm.new_value_error("slice step too large"))?; + if s == 0 { + return Err(vm.new_value_error("slice step cannot be zero")); + } + s + } else { + 1 + }; + + // start: defaults to 0, but required if step < 0 + let start: isize = if let Some(ref start_obj) = slice.start + && !vm.is_none(start_obj) + { + start_obj + .try_int(vm)? + .as_bigint() + .to_isize() + .ok_or_else(|| vm.new_value_error("slice start too large"))? + } else { + if step < 0 { + return Err(vm.new_value_error("slice start is required for step < 0")); + } + 0 + }; - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); + // stop: ALWAYS required for pointers + if vm.is_none(&slice.stop) { + return Err(vm.new_value_error("slice stop is required")); } - let offset = offset as usize; - let size = std::mem::size_of::(); + let stop: isize = slice + .stop + .try_int(vm)? + .as_bigint() + .to_isize() + .ok_or_else(|| vm.new_value_error("slice stop too large"))?; - let buffer = PyBuffer::try_from_object(vm, source.clone())?; + // calculate length + let len: usize = if (step > 0 && start > stop) || (step < 0 && start < stop) { + 0 + } else if step > 0 { + ((stop - start - 1) / step + 1) as usize + } else { + ((stop - start + 1) / step + 1) as usize + }; + + // Get element info + let stg_info = zelf.class().stg_info(vm)?; + let element_size = if let Some(ref proto_type) = stg_info.proto { + proto_type.stg_info_opt().expect("proto has StgInfo").size + } else { + std::mem::size_of::() + }; + let type_code = stg_info + .proto + .as_ref() + .and_then(|p| p.as_object().get_attr("_type_", vm).ok()) + .and_then(|t| t.str(vm).ok()) + .map(|s| s.to_string()); - if buffer.desc.readonly { - return Err(vm.new_type_error("underlying buffer is not writable".to_owned())); + let ptr_value = zelf.get_ptr_value(); + + // c_char → bytes + if type_code.as_deref() == Some("c") { + if len == 0 { + return Ok(vm.ctx.new_bytes(vec![]).into()); + } + let mut result = Vec::with_capacity(len); + if step == 1 { + // Optimized contiguous copy + let start_addr = (ptr_value as isize + start * element_size as isize) as *const u8; + unsafe { + result.extend_from_slice(std::slice::from_raw_parts(start_addr, len)); + } + } else { + let mut cur = start; + for _ in 0..len { + let addr = (ptr_value as isize + cur * element_size as isize) as *const u8; + unsafe { + result.push(*addr); + } + cur += step; + } + } + return Ok(vm.ctx.new_bytes(result).into()); } - let buffer_len = buffer.desc.len; - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); + // c_wchar → str + if type_code.as_deref() == Some("u") { + if len == 0 { + return Ok(vm.ctx.new_str("").into()); + } + let mut result = String::with_capacity(len); + let wchar_size = std::mem::size_of::(); + let mut cur = start; + for _ in 0..len { + let addr = (ptr_value as isize + cur * wchar_size as isize) as *const libc::wchar_t; + unsafe { + if let Some(c) = char::from_u32(*addr as u32) { + result.push(c); + } + } + cur += step; + } + return Ok(vm.ctx.new_str(result).into()); } - // Read pointer value from buffer - let bytes = buffer.obj_bytes(); - let ptr_bytes = &bytes[offset..offset + size]; - let ptr_val = usize::from_ne_bytes(ptr_bytes.try_into().expect("size is checked above")); + // other types → list with Pointer_item for each + let mut items = Vec::with_capacity(len); + let mut cur = start; + for _ in 0..len { + items.push(Self::getitem_by_index(zelf, cur, vm)?); + cur += step; + } + Ok(PyList::from(items).into_ref(&vm.ctx).into()) + } - Ok(PyCPointer { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - contents: PyRwLock::new(vm.ctx.new_int(ptr_val).into()), + // Pointer_ass_item + #[pymethod] + fn __setitem__( + zelf: &Py, + item: PyObjectRef, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Pointer does not support item deletion (value always provided) + // only integer indices supported for setitem + if let Some(i) = item.downcast_ref::() { + let i = i.as_bigint().to_isize().ok_or_else(|| { + vm.new_index_error("cannot fit index into an index-sized integer") + })?; + Self::setitem_by_index(zelf, i, value, vm) + } else { + Err(vm.new_type_error("Pointer indices must be integer")) } - .into_ref_with_type(vm, cls)? - .into()) } - #[pyclassmethod] - fn from_buffer_copy( - cls: PyTypeRef, - source: crate::function::ArgBytesLike, - offset: crate::function::OptionalArg, + fn setitem_by_index( + zelf: &Py, + index: isize, + value: PyObjectRef, vm: &VirtualMachine, - ) -> PyResult { - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); + ) -> PyResult<()> { + let ptr_value = zelf.get_ptr_value(); + if ptr_value == 0 { + return Err(vm.new_value_error("NULL pointer access")); } - let offset = offset as usize; - let size = std::mem::size_of::(); - let source_bytes = source.borrow_buf(); - let buffer_len = source_bytes.len(); + // Get element type, size and type_code from StgInfo.proto + let stg_info = zelf.class().stg_info(vm)?; + let proto_type = stg_info.proto(); - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); - } + // Get type code from proto's _type_ attribute + let type_code: Option = proto_type + .as_object() + .get_attr("_type_", vm) + .ok() + .and_then(|t| t.downcast_ref::().map(|s| s.to_string())); - // Read pointer value from buffer - let ptr_bytes = &source_bytes[offset..offset + size]; - let ptr_val = usize::from_ne_bytes(ptr_bytes.try_into().expect("size is checked above")); + let element_size = proto_type + .stg_info_opt() + .map_or(std::mem::size_of::(), |info| info.size); - Ok(PyCPointer { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - contents: PyRwLock::new(vm.ctx.new_int(ptr_val).into()), + // Calculate address + let offset = index * element_size as isize; + let addr = (ptr_value as isize + offset) as usize; + + // Write value at address + // Handle Structure/Array types by copying their buffer + if let Some(cdata) = value.downcast_ref::() + && (cdata.fast_isinstance(PyCStructure::static_type()) + || cdata.fast_isinstance(PyCArray::static_type()) + || cdata.fast_isinstance(PyCSimple::static_type())) + { + let src_buffer = cdata.buffer.read(); + let copy_len = src_buffer.len().min(element_size); + unsafe { + let dest_ptr = addr as *mut u8; + std::ptr::copy_nonoverlapping(src_buffer.as_ptr(), dest_ptr, copy_len); + } + } else { + // Handle z/Z specially to store converted value + if type_code.as_deref() == Some("z") + && let Some(bytes) = value.downcast_ref::() + { + let (converted, ptr_val) = super::base::ensure_z_null_terminated(bytes, vm); + unsafe { + *(addr as *mut usize) = ptr_val; + } + return zelf.0.keep_ref(index as usize, converted, vm); + } else if type_code.as_deref() == Some("Z") + && let Some(s) = value.downcast_ref::() + { + let (holder, ptr_val) = super::base::str_to_wchar_bytes(s.as_str(), vm); + unsafe { + *(addr as *mut usize) = ptr_val; + } + return zelf.0.keep_ref(index as usize, holder, vm); + } else { + Self::write_value_at_address(addr, element_size, &value, type_code.as_deref(), vm)?; + } } - .into_ref_with_type(vm, cls)? - .into()) + + // KeepRef: store reference to keep value alive using actual index + zelf.0.keep_ref(index as usize, value, vm) } - #[pyclassmethod] - fn in_dll( - cls: PyTypeRef, - dll: PyObjectRef, - name: crate::builtins::PyStrRef, + /// Read a value from memory address + fn read_value_at_address( + addr: usize, + size: usize, + type_code: Option<&str>, vm: &VirtualMachine, ) -> PyResult { - use libloading::Symbol; + unsafe { + let ptr = addr as *const u8; + match type_code { + Some("c") => Ok(vm.ctx.new_bytes(vec![*ptr]).into()), + Some("b") => Ok(vm.ctx.new_int(*(ptr as *const i8) as i32).into()), + Some("B") => Ok(vm.ctx.new_int(*ptr as i32).into()), + Some("h") => Ok(vm.ctx.new_int(*(ptr as *const i16) as i32).into()), + Some("H") => Ok(vm.ctx.new_int(*(ptr as *const u16) as i32).into()), + Some("i") | Some("l") => Ok(vm.ctx.new_int(*(ptr as *const i32)).into()), + Some("I") | Some("L") => Ok(vm.ctx.new_int(*(ptr as *const u32)).into()), + Some("q") => Ok(vm.ctx.new_int(*(ptr as *const i64)).into()), + Some("Q") => Ok(vm.ctx.new_int(*(ptr as *const u64)).into()), + Some("f") => Ok(vm.ctx.new_float(*(ptr as *const f32) as f64).into()), + Some("d") | Some("g") => Ok(vm.ctx.new_float(*(ptr as *const f64)).into()), + Some("P") | Some("z") | Some("Z") => { + Ok(vm.ctx.new_int(*(ptr as *const usize)).into()) + } + _ => { + // Default: read as bytes + let bytes = std::slice::from_raw_parts(ptr, size).to_vec(); + Ok(vm.ctx.new_bytes(bytes).into()) + } + } + } + } - // Get the library handle from dll object - let handle = if let Ok(int_handle) = dll.try_int(vm) { - // dll is an integer handle - int_handle - .as_bigint() - .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid library handle".to_owned()))? - } else { - // dll is a CDLL/PyDLL/WinDLL object with _handle attribute - dll.get_attr("_handle", vm)? - .try_int(vm)? - .as_bigint() - .to_usize() - .ok_or_else(|| vm.new_value_error("Invalid library handle".to_owned()))? - }; + /// Write a value to memory address + fn write_value_at_address( + addr: usize, + size: usize, + value: &PyObject, + type_code: Option<&str>, + vm: &VirtualMachine, + ) -> PyResult<()> { + unsafe { + let ptr = addr as *mut u8; - // Get the library from cache - let library_cache = crate::stdlib::ctypes::library::libcache().read(); - let library = library_cache - .get_lib(handle) - .ok_or_else(|| vm.new_attribute_error("Library not found".to_owned()))?; + // Handle c_char_p (z) and c_wchar_p (Z) - store pointer address + // Note: PyBytes/PyStr cases are handled by caller (setitem_by_index) + match type_code { + Some("z") | Some("Z") => { + let ptr_val = if vm.is_none(value) { + 0usize + } else if let Ok(int_val) = value.try_index(vm) { + int_val.as_bigint().to_usize().unwrap_or(0) + } else { + return Err(vm.new_type_error( + "bytes/string or integer address expected".to_owned(), + )); + }; + *(ptr as *mut usize) = ptr_val; + return Ok(()); + } + _ => {} + } - // Get symbol address from library - let symbol_name = format!("{}\0", name.as_str()); - let inner_lib = library.lib.lock(); + // Try to get value as integer + if let Ok(int_val) = value.try_int(vm) { + let i = int_val.as_bigint(); + match size { + 1 => { + *ptr = i.to_u8().unwrap_or(0); + } + 2 => { + *(ptr as *mut i16) = i.to_i16().unwrap_or(0); + } + 4 => { + *(ptr as *mut i32) = i.to_i32().unwrap_or(0); + } + 8 => { + *(ptr as *mut i64) = i.to_i64().unwrap_or(0); + } + _ => { + let bytes = i.to_signed_bytes_le(); + let copy_len = bytes.len().min(size); + std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, copy_len); + } + } + return Ok(()); + } - let symbol_address = if let Some(lib) = &*inner_lib { - unsafe { - // Try to get the symbol from the library - let symbol: Symbol<'_, *mut u8> = lib.get(symbol_name.as_bytes()).map_err(|e| { - vm.new_attribute_error(format!("{}: symbol '{}' not found", e, name.as_str())) - })?; - *symbol as usize + // Try to get value as float + if let Ok(float_val) = value.try_float(vm) { + let f = float_val.to_f64(); + match size { + 4 => { + *(ptr as *mut f32) = f as f32; + } + 8 => { + *(ptr as *mut f64) = f; + } + _ => {} + } + return Ok(()); + } + + // Try bytes + if let Ok(bytes) = value.try_bytes_like(vm, |b| b.to_vec()) { + let copy_len = bytes.len().min(size); + std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, copy_len); + return Ok(()); } - } else { - return Err(vm.new_attribute_error("Library is closed".to_owned())); - }; - // For pointer types, we return a pointer to the symbol address - Ok(PyCPointer { - _base: PyCData::new(CDataObject::from_bytes(vec![], None)), - contents: PyRwLock::new(vm.ctx.new_int(symbol_address).into()), + Err(vm.new_type_error(format!( + "cannot convert {} to ctypes data", + value.class().name() + ))) } - .into_ref_with_type(vm, cls)? - .into()) } } diff --git a/crates/vm/src/stdlib/ctypes/simple.rs b/crates/vm/src/stdlib/ctypes/simple.rs new file mode 100644 index 00000000000..1c0ec250d72 --- /dev/null +++ b/crates/vm/src/stdlib/ctypes/simple.rs @@ -0,0 +1,1379 @@ +use super::_ctypes::CArgObject; +use super::array::{PyCArray, WCHAR_SIZE, wchar_to_bytes}; +use super::base::{ + CDATA_BUFFER_METHODS, FfiArgValue, PyCData, StgInfo, StgInfoFlags, buffer_to_ffi_value, + bytes_to_pyobject, +}; +use super::function::PyCFuncPtr; +use super::get_size; +use super::pointer::PyCPointer; +use crate::builtins::{PyByteArray, PyBytes, PyInt, PyNone, PyStr, PyType, PyTypeRef}; +use crate::convert::ToPyObject; +use crate::function::{Either, FuncArgs, OptionalArg}; +use crate::protocol::{BufferDescriptor, PyBuffer, PyNumberMethods}; +use crate::types::{AsBuffer, AsNumber, Constructor, Initializer, Representable}; +use crate::{AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; +use num_traits::ToPrimitive; +use std::fmt::Debug; + +/// Valid type codes for ctypes simple types +// spell-checker: disable-next-line +pub(super) const SIMPLE_TYPE_CHARS: &str = "cbBhHiIlLdfuzZqQPXOv?g"; + +/// Create a new simple type instance from a class +fn new_simple_type( + cls: Either<&PyObject, &Py>, + vm: &VirtualMachine, +) -> PyResult { + let cls = match cls { + Either::A(obj) => obj, + Either::B(typ) => typ.as_object(), + }; + + let _type_ = cls + .get_attr("_type_", vm) + .map_err(|_| vm.new_attribute_error("class must define a '_type_' attribute"))?; + + if !_type_.is_instance((&vm.ctx.types.str_type).as_ref(), vm)? { + return Err(vm.new_type_error("class must define a '_type_' string attribute")); + } + + let tp_str = _type_.str(vm)?.to_string(); + + if tp_str.len() != 1 { + return Err(vm.new_value_error(format!( + "class must define a '_type_' attribute which must be a string of length 1, str: {tp_str}" + ))); + } + + if !SIMPLE_TYPE_CHARS.contains(tp_str.as_str()) { + return Err(vm.new_attribute_error(format!( + "class must define a '_type_' attribute which must be\n a single character string containing one of {SIMPLE_TYPE_CHARS}, currently it is {tp_str}." + ))); + } + + let size = get_size(&tp_str); + Ok(PyCSimple(PyCData::from_bytes(vec![0u8; size], None))) +} + +fn set_primitive(_type_: &str, value: &PyObject, vm: &VirtualMachine) -> PyResult { + match _type_ { + "c" => { + // c_set: accepts bytes(len=1), bytearray(len=1), or int(0-255) + if value + .downcast_ref_if_exact::(vm) + .is_some_and(|v| v.len() == 1) + || value + .downcast_ref_if_exact::(vm) + .is_some_and(|v| v.borrow_buf().len() == 1) + || value.downcast_ref_if_exact::(vm).is_some_and(|v| { + v.as_bigint() + .to_i64() + .is_some_and(|n| (0..=255).contains(&n)) + }) + { + Ok(value.to_owned()) + } else { + Err(vm.new_type_error("one character bytes, bytearray or integer expected")) + } + } + "u" => { + if let Ok(b) = value.str(vm).map(|v| v.to_string().chars().count() == 1) { + if b { + Ok(value.to_owned()) + } else { + Err(vm.new_type_error("one character unicode string expected")) + } + } else { + Err(vm.new_type_error(format!( + "unicode string expected instead of {} instance", + value.class().name() + ))) + } + } + "b" | "h" | "H" | "i" | "I" | "l" | "q" | "L" | "Q" => { + // Support __index__ protocol + if value.try_index(vm).is_ok() { + Ok(value.to_owned()) + } else { + Err(vm.new_type_error(format!( + "an integer is required (got type {})", + value.class().name() + ))) + } + } + "f" | "d" | "g" => { + // Handle int specially to check overflow + if let Some(int_obj) = value.downcast_ref_if_exact::(vm) { + // Check if int can fit in f64 + if int_obj.as_bigint().to_f64().is_some() { + return Ok(value.to_owned()); + } else { + return Err(vm.new_overflow_error("int too large to convert to float")); + } + } + // __float__ protocol + if value.try_float(vm).is_ok() { + Ok(value.to_owned()) + } else { + Err(vm.new_type_error(format!("must be real number, not {}", value.class().name()))) + } + } + "?" => Ok(PyObjectRef::from( + vm.ctx.new_bool(value.to_owned().try_to_bool(vm)?), + )), + "v" => { + // VARIANT_BOOL: any truthy → True + Ok(PyObjectRef::from( + vm.ctx.new_bool(value.to_owned().try_to_bool(vm)?), + )) + } + "B" => { + // Support __index__ protocol + if value.try_index(vm).is_ok() { + // Store as-is, conversion to unsigned happens in the getter + Ok(value.to_owned()) + } else { + Err(vm.new_type_error(format!("int expected instead of {}", value.class().name()))) + } + } + "z" => { + if value.is(&vm.ctx.none) + || value.downcast_ref_if_exact::(vm).is_some() + || value.downcast_ref_if_exact::(vm).is_some() + { + Ok(value.to_owned()) + } else { + Err(vm.new_type_error(format!( + "bytes or integer address expected instead of {} instance", + value.class().name() + ))) + } + } + "Z" => { + if value.is(&vm.ctx.none) + || value.downcast_ref_if_exact::(vm).is_some() + || value.downcast_ref_if_exact::(vm).is_some() + { + Ok(value.to_owned()) + } else { + Err(vm.new_type_error(format!( + "unicode string or integer address expected instead of {} instance", + value.class().name() + ))) + } + } + // O_set: py_object accepts any Python object + "O" => Ok(value.to_owned()), + _ => { + // "P" + if value.downcast_ref_if_exact::(vm).is_some() + || value.downcast_ref_if_exact::(vm).is_some() + { + Ok(value.to_owned()) + } else { + Err(vm.new_type_error("cannot be converted to pointer")) + } + } + } +} + +#[pyclass(module = "_ctypes", name = "PyCSimpleType", base = PyType)] +#[derive(Debug)] +#[repr(transparent)] +pub struct PyCSimpleType(PyType); + +#[pyclass(flags(BASETYPE), with(AsNumber, Initializer))] +impl PyCSimpleType { + #[allow(clippy::new_ret_no_self)] + #[pymethod] + fn new(cls: PyTypeRef, _: OptionalArg, vm: &VirtualMachine) -> PyResult { + Ok(PyObjectRef::from( + new_simple_type(Either::B(&cls), vm)? + .into_ref_with_type(vm, cls)? + .clone(), + )) + } + + #[pymethod] + fn from_param(zelf: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // zelf is the class (e.g., c_int) that from_param was called on + let cls = zelf + .downcast::() + .map_err(|_| vm.new_type_error("from_param: expected a type"))?; + + // 1. If the value is already an instance of the requested type, return it + if value.is_instance(cls.as_object(), vm)? { + return Ok(value); + } + + // 2. Get the type code to determine conversion rules + let type_code = cls.type_code(vm); + + // 3. Handle None for pointer types (c_char_p, c_wchar_p, c_void_p) + if vm.is_none(&value) && matches!(type_code.as_deref(), Some("z") | Some("Z") | Some("P")) { + return Ok(value); + } + + // Helper to create CArgObject wrapping a simple instance + let create_simple_with_value = |type_str: &str, val: &PyObject| -> PyResult { + let simple = new_simple_type(Either::B(&cls), vm)?; + let buffer_bytes = value_to_bytes_endian(type_str, val, false, vm); + *simple.0.buffer.write() = std::borrow::Cow::Owned(buffer_bytes.clone()); + let simple_obj: PyObjectRef = simple.into_ref_with_type(vm, cls.clone())?.into(); + // from_param returns CArgObject, not the simple type itself + let tag = type_str.as_bytes().first().copied().unwrap_or(b'?'); + let ffi_value = buffer_to_ffi_value(type_str, &buffer_bytes); + Ok(CArgObject { + tag, + value: ffi_value, + obj: simple_obj, + size: 0, + offset: 0, + } + .to_pyobject(vm)) + }; + + // 4. Try to convert value based on type code + match type_code.as_deref() { + // Integer types: accept integers + Some(tc @ ("b" | "B" | "h" | "H" | "i" | "I" | "l" | "L" | "q" | "Q")) => { + if value.try_int(vm).is_ok() { + return create_simple_with_value(tc, &value); + } + } + // Float types: accept numbers + Some(tc @ ("f" | "d" | "g")) => { + if value.try_float(vm).is_ok() || value.try_int(vm).is_ok() { + return create_simple_with_value(tc, &value); + } + } + // c_char: 1 byte character + Some("c") => { + if let Some(bytes) = value.downcast_ref::() + && bytes.len() == 1 + { + return create_simple_with_value("c", &value); + } + if let Ok(int_val) = value.try_int(vm) + && int_val.as_bigint().to_u8().is_some() + { + return create_simple_with_value("c", &value); + } + return Err(vm.new_type_error( + "one character bytes, bytearray or integer expected".to_string(), + )); + } + // c_wchar: 1 unicode character + Some("u") => { + if let Some(s) = value.downcast_ref::() + && s.as_str().chars().count() == 1 + { + return create_simple_with_value("u", &value); + } + return Err(vm.new_type_error("one character unicode string expected")); + } + // c_char_p: bytes pointer + Some("z") => { + // 1. bytes → create CArgObject with null-terminated buffer + if let Some(bytes) = value.downcast_ref::() { + let (holder, ptr) = super::base::ensure_z_null_terminated(bytes, vm); + return Ok(CArgObject { + tag: b'z', + value: FfiArgValue::OwnedPointer(ptr, holder), + obj: value.clone(), + size: 0, + offset: 0, + } + .to_pyobject(vm)); + } + // 2. Array/Pointer with c_char element type + if is_cchar_array_or_pointer(&value, vm) { + return Ok(value); + } + // 3. CArgObject (byref(c_char(...))) + if let Some(carg) = value.downcast_ref::() + && carg.tag == b'c' + { + return Ok(value.clone()); + } + } + // c_wchar_p: unicode pointer + Some("Z") => { + // 1. str → create CArgObject with null-terminated wchar buffer + if let Some(s) = value.downcast_ref::() { + let (holder, ptr) = super::base::str_to_wchar_bytes(s.as_str(), vm); + return Ok(CArgObject { + tag: b'Z', + value: FfiArgValue::OwnedPointer(ptr, holder), + obj: value.clone(), + size: 0, + offset: 0, + } + .to_pyobject(vm)); + } + // 2. Array/Pointer with c_wchar element type + if is_cwchar_array_or_pointer(&value, vm) { + return Ok(value); + } + // 3. CArgObject (byref(c_wchar(...))) + if let Some(carg) = value.downcast_ref::() + && carg.tag == b'u' + { + return Ok(value.clone()); + } + } + // c_void_p: most flexible - accepts int, bytes, str, any array/pointer, funcptr + Some("P") => { + // 1. int → create c_void_p with that address + if value.try_int(vm).is_ok() { + return create_simple_with_value("P", &value); + } + // 2. bytes → create CArgObject with null-terminated buffer + if let Some(bytes) = value.downcast_ref::() { + let (holder, ptr) = super::base::ensure_z_null_terminated(bytes, vm); + return Ok(CArgObject { + tag: b'z', + value: FfiArgValue::OwnedPointer(ptr, holder), + obj: value.clone(), + size: 0, + offset: 0, + } + .to_pyobject(vm)); + } + // 3. str → create CArgObject with null-terminated wchar buffer + if let Some(s) = value.downcast_ref::() { + let (holder, ptr) = super::base::str_to_wchar_bytes(s.as_str(), vm); + return Ok(CArgObject { + tag: b'Z', + value: FfiArgValue::OwnedPointer(ptr, holder), + obj: value.clone(), + size: 0, + offset: 0, + } + .to_pyobject(vm)); + } + // 4. Any Array or Pointer → accept directly + if value.downcast_ref::().is_some() + || value.downcast_ref::().is_some() + { + return Ok(value); + } + // 5. CArgObject with 'P' tag (byref(c_void_p(...))) + if let Some(carg) = value.downcast_ref::() + && carg.tag == b'P' + { + return Ok(value.clone()); + } + // 6. PyCFuncPtr → extract function pointer address + if let Some(funcptr) = value.downcast_ref::() { + let ptr_val = { + let buffer = funcptr._base.buffer.read(); + if buffer.len() >= std::mem::size_of::() { + usize::from_ne_bytes( + buffer[..std::mem::size_of::()].try_into().unwrap(), + ) + } else { + 0 + } + }; + return Ok(CArgObject { + tag: b'P', + value: FfiArgValue::Pointer(ptr_val), + obj: value.clone(), + size: 0, + offset: 0, + } + .to_pyobject(vm)); + } + // 7. c_char_p or c_wchar_p instance → extract pointer value + if let Some(simple) = value.downcast_ref::() { + let value_type_code = value.class().type_code(vm); + if matches!(value_type_code.as_deref(), Some("z") | Some("Z")) { + let ptr_val = { + let buffer = simple.0.buffer.read(); + if buffer.len() >= std::mem::size_of::() { + usize::from_ne_bytes( + buffer[..std::mem::size_of::()].try_into().unwrap(), + ) + } else { + 0 + } + }; + return Ok(CArgObject { + tag: b'Z', + value: FfiArgValue::Pointer(ptr_val), + obj: value.clone(), + size: 0, + offset: 0, + } + .to_pyobject(vm)); + } + } + } + // c_bool + Some("?") => { + let bool_val = value.is_true(vm)?; + let bool_obj: PyObjectRef = vm.ctx.new_bool(bool_val).into(); + return create_simple_with_value("?", &bool_obj); + } + _ => {} + } + + // 5. Check for _as_parameter_ attribute + if let Ok(as_parameter) = value.get_attr("_as_parameter_", vm) { + return PyCSimpleType::from_param(cls.as_object().to_owned(), as_parameter, vm); + } + + // 6. Type-specific error messages + match type_code.as_deref() { + Some("z") => Err(vm.new_type_error(format!( + "'{}' object cannot be interpreted as ctypes.c_char_p", + value.class().name() + ))), + Some("Z") => Err(vm.new_type_error(format!( + "'{}' object cannot be interpreted as ctypes.c_wchar_p", + value.class().name() + ))), + _ => Err(vm.new_type_error("wrong type")), + } + } + + #[pymethod] + fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { + PyCSimple::repeat(cls, n, vm) + } +} + +impl AsNumber for PyCSimpleType { + fn as_number() -> &'static PyNumberMethods { + static AS_NUMBER: PyNumberMethods = PyNumberMethods { + multiply: Some(|a, b, vm| { + // a is a PyCSimpleType instance (type object like c_char) + // b is int (array size) + let cls = a + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("expected type"))?; + let n = b + .try_index(vm)? + .as_bigint() + .to_isize() + .ok_or_else(|| vm.new_overflow_error("array size too large"))?; + PyCSimple::repeat(cls.to_owned(), n, vm) + }), + ..PyNumberMethods::NOT_IMPLEMENTED + }; + &AS_NUMBER + } +} + +impl Initializer for PyCSimpleType { + type Args = FuncArgs; + + fn init(zelf: PyRef, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + // type_init requires exactly 3 positional arguments: name, bases, dict + if args.args.len() != 3 { + return Err(vm.new_type_error(format!( + "type.__init__() takes 3 positional arguments but {} were given", + args.args.len() + ))); + } + + // Get the type from the metatype instance + let type_ref: PyTypeRef = zelf + .as_object() + .to_owned() + .downcast() + .map_err(|_| vm.new_type_error("expected type"))?; + + type_ref.check_not_initialized(vm)?; + + // Get _type_ attribute + let type_attr = match type_ref.as_object().get_attr("_type_", vm) { + Ok(attr) => attr, + Err(_) => { + return Err(vm.new_attribute_error("class must define a '_type_' attribute")); + } + }; + + // Validate _type_ is a string + let type_str = type_attr.str(vm)?.to_string(); + + // Validate _type_ is a single character + if type_str.len() != 1 { + return Err(vm.new_value_error( + "class must define a '_type_' attribute which must be a string of length 1" + .to_owned(), + )); + } + + // Validate _type_ is a valid type character + if !SIMPLE_TYPE_CHARS.contains(type_str.as_str()) { + return Err(vm.new_attribute_error(format!( + "class must define a '_type_' attribute which must be a single character string containing one of '{}', currently it is '{}'.", + SIMPLE_TYPE_CHARS, type_str + ))); + } + + // Initialize StgInfo + let size = super::get_size(&type_str); + let align = super::get_align(&type_str); + let mut stg_info = StgInfo::new(size, align); + + // Set format for PEP 3118 buffer protocol + // Format is endian prefix + type code (e.g., "" + }; + stg_info.format = Some(format!("{}{}", endian_prefix, type_str)); + stg_info.paramfunc = super::base::ParamFunc::Simple; + + // Set TYPEFLAG_ISPOINTER for pointer types: z (c_char_p), Z (c_wchar_p), + // P (c_void_p), s (char array), X (BSTR), O (py_object) + if matches!(type_str.as_str(), "z" | "Z" | "P" | "s" | "X" | "O") { + stg_info.flags |= StgInfoFlags::TYPEFLAG_ISPOINTER; + } + + super::base::set_or_init_stginfo(&type_ref, stg_info); + + // Create __ctype_le__ and __ctype_be__ swapped types + create_swapped_types(&type_ref, &type_str, vm)?; + + Ok(()) + } +} + +/// Create __ctype_le__ and __ctype_be__ swapped byte order types +/// On little-endian systems: __ctype_le__ = self, __ctype_be__ = swapped type +/// On big-endian systems: __ctype_be__ = self, __ctype_le__ = swapped type +/// +/// - Single-byte types (c, b, B): __ctype_le__ = __ctype_be__ = self +/// - Pointer/unsupported types (z, Z, P, u, O): NO __ctype_le__/__ctype_be__ attributes +/// - Multi-byte numeric types (h, H, i, I, l, L, q, Q, f, d, g, ?): create swapped types +fn create_swapped_types( + type_ref: &Py, + type_str: &str, + vm: &VirtualMachine, +) -> PyResult<()> { + use crate::builtins::PyDict; + + // Avoid infinite recursion - if __ctype_le__ already exists, skip + if type_ref.as_object().get_attr("__ctype_le__", vm).is_ok() { + return Ok(()); + } + + // Types that don't support byte order swapping - no __ctype_le__/__ctype_be__ + // c_void_p (P), c_char_p (z), c_wchar_p (Z), c_wchar (u), py_object (O) + let unsupported_types = ["P", "z", "Z", "u", "O"]; + if unsupported_types.contains(&type_str) { + return Ok(()); + } + + // Single-byte types - __ctype_le__ = __ctype_be__ = self (no swapping needed) + // c_char (c), c_byte (b), c_ubyte (B) + let single_byte_types = ["c", "b", "B"]; + if single_byte_types.contains(&type_str) { + type_ref + .as_object() + .set_attr("__ctype_le__", type_ref.as_object().to_owned(), vm)?; + type_ref + .as_object() + .set_attr("__ctype_be__", type_ref.as_object().to_owned(), vm)?; + return Ok(()); + } + + // Multi-byte types - create swapped type + // Check system byte order at compile time + let is_little_endian = cfg!(target_endian = "little"); + + // Create dict for the swapped (non-native) type + let swapped_dict: crate::PyRef = PyDict::default().into_ref(&vm.ctx); + swapped_dict.set_item("_type_", vm.ctx.new_str(type_str).into(), vm)?; + + // Create the swapped type using the same metaclass + let metaclass = type_ref.class(); + let bases = vm.ctx.new_tuple(vec![type_ref.as_object().to_owned()]); + + // Set placeholder first to prevent recursion + type_ref + .as_object() + .set_attr("__ctype_le__", vm.ctx.none(), vm)?; + type_ref + .as_object() + .set_attr("__ctype_be__", vm.ctx.none(), vm)?; + + // Create only the non-native endian type + let suffix = if is_little_endian { "_be" } else { "_le" }; + let swapped_type = metaclass.as_object().call( + ( + vm.ctx.new_str(format!("{}{}", type_ref.name(), suffix)), + bases, + swapped_dict.as_object().to_owned(), + ), + vm, + )?; + + // Set _swappedbytes_ on the swapped type to indicate byte swapping is needed + swapped_type.set_attr("_swappedbytes_", vm.ctx.none(), vm)?; + + // Update swapped type's StgInfo format to use opposite endian prefix + // Native uses '<' on little-endian, '>' on big-endian + // Swapped uses the opposite + if let Ok(swapped_type_ref) = swapped_type.clone().downcast::() + && let Some(mut sw_stg) = swapped_type_ref.get_type_data_mut::() + { + let swapped_prefix = if is_little_endian { ">" } else { "<" }; + sw_stg.format = Some(format!("{}{}", swapped_prefix, type_str)); + } + + // Set attributes based on system byte order + // Native endian attribute points to self, non-native points to swapped type + if is_little_endian { + // Little-endian system: __ctype_le__ = self, __ctype_be__ = swapped + type_ref + .as_object() + .set_attr("__ctype_le__", type_ref.as_object().to_owned(), vm)?; + type_ref + .as_object() + .set_attr("__ctype_be__", swapped_type.clone(), vm)?; + swapped_type.set_attr("__ctype_le__", type_ref.as_object().to_owned(), vm)?; + swapped_type.set_attr("__ctype_be__", swapped_type.clone(), vm)?; + } else { + // Big-endian system: __ctype_be__ = self, __ctype_le__ = swapped + type_ref + .as_object() + .set_attr("__ctype_be__", type_ref.as_object().to_owned(), vm)?; + type_ref + .as_object() + .set_attr("__ctype_le__", swapped_type.clone(), vm)?; + swapped_type.set_attr("__ctype_be__", type_ref.as_object().to_owned(), vm)?; + swapped_type.set_attr("__ctype_le__", swapped_type.clone(), vm)?; + } + + Ok(()) +} + +#[pyclass( + module = "_ctypes", + name = "_SimpleCData", + base = PyCData, + metaclass = "PyCSimpleType" +)] +#[repr(transparent)] +pub struct PyCSimple(pub PyCData); + +impl Debug for PyCSimple { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyCSimple") + .field("size", &self.0.buffer.read().len()) + .finish() + } +} + +fn value_to_bytes_endian( + _type_: &str, + value: &PyObject, + swapped: bool, + vm: &VirtualMachine, +) -> Vec { + // Helper macro for endian conversion + macro_rules! to_bytes { + ($val:expr) => { + if swapped { + // Use opposite endianness + #[cfg(target_endian = "little")] + { + $val.to_be_bytes().to_vec() + } + #[cfg(target_endian = "big")] + { + $val.to_le_bytes().to_vec() + } + } else { + $val.to_ne_bytes().to_vec() + } + }; + } + + match _type_ { + "c" => { + // c_char - single byte (bytes, bytearray, or int 0-255) + if let Some(bytes) = value.downcast_ref::() + && !bytes.is_empty() + { + return vec![bytes.as_bytes()[0]]; + } + if let Some(bytearray) = value.downcast_ref::() { + let buf = bytearray.borrow_buf(); + if !buf.is_empty() { + return vec![buf[0]]; + } + } + if let Ok(int_val) = value.try_int(vm) + && let Some(v) = int_val.as_bigint().to_u8() + { + return vec![v]; + } + vec![0] + } + "u" => { + // c_wchar - platform-dependent size (2 on Windows, 4 on Unix) + if let Ok(s) = value.str(vm) + && let Some(c) = s.as_str().chars().next() + { + let mut buffer = vec![0u8; WCHAR_SIZE]; + wchar_to_bytes(c as u32, &mut buffer); + if swapped { + buffer.reverse(); + } + return buffer; + } + vec![0; WCHAR_SIZE] + } + "b" => { + // c_byte - signed char (1 byte) + // PyLong_AsLongMask pattern: wrapping for overflow values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().unwrap_or(0) as i8; + return vec![v as u8]; + } + vec![0] + } + "B" => { + // c_ubyte - unsigned char (1 byte) + // PyLong_AsUnsignedLongMask: wrapping for negative values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().map(|n| n as u8).unwrap_or(0); + return vec![v]; + } + vec![0] + } + "h" => { + // c_short (2 bytes) + // PyLong_AsLongMask pattern: wrapping for overflow values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().unwrap_or(0) as i16; + return to_bytes!(v); + } + vec![0; 2] + } + "H" => { + // c_ushort (2 bytes) + // PyLong_AsUnsignedLongMask: wrapping for negative values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().map(|n| n as u16).unwrap_or(0); + return to_bytes!(v); + } + vec![0; 2] + } + "i" => { + // c_int (4 bytes) + // PyLong_AsLongMask pattern: wrapping for overflow values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().unwrap_or(0) as i32; + return to_bytes!(v); + } + vec![0; 4] + } + "I" => { + // c_uint (4 bytes) + // PyLong_AsUnsignedLongMask: wrapping for negative values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().map(|n| n as u32).unwrap_or(0); + return to_bytes!(v); + } + vec![0; 4] + } + "l" => { + // c_long (platform dependent) + // PyLong_AsLongMask pattern: wrapping for overflow values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().unwrap_or(0) as libc::c_long; + return to_bytes!(v); + } + const SIZE: usize = std::mem::size_of::(); + vec![0; SIZE] + } + "L" => { + // c_ulong (platform dependent) + // PyLong_AsUnsignedLongMask: wrapping for negative values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val + .as_bigint() + .to_i128() + .map(|n| n as libc::c_ulong) + .unwrap_or(0); + return to_bytes!(v); + } + const SIZE: usize = std::mem::size_of::(); + vec![0; SIZE] + } + "q" => { + // c_longlong (8 bytes) + // PyLong_AsLongMask pattern: wrapping for overflow values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().unwrap_or(0) as i64; + return to_bytes!(v); + } + vec![0; 8] + } + "Q" => { + // c_ulonglong (8 bytes) + // PyLong_AsUnsignedLongLongMask: wrapping for negative values + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_i128().map(|n| n as u64).unwrap_or(0); + return to_bytes!(v); + } + vec![0; 8] + } + "f" => { + // c_float (4 bytes) - also accepts int + if let Ok(float_val) = value.try_float(vm) { + return to_bytes!(float_val.to_f64() as f32); + } + if let Ok(int_val) = value.try_int(vm) + && let Some(v) = int_val.as_bigint().to_f64() + { + return to_bytes!(v as f32); + } + vec![0; 4] + } + "d" => { + // c_double (8 bytes) - also accepts int + if let Ok(float_val) = value.try_float(vm) { + return to_bytes!(float_val.to_f64()); + } + if let Ok(int_val) = value.try_int(vm) + && let Some(v) = int_val.as_bigint().to_f64() + { + return to_bytes!(v); + } + vec![0; 8] + } + "g" => { + // long double - platform dependent size + // Store as f64, zero-pad to platform long double size + // Note: This may lose precision on platforms where long double > 64 bits + let f64_val = if let Ok(float_val) = value.try_float(vm) { + float_val.to_f64() + } else if let Ok(int_val) = value.try_int(vm) { + int_val.as_bigint().to_f64().unwrap_or(0.0) + } else { + 0.0 + }; + let f64_bytes = if swapped { + #[cfg(target_endian = "little")] + { + f64_val.to_be_bytes().to_vec() + } + #[cfg(target_endian = "big")] + { + f64_val.to_le_bytes().to_vec() + } + } else { + f64_val.to_ne_bytes().to_vec() + }; + // Pad to long double size + let long_double_size = super::get_size("g"); + let mut result = f64_bytes; + result.resize(long_double_size, 0); + result + } + "?" => { + // c_bool (1 byte) + if let Ok(b) = value.to_owned().try_to_bool(vm) { + return vec![if b { 1 } else { 0 }]; + } + vec![0] + } + "v" => { + // VARIANT_BOOL: True = 0xFFFF (-1 as i16), False = 0x0000 + if let Ok(b) = value.to_owned().try_to_bool(vm) { + let val: i16 = if b { -1 } else { 0 }; + return to_bytes!(val); + } + vec![0; 2] + } + "P" => { + // c_void_p - pointer type (platform pointer size) + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_usize().unwrap_or(0); + return to_bytes!(v); + } + vec![0; std::mem::size_of::()] + } + "z" => { + // c_char_p - pointer to char (stores pointer value from int) + // PyBytes case is handled in slot_new/set_value with make_z_buffer() + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_usize().unwrap_or(0); + return to_bytes!(v); + } + vec![0; std::mem::size_of::()] + } + "Z" => { + // c_wchar_p - pointer to wchar_t (stores pointer value from int) + // PyStr case is handled in slot_new/set_value with make_wchar_buffer() + if let Ok(int_val) = value.try_index(vm) { + let v = int_val.as_bigint().to_usize().unwrap_or(0); + return to_bytes!(v); + } + vec![0; std::mem::size_of::()] + } + "O" => { + // py_object - store object id as non-zero marker + // The actual object is stored in _objects + // Use object's id as a non-zero placeholder (indicates non-NULL) + let id = value.get_id(); + to_bytes!(id) + } + _ => vec![0], + } +} + +/// Check if value is a c_char array or pointer(c_char) +fn is_cchar_array_or_pointer(value: &PyObject, vm: &VirtualMachine) -> bool { + // Check Array with c_char element type + if let Some(arr) = value.downcast_ref::() + && let Some(info) = arr.class().stg_info_opt() + && let Some(ref elem_type) = info.element_type + && let Some(elem_code) = elem_type.class().type_code(vm) + { + return elem_code == "c"; + } + // Check Pointer to c_char + if let Some(ptr) = value.downcast_ref::() + && let Some(info) = ptr.class().stg_info_opt() + && let Some(ref proto) = info.proto + && let Some(proto_code) = proto.class().type_code(vm) + { + return proto_code == "c"; + } + false +} + +/// Check if value is a c_wchar array or pointer(c_wchar) +fn is_cwchar_array_or_pointer(value: &PyObject, vm: &VirtualMachine) -> bool { + // Check Array with c_wchar element type + if let Some(arr) = value.downcast_ref::() { + let info = arr.class().stg_info_opt().expect("array has StgInfo"); + let elem_type = info.element_type.as_ref().expect("array has element_type"); + if let Some(elem_code) = elem_type.class().type_code(vm) { + return elem_code == "u"; + } + } + // Check Pointer to c_wchar + if let Some(ptr) = value.downcast_ref::() { + let info = ptr.class().stg_info_opt().expect("pointer has StgInfo"); + if let Some(ref proto) = info.proto + && let Some(proto_code) = proto.class().type_code(vm) + { + return proto_code == "u"; + } + } + false +} + +impl Constructor for PyCSimple { + type Args = (OptionalArg,); + + fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + let args: Self::Args = args.bind(vm)?; + let _type_ = cls + .type_code(vm) + .ok_or_else(|| vm.new_type_error("abstract class"))?; + // Save the initial argument for c_char_p/c_wchar_p _objects + let init_arg = args.0.into_option(); + + // Handle z/Z types with PyBytes/PyStr separately to avoid memory leak + if let Some(ref v) = init_arg { + if _type_ == "z" { + if let Some(bytes) = v.downcast_ref::() { + let (converted, ptr) = super::base::ensure_z_null_terminated(bytes, vm); + let buffer = ptr.to_ne_bytes().to_vec(); + let cdata = PyCData::from_bytes(buffer, Some(converted)); + return PyCSimple(cdata).into_ref_with_type(vm, cls).map(Into::into); + } + } else if _type_ == "Z" + && let Some(s) = v.downcast_ref::() + { + let (holder, ptr) = super::base::str_to_wchar_bytes(s.as_str(), vm); + let buffer = ptr.to_ne_bytes().to_vec(); + let cdata = PyCData::from_bytes(buffer, Some(holder)); + return PyCSimple(cdata).into_ref_with_type(vm, cls).map(Into::into); + } + } + + let value = if let Some(ref v) = init_arg { + set_primitive(_type_.as_str(), v, vm)? + } else { + match _type_.as_str() { + "c" | "u" => PyObjectRef::from(vm.ctx.new_bytes(vec![0])), + "b" | "B" | "h" | "H" | "i" | "I" | "l" | "q" | "L" | "Q" => { + PyObjectRef::from(vm.ctx.new_int(0)) + } + "f" | "d" | "g" => PyObjectRef::from(vm.ctx.new_float(0.0)), + "?" => PyObjectRef::from(vm.ctx.new_bool(false)), + _ => vm.ctx.none(), // "z" | "Z" | "P" + } + }; + + // Check if this is a swapped endian type (presence of attribute indicates swapping) + let swapped = cls.as_object().get_attr("_swappedbytes_", vm).is_ok(); + + let buffer = value_to_bytes_endian(&_type_, &value, swapped, vm); + + // For c_char_p (type "z"), c_wchar_p (type "Z"), and py_object (type "O"), + // store the initial value in _objects + let objects = if (_type_ == "z" || _type_ == "Z" || _type_ == "O") && init_arg.is_some() { + init_arg + } else { + None + }; + + PyCSimple(PyCData::from_bytes(buffer, objects)) + .into_ref_with_type(vm, cls) + .map(Into::into) + } + + fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { + unimplemented!("use slot_new") + } +} + +// Simple_repr +impl Representable for PyCSimple { + fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let cls = zelf.class(); + let type_name = cls.name(); + + // Check if base is _SimpleCData (direct simple type like c_int, c_char) + // vs subclass of simple type (like class X(c_int): pass) + let bases = cls.bases.read(); + let is_direct_simple = bases + .iter() + .any(|base| base.name().to_string() == "_SimpleCData"); + + if is_direct_simple { + // Direct SimpleCData: "typename(repr(value))" + let value = PyCSimple::value(zelf.to_owned().into(), vm)?; + let value_repr = value.repr(vm)?.to_string(); + Ok(format!("{}({})", type_name, value_repr)) + } else { + // Subclass: "" + let addr = zelf.get_id(); + Ok(format!("<{} object at {:#x}>", type_name, addr)) + } + } +} + +#[pyclass(flags(BASETYPE), with(Constructor, AsBuffer, AsNumber, Representable))] +impl PyCSimple { + #[pygetset] + fn _b0_(&self) -> Option { + self.0.base.read().clone() + } + + /// return True if any byte in buffer is non-zero + #[pymethod] + fn __bool__(&self) -> bool { + let buffer = self.0.buffer.read(); + // Simple_bool: memcmp(self->b_ptr, zeros, self->b_size) + buffer.iter().any(|&b| b != 0) + } + + #[pygetset] + pub fn value(instance: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let zelf: &Py = instance + .downcast_ref() + .ok_or_else(|| vm.new_type_error("cannot get value of instance"))?; + + // Get _type_ from class + let cls = zelf.class(); + let type_attr = cls + .as_object() + .get_attr("_type_", vm) + .map_err(|_| vm.new_type_error("no _type_ attribute"))?; + let type_code = type_attr.str(vm)?.to_string(); + + // Special handling for c_char_p (z) and c_wchar_p (Z) + // z_get, Z_get - dereference pointer to get string + if type_code == "z" { + // c_char_p: read pointer from buffer, dereference to get bytes string + let buffer = zelf.0.buffer.read(); + let ptr = super::base::read_ptr_from_buffer(&buffer); + if ptr == 0 { + return Ok(vm.ctx.none()); + } + // Read null-terminated string at the address + unsafe { + let cstr = std::ffi::CStr::from_ptr(ptr as _); + return Ok(vm.ctx.new_bytes(cstr.to_bytes().to_vec()).into()); + } + } + if type_code == "Z" { + // c_wchar_p: read pointer from buffer, dereference to get wide string + let buffer = zelf.0.buffer.read(); + let ptr = super::base::read_ptr_from_buffer(&buffer); + if ptr == 0 { + return Ok(vm.ctx.none()); + } + // Read null-terminated wide string at the address + unsafe { + let w_ptr = ptr as *const libc::wchar_t; + let len = libc::wcslen(w_ptr); + let wchars = std::slice::from_raw_parts(w_ptr, len); + let s: String = wchars + .iter() + .filter_map(|&c| char::from_u32(c as u32)) + .collect(); + return Ok(vm.ctx.new_str(s).into()); + } + } + + // O_get: py_object - read PyObject pointer from buffer + if type_code == "O" { + let buffer = zelf.0.buffer.read(); + let ptr = super::base::read_ptr_from_buffer(&buffer); + if ptr == 0 { + return Err(vm.new_value_error("PyObject is NULL")); + } + // Non-NULL: return stored object from _objects if available + if let Some(obj) = zelf.0.objects.read().as_ref() { + return Ok(obj.clone()); + } + return Err(vm.new_value_error("PyObject is NULL")); + } + + // Check if this is a swapped endian type (presence of attribute indicates swapping) + let swapped = cls.as_object().get_attr("_swappedbytes_", vm).is_ok(); + + // Read value from buffer, swap bytes if needed + let buffer = zelf.0.buffer.read(); + let buffer_data: std::borrow::Cow<'_, [u8]> = if swapped { + // Reverse bytes for swapped endian types + let mut swapped_bytes = buffer.to_vec(); + swapped_bytes.reverse(); + std::borrow::Cow::Owned(swapped_bytes) + } else { + std::borrow::Cow::Borrowed(&*buffer) + }; + + let cls_ref = cls.to_owned(); + bytes_to_pyobject(&cls_ref, &buffer_data, vm).or_else(|_| { + // Fallback: return bytes as integer based on type + match type_code.as_str() { + "c" => { + if !buffer.is_empty() { + Ok(vm.ctx.new_bytes(vec![buffer[0]]).into()) + } else { + Ok(vm.ctx.new_bytes(vec![0]).into()) + } + } + "?" => { + let val = buffer.first().copied().unwrap_or(0); + Ok(vm.ctx.new_bool(val != 0).into()) + } + _ => Ok(vm.ctx.new_int(0).into()), + } + }) + } + + #[pygetset(setter)] + fn set_value(instance: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let zelf: PyRef = instance + .clone() + .downcast() + .map_err(|_| vm.new_type_error("cannot set value of instance"))?; + + // Get _type_ from class + let cls = zelf.class(); + let type_attr = cls + .as_object() + .get_attr("_type_", vm) + .map_err(|_| vm.new_type_error("no _type_ attribute"))?; + let type_code = type_attr.str(vm)?.to_string(); + + // Handle z/Z types with PyBytes/PyStr separately to avoid memory leak + if type_code == "z" { + if let Some(bytes) = value.downcast_ref::() { + let (converted, ptr) = super::base::ensure_z_null_terminated(bytes, vm); + *zelf.0.buffer.write() = std::borrow::Cow::Owned(ptr.to_ne_bytes().to_vec()); + *zelf.0.objects.write() = Some(converted); + return Ok(()); + } + } else if type_code == "Z" + && let Some(s) = value.downcast_ref::() + { + let (holder, ptr) = super::base::str_to_wchar_bytes(s.as_str(), vm); + *zelf.0.buffer.write() = std::borrow::Cow::Owned(ptr.to_ne_bytes().to_vec()); + *zelf.0.objects.write() = Some(holder); + return Ok(()); + } + + let content = set_primitive(&type_code, &value, vm)?; + + // Check if this is a swapped endian type (presence of attribute indicates swapping) + let swapped = instance + .class() + .as_object() + .get_attr("_swappedbytes_", vm) + .is_ok(); + + // Update buffer when value changes + let buffer_bytes = value_to_bytes_endian(&type_code, &content, swapped, vm); + *zelf.0.buffer.write() = std::borrow::Cow::Owned(buffer_bytes); + + // For c_char_p (type "z"), c_wchar_p (type "Z"), and py_object (type "O"), + // keep the reference in _objects + if type_code == "z" || type_code == "Z" || type_code == "O" { + *zelf.0.objects.write() = Some(value); + } + Ok(()) + } + + #[pyclassmethod] + fn repeat(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { + use super::array::array_type_from_ctype; + + if n < 0 { + return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}"))); + } + // Use cached array type creation + array_type_from_ctype(cls.into(), n as usize, vm) + } + + /// Simple_from_outparm - convert output parameter back to Python value + /// For direct subclasses of _SimpleCData (e.g., c_int), returns the value. + /// For subclasses of those (e.g., class MyInt(c_int)), returns self. + #[pymethod] + fn __ctypes_from_outparam__(zelf: PyRef, vm: &VirtualMachine) -> PyResult { + // _ctypes_simple_instance: returns true if NOT a direct subclass of Simple_Type + // i.e., c_int (direct) -> false, MyInt(c_int) (subclass) -> true + let is_subclass_of_simple = { + let cls = zelf.class(); + let bases = cls.bases.read(); + // If base is NOT _SimpleCData, then it's a subclass of a subclass + !bases + .iter() + .any(|base| base.name().to_string() == "_SimpleCData") + }; + + if is_subclass_of_simple { + // Subclass of simple type (e.g., MyInt(c_int)): return self + Ok(zelf.into()) + } else { + // Direct simple type (e.g., c_int): return value + PyCSimple::value(zelf.into(), vm) + } + } +} + +impl PyCSimple { + /// Extract the value from this ctypes object as an owned FfiArgValue. + /// The value must be kept alive until after the FFI call completes. + pub fn to_ffi_value( + &self, + ty: libffi::middle::Type, + _vm: &VirtualMachine, + ) -> Option { + let buffer = self.0.buffer.read(); + let bytes: &[u8] = &buffer; + + if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u8().as_raw_ptr()) { + if !bytes.is_empty() { + return Some(FfiArgValue::U8(bytes[0])); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i8().as_raw_ptr()) { + if !bytes.is_empty() { + return Some(FfiArgValue::I8(bytes[0] as i8)); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u16().as_raw_ptr()) { + if bytes.len() >= 2 { + return Some(FfiArgValue::U16(u16::from_ne_bytes([bytes[0], bytes[1]]))); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i16().as_raw_ptr()) { + if bytes.len() >= 2 { + return Some(FfiArgValue::I16(i16::from_ne_bytes([bytes[0], bytes[1]]))); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u32().as_raw_ptr()) { + if bytes.len() >= 4 { + return Some(FfiArgValue::U32(u32::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], + ]))); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i32().as_raw_ptr()) { + if bytes.len() >= 4 { + return Some(FfiArgValue::I32(i32::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], + ]))); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u64().as_raw_ptr()) { + if bytes.len() >= 8 { + return Some(FfiArgValue::U64(u64::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]))); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i64().as_raw_ptr()) { + if bytes.len() >= 8 { + return Some(FfiArgValue::I64(i64::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]))); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::f32().as_raw_ptr()) { + if bytes.len() >= 4 { + return Some(FfiArgValue::F32(f32::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], + ]))); + } + } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::f64().as_raw_ptr()) { + if bytes.len() >= 8 { + return Some(FfiArgValue::F64(f64::from_ne_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]))); + } + } else if std::ptr::eq( + ty.as_raw_ptr(), + libffi::middle::Type::pointer().as_raw_ptr(), + ) && bytes.len() >= std::mem::size_of::() + { + let val = + usize::from_ne_bytes(bytes[..std::mem::size_of::()].try_into().unwrap()); + return Some(FfiArgValue::Pointer(val)); + } + None + } +} + +impl AsBuffer for PyCSimple { + fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + let buffer_len = zelf.0.buffer.read().len(); + let buf = PyBuffer::new( + zelf.to_owned().into(), + BufferDescriptor::simple(buffer_len, false), // readonly=false for ctypes + &CDATA_BUFFER_METHODS, + ); + Ok(buf) + } +} + +/// Simple_bool: return non-zero if any byte in buffer is non-zero +impl AsNumber for PyCSimple { + fn as_number() -> &'static PyNumberMethods { + static AS_NUMBER: PyNumberMethods = PyNumberMethods { + boolean: Some(|obj, _vm| { + let zelf = obj + .downcast_ref::() + .expect("PyCSimple::as_number called on non-PyCSimple"); + let buffer = zelf.0.buffer.read(); + // Simple_bool: memcmp(self->b_ptr, zeros, self->b_size) + // Returns true if any byte is non-zero + Ok(buffer.iter().any(|&b| b != 0)) + }), + ..PyNumberMethods::NOT_IMPLEMENTED + }; + &AS_NUMBER + } +} diff --git a/crates/vm/src/stdlib/ctypes/structure.rs b/crates/vm/src/stdlib/ctypes/structure.rs index ca67a2fe7d6..10b8812e42c 100644 --- a/crates/vm/src/stdlib/ctypes/structure.rs +++ b/crates/vm/src/stdlib/ctypes/structure.rs @@ -1,42 +1,60 @@ -use super::base::{CDataObject, PyCData}; -use super::field::PyCField; -use super::util::StgInfo; +use super::base::{CDATA_BUFFER_METHODS, PyCData, PyCField, StgInfo, StgInfoFlags}; use crate::builtins::{PyList, PyStr, PyTuple, PyType, PyTypeRef}; use crate::convert::ToPyObject; use crate::function::FuncArgs; -use crate::protocol::{BufferDescriptor, BufferMethods, PyBuffer, PyNumberMethods}; -use crate::stdlib::ctypes::_ctypes::get_size; -use crate::types::{AsBuffer, AsNumber, Constructor}; -use crate::{AsObject, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine}; -use indexmap::IndexMap; +use crate::function::PySetterValue; +use crate::protocol::{BufferDescriptor, PyBuffer, PyNumberMethods}; +use crate::types::{AsBuffer, AsNumber, Constructor, Initializer, SetAttr}; +use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; use num_traits::ToPrimitive; -use rustpython_common::lock::PyRwLock; +use std::borrow::Cow; use std::fmt::Debug; +/// Calculate Structure type size from _fields_ (sum of field sizes) +pub(super) fn calculate_struct_size(cls: &Py, vm: &VirtualMachine) -> PyResult { + if let Ok(fields_attr) = cls.as_object().get_attr("_fields_", vm) { + let fields: Vec = fields_attr.try_to_value(vm)?; + let mut total_size = 0usize; + + for field in fields.iter() { + if let Some(tuple) = field.downcast_ref::() + && let Some(field_type) = tuple.get(1) + { + total_size += super::_ctypes::sizeof(field_type.clone(), vm)?; + } + } + return Ok(total_size); + } + Ok(0) +} + /// PyCStructType - metaclass for Structure #[pyclass(name = "PyCStructType", base = PyType, module = "_ctypes")] #[derive(Debug)] #[repr(transparent)] -pub struct PyCStructType(PyType); +pub(super) struct PyCStructType(PyType); impl Constructor for PyCStructType { type Args = FuncArgs; fn slot_new(metatype: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - // 1. Create the new class using PyType::py_new + // 1. Create the new class using PyType::slot_new let new_class = crate::builtins::type_::PyType::slot_new(metatype, args, vm)?; - // 2. Process _fields_ if defined on the new class + // 2. Get the new type let new_type = new_class .clone() .downcast::() .map_err(|_| vm.new_type_error("expected type"))?; - // Only process _fields_ if defined directly on this class (not inherited) - if let Some(fields_attr) = new_type.get_direct_attr(vm.ctx.intern_str("_fields_")) { - Self::process_fields(&new_type, fields_attr, vm)?; - } + // 3. Mark base classes as finalized (subclassing finalizes the parent) + new_type.mark_bases_final(); + + // 4. Initialize StgInfo for the new type (initialized=false, to be set in init) + let stg_info = StgInfo::default(); + let _ = new_type.init_type_data(stg_info); + // Note: _fields_ processing moved to Initializer::init() Ok(new_class) } @@ -45,11 +63,102 @@ impl Constructor for PyCStructType { } } -#[pyclass(flags(BASETYPE), with(AsNumber, Constructor))] +impl Initializer for PyCStructType { + type Args = FuncArgs; + + fn init(zelf: crate::PyRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + // Get the type as PyTypeRef by converting PyRef -> PyObjectRef -> PyRef + let obj: PyObjectRef = zelf.clone().into(); + let new_type: PyTypeRef = obj + .downcast() + .map_err(|_| vm.new_type_error("expected type"))?; + + // Backward compatibility: skip initialization for abstract types + if new_type + .get_direct_attr(vm.ctx.intern_str("_abstract_")) + .is_some() + { + return Ok(()); + } + + new_type.check_not_initialized(vm)?; + + // Process _fields_ if defined directly on this class (not inherited) + if let Some(fields_attr) = new_type.get_direct_attr(vm.ctx.intern_str("_fields_")) { + Self::process_fields(&new_type, fields_attr, vm)?; + } else { + // No _fields_ defined - try to copy from base class (PyCStgInfo_clone) + let (has_base_info, base_clone) = { + let bases = new_type.bases.read(); + if let Some(base) = bases.first() { + (base.stg_info_opt().is_some(), Some(base.clone())) + } else { + (false, None) + } + }; + + if has_base_info && let Some(ref base) = base_clone { + // Clone base StgInfo (release guard before getting mutable reference) + let stg_info_opt = base.stg_info_opt().map(|baseinfo| { + let mut stg_info = baseinfo.clone(); + stg_info.flags &= !StgInfoFlags::DICTFLAG_FINAL; // Clear FINAL in subclass + stg_info.initialized = true; + stg_info + }); + + if let Some(stg_info) = stg_info_opt { + // Mark base as FINAL (now guard is released) + if let Some(mut base_stg) = base.get_type_data_mut::() { + base_stg.flags |= StgInfoFlags::DICTFLAG_FINAL; + } + + super::base::set_or_init_stginfo(&new_type, stg_info); + return Ok(()); + } + } + + // No base StgInfo - create default + let mut stg_info = StgInfo::new(0, 1); + stg_info.paramfunc = super::base::ParamFunc::Structure; + stg_info.format = Some("B".to_string()); + super::base::set_or_init_stginfo(&new_type, stg_info); + } + + Ok(()) + } +} + +#[pyclass(flags(BASETYPE), with(AsNumber, Constructor, Initializer, SetAttr))] impl PyCStructType { + #[pymethod] + fn from_param(zelf: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // zelf is the structure type class that from_param was called on + let cls = zelf + .downcast::() + .map_err(|_| vm.new_type_error("from_param: expected a type"))?; + + // 1. If already an instance of the requested type, return it + if value.is_instance(cls.as_object(), vm)? { + return Ok(value); + } + + // 2. Check for _as_parameter_ attribute + if let Ok(as_parameter) = value.get_attr("_as_parameter_", vm) { + return PyCStructType::from_param(cls.as_object().to_owned(), as_parameter, vm); + } + + Err(vm.new_type_error(format!( + "expected {} instance instead of {}", + cls.name(), + value.class().name() + ))) + } + /// Called when a new Structure subclass is created #[pyclassmethod] fn __init_subclass__(cls: PyTypeRef, vm: &VirtualMachine) -> PyResult<()> { + cls.mark_bases_final(); + // Check if _fields_ is defined if let Some(fields_attr) = cls.get_direct_attr(vm.ctx.intern_str("_fields_")) { Self::process_fields(&cls, fields_attr, vm)?; @@ -59,24 +168,63 @@ impl PyCStructType { /// Process _fields_ and create CField descriptors fn process_fields( - cls: &PyTypeRef, + cls: &Py, fields_attr: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { + // Check if this is a swapped byte order structure + let is_swapped = cls.as_object().get_attr("_swappedbytes_", vm).is_ok(); + // Try to downcast to list or tuple let fields: Vec = if let Some(list) = fields_attr.downcast_ref::() { list.borrow_vec().to_vec() } else if let Some(tuple) = fields_attr.downcast_ref::() { tuple.to_vec() } else { - return Err(vm.new_type_error("_fields_ must be a list or tuple".to_string())); + return Err(vm.new_type_error("_fields_ must be a list or tuple")); + }; + + let pack = super::base::get_usize_attr(cls.as_object(), "_pack_", 0, vm)?; + let forced_alignment = + super::base::get_usize_attr(cls.as_object(), "_align_", 1, vm)?.max(1); + + // Determine byte order for format string + let big_endian = super::base::is_big_endian(is_swapped); + + // Initialize offset, alignment, type flags, and ffi_field_types from base class + let ( + mut offset, + mut max_align, + mut has_pointer, + mut has_union, + mut has_bitfield, + mut ffi_field_types, + ) = { + let bases = cls.bases.read(); + if let Some(base) = bases.first() + && let Some(baseinfo) = base.stg_info_opt() + { + ( + baseinfo.size, + std::cmp::max(baseinfo.align, forced_alignment), + baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASPOINTER), + baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASUNION), + baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASBITFIELD), + baseinfo.ffi_field_types.clone(), + ) + } else { + (0, forced_alignment, false, false, false, Vec::new()) + } }; - let mut offset = 0usize; + // Initialize PEP3118 format string + let mut format = String::from("T{"); + let mut last_end = 0usize; // Track end of last field for padding calculation + for (index, field) in fields.iter().enumerate() { let field_tuple = field .downcast_ref::() - .ok_or_else(|| vm.new_type_error("_fields_ must contain tuples".to_string()))?; + .ok_or_else(|| vm.new_type_error("_fields_ must contain tuples"))?; if field_tuple.len() < 2 { return Err(vm.new_type_error( @@ -86,99 +234,173 @@ impl PyCStructType { let name = field_tuple .first() - .unwrap() + .expect("len checked") .downcast_ref::() - .ok_or_else(|| vm.new_type_error("field name must be a string".to_string()))? + .ok_or_else(|| vm.new_type_error("field name must be a string"))? .to_string(); - let field_type = field_tuple.get(1).unwrap().clone(); + let field_type = field_tuple.get(1).expect("len checked").clone(); + + // For swapped byte order structures, validate field type supports byte swapping + if is_swapped { + super::base::check_other_endian_support(&field_type, vm)?; + } + + // Get size and alignment of the field type + let size = super::base::get_field_size(&field_type, vm)?; + let field_align = super::base::get_field_align(&field_type, vm); + + // Calculate effective alignment (PyCField_FromDesc) + let effective_align = if pack > 0 { + std::cmp::min(pack, field_align) + } else { + field_align + }; + + // Apply padding to align offset (cfield.c NO_BITFIELD case) + if effective_align > 0 && offset % effective_align != 0 { + let delta = effective_align - (offset % effective_align); + offset += delta; + } + + max_align = max_align.max(effective_align); + + // Propagate type flags from field type (HASPOINTER, HASUNION, HASBITFIELD) + if let Some(type_obj) = field_type.downcast_ref::() + && let Some(field_stg) = type_obj.stg_info_opt() + { + // HASPOINTER: propagate if field is pointer or contains pointer + if field_stg.flags.intersects( + StgInfoFlags::TYPEFLAG_ISPOINTER | StgInfoFlags::TYPEFLAG_HASPOINTER, + ) { + has_pointer = true; + } + // HASUNION, HASBITFIELD: propagate directly + if field_stg.flags.contains(StgInfoFlags::TYPEFLAG_HASUNION) { + has_union = true; + } + if field_stg.flags.contains(StgInfoFlags::TYPEFLAG_HASBITFIELD) { + has_bitfield = true; + } + // Collect FFI type for this field + ffi_field_types.push(field_stg.to_ffi_type()); + } + + // Mark field type as finalized (using type as field finalizes it) + if let Some(type_obj) = field_type.downcast_ref::() { + if let Some(mut stg_info) = type_obj.get_type_data_mut::() { + stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL; + } else { + // Create StgInfo with FINAL flag if it doesn't exist + let mut stg_info = StgInfo::new(size, field_align); + stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL; + let _ = type_obj.init_type_data(stg_info); + } + } + + // Build format string: add padding before field + let padding = offset - last_end; + if padding > 0 { + if padding != 1 { + format.push_str(&padding.to_string()); + } + format.push('x'); + } + + // Get field format and add to format string + let field_format = super::base::get_field_format(&field_type, big_endian, vm); + + // Handle arrays: prepend shape + if let Some(type_obj) = field_type.downcast_ref::() + && let Some(field_stg) = type_obj.stg_info_opt() + && !field_stg.shape.is_empty() + { + let shape_str = field_stg + .shape + .iter() + .map(|d| d.to_string()) + .collect::>() + .join(","); + format.push_str(&std::format!("({}){}", shape_str, field_format)); + } else { + format.push_str(&field_format); + } - // Get size of the field type - let size = Self::get_field_size(&field_type, vm)?; + // Add field name + format.push(':'); + format.push_str(&name); + format.push(':'); - // Create CField descriptor (accepts any ctypes type including arrays) - let c_field = PyCField::new(name.clone(), field_type, offset, size, index); + // Create CField descriptor with padding-adjusted offset + let field_type_ref = field_type + .clone() + .downcast::() + .map_err(|_| vm.new_type_error("_fields_ type must be a ctypes type"))?; + let c_field = PyCField::new(field_type_ref, offset as isize, size as isize, index); // Set the CField as a class attribute - cls.set_attr(vm.ctx.intern_str(name), c_field.to_pyobject(vm)); + cls.set_attr(vm.ctx.intern_str(name.clone()), c_field.to_pyobject(vm)); + // Update tracking + last_end = offset + size; offset += size; } - Ok(()) - } + // Calculate total_align = max(max_align, forced_alignment) + let total_align = std::cmp::max(max_align, forced_alignment); - /// Get the size of a ctypes type - fn get_field_size(field_type: &PyObject, vm: &VirtualMachine) -> PyResult { - // Try to get _type_ attribute for simple types - if let Some(size) = field_type - .get_attr("_type_", vm) - .ok() - .and_then(|type_attr| type_attr.str(vm).ok()) - .and_then(|type_str| { - let s = type_str.to_string(); - (s.len() == 1).then(|| get_size(&s)) - }) - { - return Ok(size); - } + // Calculate aligned_size (PyCStructUnionType_update_stginfo) + let aligned_size = if total_align > 0 { + offset.div_ceil(total_align) * total_align + } else { + offset + }; - // Try sizeof for other types - if let Some(s) = field_type - .get_attr("size_of_instances", vm) - .ok() - .and_then(|size_method| size_method.call((), vm).ok()) - .and_then(|size| size.try_int(vm).ok()) - .and_then(|n| n.as_bigint().to_usize()) - { - return Ok(s); + // Complete format string: add final padding and close + let final_padding = aligned_size - last_end; + if final_padding > 0 { + if final_padding != 1 { + format.push_str(&final_padding.to_string()); + } + format.push('x'); + } + format.push('}'); + + // Store StgInfo with aligned size and total alignment + let mut stg_info = StgInfo::new(aligned_size, total_align); + stg_info.format = Some(format); + stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL; // Mark as finalized + if has_pointer { + stg_info.flags |= StgInfoFlags::TYPEFLAG_HASPOINTER; + } + if has_union { + stg_info.flags |= StgInfoFlags::TYPEFLAG_HASUNION; } + if has_bitfield { + stg_info.flags |= StgInfoFlags::TYPEFLAG_HASBITFIELD; + } + stg_info.paramfunc = super::base::ParamFunc::Structure; + // Set byte order: swap if _swappedbytes_ is defined + stg_info.big_endian = super::base::is_big_endian(is_swapped); + // Store FFI field types for structure passing + stg_info.ffi_field_types = ffi_field_types; + super::base::set_or_init_stginfo(cls, stg_info); - // Default to pointer size for unknown types - Ok(std::mem::size_of::()) - } + // Process _anonymous_ fields + super::base::make_anon_fields(cls, vm)?; - /// Get the alignment of a ctypes type - fn get_field_align(field_type: &PyObject, vm: &VirtualMachine) -> usize { - // Try to get _type_ attribute for simple types - if let Some(align) = field_type - .get_attr("_type_", vm) - .ok() - .and_then(|type_attr| type_attr.str(vm).ok()) - .and_then(|type_str| { - let s = type_str.to_string(); - (s.len() == 1).then(|| get_size(&s)) // alignment == size for simple types - }) - { - return align; - } - // Default alignment - 1 + Ok(()) } #[pymethod] fn __mul__(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult { - use super::array::create_array_type_with_stg_info; - use crate::stdlib::ctypes::_ctypes::size_of; + use super::array::array_type_from_ctype; if n < 0 { return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}"))); } - - // Calculate element size from the Structure type - let element_size = size_of(cls.clone().into(), vm)?; - - let total_size = element_size - .checked_mul(n as usize) - .ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?; - let stg_info = super::util::StgInfo::new_array( - total_size, - element_size, - n as usize, - cls.clone().into(), - element_size, - ); - create_array_type_with_stg_info(stg_info, vm) + // Use cached array type creation + array_type_from_ctype(cls.into(), n as usize, vm) } } @@ -188,12 +410,12 @@ impl AsNumber for PyCStructType { multiply: Some(|a, b, vm| { let cls = a .downcast_ref::() - .ok_or_else(|| vm.new_type_error("expected type".to_owned()))?; + .ok_or_else(|| vm.new_type_error("expected type"))?; let n = b .try_index(vm)? .as_bigint() .to_isize() - .ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?; + .ok_or_else(|| vm.new_overflow_error("array size too large"))?; PyCStructType::__mul__(cls.to_owned(), n, vm) }), ..PyNumberMethods::NOT_IMPLEMENTED @@ -202,14 +424,70 @@ impl AsNumber for PyCStructType { } } -/// Structure field info stored in instance -#[allow(dead_code)] -#[derive(Debug, Clone)] -pub struct FieldInfo { - pub name: String, - pub offset: usize, - pub size: usize, - pub type_ref: PyTypeRef, +impl SetAttr for PyCStructType { + fn setattro( + zelf: &Py, + attr_name: &Py, + value: PySetterValue, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Check if _fields_ is being set + if attr_name.as_str() == "_fields_" { + let pytype: &Py = zelf.to_base(); + + // Check finalization in separate scope to release read lock before process_fields + // This prevents deadlock: process_fields needs write lock on the same RwLock + let is_final = { + let Some(stg_info) = pytype.get_type_data::() else { + return Err(vm.new_type_error("ctypes state is not initialized")); + }; + stg_info.is_final() + }; // Read lock released here + + if is_final { + return Err(vm.new_attribute_error("_fields_ is final")); + } + + // Process _fields_ and set attribute + let PySetterValue::Assign(fields_value) = value else { + return Err(vm.new_attribute_error("cannot delete _fields_")); + }; + // Process fields (this will also set DICTFLAG_FINAL) + PyCStructType::process_fields(pytype, fields_value.clone(), vm)?; + // Set the _fields_ attribute on the type + pytype + .attributes + .write() + .insert(vm.ctx.intern_str("_fields_"), fields_value); + return Ok(()); + } + // Delegate to PyType's setattro logic for type attributes + let attr_name_interned = vm.ctx.intern_str(attr_name.as_str()); + let pytype: &Py = zelf.to_base(); + + // Check for data descriptor first + if let Some(attr) = pytype.get_class_attr(attr_name_interned) { + let descr_set = attr.class().mro_find_map(|cls| cls.slots.descr_set.load()); + if let Some(descriptor) = descr_set { + return descriptor(&attr, pytype.to_owned().into(), value, vm); + } + } + + // Store in type's attributes dict + if let PySetterValue::Assign(value) = value { + pytype.attributes.write().insert(attr_name_interned, value); + } else { + let prev = pytype.attributes.write().shift_remove(attr_name_interned); + if prev.is_none() { + return Err(vm.new_attribute_error(format!( + "type object '{}' has no attribute '{}'", + pytype.name(), + attr_name.as_str(), + ))); + } + } + Ok(()) + } } /// PyCStructure - base class for Structure instances @@ -219,19 +497,13 @@ pub struct FieldInfo { base = PyCData, metaclass = "PyCStructType" )] -pub struct PyCStructure { - _base: PyCData, - /// Common CDataObject for memory buffer - pub(super) cdata: PyRwLock, - /// Field information (name -> FieldInfo) - #[allow(dead_code)] - pub(super) fields: PyRwLock>, -} +#[repr(transparent)] +pub struct PyCStructure(pub PyCData); impl Debug for PyCStructure { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PyCStructure") - .field("size", &self.cdata.read().size()) + .field("size", &self.0.size()) .finish() } } @@ -240,13 +512,22 @@ impl Constructor for PyCStructure { type Args = FuncArgs; fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + // Check for abstract class and extract values in a block to drop the borrow + let (total_size, total_align, length) = { + let stg_info = cls.stg_info(vm)?; + (stg_info.size, stg_info.align, stg_info.length) + }; + + // Mark the class as finalized (instance creation finalizes the type) + if let Some(mut stg_info_mut) = cls.get_type_data_mut::() { + stg_info_mut.flags |= StgInfoFlags::DICTFLAG_FINAL; + } + // Get _fields_ from the class using get_attr to properly search MRO let fields_attr = cls.as_object().get_attr("_fields_", vm).ok(); - let mut fields_map = IndexMap::new(); - let mut total_size = 0usize; - let mut max_align = 1usize; - + // Collect field names for initialization + let mut field_names: Vec = Vec::new(); if let Some(fields_attr) = fields_attr { let fields: Vec = if let Some(list) = fields_attr.downcast_ref::() { @@ -257,7 +538,6 @@ impl Constructor for PyCStructure { vec![] }; - let mut offset = 0usize; for field in fields.iter() { let Some(field_tuple) = field.downcast_ref::() else { continue; @@ -265,43 +545,21 @@ impl Constructor for PyCStructure { if field_tuple.len() < 2 { continue; } - let Some(name) = field_tuple.first().unwrap().downcast_ref::() else { - continue; - }; - let name = name.to_string(); - let field_type = field_tuple.get(1).unwrap().clone(); - let size = PyCStructType::get_field_size(&field_type, vm)?; - let field_align = PyCStructType::get_field_align(&field_type, vm); - max_align = max_align.max(field_align); - - let type_ref = field_type - .downcast::() - .unwrap_or_else(|_| vm.ctx.types.object_type.to_owned()); - - fields_map.insert( - name.clone(), - FieldInfo { - name, - offset, - size, - type_ref, - }, - ); - - offset += size; + if let Some(name) = field_tuple.first().unwrap().downcast_ref::() { + field_names.push(name.to_string()); + } } - total_size = offset; } - // Initialize buffer with zeros - let mut stg_info = StgInfo::new(total_size, max_align); - stg_info.length = fields_map.len(); - let cdata = CDataObject::from_stg_info(&stg_info); - let instance = PyCStructure { - _base: PyCData::new(cdata.clone()), - cdata: PyRwLock::new(cdata), - fields: PyRwLock::new(fields_map.clone()), + // Initialize buffer with zeros using computed size + let mut stg_info = StgInfo::new(total_size, total_align); + stg_info.length = if length > 0 { + length + } else { + field_names.len() }; + stg_info.paramfunc = super::base::ParamFunc::Structure; + let instance = PyCStructure(PyCData::from_stg_info(&stg_info)); // Handle keyword arguments for field initialization let py_instance = instance.into_ref_with_type(vm, cls.clone())?; @@ -309,21 +567,21 @@ impl Constructor for PyCStructure { // Set field values from kwargs using standard attribute setting for (key, value) in args.kwargs.iter() { - if fields_map.contains_key(key.as_str()) { + if field_names.iter().any(|n| n == key.as_str()) { py_obj.set_attr(vm.ctx.intern_str(key.as_str()), value.clone(), vm)?; } } // Set field values from positional args - let field_names: Vec = fields_map.keys().cloned().collect(); + if args.args.len() > field_names.len() { + return Err(vm.new_type_error("too many initializers".to_string())); + } for (i, value) in args.args.iter().enumerate() { - if i < field_names.len() { - py_obj.set_attr( - vm.ctx.intern_str(field_names[i].as_str()), - value.clone(), - vm, - )?; - } + py_obj.set_attr( + vm.ctx.intern_str(field_names[i].as_str()), + value.clone(), + vm, + )?; } Ok(py_instance.into()) @@ -337,11 +595,11 @@ impl Constructor for PyCStructure { // Note: GetAttr and SetAttr are not implemented here. // Field access is handled by CField descriptors registered on the class. -#[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor))] +#[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor, AsBuffer))] impl PyCStructure { #[pygetset] - fn _objects(&self) -> Option { - self.cdata.read().objects.clone() + fn _b0_(&self) -> Option { + self.0.base.read().clone() } #[pygetset] @@ -349,165 +607,30 @@ impl PyCStructure { // Return the _fields_ from the class, not instance vm.ctx.none() } - - #[pyclassmethod] - fn from_address(cls: PyTypeRef, address: isize, vm: &VirtualMachine) -> PyResult { - use crate::stdlib::ctypes::_ctypes::size_of; - - // Get size from cls - let size = size_of(cls.clone().into(), vm)?; - - // Read data from address - if address == 0 || size == 0 { - return Err(vm.new_value_error("NULL pointer access".to_owned())); - } - let data = unsafe { - let ptr = address as *const u8; - std::slice::from_raw_parts(ptr, size).to_vec() - }; - - // Create instance - let cdata = CDataObject::from_bytes(data, None); - Ok(PyCStructure { - _base: PyCData::new(cdata.clone()), - cdata: PyRwLock::new(cdata), - fields: PyRwLock::new(IndexMap::new()), - } - .into_ref_with_type(vm, cls)? - .into()) - } - - #[pyclassmethod] - fn from_buffer( - cls: PyTypeRef, - source: PyObjectRef, - offset: crate::function::OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - use crate::TryFromObject; - use crate::protocol::PyBuffer; - use crate::stdlib::ctypes::_ctypes::size_of; - - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); - } - let offset = offset as usize; - - // Get buffer from source - let buffer = PyBuffer::try_from_object(vm, source.clone())?; - - // Check if buffer is writable - if buffer.desc.readonly { - return Err(vm.new_type_error("underlying buffer is not writable".to_owned())); - } - - // Get size from cls - let size = size_of(cls.clone().into(), vm)?; - - // Check if buffer is large enough - let buffer_len = buffer.desc.len; - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); - } - - // Read bytes from buffer at offset - let bytes = buffer.obj_bytes(); - let data = bytes[offset..offset + size].to_vec(); - - // Create instance - let cdata = CDataObject::from_bytes(data, Some(source)); - Ok(PyCStructure { - _base: PyCData::new(cdata.clone()), - cdata: PyRwLock::new(cdata), - fields: PyRwLock::new(IndexMap::new()), - } - .into_ref_with_type(vm, cls)? - .into()) - } - - #[pyclassmethod] - fn from_buffer_copy( - cls: PyTypeRef, - source: crate::function::ArgBytesLike, - offset: crate::function::OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - use crate::stdlib::ctypes::_ctypes::size_of; - - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); - } - let offset = offset as usize; - - // Get size from cls - let size = size_of(cls.clone().into(), vm)?; - - // Borrow bytes from source - let source_bytes = source.borrow_buf(); - let buffer_len = source_bytes.len(); - - // Check if buffer is large enough - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); - } - - // Copy bytes from buffer at offset - let data = source_bytes[offset..offset + size].to_vec(); - - // Create instance - let cdata = CDataObject::from_bytes(data, None); - Ok(PyCStructure { - _base: PyCData::new(cdata.clone()), - cdata: PyRwLock::new(cdata), - fields: PyRwLock::new(IndexMap::new()), - } - .into_ref_with_type(vm, cls)? - .into()) - } } -static STRUCTURE_BUFFER_METHODS: BufferMethods = BufferMethods { - obj_bytes: |buffer| { - rustpython_common::lock::PyMappedRwLockReadGuard::map( - rustpython_common::lock::PyRwLockReadGuard::map( - buffer.obj_as::().cdata.read(), - |x: &CDataObject| x, - ), - |x: &CDataObject| x.buffer.as_slice(), - ) - .into() - }, - obj_bytes_mut: |buffer| { - rustpython_common::lock::PyMappedRwLockWriteGuard::map( - rustpython_common::lock::PyRwLockWriteGuard::map( - buffer.obj_as::().cdata.write(), - |x: &mut CDataObject| x, - ), - |x: &mut CDataObject| x.buffer.as_mut_slice(), - ) - .into() - }, - release: |_| {}, - retain: |_| {}, -}; - impl AsBuffer for PyCStructure { fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { - let buffer_len = zelf.cdata.read().buffer.len(); + let buffer_len = zelf.0.buffer.read().len(); + + // PyCData_NewGetBuffer: use info->format if available, otherwise "B" + let format = zelf + .class() + .stg_info_opt() + .and_then(|info| info.format.clone()) + .unwrap_or_else(|| "B".to_string()); + + // Structure: ndim=0, shape=(), itemsize=struct_size let buf = PyBuffer::new( zelf.to_owned().into(), - BufferDescriptor::simple(buffer_len, false), // readonly=false for ctypes - &STRUCTURE_BUFFER_METHODS, + BufferDescriptor { + len: buffer_len, + readonly: false, + itemsize: buffer_len, + format: Cow::Owned(format), + dim_desc: vec![], // ndim=0 means empty dim_desc + }, + &CDATA_BUFFER_METHODS, ); Ok(buf) } diff --git a/crates/vm/src/stdlib/ctypes/thunk.rs b/crates/vm/src/stdlib/ctypes/thunk.rs deleted file mode 100644 index 2de2308e1a3..00000000000 --- a/crates/vm/src/stdlib/ctypes/thunk.rs +++ /dev/null @@ -1,319 +0,0 @@ -//! FFI callback (thunk) implementation for ctypes. -//! -//! This module implements CThunkObject which wraps Python callables -//! to be callable from C code via libffi closures. - -use crate::builtins::{PyStr, PyType, PyTypeRef}; -use crate::vm::thread::with_current_vm; -use crate::{PyObjectRef, PyPayload, PyResult, VirtualMachine}; -use libffi::low; -use libffi::middle::{Cif, Closure, CodePtr, Type}; -use num_traits::ToPrimitive; -use rustpython_common::lock::PyRwLock; -use std::ffi::c_void; -use std::fmt::Debug; - -use super::base::ffi_type_from_str; -/// Userdata passed to the libffi callback. -/// This contains everything needed to invoke the Python callable. -pub struct ThunkUserData { - /// The Python callable to invoke - pub callable: PyObjectRef, - /// Argument types for conversion - pub arg_types: Vec, - /// Result type for conversion (None means void) - pub res_type: Option, -} - -/// Get the type code string from a ctypes type -fn get_type_code(ty: &PyTypeRef, vm: &VirtualMachine) -> Option { - ty.get_attr(vm.ctx.intern_str("_type_")) - .and_then(|t| t.downcast_ref::().map(|s| s.to_string())) -} - -/// Convert a C value to a Python object based on the type code -fn ffi_to_python(ty: &PyTypeRef, ptr: *const c_void, vm: &VirtualMachine) -> PyObjectRef { - let type_code = get_type_code(ty, vm); - // SAFETY: ptr is guaranteed to be valid by libffi calling convention - unsafe { - match type_code.as_deref() { - Some("b") => vm.ctx.new_int(*(ptr as *const i8) as i32).into(), - Some("B") => vm.ctx.new_int(*(ptr as *const u8) as i32).into(), - Some("c") => vm.ctx.new_bytes(vec![*(ptr as *const u8)]).into(), - Some("h") => vm.ctx.new_int(*(ptr as *const i16) as i32).into(), - Some("H") => vm.ctx.new_int(*(ptr as *const u16) as i32).into(), - Some("i") => vm.ctx.new_int(*(ptr as *const i32)).into(), - Some("I") => vm.ctx.new_int(*(ptr as *const u32)).into(), - Some("l") => vm.ctx.new_int(*(ptr as *const libc::c_long)).into(), - Some("L") => vm.ctx.new_int(*(ptr as *const libc::c_ulong)).into(), - Some("q") => vm.ctx.new_int(*(ptr as *const libc::c_longlong)).into(), - Some("Q") => vm.ctx.new_int(*(ptr as *const libc::c_ulonglong)).into(), - Some("f") => vm.ctx.new_float(*(ptr as *const f32) as f64).into(), - Some("d") => vm.ctx.new_float(*(ptr as *const f64)).into(), - Some("P") | Some("z") | Some("Z") => vm.ctx.new_int(ptr as usize).into(), - _ => vm.ctx.none(), - } - } -} - -/// Convert a Python object to a C value and store it at the result pointer -fn python_to_ffi(obj: PyResult, ty: &PyTypeRef, result: *mut c_void, vm: &VirtualMachine) { - let obj = match obj { - Ok(o) => o, - Err(_) => return, // Exception occurred, leave result as-is - }; - - let type_code = get_type_code(ty, vm); - // SAFETY: result is guaranteed to be valid by libffi calling convention - unsafe { - match type_code.as_deref() { - Some("b") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut i8) = i.as_bigint().to_i8().unwrap_or(0); - } - } - Some("B") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u8) = i.as_bigint().to_u8().unwrap_or(0); - } - } - Some("c") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u8) = i.as_bigint().to_u8().unwrap_or(0); - } - } - Some("h") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut i16) = i.as_bigint().to_i16().unwrap_or(0); - } - } - Some("H") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u16) = i.as_bigint().to_u16().unwrap_or(0); - } - } - Some("i") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut i32) = i.as_bigint().to_i32().unwrap_or(0); - } - } - Some("I") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u32) = i.as_bigint().to_u32().unwrap_or(0); - } - } - Some("l") | Some("q") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut i64) = i.as_bigint().to_i64().unwrap_or(0); - } - } - Some("L") | Some("Q") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut u64) = i.as_bigint().to_u64().unwrap_or(0); - } - } - Some("f") => { - if let Ok(f) = obj.try_float(vm) { - *(result as *mut f32) = f.to_f64() as f32; - } - } - Some("d") => { - if let Ok(f) = obj.try_float(vm) { - *(result as *mut f64) = f.to_f64(); - } - } - Some("P") | Some("z") | Some("Z") => { - if let Ok(i) = obj.try_int(vm) { - *(result as *mut usize) = i.as_bigint().to_usize().unwrap_or(0); - } - } - _ => {} - } - } -} - -/// The callback function that libffi calls when the closure is invoked. -/// This function converts C arguments to Python objects, calls the Python -/// callable, and converts the result back to C. -unsafe extern "C" fn thunk_callback( - _cif: &low::ffi_cif, - result: &mut c_void, - args: *const *const c_void, - userdata: &ThunkUserData, -) { - with_current_vm(|vm| { - // Convert C arguments to Python objects - let py_args: Vec = userdata - .arg_types - .iter() - .enumerate() - .map(|(i, ty)| { - let arg_ptr = unsafe { *args.add(i) }; - ffi_to_python(ty, arg_ptr, vm) - }) - .collect(); - - // Call the Python callable - let py_result = userdata.callable.call(py_args, vm); - - // Convert result back to C type - if let Some(ref res_type) = userdata.res_type { - python_to_ffi(py_result, res_type, result as *mut c_void, vm); - } - }); -} - -/// Holds the closure and userdata together to ensure proper lifetime. -/// The userdata is leaked to create a 'static reference that the closure can use. -struct ThunkData { - #[allow(dead_code)] - closure: Closure<'static>, - /// Raw pointer to the leaked userdata, for cleanup - userdata_ptr: *mut ThunkUserData, -} - -impl Drop for ThunkData { - fn drop(&mut self) { - // SAFETY: We created this with Box::into_raw, so we can reclaim it - unsafe { - drop(Box::from_raw(self.userdata_ptr)); - } - } -} - -/// CThunkObject wraps a Python callable to make it callable from C code. -#[pyclass(name = "CThunkObject", module = "_ctypes")] -#[derive(PyPayload)] -pub struct PyCThunk { - /// The Python callable - callable: PyObjectRef, - /// The libffi closure (must be kept alive) - #[allow(dead_code)] - thunk_data: PyRwLock>, - /// The code pointer for the closure - code_ptr: CodePtr, -} - -impl Debug for PyCThunk { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PyCThunk") - .field("callable", &self.callable) - .finish() - } -} - -impl PyCThunk { - /// Create a new thunk wrapping a Python callable. - /// - /// # Arguments - /// * `callable` - The Python callable to wrap - /// * `arg_types` - Optional sequence of argument types - /// * `res_type` - Optional result type - /// * `vm` - The virtual machine - pub fn new( - callable: PyObjectRef, - arg_types: Option, - res_type: Option, - vm: &VirtualMachine, - ) -> PyResult { - // Parse argument types - let arg_type_vec: Vec = if let Some(args) = arg_types { - if vm.is_none(&args) { - Vec::new() - } else { - let mut types = Vec::new(); - for item in args.try_to_value::>(vm)? { - types.push(item.downcast::().map_err(|_| { - vm.new_type_error("_argtypes_ must be a sequence of types".to_string()) - })?); - } - types - } - } else { - Vec::new() - }; - - // Parse result type - let res_type_ref: Option = - if let Some(ref rt) = res_type { - if vm.is_none(rt) { - None - } else { - Some(rt.clone().downcast::().map_err(|_| { - vm.new_type_error("restype must be a ctypes type".to_string()) - })?) - } - } else { - None - }; - - // Build FFI types - let ffi_arg_types: Vec = arg_type_vec - .iter() - .map(|ty| { - get_type_code(ty, vm) - .and_then(|code| ffi_type_from_str(&code)) - .unwrap_or(Type::pointer()) - }) - .collect(); - - let ffi_res_type = res_type_ref - .as_ref() - .and_then(|ty| get_type_code(ty, vm)) - .and_then(|code| ffi_type_from_str(&code)) - .unwrap_or(Type::void()); - - // Create the CIF - let cif = Cif::new(ffi_arg_types, ffi_res_type); - - // Create userdata and leak it to get a 'static reference - let userdata = Box::new(ThunkUserData { - callable: callable.clone(), - arg_types: arg_type_vec, - res_type: res_type_ref, - }); - let userdata_ptr = Box::into_raw(userdata); - - // SAFETY: We maintain the userdata lifetime by storing it in ThunkData - // and cleaning it up in Drop - let userdata_ref: &'static ThunkUserData = unsafe { &*userdata_ptr }; - - // Create the closure - let closure = Closure::new(cif, thunk_callback, userdata_ref); - - // Get the code pointer - let code_ptr = CodePtr(*closure.code_ptr() as *mut _); - - // Store closure and userdata together - let thunk_data = ThunkData { - closure, - userdata_ptr, - }; - - Ok(Self { - callable, - thunk_data: PyRwLock::new(Some(thunk_data)), - code_ptr, - }) - } - - /// Get the code pointer for this thunk - pub fn code_ptr(&self) -> CodePtr { - self.code_ptr - } -} - -// SAFETY: PyCThunk is safe to send/sync because: -// - callable is a PyObjectRef which is Send+Sync -// - thunk_data contains the libffi closure which is heap-allocated -// - code_ptr is just a pointer to executable memory -unsafe impl Send for PyCThunk {} -unsafe impl Sync for PyCThunk {} - -#[pyclass] -impl PyCThunk { - #[pygetset] - fn callable(&self) -> PyObjectRef { - self.callable.clone() - } -} diff --git a/crates/vm/src/stdlib/ctypes/union.rs b/crates/vm/src/stdlib/ctypes/union.rs index 308a5e4e98f..500aa8e6244 100644 --- a/crates/vm/src/stdlib/ctypes/union.rs +++ b/crates/vm/src/stdlib/ctypes/union.rs @@ -1,40 +1,60 @@ -use super::base::{CDataObject, PyCData}; -use super::field::PyCField; -use super::util::StgInfo; +use super::base::{CDATA_BUFFER_METHODS, StgInfoFlags}; +use super::{PyCData, PyCField, StgInfo}; use crate::builtins::{PyList, PyStr, PyTuple, PyType, PyTypeRef}; use crate::convert::ToPyObject; use crate::function::FuncArgs; -use crate::protocol::{BufferDescriptor, BufferMethods, PyBuffer as ProtocolPyBuffer}; -use crate::stdlib::ctypes::_ctypes::get_size; -use crate::types::{AsBuffer, Constructor}; -use crate::{AsObject, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine}; -use num_traits::ToPrimitive; -use rustpython_common::lock::PyRwLock; +use crate::function::PySetterValue; +use crate::protocol::{BufferDescriptor, PyBuffer}; +use crate::types::{AsBuffer, Constructor, Initializer, SetAttr}; +use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; +use std::borrow::Cow; + +/// Calculate Union type size from _fields_ (max field size) +pub(super) fn calculate_union_size(cls: &Py, vm: &VirtualMachine) -> PyResult { + if let Ok(fields_attr) = cls.as_object().get_attr("_fields_", vm) { + let fields: Vec = fields_attr.try_to_value(vm)?; + let mut max_size = 0usize; + + for field in fields.iter() { + if let Some(tuple) = field.downcast_ref::() + && let Some(field_type) = tuple.get(1) + { + let field_size = super::_ctypes::sizeof(field_type.clone(), vm)?; + max_size = max_size.max(field_size); + } + } + return Ok(max_size); + } + Ok(0) +} /// PyCUnionType - metaclass for Union #[pyclass(name = "UnionType", base = PyType, module = "_ctypes")] #[derive(Debug)] #[repr(transparent)] -pub struct PyCUnionType(PyType); +pub(super) struct PyCUnionType(PyType); impl Constructor for PyCUnionType { type Args = FuncArgs; fn slot_new(metatype: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - // 1. Create the new class using PyType::py_new - let new_class = crate::builtins::type_::PyType::slot_new(metatype, args, vm)?; + // 1. Create the new class using PyType::slot_new + let new_class = crate::builtins::PyType::slot_new(metatype, args, vm)?; - // 2. Process _fields_ if defined on the new class + // 2. Get the new type let new_type = new_class .clone() .downcast::() .map_err(|_| vm.new_type_error("expected type"))?; - // Only process _fields_ if defined directly on this class (not inherited) - if let Some(fields_attr) = new_type.get_direct_attr(vm.ctx.intern_str("_fields_")) { - Self::process_fields(&new_type, fields_attr, vm)?; - } + // 3. Mark base classes as finalized (subclassing finalizes the parent) + new_type.mark_bases_final(); + + // 4. Initialize StgInfo for the new type (initialized=false, to be set in init) + let stg_info = StgInfo::default(); + let _ = new_type.init_type_data(stg_info); + // Note: _fields_ processing moved to Initializer::init() Ok(new_class) } @@ -43,26 +63,132 @@ impl Constructor for PyCUnionType { } } +impl Initializer for PyCUnionType { + type Args = FuncArgs; + + fn init(zelf: crate::PyRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + // Get the type as PyTypeRef by converting PyRef -> PyObjectRef -> PyRef + let obj: PyObjectRef = zelf.clone().into(); + let new_type: PyTypeRef = obj + .downcast() + .map_err(|_| vm.new_type_error("expected type"))?; + + // Check for _abstract_ attribute - skip initialization if present + if new_type + .get_direct_attr(vm.ctx.intern_str("_abstract_")) + .is_some() + { + return Ok(()); + } + + new_type.check_not_initialized(vm)?; + + // Process _fields_ if defined directly on this class (not inherited) + // Use set_attr to trigger setattro + if let Some(fields_attr) = new_type.get_direct_attr(vm.ctx.intern_str("_fields_")) { + new_type + .as_object() + .set_attr(vm.ctx.intern_str("_fields_"), fields_attr, vm)?; + } else { + // No _fields_ defined - try to copy from base class + let (has_base_info, base_clone) = { + let bases = new_type.bases.read(); + if let Some(base) = bases.first() { + (base.stg_info_opt().is_some(), Some(base.clone())) + } else { + (false, None) + } + }; + + if has_base_info && let Some(ref base) = base_clone { + // Clone base StgInfo (release guard before getting mutable reference) + let stg_info_opt = base.stg_info_opt().map(|baseinfo| { + let mut stg_info = baseinfo.clone(); + stg_info.flags &= !StgInfoFlags::DICTFLAG_FINAL; // Clear FINAL flag in subclass + stg_info.initialized = true; + stg_info + }); + + if let Some(stg_info) = stg_info_opt { + // Mark base as FINAL (now guard is released) + if let Some(mut base_stg) = base.get_type_data_mut::() { + base_stg.flags |= StgInfoFlags::DICTFLAG_FINAL; + } + + super::base::set_or_init_stginfo(&new_type, stg_info); + return Ok(()); + } + } + + // No base StgInfo - create default + let mut stg_info = StgInfo::new(0, 1); + stg_info.flags |= StgInfoFlags::TYPEFLAG_HASUNION; + stg_info.paramfunc = super::base::ParamFunc::Union; + // PEP 3118 doesn't support union. Use 'B' for bytes. + stg_info.format = Some("B".to_string()); + super::base::set_or_init_stginfo(&new_type, stg_info); + } + + Ok(()) + } +} + impl PyCUnionType { /// Process _fields_ and create CField descriptors /// For Union, all fields start at offset 0 fn process_fields( - cls: &PyTypeRef, + cls: &Py, fields_attr: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { + // Check if already finalized + { + let Some(stg_info) = cls.get_type_data::() else { + return Err(vm.new_type_error("ctypes state is not initialized")); + }; + if stg_info.is_final() { + return Err(vm.new_attribute_error("_fields_ is final")); + } + } // Read lock released here + + // Check if this is a swapped byte order union + let is_swapped = cls.as_object().get_attr("_swappedbytes_", vm).is_ok(); + let fields: Vec = if let Some(list) = fields_attr.downcast_ref::() { list.borrow_vec().to_vec() } else if let Some(tuple) = fields_attr.downcast_ref::() { tuple.to_vec() } else { - return Err(vm.new_type_error("_fields_ must be a list or tuple".to_string())); + return Err(vm.new_type_error("_fields_ must be a list or tuple")); + }; + + let pack = super::base::get_usize_attr(cls.as_object(), "_pack_", 0, vm)?; + let forced_alignment = + super::base::get_usize_attr(cls.as_object(), "_align_", 1, vm)?.max(1); + + // Initialize size, alignment, type flags, and ffi_field_types from base class + // Note: Union fields always start at offset 0, but we inherit base size/align + let (mut max_size, mut max_align, mut has_pointer, mut has_bitfield, mut ffi_field_types) = { + let bases = cls.bases.read(); + if let Some(base) = bases.first() + && let Some(baseinfo) = base.stg_info_opt() + { + ( + baseinfo.size, + std::cmp::max(baseinfo.align, forced_alignment), + baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASPOINTER), + baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASBITFIELD), + baseinfo.ffi_field_types.clone(), + ) + } else { + (0, forced_alignment, false, false, Vec::new()) + } }; for (index, field) in fields.iter().enumerate() { let field_tuple = field .downcast_ref::() - .ok_or_else(|| vm.new_type_error("_fields_ must contain tuples".to_string()))?; + .ok_or_else(|| vm.new_type_error("_fields_ must contain tuples"))?; if field_tuple.len() < 2 { return Err(vm.new_type_error( @@ -72,66 +198,230 @@ impl PyCUnionType { let name = field_tuple .first() - .unwrap() + .expect("len checked") .downcast_ref::() - .ok_or_else(|| vm.new_type_error("field name must be a string".to_string()))? + .ok_or_else(|| vm.new_type_error("field name must be a string"))? .to_string(); - let field_type = field_tuple.get(1).unwrap().clone(); - let size = Self::get_field_size(&field_type, vm)?; + let field_type = field_tuple.get(1).expect("len checked").clone(); + + // For swapped byte order unions, validate field type supports byte swapping + if is_swapped { + super::base::check_other_endian_support(&field_type, vm)?; + } + + let size = super::base::get_field_size(&field_type, vm)?; + let field_align = super::base::get_field_align(&field_type, vm); + + // Calculate effective alignment + let effective_align = if pack > 0 { + std::cmp::min(pack, field_align) + } else { + field_align + }; + + max_size = max_size.max(size); + max_align = max_align.max(effective_align); + + // Propagate type flags from field type (HASPOINTER, HASBITFIELD) + if let Some(type_obj) = field_type.downcast_ref::() + && let Some(field_stg) = type_obj.stg_info_opt() + { + // HASPOINTER: propagate if field is pointer or contains pointer + if field_stg.flags.intersects( + StgInfoFlags::TYPEFLAG_ISPOINTER | StgInfoFlags::TYPEFLAG_HASPOINTER, + ) { + has_pointer = true; + } + // HASBITFIELD: propagate directly + if field_stg.flags.contains(StgInfoFlags::TYPEFLAG_HASBITFIELD) { + has_bitfield = true; + } + // Collect FFI type for this field + ffi_field_types.push(field_stg.to_ffi_type()); + } + + // Mark field type as finalized (using type as field finalizes it) + if let Some(type_obj) = field_type.downcast_ref::() { + if let Some(mut stg_info) = type_obj.get_type_data_mut::() { + stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL; + } else { + // Create StgInfo with FINAL flag if it doesn't exist + let mut stg_info = StgInfo::new(size, field_align); + stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL; + let _ = type_obj.init_type_data(stg_info); + } + } // For Union, all fields start at offset 0 - // Create CField descriptor (accepts any ctypes type including arrays) - let c_field = PyCField::new(name.clone(), field_type, 0, size, index); + let field_type_ref = field_type + .clone() + .downcast::() + .map_err(|_| vm.new_type_error("_fields_ type must be a ctypes type"))?; + let c_field = PyCField::new(field_type_ref, 0, size as isize, index); cls.set_attr(vm.ctx.intern_str(name), c_field.to_pyobject(vm)); } + // Calculate total_align and aligned_size + let total_align = std::cmp::max(max_align, forced_alignment); + let aligned_size = if total_align > 0 { + max_size.div_ceil(total_align) * total_align + } else { + max_size + }; + + // Store StgInfo with aligned size + let mut stg_info = StgInfo::new(aligned_size, total_align); + stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL | StgInfoFlags::TYPEFLAG_HASUNION; + // PEP 3118 doesn't support union. Use 'B' for bytes. + stg_info.format = Some("B".to_string()); + if has_pointer { + stg_info.flags |= StgInfoFlags::TYPEFLAG_HASPOINTER; + } + if has_bitfield { + stg_info.flags |= StgInfoFlags::TYPEFLAG_HASBITFIELD; + } + stg_info.paramfunc = super::base::ParamFunc::Union; + // Set byte order: swap if _swappedbytes_ is defined + stg_info.big_endian = super::base::is_big_endian(is_swapped); + // Store FFI field types for union passing + stg_info.ffi_field_types = ffi_field_types; + super::base::set_or_init_stginfo(cls, stg_info); + + // Process _anonymous_ fields + super::base::make_anon_fields(cls, vm)?; + Ok(()) } +} - fn get_field_size(field_type: &PyObject, vm: &VirtualMachine) -> PyResult { - if let Some(size) = field_type - .get_attr("_type_", vm) - .ok() - .and_then(|type_attr| type_attr.str(vm).ok()) - .and_then(|type_str| { - let s = type_str.to_string(); - (s.len() == 1).then(|| get_size(&s)) - }) - { - return Ok(size); +#[pyclass(flags(BASETYPE), with(Constructor, Initializer, SetAttr))] +impl PyCUnionType { + #[pymethod] + fn from_param(zelf: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // zelf is the union type class that from_param was called on + let cls = zelf + .downcast::() + .map_err(|_| vm.new_type_error("from_param: expected a type"))?; + + // 1. If already an instance of the requested type, return it + if value.is_instance(cls.as_object(), vm)? { + return Ok(value); } - if let Some(s) = field_type - .get_attr("size_of_instances", vm) - .ok() - .and_then(|size_method| size_method.call((), vm).ok()) - .and_then(|size| size.try_int(vm).ok()) - .and_then(|n| n.as_bigint().to_usize()) - { - return Ok(s); + // 2. Check for CArgObject (PyCArg_CheckExact) + if let Some(carg) = value.downcast_ref::() { + // Check against proto (for pointer types) + if let Some(stg_info) = cls.stg_info_opt() + && let Some(ref proto) = stg_info.proto + && carg.obj.is_instance(proto.as_object(), vm)? + { + return Ok(value); + } + // Fallback: check if the wrapped object is an instance of the requested type + if carg.obj.is_instance(cls.as_object(), vm)? { + return Ok(value); // Return the CArgObject as-is + } + // CArgObject but wrong type + return Err(vm.new_type_error(format!( + "expected {} instance instead of pointer to {}", + cls.name(), + carg.obj.class().name() + ))); } - Ok(std::mem::size_of::()) + // 3. Check for _as_parameter_ attribute + if let Ok(as_parameter) = value.get_attr("_as_parameter_", vm) { + return PyCUnionType::from_param(cls.as_object().to_owned(), as_parameter, vm); + } + + Err(vm.new_type_error(format!( + "expected {} instance instead of {}", + cls.name(), + value.class().name() + ))) + } + + /// Called when a new Union subclass is created + #[pyclassmethod] + fn __init_subclass__(cls: PyTypeRef, vm: &VirtualMachine) -> PyResult<()> { + cls.mark_bases_final(); + + // Check if _fields_ is defined + if let Some(fields_attr) = cls.get_direct_attr(vm.ctx.intern_str("_fields_")) { + Self::process_fields(&cls, fields_attr, vm)?; + } + Ok(()) } } -#[pyclass(flags(BASETYPE), with(Constructor))] -impl PyCUnionType {} +impl SetAttr for PyCUnionType { + fn setattro( + zelf: &Py, + attr_name: &Py, + value: PySetterValue, + vm: &VirtualMachine, + ) -> PyResult<()> { + let pytype: &Py = zelf.to_base(); + let attr_name_interned = vm.ctx.intern_str(attr_name.as_str()); + + // 1. First, do PyType's setattro (PyType_Type.tp_setattro first) + // Check for data descriptor first + if let Some(attr) = pytype.get_class_attr(attr_name_interned) { + let descr_set = attr.class().mro_find_map(|cls| cls.slots.descr_set.load()); + if let Some(descriptor) = descr_set { + descriptor(&attr, pytype.to_owned().into(), value.clone(), vm)?; + // After successful setattro, check if _fields_ and call process_fields + if attr_name.as_str() == "_fields_" + && let PySetterValue::Assign(fields_value) = value + { + PyCUnionType::process_fields(pytype, fields_value, vm)?; + } + return Ok(()); + } + } + + // Store in type's attributes dict + match &value { + PySetterValue::Assign(v) => { + pytype + .attributes + .write() + .insert(attr_name_interned, v.clone()); + } + PySetterValue::Delete => { + let prev = pytype.attributes.write().shift_remove(attr_name_interned); + if prev.is_none() { + return Err(vm.new_attribute_error(format!( + "type object '{}' has no attribute '{}'", + pytype.name(), + attr_name.as_str(), + ))); + } + } + } + + // 2. If _fields_, call process_fields (which checks FINAL internally) + if attr_name.as_str() == "_fields_" + && let PySetterValue::Assign(fields_value) = value + { + PyCUnionType::process_fields(pytype, fields_value, vm)?; + } + + Ok(()) + } +} /// PyCUnion - base class for Union #[pyclass(module = "_ctypes", name = "Union", base = PyCData, metaclass = "PyCUnionType")] -pub struct PyCUnion { - _base: PyCData, - /// Common CDataObject for memory buffer - pub(super) cdata: PyRwLock, -} +#[repr(transparent)] +pub struct PyCUnion(pub PyCData); impl std::fmt::Debug for PyCUnion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PyCUnion") - .field("size", &self.cdata.read().size()) + .field("size", &self.0.size()) .finish() } } @@ -140,47 +430,22 @@ impl Constructor for PyCUnion { type Args = FuncArgs; fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { - // Get _fields_ from the class - let fields_attr = cls.as_object().get_attr("_fields_", vm).ok(); - - // Calculate union size (max of all field sizes) and alignment - let mut max_size = 0usize; - let mut max_align = 1usize; - - if let Some(fields_attr) = fields_attr { - let fields: Vec = if let Some(list) = fields_attr.downcast_ref::() - { - list.borrow_vec().to_vec() - } else if let Some(tuple) = fields_attr.downcast_ref::() { - tuple.to_vec() - } else { - vec![] - }; + // Check for abstract class and extract values in a block to drop the borrow + let (total_size, total_align) = { + let stg_info = cls.stg_info(vm)?; + (stg_info.size, stg_info.align) + }; - for field in fields.iter() { - let Some(field_tuple) = field.downcast_ref::() else { - continue; - }; - if field_tuple.len() < 2 { - continue; - } - let field_type = field_tuple.get(1).unwrap().clone(); - let size = PyCUnionType::get_field_size(&field_type, vm)?; - max_size = max_size.max(size); - // For simple types, alignment == size - max_align = max_align.max(size); - } + // Mark the class as finalized (instance creation finalizes the type) + if let Some(mut stg_info_mut) = cls.get_type_data_mut::() { + stg_info_mut.flags |= StgInfoFlags::DICTFLAG_FINAL; } - // Initialize buffer with zeros - let stg_info = StgInfo::new(max_size, max_align); - let cdata = CDataObject::from_stg_info(&stg_info); - PyCUnion { - _base: PyCData::new(cdata.clone()), - cdata: PyRwLock::new(cdata), - } - .into_ref_with_type(vm, cls) - .map(Into::into) + // Initialize buffer with zeros using computed size + let new_stg_info = StgInfo::new(total_size, total_align); + PyCUnion(PyCData::from_stg_info(&new_stg_info)) + .into_ref_with_type(vm, cls) + .map(Into::into) } fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { @@ -188,147 +453,125 @@ impl Constructor for PyCUnion { } } -#[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor, AsBuffer))] impl PyCUnion { - #[pygetset] - fn _objects(&self) -> Option { - self.cdata.read().objects.clone() - } - - #[pyclassmethod] - fn from_address(cls: PyTypeRef, address: isize, vm: &VirtualMachine) -> PyResult { - use crate::stdlib::ctypes::_ctypes::size_of; - - // Get size from cls - let size = size_of(cls.clone().into(), vm)?; - - // Create instance with data from address - if address == 0 || size == 0 { - return Err(vm.new_value_error("NULL pointer access".to_owned())); - } - let stg_info = StgInfo::new(size, 1); - let cdata = CDataObject::from_stg_info(&stg_info); - Ok(PyCUnion { - _base: PyCData::new(cdata.clone()), - cdata: PyRwLock::new(cdata), - } - .into_ref_with_type(vm, cls)? - .into()) - } - - #[pyclassmethod] - fn from_buffer( - cls: PyTypeRef, - source: PyObjectRef, - offset: crate::function::OptionalArg, + /// Recursively initialize positional arguments through inheritance chain + /// Returns the number of arguments consumed + fn init_pos_args( + self_obj: &Py, + type_obj: &Py, + args: &[PyObjectRef], + kwargs: &indexmap::IndexMap, + index: usize, vm: &VirtualMachine, - ) -> PyResult { - use crate::TryFromObject; - use crate::protocol::PyBuffer; - use crate::stdlib::ctypes::_ctypes::size_of; - - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); - } - let offset = offset as usize; - - let buffer = PyBuffer::try_from_object(vm, source.clone())?; + ) -> PyResult { + let mut current_index = index; + + // 1. First process base class fields recursively + // Recurse if base has StgInfo + let base_clone = { + let bases = type_obj.bases.read(); + if let Some(base) = bases.first() && + // Check if base has StgInfo + base.stg_info_opt().is_some() + { + Some(base.clone()) + } else { + None + } + }; - if buffer.desc.readonly { - return Err(vm.new_type_error("underlying buffer is not writable".to_owned())); + if let Some(ref base) = base_clone { + current_index = Self::init_pos_args(self_obj, base, args, kwargs, current_index, vm)?; } - let size = size_of(cls.clone().into(), vm)?; - let buffer_len = buffer.desc.len; + // 2. Process this class's _fields_ + if let Some(fields_attr) = type_obj.get_direct_attr(vm.ctx.intern_str("_fields_")) { + let fields: Vec = fields_attr.try_to_value(vm)?; - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); + for field in fields.iter() { + if current_index >= args.len() { + break; + } + if let Some(tuple) = field.downcast_ref::() + && let Some(name) = tuple.first() + && let Some(name_str) = name.downcast_ref::() + { + let field_name = name_str.as_str().to_owned(); + // Check for duplicate in kwargs + if kwargs.contains_key(&field_name) { + return Err(vm.new_type_error(format!( + "duplicate values for field {:?}", + field_name + ))); + } + self_obj.as_object().set_attr( + vm.ctx.intern_str(field_name), + args[current_index].clone(), + vm, + )?; + current_index += 1; + } + } } - // Copy data from source buffer - let bytes = buffer.obj_bytes(); - let data = bytes[offset..offset + size].to_vec(); - - let cdata = CDataObject::from_bytes(data, None); - Ok(PyCUnion { - _base: PyCData::new(cdata.clone()), - cdata: PyRwLock::new(cdata), - } - .into_ref_with_type(vm, cls)? - .into()) + Ok(current_index) } +} - #[pyclassmethod] - fn from_buffer_copy( - cls: PyTypeRef, - source: crate::function::ArgBytesLike, - offset: crate::function::OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - use crate::stdlib::ctypes::_ctypes::size_of; +impl Initializer for PyCUnion { + type Args = FuncArgs; - let offset = offset.unwrap_or(0); - if offset < 0 { - return Err(vm.new_value_error("offset cannot be negative".to_owned())); - } - let offset = offset as usize; + fn init(zelf: crate::PyRef, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + // Struct_init: handle positional and keyword arguments + let cls = zelf.class().to_owned(); - let size = size_of(cls.clone().into(), vm)?; - let source_bytes = source.borrow_buf(); - let buffer_len = source_bytes.len(); + // 1. Process positional arguments recursively through inheritance chain + if !args.args.is_empty() { + let consumed = PyCUnion::init_pos_args(&zelf, &cls, &args.args, &args.kwargs, 0, vm)?; - if offset + size > buffer_len { - return Err(vm.new_value_error(format!( - "Buffer size too small ({} instead of at least {} bytes)", - buffer_len, - offset + size - ))); + if consumed < args.args.len() { + return Err(vm.new_type_error("too many initializers")); + } } - // Copy data from source - let data = source_bytes[offset..offset + size].to_vec(); - - let cdata = CDataObject::from_bytes(data, None); - Ok(PyCUnion { - _base: PyCData::new(cdata.clone()), - cdata: PyRwLock::new(cdata), + // 2. Process keyword arguments + for (key, value) in args.kwargs.iter() { + zelf.as_object() + .set_attr(vm.ctx.intern_str(key.as_str()), value.clone(), vm)?; } - .into_ref_with_type(vm, cls)? - .into()) + + Ok(()) } } -static UNION_BUFFER_METHODS: BufferMethods = BufferMethods { - obj_bytes: |buffer| { - rustpython_common::lock::PyRwLockReadGuard::map( - buffer.obj_as::().cdata.read(), - |x: &CDataObject| x.buffer.as_slice(), - ) - .into() - }, - obj_bytes_mut: |buffer| { - rustpython_common::lock::PyRwLockWriteGuard::map( - buffer.obj_as::().cdata.write(), - |x: &mut CDataObject| x.buffer.as_mut_slice(), - ) - .into() - }, - release: |_| {}, - retain: |_| {}, -}; +#[pyclass( + flags(BASETYPE, IMMUTABLETYPE), + with(Constructor, Initializer, AsBuffer) +)] +impl PyCUnion {} impl AsBuffer for PyCUnion { - fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { - let buffer_len = zelf.cdata.read().buffer.len(); - let buf = ProtocolPyBuffer::new( + fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + let buffer_len = zelf.0.buffer.read().len(); + + // PyCData_NewGetBuffer: use info->format if available, otherwise "B" + let format = zelf + .class() + .stg_info_opt() + .and_then(|info| info.format.clone()) + .unwrap_or_else(|| "B".to_string()); + + // Union: ndim=0, shape=(), itemsize=union_size + let buf = PyBuffer::new( zelf.to_owned().into(), - BufferDescriptor::simple(buffer_len, false), // readonly=false for ctypes - &UNION_BUFFER_METHODS, + BufferDescriptor { + len: buffer_len, + readonly: false, + itemsize: buffer_len, + format: Cow::Owned(format), + dim_desc: vec![], // ndim=0 means empty dim_desc + }, + &CDATA_BUFFER_METHODS, ); Ok(buf) } diff --git a/crates/vm/src/stdlib/ctypes/util.rs b/crates/vm/src/stdlib/ctypes/util.rs deleted file mode 100644 index b8c6def63ca..00000000000 --- a/crates/vm/src/stdlib/ctypes/util.rs +++ /dev/null @@ -1,88 +0,0 @@ -use crate::PyObjectRef; - -/// Storage information for ctypes types -/// Stored in TypeDataSlot of heap types (PyType::init_type_data/get_type_data) -#[derive(Clone)] -pub struct StgInfo { - pub initialized: bool, - pub size: usize, // number of bytes - pub align: usize, // alignment requirements - pub length: usize, // number of fields (for arrays/structures) - pub proto: Option, // Only for Pointer/ArrayObject - pub flags: i32, // calling convention and such - - // Array-specific fields (moved from PyCArrayType) - pub element_type: Option, // _type_ for arrays - pub element_size: usize, // size of each element -} - -// StgInfo is stored in type_data which requires Send + Sync. -// The PyObjectRef in proto/element_type fields is protected by the type system's locking mechanism. -// CPython: ctypes objects are not thread-safe by design; users must synchronize access. -unsafe impl Send for StgInfo {} -unsafe impl Sync for StgInfo {} - -impl std::fmt::Debug for StgInfo { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("StgInfo") - .field("initialized", &self.initialized) - .field("size", &self.size) - .field("align", &self.align) - .field("length", &self.length) - .field("proto", &self.proto) - .field("flags", &self.flags) - .field("element_type", &self.element_type) - .field("element_size", &self.element_size) - .finish() - } -} - -impl Default for StgInfo { - fn default() -> Self { - StgInfo { - initialized: false, - size: 0, - align: 1, - length: 0, - proto: None, - flags: 0, - element_type: None, - element_size: 0, - } - } -} - -impl StgInfo { - pub fn new(size: usize, align: usize) -> Self { - StgInfo { - initialized: true, - size, - align, - length: 0, - proto: None, - flags: 0, - element_type: None, - element_size: 0, - } - } - - /// Create StgInfo for an array type - pub fn new_array( - size: usize, - align: usize, - length: usize, - element_type: PyObjectRef, - element_size: usize, - ) -> Self { - StgInfo { - initialized: true, - size, - align, - length, - proto: None, - flags: 0, - element_type: Some(element_type), - element_size, - } - } -} diff --git a/crates/vm/src/stdlib/functools.rs b/crates/vm/src/stdlib/functools.rs index d5a42739e96..26dff8b4426 100644 --- a/crates/vm/src/stdlib/functools.rs +++ b/crates/vm/src/stdlib/functools.rs @@ -73,8 +73,8 @@ mod _functools { self.inner.read().keywords.clone() } - #[pymethod(name = "__reduce__")] - fn reduce(zelf: &Py, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn __reduce__(zelf: &Py, vm: &VirtualMachine) -> PyResult { let inner = zelf.inner.read(); let partial_type = zelf.class(); diff --git a/crates/vm/src/stdlib/operator.rs b/crates/vm/src/stdlib/operator.rs index 0c048ea2a3f..7877ddb0114 100644 --- a/crates/vm/src/stdlib/operator.rs +++ b/crates/vm/src/stdlib/operator.rs @@ -323,7 +323,7 @@ mod _operator { ) -> PyResult { let res = match (a, b) { (Either::A(a), Either::A(b)) => { - if !a.is_ascii() || !b.is_ascii() { + if !a.isascii() || !b.isascii() { return Err(vm.new_type_error( "comparing strings with non-ASCII characters is not supported", )); diff --git a/crates/vm/src/types/structseq.rs b/crates/vm/src/types/structseq.rs index be0a1c9a70c..2b6a2530b02 100644 --- a/crates/vm/src/types/structseq.rs +++ b/crates/vm/src/types/structseq.rs @@ -199,7 +199,7 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { .ok_or_else(|| vm.new_type_error("unexpected payload for __repr__"))?; let field_names = Self::Data::REQUIRED_FIELD_NAMES; - let format_field = |(value, name): (&PyObjectRef, _)| { + let format_field = |(value, name): (&PyObject, _)| { let s = value.repr(vm)?; Ok(format!("{name}={s}")) }; @@ -212,6 +212,7 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { } else { let fields: PyResult> = zelf .iter() + .map(|value| value.as_ref()) .zip(field_names.iter().copied()) .map(format_field) .collect(); From 6c186e389381cb1b9d99e74e085a9c53ddd24624 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Fri, 19 Dec 2025 14:51:09 +0100 Subject: [PATCH 013/418] Update `smptlib` and `test_smtpnet.py` from 3.13.11 (#6435) * Update `test_smtpnet.py` from 3.13.11 * Update `test_smtplib.py` from 3.13.11 * Update `smtplib.py` from 3.13.11 * Catch AttributeError --- Lib/smtplib.py | 43 ++++++++++++++------- Lib/test/test_smtplib.py | 83 +++++++++++++++++++++++++++++++--------- Lib/test/test_smtpnet.py | 9 +++-- 3 files changed, 100 insertions(+), 35 deletions(-) mode change 100644 => 100755 Lib/smtplib.py diff --git a/Lib/smtplib.py b/Lib/smtplib.py old mode 100644 new mode 100755 index 912233d8176..9b81bcfbc41 --- a/Lib/smtplib.py +++ b/Lib/smtplib.py @@ -171,7 +171,7 @@ def quotedata(data): internet CRLF end-of-line. """ return re.sub(r'(?m)^\.', '..', - re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data)) + re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data)) def _quote_periods(bindata): return re.sub(br'(?m)^\.', b'..', bindata) @@ -179,6 +179,16 @@ def _quote_periods(bindata): def _fix_eols(data): return re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data) + +try: + hmac.digest(b'', b'', 'md5') +# except ValueError: +except (ValueError, AttributeError): # TODO: RUSTPYTHON + _have_cram_md5_support = False +else: + _have_cram_md5_support = True + + try: import ssl except ImportError: @@ -475,7 +485,7 @@ def ehlo(self, name=''): if auth_match: # This doesn't remove duplicates, but that's no problem self.esmtp_features["auth"] = self.esmtp_features.get("auth", "") \ - + " " + auth_match.groups(0)[0] + + " " + auth_match.groups(0)[0] continue # RFC 1869 requires a space between ehlo keyword and parameters. @@ -488,7 +498,7 @@ def ehlo(self, name=''): params = m.string[m.end("feature"):].strip() if feature == "auth": self.esmtp_features[feature] = self.esmtp_features.get(feature, "") \ - + " " + params + + " " + params else: self.esmtp_features[feature] = params return (code, msg) @@ -542,7 +552,7 @@ def mail(self, sender, options=()): raise SMTPNotSupportedError( 'SMTPUTF8 not supported by server') optionlist = ' ' + ' '.join(options) - self.putcmd("mail", "FROM:%s%s" % (quoteaddr(sender), optionlist)) + self.putcmd("mail", "from:%s%s" % (quoteaddr(sender), optionlist)) return self.getreply() def rcpt(self, recip, options=()): @@ -550,7 +560,7 @@ def rcpt(self, recip, options=()): optionlist = '' if options and self.does_esmtp: optionlist = ' ' + ' '.join(options) - self.putcmd("rcpt", "TO:%s%s" % (quoteaddr(recip), optionlist)) + self.putcmd("rcpt", "to:%s%s" % (quoteaddr(recip), optionlist)) return self.getreply() def data(self, msg): @@ -667,8 +677,11 @@ def auth_cram_md5(self, challenge=None): # CRAM-MD5 does not support initial-response. if challenge is None: return None - return self.user + " " + hmac.HMAC( - self.password.encode('ascii'), challenge, 'md5').hexdigest() + if not _have_cram_md5_support: + raise SMTPException("CRAM-MD5 is not supported") + password = self.password.encode('ascii') + authcode = hmac.HMAC(password, challenge, 'md5') + return f"{self.user} {authcode.hexdigest()}" def auth_plain(self, challenge=None): """ Authobject to use with PLAIN authentication. Requires self.user and @@ -720,8 +733,10 @@ def login(self, user, password, *, initial_response_ok=True): advertised_authlist = self.esmtp_features["auth"].split() # Authentication methods we can handle in our preferred order: - preferred_auths = ['CRAM-MD5', 'PLAIN', 'LOGIN'] - + if _have_cram_md5_support: + preferred_auths = ['CRAM-MD5', 'PLAIN', 'LOGIN'] + else: + preferred_auths = ['PLAIN', 'LOGIN'] # We try the supported authentications in our preferred order, if # the server supports them. authlist = [auth for auth in preferred_auths @@ -905,7 +920,7 @@ def send_message(self, msg, from_addr=None, to_addrs=None, The arguments are as for sendmail, except that msg is an email.message.Message object. If from_addr is None or to_addrs is None, these arguments are taken from the headers of the Message as - described in RFC 2822 (a ValueError is raised if there is more than + described in RFC 5322 (a ValueError is raised if there is more than one set of 'Resent-' headers). Regardless of the values of from_addr and to_addr, any Bcc field (or Resent-Bcc field, when the Message is a resent) of the Message object won't be transmitted. The Message @@ -919,7 +934,7 @@ def send_message(self, msg, from_addr=None, to_addrs=None, policy. """ - # 'Resent-Date' is a mandatory field if the Message is resent (RFC 2822 + # 'Resent-Date' is a mandatory field if the Message is resent (RFC 5322 # Section 3.6.6). In such a case, we use the 'Resent-*' fields. However, # if there is more than one 'Resent-' block there's no way to # unambiguously determine which one is the most recent in all cases, @@ -938,10 +953,10 @@ def send_message(self, msg, from_addr=None, to_addrs=None, else: raise ValueError("message has more than one 'Resent-' header block") if from_addr is None: - # Prefer the sender field per RFC 2822:3.6.2. + # Prefer the sender field per RFC 5322 section 3.6.2. from_addr = (msg[header_prefix + 'Sender'] - if (header_prefix + 'Sender') in msg - else msg[header_prefix + 'From']) + if (header_prefix + 'Sender') in msg + else msg[header_prefix + 'From']) from_addr = email.utils.getaddresses([from_addr])[0][1] if to_addrs is None: addr_fields = [f for f in (msg[header_prefix + 'To'], diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index 9b787950fc2..ade0dc1308c 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -17,6 +17,7 @@ import threading import unittest +import unittest.mock as mock from test import support, mock_socket from test.support import hashlib_helper from test.support import socket_helper @@ -350,7 +351,7 @@ def testVRFY(self): timeout=support.LOOPBACK_TIMEOUT) self.addCleanup(smtp.close) expected = (252, b'Cannot VRFY user, but will accept message ' + \ - b'and attempt delivery') + b'and attempt delivery') self.assertEqual(smtp.vrfy('nobody@nowhere.com'), expected) self.assertEqual(smtp.verify('nobody@nowhere.com'), expected) smtp.quit() @@ -371,7 +372,7 @@ def testHELP(self): timeout=support.LOOPBACK_TIMEOUT) self.addCleanup(smtp.close) self.assertEqual(smtp.help(), b'Supported commands: EHLO HELO MAIL ' + \ - b'RCPT DATA RSET NOOP QUIT VRFY') + b'RCPT DATA RSET NOOP QUIT VRFY') smtp.quit() def testSend(self): @@ -527,7 +528,7 @@ def testSendMessageWithAddresses(self): smtp.quit() # make sure the Bcc header is still in the message. self.assertEqual(m['Bcc'], 'John Root , "Dinsdale" ' - '') + '') self.client_evt.set() self.serv_evt.wait() @@ -766,7 +767,7 @@ def tearDown(self): def testFailingHELO(self): self.assertRaises(smtplib.SMTPConnectError, smtplib.SMTP, - HOST, self.port, 'localhost', 3) + HOST, self.port, 'localhost', 3) class TooLongLineTests(unittest.TestCase): @@ -804,14 +805,14 @@ def testLineTooLong(self): sim_users = {'Mr.A@somewhere.com':'John A', 'Ms.B@xn--fo-fka.com':'Sally B', 'Mrs.C@somewhereesle.com':'Ruth C', - } + } sim_auth = ('Mr.A@somewhere.com', 'somepassword') sim_cram_md5_challenge = ('PENCeUxFREJoU0NnbmhNWitOMjNGNn' 'dAZWx3b29kLmlubm9zb2Z0LmNvbT4=') sim_lists = {'list-1':['Mr.A@somewhere.com','Mrs.C@somewhereesle.com'], 'list-2':['Ms.B@xn--fo-fka.com',], - } + } # Simulated SMTP channel & server class ResponseException(Exception): pass @@ -830,6 +831,7 @@ class SimSMTPChannel(smtpd.SMTPChannel): def __init__(self, extra_features, *args, **kw): self._extrafeatures = ''.join( [ "250-{0}\r\n".format(x) for x in extra_features ]) + self.all_received_lines = [] super(SimSMTPChannel, self).__init__(*args, **kw) # AUTH related stuff. It would be nice if support for this were in smtpd. @@ -844,6 +846,7 @@ def found_terminator(self): self.smtp_state = self.COMMAND self.push('%s %s' % (e.smtp_code, e.smtp_error)) return + self.all_received_lines.append(self.received_lines) super().found_terminator() @@ -924,11 +927,14 @@ def _auth_cram_md5(self, arg=None): except ValueError as e: self.push('535 Splitting response {!r} into user and password ' 'failed: {}'.format(logpass, e)) - return False - valid_hashed_pass = hmac.HMAC( - sim_auth[1].encode('ascii'), - self._decode_base64(sim_cram_md5_challenge).encode('ascii'), - 'md5').hexdigest() + return + pwd = sim_auth[1].encode('ascii') + msg = self._decode_base64(sim_cram_md5_challenge).encode('ascii') + try: + valid_hashed_pass = hmac.HMAC(pwd, msg, 'md5').hexdigest() + except ValueError: + self.push('504 CRAM-MD5 is not supported') + return self._authenticated(user, hashed_pass == valid_hashed_pass) # end AUTH related stuff. @@ -1170,8 +1176,7 @@ def auth_buggy(challenge=None): finally: smtp.close() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @hashlib_helper.requires_hashdigest('md5', openssl=True) def testAUTH_CRAM_MD5(self): self.serv.add_feature("AUTH CRAM-MD5") @@ -1181,8 +1186,39 @@ def testAUTH_CRAM_MD5(self): self.assertEqual(resp, (235, b'Authentication Succeeded')) smtp.close() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @mock.patch("hmac.HMAC") + @mock.patch("smtplib._have_cram_md5_support", False) + def testAUTH_CRAM_MD5_blocked(self, hmac_constructor): + # CRAM-MD5 is the only "known" method by the server, + # but it is not supported by the client. In particular, + # no challenge will ever be sent. + self.serv.add_feature("AUTH CRAM-MD5") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + msg = re.escape("No suitable authentication method found.") + with self.assertRaisesRegex(smtplib.SMTPException, msg): + smtp.login(sim_auth[0], sim_auth[1]) + hmac_constructor.assert_not_called() # call has been bypassed + + @mock.patch("smtplib._have_cram_md5_support", False) + def testAUTH_CRAM_MD5_blocked_and_fallback(self): + # Test that PLAIN is tried after CRAM-MD5 failed + self.serv.add_feature("AUTH CRAM-MD5 PLAIN") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with ( + mock.patch.object(smtp, "auth_cram_md5") as smtp_auth_cram_md5, + mock.patch.object( + smtp, "auth_plain", wraps=smtp.auth_plain + ) as smtp_auth_plain + ): + resp = smtp.login(sim_auth[0], sim_auth[1]) + smtp_auth_plain.assert_called_once() + smtp_auth_cram_md5.assert_not_called() # no call to HMAC constructor + self.assertEqual(resp, (235, b'Authentication Succeeded')) + @hashlib_helper.requires_hashdigest('md5', openssl=True) def testAUTH_multiple(self): # Test that multiple authentication methods are tried. @@ -1193,8 +1229,7 @@ def testAUTH_multiple(self): self.assertEqual(resp, (235, b'Authentication Succeeded')) smtp.close() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_auth_function(self): supported = {'PLAIN', 'LOGIN'} try: @@ -1354,6 +1389,18 @@ def test_name_field_not_included_in_envelop_addresses(self): self.assertEqual(self.serv._addresses['from'], 'michael@example.com') self.assertEqual(self.serv._addresses['tos'], ['rene@example.com']) + def test_lowercase_mail_from_rcpt_to(self): + m = 'A test message' + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + + smtp.sendmail('John', 'Sally', m) + + self.assertIn(['mail from: size=14'], self.serv._SMTPchannel.all_received_lines) + self.assertIn(['rcpt to:'], self.serv._SMTPchannel.all_received_lines) + class SimSMTPUTF8Server(SimSMTPServer): @@ -1372,7 +1419,7 @@ def handle_accepted(self, conn, addr): ) def process_message(self, peer, mailfrom, rcpttos, data, mail_options=None, - rcpt_options=None): + rcpt_options=None): self.last_peer = peer self.last_mailfrom = mailfrom self.last_rcpttos = rcpttos diff --git a/Lib/test/test_smtpnet.py b/Lib/test/test_smtpnet.py index be25e961f74..d765746987b 100644 --- a/Lib/test/test_smtpnet.py +++ b/Lib/test/test_smtpnet.py @@ -2,6 +2,7 @@ from test import support from test.support import import_helper from test.support import socket_helper +import os import smtplib import socket @@ -9,6 +10,8 @@ support.requires("network") +SMTP_TEST_SERVER = os.getenv('CPYTHON_TEST_SMTP_SERVER', 'smtp.gmail.com') + def check_ssl_verifiy(host, port): context = ssl.create_default_context() with socket.create_connection((host, port)) as sock: @@ -22,7 +25,7 @@ def check_ssl_verifiy(host, port): class SmtpTest(unittest.TestCase): - testServer = 'smtp.gmail.com' + testServer = SMTP_TEST_SERVER remotePort = 587 def test_connect_starttls(self): @@ -44,7 +47,7 @@ def test_connect_starttls(self): class SmtpSSLTest(unittest.TestCase): - testServer = 'smtp.gmail.com' + testServer = SMTP_TEST_SERVER remotePort = 465 def test_connect(self): @@ -87,4 +90,4 @@ def test_connect_using_sslcontext_verified(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From ab1105a61debb068b104ebc03e7d37b8e7b1edad Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sat, 20 Dec 2025 09:55:28 +0900 Subject: [PATCH 014/418] Fix fix_test.py (#6415) --- scripts/fix_test.py | 165 ++++++++++++++++++++++++++------------------ 1 file changed, 96 insertions(+), 69 deletions(-) diff --git a/scripts/fix_test.py b/scripts/fix_test.py index a5663e3eee3..9716bd0b008 100644 --- a/scripts/fix_test.py +++ b/scripts/fix_test.py @@ -5,16 +5,20 @@ How to use: 1. Copy a specific test from the CPython repository to the RustPython repository. -2. Remove all unexpected failures from the test and skip the tests that hang -3. Run python ./scripts/fix_test.py --test test_venv --path ./Lib/test/test_venv.py or equivalent for the test from the project root. -4. Ensure that there are no unexpected successes in the test. -5. Actually fix the test. +2. Remove all unexpected failures from the test and skip the tests that hang. +3. Build RustPython: cargo build --release +4. Run from the project root: + - For single-file tests: python ./scripts/fix_test.py --path ./Lib/test/test_venv.py + - For package tests: python ./scripts/fix_test.py --path ./Lib/test/test_inspect/test_inspect.py +5. Verify: cargo run --release -- -m test test_venv (should pass with expected failures) +6. Actually fix the tests marked with # TODO: RUSTPYTHON """ import argparse import ast import itertools import platform +import sys from pathlib import Path @@ -58,85 +62,87 @@ def parse_results(result): in_test_results = True elif line.startswith("-----------"): in_test_results = False - if ( - in_test_results - and not line.startswith("tests") - and not line.startswith("[") - ): - line = line.split(" ") - if line != [] and len(line) > 3: - test = Test() - test.name = line[0] - test.path = line[1].strip("(").strip(")") - test.result = " ".join(line[3:]).lower() - test_results.tests.append(test) - else: - if "== Tests result: " in line: - res = line.split("== Tests result: ")[1] - res = res.split(" ")[0] - test_results.tests_result = res + if in_test_results and " ... " in line: + line = line.strip() + # Skip lines that don't look like test results + if line.startswith("tests") or line.startswith("["): + continue + # Parse: "test_name (path) [subtest] ... RESULT" + parts = line.split(" ... ") + if len(parts) >= 2: + test_info = parts[0] + result_str = parts[-1].lower() + # Only process FAIL or ERROR + if result_str not in ("fail", "error"): + continue + # Extract test name (first word) + first_space = test_info.find(" ") + if first_space > 0: + test = Test() + test.name = test_info[:first_space] + # Extract path from (path) + rest = test_info[first_space:].strip() + if rest.startswith("("): + end_paren = rest.find(")") + if end_paren > 0: + test.path = rest[1:end_paren] + test.result = result_str + test_results.tests.append(test) + elif "== Tests result: " in line: + res = line.split("== Tests result: ")[1] + res = res.split(" ")[0] + test_results.tests_result = res return test_results def path_to_test(path) -> list[str]: - return path.split(".")[2:] + # path format: test.module_name[.submodule].ClassName.test_method + # We need [ClassName, test_method] - always the last 2 elements + parts = path.split(".") + return parts[-2:] # Get class name and method name -def modify_test(file: str, test: list[str], for_platform: bool = False) -> str: +def find_test_lineno(file: str, test: list[str]) -> tuple[int, int] | None: + """Find the line number and column offset of a test function. + Returns (lineno, col_offset) or None if not found. + """ a = ast.parse(file) - lines = file.splitlines() - fixture = "@unittest.expectedFailure" - for node in ast.walk(a): - if isinstance(node, ast.FunctionDef): - if node.name == test[-1]: - assert not for_platform - indent = " " * node.col_offset - lines.insert(node.lineno - 1, indent + fixture) - lines.insert(node.lineno - 1, indent + "# TODO: RUSTPYTHON") - break - return "\n".join(lines) - - -def modify_test_v2(file: str, test: list[str], for_platform: bool = False) -> str: - a = ast.parse(file) - lines = file.splitlines() - fixture = "@unittest.expectedFailure" for key, node in ast.iter_fields(a): if key == "body": - for i, n in enumerate(node): + for n in node: match n: case ast.ClassDef(): if len(test) == 2 and test[0] == n.name: - # look through body for function def - for i, fn in enumerate(n.body): + for fn in n.body: match fn: - case ast.FunctionDef(): + case ast.FunctionDef() | ast.AsyncFunctionDef(): if fn.name == test[-1]: - assert not for_platform - indent = " " * fn.col_offset - lines.insert( - fn.lineno - 1, indent + fixture - ) - lines.insert( - fn.lineno - 1, - indent + "# TODO: RUSTPYTHON", - ) - break - case ast.FunctionDef(): + return (fn.lineno, fn.col_offset) + case ast.FunctionDef() | ast.AsyncFunctionDef(): if n.name == test[0] and len(test) == 1: - assert not for_platform - indent = " " * n.col_offset - lines.insert(n.lineno - 1, indent + fixture) - lines.insert(n.lineno - 1, indent + "# TODO: RUSTPYTHON") - break - if i > 500: - exit() + return (n.lineno, n.col_offset) + return None + + +def apply_modifications(file: str, modifications: list[tuple[int, int]]) -> str: + """Apply all modifications in reverse order to avoid line number offset issues.""" + lines = file.splitlines() + fixture = "@unittest.expectedFailure" + # Sort by line number in descending order + modifications.sort(key=lambda x: x[0], reverse=True) + for lineno, col_offset in modifications: + indent = " " * col_offset + lines.insert(lineno - 1, indent + fixture) + lines.insert(lineno - 1, indent + "# TODO: RUSTPYTHON") return "\n".join(lines) def run_test(test_name): print(f"Running test: {test_name}") rustpython_location = "./target/release/rustpython" + if sys.platform == "win32": + rustpython_location += ".exe" + import subprocess result = subprocess.run( @@ -149,13 +155,34 @@ def run_test(test_name): if __name__ == "__main__": args = parse_args() - test_name = args.path.stem + test_path = args.path.resolve() + if not test_path.exists(): + print(f"Error: File not found: {test_path}") + sys.exit(1) + test_name = test_path.stem tests = run_test(test_name) - f = open(args.path).read() + f = test_path.read_text(encoding="utf-8") + + # Collect all modifications first (with deduplication for subtests) + modifications = [] + seen_tests = set() # Track (class_name, method_name) to avoid duplicates for test in tests.tests: if test.result == "fail" or test.result == "error": - print("Modifying test:", test.name) - f = modify_test_v2(f, path_to_test(test.path), args.platform) - with open(args.path, "w") as file: - # TODO: Find validation method, and make --force override it - file.write(f) + test_parts = path_to_test(test.path) + test_key = tuple(test_parts) + if test_key in seen_tests: + continue # Skip duplicate (same test, different subtest) + seen_tests.add(test_key) + location = find_test_lineno(f, test_parts) + if location: + print(f"Modifying test: {test.name} at line {location[0]}") + modifications.append(location) + else: + print(f"Warning: Could not find test: {test.name} ({test_parts})") + + # Apply all modifications in reverse order + if modifications: + f = apply_modifications(f, modifications) + test_path.write_text(f, encoding="utf-8") + + print(f"Modified {len(modifications)} tests") From 46c61a65a6773de5583860e90857e619bed4aaf6 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sat, 20 Dec 2025 12:17:11 +0900 Subject: [PATCH 015/418] Upgrade venv to v3.13.11 (#6459) * Upgrade venv * get_venv_base_executable * mark failing tests --- Lib/test/test_venv.py | 615 ++++++++++++++++-- Lib/venv/__init__.py | 607 +++++++++++------ Lib/venv/__main__.py | 2 +- Lib/venv/scripts/common/Activate.ps1 | 3 +- Lib/venv/scripts/common/activate | 43 +- .../scripts/{posix => common}/activate.fish | 19 +- Lib/venv/scripts/nt/activate.bat | 10 +- Lib/venv/scripts/posix/activate.csh | 9 +- crates/vm/src/stdlib/sys.rs | 65 +- 9 files changed, 1065 insertions(+), 308 deletions(-) rename Lib/venv/scripts/{posix => common}/activate.fish (76%) diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py index 94d626598ba..19b12070531 100644 --- a/Lib/test/test_venv.py +++ b/Lib/test/test_venv.py @@ -5,21 +5,31 @@ Licensed to the PSF under a contributor agreement. """ +import contextlib import ensurepip import os import os.path +import pathlib import re import shutil import struct import subprocess import sys +import sysconfig import tempfile -from test.support import (captured_stdout, captured_stderr, requires_zlib, - skip_if_broken_multiprocessing_synchronize) -from test.support.os_helper import (can_symlink, EnvironmentVarGuard, rmtree) +import shlex +from test.support import (captured_stdout, captured_stderr, + skip_if_broken_multiprocessing_synchronize, verbose, + requires_subprocess, is_android, is_apple_mobile, + is_emscripten, is_wasi, + requires_venv_with_pip, TEST_HOME_DIR, + requires_resource, copy_python_src_ignore) +from test.support.os_helper import (can_symlink, EnvironmentVarGuard, rmtree, + TESTFN, FakePath) +from test.support.testcase import ExtraAssertions import unittest import venv -from unittest.mock import patch +from unittest.mock import patch, Mock try: import ctypes @@ -33,18 +43,29 @@ or sys._base_executable != sys.executable, 'cannot run venv.create from within a venv on this platform') +if is_android or is_apple_mobile or is_emscripten or is_wasi: + raise unittest.SkipTest("venv is not available on this platform") + +@requires_subprocess() def check_output(cmd, encoding=None): p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - encoding=encoding) + env={**os.environ, "PYTHONHOME": ""}) out, err = p.communicate() if p.returncode: + if verbose and err: + print(err.decode(encoding or 'utf-8', 'backslashreplace')) raise subprocess.CalledProcessError( p.returncode, cmd, out, err) + if encoding: + return ( + out.decode(encoding, 'backslashreplace'), + err.decode(encoding, 'backslashreplace'), + ) return out, err -class BaseTest(unittest.TestCase): +class BaseTest(unittest.TestCase, ExtraAssertions): """Base class for venv tests.""" maxDiff = 80 * 50 @@ -56,7 +77,7 @@ def setUp(self): self.include = 'Include' else: self.bindir = 'bin' - self.lib = ('lib', 'python%d.%d' % sys.version_info[:2]) + self.lib = ('lib', f'python{sysconfig._get_python_version_abi()}') self.include = 'include' executable = sys._base_executable self.exe = os.path.split(executable)[-1] @@ -70,6 +91,13 @@ def setUp(self): def tearDown(self): rmtree(self.env_dir) + def envpy(self, *, real_env_dir=False): + if real_env_dir: + env_dir = os.path.realpath(self.env_dir) + else: + env_dir = self.env_dir + return os.path.join(env_dir, self.bindir, self.exe) + def run_with_capture(self, func, *args, **kwargs): with captured_stdout() as output: with captured_stderr() as error: @@ -91,12 +119,27 @@ def isdir(self, *args): fn = self.get_env_file(*args) self.assertTrue(os.path.isdir(fn)) - def test_defaults(self): + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_defaults_with_str_path(self): """ - Test the create function with default arguments. + Test the create function with default arguments and a str path. """ rmtree(self.env_dir) self.run_with_capture(venv.create, self.env_dir) + self._check_output_of_default_create() + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_defaults_with_pathlike(self): + """ + Test the create function with default arguments and a path-like path. + """ + rmtree(self.env_dir) + self.run_with_capture(venv.create, FakePath(self.env_dir)) + self._check_output_of_default_create() + + def _check_output_of_default_create(self): self.isdir(self.bindir) self.isdir(self.include) self.isdir(*self.lib) @@ -112,6 +155,12 @@ def test_defaults(self): executable = sys._base_executable path = os.path.dirname(executable) self.assertIn('home = %s' % path, data) + self.assertIn('executable = %s' % + os.path.realpath(sys.executable), data) + copies = '' if os.name=='nt' else ' --copies' + cmd = (f'command = {sys.executable} -m venv{copies} --without-pip ' + f'--without-scm-ignore-files {self.env_dir}') + self.assertIn(cmd, data) fn = self.get_env_file(self.bindir, self.exe) if not os.path.exists(fn): # diagnostics for Windows buildbot failures bd = self.get_env_file(self.bindir) @@ -119,6 +168,39 @@ def test_defaults(self): print(' %r' % os.listdir(bd)) self.assertTrue(os.path.exists(fn), 'File %r should exist.' % fn) + def test_config_file_command_key(self): + options = [ + (None, None, None), # Default case. + ('--copies', 'symlinks', False), + ('--without-pip', 'with_pip', False), + ('--system-site-packages', 'system_site_packages', True), + ('--clear', 'clear', True), + ('--upgrade', 'upgrade', True), + ('--upgrade-deps', 'upgrade_deps', True), + ('--prompt="foobar"', 'prompt', 'foobar'), + ('--without-scm-ignore-files', 'scm_ignore_files', frozenset()), + ] + for opt, attr, value in options: + with self.subTest(opt=opt, attr=attr, value=value): + rmtree(self.env_dir) + if not attr: + kwargs = {} + else: + kwargs = {attr: value} + b = venv.EnvBuilder(**kwargs) + b.upgrade_dependencies = Mock() # avoid pip command to upgrade deps + b._setup_pip = Mock() # avoid pip setup + self.run_with_capture(b.create, self.env_dir) + data = self.get_text_file_contents('pyvenv.cfg') + if not attr or opt.endswith('git'): + for opt in ('--system-site-packages', '--clear', '--upgrade', + '--upgrade-deps', '--prompt'): + self.assertNotRegex(data, rf'command = .* {opt}') + elif os.name=='nt' and attr=='symlinks': + pass + else: + self.assertRegex(data, rf'command = .* {opt}') + def test_prompt(self): env_name = os.path.split(self.env_dir)[1] @@ -127,7 +209,7 @@ def test_prompt(self): self.run_with_capture(builder.create, self.env_dir) context = builder.ensure_directories(self.env_dir) data = self.get_text_file_contents('pyvenv.cfg') - self.assertEqual(context.prompt, '(%s) ' % env_name) + self.assertEqual(context.prompt, env_name) self.assertNotIn("prompt = ", data) rmtree(self.env_dir) @@ -135,7 +217,7 @@ def test_prompt(self): self.run_with_capture(builder.create, self.env_dir) context = builder.ensure_directories(self.env_dir) data = self.get_text_file_contents('pyvenv.cfg') - self.assertEqual(context.prompt, '(My prompt) ') + self.assertEqual(context.prompt, 'My prompt') self.assertIn("prompt = 'My prompt'\n", data) rmtree(self.env_dir) @@ -144,13 +226,19 @@ def test_prompt(self): self.run_with_capture(builder.create, self.env_dir) context = builder.ensure_directories(self.env_dir) data = self.get_text_file_contents('pyvenv.cfg') - self.assertEqual(context.prompt, '(%s) ' % cwd) + self.assertEqual(context.prompt, cwd) self.assertIn("prompt = '%s'\n" % cwd, data) def test_upgrade_dependencies(self): builder = venv.EnvBuilder() - bin_path = 'Scripts' if sys.platform == 'win32' else 'bin' + bin_path = 'bin' python_exe = os.path.split(sys.executable)[1] + if sys.platform == 'win32': + bin_path = 'Scripts' + if os.path.normcase(os.path.splitext(python_exe)[0]).endswith('_d'): + python_exe = 'python_d.exe' + else: + python_exe = 'python.exe' with tempfile.TemporaryDirectory() as fake_env_dir: expect_exe = os.path.normcase( os.path.join(fake_env_dir, bin_path, python_exe) @@ -158,7 +246,7 @@ def test_upgrade_dependencies(self): if sys.platform == 'win32': expect_exe = os.path.normcase(os.path.realpath(expect_exe)) - def pip_cmd_checker(cmd): + def pip_cmd_checker(cmd, **kwargs): cmd[0] = os.path.normcase(cmd[0]) self.assertEqual( cmd, @@ -169,12 +257,11 @@ def pip_cmd_checker(cmd): 'install', '--upgrade', 'pip', - 'setuptools' ] ) fake_context = builder.ensure_directories(fake_env_dir) - with patch('venv.subprocess.check_call', pip_cmd_checker): + with patch('venv.subprocess.check_output', pip_cmd_checker): builder.upgrade_dependencies(fake_context) @requireVenvCreate @@ -185,8 +272,7 @@ def test_prefixes(self): # check a venv's prefixes rmtree(self.env_dir) self.run_with_capture(venv.create, self.env_dir) - envpy = os.path.join(self.env_dir, self.bindir, self.exe) - cmd = [envpy, '-c', None] + cmd = [self.envpy(), '-c', None] for prefix, expected in ( ('prefix', self.env_dir), ('exec_prefix', self.env_dir), @@ -194,7 +280,76 @@ def test_prefixes(self): ('base_exec_prefix', sys.base_exec_prefix)): cmd[2] = 'import sys; print(sys.%s)' % prefix out, err = check_output(cmd) - self.assertEqual(out.strip(), expected.encode()) + self.assertEqual(pathlib.Path(out.strip().decode()), + pathlib.Path(expected), prefix) + + @requireVenvCreate + def test_sysconfig(self): + """ + Test that the sysconfig functions work in a virtual environment. + """ + rmtree(self.env_dir) + self.run_with_capture(venv.create, self.env_dir, symlinks=False) + cmd = [self.envpy(), '-c', None] + for call, expected in ( + # installation scheme + ('get_preferred_scheme("prefix")', 'venv'), + ('get_default_scheme()', 'venv'), + # build environment + ('is_python_build()', str(sysconfig.is_python_build())), + ('get_makefile_filename()', sysconfig.get_makefile_filename()), + ('get_config_h_filename()', sysconfig.get_config_h_filename()), + ('get_config_var("Py_GIL_DISABLED")', + str(sysconfig.get_config_var("Py_GIL_DISABLED")))): + with self.subTest(call): + cmd[2] = 'import sysconfig; print(sysconfig.%s)' % call + out, err = check_output(cmd, encoding='utf-8') + self.assertEqual(out.strip(), expected, err) + for attr, expected in ( + ('executable', self.envpy()), + # Usually compare to sys.executable, but if we're running in our own + # venv then we really need to compare to our base executable + ('_base_executable', sys._base_executable), + ): + with self.subTest(attr): + cmd[2] = f'import sys; print(sys.{attr})' + out, err = check_output(cmd, encoding='utf-8') + self.assertEqual(out.strip(), expected, err) + + @requireVenvCreate + @unittest.skipUnless(can_symlink(), 'Needs symlinks') + def test_sysconfig_symlinks(self): + """ + Test that the sysconfig functions work in a virtual environment. + """ + rmtree(self.env_dir) + self.run_with_capture(venv.create, self.env_dir, symlinks=True) + cmd = [self.envpy(), '-c', None] + for call, expected in ( + # installation scheme + ('get_preferred_scheme("prefix")', 'venv'), + ('get_default_scheme()', 'venv'), + # build environment + ('is_python_build()', str(sysconfig.is_python_build())), + ('get_makefile_filename()', sysconfig.get_makefile_filename()), + ('get_config_h_filename()', sysconfig.get_config_h_filename()), + ('get_config_var("Py_GIL_DISABLED")', + str(sysconfig.get_config_var("Py_GIL_DISABLED")))): + with self.subTest(call): + cmd[2] = 'import sysconfig; print(sysconfig.%s)' % call + out, err = check_output(cmd, encoding='utf-8') + self.assertEqual(out.strip(), expected, err) + for attr, expected in ( + ('executable', self.envpy()), + # Usually compare to sys.executable, but if we're running in our own + # venv then we really need to compare to our base executable + # HACK: Test fails on POSIX with unversioned binary (PR gh-113033) + #('_base_executable', sys._base_executable), + ): + with self.subTest(attr): + cmd[2] = f'import sys; print(sys.{attr})' + out, err = check_output(cmd, encoding='utf-8') + self.assertEqual(out.strip(), expected, err) if sys.platform == 'win32': ENV_SUBDIRS = ( @@ -259,6 +414,8 @@ def test_unoverwritable_fails(self): self.assertRaises((ValueError, OSError), venv.create, self.env_dir) self.clear_directory(self.env_dir) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_upgrade(self): """ Test upgrading an existing environment directory. @@ -321,8 +478,7 @@ def test_executable(self): """ rmtree(self.env_dir) self.run_with_capture(venv.create, self.env_dir) - envpy = os.path.join(os.path.realpath(self.env_dir), - self.bindir, self.exe) + envpy = self.envpy(real_env_dir=True) out, err = check_output([envpy, '-c', 'import sys; print(sys.executable)']) self.assertEqual(out.strip(), envpy.encode()) @@ -335,12 +491,89 @@ def test_executable_symlinks(self): rmtree(self.env_dir) builder = venv.EnvBuilder(clear=True, symlinks=True) builder.create(self.env_dir) - envpy = os.path.join(os.path.realpath(self.env_dir), - self.bindir, self.exe) + envpy = self.envpy(real_env_dir=True) out, err = check_output([envpy, '-c', 'import sys; print(sys.executable)']) self.assertEqual(out.strip(), envpy.encode()) + # gh-124651: test quoted strings + @unittest.skipIf(os.name == 'nt', 'contains invalid characters on Windows') + def test_special_chars_bash(self): + """ + Test that the template strings are quoted properly (bash) + """ + rmtree(self.env_dir) + bash = shutil.which('bash') + if bash is None: + self.skipTest('bash required for this test') + env_name = '"\';&&$e|\'"' + env_dir = os.path.join(os.path.realpath(self.env_dir), env_name) + builder = venv.EnvBuilder(clear=True) + builder.create(env_dir) + activate = os.path.join(env_dir, self.bindir, 'activate') + test_script = os.path.join(self.env_dir, 'test_special_chars.sh') + with open(test_script, "w") as f: + f.write(f'source {shlex.quote(activate)}\n' + 'python -c \'import sys; print(sys.executable)\'\n' + 'python -c \'import os; print(os.environ["VIRTUAL_ENV"])\'\n' + 'deactivate\n') + out, err = check_output([bash, test_script]) + lines = out.splitlines() + self.assertTrue(env_name.encode() in lines[0]) + self.assertEndsWith(lines[1], env_name.encode()) + + # gh-124651: test quoted strings + @unittest.skipIf(os.name == 'nt', 'contains invalid characters on Windows') + @unittest.skipIf(sys.platform.startswith('netbsd'), + "NetBSD csh fails with quoted special chars; see gh-139308") + def test_special_chars_csh(self): + """ + Test that the template strings are quoted properly (csh) + """ + rmtree(self.env_dir) + csh = shutil.which('tcsh') or shutil.which('csh') + if csh is None: + self.skipTest('csh required for this test') + env_name = '"\';&&$e|\'"' + env_dir = os.path.join(os.path.realpath(self.env_dir), env_name) + builder = venv.EnvBuilder(clear=True) + builder.create(env_dir) + activate = os.path.join(env_dir, self.bindir, 'activate.csh') + test_script = os.path.join(self.env_dir, 'test_special_chars.csh') + with open(test_script, "w") as f: + f.write(f'source {shlex.quote(activate)}\n' + 'python -c \'import sys; print(sys.executable)\'\n' + 'python -c \'import os; print(os.environ["VIRTUAL_ENV"])\'\n' + 'deactivate\n') + out, err = check_output([csh, test_script]) + lines = out.splitlines() + self.assertTrue(env_name.encode() in lines[0]) + self.assertEndsWith(lines[1], env_name.encode()) + + # gh-124651: test quoted strings on Windows + @unittest.skipUnless(os.name == 'nt', 'only relevant on Windows') + def test_special_chars_windows(self): + """ + Test that the template strings are quoted properly on Windows + """ + rmtree(self.env_dir) + env_name = "'&&^$e" + env_dir = os.path.join(os.path.realpath(self.env_dir), env_name) + builder = venv.EnvBuilder(clear=True) + builder.create(env_dir) + activate = os.path.join(env_dir, self.bindir, 'activate.bat') + test_batch = os.path.join(self.env_dir, 'test_special_chars.bat') + with open(test_batch, "w") as f: + f.write('@echo off\n' + f'"{activate}" & ' + f'{self.exe} -c "import sys; print(sys.executable)" & ' + f'{self.exe} -c "import os; print(os.environ[\'VIRTUAL_ENV\'])" & ' + 'deactivate') + out, err = check_output([test_batch]) + lines = out.splitlines() + self.assertTrue(env_name.encode() in lines[0]) + self.assertEndsWith(lines[1], env_name.encode()) + @unittest.skipUnless(os.name == 'nt', 'only relevant on Windows') def test_unicode_in_batch_file(self): """ @@ -351,13 +584,27 @@ def test_unicode_in_batch_file(self): builder = venv.EnvBuilder(clear=True) builder.create(env_dir) activate = os.path.join(env_dir, self.bindir, 'activate.bat') - envpy = os.path.join(env_dir, self.bindir, self.exe) out, err = check_output( [activate, '&', self.exe, '-c', 'print(0)'], encoding='oem', ) self.assertEqual(out.strip(), '0') + @unittest.skipUnless(os.name == 'nt' and can_symlink(), + 'symlinks on Windows') + def test_failed_symlink(self): + """ + Test handling of failed symlinks on Windows. + """ + rmtree(self.env_dir) + env_dir = os.path.join(os.path.realpath(self.env_dir), 'venv') + with patch('os.symlink') as mock_symlink: + mock_symlink.side_effect = OSError() + builder = venv.EnvBuilder(clear=True, symlinks=True) + _, err = self.run_with_capture(builder.create, env_dir) + filepath_regex = r"'[A-Z]:\\\\(?:[^\\\\]+\\\\)*[^\\\\]+'" + self.assertRegex(err, rf"Unable to symlink {filepath_regex} to {filepath_regex}") + @requireVenvCreate def test_multiprocessing(self): """ @@ -370,15 +617,25 @@ def test_multiprocessing(self): rmtree(self.env_dir) self.run_with_capture(venv.create, self.env_dir) - envpy = os.path.join(os.path.realpath(self.env_dir), - self.bindir, self.exe) - out, err = check_output([envpy, '-c', + out, err = check_output([self.envpy(real_env_dir=True), '-c', 'from multiprocessing import Pool; ' 'pool = Pool(1); ' 'print(pool.apply_async("Python".lower).get(3)); ' 'pool.terminate()']) self.assertEqual(out.strip(), "python".encode()) + @requireVenvCreate + def test_multiprocessing_recursion(self): + """ + Test that the multiprocessing is able to spawn itself + """ + skip_if_broken_multiprocessing_synchronize() + + rmtree(self.env_dir) + self.run_with_capture(venv.create, self.env_dir) + script = os.path.join(TEST_HOME_DIR, '_test_venv_multiprocessing.py') + subprocess.check_call([self.envpy(real_env_dir=True), "-I", script]) + @unittest.skipIf(os.name == 'nt', 'not relevant on Windows') def test_deactivate_with_strict_bash_opts(self): bash = shutil.which("bash") @@ -404,19 +661,250 @@ def test_macos_env(self): builder = venv.EnvBuilder() builder.create(self.env_dir) - envpy = os.path.join(os.path.realpath(self.env_dir), - self.bindir, self.exe) - out, err = check_output([envpy, '-c', + out, err = check_output([self.envpy(real_env_dir=True), '-c', 'import os; print("__PYVENV_LAUNCHER__" in os.environ)']) self.assertEqual(out.strip(), 'False'.encode()) + def test_pathsep_error(self): + """ + Test that venv creation fails when the target directory contains + the path separator. + """ + rmtree(self.env_dir) + bad_itempath = self.env_dir + os.pathsep + self.assertRaises(ValueError, venv.create, bad_itempath) + self.assertRaises(ValueError, venv.create, FakePath(bad_itempath)) + + @unittest.skipIf(os.name == 'nt', 'not relevant on Windows') + @requireVenvCreate + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_zippath_from_non_installed_posix(self): + """ + Test that when create venv from non-installed python, the zip path + value is as expected. + """ + rmtree(self.env_dir) + # First try to create a non-installed python. It's not a real full + # functional non-installed python, but enough for this test. + platlibdir = sys.platlibdir + non_installed_dir = os.path.realpath(tempfile.mkdtemp()) + self.addCleanup(rmtree, non_installed_dir) + bindir = os.path.join(non_installed_dir, self.bindir) + os.mkdir(bindir) + shutil.copy2(sys.executable, bindir) + libdir = os.path.join(non_installed_dir, platlibdir, self.lib[1]) + os.makedirs(libdir) + landmark = os.path.join(libdir, "os.py") + abi_thread = "t" if sysconfig.get_config_var("Py_GIL_DISABLED") else "" + stdlib_zip = f"python{sys.version_info.major}{sys.version_info.minor}{abi_thread}" + zip_landmark = os.path.join(non_installed_dir, + platlibdir, + stdlib_zip) + additional_pythonpath_for_non_installed = [] + + # Copy stdlib files to the non-installed python so venv can + # correctly calculate the prefix. + for eachpath in sys.path: + if eachpath.endswith(".zip"): + if os.path.isfile(eachpath): + shutil.copyfile( + eachpath, + os.path.join(non_installed_dir, platlibdir)) + elif os.path.isfile(os.path.join(eachpath, "os.py")): + names = os.listdir(eachpath) + ignored_names = copy_python_src_ignore(eachpath, names) + for name in names: + if name in ignored_names: + continue + if name == "site-packages": + continue + fn = os.path.join(eachpath, name) + if os.path.isfile(fn): + shutil.copy(fn, libdir) + elif os.path.isdir(fn): + shutil.copytree(fn, os.path.join(libdir, name), + ignore=copy_python_src_ignore) + else: + additional_pythonpath_for_non_installed.append( + eachpath) + cmd = [os.path.join(non_installed_dir, self.bindir, self.exe), + "-m", + "venv", + "--without-pip", + "--without-scm-ignore-files", + self.env_dir] + # Our fake non-installed python is not fully functional because + # it cannot find the extensions. Set PYTHONPATH so it can run the + # venv module correctly. + pythonpath = os.pathsep.join( + additional_pythonpath_for_non_installed) + # For python built with shared enabled. We need to set + # LD_LIBRARY_PATH so the non-installed python can find and link + # libpython.so + ld_library_path = sysconfig.get_config_var("LIBDIR") + if not ld_library_path or sysconfig.is_python_build(): + ld_library_path = os.path.abspath(os.path.dirname(sys.executable)) + if sys.platform == 'darwin': + ld_library_path_env = "DYLD_LIBRARY_PATH" + else: + ld_library_path_env = "LD_LIBRARY_PATH" + child_env = { + "PYTHONPATH": pythonpath, + ld_library_path_env: ld_library_path, + } + if asan_options := os.environ.get("ASAN_OPTIONS"): + # prevent https://github.com/python/cpython/issues/104839 + child_env["ASAN_OPTIONS"] = asan_options + subprocess.check_call(cmd, env=child_env) + # Now check the venv created from the non-installed python has + # correct zip path in pythonpath. + cmd = [self.envpy(), '-S', '-c', 'import sys; print(sys.path)'] + out, err = check_output(cmd) + self.assertTrue(zip_landmark.encode() in out) + + @requireVenvCreate + def test_activate_shell_script_has_no_dos_newlines(self): + """ + Test that the `activate` shell script contains no CR LF. + This is relevant for Cygwin, as the Windows build might have + converted line endings accidentally. + """ + venv_dir = pathlib.Path(self.env_dir) + rmtree(venv_dir) + [[scripts_dir], *_] = self.ENV_SUBDIRS + script_path = venv_dir / scripts_dir / "activate" + venv.create(venv_dir) + with open(script_path, 'rb') as script: + for i, line in enumerate(script, 1): + error_message = f"CR LF found in line {i}" + self.assertFalse(line.endswith(b'\r\n'), error_message) + + @requireVenvCreate + def test_scm_ignore_files_git(self): + """ + Test that a .gitignore file is created when "git" is specified. + The file should contain a `*\n` line. + """ + self.run_with_capture(venv.create, self.env_dir, + scm_ignore_files={'git'}) + file_lines = self.get_text_file_contents('.gitignore').splitlines() + self.assertIn('*', file_lines) + + @requireVenvCreate + def test_create_scm_ignore_files_multiple(self): + """ + Test that ``scm_ignore_files`` can work with multiple SCMs. + """ + bzrignore_name = ".bzrignore" + contents = "# For Bazaar.\n*\n" + + class BzrEnvBuilder(venv.EnvBuilder): + def create_bzr_ignore_file(self, context): + gitignore_path = os.path.join(context.env_dir, bzrignore_name) + with open(gitignore_path, 'w', encoding='utf-8') as file: + file.write(contents) + + builder = BzrEnvBuilder(scm_ignore_files={'git', 'bzr'}) + self.run_with_capture(builder.create, self.env_dir) + + gitignore_lines = self.get_text_file_contents('.gitignore').splitlines() + self.assertIn('*', gitignore_lines) + + bzrignore = self.get_text_file_contents(bzrignore_name) + self.assertEqual(bzrignore, contents) + + @requireVenvCreate + def test_create_scm_ignore_files_empty(self): + """ + Test that no default ignore files are created when ``scm_ignore_files`` + is empty. + """ + # scm_ignore_files is set to frozenset() by default. + self.run_with_capture(venv.create, self.env_dir) + with self.assertRaises(FileNotFoundError): + self.get_text_file_contents('.gitignore') + + self.assertIn("--without-scm-ignore-files", + self.get_text_file_contents('pyvenv.cfg')) + + @requireVenvCreate + def test_cli_with_scm_ignore_files(self): + """ + Test that default SCM ignore files are created by default via the CLI. + """ + self.run_with_capture(venv.main, ['--without-pip', self.env_dir]) + + gitignore_lines = self.get_text_file_contents('.gitignore').splitlines() + self.assertIn('*', gitignore_lines) + + @requireVenvCreate + def test_cli_without_scm_ignore_files(self): + """ + Test that ``--without-scm-ignore-files`` doesn't create SCM ignore files. + """ + args = ['--without-pip', '--without-scm-ignore-files', self.env_dir] + self.run_with_capture(venv.main, args) + + with self.assertRaises(FileNotFoundError): + self.get_text_file_contents('.gitignore') + + def test_venv_same_path(self): + same_path = venv.EnvBuilder._same_path + if sys.platform == 'win32': + # Case-insensitive, and handles short/long names + tests = [ + (True, TESTFN, TESTFN), + (True, TESTFN.lower(), TESTFN.upper()), + ] + import _winapi + # ProgramFiles is the most reliable path that will have short/long + progfiles = os.getenv('ProgramFiles') + if progfiles: + tests = [ + *tests, + (True, progfiles, progfiles), + (True, _winapi.GetShortPathName(progfiles), _winapi.GetLongPathName(progfiles)), + ] + else: + # Just a simple case-sensitive comparison + tests = [ + (True, TESTFN, TESTFN), + (False, TESTFN.lower(), TESTFN.upper()), + ] + for r, path1, path2 in tests: + with self.subTest(f"{path1}-{path2}"): + if r: + self.assertTrue(same_path(path1, path2)) + else: + self.assertFalse(same_path(path1, path2)) + + # gh-126084: venvwlauncher should run pythonw, not python + @requireVenvCreate + @unittest.skipUnless(os.name == 'nt', 'only relevant on Windows') + def test_venvwlauncher(self): + """ + Test that the GUI launcher runs the GUI python. + """ + rmtree(self.env_dir) + venv.create(self.env_dir) + exename = self.exe + # Retain the debug suffix if present + if "python" in exename and not "pythonw" in exename: + exename = exename.replace("python", "pythonw") + envpyw = os.path.join(self.env_dir, self.bindir, exename) + try: + subprocess.check_call([envpyw, "-c", "import sys; " + "assert sys._base_executable.endswith('%s')" % exename]) + except subprocess.CalledProcessError: + self.fail("venvwlauncher.exe did not run %s" % exename) + + @requireVenvCreate class EnsurePipTest(BaseTest): """Test venv module installation of pip.""" def assert_pip_not_installed(self): - envpy = os.path.join(os.path.realpath(self.env_dir), - self.bindir, self.exe) - out, err = check_output([envpy, '-c', + out, err = check_output([self.envpy(real_env_dir=True), '-c', 'try:\n import pip\nexcept ImportError:\n print("OK")']) # We force everything to text, so unittest gives the detailed diff # if we get unexpected results @@ -478,20 +966,14 @@ def do_test_with_pip(self, system_site_packages): # Actually run the create command with all that unhelpful # config in place to ensure we ignore it - try: + with self.nicer_error(): self.run_with_capture(venv.create, self.env_dir, system_site_packages=system_site_packages, with_pip=True) - except subprocess.CalledProcessError as exc: - # The output this produces can be a little hard to read, - # but at least it has all the details - details = exc.output.decode(errors="replace") - msg = "{}\n\n**Subprocess Output**\n{}" - self.fail(msg.format(exc, details)) # Ensure pip is available in the virtual environment - envpy = os.path.join(os.path.realpath(self.env_dir), self.bindir, self.exe) # Ignore DeprecationWarning since pip code is not part of Python - out, err = check_output([envpy, '-W', 'ignore::DeprecationWarning', + out, err = check_output([self.envpy(real_env_dir=True), + '-W', 'ignore::DeprecationWarning', '-W', 'ignore::ImportWarning', '-I', '-m', 'pip', '--version']) # We force everything to text, so unittest gives the detailed diff @@ -508,13 +990,14 @@ def do_test_with_pip(self, system_site_packages): # Check the private uninstall command provided for the Windows # installers works (at least in a virtual environment) with EnvironmentVarGuard() as envvars: - # It seems ensurepip._uninstall calls subprocesses which do not - # inherit the interpreter settings. - envvars["PYTHONWARNINGS"] = "ignore" - out, err = check_output([envpy, - '-W', 'ignore::DeprecationWarning', - '-W', 'ignore::ImportWarning', '-I', - '-m', 'ensurepip._uninstall']) + with self.nicer_error(): + # It seems ensurepip._uninstall calls subprocesses which do not + # inherit the interpreter settings. + envvars["PYTHONWARNINGS"] = "ignore" + out, err = check_output([self.envpy(real_env_dir=True), + '-W', 'ignore::DeprecationWarning', + '-W', 'ignore::ImportWarning', '-I', + '-m', 'ensurepip._uninstall']) # We force everything to text, so unittest gives the detailed diff # if we get unexpected results err = err.decode("latin-1") # Force to text, prevent decoding errors @@ -527,25 +1010,51 @@ def do_test_with_pip(self, system_site_packages): err = re.sub("^(WARNING: )?The directory .* or its parent directory " "is not owned or is not writable by the current user.*$", "", err, flags=re.MULTILINE) + # Ignore warning about missing optional module: + try: + import ssl + except ImportError: + err = re.sub( + "^WARNING: Disabling truststore since ssl support is missing$", + "", + err, flags=re.MULTILINE) self.assertEqual(err.rstrip(), "") # Being fairly specific regarding the expected behaviour for the # initial bundling phase in Python 3.4. If the output changes in # future pip versions, this test can likely be relaxed further. out = out.decode("latin-1") # Force to text, prevent decoding errors self.assertIn("Successfully uninstalled pip", out) - self.assertIn("Successfully uninstalled setuptools", out) # Check pip is now gone from the virtual environment. This only # applies in the system_site_packages=False case, because in the # other case, pip may still be available in the system site-packages if not system_site_packages: self.assert_pip_not_installed() - # Issue #26610: pip/pep425tags.py requires ctypes - @unittest.skipUnless(ctypes, 'pip requires ctypes') - @requires_zlib() + @contextlib.contextmanager + def nicer_error(self): + """ + Capture output from a failed subprocess for easier debugging. + + The output this handler produces can be a little hard to read, + but at least it has all the details. + """ + try: + yield + except subprocess.CalledProcessError as exc: + out = (exc.output or b'').decode(errors="replace") + err = (exc.stderr or b'').decode(errors="replace") + self.fail( + f"{exc}\n\n" + f"**Subprocess Output**\n{out}\n\n" + f"**Subprocess Error**\n{err}" + ) + + @requires_venv_with_pip() + @requires_resource('cpu') def test_with_pip(self): self.do_test_with_pip(False) self.do_test_with_pip(True) + if __name__ == "__main__": unittest.main() diff --git a/Lib/venv/__init__.py b/Lib/venv/__init__.py index 6f1af294ae6..f7a6d261401 100644 --- a/Lib/venv/__init__.py +++ b/Lib/venv/__init__.py @@ -11,9 +11,10 @@ import sys import sysconfig import types +import shlex -CORE_VENV_DEPS = ('pip', 'setuptools') +CORE_VENV_DEPS = ('pip',) logger = logging.getLogger(__name__) @@ -41,20 +42,24 @@ class EnvBuilder: environment :param prompt: Alternative terminal prefix for the environment. :param upgrade_deps: Update the base venv modules to the latest on PyPI + :param scm_ignore_files: Create ignore files for the SCMs specified by the + iterable. """ def __init__(self, system_site_packages=False, clear=False, symlinks=False, upgrade=False, with_pip=False, prompt=None, - upgrade_deps=False): + upgrade_deps=False, *, scm_ignore_files=frozenset()): self.system_site_packages = system_site_packages self.clear = clear self.symlinks = symlinks self.upgrade = upgrade self.with_pip = with_pip + self.orig_prompt = prompt if prompt == '.': # see bpo-38901 prompt = os.path.basename(os.getcwd()) self.prompt = prompt self.upgrade_deps = upgrade_deps + self.scm_ignore_files = frozenset(map(str.lower, scm_ignore_files)) def create(self, env_dir): """ @@ -65,6 +70,8 @@ def create(self, env_dir): """ env_dir = os.path.abspath(env_dir) context = self.ensure_directories(env_dir) + for scm in self.scm_ignore_files: + getattr(self, f"create_{scm}_ignore_file")(context) # See issue 24875. We need system_site_packages to be False # until after pip is installed. true_system_site_packages = self.system_site_packages @@ -92,6 +99,42 @@ def clear_directory(self, path): elif os.path.isdir(fn): shutil.rmtree(fn) + def _venv_path(self, env_dir, name): + vars = { + 'base': env_dir, + 'platbase': env_dir, + 'installed_base': env_dir, + 'installed_platbase': env_dir, + } + return sysconfig.get_path(name, scheme='venv', vars=vars) + + @classmethod + def _same_path(cls, path1, path2): + """Check whether two paths appear the same. + + Whether they refer to the same file is irrelevant; we're testing for + whether a human reader would look at the path string and easily tell + that they're the same file. + """ + if sys.platform == 'win32': + if os.path.normcase(path1) == os.path.normcase(path2): + return True + # gh-90329: Don't display a warning for short/long names + import _winapi + try: + path1 = _winapi.GetLongPathName(os.fsdecode(path1)) + except OSError: + pass + try: + path2 = _winapi.GetLongPathName(os.fsdecode(path2)) + except OSError: + pass + if os.path.normcase(path1) == os.path.normcase(path2): + return True + return False + else: + return path1 == path2 + def ensure_directories(self, env_dir): """ Create the directories for the environment. @@ -106,31 +149,38 @@ def create_if_needed(d): elif os.path.islink(d) or os.path.isfile(d): raise ValueError('Unable to create directory %r' % d) + if os.pathsep in os.fspath(env_dir): + raise ValueError(f'Refusing to create a venv in {env_dir} because ' + f'it contains the PATH separator {os.pathsep}.') if os.path.exists(env_dir) and self.clear: self.clear_directory(env_dir) context = types.SimpleNamespace() context.env_dir = env_dir context.env_name = os.path.split(env_dir)[1] - prompt = self.prompt if self.prompt is not None else context.env_name - context.prompt = '(%s) ' % prompt + context.prompt = self.prompt if self.prompt is not None else context.env_name create_if_needed(env_dir) executable = sys._base_executable + if not executable: # see gh-96861 + raise ValueError('Unable to determine path to the running ' + 'Python interpreter. Provide an explicit path or ' + 'check that your PATH environment variable is ' + 'correctly set.') dirname, exename = os.path.split(os.path.abspath(executable)) + if sys.platform == 'win32': + # Always create the simplest name in the venv. It will either be a + # link back to executable, or a copy of the appropriate launcher + _d = '_d' if os.path.splitext(exename)[0].endswith('_d') else '' + exename = f'python{_d}.exe' context.executable = executable context.python_dir = dirname context.python_exe = exename - if sys.platform == 'win32': - binname = 'Scripts' - incpath = 'Include' - libpath = os.path.join(env_dir, 'Lib', 'site-packages') - else: - binname = 'bin' - incpath = 'include' - libpath = os.path.join(env_dir, 'lib', - 'python%d.%d' % sys.version_info[:2], - 'site-packages') - context.inc_path = path = os.path.join(env_dir, incpath) - create_if_needed(path) + binpath = self._venv_path(env_dir, 'scripts') + incpath = self._venv_path(env_dir, 'include') + libpath = self._venv_path(env_dir, 'purelib') + + context.inc_path = incpath + create_if_needed(incpath) + context.lib_path = libpath create_if_needed(libpath) # Issue 21197: create lib64 as a symlink to lib on 64-bit non-OS X POSIX if ((sys.maxsize > 2**32) and (os.name == 'posix') and @@ -138,8 +188,8 @@ def create_if_needed(d): link_path = os.path.join(env_dir, 'lib64') if not os.path.exists(link_path): # Issue #21643 os.symlink('lib', link_path) - context.bin_path = binpath = os.path.join(env_dir, binname) - context.bin_name = binname + context.bin_path = binpath + context.bin_name = os.path.relpath(binpath, env_dir) context.env_exe = os.path.join(binpath, exename) create_if_needed(binpath) # Assign and update the command to use when launching the newly created @@ -149,7 +199,7 @@ def create_if_needed(d): # bpo-45337: Fix up env_exec_cmd to account for file system redirections. # Some redirects only apply to CreateFile and not CreateProcess real_env_exe = os.path.realpath(context.env_exe) - if os.path.normcase(real_env_exe) != os.path.normcase(context.env_exe): + if not self._same_path(real_env_exe, context.env_exe): logger.warning('Actual environment location may have moved due to ' 'redirects, links or junctions.\n' ' Requested location: "%s"\n' @@ -178,86 +228,84 @@ def create_configuration(self, context): f.write('version = %d.%d.%d\n' % sys.version_info[:3]) if self.prompt is not None: f.write(f'prompt = {self.prompt!r}\n') - - if os.name != 'nt': - def symlink_or_copy(self, src, dst, relative_symlinks_ok=False): - """ - Try symlinking a file, and if that fails, fall back to copying. - """ - force_copy = not self.symlinks - if not force_copy: - try: - if not os.path.islink(dst): # can't link to itself! - if relative_symlinks_ok: - assert os.path.dirname(src) == os.path.dirname(dst) - os.symlink(os.path.basename(src), dst) - else: - os.symlink(src, dst) - except Exception: # may need to use a more specific exception - logger.warning('Unable to symlink %r to %r', src, dst) - force_copy = True - if force_copy: - shutil.copyfile(src, dst) - else: - def symlink_or_copy(self, src, dst, relative_symlinks_ok=False): - """ - Try symlinking a file, and if that fails, fall back to copying. - """ - bad_src = os.path.lexists(src) and not os.path.exists(src) - if self.symlinks and not bad_src and not os.path.islink(dst): - try: + f.write('executable = %s\n' % os.path.realpath(sys.executable)) + args = [] + nt = os.name == 'nt' + if nt and self.symlinks: + args.append('--symlinks') + if not nt and not self.symlinks: + args.append('--copies') + if not self.with_pip: + args.append('--without-pip') + if self.system_site_packages: + args.append('--system-site-packages') + if self.clear: + args.append('--clear') + if self.upgrade: + args.append('--upgrade') + if self.upgrade_deps: + args.append('--upgrade-deps') + if self.orig_prompt is not None: + args.append(f'--prompt="{self.orig_prompt}"') + if not self.scm_ignore_files: + args.append('--without-scm-ignore-files') + + args.append(context.env_dir) + args = ' '.join(args) + f.write(f'command = {sys.executable} -m venv {args}\n') + + def symlink_or_copy(self, src, dst, relative_symlinks_ok=False): + """ + Try symlinking a file, and if that fails, fall back to copying. + (Unused on Windows, because we can't just copy a failed symlink file: we + switch to a different set of files instead.) + """ + assert os.name != 'nt' + force_copy = not self.symlinks + if not force_copy: + try: + if not os.path.islink(dst): # can't link to itself! if relative_symlinks_ok: assert os.path.dirname(src) == os.path.dirname(dst) os.symlink(os.path.basename(src), dst) else: os.symlink(src, dst) - return - except Exception: # may need to use a more specific exception - logger.warning('Unable to symlink %r to %r', src, dst) - - # On Windows, we rewrite symlinks to our base python.exe into - # copies of venvlauncher.exe - basename, ext = os.path.splitext(os.path.basename(src)) - srcfn = os.path.join(os.path.dirname(__file__), - "scripts", - "nt", - basename + ext) - # Builds or venv's from builds need to remap source file - # locations, as we do not put them into Lib/venv/scripts - if sysconfig.is_python_build(True) or not os.path.isfile(srcfn): - if basename.endswith('_d'): - ext = '_d' + ext - basename = basename[:-2] - if basename == 'python': - basename = 'venvlauncher' - elif basename == 'pythonw': - basename = 'venvwlauncher' - src = os.path.join(os.path.dirname(src), basename + ext) - else: - src = srcfn - if not os.path.exists(src): - if not bad_src: - logger.warning('Unable to copy %r', src) - return - + except Exception: # may need to use a more specific exception + logger.warning('Unable to symlink %r to %r', src, dst) + force_copy = True + if force_copy: shutil.copyfile(src, dst) - def setup_python(self, context): + def create_git_ignore_file(self, context): """ - Set up a Python executable in the environment. + Create a .gitignore file in the environment directory. - :param context: The information for the environment creation request - being processed. + The contents of the file cause the entire environment directory to be + ignored by git. """ - binpath = context.bin_path - path = context.env_exe - copier = self.symlink_or_copy - dirname = context.python_dir - if os.name != 'nt': + gitignore_path = os.path.join(context.env_dir, '.gitignore') + with open(gitignore_path, 'w', encoding='utf-8') as file: + file.write('# Created by venv; ' + 'see https://docs.python.org/3/library/venv.html\n') + file.write('*\n') + + if os.name != 'nt': + def setup_python(self, context): + """ + Set up a Python executable in the environment. + + :param context: The information for the environment creation request + being processed. + """ + binpath = context.bin_path + path = context.env_exe + copier = self.symlink_or_copy + dirname = context.python_dir copier(context.executable, path) if not os.path.islink(path): os.chmod(path, 0o755) - for suffix in ('python', 'python3', f'python3.{sys.version_info[1]}'): + for suffix in ('python', 'python3', + f'python3.{sys.version_info[1]}'): path = os.path.join(binpath, suffix) if not os.path.exists(path): # Issue 18807: make copies if @@ -265,32 +313,107 @@ def setup_python(self, context): copier(context.env_exe, path, relative_symlinks_ok=True) if not os.path.islink(path): os.chmod(path, 0o755) - else: - if self.symlinks: - # For symlinking, we need a complete copy of the root directory - # If symlinks fail, you'll get unnecessary copies of files, but - # we assume that if you've opted into symlinks on Windows then - # you know what you're doing. - suffixes = [ - f for f in os.listdir(dirname) if - os.path.normcase(os.path.splitext(f)[1]) in ('.exe', '.dll') - ] - if sysconfig.is_python_build(True): - suffixes = [ - f for f in suffixes if - os.path.normcase(f).startswith(('python', 'vcruntime')) - ] + + else: + def setup_python(self, context): + """ + Set up a Python executable in the environment. + + :param context: The information for the environment creation request + being processed. + """ + binpath = context.bin_path + dirname = context.python_dir + exename = os.path.basename(context.env_exe) + exe_stem = os.path.splitext(exename)[0] + exe_d = '_d' if os.path.normcase(exe_stem).endswith('_d') else '' + if sysconfig.is_python_build(): + scripts = dirname else: - suffixes = {'python.exe', 'python_d.exe', 'pythonw.exe', 'pythonw_d.exe'} - base_exe = os.path.basename(context.env_exe) - suffixes.add(base_exe) + scripts = os.path.join(os.path.dirname(__file__), + 'scripts', 'nt') + if not sysconfig.get_config_var("Py_GIL_DISABLED"): + python_exe = os.path.join(dirname, f'python{exe_d}.exe') + pythonw_exe = os.path.join(dirname, f'pythonw{exe_d}.exe') + link_sources = { + 'python.exe': python_exe, + f'python{exe_d}.exe': python_exe, + 'pythonw.exe': pythonw_exe, + f'pythonw{exe_d}.exe': pythonw_exe, + } + python_exe = os.path.join(scripts, f'venvlauncher{exe_d}.exe') + pythonw_exe = os.path.join(scripts, f'venvwlauncher{exe_d}.exe') + copy_sources = { + 'python.exe': python_exe, + f'python{exe_d}.exe': python_exe, + 'pythonw.exe': pythonw_exe, + f'pythonw{exe_d}.exe': pythonw_exe, + } + else: + exe_t = f'3.{sys.version_info[1]}t' + python_exe = os.path.join(dirname, f'python{exe_t}{exe_d}.exe') + pythonw_exe = os.path.join(dirname, f'pythonw{exe_t}{exe_d}.exe') + link_sources = { + 'python.exe': python_exe, + f'python{exe_d}.exe': python_exe, + f'python{exe_t}.exe': python_exe, + f'python{exe_t}{exe_d}.exe': python_exe, + 'pythonw.exe': pythonw_exe, + f'pythonw{exe_d}.exe': pythonw_exe, + f'pythonw{exe_t}.exe': pythonw_exe, + f'pythonw{exe_t}{exe_d}.exe': pythonw_exe, + } + python_exe = os.path.join(scripts, f'venvlaunchert{exe_d}.exe') + pythonw_exe = os.path.join(scripts, f'venvwlaunchert{exe_d}.exe') + copy_sources = { + 'python.exe': python_exe, + f'python{exe_d}.exe': python_exe, + f'python{exe_t}.exe': python_exe, + f'python{exe_t}{exe_d}.exe': python_exe, + 'pythonw.exe': pythonw_exe, + f'pythonw{exe_d}.exe': pythonw_exe, + f'pythonw{exe_t}.exe': pythonw_exe, + f'pythonw{exe_t}{exe_d}.exe': pythonw_exe, + } + + do_copies = True + if self.symlinks: + do_copies = False + # For symlinking, we need all the DLLs to be available alongside + # the executables. + link_sources.update({ + f: os.path.join(dirname, f) for f in os.listdir(dirname) + if os.path.normcase(f).startswith(('python', 'vcruntime')) + and os.path.normcase(os.path.splitext(f)[1]) == '.dll' + }) + + to_unlink = [] + for dest, src in link_sources.items(): + dest = os.path.join(binpath, dest) + try: + os.symlink(src, dest) + to_unlink.append(dest) + except OSError: + logger.warning('Unable to symlink %r to %r', src, dest) + do_copies = True + for f in to_unlink: + try: + os.unlink(f) + except OSError: + logger.warning('Failed to clean up symlink %r', + f) + logger.warning('Retrying with copies') + break - for suffix in suffixes: - src = os.path.join(dirname, suffix) - if os.path.lexists(src): - copier(src, os.path.join(binpath, suffix)) + if do_copies: + for dest, src in copy_sources.items(): + dest = os.path.join(binpath, dest) + try: + shutil.copy2(src, dest) + except OSError: + logger.warning('Unable to copy %r to %r', src, dest) - if sysconfig.is_python_build(True): + if sysconfig.is_python_build(): # copy init.tcl for root, dirs, files in os.walk(context.python_dir): if 'init.tcl' in files: @@ -303,14 +426,25 @@ def setup_python(self, context): shutil.copyfile(src, dst) break + def _call_new_python(self, context, *py_args, **kwargs): + """Executes the newly created Python using safe-ish options""" + # gh-98251: We do not want to just use '-I' because that masks + # legitimate user preferences (such as not writing bytecode). All we + # really need is to ensure that the path variables do not overrule + # normal venv handling. + args = [context.env_exec_cmd, *py_args] + kwargs['env'] = env = os.environ.copy() + env['VIRTUAL_ENV'] = context.env_dir + env.pop('PYTHONHOME', None) + env.pop('PYTHONPATH', None) + kwargs['cwd'] = context.env_dir + kwargs['executable'] = context.env_exec_cmd + subprocess.check_output(args, **kwargs) + def _setup_pip(self, context): """Installs or upgrades pip in a virtual environment""" - # We run ensurepip in isolated mode to avoid side effects from - # environment vars, the current directory and anything else - # intended for the global Python environment - cmd = [context.env_exec_cmd, '-Im', 'ensurepip', '--upgrade', - '--default-pip'] - subprocess.check_output(cmd, stderr=subprocess.STDOUT) + self._call_new_python(context, '-m', 'ensurepip', '--upgrade', + '--default-pip', stderr=subprocess.STDOUT) def setup_scripts(self, context): """ @@ -348,11 +482,41 @@ def replace_variables(self, text, context): :param context: The information for the environment creation request being processed. """ - text = text.replace('__VENV_DIR__', context.env_dir) - text = text.replace('__VENV_NAME__', context.env_name) - text = text.replace('__VENV_PROMPT__', context.prompt) - text = text.replace('__VENV_BIN_NAME__', context.bin_name) - text = text.replace('__VENV_PYTHON__', context.env_exe) + replacements = { + '__VENV_DIR__': context.env_dir, + '__VENV_NAME__': context.env_name, + '__VENV_PROMPT__': context.prompt, + '__VENV_BIN_NAME__': context.bin_name, + '__VENV_PYTHON__': context.env_exe, + } + + def quote_ps1(s): + """ + This should satisfy PowerShell quoting rules [1], unless the quoted + string is passed directly to Windows native commands [2]. + [1]: https://learn.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_quoting_rules + [2]: https://learn.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_parsing#passing-arguments-that-contain-quote-characters + """ + s = s.replace("'", "''") + return f"'{s}'" + + def quote_bat(s): + return s + + # gh-124651: need to quote the template strings properly + quote = shlex.quote + script_path = context.script_path + if script_path.endswith('.ps1'): + quote = quote_ps1 + elif script_path.endswith('.bat'): + quote = quote_bat + else: + # fallbacks to POSIX shell compliant quote + quote = shlex.quote + + replacements = {key: quote(s) for key, s in replacements.items()} + for key, quoted in replacements.items(): + text = text.replace(key, quoted) return text def install_scripts(self, context, path): @@ -370,15 +534,22 @@ def install_scripts(self, context, path): """ binpath = context.bin_path plen = len(path) + if os.name == 'nt': + def skip_file(f): + f = os.path.normcase(f) + return (f.startswith(('python', 'venv')) + and f.endswith(('.exe', '.pdb'))) + else: + def skip_file(f): + return False for root, dirs, files in os.walk(path): - if root == path: # at top-level, remove irrelevant dirs + if root == path: # at top-level, remove irrelevant dirs for d in dirs[:]: if d not in ('common', os.name): dirs.remove(d) - continue # ignore files in top level + continue # ignore files in top level for f in files: - if (os.name == 'nt' and f.startswith('python') - and f.endswith(('.exe', '.pdb'))): + if skip_file(f): continue srcfile = os.path.join(root, f) suffix = root[plen:].split(os.sep)[2:] @@ -389,116 +560,122 @@ def install_scripts(self, context, path): if not os.path.exists(dstdir): os.makedirs(dstdir) dstfile = os.path.join(dstdir, f) + if os.name == 'nt' and srcfile.endswith(('.exe', '.pdb')): + shutil.copy2(srcfile, dstfile) + continue with open(srcfile, 'rb') as f: data = f.read() - if not srcfile.endswith(('.exe', '.pdb')): - try: - data = data.decode('utf-8') - data = self.replace_variables(data, context) - data = data.encode('utf-8') - except UnicodeError as e: - data = None - logger.warning('unable to copy script %r, ' - 'may be binary: %s', srcfile, e) - if data is not None: + try: + context.script_path = srcfile + new_data = ( + self.replace_variables(data.decode('utf-8'), context) + .encode('utf-8') + ) + except UnicodeError as e: + logger.warning('unable to copy script %r, ' + 'may be binary: %s', srcfile, e) + continue + if new_data == data: + shutil.copy2(srcfile, dstfile) + else: with open(dstfile, 'wb') as f: - f.write(data) + f.write(new_data) shutil.copymode(srcfile, dstfile) def upgrade_dependencies(self, context): logger.debug( f'Upgrading {CORE_VENV_DEPS} packages in {context.bin_path}' ) - cmd = [context.env_exec_cmd, '-m', 'pip', 'install', '--upgrade'] - cmd.extend(CORE_VENV_DEPS) - subprocess.check_call(cmd) + self._call_new_python(context, '-m', 'pip', 'install', '--upgrade', + *CORE_VENV_DEPS) def create(env_dir, system_site_packages=False, clear=False, - symlinks=False, with_pip=False, prompt=None, upgrade_deps=False): + symlinks=False, with_pip=False, prompt=None, upgrade_deps=False, + *, scm_ignore_files=frozenset()): """Create a virtual environment in a directory.""" builder = EnvBuilder(system_site_packages=system_site_packages, clear=clear, symlinks=symlinks, with_pip=with_pip, - prompt=prompt, upgrade_deps=upgrade_deps) + prompt=prompt, upgrade_deps=upgrade_deps, + scm_ignore_files=scm_ignore_files) builder.create(env_dir) + def main(args=None): - compatible = True - if sys.version_info < (3, 3): - compatible = False - elif not hasattr(sys, 'base_prefix'): - compatible = False - if not compatible: - raise ValueError('This script is only for use with Python >= 3.3') + import argparse + + parser = argparse.ArgumentParser(prog=__name__, + description='Creates virtual Python ' + 'environments in one or ' + 'more target ' + 'directories.', + epilog='Once an environment has been ' + 'created, you may wish to ' + 'activate it, e.g. by ' + 'sourcing an activate script ' + 'in its bin directory.') + parser.add_argument('dirs', metavar='ENV_DIR', nargs='+', + help='A directory to create the environment in.') + parser.add_argument('--system-site-packages', default=False, + action='store_true', dest='system_site', + help='Give the virtual environment access to the ' + 'system site-packages dir.') + if os.name == 'nt': + use_symlinks = False else: - import argparse - - parser = argparse.ArgumentParser(prog=__name__, - description='Creates virtual Python ' - 'environments in one or ' - 'more target ' - 'directories.', - epilog='Once an environment has been ' - 'created, you may wish to ' - 'activate it, e.g. by ' - 'sourcing an activate script ' - 'in its bin directory.') - parser.add_argument('dirs', metavar='ENV_DIR', nargs='+', - help='A directory to create the environment in.') - parser.add_argument('--system-site-packages', default=False, - action='store_true', dest='system_site', - help='Give the virtual environment access to the ' - 'system site-packages dir.') - if os.name == 'nt': - use_symlinks = False - else: - use_symlinks = True - group = parser.add_mutually_exclusive_group() - group.add_argument('--symlinks', default=use_symlinks, - action='store_true', dest='symlinks', - help='Try to use symlinks rather than copies, ' - 'when symlinks are not the default for ' - 'the platform.') - group.add_argument('--copies', default=not use_symlinks, - action='store_false', dest='symlinks', - help='Try to use copies rather than symlinks, ' - 'even when symlinks are the default for ' - 'the platform.') - parser.add_argument('--clear', default=False, action='store_true', - dest='clear', help='Delete the contents of the ' - 'environment directory if it ' - 'already exists, before ' - 'environment creation.') - parser.add_argument('--upgrade', default=False, action='store_true', - dest='upgrade', help='Upgrade the environment ' - 'directory to use this version ' - 'of Python, assuming Python ' - 'has been upgraded in-place.') - parser.add_argument('--without-pip', dest='with_pip', - default=True, action='store_false', - help='Skips installing or upgrading pip in the ' - 'virtual environment (pip is bootstrapped ' - 'by default)') - parser.add_argument('--prompt', - help='Provides an alternative prompt prefix for ' - 'this environment.') - parser.add_argument('--upgrade-deps', default=False, action='store_true', - dest='upgrade_deps', - help='Upgrade core dependencies: {} to the latest ' - 'version in PyPI'.format( - ' '.join(CORE_VENV_DEPS))) - options = parser.parse_args(args) - if options.upgrade and options.clear: - raise ValueError('you cannot supply --upgrade and --clear together.') - builder = EnvBuilder(system_site_packages=options.system_site, - clear=options.clear, - symlinks=options.symlinks, - upgrade=options.upgrade, - with_pip=options.with_pip, - prompt=options.prompt, - upgrade_deps=options.upgrade_deps) - for d in options.dirs: - builder.create(d) + use_symlinks = True + group = parser.add_mutually_exclusive_group() + group.add_argument('--symlinks', default=use_symlinks, + action='store_true', dest='symlinks', + help='Try to use symlinks rather than copies, ' + 'when symlinks are not the default for ' + 'the platform.') + group.add_argument('--copies', default=not use_symlinks, + action='store_false', dest='symlinks', + help='Try to use copies rather than symlinks, ' + 'even when symlinks are the default for ' + 'the platform.') + parser.add_argument('--clear', default=False, action='store_true', + dest='clear', help='Delete the contents of the ' + 'environment directory if it ' + 'already exists, before ' + 'environment creation.') + parser.add_argument('--upgrade', default=False, action='store_true', + dest='upgrade', help='Upgrade the environment ' + 'directory to use this version ' + 'of Python, assuming Python ' + 'has been upgraded in-place.') + parser.add_argument('--without-pip', dest='with_pip', + default=True, action='store_false', + help='Skips installing or upgrading pip in the ' + 'virtual environment (pip is bootstrapped ' + 'by default)') + parser.add_argument('--prompt', + help='Provides an alternative prompt prefix for ' + 'this environment.') + parser.add_argument('--upgrade-deps', default=False, action='store_true', + dest='upgrade_deps', + help=f'Upgrade core dependencies ({", ".join(CORE_VENV_DEPS)}) ' + 'to the latest version in PyPI') + parser.add_argument('--without-scm-ignore-files', dest='scm_ignore_files', + action='store_const', const=frozenset(), + default=frozenset(['git']), + help='Skips adding SCM ignore files to the environment ' + 'directory (Git is supported by default).') + options = parser.parse_args(args) + if options.upgrade and options.clear: + raise ValueError('you cannot supply --upgrade and --clear together.') + builder = EnvBuilder(system_site_packages=options.system_site, + clear=options.clear, + symlinks=options.symlinks, + upgrade=options.upgrade, + with_pip=options.with_pip, + prompt=options.prompt, + upgrade_deps=options.upgrade_deps, + scm_ignore_files=options.scm_ignore_files) + for d in options.dirs: + builder.create(d) + if __name__ == '__main__': rc = 1 diff --git a/Lib/venv/__main__.py b/Lib/venv/__main__.py index 912423e4a78..88f55439dc2 100644 --- a/Lib/venv/__main__.py +++ b/Lib/venv/__main__.py @@ -6,5 +6,5 @@ main() rc = 0 except Exception as e: - print('Error: %s' % e, file=sys.stderr) + print('Error:', e, file=sys.stderr) sys.exit(rc) diff --git a/Lib/venv/scripts/common/Activate.ps1 b/Lib/venv/scripts/common/Activate.ps1 index b49d77ba44b..16ba5290fae 100644 --- a/Lib/venv/scripts/common/Activate.ps1 +++ b/Lib/venv/scripts/common/Activate.ps1 @@ -219,6 +219,8 @@ deactivate -nondestructive # that there is an activated venv. $env:VIRTUAL_ENV = $VenvDir +$env:VIRTUAL_ENV_PROMPT = $Prompt + if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) { Write-Verbose "Setting prompt to '$Prompt'" @@ -233,7 +235,6 @@ if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) { Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) " _OLD_VIRTUAL_PROMPT } - $env:VIRTUAL_ENV_PROMPT = $Prompt } # Clear PYTHONHOME diff --git a/Lib/venv/scripts/common/activate b/Lib/venv/scripts/common/activate index 6fbc2b8801d..70673a265d4 100644 --- a/Lib/venv/scripts/common/activate +++ b/Lib/venv/scripts/common/activate @@ -1,5 +1,5 @@ # This file must be used with "source bin/activate" *from bash* -# you cannot run it directly +# You cannot run it directly deactivate () { # reset old environment variables @@ -14,12 +14,10 @@ deactivate () { unset _OLD_VIRTUAL_PYTHONHOME fi - # This should detect bash and zsh, which have a hash command that must - # be called to get it to forget past commands. Without forgetting - # past commands the $PATH changes we made may not be respected - if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then - hash -r 2> /dev/null - fi + # Call hash to forget past locations. Without forgetting + # past locations the $PATH changes we made may not be respected. + # See "man bash" for more details. hash is usually a builtin of your shell + hash -r 2> /dev/null if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then PS1="${_OLD_VIRTUAL_PS1:-}" @@ -38,13 +36,27 @@ deactivate () { # unset irrelevant variables deactivate nondestructive -VIRTUAL_ENV="__VENV_DIR__" -export VIRTUAL_ENV +# on Windows, a path can contain colons and backslashes and has to be converted: +case "$(uname)" in + CYGWIN*|MSYS*|MINGW*) + # transform D:\path\to\venv to /d/path/to/venv on MSYS and MINGW + # and to /cygdrive/d/path/to/venv on Cygwin + VIRTUAL_ENV=$(cygpath __VENV_DIR__) + export VIRTUAL_ENV + ;; + *) + # use the path as-is + export VIRTUAL_ENV=__VENV_DIR__ + ;; +esac _OLD_VIRTUAL_PATH="$PATH" -PATH="$VIRTUAL_ENV/__VENV_BIN_NAME__:$PATH" +PATH="$VIRTUAL_ENV/"__VENV_BIN_NAME__":$PATH" export PATH +VIRTUAL_ENV_PROMPT=__VENV_PROMPT__ +export VIRTUAL_ENV_PROMPT + # unset PYTHONHOME if set # this will fail if PYTHONHOME is set to the empty string (which is bad anyway) # could use `if (set -u; : $PYTHONHOME) ;` in bash @@ -55,15 +67,10 @@ fi if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then _OLD_VIRTUAL_PS1="${PS1:-}" - PS1="__VENV_PROMPT__${PS1:-}" + PS1="("__VENV_PROMPT__") ${PS1:-}" export PS1 - VIRTUAL_ENV_PROMPT="__VENV_PROMPT__" - export VIRTUAL_ENV_PROMPT fi -# This should detect bash and zsh, which have a hash command that must -# be called to get it to forget past commands. Without forgetting +# Call hash to forget past commands. Without forgetting # past commands the $PATH changes we made may not be respected -if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then - hash -r 2> /dev/null -fi +hash -r 2> /dev/null diff --git a/Lib/venv/scripts/posix/activate.fish b/Lib/venv/scripts/common/activate.fish similarity index 76% rename from Lib/venv/scripts/posix/activate.fish rename to Lib/venv/scripts/common/activate.fish index e40a1d71489..284a7469c99 100644 --- a/Lib/venv/scripts/posix/activate.fish +++ b/Lib/venv/scripts/common/activate.fish @@ -1,5 +1,5 @@ # This file must be used with "source /bin/activate.fish" *from fish* -# (https://fishshell.com/); you cannot run it directly. +# (https://fishshell.com/). You cannot run it directly. function deactivate -d "Exit virtual environment and return to normal shell environment" # reset old environment variables @@ -13,10 +13,13 @@ function deactivate -d "Exit virtual environment and return to normal shell env end if test -n "$_OLD_FISH_PROMPT_OVERRIDE" - functions -e fish_prompt set -e _OLD_FISH_PROMPT_OVERRIDE - functions -c _old_fish_prompt fish_prompt - functions -e _old_fish_prompt + # prevents error when using nested fish instances (Issue #93858) + if functions -q _old_fish_prompt + functions -e fish_prompt + functions -c _old_fish_prompt fish_prompt + functions -e _old_fish_prompt + end end set -e VIRTUAL_ENV @@ -30,10 +33,11 @@ end # Unset irrelevant variables. deactivate nondestructive -set -gx VIRTUAL_ENV "__VENV_DIR__" +set -gx VIRTUAL_ENV __VENV_DIR__ set -gx _OLD_VIRTUAL_PATH $PATH -set -gx PATH "$VIRTUAL_ENV/__VENV_BIN_NAME__" $PATH +set -gx PATH "$VIRTUAL_ENV/"__VENV_BIN_NAME__ $PATH +set -gx VIRTUAL_ENV_PROMPT __VENV_PROMPT__ # Unset PYTHONHOME if set. if set -q PYTHONHOME @@ -53,7 +57,7 @@ if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" set -l old_status $status # Output the venv prompt; color taken from the blue of the Python logo. - printf "%s%s%s" (set_color 4B8BBE) "__VENV_PROMPT__" (set_color normal) + printf "%s(%s)%s " (set_color 4B8BBE) __VENV_PROMPT__ (set_color normal) # Restore the return status of the previous command. echo "exit $old_status" | . @@ -62,5 +66,4 @@ if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" end set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV" - set -gx VIRTUAL_ENV_PROMPT "__VENV_PROMPT__" end diff --git a/Lib/venv/scripts/nt/activate.bat b/Lib/venv/scripts/nt/activate.bat index 5daa45afc9f..9ac5c20b477 100644 --- a/Lib/venv/scripts/nt/activate.bat +++ b/Lib/venv/scripts/nt/activate.bat @@ -8,15 +8,15 @@ if defined _OLD_CODEPAGE ( "%SystemRoot%\System32\chcp.com" 65001 > nul ) -set VIRTUAL_ENV=__VENV_DIR__ +set "VIRTUAL_ENV=__VENV_DIR__" if not defined PROMPT set PROMPT=$P$G if defined _OLD_VIRTUAL_PROMPT set PROMPT=%_OLD_VIRTUAL_PROMPT% if defined _OLD_VIRTUAL_PYTHONHOME set PYTHONHOME=%_OLD_VIRTUAL_PYTHONHOME% -set _OLD_VIRTUAL_PROMPT=%PROMPT% -set PROMPT=__VENV_PROMPT__%PROMPT% +set "_OLD_VIRTUAL_PROMPT=%PROMPT%" +set "PROMPT=(__VENV_PROMPT__) %PROMPT%" if defined PYTHONHOME set _OLD_VIRTUAL_PYTHONHOME=%PYTHONHOME% set PYTHONHOME= @@ -24,8 +24,8 @@ set PYTHONHOME= if defined _OLD_VIRTUAL_PATH set PATH=%_OLD_VIRTUAL_PATH% if not defined _OLD_VIRTUAL_PATH set _OLD_VIRTUAL_PATH=%PATH% -set PATH=%VIRTUAL_ENV%\__VENV_BIN_NAME__;%PATH% -set VIRTUAL_ENV_PROMPT=__VENV_PROMPT__ +set "PATH=%VIRTUAL_ENV%\__VENV_BIN_NAME__;%PATH%" +set "VIRTUAL_ENV_PROMPT=__VENV_PROMPT__" :END if defined _OLD_CODEPAGE ( diff --git a/Lib/venv/scripts/posix/activate.csh b/Lib/venv/scripts/posix/activate.csh index d6f697c55ed..2a3fa835476 100644 --- a/Lib/venv/scripts/posix/activate.csh +++ b/Lib/venv/scripts/posix/activate.csh @@ -1,5 +1,6 @@ # This file must be used with "source bin/activate.csh" *from csh*. # You cannot run it directly. + # Created by Davide Di Blasi . # Ported to Python 3.3 venv by Andrew Svetlov @@ -8,17 +9,17 @@ alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PA # Unset irrelevant variables. deactivate nondestructive -setenv VIRTUAL_ENV "__VENV_DIR__" +setenv VIRTUAL_ENV __VENV_DIR__ set _OLD_VIRTUAL_PATH="$PATH" -setenv PATH "$VIRTUAL_ENV/__VENV_BIN_NAME__:$PATH" +setenv PATH "$VIRTUAL_ENV/"__VENV_BIN_NAME__":$PATH" +setenv VIRTUAL_ENV_PROMPT __VENV_PROMPT__ set _OLD_VIRTUAL_PROMPT="$prompt" if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then - set prompt = "__VENV_PROMPT__$prompt" - setenv VIRTUAL_ENV_PROMPT "__VENV_PROMPT__" + set prompt = "("__VENV_PROMPT__") $prompt:q" endif alias pydoc python -m pydoc diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 52ecf927d5c..f9bd2b59456 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -164,11 +164,70 @@ mod sys { #[pyattr] fn _base_executable(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; + // First check __PYVENV_LAUNCHER__ environment variable if let Ok(var) = env::var("__PYVENV_LAUNCHER__") { - ctx.new_str(var).into() - } else { - executable(vm) + return ctx.new_str(var).into(); + } + + // Try to detect if we're running from a venv by looking for pyvenv.cfg + if let Some(base_exe) = get_venv_base_executable() { + return ctx.new_str(base_exe).into(); } + + executable(vm) + } + + /// Try to find base executable from pyvenv.cfg (see getpath.py) + fn get_venv_base_executable() -> Option { + // TODO: This is a minimal implementation of getpath.py + // To fully support all cases, `getpath.py` should be placed in @crates/vm/Lib/python_builtins/ + + // Get current executable path + #[cfg(not(target_arch = "wasm32"))] + let exe_path = { + let exec_arg = env::args_os().next()?; + which::which(exec_arg).ok()? + }; + #[cfg(target_arch = "wasm32")] + let exe_path = { + let exec_arg = env::args().next()?; + path::PathBuf::from(exec_arg) + }; + + let exe_dir = exe_path.parent()?; + let exe_name = exe_path.file_name()?; + + // Look for pyvenv.cfg in parent directory (typical venv layout: venv/bin/python) + let venv_dir = exe_dir.parent()?; + let pyvenv_cfg = venv_dir.join("pyvenv.cfg"); + + if !pyvenv_cfg.exists() { + return None; + } + + // Parse pyvenv.cfg and extract home directory + let content = std::fs::read_to_string(&pyvenv_cfg).ok()?; + + for line in content.lines() { + if let Some((key, value)) = line.split_once('=') { + let key = key.trim().to_lowercase(); + let value = value.trim(); + + if key == "home" { + // First try to resolve symlinks (getpath.py line 373-377) + if let Ok(resolved) = std::fs::canonicalize(&exe_path) + && resolved != exe_path + { + return Some(resolved.to_string_lossy().into_owned()); + } + // Fallback: home_dir + executable_name (getpath.py line 381) + let base_exe = path::Path::new(value).join(exe_name); + return Some(base_exe.to_string_lossy().into_owned()); + } + } + } + + None } #[pyattr] From 7bfa5d9ceda6c5a3fb402dcf89d37b2c3e47f9f4 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sat, 20 Dec 2025 18:31:00 +0900 Subject: [PATCH 016/418] Buffer LongDouble + ndim (#6460) * fix fix_test * ndim buffer --- crates/vm/src/buffer.rs | 20 +++++++++++++++++--- crates/vm/src/protocol/buffer.rs | 11 +++++++++-- scripts/fix_test.py | 6 +++++- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/crates/vm/src/buffer.rs b/crates/vm/src/buffer.rs index eeb6a676542..5c67f87521d 100644 --- a/crates/vm/src/buffer.rs +++ b/crates/vm/src/buffer.rs @@ -88,7 +88,9 @@ pub(crate) enum FormatType { Half = b'e', Float = b'f', Double = b'd', + LongDouble = b'g', VoidP = b'P', + PyObject = b'O', } impl fmt::Debug for FormatType { @@ -148,7 +150,9 @@ impl FormatType { Half => nonnative_info!(f16, $end), Float => nonnative_info!(f32, $end), Double => nonnative_info!(f64, $end), - _ => unreachable!(), // size_t or void* + LongDouble => nonnative_info!(f64, $end), // long double same as double + PyObject => nonnative_info!(usize, $end), // pointer size + _ => unreachable!(), // size_t or void* } }}; } @@ -183,7 +187,9 @@ impl FormatType { Half => native_info!(f16), Float => native_info!(raw::c_float), Double => native_info!(raw::c_double), + LongDouble => native_info!(raw::c_double), // long double same as double for now VoidP => native_info!(*mut raw::c_void), + PyObject => native_info!(*mut raw::c_void), // pointer to PyObject }, Endianness::Big => match_nonnative!(self, BigEndian), Endianness::Little => match_nonnative!(self, LittleEndian), @@ -306,8 +312,16 @@ impl FormatCode { continue; } - if c == b'{' || c == b'}' { - // Skip standalone braces (pointer targets, etc.) + if c == b'{' + || c == b'}' + || c == b'&' + || c == b'<' + || c == b'>' + || c == b'@' + || c == b'=' + || c == b'!' + { + // Skip standalone braces (pointer targets, etc.), pointer prefix, and nested endianness markers continue; } diff --git a/crates/vm/src/protocol/buffer.rs b/crates/vm/src/protocol/buffer.rs index 948ec763dc6..0a34af59080 100644 --- a/crates/vm/src/protocol/buffer.rs +++ b/crates/vm/src/protocol/buffer.rs @@ -201,16 +201,23 @@ impl BufferDescriptor { #[cfg(debug_assertions)] pub fn validate(self) -> Self { - assert!(self.itemsize != 0); // ndim=0 is valid for scalar types (e.g., ctypes Structure) if self.ndim() == 0 { + // Empty structures (len=0) can have itemsize=0 + if self.len > 0 { + assert!(self.itemsize != 0); + } assert!(self.itemsize == self.len); } else { let mut shape_product = 1; + let has_zero_dim = self.dim_desc.iter().any(|(s, _, _)| *s == 0); for (shape, stride, suboffset) in self.dim_desc.iter().cloned() { shape_product *= shape; assert!(suboffset >= 0); - assert!(stride != 0); + // For empty arrays (any dimension is 0), strides can be 0 + if !has_zero_dim { + assert!(stride != 0); + } } assert!(shape_product * self.itemsize == self.len); } diff --git a/scripts/fix_test.py b/scripts/fix_test.py index 9716bd0b008..53b10d63834 100644 --- a/scripts/fix_test.py +++ b/scripts/fix_test.py @@ -159,7 +159,11 @@ def run_test(test_name): if not test_path.exists(): print(f"Error: File not found: {test_path}") sys.exit(1) - test_name = test_path.stem + # Detect package tests (e.g., test_ctypes/test_random_things.py) + if test_path.parent.name.startswith("test_"): + test_name = f"{test_path.parent.name}.{test_path.stem}" + else: + test_name = test_path.stem tests = run_test(test_name) f = test_path.read_text(encoding="utf-8") From 6660170bf838b239b74a5c0fe9dab7615f3c7e41 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sat, 20 Dec 2025 22:06:22 +0900 Subject: [PATCH 017/418] Share more ssl consts and fix openssl (#6462) * Reuse SSL error names * Share ssl error codes * fix ssl * fix openssl --- crates/stdlib/src/openssl.rs | 429 +++++++++++++++++++++++--------- crates/stdlib/src/ssl.rs | 22 -- crates/stdlib/src/ssl/compat.rs | 15 +- crates/stdlib/src/ssl/error.rs | 24 +- 4 files changed, 347 insertions(+), 143 deletions(-) diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index e07ad552f17..b6d5f5c2035 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -53,11 +53,10 @@ cfg_if::cfg_if! { mod _ssl { use super::{bio, probe}; - // Import error types used in this module (others are exposed via pymodule(with(...))) + // Import error types and helpers used in this module (others are exposed via pymodule(with(...))) use super::ssl_error::{ - PySSLCertVerificationError as PySslCertVerificationError, PySSLEOFError as PySslEOFError, - PySSLError as PySslError, PySSLWantReadError as PySslWantReadError, - PySSLWantWriteError as PySslWantWriteError, + PySSLCertVerificationError, PySSLError, create_ssl_eof_error, create_ssl_want_read_error, + create_ssl_want_write_error, }; use crate::{ common::lock::{ @@ -249,8 +248,6 @@ mod _ssl { #[pyattr] const VERIFY_DEFAULT: u32 = 0; #[pyattr] - const SSL_ERROR_EOF: u32 = 8; // custom for python - #[pyattr] const HAS_SNI: bool = true; #[pyattr] const HAS_ECDH: bool = true; @@ -709,21 +706,124 @@ mod _ssl { } } + // OpenSSL record type constants for msg_callback + const SSL3_RT_CHANGE_CIPHER_SPEC: i32 = 20; + const SSL3_RT_ALERT: i32 = 21; + const SSL3_RT_HANDSHAKE: i32 = 22; + const SSL3_RT_HEADER: i32 = 256; + const SSL3_RT_INNER_CONTENT_TYPE: i32 = 257; + // Special value for change cipher spec (CPython compatibility) + const SSL3_MT_CHANGE_CIPHER_SPEC: i32 = 0x0101; + // Message callback function called by OpenSSL - // NOTE: This callback is intentionally a no-op to avoid deadlocks. - // The msg_callback can be called during various SSL operations (read, write, handshake), - // and invoking Python code from within these operations can cause deadlocks - // (see CPython bpo-43577). A proper implementation would require careful lock ordering. + // Called during SSL operations to report protocol messages. + // debughelpers.c:_PySSL_msg_callback unsafe extern "C" fn _msg_callback( - _write_p: libc::c_int, - _version: libc::c_int, - _content_type: libc::c_int, - _buf: *const libc::c_void, - _len: usize, - _ssl_ptr: *mut sys::SSL, + write_p: libc::c_int, + mut version: libc::c_int, + content_type: libc::c_int, + buf: *const libc::c_void, + len: usize, + ssl_ptr: *mut sys::SSL, _arg: *mut libc::c_void, ) { - // Intentionally empty to avoid deadlocks + if ssl_ptr.is_null() { + return; + } + + unsafe { + // Get SSL socket from SSL_get_ex_data (index 0) + let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); + if ssl_socket_ptr.is_null() { + return; + } + + // ssl_socket_ptr is a pointer to Py, set in _wrap_socket/_wrap_bio + let ssl_socket: &Py = &*(ssl_socket_ptr as *const Py); + + // Get the callback from the context + let callback_opt = ssl_socket.ctx.read().msg_callback.lock().clone(); + let Some(callback) = callback_opt else { + return; + }; + + // Get VM from thread-local storage (set by HandshakeVmGuard in do_handshake) + let Some(vm_ptr) = HANDSHAKE_VM.with(|cell| cell.get()) else { + // VM not available - this shouldn't happen during handshake + return; + }; + let vm = &*vm_ptr; + + // Get SSL socket owner object + let ssl_socket_obj = ssl_socket + .owner + .read() + .as_ref() + .and_then(|weak| weak.upgrade()) + .unwrap_or_else(|| vm.ctx.none()); + + // Create the message bytes + let buf_slice = std::slice::from_raw_parts(buf as *const u8, len); + let msg_bytes = vm.ctx.new_bytes(buf_slice.to_vec()); + + // Determine direction string + let direction_str = if write_p != 0 { "write" } else { "read" }; + + // Calculate msg_type based on content_type (debughelpers.c behavior) + let msg_type = match content_type { + SSL3_RT_CHANGE_CIPHER_SPEC => SSL3_MT_CHANGE_CIPHER_SPEC, + SSL3_RT_ALERT => { + // byte 1 is alert type + if len >= 2 { buf_slice[1] as i32 } else { -1 } + } + SSL3_RT_HANDSHAKE => { + // byte 0 is handshake type + if !buf_slice.is_empty() { + buf_slice[0] as i32 + } else { + -1 + } + } + SSL3_RT_HEADER => { + // Frame header: version in bytes 1..2, type in byte 0 + if len >= 3 { + version = ((buf_slice[1] as i32) << 8) | (buf_slice[2] as i32); + buf_slice[0] as i32 + } else { + -1 + } + } + SSL3_RT_INNER_CONTENT_TYPE => { + // Inner content type in byte 0 + if !buf_slice.is_empty() { + buf_slice[0] as i32 + } else { + -1 + } + } + _ => -1, + }; + + // Call the Python callback + // Signature: callback(conn, direction, version, content_type, msg_type, data) + match callback.call( + ( + ssl_socket_obj, + vm.ctx.new_str(direction_str), + vm.ctx.new_int(version), + vm.ctx.new_int(content_type), + vm.ctx.new_int(msg_type), + msg_bytes, + ), + vm, + ) { + Ok(_) => {} + Err(exc) => { + // Log the exception but don't propagate it + vm.run_unraisable(exc, None, vm.ctx.none()); + } + } + } } #[pyfunction(name = "RAND_pseudo_bytes")] @@ -794,6 +894,7 @@ mod _ssl { SslVerifyMode::NONE }); + // Start with OP_ALL but remove options that CPython doesn't include by default let mut options = SslOptions::ALL & !SslOptions::DONT_INSERT_EMPTY_FRAGMENTS; if proto != SslVersion::Ssl2 { options |= SslOptions::NO_SSLV2; @@ -807,6 +908,8 @@ mod _ssl { options |= SslOptions::SINGLE_ECDH_USE; options |= SslOptions::ENABLE_MIDDLEBOX_COMPAT; builder.set_options(options); + // Remove NO_TLSv1 and NO_TLSv1_1 which newer OpenSSL adds to OP_ALL + builder.clear_options(SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1); let mode = ssl::SslMode::ACCEPT_MOVING_WRITE_BUFFER | ssl::SslMode::AUTO_RETRY; builder.set_mode(mode); @@ -1242,7 +1345,7 @@ mod _ssl { if let Some(cadata) = args.cadata { let certs = match cadata { Either::A(s) => { - if !s.is_ascii() { + if !s.as_str().is_ascii() { return Err(invalid_cadata(vm)); } X509::stack_from_pem(s.as_bytes()) @@ -1261,6 +1364,29 @@ mod _ssl { if args.cafile.is_some() || args.capath.is_some() { let cafile_path = args.cafile.map(|p| p.to_path_buf(vm)).transpose()?; let capath_path = args.capath.map(|p| p.to_path_buf(vm)).transpose()?; + // Check file/directory existence before calling OpenSSL to get proper errno + if let Some(ref path) = cafile_path { + if !path.exists() { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", path.display()), + ) + .upcast()); + } + } + if let Some(ref path) = capath_path { + if !path.exists() { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", path.display()), + ) + .upcast()); + } + } ctx.load_verify_locations(cafile_path.as_deref(), capath_path.as_deref()) .map_err(|e| convert_openssl_error(vm, e))?; } @@ -1387,7 +1513,7 @@ mod _ssl { std::io::ErrorKind::NotFound => vm .new_os_subtype_error( vm.ctx.exceptions.file_not_found_error.to_owned(), - None, + Some(libc::ENOENT), e.to_string(), ) .upcast(), @@ -1559,6 +1685,27 @@ mod _ssl { let mut ctx = self.builder(); let key_path = keyfile.map(|path| path.to_path_buf(vm)).transpose()?; let cert_path = certfile.to_path_buf(vm)?; + // Check file existence before calling OpenSSL to get proper errno + if !cert_path.exists() { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", cert_path.display()), + ) + .upcast()); + } + if let Some(ref kp) = key_path { + if !kp.exists() { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", kp.display()), + ) + .upcast()); + } + } ctx.set_certificate_chain_file(&cert_path) .and_then(|()| { ctx.set_private_key_file( @@ -1623,7 +1770,7 @@ mod _ssl { )); } if hostname_str.contains('\0') { - return Err(vm.new_value_error("embedded null byte in server_hostname")); + return Err(vm.new_type_error("embedded null character")); } let ip = hostname_str.parse::(); if ip.is_err() { @@ -1704,12 +1851,19 @@ mod _ssl { // Check if SNI callback is configured (minimize lock time) let has_sni_callback = zelf.sni_callback.lock().is_some(); - // Set SNI callback data if needed (after releasing the lock) - if has_sni_callback { - let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?; - unsafe { - let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); + // Set up ex_data for callbacks + unsafe { + let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); + + // Store ssl_socket pointer in index 0 for msg_callback (like CPython's SSL_set_app_data) + // This is safe because ssl_socket owns the SSL object and outlives it + // We store a pointer to Py, which msg_callback can dereference + let py_ptr: *const Py = &*py_ref; + sys::SSL_set_ex_data(ssl_ptr, 0, py_ptr as *mut _); + // Set SNI callback data if needed + if has_sni_callback { + let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?; // Store callback data in SSL ex_data - use weak reference to avoid cycle let callback_data = Box::new(SniCallbackData { ssl_context: zelf.clone(), @@ -1765,12 +1919,19 @@ mod _ssl { // Check if SNI callback is configured (minimize lock time) let has_sni_callback = zelf.sni_callback.lock().is_some(); - // Set SNI callback data if needed (after releasing the lock) - if has_sni_callback { - let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?; - unsafe { - let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); + // Set up ex_data for callbacks + unsafe { + let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); + // Store ssl_socket pointer in index 0 for msg_callback (like CPython's SSL_set_app_data) + // This is safe because ssl_socket owns the SSL object and outlives it + // We store a pointer to Py, which msg_callback can dereference + let py_ptr: *const Py = &*py_ref; + sys::SSL_set_ex_data(ssl_ptr, 0, py_ptr as *mut _); + + // Set SNI callback data if needed + if has_sni_callback { + let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?; // Store callback data in SSL ex_data - use weak reference to avoid cycle let callback_data = Box::new(SniCallbackData { ssl_context: zelf.clone(), @@ -1866,12 +2027,14 @@ mod _ssl { Some(s) => s, None => return SelectRet::Closed, }; - let deadline = match &deadline { + // For blocking sockets without timeout, call sock_select with None timeout + // to actually block waiting for data instead of busy-looping + let timeout = match &deadline { Ok(deadline) => match deadline.checked_duration_since(Instant::now()) { - Some(deadline) => deadline, + Some(d) => Some(d), None => return SelectRet::TimedOut, }, - Err(true) => return SelectRet::IsBlocking, + Err(true) => None, // Blocking: no timeout, wait indefinitely Err(false) => return SelectRet::Nonblocking, }; let res = socket::sock_select( @@ -1880,7 +2043,7 @@ mod _ssl { SslNeeds::Read => socket::SelectKind::Read, SslNeeds::Write => socket::SelectKind::Write, }, - Some(deadline), + timeout, ); match res { Ok(true) => SelectRet::TimedOut, @@ -2017,6 +2180,14 @@ mod _ssl { SslConnection::Bio(stream) => stream.get_shutdown(), } } + + // Check if incoming BIO has EOF (for BIO mode only) + fn is_bio_eof(&self) -> bool { + match self { + SslConnection::Socket(_) => false, + SslConnection::Bio(stream) => stream.get_ref().inbio.eof_written.load(), + } + } } #[pyattr] @@ -2172,17 +2343,21 @@ mod _ssl { #[pymethod] fn version(&self) -> Option<&'static str> { - let v = self.connection.read().ssl().version_str(); + // Use thread-local SSL pointer during handshake to avoid deadlock + let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); + // Return None if handshake is not complete (CPython behavior) + if unsafe { sys::SSL_is_init_finished(ssl_ptr) } == 0 { + return None; + } + let v = unsafe { ssl::SslRef::from_ptr(ssl_ptr).version_str() }; if v == "unknown" { None } else { Some(v) } } #[pymethod] fn cipher(&self) -> Option { - self.connection - .read() - .ssl() - .current_cipher() - .map(cipher_to_tuple) + // Use thread-local SSL pointer during handshake to avoid deadlock + let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); + unsafe { ssl::SslRef::from_ptr(ssl_ptr).current_cipher() }.map(cipher_to_tuple) } #[pymethod] @@ -2195,14 +2370,15 @@ mod _ssl { fn shared_ciphers(&self, vm: &VirtualMachine) -> Option { #[cfg(ossl110)] { - let stream = self.connection.read(); + // Use thread-local SSL pointer during handshake to avoid deadlock + let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); unsafe { - let server_ciphers = SSL_get_ciphers(stream.ssl().as_ptr()); + let server_ciphers = SSL_get_ciphers(ssl_ptr); if server_ciphers.is_null() { return None; } - let client_ciphers = SSL_get_client_ciphers(stream.ssl().as_ptr()); + let client_ciphers = SSL_get_client_ciphers(ssl_ptr); if client_ciphers.is_null() { return None; } @@ -2258,12 +2434,13 @@ mod _ssl { fn selected_alpn_protocol(&self) -> Option { #[cfg(ossl102)] { - let stream = self.connection.read(); + // Use thread-local SSL pointer during handshake to avoid deadlock + let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); unsafe { let mut out: *const libc::c_uchar = std::ptr::null(); let mut outlen: libc::c_uint = 0; - sys::SSL_get0_alpn_selected(stream.ssl().as_ptr(), &mut out, &mut outlen); + sys::SSL_get0_alpn_selected(ssl_ptr, &mut out, &mut outlen); if out.is_null() { None @@ -2343,42 +2520,53 @@ mod _ssl { } #[pymethod] - fn shutdown(&self, vm: &VirtualMachine) -> PyResult> { + fn shutdown(&self, vm: &VirtualMachine) -> PyResult>> { let stream = self.connection.read(); - - // BIO mode doesn't have an underlying socket - if stream.is_bio() { - return Err(vm.new_not_implemented_error( - "shutdown() is not supported for BIO-based SSL objects".to_owned(), - )); - } - let ssl_ptr = stream.ssl().as_ptr(); - // Perform SSL shutdown - let ret = unsafe { sys::SSL_shutdown(ssl_ptr) }; + // Perform SSL shutdown - may need to be called twice: + // 1st call: sends close-notify, returns 0 + // 2nd call: reads peer's close-notify, returns 1 + let mut ret = unsafe { sys::SSL_shutdown(ssl_ptr) }; + + // If ret == 0, try once more to complete the bidirectional shutdown + // This handles the case where peer's close-notify is already available + if ret == 0 { + ret = unsafe { sys::SSL_shutdown(ssl_ptr) }; + } if ret < 0 { // Error occurred let err = unsafe { sys::SSL_get_error(ssl_ptr, ret) }; - if err == sys::SSL_ERROR_WANT_READ || err == sys::SSL_ERROR_WANT_WRITE { - // Non-blocking would block - this is okay for shutdown - // Return the underlying socket + if err == sys::SSL_ERROR_WANT_READ { + return Err(create_ssl_want_read_error(vm).upcast()); + } else if err == sys::SSL_ERROR_WANT_WRITE { + return Err(create_ssl_want_write_error(vm).upcast()); } else { return Err(new_ssl_error( vm, format!("SSL shutdown failed: error code {}", err), )); } + } else if ret == 0 { + // Still waiting for peer's close-notify after retry + // In BIO mode, raise SSLWantReadError + if stream.is_bio() { + return Err(create_ssl_want_read_error(vm).upcast()); + } + } + + // BIO mode doesn't have an underlying socket to return + if stream.is_bio() { + return Ok(None); } - // Return the underlying socket - // Get the socket from the stream (SocketStream wraps PyRef) + // Return the underlying socket for socket mode let socket = stream .get_ref() .expect("unwrap() called on bio mode; should only be called in socket mode"); - Ok(socket.0.clone()) + Ok(Some(socket.0.clone())) } #[cfg(osslconf = "OPENSSL_NO_COMP")] @@ -2389,8 +2577,9 @@ mod _ssl { #[cfg(not(osslconf = "OPENSSL_NO_COMP"))] #[pymethod] fn compression(&self) -> Option<&'static str> { - let stream = self.connection.read(); - let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) }; + // Use thread-local SSL pointer during handshake to avoid deadlock + let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); + let comp_method = unsafe { sys::SSL_get_current_compression(ssl_ptr) }; if comp_method.is_null() { return None; } @@ -2416,7 +2605,7 @@ mod _ssl { let result = stream.do_handshake().map_err(|e| { let exc = convert_ssl_error(vm, e); // If it's a cert verification error, set verify info - if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { + if exc.class().is(PySSLCertVerificationError::class(&vm.ctx)) { set_verify_error_info(&exc, ssl_ptr, vm); } exc @@ -2473,7 +2662,7 @@ mod _ssl { } let exc = convert_ssl_error(vm, err); // If it's a cert verification error, set verify info - if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { + if exc.class().is(PySSLCertVerificationError::class(&vm.ctx)) { set_verify_error_info(&exc, ssl_ptr, vm); } // Clean up SNI ex_data before returning error @@ -2603,19 +2792,41 @@ mod _ssl { #[pygetset] fn session_reused(&self) -> bool { - let stream = self.connection.read(); - unsafe { sys::SSL_session_reused(stream.ssl().as_ptr()) != 0 } + // Use thread-local SSL pointer during handshake to avoid deadlock + let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); + unsafe { sys::SSL_session_reused(ssl_ptr) != 0 } } #[pymethod] fn read( &self, - n: usize, + n: isize, buffer: OptionalArg, vm: &VirtualMachine, ) -> PyResult { + // Handle negative n: + // - If buffer is None and n < 0: raise ValueError + // - If buffer is present and n <= 0: use buffer length + // This matches _ssl__SSLSocket_read_impl in CPython + let read_len: usize = match &buffer { + OptionalArg::Present(buf) => { + let buf_len = buf.borrow_buf_mut().len(); + if n <= 0 || (n as usize) > buf_len { + buf_len + } else { + n as usize + } + } + OptionalArg::Missing => { + if n < 0 { + return Err(vm.new_value_error("size should not be negative".to_owned())); + } + n as usize + } + }; + // Special case: reading 0 bytes should return empty bytes immediately - if n == 0 { + if read_len == 0 { return if buffer.is_present() { Ok(vm.ctx.new_int(0).into()) } else { @@ -2627,13 +2838,13 @@ mod _ssl { let mut inner_buffer = if let OptionalArg::Present(buffer) = &buffer { Either::A(buffer.borrow_buf_mut()) } else { - Either::B(vec![0u8; n]) + Either::B(vec![0u8; read_len]) }; let buf = match &mut inner_buffer { Either::A(b) => &mut **b, Either::B(b) => b.as_mut_slice(), }; - let buf = match buf.get_mut(..n) { + let buf = match buf.get_mut(..read_len) { Some(b) => b, None => buf, }; @@ -2642,7 +2853,18 @@ mod _ssl { let count = if stream.is_bio() { match stream.ssl_read(buf) { Ok(count) => count, - Err(e) => return Err(convert_ssl_error(vm, e)), + Err(e) => { + // Handle ZERO_RETURN (EOF) - raise SSLEOFError + if e.code() == ssl::ErrorCode::ZERO_RETURN { + return Err(create_ssl_eof_error(vm).upcast()); + } + // If WANT_READ and the incoming BIO has EOF written, + // this is an unexpected EOF (transport closed without TLS close_notify) + if e.code() == ssl::ErrorCode::WANT_READ && stream.is_bio_eof() { + return Err(create_ssl_eof_error(vm).upcast()); + } + return Err(convert_ssl_error(vm, e)); + } } } else { // Socket mode: handle timeout and blocking @@ -3051,22 +3273,10 @@ mod _ssl { let len = size.unwrap_or(-1); let len = if len < 0 || len > avail { avail } else { len }; - // Check if EOF has been written and no data available - // This matches CPython's behavior where read() returns b'' when EOF is set - if len == 0 && self.eof_written.load() { - return Ok(Vec::new()); - } - + // When no data available, return empty bytes (CPython behavior) + // CPython returns empty bytes directly without calling BIO_read() if len == 0 { - // No data available and no EOF - would block - // Call BIO_read() to get the proper error (SSL_ERROR_WANT_READ) - let mut test_buf = [0u8; 1]; - let nbytes = sys::BIO_read(self.bio, test_buf.as_mut_ptr() as *mut _, 1); - if nbytes < 0 { - return Err(convert_openssl_error(vm, ErrorStack::get())); - } - // Shouldn't reach here, but if we do, return what we got - return Ok(test_buf[..nbytes as usize].to_vec()); + return Ok(Vec::new()); } let mut buf = vec![0u8; len as usize]; @@ -3175,7 +3385,7 @@ mod _ssl { /// Helper function to create SSL error with proper OSError subtype handling fn new_ssl_error(vm: &VirtualMachine, msg: impl ToString) -> PyBaseExceptionRef { - vm.new_os_subtype_error(PySslError::class(&vm.ctx).to_owned(), None, msg.to_string()) + vm.new_os_subtype_error(PySSLError::class(&vm.ctx).to_owned(), None, msg.to_string()) .upcast() } @@ -3235,9 +3445,9 @@ mod _ssl { // Use SSLCertVerificationError for certificate verification failures let cls = if is_cert_verify_error { - PySslCertVerificationError::class(&vm.ctx).to_owned() + PySSLCertVerificationError::class(&vm.ctx).to_owned() } else { - PySslError::class(&vm.ctx).to_owned() + PySSLError::class(&vm.ctx).to_owned() }; // Build message @@ -3278,7 +3488,7 @@ mod _ssl { ) } None => { - let cls = PySslError::class(&vm.ctx).to_owned(); + let cls = PySSLError::class(&vm.ctx).to_owned(); vm.new_os_subtype_error(cls, None, "unknown SSL error") .upcast() } @@ -3317,27 +3527,18 @@ mod _ssl { ) -> PyBaseExceptionRef { let e = e.borrow(); let (cls, msg) = match e.code() { - ssl::ErrorCode::WANT_READ => ( - PySslWantReadError::class(&vm.ctx).to_owned(), - "The operation did not complete (read)", - ), - ssl::ErrorCode::WANT_WRITE => ( - PySslWantWriteError::class(&vm.ctx).to_owned(), - "The operation did not complete (write)", - ), + ssl::ErrorCode::WANT_READ => { + return create_ssl_want_read_error(vm).upcast(); + } + ssl::ErrorCode::WANT_WRITE => { + return create_ssl_want_write_error(vm).upcast(); + } ssl::ErrorCode::SYSCALL => match e.io_error() { Some(io_err) => return io_err.to_pyexception(vm), // When no I/O error and OpenSSL error queue is empty, // this is an EOF in violation of protocol -> SSLEOFError - // Need to set args[0] = SSL_ERROR_EOF for suppress_ragged_eofs check None => { - return vm - .new_os_subtype_error( - PySslEOFError::class(&vm.ctx).to_owned(), - Some(SSL_ERROR_EOF as i32), - "EOF occurred in violation of protocol", - ) - .upcast(); + return create_ssl_eof_error(vm).upcast(); } }, ssl::ErrorCode::SSL => { @@ -3350,24 +3551,18 @@ mod _ssl { let reason = sys::ERR_GET_REASON(err_code); let lib = sys::ERR_GET_LIB(err_code); if lib == ERR_LIB_SSL && reason == SSL_R_UNEXPECTED_EOF_WHILE_READING { - return vm - .new_os_subtype_error( - PySslEOFError::class(&vm.ctx).to_owned(), - Some(SSL_ERROR_EOF as i32), - "EOF occurred in violation of protocol", - ) - .upcast(); + return create_ssl_eof_error(vm).upcast(); } } return convert_openssl_error(vm, ssl_err.clone()); } ( - PySslError::class(&vm.ctx).to_owned(), + PySSLError::class(&vm.ctx).to_owned(), "A failure in the SSL library occurred", ) } _ => ( - PySslError::class(&vm.ctx).to_owned(), + PySSLError::class(&vm.ctx).to_owned(), "A failure in the SSL library occurred", ), }; diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 992b32e00ea..bf25260ed3a 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -207,28 +207,6 @@ mod _ssl { #[pyattr] const OP_ALL: i32 = 0x00000BFB; // Combined "safe" options (reduced for i32, excluding OP_LEGACY_SERVER_CONNECT for OpenSSL 3.0.0+ compatibility) - // Error types - #[pyattr] - const SSL_ERROR_NONE: i32 = 0; - #[pyattr] - const SSL_ERROR_SSL: i32 = 1; - #[pyattr] - const SSL_ERROR_WANT_READ: i32 = 2; - #[pyattr] - const SSL_ERROR_WANT_WRITE: i32 = 3; - #[pyattr] - const SSL_ERROR_WANT_X509_LOOKUP: i32 = 4; - #[pyattr] - const SSL_ERROR_SYSCALL: i32 = 5; - #[pyattr] - const SSL_ERROR_ZERO_RETURN: i32 = 6; - #[pyattr] - const SSL_ERROR_WANT_CONNECT: i32 = 7; - #[pyattr] - const SSL_ERROR_EOF: i32 = 8; - #[pyattr] - const SSL_ERROR_INVALID_ERROR_CODE: i32 = 10; - // Alert types (matching _TLSAlertType enum) #[pyattr] const ALERT_DESCRIPTION_CLOSE_NOTIFY: i32 = 0; diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index ab3c81b7a4e..cd927a0e410 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -1397,8 +1397,21 @@ pub(super) fn ssl_read( return Ok(n); } - // No plaintext available and cannot read more TLS records + // No plaintext available and rustls doesn't want to read more TLS records if !needs_more_tls { + // Check if connection needs to write data first (e.g., TLS key update, renegotiation) + // This mirrors the handshake logic which checks both wants_read() and wants_write() + if conn.wants_write() && !is_bio { + // Flush pending TLS data before continuing + let tls_data = ssl_write_tls_records(conn)?; + if !tls_data.is_empty() { + socket.sock_send(tls_data, vm).map_err(SslError::Py)?; + } + // After flushing, rustls may want to read again - continue loop + continue; + } + + // BIO mode: check for EOF if is_bio && let Some(bio_obj) = socket.incoming_bio() { let is_eof = bio_obj .get_attr("eof", vm) diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs index e31683ec72d..bef9ba513d7 100644 --- a/crates/stdlib/src/ssl/error.rs +++ b/crates/stdlib/src/ssl/error.rs @@ -10,9 +10,27 @@ pub(crate) mod ssl_error { types::Constructor, }; - // Error type constants (needed for create_ssl_want_read_error etc.) + // Error type constants - exposed as pyattr and available for internal use + #[pyattr] + pub(crate) const SSL_ERROR_NONE: i32 = 0; + #[pyattr] + pub(crate) const SSL_ERROR_SSL: i32 = 1; + #[pyattr] pub(crate) const SSL_ERROR_WANT_READ: i32 = 2; + #[pyattr] pub(crate) const SSL_ERROR_WANT_WRITE: i32 = 3; + #[pyattr] + pub(crate) const SSL_ERROR_WANT_X509_LOOKUP: i32 = 4; + #[pyattr] + pub(crate) const SSL_ERROR_SYSCALL: i32 = 5; + #[pyattr] + pub(crate) const SSL_ERROR_ZERO_RETURN: i32 = 6; + #[pyattr] + pub(crate) const SSL_ERROR_WANT_CONNECT: i32 = 7; + #[pyattr] + pub(crate) const SSL_ERROR_EOF: i32 = 8; + #[pyattr] + pub(crate) const SSL_ERROR_INVALID_ERROR_CODE: i32 = 10; #[pyattr] #[pyexception(name = "SSLError", base = PyOSError)] @@ -102,7 +120,7 @@ pub(crate) mod ssl_error { pub fn create_ssl_eof_error(vm: &VirtualMachine) -> PyRef { vm.new_os_subtype_error( PySSLEOFError::class(&vm.ctx).to_owned(), - None, + Some(SSL_ERROR_EOF), "EOF occurred in violation of protocol", ) } @@ -110,7 +128,7 @@ pub(crate) mod ssl_error { pub fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { vm.new_os_subtype_error( PySSLZeroReturnError::class(&vm.ctx).to_owned(), - None, + Some(SSL_ERROR_ZERO_RETURN), "TLS/SSL connection has been closed (EOF)", ) } From 898fe85f40b1c38b117688f96b2f02786a524194 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sun, 21 Dec 2025 00:38:50 +0900 Subject: [PATCH 018/418] More openssl impl (#6464) * Add openssl build to CI * more openssl impls --- .cspell.dict/cpython.txt | 1 + .github/workflows/ci.yaml | 5 + crates/stdlib/src/openssl.rs | 450 ++++++++++++++++++++++++------ crates/stdlib/src/openssl/cert.rs | 168 +++++++++-- 4 files changed, 520 insertions(+), 104 deletions(-) diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index 8acb1468f66..8ccd6d6b641 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -17,6 +17,7 @@ cmpop denom DICTFLAG dictoffset +distpoint elts excepthandler fileutils diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8539db1a27e..783e3dfa6b3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -147,6 +147,11 @@ jobs: - name: check compilation without threading run: cargo check ${{ env.CARGO_ARGS }} + - name: Test openssl build + run: + cargo build --no-default-features --features ssl-openssl + if: runner.os == 'Linux' + - name: Test example projects run: cargo run --manifest-path example_projects/barebone/Cargo.toml diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index b6d5f5c2035..4d420e7d539 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -64,7 +64,7 @@ mod _ssl { }, socket::{self, PySocket}, vm::{ - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{ PyBaseException, PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyWeak, @@ -73,8 +73,8 @@ mod _ssl { convert::ToPyException, exceptions, function::{ - ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, - OptionalArg, PyComparisonValue, + ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, OptionalArg, + PyComparisonValue, }, types::{Comparable, Constructor, PyComparisonOp}, utils::ToCString, @@ -602,6 +602,39 @@ mod _ssl { } } + // Get or create an ex_data index for msg_callback data + fn get_msg_callback_ex_data_index() -> libc::c_int { + use std::sync::LazyLock; + static MSG_CB_EX_DATA_IDX: LazyLock = LazyLock::new(|| unsafe { + sys::SSL_get_ex_new_index( + 0, + std::ptr::null_mut(), + None, + None, + Some(msg_callback_data_free), + ) + }); + *MSG_CB_EX_DATA_IDX + } + + // Free function for msg_callback data - called by OpenSSL when SSL is freed + unsafe extern "C" fn msg_callback_data_free( + _parent: *mut libc::c_void, + ptr: *mut libc::c_void, + _ad: *mut sys::CRYPTO_EX_DATA, + _idx: libc::c_int, + _argl: libc::c_long, + _argp: *mut libc::c_void, + ) { + if !ptr.is_null() { + unsafe { + // Reconstruct PyObjectRef and drop to decrement reference count + let raw = std::ptr::NonNull::new_unchecked(ptr as *mut PyObject); + let _ = PyObjectRef::from_raw(raw); + } + } + } + // SNI callback function called by OpenSSL unsafe extern "C" fn _servername_callback( ssl_ptr: *mut sys::SSL, @@ -732,13 +765,14 @@ mod _ssl { } unsafe { - // Get SSL socket from SSL_get_ex_data (index 0) - let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); + // Get SSL socket from ex_data using the dedicated index + let idx = get_msg_callback_ex_data_index(); + let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); if ssl_socket_ptr.is_null() { return; } - // ssl_socket_ptr is a pointer to Py, set in _wrap_socket/_wrap_bio + // ssl_socket_ptr is a pointer to Box>, set in _wrap_socket/_wrap_bio let ssl_socket: &Py = &*(ssl_socket_ptr as *const Py); // Get the callback from the context @@ -849,6 +883,9 @@ mod _ssl { post_handshake_auth: PyMutex, sni_callback: PyMutex>, msg_callback: PyMutex>, + psk_client_callback: PyMutex>, + psk_server_callback: PyMutex>, + psk_identity_hint: PyMutex>, } impl fmt::Debug for PySslContext { @@ -960,6 +997,9 @@ mod _ssl { post_handshake_auth: PyMutex::new(false), sni_callback: PyMutex::new(None), msg_callback: PyMutex::new(None), + psk_client_callback: PyMutex::new(None), + psk_server_callback: PyMutex::new(None), + psk_identity_hint: PyMutex::new(None), }) } } @@ -1083,9 +1123,26 @@ mod _ssl { self.ctx.read().options().bits() as _ } #[pygetset(setter)] - fn set_options(&self, opts: libc::c_ulong) { - self.builder() - .set_options(SslOptions::from_bits_truncate(opts as _)); + fn set_options(&self, new_opts: libc::c_ulong) { + let mut ctx = self.builder(); + // Get current options + let current = ctx.options().bits() as libc::c_ulong; + + // Calculate options to clear and set + let clear = current & !new_opts; + let set = !current & new_opts; + + // Clear options first (using raw FFI since openssl crate doesn't expose clear_options) + if clear != 0 { + unsafe { + sys::SSL_CTX_clear_options(ctx.as_ptr(), clear); + } + } + + // Then set new options + if set != 0 { + ctx.set_options(SslOptions::from_bits_truncate(set as _)); + } } #[pygetset] fn protocol(&self) -> i32 { @@ -1312,7 +1369,7 @@ mod _ssl { ssl::select_next_proto(&server, client).ok_or(ssl::AlpnError::NOACK)?; let pos = memchr::memmem::find(client, proto) .expect("selected alpn proto should be present in client protos"); - Ok(&client[pos..proto.len()]) + Ok(&client[pos..pos + proto.len()]) }); Ok(()) } @@ -1324,6 +1381,78 @@ mod _ssl { } } + #[pymethod] + fn set_psk_client_callback( + &self, + callback: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Cannot add PSK client callback to a server context + if self.protocol == SslVersion::TlsServer { + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Cannot add PSK client callback to a PROTOCOL_TLS_SERVER context" + .to_owned(), + ) + .upcast()); + } + + if vm.is_none(&callback) { + *self.psk_client_callback.lock() = None; + unsafe { + sys::SSL_CTX_set_psk_client_callback(self.builder().as_ptr(), None); + } + } else { + if !callback.is_callable() { + return Err(vm.new_type_error("callback must be callable".to_owned())); + } + *self.psk_client_callback.lock() = Some(callback); + // Note: The actual callback will be invoked via SSL app_data mechanism + // when do_handshake is called + } + Ok(()) + } + + #[pymethod] + fn set_psk_server_callback( + &self, + callback: PyObjectRef, + identity_hint: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Cannot add PSK server callback to a client context + if self.protocol == SslVersion::TlsClient { + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Cannot add PSK server callback to a PROTOCOL_TLS_CLIENT context" + .to_owned(), + ) + .upcast()); + } + + if vm.is_none(&callback) { + *self.psk_server_callback.lock() = None; + *self.psk_identity_hint.lock() = None; + unsafe { + sys::SSL_CTX_set_psk_server_callback(self.builder().as_ptr(), None); + } + } else { + if !callback.is_callable() { + return Err(vm.new_type_error("callback must be callable".to_owned())); + } + *self.psk_server_callback.lock() = Some(callback); + if let OptionalArg::Present(hint) = identity_hint { + *self.psk_identity_hint.lock() = Some(hint.as_str().to_owned()); + } + // Note: The actual callback will be invoked via SSL app_data mechanism + } + Ok(()) + } + #[pymethod] fn load_verify_locations( &self, @@ -1343,16 +1472,33 @@ mod _ssl { // validate cadata type and load cadata if let Some(cadata) = args.cadata { - let certs = match cadata { + let (certs, is_pem) = match cadata { Either::A(s) => { if !s.as_str().is_ascii() { return Err(invalid_cadata(vm)); } - X509::stack_from_pem(s.as_bytes()) + (X509::stack_from_pem(s.as_bytes()), true) } - Either::B(b) => b.with_ref(x509_stack_from_der), + Either::B(b) => (b.with_ref(x509_stack_from_der), false), }; let certs = certs.map_err(|e| convert_openssl_error(vm, e))?; + + // If no certificates were loaded, raise an error + if certs.is_empty() { + let msg = if is_pem { + "no start line: cadata does not contain a certificate" + } else { + "not enough data: cadata does not contain a certificate" + }; + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + msg.to_owned(), + ) + .upcast()); + } + let store = ctx.cert_store_mut(); for cert in certs { store @@ -1365,27 +1511,27 @@ mod _ssl { let cafile_path = args.cafile.map(|p| p.to_path_buf(vm)).transpose()?; let capath_path = args.capath.map(|p| p.to_path_buf(vm)).transpose()?; // Check file/directory existence before calling OpenSSL to get proper errno - if let Some(ref path) = cafile_path { - if !path.exists() { - return Err(vm - .new_os_subtype_error( - vm.ctx.exceptions.file_not_found_error.to_owned(), - Some(libc::ENOENT), - format!("No such file or directory: '{}'", path.display()), - ) - .upcast()); - } + if let Some(ref path) = cafile_path + && !path.exists() + { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", path.display()), + ) + .upcast()); } - if let Some(ref path) = capath_path { - if !path.exists() { - return Err(vm - .new_os_subtype_error( - vm.ctx.exceptions.file_not_found_error.to_owned(), - Some(libc::ENOENT), - format!("No such file or directory: '{}'", path.display()), - ) - .upcast()); - } + if let Some(ref path) = capath_path + && !path.exists() + { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", path.display()), + ) + .upcast()); } ctx.load_verify_locations(cafile_path.as_deref(), capath_path.as_deref()) .map_err(|e| convert_openssl_error(vm, e))?; @@ -1450,7 +1596,8 @@ mod _ssl { X509_LU_X509 => { x509_count += 1; let x509_ptr = sys::X509_OBJECT_get0_X509(obj_ptr); - if !x509_ptr.is_null() && X509_check_ca(x509_ptr) == 1 { + // X509_check_ca returns non-zero for any CA type + if !x509_ptr.is_null() && X509_check_ca(x509_ptr) != 0 { ca_count += 1; } } @@ -1673,18 +1820,19 @@ mod _ssl { #[pymethod] fn load_cert_chain(&self, args: LoadCertChainArgs, vm: &VirtualMachine) -> PyResult<()> { + use openssl::pkey::PKey; + use std::cell::RefCell; + let LoadCertChainArgs { certfile, keyfile, password, } = args; - // TODO: requires passing a callback to C - if password.is_some() { - return Err(vm.new_not_implemented_error("password arg not yet supported")); - } + let mut ctx = self.builder(); let key_path = keyfile.map(|path| path.to_path_buf(vm)).transpose()?; let cert_path = certfile.to_path_buf(vm)?; + // Check file existence before calling OpenSSL to get proper errno if !cert_path.exists() { return Err(vm @@ -1695,28 +1843,139 @@ mod _ssl { ) .upcast()); } - if let Some(ref kp) = key_path { - if !kp.exists() { - return Err(vm - .new_os_subtype_error( - vm.ctx.exceptions.file_not_found_error.to_owned(), - Some(libc::ENOENT), - format!("No such file or directory: '{}'", kp.display()), - ) - .upcast()); - } + if let Some(ref kp) = key_path + && !kp.exists() + { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", kp.display()), + ) + .upcast()); } + + // Load certificate chain ctx.set_certificate_chain_file(&cert_path) - .and_then(|()| { - ctx.set_private_key_file( - key_path.as_ref().unwrap_or(&cert_path), - ssl::SslFiletype::PEM, - ) - }) - .and_then(|()| ctx.check_private_key()) + .map_err(|e| convert_openssl_error(vm, e))?; + + // Load private key - handle password if provided + let key_file_path = key_path.as_ref().unwrap_or(&cert_path); + + // PEM_BUFSIZE = 1024 (maximum password length in OpenSSL) + const PEM_BUFSIZE: usize = 1024; + + // Read key file data + let key_data = std::fs::read(key_file_path) + .map_err(|e| crate::vm::convert::ToPyException::to_pyexception(&e, vm))?; + + let pkey = if let Some(ref pw_obj) = password { + if pw_obj.is_callable() { + // Callable password - use callback that calls Python function + // Store any Python error that occurs in the callback + let py_error: RefCell> = RefCell::new(None); + + let result = PKey::private_key_from_pem_callback(&key_data, |buf| { + // Call the Python password callback + let pw_result = pw_obj.call((), vm); + match pw_result { + Ok(result) => { + // Extract password bytes + match Self::extract_password_bytes( + &result, + "password callback must return a string", + vm, + ) { + Ok(pw) => { + // Check password length + if pw.len() > PEM_BUFSIZE { + *py_error.borrow_mut() = + Some(vm.new_value_error(format!( + "password cannot be longer than {} bytes", + PEM_BUFSIZE + ))); + return Err(openssl::error::ErrorStack::get()); + } + let len = std::cmp::min(pw.len(), buf.len()); + buf[..len].copy_from_slice(&pw[..len]); + Ok(len) + } + Err(e) => { + *py_error.borrow_mut() = Some(e); + Err(openssl::error::ErrorStack::get()) + } + } + } + Err(e) => { + *py_error.borrow_mut() = Some(e); + Err(openssl::error::ErrorStack::get()) + } + } + }); + + // Check for Python error first + if let Some(py_err) = py_error.into_inner() { + return Err(py_err); + } + + result.map_err(|e| convert_openssl_error(vm, e))? + } else { + // Direct password (string/bytes) + let pw = Self::extract_password_bytes( + pw_obj, + "password should be a string or bytes", + vm, + )?; + + // Check password length + if pw.len() > PEM_BUFSIZE { + return Err(vm.new_value_error(format!( + "password cannot be longer than {} bytes", + PEM_BUFSIZE + ))); + } + + PKey::private_key_from_pem_passphrase(&key_data, &pw) + .map_err(|e| convert_openssl_error(vm, e))? + } + } else { + // No password - use SSL_CTX_use_PrivateKey_file directly for correct error messages + ctx.set_private_key_file(key_file_path, ssl::SslFiletype::PEM) + .map_err(|e| convert_openssl_error(vm, e))?; + + // Verify key matches certificate and return early + return ctx + .check_private_key() + .map_err(|e| convert_openssl_error(vm, e)); + }; + + ctx.set_private_key(&pkey) + .map_err(|e| convert_openssl_error(vm, e))?; + + // Verify key matches certificate + ctx.check_private_key() .map_err(|e| convert_openssl_error(vm, e)) } + // Helper to extract password bytes from string/bytes/bytearray + fn extract_password_bytes( + obj: &PyObject, + bad_type_error: &str, + vm: &VirtualMachine, + ) -> PyResult> { + use crate::vm::builtins::{PyByteArray, PyBytes, PyStr}; + + if let Some(s) = obj.downcast_ref::() { + Ok(s.as_str().as_bytes().to_vec()) + } else if let Some(b) = obj.downcast_ref::() { + Ok(b.as_bytes().to_vec()) + } else if let Some(ba) = obj.downcast_ref::() { + Ok(ba.borrow_buf().to_vec()) + } else { + Err(vm.new_type_error(bad_type_error.to_owned())) + } + } + // Helper function to create SSL socket // = CPython's newPySSLSocket() fn new_py_ssl_socket( @@ -1855,11 +2114,12 @@ mod _ssl { unsafe { let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - // Store ssl_socket pointer in index 0 for msg_callback (like CPython's SSL_set_app_data) - // This is safe because ssl_socket owns the SSL object and outlives it - // We store a pointer to Py, which msg_callback can dereference - let py_ptr: *const Py = &*py_ref; - sys::SSL_set_ex_data(ssl_ptr, 0, py_ptr as *mut _); + // Clone and store via into_raw() - increments refcount and returns stable pointer + // The refcount will be decremented by msg_callback_data_free when SSL is freed + let cloned: PyObjectRef = py_ref.clone().into(); + let raw_ptr = cloned.into_raw(); + let msg_cb_idx = get_msg_callback_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, msg_cb_idx, raw_ptr.as_ptr() as *mut _); // Set SNI callback data if needed if has_sni_callback { @@ -1869,8 +2129,8 @@ mod _ssl { ssl_context: zelf.clone(), ssl_socket_weak, }); - let idx = get_sni_ex_data_index(); - sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); + let sni_idx = get_sni_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, sni_idx, Box::into_raw(callback_data) as *mut _); } } @@ -1923,11 +2183,12 @@ mod _ssl { unsafe { let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - // Store ssl_socket pointer in index 0 for msg_callback (like CPython's SSL_set_app_data) - // This is safe because ssl_socket owns the SSL object and outlives it - // We store a pointer to Py, which msg_callback can dereference - let py_ptr: *const Py = &*py_ref; - sys::SSL_set_ex_data(ssl_ptr, 0, py_ptr as *mut _); + // Clone and store via into_raw() - increments refcount and returns stable pointer + // The refcount will be decremented by msg_callback_data_free when SSL is freed + let cloned: PyObjectRef = py_ref.clone().into(); + let raw_ptr = cloned.into_raw(); + let msg_cb_idx = get_msg_callback_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, msg_cb_idx, raw_ptr.as_ptr() as *mut _); // Set SNI callback data if needed if has_sni_callback { @@ -1937,8 +2198,8 @@ mod _ssl { ssl_context: zelf.clone(), ssl_socket_weak, }); - let idx = get_sni_ex_data_index(); - sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); + let sni_idx = get_sni_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, sni_idx, Box::into_raw(callback_data) as *mut _); } } @@ -1995,7 +2256,7 @@ mod _ssl { #[pyarg(any, optional)] keyfile: Option, #[pyarg(any, optional)] - password: Option>, + password: Option, } // Err is true if the socket is blocking @@ -2004,7 +2265,6 @@ mod _ssl { enum SelectRet { Nonblocking, TimedOut, - IsBlocking, Closed, Ok, } @@ -2292,12 +2552,13 @@ mod _ssl { #[pymethod] fn get_unverified_chain(&self, vm: &VirtualMachine) -> PyResult> { let stream = self.connection.read(); - let Some(chain) = stream.ssl().peer_cert_chain() else { + let ssl = stream.ssl(); + let Some(chain) = ssl.peer_cert_chain() else { return Ok(None); }; // Return Certificate objects - let certs: Vec = chain + let mut certs: Vec = chain .iter() .map(|cert| unsafe { sys::X509_up_ref(cert.as_ptr()); @@ -2305,6 +2566,16 @@ mod _ssl { cert_to_certificate(vm, owned) }) .collect::>()?; + + // SSL_get_peer_cert_chain does not include peer cert for server-side sockets + // Add it manually at the beginning + if matches!(self.socket_type, SslServerOrClient::Server) + && let Some(peer_cert) = ssl.peer_certificate() + { + let peer_obj = cert_to_certificate(vm, peer_cert)?; + certs.insert(0, peer_obj); + } + Ok(Some(vm.ctx.new_list(certs))) } @@ -2652,7 +2923,7 @@ mod _ssl { return Err(socket_closed_error(vm)); } SelectRet::Nonblocking => {} - SelectRet::IsBlocking | SelectRet::Ok => { + SelectRet::Ok => { // For blocking sockets, select() has completed successfully // Continue the handshake loop (matches CPython's SOCKET_IS_BLOCKING behavior) if needs.is_some() { @@ -2719,7 +2990,7 @@ mod _ssl { } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} - SelectRet::IsBlocking | SelectRet::Ok => { + SelectRet::Ok => { // For blocking sockets, select() has completed successfully // Continue the write loop (matches CPython's SOCKET_IS_BLOCKING behavior) if needs.is_some() { @@ -2896,7 +3167,7 @@ mod _ssl { } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} - SelectRet::IsBlocking | SelectRet::Ok => { + SelectRet::Ok => { // For blocking sockets, select() has completed successfully // Continue the read loop (matches CPython's SOCKET_IS_BLOCKING behavior) if needs.is_some() { @@ -3576,30 +3847,33 @@ mod _ssl { let bio = bio::MemBioSlice::new(der)?; let mut certs = vec![]; + let mut was_bio_eof = false; loop { + // Check for EOF before attempting to parse (like CPython's _add_ca_certs) + // BIO_ctrl with BIO_CTRL_EOF returns 1 if EOF, 0 otherwise + if sys::BIO_ctrl(bio.as_ptr(), sys::BIO_CTRL_EOF, 0, std::ptr::null_mut()) != 0 { + was_bio_eof = true; + break; + } + let cert = sys::d2i_X509_bio(bio.as_ptr(), std::ptr::null_mut()); if cert.is_null() { + // Parse error (not just EOF) break; } certs.push(X509::from_ptr(cert)); } - if certs.is_empty() { - // No certificates loaded at all + // If we loaded some certs but didn't reach EOF, there's garbage data + // (like cacert_der + b"A") - this is an error + if !certs.is_empty() && !was_bio_eof { + // Return the error from the last failed parse attempt return Err(ErrorStack::get()); } - // Successfully loaded at least one certificate from DER data. - // Clear any trailing errors from EOF. - // CPython clears errors when: - // - DER: was_bio_eof is set (EOF reached) - // - PEM: PEM_R_NO_START_LINE error (normal EOF) - // Both cases mean successful completion with loaded certs. - eprintln!( - "[x509_stack_from_der] SUCCESS: Clearing errors and returning {} certs", - certs.len() - ); + // Clear any errors (including parse errors when no certs loaded) + // Let the caller decide how to handle empty results sys::ERR_clear_error(); Ok(certs) } diff --git a/crates/stdlib/src/openssl/cert.rs b/crates/stdlib/src/openssl/cert.rs index 1197bf4aa46..b63d824a837 100644 --- a/crates/stdlib/src/openssl/cert.rs +++ b/crates/stdlib/src/openssl/cert.rs @@ -5,16 +5,19 @@ pub(super) use ssl_cert::{PySSLCertificate, cert_to_certificate, cert_to_py, obj #[pymodule(sub)] pub(crate) mod ssl_cert { use crate::{ - common::ascii, + common::{ascii, hash::PyHash}, vm::{ - PyObjectRef, PyPayload, PyResult, VirtualMachine, + Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, + class_or_notimplemented, convert::{ToPyException, ToPyObject}, - function::{FsPath, OptionalArg}, + function::{FsPath, OptionalArg, PyComparisonValue}, + types::{Comparable, Hashable, PyComparisonOp, Representable}, }, }; use foreign_types_shared::ForeignTypeRef; use openssl::{ asn1::Asn1ObjectRef, + nid::Nid, x509::{self, X509, X509Ref}, }; use openssl_sys as sys; @@ -54,7 +57,7 @@ pub(crate) mod ssl_cert { #[pyclass(module = "ssl", name = "Certificate")] #[derive(PyPayload)] pub(crate) struct PySSLCertificate { - cert: X509, + pub(crate) cert: X509, } impl fmt::Debug for PySSLCertificate { @@ -63,7 +66,7 @@ pub(crate) mod ssl_cert { } } - #[pyclass] + #[pyclass(with(Comparable, Hashable, Representable))] impl PySSLCertificate { #[pymethod] fn public_bytes( @@ -83,12 +86,14 @@ pub(crate) mod ssl_cert { Ok(vm.ctx.new_bytes(der).into()) } ENCODING_PEM => { - // PEM encoding + // PEM encoding - returns string let pem = self .cert .to_pem() .map_err(|e| convert_openssl_error(vm, e))?; - Ok(vm.ctx.new_bytes(pem).into()) + let pem_str = String::from_utf8(pem) + .map_err(|_| vm.new_value_error("Invalid UTF-8 in PEM"))?; + Ok(vm.ctx.new_str(pem_str).into()) } _ => Err(vm.new_value_error("Unsupported format")), } @@ -100,6 +105,66 @@ pub(crate) mod ssl_cert { } } + impl Comparable for PySSLCertificate { + fn cmp( + zelf: &Py, + other: &PyObject, + op: PyComparisonOp, + vm: &VirtualMachine, + ) -> PyResult { + let other = class_or_notimplemented!(Self, other); + + // Only support equality comparison + if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) { + return Ok(PyComparisonValue::NotImplemented); + } + + // Compare DER encodings + let self_der = zelf + .cert + .to_der() + .map_err(|e| convert_openssl_error(vm, e))?; + let other_der = other + .cert + .to_der() + .map_err(|e| convert_openssl_error(vm, e))?; + + let eq = self_der == other_der; + Ok(op.eval_ord(eq.cmp(&true)).into()) + } + } + + impl Hashable for PySSLCertificate { + fn hash(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + // Use subject name hash as certificate hash + let hash = unsafe { sys::X509_subject_name_hash(zelf.cert.as_ptr()) }; + Ok(hash as PyHash) + } + } + + impl Representable for PySSLCertificate { + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + // Build subject string like "CN=localhost, O=Python" + let subject = zelf.cert.subject_name(); + let mut parts: Vec = Vec::new(); + for entry in subject.entries() { + // Use short name (SN) if available, otherwise use OID + let name = match entry.object().nid().short_name() { + Ok(sn) => sn.to_string(), + Err(_) => obj2txt(entry.object(), true).unwrap_or_default(), + }; + if let Ok(value) = entry.data().as_utf8() { + parts.push(format!("{}={}", name, value)); + } + } + if parts.is_empty() { + Ok("".to_string()) + } else { + Ok(format!("", parts.join(", "))) + } + } + } + fn name_to_py(vm: &VirtualMachine, name: &x509::X509NameRef) -> PyResult { let list = name .entries() @@ -133,11 +198,15 @@ pub(crate) mod ssl_cert { .to_bn() .and_then(|bn| bn.to_hex_str()) .map_err(|e| convert_openssl_error(vm, e))?; - dict.set_item( - "serialNumber", - vm.ctx.new_str(serial_num.to_owned()).into(), - vm, - )?; + // Serial number must have even length (each byte = 2 hex chars) + // BigNum::to_hex_str() strips leading zeros, so we need to pad + let serial_str = serial_num.to_string(); + let serial_str = if serial_str.len() % 2 == 1 { + format!("0{}", serial_str) + } else { + serial_str + }; + dict.set_item("serialNumber", vm.ctx.new_str(serial_str).into(), vm)?; dict.set_item( "notBefore", @@ -188,10 +257,23 @@ pub(crate) mod ssl_cert { return vm.new_tuple((ascii!("DirName"), py_name)).into(); } - // TODO: Handle Registered ID (GEN_RID) - // CPython implementation uses i2t_ASN1_OBJECT to convert OID - // This requires accessing GENERAL_NAME union which is complex in Rust - // For now, we return for unhandled types + // Check for Registered ID (GEN_RID) + // Access raw GENERAL_NAME to check type + let ptr = gen_name.as_ptr(); + unsafe { + if (*ptr).type_ == sys::GEN_RID { + // d is ASN1_OBJECT* for GEN_RID + let oid_ptr = (*ptr).d as *const sys::ASN1_OBJECT; + if !oid_ptr.is_null() { + let oid_ref = Asn1ObjectRef::from_ptr(oid_ptr as *mut _); + if let Some(oid_str) = obj2txt(oid_ref, true) { + return vm + .new_tuple((ascii!("Registered ID"), oid_str)) + .into(); + } + } + } + } // For othername and other unsupported types vm.new_tuple((ascii!("othername"), ascii!(""))) @@ -202,6 +284,60 @@ pub(crate) mod ssl_cert { dict.set_item("subjectAltName", vm.ctx.new_tuple(san).into(), vm)?; }; + // Authority Information Access: OCSP URIs + if let Ok(ocsp_list) = cert.ocsp_responders() + && !ocsp_list.is_empty() + { + let uris: Vec = ocsp_list + .iter() + .map(|s| vm.ctx.new_str(s.to_string()).into()) + .collect(); + dict.set_item("OCSP", vm.ctx.new_tuple(uris).into(), vm)?; + } + + // Authority Information Access: CA Issuers URIs + if let Some(aia) = cert.authority_info() { + let ca_issuers: Vec = aia + .iter() + .filter_map(|ad| { + // Check if method is CA Issuers (NID_ad_ca_issuers) + if ad.method().nid() != Nid::AD_CA_ISSUERS { + return None; + } + // Get URI from location + ad.location() + .uri() + .map(|uri| vm.ctx.new_str(uri.to_owned()).into()) + }) + .collect(); + if !ca_issuers.is_empty() { + dict.set_item("caIssuers", vm.ctx.new_tuple(ca_issuers).into(), vm)?; + } + } + + // CRL Distribution Points + if let Some(crl_dps) = cert.crl_distribution_points() { + let mut crl_uris: Vec = Vec::new(); + for dp in crl_dps.iter() { + if let Some(dp_name) = dp.distpoint() + && let Some(fullname) = dp_name.fullname() + { + for gn in fullname.iter() { + if let Some(uri) = gn.uri() { + crl_uris.push(vm.ctx.new_str(uri.to_owned()).into()); + } + } + } + } + if !crl_uris.is_empty() { + dict.set_item( + "crlDistributionPoints", + vm.ctx.new_tuple(crl_uris).into(), + vm, + )?; + } + } + Ok(dict.into()) } From 09fa97d1b91170d3b67c63a778aea19aacaff02d Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sun, 21 Dec 2025 00:54:05 +0900 Subject: [PATCH 019/418] CI to test Apple Intel (#6465) --- .github/workflows/ci.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 783e3dfa6b3..a0947af853b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -158,13 +158,13 @@ jobs: cargo run --manifest-path example_projects/frozen_stdlib/Cargo.toml if: runner.os == 'Linux' - - name: prepare AppleSilicon build + - name: prepare Intel MacOS build uses: dtolnay/rust-toolchain@stable with: - target: aarch64-apple-darwin + target: x86_64-apple-darwin if: runner.os == 'macOS' - - name: Check compilation for Apple Silicon - run: cargo check --target aarch64-apple-darwin + - name: Check compilation for Intel MacOS + run: cargo check --target x86_64-apple-darwin if: runner.os == 'macOS' - name: prepare iOS build uses: dtolnay/rust-toolchain@stable From 8b3bf558fc9fb74bfea2344012b0fe5d52aa594f Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sun, 21 Dec 2025 08:25:17 +0900 Subject: [PATCH 020/418] Introduce PyConfig to implement Paths correctly (#6461) * getpath * introduce PyConfig * landmark in getpath --- .cspell.dict/cpython.txt | 1 + crates/stdlib/src/syslog.rs | 2 +- crates/vm/src/getpath.rs | 385 +++++++++++++++++++++++++++++++ crates/vm/src/import.rs | 2 +- crates/vm/src/lib.rs | 1 + crates/vm/src/stdlib/builtins.rs | 4 +- crates/vm/src/stdlib/imp.rs | 2 +- crates/vm/src/stdlib/io.rs | 4 +- crates/vm/src/stdlib/signal.rs | 2 +- crates/vm/src/stdlib/sys.rs | 142 ++---------- crates/vm/src/vm/compile.rs | 2 +- crates/vm/src/vm/interpreter.rs | 10 +- crates/vm/src/vm/mod.rs | 28 +-- crates/vm/src/vm/setting.rs | 36 ++- src/interpreter.rs | 16 +- src/lib.rs | 12 +- 16 files changed, 487 insertions(+), 162 deletions(-) create mode 100644 crates/vm/src/getpath.rs diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index 8ccd6d6b641..fbb467c94db 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -43,6 +43,7 @@ numer orelse pathconfig patma +platstdlib posonlyarg posonlyargs prec diff --git a/crates/stdlib/src/syslog.rs b/crates/stdlib/src/syslog.rs index adba6f297ce..d0ed3f60949 100644 --- a/crates/stdlib/src/syslog.rs +++ b/crates/stdlib/src/syslog.rs @@ -26,7 +26,7 @@ mod syslog { use libc::{LOG_AUTHPRIV, LOG_CRON, LOG_PERROR}; fn get_argv(vm: &VirtualMachine) -> Option { - if let Some(argv) = vm.state.settings.argv.first() + if let Some(argv) = vm.state.config.settings.argv.first() && !argv.is_empty() { return Some( diff --git a/crates/vm/src/getpath.rs b/crates/vm/src/getpath.rs new file mode 100644 index 00000000000..423b9b54136 --- /dev/null +++ b/crates/vm/src/getpath.rs @@ -0,0 +1,385 @@ +//! Path configuration for RustPython (ref: Modules/getpath.py) +//! +//! This module implements Python path calculation logic following getpath.py. +//! It uses landmark-based search to locate prefix, exec_prefix, and stdlib directories. +//! +//! The main entry point is `init_path_config()` which computes Paths from Settings. + +use crate::vm::{Paths, Settings}; +use std::env; +use std::path::{Path, PathBuf}; + +// Platform-specific landmarks (ref: getpath.py PLATFORM CONSTANTS) + +#[cfg(not(windows))] +mod platform { + use crate::version; + + pub const BUILDDIR_TXT: &str = "pybuilddir.txt"; + pub const BUILD_LANDMARK: &str = "Modules/Setup.local"; + pub const VENV_LANDMARK: &str = "pyvenv.cfg"; + pub const BUILDSTDLIB_LANDMARK: &str = "Lib/os.py"; + + pub fn stdlib_subdir() -> String { + format!("lib/python{}.{}", version::MAJOR, version::MINOR) + } + + pub fn stdlib_landmarks() -> [String; 2] { + let subdir = stdlib_subdir(); + [format!("{}/os.py", subdir), format!("{}/os.pyc", subdir)] + } + + pub fn platstdlib_landmark() -> String { + format!( + "lib/python{}.{}/lib-dynload", + version::MAJOR, + version::MINOR + ) + } + + pub fn zip_landmark() -> String { + format!("lib/python{}{}.zip", version::MAJOR, version::MINOR) + } +} + +#[cfg(windows)] +mod platform { + use crate::version; + + pub const BUILDDIR_TXT: &str = "pybuilddir.txt"; + pub const BUILD_LANDMARK: &str = "Modules\\Setup.local"; + pub const VENV_LANDMARK: &str = "pyvenv.cfg"; + pub const BUILDSTDLIB_LANDMARK: &str = "Lib\\os.py"; + pub const STDLIB_SUBDIR: &str = "Lib"; + + pub fn stdlib_landmarks() -> [String; 2] { + ["Lib\\os.py".into(), "Lib\\os.pyc".into()] + } + + pub fn platstdlib_landmark() -> String { + "DLLs".into() + } + + pub fn zip_landmark() -> String { + format!("python{}{}.zip", version::MAJOR, version::MINOR) + } +} + +// Helper functions (ref: getpath.py HELPER FUNCTIONS) + +/// Search upward from a directory for landmark files/directories +/// Returns the directory where a landmark was found +fn search_up(start: P, landmarks: &[&str], test: F) -> Option +where + P: AsRef, + F: Fn(&Path) -> bool, +{ + let mut current = start.as_ref().to_path_buf(); + loop { + for landmark in landmarks { + let path = current.join(landmark); + if test(&path) { + return Some(current); + } + } + if !current.pop() { + return None; + } + } +} + +/// Search upward for a file landmark +fn search_up_file>(start: P, landmarks: &[&str]) -> Option { + search_up(start, landmarks, |p| p.is_file()) +} + +/// Search upward for a directory landmark +#[cfg(not(windows))] +fn search_up_dir>(start: P, landmarks: &[&str]) -> Option { + search_up(start, landmarks, |p| p.is_dir()) +} + +// Path computation functions + +/// Compute path configuration from Settings +/// +/// This function should be called before interpreter initialization. +/// It returns a Paths struct with all computed path values. +pub fn init_path_config(settings: &Settings) -> Paths { + let mut paths = Paths::default(); + + // Step 0: Get executable path + let executable = get_executable_path(); + paths.executable = executable + .as_ref() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_default(); + + let exe_dir = executable + .as_ref() + .and_then(|p| p.parent().map(PathBuf::from)); + + // Step 1: Check for __PYVENV_LAUNCHER__ environment variable + if let Ok(launcher) = env::var("__PYVENV_LAUNCHER__") { + paths.base_executable = launcher; + } + + // Step 2: Check for venv (pyvenv.cfg) and get 'home' + let (venv_prefix, home_dir) = detect_venv(&exe_dir); + let search_dir = home_dir.clone().or(exe_dir.clone()); + + // Step 3: Check for build directory + let build_prefix = detect_build_directory(&search_dir); + + // Step 4: Calculate prefix via landmark search + // When in venv, search_dir is home_dir, so this gives us the base Python's prefix + let calculated_prefix = calculate_prefix(&search_dir, &build_prefix); + + // Step 5: Set prefix and base_prefix + if venv_prefix.is_some() { + // In venv: prefix = venv directory, base_prefix = original Python's prefix + paths.prefix = venv_prefix + .as_ref() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_else(|| calculated_prefix.clone()); + paths.base_prefix = calculated_prefix; + } else { + // Not in venv: prefix == base_prefix + paths.prefix = calculated_prefix.clone(); + paths.base_prefix = calculated_prefix; + } + + // Step 6: Calculate exec_prefix + paths.exec_prefix = if venv_prefix.is_some() { + // In venv: exec_prefix = prefix (venv directory) + paths.prefix.clone() + } else { + calculate_exec_prefix(&search_dir, &paths.prefix) + }; + paths.base_exec_prefix = paths.base_prefix.clone(); + + // Step 7: Calculate base_executable (if not already set by __PYVENV_LAUNCHER__) + if paths.base_executable.is_empty() { + paths.base_executable = calculate_base_executable(executable.as_ref(), &home_dir); + } + + // Step 8: Build module_search_paths + paths.module_search_paths = + build_module_search_paths(settings, &paths.prefix, &paths.exec_prefix); + + paths +} + +/// Get default prefix value +fn default_prefix() -> String { + std::option_env!("RUSTPYTHON_PREFIX") + .map(String::from) + .unwrap_or_else(|| { + if cfg!(windows) { + "C:".to_owned() + } else { + "/usr/local".to_owned() + } + }) +} + +/// Detect virtual environment by looking for pyvenv.cfg +/// Returns (venv_prefix, home_dir from pyvenv.cfg) +fn detect_venv(exe_dir: &Option) -> (Option, Option) { + // Try exe_dir/../pyvenv.cfg first (standard venv layout: venv/bin/python) + if let Some(dir) = exe_dir + && let Some(venv_dir) = dir.parent() + { + let cfg = venv_dir.join(platform::VENV_LANDMARK); + if cfg.exists() + && let Some(home) = parse_pyvenv_home(&cfg) + { + return (Some(venv_dir.to_path_buf()), Some(PathBuf::from(home))); + } + } + + // Try exe_dir/pyvenv.cfg (alternative layout) + if let Some(dir) = exe_dir { + let cfg = dir.join(platform::VENV_LANDMARK); + if cfg.exists() + && let Some(home) = parse_pyvenv_home(&cfg) + { + return (Some(dir.clone()), Some(PathBuf::from(home))); + } + } + + (None, None) +} + +/// Detect if running from a build directory +fn detect_build_directory(exe_dir: &Option) -> Option { + let dir = exe_dir.as_ref()?; + + // Check for pybuilddir.txt (indicates build directory) + if dir.join(platform::BUILDDIR_TXT).exists() { + return Some(dir.clone()); + } + + // Check for Modules/Setup.local (build landmark) + if dir.join(platform::BUILD_LANDMARK).exists() { + return Some(dir.clone()); + } + + // Search up for Lib/os.py (build stdlib landmark) + search_up_file(dir, &[platform::BUILDSTDLIB_LANDMARK]) +} + +/// Calculate prefix by searching for landmarks +fn calculate_prefix(exe_dir: &Option, build_prefix: &Option) -> String { + // 1. If build directory detected, use it + if let Some(bp) = build_prefix { + return bp.to_string_lossy().into_owned(); + } + + if let Some(dir) = exe_dir { + // 2. Search for ZIP landmark + let zip = platform::zip_landmark(); + if let Some(prefix) = search_up_file(dir, &[&zip]) { + return prefix.to_string_lossy().into_owned(); + } + + // 3. Search for stdlib landmarks (os.py) + let landmarks = platform::stdlib_landmarks(); + let refs: Vec<&str> = landmarks.iter().map(|s| s.as_str()).collect(); + if let Some(prefix) = search_up_file(dir, &refs) { + return prefix.to_string_lossy().into_owned(); + } + } + + // 4. Fallback to default + default_prefix() +} + +/// Calculate exec_prefix +fn calculate_exec_prefix(exe_dir: &Option, prefix: &str) -> String { + #[cfg(windows)] + { + // Windows: exec_prefix == prefix + let _ = exe_dir; // silence unused warning + prefix.to_owned() + } + + #[cfg(not(windows))] + { + // POSIX: search for lib-dynload directory + if let Some(dir) = exe_dir { + let landmark = platform::platstdlib_landmark(); + if let Some(exec_prefix) = search_up_dir(dir, &[&landmark]) { + return exec_prefix.to_string_lossy().into_owned(); + } + } + // Fallback: same as prefix + prefix.to_owned() + } +} + +/// Calculate base_executable +fn calculate_base_executable(executable: Option<&PathBuf>, home_dir: &Option) -> String { + // If in venv and we have home, construct base_executable from home + if let (Some(exe), Some(home)) = (executable, home_dir) + && let Some(exe_name) = exe.file_name() + { + let base = home.join(exe_name); + return base.to_string_lossy().into_owned(); + } + + // Otherwise, base_executable == executable + executable + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_default() +} + +/// Build the complete module_search_paths (sys.path) +fn build_module_search_paths(settings: &Settings, prefix: &str, exec_prefix: &str) -> Vec { + let mut paths = Vec::new(); + + // 1. PYTHONPATH/RUSTPYTHONPATH from settings + paths.extend(settings.path_list.iter().cloned()); + + // 2. ZIP file path + let zip_path = PathBuf::from(prefix).join(platform::zip_landmark()); + paths.push(zip_path.to_string_lossy().into_owned()); + + // 3. stdlib and platstdlib directories + #[cfg(not(windows))] + { + // POSIX: stdlib first, then lib-dynload + let stdlib_dir = PathBuf::from(prefix).join(platform::stdlib_subdir()); + paths.push(stdlib_dir.to_string_lossy().into_owned()); + + let platstdlib = PathBuf::from(exec_prefix).join(platform::platstdlib_landmark()); + paths.push(platstdlib.to_string_lossy().into_owned()); + } + + #[cfg(windows)] + { + // Windows: DLLs first, then Lib + let platstdlib = PathBuf::from(exec_prefix).join(platform::platstdlib_landmark()); + paths.push(platstdlib.to_string_lossy().into_owned()); + + let stdlib_dir = PathBuf::from(prefix).join(platform::STDLIB_SUBDIR); + paths.push(stdlib_dir.to_string_lossy().into_owned()); + } + + paths +} + +/// Get the current executable path +fn get_executable_path() -> Option { + #[cfg(not(target_arch = "wasm32"))] + { + let exec_arg = env::args_os().next()?; + which::which(exec_arg).ok() + } + #[cfg(target_arch = "wasm32")] + { + let exec_arg = env::args().next()?; + Some(PathBuf::from(exec_arg)) + } +} + +/// Parse pyvenv.cfg and extract the 'home' key value +fn parse_pyvenv_home(pyvenv_cfg: &Path) -> Option { + let content = std::fs::read_to_string(pyvenv_cfg).ok()?; + + for line in content.lines() { + if let Some((key, value)) = line.split_once('=') + && key.trim().to_lowercase() == "home" + { + return Some(value.trim().to_string()); + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_init_path_config() { + let settings = Settings::default(); + let paths = init_path_config(&settings); + // Just verify it doesn't panic and returns valid paths + assert!(!paths.prefix.is_empty()); + } + + #[test] + fn test_search_up() { + // Test with a path that doesn't have any landmarks + let result = search_up_file(std::env::temp_dir(), &["nonexistent_landmark_xyz"]); + assert!(result.is_none()); + } + + #[test] + fn test_default_prefix() { + let prefix = default_prefix(); + assert!(!prefix.is_empty()); + } +} diff --git a/crates/vm/src/import.rs b/crates/vm/src/import.rs index 3f4a437c599..39748655e0f 100644 --- a/crates/vm/src/import.rs +++ b/crates/vm/src/import.rs @@ -206,7 +206,7 @@ fn remove_importlib_frames_inner( // TODO: This function should do nothing on verbose mode. // TODO: Fix this function after making PyTraceback.next mutable pub fn remove_importlib_frames(vm: &VirtualMachine, exc: &Py) { - if vm.state.settings.verbose != 0 { + if vm.state.config.settings.verbose != 0 { return; } diff --git a/crates/vm/src/lib.rs b/crates/vm/src/lib.rs index 923b33d2acc..f461c612955 100644 --- a/crates/vm/src/lib.rs +++ b/crates/vm/src/lib.rs @@ -60,6 +60,7 @@ pub mod exceptions; pub mod format; pub mod frame; pub mod function; +pub mod getpath; pub mod import; mod intern; pub mod iter; diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index 442bb79b94e..72d2c724159 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -132,7 +132,7 @@ mod builtins { let optimize: i32 = args.optimize.map_or(Ok(-1), |v| v.try_to_primitive(vm))?; let optimize: u8 = if optimize == -1 { - vm.state.settings.optimize + vm.state.config.settings.optimize } else { optimize .try_into() @@ -1080,7 +1080,7 @@ pub fn init_module(vm: &VirtualMachine, module: &Py) { builtins::extend_module(vm, module).unwrap(); - let debug_mode: bool = vm.state.settings.optimize == 0; + let debug_mode: bool = vm.state.config.settings.optimize == 0; // Create dynamic ExceptionGroup with multiple inheritance (BaseExceptionGroup + Exception) let exception_group = crate::exception_group::exception_group(); diff --git a/crates/vm/src/stdlib/imp.rs b/crates/vm/src/stdlib/imp.rs index 596847776ff..76b3bfd124c 100644 --- a/crates/vm/src/stdlib/imp.rs +++ b/crates/vm/src/stdlib/imp.rs @@ -91,7 +91,7 @@ mod _imp { #[pyattr] fn check_hash_based_pycs(vm: &VirtualMachine) -> PyStrRef { vm.ctx - .new_str(vm.state.settings.check_hash_pycs_mode.to_string()) + .new_str(vm.state.config.settings.check_hash_pycs_mode.to_string()) } #[pyfunction] diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs index 547e482f13a..ba3576a176c 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/io.rs @@ -2398,7 +2398,9 @@ mod _io { *data = None; let encoding = match args.encoding { - None if vm.state.settings.utf8_mode > 0 => identifier_utf8!(vm, utf_8).to_owned(), + None if vm.state.config.settings.utf8_mode > 0 => { + identifier_utf8!(vm, utf_8).to_owned() + } Some(enc) if enc.as_str() != "locale" => enc, _ => { // None without utf8_mode or "locale" encoding diff --git a/crates/vm/src/stdlib/signal.rs b/crates/vm/src/stdlib/signal.rs index 4eacb10154c..6771a950400 100644 --- a/crates/vm/src/stdlib/signal.rs +++ b/crates/vm/src/stdlib/signal.rs @@ -114,7 +114,7 @@ pub(crate) mod _signal { module: &Py, vm: &VirtualMachine, ) { - if vm.state.settings.install_signal_handlers { + if vm.state.config.settings.install_signal_handlers { let sig_dfl = vm.new_pyobj(SIG_DFL as u8); let sig_ign = vm.new_pyobj(SIG_IGN as u8); diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index f9bd2b59456..cfe6f9f5e61 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -25,7 +25,6 @@ mod sys { use std::{ env::{self, VarError}, io::Read, - path, sync::atomic::Ordering, }; @@ -96,25 +95,20 @@ mod sys { const DLLHANDLE: usize = 0; #[pyattr] - const fn default_prefix(_vm: &VirtualMachine) -> &'static str { - // TODO: the windows one doesn't really make sense - if cfg!(windows) { "C:" } else { "/usr/local" } + fn prefix(vm: &VirtualMachine) -> String { + vm.state.config.paths.prefix.clone() } #[pyattr] - fn prefix(vm: &VirtualMachine) -> &'static str { - option_env!("RUSTPYTHON_PREFIX").unwrap_or_else(|| default_prefix(vm)) + fn base_prefix(vm: &VirtualMachine) -> String { + vm.state.config.paths.base_prefix.clone() } #[pyattr] - fn base_prefix(vm: &VirtualMachine) -> &'static str { - option_env!("RUSTPYTHON_BASEPREFIX").unwrap_or_else(|| prefix(vm)) + fn exec_prefix(vm: &VirtualMachine) -> String { + vm.state.config.paths.exec_prefix.clone() } #[pyattr] - fn exec_prefix(vm: &VirtualMachine) -> &'static str { - option_env!("RUSTPYTHON_BASEPREFIX").unwrap_or_else(|| prefix(vm)) - } - #[pyattr] - fn base_exec_prefix(vm: &VirtualMachine) -> &'static str { - option_env!("RUSTPYTHON_BASEPREFIX").unwrap_or_else(|| exec_prefix(vm)) + fn base_exec_prefix(vm: &VirtualMachine) -> String { + vm.state.config.paths.base_exec_prefix.clone() } #[pyattr] fn platlibdir(_vm: &VirtualMachine) -> &'static str { @@ -126,6 +120,7 @@ mod sys { #[pyattr] fn argv(vm: &VirtualMachine) -> Vec { vm.state + .config .settings .argv .iter() @@ -162,117 +157,18 @@ mod sys { } #[pyattr] - fn _base_executable(vm: &VirtualMachine) -> PyObjectRef { - let ctx = &vm.ctx; - // First check __PYVENV_LAUNCHER__ environment variable - if let Ok(var) = env::var("__PYVENV_LAUNCHER__") { - return ctx.new_str(var).into(); - } - - // Try to detect if we're running from a venv by looking for pyvenv.cfg - if let Some(base_exe) = get_venv_base_executable() { - return ctx.new_str(base_exe).into(); - } - - executable(vm) - } - - /// Try to find base executable from pyvenv.cfg (see getpath.py) - fn get_venv_base_executable() -> Option { - // TODO: This is a minimal implementation of getpath.py - // To fully support all cases, `getpath.py` should be placed in @crates/vm/Lib/python_builtins/ - - // Get current executable path - #[cfg(not(target_arch = "wasm32"))] - let exe_path = { - let exec_arg = env::args_os().next()?; - which::which(exec_arg).ok()? - }; - #[cfg(target_arch = "wasm32")] - let exe_path = { - let exec_arg = env::args().next()?; - path::PathBuf::from(exec_arg) - }; - - let exe_dir = exe_path.parent()?; - let exe_name = exe_path.file_name()?; - - // Look for pyvenv.cfg in parent directory (typical venv layout: venv/bin/python) - let venv_dir = exe_dir.parent()?; - let pyvenv_cfg = venv_dir.join("pyvenv.cfg"); - - if !pyvenv_cfg.exists() { - return None; - } - - // Parse pyvenv.cfg and extract home directory - let content = std::fs::read_to_string(&pyvenv_cfg).ok()?; - - for line in content.lines() { - if let Some((key, value)) = line.split_once('=') { - let key = key.trim().to_lowercase(); - let value = value.trim(); - - if key == "home" { - // First try to resolve symlinks (getpath.py line 373-377) - if let Ok(resolved) = std::fs::canonicalize(&exe_path) - && resolved != exe_path - { - return Some(resolved.to_string_lossy().into_owned()); - } - // Fallback: home_dir + executable_name (getpath.py line 381) - let base_exe = path::Path::new(value).join(exe_name); - return Some(base_exe.to_string_lossy().into_owned()); - } - } - } - - None + fn _base_executable(vm: &VirtualMachine) -> String { + vm.state.config.paths.base_executable.clone() } #[pyattr] fn dont_write_bytecode(vm: &VirtualMachine) -> bool { - !vm.state.settings.write_bytecode + !vm.state.config.settings.write_bytecode } #[pyattr] - fn executable(vm: &VirtualMachine) -> PyObjectRef { - let ctx = &vm.ctx; - #[cfg(not(target_arch = "wasm32"))] - { - if let Some(exec_path) = env::args_os().next() - && let Ok(path) = which::which(exec_path) - { - return ctx - .new_str( - path.into_os_string() - .into_string() - .unwrap_or_else(|p| p.to_string_lossy().into_owned()), - ) - .into(); - } - } - if let Some(exec_path) = env::args().next() { - let path = path::Path::new(&exec_path); - if !path.exists() { - return ctx.new_str(ascii!("")).into(); - } - if path.is_absolute() { - return ctx.new_str(exec_path).into(); - } - if let Ok(dir) = env::current_dir() - && let Ok(dir) = dir.into_os_string().into_string() - { - return ctx - .new_str(format!( - "{}/{}", - dir, - exec_path.strip_prefix("./").unwrap_or(&exec_path) - )) - .into(); - } - } - ctx.none() + fn executable(vm: &VirtualMachine) -> String { + vm.state.config.paths.executable.clone() } #[pyattr] @@ -312,8 +208,9 @@ mod sys { #[pyattr] fn path(vm: &VirtualMachine) -> Vec { vm.state - .settings - .path_list + .config + .paths + .module_search_paths .iter() .map(|path| vm.ctx.new_str(path.clone()).into()) .collect() @@ -350,7 +247,7 @@ mod sys { fn _xoptions(vm: &VirtualMachine) -> PyDictRef { let ctx = &vm.ctx; let xopts = ctx.new_dict(); - for (key, value) in &vm.state.settings.xoptions { + for (key, value) in &vm.state.config.settings.xoptions { let value = value.as_ref().map_or_else( || ctx.new_bool(true).into(), |s| ctx.new_str(s.clone()).into(), @@ -363,6 +260,7 @@ mod sys { #[pyattr] fn warnoptions(vm: &VirtualMachine) -> Vec { vm.state + .config .settings .warnoptions .iter() @@ -507,7 +405,7 @@ mod sys { #[pyattr] fn flags(vm: &VirtualMachine) -> PyTupleRef { - PyFlags::from_data(FlagsData::from_settings(&vm.state.settings), vm) + PyFlags::from_data(FlagsData::from_settings(&vm.state.config.settings), vm) } #[pyattr] diff --git a/crates/vm/src/vm/compile.rs b/crates/vm/src/vm/compile.rs index 6f1ea734926..44332cda838 100644 --- a/crates/vm/src/vm/compile.rs +++ b/crates/vm/src/vm/compile.rs @@ -38,7 +38,7 @@ impl VirtualMachine { } // TODO: check if this is proper place - if !self.state.settings.safe_path { + if !self.state.config.settings.safe_path { let dir = std::path::Path::new(path) .parent() .unwrap() diff --git a/crates/vm/src/vm/interpreter.rs b/crates/vm/src/vm/interpreter.rs index 05613d43384..503feb3dc7f 100644 --- a/crates/vm/src/vm/interpreter.rs +++ b/crates/vm/src/vm/interpreter.rs @@ -1,5 +1,5 @@ -use super::{Context, VirtualMachine, setting::Settings, thread}; -use crate::{PyResult, stdlib::atexit, vm::PyBaseExceptionRef}; +use super::{Context, PyConfig, VirtualMachine, setting::Settings, thread}; +use crate::{PyResult, getpath, stdlib::atexit, vm::PyBaseExceptionRef}; use std::sync::atomic::Ordering; /// The general interface for the VM @@ -47,10 +47,14 @@ impl Interpreter { where F: FnOnce(&mut VirtualMachine), { + // Compute path configuration from settings + let paths = getpath::init_path_config(&settings); + let config = PyConfig::new(settings, paths); + let ctx = Context::genesis(); crate::types::TypeZoo::extend(ctx); crate::exceptions::ExceptionZoo::extend(ctx); - let mut vm = VirtualMachine::new(settings, ctx.clone()); + let mut vm = VirtualMachine::new(config, ctx.clone()); init(&mut vm); vm.initialize(); Self { vm } diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index 4574b2de370..34092454059 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -50,7 +50,7 @@ use std::{ pub use context::Context; pub use interpreter::Interpreter; pub(crate) use method::PyMethod; -pub use setting::{CheckHashPycsMode, Settings}; +pub use setting::{CheckHashPycsMode, Paths, PyConfig, Settings}; pub const MAX_MEMORY_SIZE: usize = isize::MAX as usize; @@ -87,7 +87,7 @@ struct ExceptionStack { } pub struct PyGlobalState { - pub settings: Settings, + pub config: PyConfig, pub module_inits: stdlib::StdlibMap, pub frozen: HashMap<&'static str, FrozenModule, ahash::RandomState>, pub stacksize: AtomicCell, @@ -114,7 +114,7 @@ pub fn process_hash_secret_seed() -> u32 { impl VirtualMachine { /// Create a new `VirtualMachine` structure. - fn new(settings: Settings, ctx: PyRc) -> Self { + fn new(config: PyConfig, ctx: PyRc) -> Self { flame_guard!("new VirtualMachine"); // make a new module without access to the vm; doesn't @@ -141,7 +141,7 @@ impl VirtualMachine { let module_inits = stdlib::get_module_inits(); - let seed = match settings.hash_seed { + let seed = match config.settings.hash_seed { Some(seed) => seed, None => process_hash_secret_seed(), }; @@ -151,7 +151,7 @@ impl VirtualMachine { let warnings = WarningsState::init_state(&ctx); - let int_max_str_digits = AtomicCell::new(match settings.int_max_str_digits { + let int_max_str_digits = AtomicCell::new(match config.settings.int_max_str_digits { -1 => 4300, other => other, } as usize); @@ -171,7 +171,7 @@ impl VirtualMachine { signal_rx: None, repr_guards: RefCell::default(), state: PyRc::new(PyGlobalState { - settings, + config, module_inits, frozen: HashMap::default(), stacksize: AtomicCell::new(0), @@ -227,7 +227,7 @@ impl VirtualMachine { let rustpythonpath_env = std::env::var("RUSTPYTHONPATH").ok(); let pythonpath_env = std::env::var("PYTHONPATH").ok(); let env_set = rustpythonpath_env.as_ref().is_some() || pythonpath_env.as_ref().is_some(); - let path_contains_env = self.state.settings.path_list.iter().any(|s| { + let path_contains_env = self.state.config.paths.module_search_paths.iter().any(|s| { Some(s.as_str()) == rustpythonpath_env.as_deref() || Some(s.as_str()) == pythonpath_env.as_deref() }); @@ -238,7 +238,7 @@ impl VirtualMachine { } else if path_contains_env { "RUSTPYTHONPATH or PYTHONPATH is set, but it doesn't contain the encodings library. If you are customizing the RustPython vm/interpreter, try adding the stdlib directory to the path. If you are developing the RustPython interpreter, it might be a bug during development." } else { - "RUSTPYTHONPATH or PYTHONPATH is set, but it wasn't loaded to `Settings::path_list`. If you are going to customize the RustPython vm/interpreter, those environment variables are not loaded in the Settings struct by default. Please try creating a customized instance of the Settings struct. If you are developing the RustPython interpreter, it might be a bug during development." + "RUSTPYTHONPATH or PYTHONPATH is set, but it wasn't loaded to `PyConfig::paths::module_search_paths`. If you are going to customize the RustPython vm/interpreter, those environment variables are not loaded in the Settings struct by default. Please try creating a customized instance of the Settings struct. If you are developing the RustPython interpreter, it might be a bug during development." }; let mut msg = format!( @@ -303,7 +303,7 @@ impl VirtualMachine { let io = import::import_builtin(self, "_io")?; #[cfg(feature = "stdio")] let make_stdio = |name, fd, write| { - let buffered_stdio = self.state.settings.buffered_stdio; + let buffered_stdio = self.state.config.settings.buffered_stdio; let unbuffered = write && !buffered_stdio; let buf = crate::stdlib::io::open( self.ctx.new_int(fd).into(), @@ -364,7 +364,7 @@ impl VirtualMachine { let res = essential_init(); let importlib = self.expect_pyresult(res, "essential initialization failed"); - if self.state.settings.allow_external_library + if self.state.config.settings.allow_external_library && cfg!(feature = "rustpython-compiler") && let Err(e) = import::init_importlib_package(self, importlib) { @@ -374,8 +374,8 @@ impl VirtualMachine { self.print_exception(e); } - let _expect_stdlib = - cfg!(feature = "freeze-stdlib") || !self.state.settings.path_list.is_empty(); + let _expect_stdlib = cfg!(feature = "freeze-stdlib") + || !self.state.config.paths.module_search_paths.is_empty(); #[cfg(feature = "encodings")] if _expect_stdlib { @@ -389,7 +389,7 @@ impl VirtualMachine { // Here may not be the best place to give general `path_list` advice, // but bare rustpython_vm::VirtualMachine users skipped proper settings must hit here while properly setup vm never enters here. eprintln!( - "feature `encodings` is enabled but `settings.path_list` is empty. \ + "feature `encodings` is enabled but `paths.module_search_paths` is empty. \ Please add the library path to `settings.path_list`. If you intended to disable the entire standard library (including the `encodings` feature), please also make sure to disable the `encodings` feature.\n\ Tip: You may also want to add `\"\"` to `settings.path_list` in order to enable importing from the current working directory." ); @@ -505,7 +505,7 @@ impl VirtualMachine { #[cfg(feature = "rustpython-codegen")] pub fn compile_opts(&self) -> crate::compiler::CompileOpts { crate::compiler::CompileOpts { - optimize: self.state.settings.optimize, + optimize: self.state.config.settings.optimize, } } diff --git a/crates/vm/src/vm/setting.rs b/crates/vm/src/vm/setting.rs index deaca705c47..53e2cef1160 100644 --- a/crates/vm/src/vm/setting.rs +++ b/crates/vm/src/vm/setting.rs @@ -1,8 +1,40 @@ #[cfg(feature = "flame-it")] use std::ffi::OsString; -/// Struct containing all kind of settings for the python vm. -/// Mostly `PyConfig` in CPython. +/// Path configuration computed at runtime (like PyConfig path outputs) +#[derive(Debug, Clone, Default)] +pub struct Paths { + /// sys.executable + pub executable: String, + /// sys._base_executable (original interpreter in venv) + pub base_executable: String, + /// sys.prefix + pub prefix: String, + /// sys.base_prefix + pub base_prefix: String, + /// sys.exec_prefix + pub exec_prefix: String, + /// sys.base_exec_prefix + pub base_exec_prefix: String, + /// Computed module_search_paths (complete sys.path) + pub module_search_paths: Vec, +} + +/// Combined configuration: user settings + computed paths +/// CPython directly exposes every fields under both of them. +/// We separate them to maintain better ownership discipline. +pub struct PyConfig { + pub settings: Settings, + pub paths: Paths, +} + +impl PyConfig { + pub fn new(settings: Settings, paths: Paths) -> Self { + Self { settings, paths } + } +} + +/// User-configurable settings for the python vm. #[non_exhaustive] pub struct Settings { /// -I diff --git a/src/interpreter.rs b/src/interpreter.rs index b79a1a0ffb4..b4fd319cdae 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -104,23 +104,25 @@ pub fn init_stdlib(vm: &mut VirtualMachine) { use rustpython_vm::common::rc::PyRc; let state = PyRc::get_mut(&mut vm.state).unwrap(); - let settings = &mut state.settings; - let path_list = std::mem::take(&mut settings.path_list); + // Collect additional paths to add + let mut additional_paths = Vec::new(); // BUILDTIME_RUSTPYTHONPATH should be set when distributing if let Some(paths) = option_env!("BUILDTIME_RUSTPYTHONPATH") { - settings.path_list.extend( + additional_paths.extend( crate::settings::split_paths(paths) .map(|path| path.into_os_string().into_string().unwrap()), ) } else { #[cfg(feature = "rustpython-pylib")] - settings - .path_list - .push(rustpython_pylib::LIB_PATH.to_owned()) + additional_paths.push(rustpython_pylib::LIB_PATH.to_owned()) } - settings.path_list.extend(path_list); + // Add to both path_list (for compatibility) and module_search_paths (for sys.path) + // Insert at the beginning so stdlib comes before user paths + for path in additional_paths.into_iter().rev() { + state.config.paths.module_search_paths.insert(0, path); + } } } diff --git a/src/lib.rs b/src/lib.rs index 84a774ab029..8d278058933 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -169,7 +169,7 @@ fn run_rustpython(vm: &VirtualMachine, run_mode: RunMode) -> PyResult<()> { let scope = setup_main_module(vm)?; - if !vm.state.settings.safe_path { + if !vm.state.config.settings.safe_path { // TODO: The prepending path depends on running mode // See https://docs.python.org/3/using/cmdline.html#cmdoption-P vm.run_code_string( @@ -189,7 +189,7 @@ fn run_rustpython(vm: &VirtualMachine, run_mode: RunMode) -> PyResult<()> { // Enable faulthandler if -X faulthandler, PYTHONFAULTHANDLER or -X dev is set // _PyFaulthandler_Init() - if vm.state.settings.faulthandler { + if vm.state.config.settings.faulthandler { let _ = vm.run_code_string( vm.new_scope_with_builtins(), "import faulthandler; faulthandler.enable()", @@ -198,8 +198,8 @@ fn run_rustpython(vm: &VirtualMachine, run_mode: RunMode) -> PyResult<()> { } let is_repl = matches!(run_mode, RunMode::Repl); - if !vm.state.settings.quiet - && (vm.state.settings.verbose > 0 || (is_repl && std::io::stdin().is_terminal())) + if !vm.state.config.settings.quiet + && (vm.state.config.settings.verbose > 0 || (is_repl && std::io::stdin().is_terminal())) { eprintln!( "Welcome to the magnificent Rust Python {} interpreter \u{1f631} \u{1f596}", @@ -232,7 +232,7 @@ fn run_rustpython(vm: &VirtualMachine, run_mode: RunMode) -> PyResult<()> { } RunMode::Repl => Ok(()), }; - if is_repl || vm.state.settings.inspect { + if is_repl || vm.state.config.settings.inspect { shell::run_shell(vm, scope)?; } else { res?; @@ -241,7 +241,7 @@ fn run_rustpython(vm: &VirtualMachine, run_mode: RunMode) -> PyResult<()> { #[cfg(feature = "flame-it")] { main_guard.end(); - if let Err(e) = write_profile(&vm.state.as_ref().settings) { + if let Err(e) = write_profile(&vm.state.as_ref().config.settings) { error!("Error writing profile information: {}", e); } } From 569bee103ff528ac8d70c3b5b78c70378e6cd115 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Mon, 22 Dec 2025 04:43:58 +0100 Subject: [PATCH 021/418] Use `ruff_python_ast::visitor::Visitor` for detecting `await` (#6466) --- crates/codegen/src/compile.rs | 158 +++++++--------------------------- 1 file changed, 30 insertions(+), 128 deletions(-) diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index 9620edbd107..4f7174c8bad 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -24,12 +24,13 @@ use ruff_python_ast::{ Alias, Arguments, BoolOp, CmpOp, Comprehension, ConversionFlag, DebugText, Decorator, DictItem, ExceptHandler, ExceptHandlerExceptHandler, Expr, ExprAttribute, ExprBoolOp, ExprContext, ExprFString, ExprList, ExprName, ExprSlice, ExprStarred, ExprSubscript, ExprTuple, ExprUnaryOp, - FString, FStringFlags, FStringPart, Identifier, Int, InterpolatedElement, - InterpolatedStringElement, InterpolatedStringElements, Keyword, MatchCase, ModExpression, - ModModule, Operator, Parameters, Pattern, PatternMatchAs, PatternMatchClass, - PatternMatchMapping, PatternMatchOr, PatternMatchSequence, PatternMatchSingleton, - PatternMatchStar, PatternMatchValue, Singleton, Stmt, StmtExpr, TypeParam, TypeParamParamSpec, - TypeParamTypeVar, TypeParamTypeVarTuple, TypeParams, UnaryOp, WithItem, + FString, FStringFlags, FStringPart, Identifier, Int, InterpolatedStringElement, + InterpolatedStringElements, Keyword, MatchCase, ModExpression, ModModule, Operator, Parameters, + Pattern, PatternMatchAs, PatternMatchClass, PatternMatchMapping, PatternMatchOr, + PatternMatchSequence, PatternMatchSingleton, PatternMatchStar, PatternMatchValue, Singleton, + Stmt, StmtExpr, TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, + TypeParams, UnaryOp, WithItem, + visitor::{Visitor, walk_expr}, }; use ruff_text_size::{Ranged, TextRange}; use rustpython_compiler_core::{ @@ -5435,134 +5436,35 @@ impl Compiler { /// Whether the expression contains an await expression and /// thus requires the function to be async. - /// Async with and async for are statements, so I won't check for them here + /// + /// Both: + /// ```py + /// async with: ... + /// async for: ... + /// ``` + /// are statements, so we won't check for them here fn contains_await(expression: &Expr) -> bool { - use ruff_python_ast::*; + #[derive(Default)] + struct AwaitVisitor { + found: bool, + } - match &expression { - Expr::Call(ExprCall { - func, arguments, .. - }) => { - Self::contains_await(func) - || arguments.args.iter().any(Self::contains_await) - || arguments - .keywords - .iter() - .any(|kw| Self::contains_await(&kw.value)) - } - Expr::BoolOp(ExprBoolOp { values, .. }) => values.iter().any(Self::contains_await), - Expr::BinOp(ExprBinOp { left, right, .. }) => { - Self::contains_await(left) || Self::contains_await(right) - } - Expr::Subscript(ExprSubscript { value, slice, .. }) => { - Self::contains_await(value) || Self::contains_await(slice) - } - Expr::UnaryOp(ExprUnaryOp { operand, .. }) => Self::contains_await(operand), - Expr::Attribute(ExprAttribute { value, .. }) => Self::contains_await(value), - Expr::Compare(ExprCompare { - left, comparators, .. - }) => Self::contains_await(left) || comparators.iter().any(Self::contains_await), - Expr::List(ExprList { elts, .. }) => elts.iter().any(Self::contains_await), - Expr::Tuple(ExprTuple { elts, .. }) => elts.iter().any(Self::contains_await), - Expr::Set(ExprSet { elts, .. }) => elts.iter().any(Self::contains_await), - Expr::Dict(ExprDict { items, .. }) => items - .iter() - .flat_map(|item| &item.key) - .any(Self::contains_await), - Expr::Slice(ExprSlice { - lower, upper, step, .. - }) => { - lower.as_deref().is_some_and(Self::contains_await) - || upper.as_deref().is_some_and(Self::contains_await) - || step.as_deref().is_some_and(Self::contains_await) - } - Expr::Yield(ExprYield { value, .. }) => { - value.as_deref().is_some_and(Self::contains_await) - } - Expr::Await(ExprAwait { .. }) => true, - Expr::YieldFrom(ExprYieldFrom { value, .. }) => Self::contains_await(value), - Expr::Name(ExprName { .. }) => false, - Expr::Lambda(ExprLambda { body, .. }) => Self::contains_await(body), - Expr::ListComp(ExprListComp { - elt, generators, .. - }) => { - Self::contains_await(elt) - || generators.iter().any(|jen| Self::contains_await(&jen.iter)) - } - Expr::SetComp(ExprSetComp { - elt, generators, .. - }) => { - Self::contains_await(elt) - || generators.iter().any(|jen| Self::contains_await(&jen.iter)) - } - Expr::DictComp(ExprDictComp { - key, - value, - generators, - .. - }) => { - Self::contains_await(key) - || Self::contains_await(value) - || generators.iter().any(|jen| Self::contains_await(&jen.iter)) - } - Expr::Generator(ExprGenerator { - elt, generators, .. - }) => { - Self::contains_await(elt) - || generators.iter().any(|jen| Self::contains_await(&jen.iter)) - } - Expr::Starred(expr) => Self::contains_await(&expr.value), - Expr::If(ExprIf { - test, body, orelse, .. - }) => { - Self::contains_await(test) - || Self::contains_await(body) - || Self::contains_await(orelse) - } + impl Visitor<'_> for AwaitVisitor { + fn visit_expr(&mut self, expr: &Expr) { + if self.found { + return; + } - Expr::Named(ExprNamed { - target, - value, - node_index: _, - range: _, - }) => Self::contains_await(target) || Self::contains_await(value), - Expr::FString(fstring) => { - Self::interpolated_string_contains_await(fstring.value.elements()) - } - Expr::TString(tstring) => { - Self::interpolated_string_contains_await(tstring.value.elements()) + match expr { + Expr::Await(_) => self.found = true, + _ => walk_expr(self, expr), + } } - Expr::StringLiteral(_) - | Expr::BytesLiteral(_) - | Expr::NumberLiteral(_) - | Expr::BooleanLiteral(_) - | Expr::NoneLiteral(_) - | Expr::EllipsisLiteral(_) - | Expr::IpyEscapeCommand(_) => false, - } - } - - fn interpolated_string_contains_await<'a>( - mut elements: impl Iterator, - ) -> bool { - fn interpolated_element_contains_await bool>( - expr_element: &InterpolatedElement, - contains_await: F, - ) -> bool { - contains_await(&expr_element.expression) - || expr_element - .format_spec - .iter() - .flat_map(|spec| spec.elements.interpolations()) - .any(|element| interpolated_element_contains_await(element, contains_await)) } - elements.any(|element| match element { - InterpolatedStringElement::Interpolation(expr_element) => { - interpolated_element_contains_await(expr_element, Self::contains_await) - } - InterpolatedStringElement::Literal(_) => false, - }) + let mut visitor = AwaitVisitor::default(); + visitor.visit_expr(expression); + visitor.found } fn compile_expr_fstring(&mut self, fstring: &ExprFString) -> CompileResult<()> { From fdd2ac3b30827c1d44c5c5040eb184e65cf2534e Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Mon, 22 Dec 2025 21:10:43 +0900 Subject: [PATCH 022/418] disallow __new__, __init__ (#6446) * disallow __new__, __init__ * migrate to Initializer * apply review --- crates/derive-impl/src/pyclass.rs | 125 +++++++----- crates/stdlib/src/sqlite.rs | 41 ++-- crates/stdlib/src/ssl/error.rs | 2 +- crates/vm/src/builtins/object.rs | 32 +-- crates/vm/src/exception_group.rs | 244 ++++++++++++----------- crates/vm/src/exceptions.rs | 304 ++++++++++++++++------------- crates/vm/src/stdlib/ast/python.rs | 79 ++++---- 7 files changed, 458 insertions(+), 369 deletions(-) diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index f784a2e2a76..d1fe5398634 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -63,6 +63,7 @@ impl FromStr for AttrName { #[derive(Default)] struct ImplContext { + is_trait: bool, attribute_items: ItemNursery, method_items: MethodNursery, getset_items: GetSetNursery, @@ -232,7 +233,10 @@ pub(crate) fn impl_pyclass_impl(attr: PunctuatedNestedMeta, item: Item) -> Resul } } Item::Trait(mut trai) => { - let mut context = ImplContext::default(); + let mut context = ImplContext { + is_trait: true, + ..Default::default() + }; let mut has_extend_slots = false; for item in &trai.items { let has = match item { @@ -710,21 +714,16 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R }; // Check if with(Constructor) is specified. If Constructor trait is used, don't generate slot_new - let mut has_slot_new = false; - let mut extra_attrs = Vec::new(); + let mut with_items = vec![]; for nested in &attr { if let NestedMeta::Meta(Meta::List(MetaList { path, nested, .. })) = nested { // If we already found the constructor trait, no need to keep looking for it - if !has_slot_new && path.is_ident("with") { - // Check if Constructor is in the list + if path.is_ident("with") { for meta in nested { - if let NestedMeta::Meta(Meta::Path(p)) = meta - && p.is_ident("Constructor") - { - has_slot_new = true; - } + with_items.push(meta.get_ident().expect("with() has non-ident item").clone()); } + continue; } extra_attrs.push(NestedMeta::Meta(Meta::List(MetaList { path: path.clone(), @@ -734,43 +733,40 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R } } - let mut has_slot_init = false; + let with_contains = |with_items: &[Ident], s: &str| { + // Check if Constructor is in the list + with_items.iter().any(|ident| ident == s) + }; + let syn::ItemImpl { generics, self_ty, items, .. } = &imp; - for item in items { - // FIXME: better detection or correct wrapper implementation - let Some(ident) = item.get_ident() else { - continue; - }; - let item_name = ident.to_string(); - match item_name.as_str() { - "slot_new" => { - has_slot_new = true; - } - "slot_init" => { - has_slot_init = true; - } - _ => continue, - } - } - - // TODO: slot_new, slot_init must be Constructor or Initializer later - let slot_new = if has_slot_new { + let slot_new = if with_contains(&with_items, "Constructor") { quote!() } else { + with_items.push(Ident::new("Constructor", Span::call_site())); quote! { - #[pyslot] - pub fn slot_new( - cls: ::rustpython_vm::builtins::PyTypeRef, - args: ::rustpython_vm::function::FuncArgs, - vm: &::rustpython_vm::VirtualMachine, - ) -> ::rustpython_vm::PyResult { - ::Base::slot_new(cls, args, vm) + impl ::rustpython_vm::types::Constructor for #self_ty { + type Args = ::rustpython_vm::function::FuncArgs; + + fn slot_new( + cls: ::rustpython_vm::builtins::PyTypeRef, + args: ::rustpython_vm::function::FuncArgs, + vm: &::rustpython_vm::VirtualMachine, + ) -> ::rustpython_vm::PyResult { + ::Base::slot_new(cls, args, vm) + } + fn py_new( + _cls: &::rustpython_vm::Py<::rustpython_vm::builtins::PyType>, + _args: Self::Args, + _vm: &::rustpython_vm::VirtualMachine + ) -> ::rustpython_vm::PyResult { + unreachable!("slot_new is defined") + } } } }; @@ -779,19 +775,29 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R // from `BaseException` in `SimpleExtendsException` macro. // See: `(initproc)BaseException_init` // spell-checker:ignore initproc - let slot_init = if has_slot_init { + let slot_init = if with_contains(&with_items, "Initializer") { quote!() } else { - // FIXME: this is a generic logic for types not only for exceptions + with_items.push(Ident::new("Initializer", Span::call_site())); quote! { - #[pyslot] - #[pymethod(name="__init__")] - pub fn slot_init( - zelf: ::rustpython_vm::PyObjectRef, - args: ::rustpython_vm::function::FuncArgs, - vm: &::rustpython_vm::VirtualMachine, - ) -> ::rustpython_vm::PyResult<()> { - ::Base::slot_init(zelf, args, vm) + impl ::rustpython_vm::types::Initializer for #self_ty { + type Args = ::rustpython_vm::function::FuncArgs; + + fn slot_init( + zelf: ::rustpython_vm::PyObjectRef, + args: ::rustpython_vm::function::FuncArgs, + vm: &::rustpython_vm::VirtualMachine, + ) -> ::rustpython_vm::PyResult<()> { + ::Base::slot_init(zelf, args, vm) + } + + fn init( + _zelf: ::rustpython_vm::PyRef, + _args: Self::Args, + _vm: &::rustpython_vm::VirtualMachine + ) -> ::rustpython_vm::PyResult<()> { + unreachable!("slot_init is defined") + } } } }; @@ -803,13 +809,13 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R }; Ok(quote! { - #[pyclass(flags(BASETYPE, HAS_DICT) #extra_attrs_tokens)] + #[pyclass(flags(BASETYPE, HAS_DICT), with(#(#with_items),*) #extra_attrs_tokens)] impl #generics #self_ty { #(#items)* - - #slot_new - #slot_init } + + #slot_new + #slot_init }) } @@ -892,6 +898,23 @@ where let item_meta = MethodItemMeta::from_attr(ident.clone(), &item_attr)?; let py_name = item_meta.method_name()?; + + // Disallow __new__ and __init__ as pymethod in impl blocks (not in traits) + if !args.context.is_trait { + if py_name == "__new__" { + return Err(syn::Error::new( + ident.span(), + "#[pymethod] cannot define '__new__'. Use #[pyclass(with(Constructor))] instead.", + )); + } + if py_name == "__init__" { + return Err(syn::Error::new( + ident.span(), + "#[pymethod] cannot define '__init__'. Use #[pyclass(with(Initializer))] instead.", + )); + } + } + let raw = item_meta.raw()?; let sig_doc = text_signature(func.sig(), &py_name); diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index bc84cffbf80..ffdc9eb3831 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -1540,7 +1540,7 @@ mod _sqlite { size: Option, } - #[pyclass(with(Constructor, IterNext, Iterable), flags(BASETYPE))] + #[pyclass(with(Constructor, Initializer, IterNext, Iterable), flags(BASETYPE))] impl Cursor { fn new( connection: PyRef, @@ -1571,24 +1571,6 @@ mod _sqlite { } } - #[pymethod] - fn __init__(&self, _connection: PyRef, _vm: &VirtualMachine) -> PyResult<()> { - let mut guard = self.inner.lock(); - if guard.is_some() { - // Already initialized (e.g., from a call to super().__init__) - return Ok(()); - } - *guard = Some(CursorInner { - description: None, - row_cast_map: vec![], - lastrowid: -1, - rowcount: -1, - statement: None, - closed: false, - }); - Ok(()) - } - fn check_cursor_state(inner: Option<&CursorInner>, vm: &VirtualMachine) -> PyResult<()> { match inner { Some(inner) if inner.closed => Err(new_programming_error( @@ -1949,6 +1931,27 @@ mod _sqlite { } } + impl Initializer for Cursor { + type Args = PyRef; + + fn init(zelf: PyRef, _connection: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + let mut guard = zelf.inner.lock(); + if guard.is_some() { + // Already initialized (e.g., from a call to super().__init__) + return Ok(()); + } + *guard = Some(CursorInner { + description: None, + row_cast_map: vec![], + lastrowid: -1, + rowcount: -1, + statement: None, + closed: false, + }); + Ok(()) + } + } + impl SelfIter for Cursor {} impl IterNext for Cursor { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs index bef9ba513d7..879275228ec 100644 --- a/crates/stdlib/src/ssl/error.rs +++ b/crates/stdlib/src/ssl/error.rs @@ -7,7 +7,7 @@ pub(crate) mod ssl_error { use crate::vm::{ PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyOSError, PyStrRef}, - types::Constructor, + types::{Constructor, Initializer}, }; // Error type constants - exposed as pyattr and available for internal use diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs index cb95652f937..0970496c7b1 100644 --- a/crates/vm/src/builtins/object.rs +++ b/crates/vm/src/builtins/object.rs @@ -2,11 +2,11 @@ use super::{PyDictRef, PyList, PyStr, PyStrRef, PyType, PyTypeRef}; use crate::common::hash::PyHash; use crate::types::PyTypeFlags; use crate::{ - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, convert::ToPyResult, function::{Either, FuncArgs, PyArithmeticValue, PyComparisonValue, PySetterValue}, - types::{Constructor, PyComparisonOp}, + types::{Constructor, Initializer, PyComparisonOp}, }; use itertools::Itertools; @@ -115,6 +115,18 @@ impl Constructor for PyBaseObject { } } +impl Initializer for PyBaseObject { + type Args = FuncArgs; + + fn slot_init(_zelf: PyObjectRef, _args: FuncArgs, _vm: &VirtualMachine) -> PyResult<()> { + Ok(()) + } + + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } +} + // TODO: implement _PyType_GetSlotNames properly fn type_slot_names(typ: &Py, vm: &VirtualMachine) -> PyResult> { // let attributes = typ.attributes.read(); @@ -235,7 +247,7 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) // getstate.call((), vm) // } -#[pyclass(with(Constructor), flags(BASETYPE))] +#[pyclass(with(Constructor, Initializer), flags(BASETYPE))] impl PyBaseObject { #[pymethod(raw)] fn __getstate__(vm: &VirtualMachine, args: FuncArgs) -> PyResult { @@ -444,19 +456,17 @@ impl PyBaseObject { obj.str(vm) } - #[pyslot] - #[pymethod] - fn __init__(_zelf: PyObjectRef, _args: FuncArgs, _vm: &VirtualMachine) -> PyResult<()> { - Ok(()) - } - #[pygetset] fn __class__(obj: PyObjectRef) -> PyTypeRef { obj.class().to_owned() } - #[pygetset(name = "__class__", setter)] - fn set_class(instance: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + #[pygetset(setter)] + fn set___class__( + instance: PyObjectRef, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { match value.downcast::() { Ok(cls) => { let both_module = instance.class().fast_issubclass(vm.ctx.types.module_type) diff --git a/crates/vm/src/exception_group.rs b/crates/vm/src/exception_group.rs index cd943ae1bd9..e19dbceb8da 100644 --- a/crates/vm/src/exception_group.rs +++ b/crates/vm/src/exception_group.rs @@ -43,13 +43,14 @@ pub(super) mod types { use super::*; use crate::PyPayload; use crate::builtins::PyGenericAlias; + use crate::types::{Constructor, Initializer}; #[pyexception(name, base = PyBaseException, ctx = "base_exception_group")] #[derive(Debug)] #[repr(transparent)] pub struct PyBaseExceptionGroup(PyBaseException); - #[pyexception] + #[pyexception(with(Constructor, Initializer))] impl PyBaseExceptionGroup { #[pyclassmethod] fn __class_getitem__( @@ -60,117 +61,6 @@ pub(super) mod types { PyGenericAlias::from_args(cls, args, vm) } - #[pyslot] - fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - // Validate exactly 2 positional arguments - if args.args.len() != 2 { - return Err(vm.new_type_error(format!( - "BaseExceptionGroup.__new__() takes exactly 2 positional arguments ({} given)", - args.args.len() - ))); - } - - // Validate message is str - let message = args.args[0].clone(); - if !message.fast_isinstance(vm.ctx.types.str_type) { - return Err(vm.new_type_error(format!( - "argument 1 must be str, not {}", - message.class().name() - ))); - } - - // Validate exceptions is a sequence (not set or None) - let exceptions_arg = &args.args[1]; - - // Check for set/frozenset (not a sequence - unordered) - if exceptions_arg.fast_isinstance(vm.ctx.types.set_type) - || exceptions_arg.fast_isinstance(vm.ctx.types.frozenset_type) - { - return Err(vm.new_type_error("second argument (exceptions) must be a sequence")); - } - - // Check for None - if exceptions_arg.is(&vm.ctx.none) { - return Err(vm.new_type_error("second argument (exceptions) must be a sequence")); - } - - let exceptions: Vec = exceptions_arg.try_to_value(vm).map_err(|_| { - vm.new_type_error("second argument (exceptions) must be a sequence") - })?; - - // Validate non-empty - if exceptions.is_empty() { - return Err(vm.new_value_error( - "second argument (exceptions) must be a non-empty sequence".to_owned(), - )); - } - - // Validate all items are BaseException instances - let mut has_non_exception = false; - for (i, exc) in exceptions.iter().enumerate() { - if !exc.fast_isinstance(vm.ctx.exceptions.base_exception_type) { - return Err(vm.new_value_error(format!( - "Item {} of second argument (exceptions) is not an exception", - i - ))); - } - // Check if any exception is not an Exception subclass - // With dynamic ExceptionGroup (inherits from both BaseExceptionGroup and Exception), - // ExceptionGroup instances are automatically instances of Exception - if !exc.fast_isinstance(vm.ctx.exceptions.exception_type) { - has_non_exception = true; - } - } - - // Get the dynamic ExceptionGroup type - let exception_group_type = crate::exception_group::exception_group(); - - // Determine the actual class to use - let actual_cls = if cls.is(exception_group_type) { - // ExceptionGroup cannot contain BaseExceptions that are not Exception - if has_non_exception { - return Err( - vm.new_type_error("Cannot nest BaseExceptions in an ExceptionGroup") - ); - } - cls - } else if cls.is(vm.ctx.exceptions.base_exception_group) { - // Auto-convert to ExceptionGroup if all are Exception subclasses - if !has_non_exception { - exception_group_type.to_owned() - } else { - cls - } - } else { - // User-defined subclass - if has_non_exception && cls.fast_issubclass(vm.ctx.exceptions.exception_type) { - return Err(vm.new_type_error(format!( - "Cannot nest BaseExceptions in '{}'", - cls.name() - ))); - } - cls - }; - - // Create the exception with (message, exceptions_tuple) as args - let exceptions_tuple = vm.ctx.new_tuple(exceptions); - let init_args = vec![message, exceptions_tuple.into()]; - PyBaseException::new(init_args, vm) - .into_ref_with_type(vm, actual_cls) - .map(Into::into) - } - - #[pyslot] - #[pymethod(name = "__init__")] - fn slot_init(_zelf: PyObjectRef, _args: FuncArgs, _vm: &VirtualMachine) -> PyResult<()> { - // CPython's BaseExceptionGroup.__init__ just calls BaseException.__init__ - // which stores args as-is. Since __new__ already set up the correct args - // (message, exceptions_tuple), we don't need to do anything here. - // This also allows subclasses to pass extra arguments to __new__ without - // __init__ complaining about argument count. - Ok(()) - } - #[pymethod] fn derive( zelf: PyRef, @@ -351,6 +241,136 @@ pub(super) mod types { } } + impl Constructor for PyBaseExceptionGroup { + type Args = crate::function::PosArgs; + + fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + let args: Self::Args = args.bind(vm)?; + let args = args.into_vec(); + // Validate exactly 2 positional arguments + if args.len() != 2 { + return Err(vm.new_type_error(format!( + "BaseExceptionGroup.__new__() takes exactly 2 positional arguments ({} given)", + args.len() + ))); + } + + // Validate message is str + let message = args[0].clone(); + if !message.fast_isinstance(vm.ctx.types.str_type) { + return Err(vm.new_type_error(format!( + "argument 1 must be str, not {}", + message.class().name() + ))); + } + + // Validate exceptions is a sequence (not set or None) + let exceptions_arg = &args[1]; + + // Check for set/frozenset (not a sequence - unordered) + if exceptions_arg.fast_isinstance(vm.ctx.types.set_type) + || exceptions_arg.fast_isinstance(vm.ctx.types.frozenset_type) + { + return Err(vm.new_type_error("second argument (exceptions) must be a sequence")); + } + + // Check for None + if exceptions_arg.is(&vm.ctx.none) { + return Err(vm.new_type_error("second argument (exceptions) must be a sequence")); + } + + let exceptions: Vec = exceptions_arg.try_to_value(vm).map_err(|_| { + vm.new_type_error("second argument (exceptions) must be a sequence") + })?; + + // Validate non-empty + if exceptions.is_empty() { + return Err(vm.new_value_error( + "second argument (exceptions) must be a non-empty sequence".to_owned(), + )); + } + + // Validate all items are BaseException instances + let mut has_non_exception = false; + for (i, exc) in exceptions.iter().enumerate() { + if !exc.fast_isinstance(vm.ctx.exceptions.base_exception_type) { + return Err(vm.new_value_error(format!( + "Item {} of second argument (exceptions) is not an exception", + i + ))); + } + // Check if any exception is not an Exception subclass + // With dynamic ExceptionGroup (inherits from both BaseExceptionGroup and Exception), + // ExceptionGroup instances are automatically instances of Exception + if !exc.fast_isinstance(vm.ctx.exceptions.exception_type) { + has_non_exception = true; + } + } + + // Get the dynamic ExceptionGroup type + let exception_group_type = crate::exception_group::exception_group(); + + // Determine the actual class to use + let actual_cls = if cls.is(exception_group_type) { + // ExceptionGroup cannot contain BaseExceptions that are not Exception + if has_non_exception { + return Err( + vm.new_type_error("Cannot nest BaseExceptions in an ExceptionGroup") + ); + } + cls + } else if cls.is(vm.ctx.exceptions.base_exception_group) { + // Auto-convert to ExceptionGroup if all are Exception subclasses + if !has_non_exception { + exception_group_type.to_owned() + } else { + cls + } + } else { + // User-defined subclass + if has_non_exception && cls.fast_issubclass(vm.ctx.exceptions.exception_type) { + return Err(vm.new_type_error(format!( + "Cannot nest BaseExceptions in '{}'", + cls.name() + ))); + } + cls + }; + + // Create the exception with (message, exceptions_tuple) as args + let exceptions_tuple = vm.ctx.new_tuple(exceptions); + let init_args = vec![message, exceptions_tuple.into()]; + PyBaseException::new(init_args, vm) + .into_ref_with_type(vm, actual_cls) + .map(Into::into) + } + + fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { + unimplemented!("use slot_new") + } + } + + impl Initializer for PyBaseExceptionGroup { + type Args = FuncArgs; + + fn slot_init( + _zelf: PyObjectRef, + _args: ::rustpython_vm::function::FuncArgs, + _vm: &::rustpython_vm::VirtualMachine, + ) -> ::rustpython_vm::PyResult<()> { + // CPython's BaseExceptionGroup.__init__ just calls BaseException.__init__ + // which stores args as-is. Since __new__ already set up the correct args + // (message, exceptions_tuple), we don't need to do anything here. + // This also allows subclasses to pass extra arguments to __new__ without + // __init__ complaining about argument count. + Ok(()) + } + + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } + } + // Helper functions for ExceptionGroup fn is_base_exception_group(obj: &PyObject, vm: &VirtualMachine) -> bool { obj.fast_isinstance(vm.ctx.exceptions.base_exception_group) diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index bb10ca02c2c..2c36aa13bd5 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -1371,18 +1371,19 @@ pub(super) mod types { #[repr(transparent)] pub struct PyStopIteration(PyException); - #[pyexception] - impl PyStopIteration { - #[pyslot] - #[pymethod(name = "__init__")] - pub(crate) fn slot_init( - zelf: PyObjectRef, - args: ::rustpython_vm::function::FuncArgs, - vm: &::rustpython_vm::VirtualMachine, - ) -> ::rustpython_vm::PyResult<()> { + #[pyexception(with(Initializer))] + impl PyStopIteration {} + + impl Initializer for PyStopIteration { + type Args = FuncArgs; + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { zelf.set_attr("value", vm.unwrap_or_none(args.args.first().cloned()), vm)?; Ok(()) } + + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } } #[pyexception(name, base = PyException, ctx = "stop_async_iteration", impl)] @@ -1419,15 +1420,13 @@ pub(super) mod types { #[repr(transparent)] pub struct PyAttributeError(PyException); - #[pyexception] - impl PyAttributeError { - #[pyslot] - #[pymethod(name = "__init__")] - pub(crate) fn slot_init( - zelf: PyObjectRef, - args: ::rustpython_vm::function::FuncArgs, - vm: &::rustpython_vm::VirtualMachine, - ) -> ::rustpython_vm::PyResult<()> { + #[pyexception(with(Initializer))] + impl PyAttributeError {} + + impl Initializer for PyAttributeError { + type Args = FuncArgs; + + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { zelf.set_attr( "name", vm.unwrap_or_none(args.kwargs.get("name").cloned()), @@ -1440,6 +1439,10 @@ pub(super) mod types { )?; Ok(()) } + + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } } #[pyexception(name, base = PyException, ctx = "buffer_error", impl)] @@ -1457,15 +1460,28 @@ pub(super) mod types { #[repr(transparent)] pub struct PyImportError(PyException); - #[pyexception] + #[pyexception(with(Initializer))] impl PyImportError { - #[pyslot] - #[pymethod(name = "__init__")] - pub(crate) fn slot_init( - zelf: PyObjectRef, - args: ::rustpython_vm::function::FuncArgs, - vm: &::rustpython_vm::VirtualMachine, - ) -> ::rustpython_vm::PyResult<()> { + #[pymethod] + fn __reduce__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyTupleRef { + let obj = exc.as_object().to_owned(); + let mut result: Vec = vec![ + obj.class().to_owned().into(), + vm.new_tuple((exc.get_arg(0).unwrap(),)).into(), + ]; + + if let Some(dict) = obj.dict().filter(|x| !x.is_empty()) { + result.push(dict.into()); + } + + result.into_pytuple(vm) + } + } + + impl Initializer for PyImportError { + type Args = FuncArgs; + + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { let mut kwargs = args.kwargs.clone(); let name = kwargs.swap_remove("name"); let path = kwargs.swap_remove("path"); @@ -1482,19 +1498,9 @@ pub(super) mod types { dict.set_item("path", vm.unwrap_or_none(path), vm)?; PyBaseException::slot_init(zelf, args, vm) } - #[pymethod] - fn __reduce__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyTupleRef { - let obj = exc.as_object().to_owned(); - let mut result: Vec = vec![ - obj.class().to_owned().into(), - vm.new_tuple((exc.get_arg(0).unwrap(),)).into(), - ]; - if let Some(dict) = obj.dict().filter(|x| !x.is_empty()) { - result.push(dict.into()); - } - - result.into_pytuple(vm) + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") } } @@ -1660,11 +1666,10 @@ pub(super) mod types { } } - #[pyexception(with(Constructor))] - impl PyOSError { - #[pyslot] - #[pymethod(name = "__init__")] - pub fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + impl Initializer for PyOSError { + type Args = FuncArgs; + + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { let len = args.args.len(); let mut new_args = args; @@ -1718,6 +1723,13 @@ pub(super) mod types { PyBaseException::slot_init(zelf, new_args, vm) } + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } + } + + #[pyexception(with(Constructor, Initializer))] + impl PyOSError { #[pymethod] fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { let obj = exc.as_object().to_owned(); @@ -2011,44 +2023,8 @@ pub(super) mod types { #[repr(transparent)] pub struct PySyntaxError(PyException); - #[pyexception] + #[pyexception(with(Initializer))] impl PySyntaxError { - #[pyslot] - #[pymethod(name = "__init__")] - fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { - let len = args.args.len(); - let new_args = args; - - zelf.set_attr("print_file_and_line", vm.ctx.none(), vm)?; - - if len == 2 - && let Ok(location_tuple) = new_args.args[1] - .clone() - .downcast::() - { - let location_tup_len = location_tuple.len(); - for (i, &attr) in [ - "filename", - "lineno", - "offset", - "text", - "end_lineno", - "end_offset", - ] - .iter() - .enumerate() - { - if location_tup_len > i { - zelf.set_attr(attr, location_tuple[i].to_owned(), vm)?; - } else { - break; - } - } - } - - PyBaseException::slot_init(zelf, new_args, vm) - } - #[pymethod] fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyStrRef { fn basename(filename: &str) -> &str { @@ -2097,6 +2073,48 @@ pub(super) mod types { } } + impl Initializer for PySyntaxError { + type Args = FuncArgs; + + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + let len = args.args.len(); + let new_args = args; + + zelf.set_attr("print_file_and_line", vm.ctx.none(), vm)?; + + if len == 2 + && let Ok(location_tuple) = new_args.args[1] + .clone() + .downcast::() + { + let location_tup_len = location_tuple.len(); + for (i, &attr) in [ + "filename", + "lineno", + "offset", + "text", + "end_lineno", + "end_offset", + ] + .iter() + .enumerate() + { + if location_tup_len > i { + zelf.set_attr(attr, location_tuple[i].to_owned(), vm)?; + } else { + break; + } + } + } + + PyBaseException::slot_init(zelf, new_args, vm) + } + + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } + } + #[pyexception( name = "_IncompleteInputError", base = PySyntaxError, @@ -2106,17 +2124,19 @@ pub(super) mod types { #[repr(transparent)] pub struct PyIncompleteInputError(PySyntaxError); - #[pyexception] - impl PyIncompleteInputError { - #[pyslot] - #[pymethod(name = "__init__")] - pub(crate) fn slot_init( - zelf: PyObjectRef, - _args: FuncArgs, - vm: &VirtualMachine, - ) -> PyResult<()> { + #[pyexception(with(Initializer))] + impl PyIncompleteInputError {} + + impl Initializer for PyIncompleteInputError { + type Args = FuncArgs; + + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { zelf.set_attr("name", vm.ctx.new_str("SyntaxError"), vm)?; - Ok(()) + PySyntaxError::slot_init(zelf, args, vm) + } + + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") } } @@ -2155,26 +2175,8 @@ pub(super) mod types { #[repr(transparent)] pub struct PyUnicodeDecodeError(PyUnicodeError); - #[pyexception] + #[pyexception(with(Initializer))] impl PyUnicodeDecodeError { - #[pyslot] - #[pymethod(name = "__init__")] - pub(crate) fn slot_init( - zelf: PyObjectRef, - args: FuncArgs, - vm: &VirtualMachine, - ) -> PyResult<()> { - type Args = (PyStrRef, ArgBytesLike, isize, isize, PyStrRef); - let (encoding, object, start, end, reason): Args = args.bind(vm)?; - zelf.set_attr("encoding", encoding, vm)?; - let object_as_bytes = vm.ctx.new_bytes(object.borrow_buf().to_vec()); - zelf.set_attr("object", object_as_bytes, vm)?; - zelf.set_attr("start", vm.ctx.new_int(start), vm)?; - zelf.set_attr("end", vm.ctx.new_int(end), vm)?; - zelf.set_attr("reason", reason, vm)?; - Ok(()) - } - #[pymethod] fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { let Ok(object) = exc.as_object().get_attr("object", vm) else { @@ -2202,30 +2204,33 @@ pub(super) mod types { } } - #[pyexception(name, base = PyUnicodeError, ctx = "unicode_encode_error")] - #[derive(Debug)] - #[repr(transparent)] - pub struct PyUnicodeEncodeError(PyUnicodeError); + impl Initializer for PyUnicodeDecodeError { + type Args = FuncArgs; - #[pyexception] - impl PyUnicodeEncodeError { - #[pyslot] - #[pymethod(name = "__init__")] - pub(crate) fn slot_init( - zelf: PyObjectRef, - args: FuncArgs, - vm: &VirtualMachine, - ) -> PyResult<()> { - type Args = (PyStrRef, PyStrRef, isize, isize, PyStrRef); + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + type Args = (PyStrRef, ArgBytesLike, isize, isize, PyStrRef); let (encoding, object, start, end, reason): Args = args.bind(vm)?; zelf.set_attr("encoding", encoding, vm)?; - zelf.set_attr("object", object, vm)?; + let object_as_bytes = vm.ctx.new_bytes(object.borrow_buf().to_vec()); + zelf.set_attr("object", object_as_bytes, vm)?; zelf.set_attr("start", vm.ctx.new_int(start), vm)?; zelf.set_attr("end", vm.ctx.new_int(end), vm)?; zelf.set_attr("reason", reason, vm)?; Ok(()) } + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } + } + + #[pyexception(name, base = PyUnicodeError, ctx = "unicode_encode_error")] + #[derive(Debug)] + #[repr(transparent)] + pub struct PyUnicodeEncodeError(PyUnicodeError); + + #[pyexception(with(Initializer))] + impl PyUnicodeEncodeError { #[pymethod] fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { let Ok(object) = exc.as_object().get_attr("object", vm) else { @@ -2254,22 +2259,13 @@ pub(super) mod types { } } - #[pyexception(name, base = PyUnicodeError, ctx = "unicode_translate_error")] - #[derive(Debug)] - #[repr(transparent)] - pub struct PyUnicodeTranslateError(PyUnicodeError); + impl Initializer for PyUnicodeEncodeError { + type Args = FuncArgs; - #[pyexception] - impl PyUnicodeTranslateError { - #[pyslot] - #[pymethod(name = "__init__")] - pub(crate) fn slot_init( - zelf: PyObjectRef, - args: FuncArgs, - vm: &VirtualMachine, - ) -> PyResult<()> { - type Args = (PyStrRef, isize, isize, PyStrRef); - let (object, start, end, reason): Args = args.bind(vm)?; + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + type Args = (PyStrRef, PyStrRef, isize, isize, PyStrRef); + let (encoding, object, start, end, reason): Args = args.bind(vm)?; + zelf.set_attr("encoding", encoding, vm)?; zelf.set_attr("object", object, vm)?; zelf.set_attr("start", vm.ctx.new_int(start), vm)?; zelf.set_attr("end", vm.ctx.new_int(end), vm)?; @@ -2277,6 +2273,18 @@ pub(super) mod types { Ok(()) } + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } + } + + #[pyexception(name, base = PyUnicodeError, ctx = "unicode_translate_error")] + #[derive(Debug)] + #[repr(transparent)] + pub struct PyUnicodeTranslateError(PyUnicodeError); + + #[pyexception(with(Initializer))] + impl PyUnicodeTranslateError { #[pymethod] fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { let Ok(object) = exc.as_object().get_attr("object", vm) else { @@ -2301,6 +2309,24 @@ pub(super) mod types { } } + impl Initializer for PyUnicodeTranslateError { + type Args = FuncArgs; + + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + type Args = (PyStrRef, isize, isize, PyStrRef); + let (object, start, end, reason): Args = args.bind(vm)?; + zelf.set_attr("object", object, vm)?; + zelf.set_attr("start", vm.ctx.new_int(start), vm)?; + zelf.set_attr("end", vm.ctx.new_int(end), vm)?; + zelf.set_attr("reason", reason, vm)?; + Ok(()) + } + + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } + } + /// JIT error. #[cfg(feature = "jit")] #[pyexception(name, base = PyException, ctx = "jit_error", impl)] diff --git a/crates/vm/src/stdlib/ast/python.rs b/crates/vm/src/stdlib/ast/python.rs index aa21d8b034a..35fe561527b 100644 --- a/crates/vm/src/stdlib/ast/python.rs +++ b/crates/vm/src/stdlib/ast/python.rs @@ -3,50 +3,18 @@ use super::{PY_CF_OPTIMIZED_AST, PY_CF_TYPE_COMMENTS, PY_COMPILE_FLAG_AST_ONLY}; #[pymodule] pub(crate) mod _ast { use crate::{ - AsObject, Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef}, function::FuncArgs, - types::Constructor, + types::{Constructor, Initializer}, }; #[pyattr] #[pyclass(module = "_ast", name = "AST")] #[derive(Debug, PyPayload)] pub(crate) struct NodeAst; - #[pyclass(with(Constructor), flags(BASETYPE, HAS_DICT))] + #[pyclass(with(Constructor, Initializer), flags(BASETYPE, HAS_DICT))] impl NodeAst { - #[pyslot] - #[pymethod] - fn __init__(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { - let fields = zelf.get_attr("_fields", vm)?; - let fields: Vec = fields.try_to_value(vm)?; - let n_args = args.args.len(); - if n_args > fields.len() { - return Err(vm.new_type_error(format!( - "{} constructor takes at most {} positional argument{}", - zelf.class().name(), - fields.len(), - if fields.len() == 1 { "" } else { "s" }, - ))); - } - for (name, arg) in fields.iter().zip(args.args) { - zelf.set_attr(name, arg, vm)?; - } - for (key, value) in args.kwargs { - if let Some(pos) = fields.iter().position(|f| f.as_str() == key) - && pos < n_args - { - return Err(vm.new_type_error(format!( - "{} got multiple values for argument '{}'", - zelf.class().name(), - key - ))); - } - zelf.set_attr(vm.ctx.intern_str(key), value, vm)?; - } - Ok(()) - } - #[pyattr] fn _fields(ctx: &Context) -> PyTupleRef { ctx.empty_tuple.clone() @@ -71,7 +39,8 @@ pub(crate) mod _ast { let zelf = vm.ctx.new_base_object(cls, dict); // Initialize the instance with the provided arguments - Self::__init__(zelf.clone(), args, vm)?; + // FIXME: This is probably incorrect. Please check if init should be called outside of __new__ + Self::slot_init(zelf.clone(), args, vm)?; Ok(zelf) } @@ -81,6 +50,44 @@ pub(crate) mod _ast { } } + impl Initializer for NodeAst { + type Args = FuncArgs; + + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + let fields = zelf.get_attr("_fields", vm)?; + let fields: Vec = fields.try_to_value(vm)?; + let n_args = args.args.len(); + if n_args > fields.len() { + return Err(vm.new_type_error(format!( + "{} constructor takes at most {} positional argument{}", + zelf.class().name(), + fields.len(), + if fields.len() == 1 { "" } else { "s" }, + ))); + } + for (name, arg) in fields.iter().zip(args.args) { + zelf.set_attr(name, arg, vm)?; + } + for (key, value) in args.kwargs { + if let Some(pos) = fields.iter().position(|f| f.as_str() == key) + && pos < n_args + { + return Err(vm.new_type_error(format!( + "{} got multiple values for argument '{}'", + zelf.class().name(), + key + ))); + } + zelf.set_attr(vm.ctx.intern_str(key), value, vm)?; + } + Ok(()) + } + + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + unreachable!("slot_init is defined") + } + } + #[pyattr(name = "PyCF_ONLY_AST")] use super::PY_COMPILE_FLAG_AST_ONLY; From 5925f1483c197fcde5300b0f3058b0977c03cd58 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Dec 2025 22:25:09 +0900 Subject: [PATCH 023/418] Bump insta from 1.44.3 to 1.45.0 (#6468) Bumps [insta](https://github.com/mitsuhiko/insta) from 1.44.3 to 1.45.0. - [Release notes](https://github.com/mitsuhiko/insta/releases) - [Changelog](https://github.com/mitsuhiko/insta/blob/master/CHANGELOG.md) - [Commits](https://github.com/mitsuhiko/insta/compare/1.44.3...1.45.0) --- updated-dependencies: - dependency-name: insta dependency-version: 1.45.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 20 +++++++++++++++++--- Cargo.toml | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2873f3a529f..a57a36a40f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1181,7 +1181,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1487,13 +1487,14 @@ dependencies = [ [[package]] name = "insta" -version = "1.44.3" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5c943d4415edd8153251b6f197de5eb1640e56d84e8d9159bea190421c73698" +checksum = "b76866be74d68b1595eb8060cb9191dca9c021db2316558e52ddc5d55d41b66c" dependencies = [ "console", "once_cell", "similar", + "tempfile", ] [[package]] @@ -3721,6 +3722,19 @@ dependencies = [ "shared-build", ] +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "termios" version = "0.3.3" diff --git a/Cargo.toml b/Cargo.toml index f68e6e68157..fad506ddfaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -173,7 +173,7 @@ getrandom = { version = "0.3", features = ["std"] } glob = "0.3" hex = "0.4.3" indexmap = { version = "2.11.3", features = ["std"] } -insta = "1.44" +insta = "1.45" itertools = "0.14.0" is-macro = "0.3.7" junction = "1.3.0" From 2342006b37ccfef8633bd937c1758ac1101f99a8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 08:57:03 +0900 Subject: [PATCH 024/418] Bump aws-lc-rs from 1.15.1 to 1.15.2 (#6469) Bumps [aws-lc-rs](https://github.com/aws/aws-lc-rs) from 1.15.1 to 1.15.2. - [Release notes](https://github.com/aws/aws-lc-rs/releases) - [Commits](https://github.com/aws/aws-lc-rs/compare/v1.15.1...v1.15.2) --- updated-dependencies: - dependency-name: aws-lc-rs dependency-version: 1.15.2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- crates/stdlib/Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a57a36a40f9..e3296732cb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -254,9 +254,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.15.1" +version = "1.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b5ce75405893cd713f9ab8e297d8e438f624dde7d706108285f7e17a25a180f" +checksum = "6a88aab2464f1f25453baa7a07c84c5b7684e274054ba06817f382357f77a288" dependencies = [ "aws-lc-fips-sys", "aws-lc-sys", @@ -266,9 +266,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.34.0" +version = "0.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "179c3777a8b5e70e90ea426114ffc565b2c1a9f82f6c4a0c5a34aa6ef5e781b6" +checksum = "b45afffdee1e7c9126814751f88dddc747f41d91da16c9551a0f1e8a11e788a1" dependencies = [ "cc", "cmake", diff --git a/crates/stdlib/Cargo.toml b/crates/stdlib/Cargo.toml index a5328697ca8..0cd853223e2 100644 --- a/crates/stdlib/Cargo.toml +++ b/crates/stdlib/Cargo.toml @@ -127,7 +127,7 @@ x509-parser = { version = "0.18", optional = true } der = { version = "0.7", features = ["alloc", "oid"], optional = true } pem-rfc7468 = { version = "1.0", features = ["alloc"], optional = true } webpki-roots = { version = "1.0", optional = true } -aws-lc-rs = { version = "1.14.1", optional = true } +aws-lc-rs = { version = "1.15.2", optional = true } oid-registry = { version = "0.8", features = ["x509", "pkcs1", "nist_algs"], optional = true } pkcs8 = { version = "0.10", features = ["encryption", "pkcs5", "pem"], optional = true } From d919fe516e65bd259d05295773d15859d597e7c7 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 19 Dec 2025 16:49:07 +0900 Subject: [PATCH 025/418] mark test_threading --- Lib/test/test_threading.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 94f21d1c38f..8f86792e820 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -229,6 +229,7 @@ def f(mutex): # PyThreadState_SetAsyncExc() is a CPython-only gimmick, not (currently) # exposed at the Python level. This test relies on ctypes to get at it. + @unittest.skip("TODO: RUSTPYTHON; expects @cpython_only") def test_PyThreadState_SetAsyncExc(self): ctypes = import_module("ctypes") @@ -332,6 +333,7 @@ def fail_new_thread(*args): finally: threading._start_new_thread = _start_new_thread + @unittest.skip("TODO: RUSTPYTHON; ctypes.pythonapi is not supported") def test_finalize_running_thread(self): # Issue 1402: the PyGILState_Ensure / _Release functions may be called # very late on python exit: on deallocation of a running thread for From 73e1c3816e737b0559063c44f52537d6963aa267 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 23 Dec 2025 11:55:59 +0900 Subject: [PATCH 026/418] fix ssl --- crates/stdlib/src/openssl.rs | 14 ++++++++++---- crates/stdlib/src/ssl.rs | 18 ++++++++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index 4d420e7d539..d352d15a614 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -1543,10 +1543,10 @@ mod _ssl { #[pymethod] fn get_ca_certs( &self, - binary_form: OptionalArg, + args: GetCertArgs, vm: &VirtualMachine, ) -> PyResult> { - let binary_form = binary_form.unwrap_or(false); + let binary_form = args.binary_form.unwrap_or(false); let ctx = self.ctx(); #[cfg(ossl300)] let certs = ctx.cert_store().all_certificates(); @@ -2259,6 +2259,12 @@ mod _ssl { password: Option, } + #[derive(FromArgs)] + struct GetCertArgs { + #[pyarg(any, optional)] + binary_form: OptionalArg, + } + // Err is true if the socket is blocking type SocketDeadline = Result; @@ -2516,10 +2522,10 @@ mod _ssl { #[pymethod] fn getpeercert( &self, - binary: OptionalArg, + args: GetCertArgs, vm: &VirtualMachine, ) -> PyResult> { - let binary = binary.unwrap_or(false); + let binary = args.binary_form.unwrap_or(false); let stream = self.connection.read(); if !stream.ssl().is_init_finished() { return Err(vm.new_value_error("handshake not done yet")); diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index bf25260ed3a..16449e2d019 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -841,6 +841,12 @@ mod _ssl { password: OptionalArg, } + #[derive(FromArgs)] + struct GetCertArgs { + #[pyarg(any, optional)] + binary_form: OptionalArg, + } + #[pyclass(with(Constructor), flags(BASETYPE))] impl PySSLContext { // Helper method to convert DER certificate bytes to Python dict @@ -1688,12 +1694,8 @@ mod _ssl { } #[pymethod] - fn get_ca_certs( - &self, - binary_form: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - let binary_form = binary_form.unwrap_or(false); + fn get_ca_certs(&self, args: GetCertArgs, vm: &VirtualMachine) -> PyResult { + let binary_form = args.binary_form.unwrap_or(false); let ca_certs_der = self.ca_certs_der.read(); let mut certs = Vec::new(); @@ -3444,10 +3446,10 @@ mod _ssl { #[pymethod] fn getpeercert( &self, - binary_form: OptionalArg, + args: GetCertArgs, vm: &VirtualMachine, ) -> PyResult> { - let binary = binary_form.unwrap_or(false); + let binary = args.binary_form.unwrap_or(false); // Check if handshake is complete if !*self.handshake_done.lock() { From a84452ab457c71880e1a90cc488c314596d7d184 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 23 Dec 2025 11:52:31 +0900 Subject: [PATCH 027/418] fix sysconfigdata --- Lib/test/test_sysconfig.py | 1 - crates/vm/src/stdlib/sysconfigdata.rs | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py index 35e62d54635..965780668ca 100644 --- a/Lib/test/test_sysconfig.py +++ b/Lib/test/test_sysconfig.py @@ -447,7 +447,6 @@ def test_main(self): _main() self.assertTrue(len(output.getvalue().split('\n')) > 0) - @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipIf(sys.platform == "win32", "Does not apply to Windows") def test_ldshared_value(self): ldflags = sysconfig.get_config_var('LDFLAGS') diff --git a/crates/vm/src/stdlib/sysconfigdata.rs b/crates/vm/src/stdlib/sysconfigdata.rs index 90e46b83b97..ee40b693aa2 100644 --- a/crates/vm/src/stdlib/sysconfigdata.rs +++ b/crates/vm/src/stdlib/sysconfigdata.rs @@ -1,3 +1,5 @@ +// spell-checker: words LDSHARED ARFLAGS CPPFLAGS CCSHARED BASECFLAGS BLDSHARED + pub(crate) use _sysconfigdata::make_module; #[pymodule] @@ -18,6 +20,21 @@ pub(crate) mod _sysconfigdata { "MULTIARCH" => MULTIARCH, // enough for tests to stop expecting urandom() to fail after restricting file resources "HAVE_GETRANDOM" => 1, + // Compiler configuration for native extension builds + "CC" => "cc", + "CXX" => "c++", + "CFLAGS" => "", + "CPPFLAGS" => "", + "LDFLAGS" => "", + "LDSHARED" => "cc -shared", + "CCSHARED" => "", + "SHLIB_SUFFIX" => ".so", + "SO" => ".so", + "AR" => "ar", + "ARFLAGS" => "rcs", + "OPT" => "", + "BASECFLAGS" => "", + "BLDSHARED" => "cc -shared", } include!(concat!(env!("OUT_DIR"), "/env_vars.rs")); vars From df523cb58c42a9a10a539ca1598dd264876da4f4 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Tue, 23 Dec 2025 14:58:42 +0900 Subject: [PATCH 028/418] skip spawnve on windows --- Lib/test/test_os.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 105629bda19..939315379f2 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -3491,6 +3491,7 @@ def test_spawnl(self): self.assertEqual(exitcode, self.exitcode) @requires_os_func('spawnle') + @unittest.skipIf(sys.platform == 'win32', "TODO: RUSTPYTHON; fix spawnve on Windows") def test_spawnle(self): program, args = self.create_args(with_env=True) exitcode = os.spawnle(os.P_WAIT, program, *args, self.env) @@ -3519,6 +3520,7 @@ def test_spawnv(self): self.assertEqual(exitcode, self.exitcode) @requires_os_func('spawnve') + @unittest.skipIf(sys.platform == 'win32', "TODO: RUSTPYTHON; fix spawnve on Windows") def test_spawnve(self): program, args = self.create_args(with_env=True) exitcode = os.spawnve(os.P_WAIT, program, args, self.env) @@ -3627,6 +3629,7 @@ def _test_invalid_env(self, spawn): self.assertEqual(exitcode, 0) @requires_os_func('spawnve') + @unittest.skipIf(sys.platform == 'win32', "TODO: RUSTPYTHON; fix spawnve on Windows") def test_spawnve_invalid_env(self): self._test_invalid_env(os.spawnve) From 9760249a750c272733369ff886aac22ffbe25d31 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 23 Dec 2025 15:24:07 +0900 Subject: [PATCH 029/418] fix multiarch --- Lib/test/test_sysconfig.py | 1 - crates/vm/src/stdlib/sys.rs | 20 +++++++++++++------- crates/vm/src/stdlib/sysconfigdata.rs | 11 +++++++---- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py index 965780668ca..82c11bdf7e2 100644 --- a/Lib/test/test_sysconfig.py +++ b/Lib/test/test_sysconfig.py @@ -598,7 +598,6 @@ def test_android_ext_suffix(self): self.assertTrue(suffix.endswith(f"-{expected_triplet}.so"), f"{machine=}, {suffix=}") - @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipUnless(sys.platform == 'darwin', 'OS X-specific test') def test_osx_ext_suffix(self): suffix = sysconfig.get_config_var('EXT_SUFFIX') diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index cfe6f9f5e61..0e46ec18a01 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -1,6 +1,8 @@ use crate::{Py, PyResult, VirtualMachine, builtins::PyModule, convert::ToPyObject}; -pub(crate) use sys::{__module_def, DOC, MAXSIZE, MULTIARCH, UnraisableHookArgsData}; +pub(crate) use sys::{ + __module_def, DOC, MAXSIZE, RUST_MULTIARCH, UnraisableHookArgsData, multiarch, +}; #[pymodule] mod sys { @@ -37,10 +39,14 @@ mod sys { System::LibraryLoader::{GetModuleFileNameW, GetModuleHandleW}, }; - // not the same as CPython (e.g. rust's x86_x64-unknown-linux-gnu is just x86_64-linux-gnu) - // but hopefully that's just an implementation detail? TODO: copy CPython's multiarch exactly, - // https://github.com/python/cpython/blob/3.8/configure.ac#L725 - pub(crate) const MULTIARCH: &str = env!("RUSTPYTHON_TARGET_TRIPLE"); + // Rust target triple (e.g., "x86_64-unknown-linux-gnu") + pub(crate) const RUST_MULTIARCH: &str = env!("RUSTPYTHON_TARGET_TRIPLE"); + + /// Convert Rust target triple to CPython-style multiarch + /// e.g., "x86_64-unknown-linux-gnu" -> "x86_64-linux-gnu" + pub(crate) fn multiarch() -> String { + RUST_MULTIARCH.replace("-unknown", "") + } #[pyattr(name = "_rustpython_debugbuild")] const RUSTPYTHON_DEBUGBUILD: bool = cfg!(debug_assertions); @@ -189,7 +195,7 @@ mod sys { py_namespace!(vm, { "name" => ctx.new_str(NAME), "cache_tag" => ctx.new_str(cache_tag), - "_multiarch" => ctx.new_str(MULTIARCH.to_owned()), + "_multiarch" => ctx.new_str(multiarch()), "version" => version_info(vm), "hexversion" => ctx.new_int(version::VERSION_HEX), }) @@ -1249,6 +1255,6 @@ pub(crate) fn sysconfigdata_name() -> String { "_sysconfigdata_{}_{}_{}", sys::ABIFLAGS, sys::PLATFORM, - sys::MULTIARCH + sys::multiarch() ) } diff --git a/crates/vm/src/stdlib/sysconfigdata.rs b/crates/vm/src/stdlib/sysconfigdata.rs index ee40b693aa2..5e954f7fe7a 100644 --- a/crates/vm/src/stdlib/sysconfigdata.rs +++ b/crates/vm/src/stdlib/sysconfigdata.rs @@ -4,20 +4,23 @@ pub(crate) use _sysconfigdata::make_module; #[pymodule] pub(crate) mod _sysconfigdata { - use crate::{VirtualMachine, builtins::PyDictRef, convert::ToPyObject, stdlib::sys::MULTIARCH}; + use crate::stdlib::sys::{RUST_MULTIARCH, multiarch}; + use crate::{VirtualMachine, builtins::PyDictRef, convert::ToPyObject}; #[pyattr] fn build_time_vars(vm: &VirtualMachine) -> PyDictRef { let vars = vm.ctx.new_dict(); + let multiarch = multiarch(); macro_rules! sysvars { ($($key:literal => $value:expr),*$(,)?) => {{ $(vars.set_item($key, $value.to_pyobject(vm), vm).unwrap();)* }}; } sysvars! { - // fake shared module extension - "EXT_SUFFIX" => format!(".rustpython-{MULTIARCH}"), - "MULTIARCH" => MULTIARCH, + // Extension module suffix in CPython-compatible format + "EXT_SUFFIX" => format!(".rustpython313-{multiarch}.so"), + "MULTIARCH" => multiarch.clone(), + "RUST_MULTIARCH" => RUST_MULTIARCH, // enough for tests to stop expecting urandom() to fail after restricting file resources "HAVE_GETRANDOM" => 1, // Compiler configuration for native extension builds From b44229f7ca4f7718f405c4a0f397c2a70c9178ec Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Tue, 23 Dec 2025 21:57:05 +0900 Subject: [PATCH 030/418] debug whats_left --- whats_left.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/whats_left.py b/whats_left.py index 91e46bef7ef..7d02a4cea4d 100755 --- a/whats_left.py +++ b/whats_left.py @@ -361,7 +361,7 @@ def method_incompatibility_reason(typ, method_name, real_method_value): if platform.python_implementation() == "CPython": if not_implementeds: - sys.exit("ERROR: CPython should have all the methods") + sys.exit(f"ERROR: CPython should have all the methods but missing: {not_implementeds}") mod_names = [ name.decode() @@ -455,6 +455,7 @@ def remove_one_indent(s): ) # The last line should be json output, the rest of the lines can contain noise # because importing certain modules can print stuff to stdout/stderr +print(result.stderr, file=sys.stderr) result = json.loads(result.stdout.splitlines()[-1]) if args.json: From fc0a34a5a53730a0f1b59c432e568fea9f84b8f3 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 20 Dec 2025 14:31:33 +0900 Subject: [PATCH 031/418] Remove unused ctypes/field.rs --- crates/vm/src/stdlib/ctypes/field.rs | 306 --------------------------- 1 file changed, 306 deletions(-) delete mode 100644 crates/vm/src/stdlib/ctypes/field.rs diff --git a/crates/vm/src/stdlib/ctypes/field.rs b/crates/vm/src/stdlib/ctypes/field.rs deleted file mode 100644 index ea57d68065a..00000000000 --- a/crates/vm/src/stdlib/ctypes/field.rs +++ /dev/null @@ -1,306 +0,0 @@ -use crate::builtins::PyType; -use crate::function::PySetterValue; -use crate::types::{GetDescriptor, Representable}; -use crate::{AsObject, Py, PyObject, PyObjectRef, PyResult, VirtualMachine}; -use num_traits::ToPrimitive; - -use super::structure::PyCStructure; -use super::union::PyCUnion; - -#[pyclass(name = "PyCFieldType", base = PyType, module = "_ctypes")] -#[derive(Debug)] -pub struct PyCFieldType { - pub _base: PyType, - #[allow(dead_code)] - pub(super) inner: PyCField, -} - -#[pyclass] -impl PyCFieldType {} - -#[pyclass(name = "CField", module = "_ctypes")] -#[derive(Debug, PyPayload)] -pub struct PyCField { - pub(super) byte_offset: usize, - pub(super) byte_size: usize, - #[allow(unused)] - pub(super) index: usize, - /// The ctypes type for this field (can be any ctypes type including arrays) - pub(super) proto: PyObjectRef, - pub(super) anonymous: bool, - pub(super) bitfield_size: bool, - pub(super) bit_offset: u8, - pub(super) name: String, -} - -impl PyCField { - pub fn new( - name: String, - proto: PyObjectRef, - byte_offset: usize, - byte_size: usize, - index: usize, - ) -> Self { - Self { - name, - proto, - byte_offset, - byte_size, - index, - anonymous: false, - bitfield_size: false, - bit_offset: 0, - } - } -} - -impl Representable for PyCField { - fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { - // Get type name from the proto object - let tp_name = if let Some(name_attr) = vm - .ctx - .interned_str("__name__") - .and_then(|s| zelf.proto.get_attr(s, vm).ok()) - { - name_attr.str(vm)?.to_string() - } else { - zelf.proto.class().name().to_string() - }; - - if zelf.bitfield_size { - Ok(format!( - "<{} type={}, ofs={byte_offset}, bit_size={bitfield_size}, bit_offset={bit_offset}", - zelf.name, - tp_name, - byte_offset = zelf.byte_offset, - bitfield_size = zelf.bitfield_size, - bit_offset = zelf.bit_offset - )) - } else { - Ok(format!( - "<{} type={tp_name}, ofs={}, size={}", - zelf.name, zelf.byte_offset, zelf.byte_size - )) - } - } -} - -impl GetDescriptor for PyCField { - fn descr_get( - zelf: PyObjectRef, - obj: Option, - _cls: Option, - vm: &VirtualMachine, - ) -> PyResult { - let zelf = zelf - .downcast::() - .map_err(|_| vm.new_type_error("expected CField".to_owned()))?; - - // If obj is None, return the descriptor itself (class attribute access) - let obj = match obj { - Some(obj) if !vm.is_none(&obj) => obj, - _ => return Ok(zelf.into()), - }; - - // Instance attribute access - read value from the structure/union's buffer - if let Some(structure) = obj.downcast_ref::() { - let cdata = structure.cdata.read(); - let offset = zelf.byte_offset; - let size = zelf.byte_size; - - if offset + size <= cdata.buffer.len() { - let bytes = &cdata.buffer[offset..offset + size]; - return PyCField::bytes_to_value(bytes, size, vm); - } - } else if let Some(union) = obj.downcast_ref::() { - let cdata = union.cdata.read(); - let offset = zelf.byte_offset; - let size = zelf.byte_size; - - if offset + size <= cdata.buffer.len() { - let bytes = &cdata.buffer[offset..offset + size]; - return PyCField::bytes_to_value(bytes, size, vm); - } - } - - // Fallback: return 0 for uninitialized or unsupported types - Ok(vm.ctx.new_int(0).into()) - } -} - -impl PyCField { - /// Convert bytes to a Python value based on size - fn bytes_to_value(bytes: &[u8], size: usize, vm: &VirtualMachine) -> PyResult { - match size { - 1 => Ok(vm.ctx.new_int(bytes[0] as i8).into()), - 2 => { - let val = i16::from_ne_bytes([bytes[0], bytes[1]]); - Ok(vm.ctx.new_int(val).into()) - } - 4 => { - let val = i32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); - Ok(vm.ctx.new_int(val).into()) - } - 8 => { - let val = i64::from_ne_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - ]); - Ok(vm.ctx.new_int(val).into()) - } - _ => Ok(vm.ctx.new_int(0).into()), - } - } - - /// Convert a Python value to bytes - fn value_to_bytes(value: &PyObject, size: usize, vm: &VirtualMachine) -> PyResult> { - if let Ok(int_val) = value.try_int(vm) { - let i = int_val.as_bigint(); - match size { - 1 => { - let val = i.to_i8().unwrap_or(0); - Ok(val.to_ne_bytes().to_vec()) - } - 2 => { - let val = i.to_i16().unwrap_or(0); - Ok(val.to_ne_bytes().to_vec()) - } - 4 => { - let val = i.to_i32().unwrap_or(0); - Ok(val.to_ne_bytes().to_vec()) - } - 8 => { - let val = i.to_i64().unwrap_or(0); - Ok(val.to_ne_bytes().to_vec()) - } - _ => Ok(vec![0u8; size]), - } - } else { - Ok(vec![0u8; size]) - } - } -} - -#[pyclass( - flags(DISALLOW_INSTANTIATION, IMMUTABLETYPE), - with(Representable, GetDescriptor) -)] -impl PyCField { - #[pyslot] - fn descr_set( - zelf: &crate::PyObject, - obj: PyObjectRef, - value: PySetterValue, - vm: &VirtualMachine, - ) -> PyResult<()> { - let zelf = zelf - .downcast_ref::() - .ok_or_else(|| vm.new_type_error("expected CField".to_owned()))?; - - // Get the structure/union instance - use downcast_ref() to access the struct data - if let Some(structure) = obj.downcast_ref::() { - match value { - PySetterValue::Assign(value) => { - let offset = zelf.byte_offset; - let size = zelf.byte_size; - let bytes = PyCField::value_to_bytes(&value, size, vm)?; - - let mut cdata = structure.cdata.write(); - if offset + size <= cdata.buffer.len() { - cdata.buffer[offset..offset + size].copy_from_slice(&bytes); - } - Ok(()) - } - PySetterValue::Delete => { - Err(vm.new_type_error("cannot delete structure field".to_owned())) - } - } - } else if let Some(union) = obj.downcast_ref::() { - match value { - PySetterValue::Assign(value) => { - let offset = zelf.byte_offset; - let size = zelf.byte_size; - let bytes = PyCField::value_to_bytes(&value, size, vm)?; - - let mut cdata = union.cdata.write(); - if offset + size <= cdata.buffer.len() { - cdata.buffer[offset..offset + size].copy_from_slice(&bytes); - } - Ok(()) - } - PySetterValue::Delete => { - Err(vm.new_type_error("cannot delete union field".to_owned())) - } - } - } else { - Err(vm.new_type_error(format!( - "descriptor works only on Structure or Union instances, got {}", - obj.class().name() - ))) - } - } - - #[pymethod] - fn __set__( - zelf: PyObjectRef, - obj: PyObjectRef, - value: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<()> { - Self::descr_set(&zelf, obj, PySetterValue::Assign(value), vm) - } - - #[pymethod] - fn __delete__(zelf: PyObjectRef, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - Self::descr_set(&zelf, obj, PySetterValue::Delete, vm) - } - - #[pygetset] - fn size(&self) -> usize { - self.byte_size - } - - #[pygetset] - fn bit_size(&self) -> bool { - self.bitfield_size - } - - #[pygetset] - fn is_bitfield(&self) -> bool { - self.bitfield_size - } - - #[pygetset] - fn is_anonymous(&self) -> bool { - self.anonymous - } - - #[pygetset] - fn name(&self) -> String { - self.name.clone() - } - - #[pygetset(name = "type")] - fn type_(&self) -> PyObjectRef { - self.proto.clone() - } - - #[pygetset] - fn offset(&self) -> usize { - self.byte_offset - } - - #[pygetset] - fn byte_offset(&self) -> usize { - self.byte_offset - } - - #[pygetset] - fn byte_size(&self) -> usize { - self.byte_size - } - - #[pygetset] - fn bit_offset(&self) -> u8 { - self.bit_offset - } -} From 286a5b5b8f8759808ed5a46f0256965bddcaf45d Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 23 Dec 2025 23:28:58 +0900 Subject: [PATCH 032/418] msc_info in get_version --- Lib/platform.py | 2 +- crates/vm/src/version.rs | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/Lib/platform.py b/Lib/platform.py index 58b66078e1a..c64c6d2c6f5 100755 --- a/Lib/platform.py +++ b/Lib/platform.py @@ -1080,7 +1080,7 @@ def _sys_version(sys_version=None): match.groups() # XXX: RUSTPYTHON support - if "rustc" in sys_version: + if "RustPython" in sys_version: name = "RustPython" else: name = 'CPython' diff --git a/crates/vm/src/version.rs b/crates/vm/src/version.rs index 9d472e8be0a..deb3dccd535 100644 --- a/crates/vm/src/version.rs +++ b/crates/vm/src/version.rs @@ -15,12 +15,29 @@ pub const VERSION_HEX: usize = (MAJOR << 24) | (MINOR << 16) | (MICRO << 8) | (RELEASELEVEL_N << 4) | SERIAL; pub fn get_version() -> String { + // Windows: include MSC v. for compatibility with ctypes.util.find_library + // MSC v.1929 = VS 2019, version 14+ makes find_msvcrt() return None + #[cfg(windows)] + let msc_info = { + let arch = if cfg!(target_pointer_width = "64") { + "64 bit (AMD64)" + } else { + "32 bit (Intel)" + }; + // Include both RustPython identifier and MSC v. for compatibility + format!(" MSC v.1929 {arch}",) + }; + + #[cfg(not(windows))] + let msc_info = String::new(); + format!( - "{:.80} ({:.80}) \n[RustPython {} with {:.80}]", // \n is PyPy convention + "{:.80} ({:.80}) \n[RustPython {} with {:.80}{}]", // \n is PyPy convention get_version_number(), get_build_info(), env!("CARGO_PKG_VERSION"), COMPILER, + msc_info, ) } From 4bf0bac52d0270c0efcbc248a3888789624a2645 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 23 Dec 2025 23:29:15 +0900 Subject: [PATCH 033/418] debuggable whats_left --- whats_left.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/whats_left.py b/whats_left.py index 7d02a4cea4d..c5b0be6eadc 100755 --- a/whats_left.py +++ b/whats_left.py @@ -361,7 +361,9 @@ def method_incompatibility_reason(typ, method_name, real_method_value): if platform.python_implementation() == "CPython": if not_implementeds: - sys.exit(f"ERROR: CPython should have all the methods but missing: {not_implementeds}") + sys.exit( + f"ERROR: CPython should have all the methods but missing: {not_implementeds}" + ) mod_names = [ name.decode() From 79abbf0b29d5183b0885e45ea9991bd8d046d702 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 23 Dec 2025 23:36:38 +0900 Subject: [PATCH 034/418] fix ctypes --- crates/vm/src/stdlib/ctypes.rs | 102 ++- crates/vm/src/stdlib/ctypes/array.rs | 44 +- crates/vm/src/stdlib/ctypes/base.rs | 76 ++- crates/vm/src/stdlib/ctypes/function.rs | 776 ++++++++++++++--------- crates/vm/src/stdlib/ctypes/library.rs | 108 +++- crates/vm/src/stdlib/ctypes/pointer.rs | 147 ++++- crates/vm/src/stdlib/ctypes/simple.rs | 187 ++++-- crates/vm/src/stdlib/ctypes/structure.rs | 6 - 8 files changed, 973 insertions(+), 473 deletions(-) diff --git a/crates/vm/src/stdlib/ctypes.rs b/crates/vm/src/stdlib/ctypes.rs index 3fdb2df6104..9b922230431 100644 --- a/crates/vm/src/stdlib/ctypes.rs +++ b/crates/vm/src/stdlib/ctypes.rs @@ -95,6 +95,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { pointer::PyCPointerType::make_class(ctx); structure::PyCStructType::make_class(ctx); union::PyCUnionType::make_class(ctx); + function::PyCFuncPtrType::make_class(ctx); extend_module!(vm, &module, { "_CData" => PyCData::make_class(ctx), "_SimpleCData" => PyCSimple::make_class(ctx), @@ -385,12 +386,8 @@ pub(crate) mod _ctypes { #[pyattr] const RTLD_GLOBAL: i32 = 0; - #[cfg(target_os = "windows")] - #[pyattr] - const SIZEOF_TIME_T: usize = 8; - #[cfg(not(target_os = "windows"))] #[pyattr] - const SIZEOF_TIME_T: usize = 4; + const SIZEOF_TIME_T: usize = std::mem::size_of::(); #[pyattr] const CTYPES_MAX_ARGCOUNT: usize = 1024; @@ -578,30 +575,42 @@ pub(crate) mod _ctypes { #[pyfunction(name = "dlopen")] fn load_library_unix( name: Option, - _load_flags: OptionalArg, + load_flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - // TODO: audit functions first - // TODO: load_flags + // Default mode: RTLD_NOW | RTLD_LOCAL, always force RTLD_NOW + let mode = load_flags.unwrap_or(libc::RTLD_NOW | libc::RTLD_LOCAL) | libc::RTLD_NOW; + match name { Some(name) => { let cache = library::libcache(); let mut cache_write = cache.write(); let os_str = name.as_os_str(vm)?; - let (id, _) = cache_write.get_or_insert_lib(&*os_str, vm).map_err(|e| { - // Include filename in error message for better diagnostics - let name_str = os_str.to_string_lossy(); - vm.new_os_error(format!("{}: {}", name_str, e)) - })?; + let (id, _) = cache_write + .get_or_insert_lib_with_mode(&*os_str, mode, vm) + .map_err(|e| { + let name_str = os_str.to_string_lossy(); + vm.new_os_error(format!("{}: {}", name_str, e)) + })?; Ok(id) } None => { - // If None, call libc::dlopen(null, mode) to get the current process handle - let handle = unsafe { libc::dlopen(std::ptr::null(), libc::RTLD_NOW) }; + // dlopen(NULL, mode) to get the current process handle (for pythonapi) + let handle = unsafe { libc::dlopen(std::ptr::null(), mode) }; if handle.is_null() { - return Err(vm.new_os_error("dlopen() error")); + let err = unsafe { libc::dlerror() }; + let msg = if err.is_null() { + "dlopen() error".to_string() + } else { + unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() } + }; + return Err(vm.new_os_error(msg)); } - Ok(handle as usize) + // Add to library cache so symbol lookup works + let cache = library::libcache(); + let mut cache_write = cache.write(); + let id = cache_write.insert_raw_handle(handle); + Ok(id) } } } @@ -614,6 +623,48 @@ pub(crate) mod _ctypes { Ok(()) } + #[cfg(not(windows))] + #[pyfunction] + fn dlclose(handle: usize, _vm: &VirtualMachine) -> PyResult<()> { + // Remove from cache, which triggers SharedLibrary drop. + // libloading::Library calls dlclose automatically on Drop. + let cache = library::libcache(); + let mut cache_write = cache.write(); + cache_write.drop_lib(handle); + Ok(()) + } + + #[cfg(not(windows))] + #[pyfunction] + fn dlsym( + handle: usize, + name: crate::builtins::PyStrRef, + vm: &VirtualMachine, + ) -> PyResult { + let symbol_name = std::ffi::CString::new(name.as_str()) + .map_err(|_| vm.new_value_error("symbol name contains null byte"))?; + + // Clear previous error + unsafe { libc::dlerror() }; + + let ptr = unsafe { libc::dlsym(handle as *mut libc::c_void, symbol_name.as_ptr()) }; + + // Check for error via dlerror first + let err = unsafe { libc::dlerror() }; + if !err.is_null() { + let msg = unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() }; + return Err(vm.new_os_error(msg)); + } + + // Treat NULL symbol address as error + // This handles cases like GNU IFUNCs that resolve to NULL + if ptr.is_null() { + return Err(vm.new_os_error(format!("symbol '{}' not found", name.as_str()))); + } + + Ok(ptr as usize) + } + #[pyfunction(name = "POINTER")] fn create_pointer_type(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult { use crate::builtins::PyStr; @@ -905,25 +956,24 @@ pub(crate) mod _ctypes { #[pyfunction] fn get_errno() -> i32 { - errno::errno().0 + super::function::get_errno_value() } #[pyfunction] - fn set_errno(value: i32) { - errno::set_errno(errno::Errno(value)); + fn set_errno(value: i32) -> i32 { + super::function::set_errno_value(value) } #[cfg(windows)] #[pyfunction] fn get_last_error() -> PyResult { - Ok(unsafe { windows_sys::Win32::Foundation::GetLastError() }) + Ok(super::function::get_last_error_value()) } #[cfg(windows)] #[pyfunction] - fn set_last_error(value: u32) -> PyResult<()> { - unsafe { windows_sys::Win32::Foundation::SetLastError(value) }; - Ok(()) + fn set_last_error(value: u32) -> u32 { + super::function::set_last_error_value(value) } #[pyattr] @@ -1084,9 +1134,9 @@ pub(crate) mod _ctypes { ffi_args.push(Arg::new(val)); } - let cif = Cif::new(arg_types, Type::isize()); + let cif = Cif::new(arg_types, Type::c_int()); let code_ptr = CodePtr::from_ptr(func_addr as *const _); - let result: isize = unsafe { cif.call(code_ptr, &ffi_args) }; + let result: libc::c_int = unsafe { cif.call(code_ptr, &ffi_args) }; Ok(vm.ctx.new_int(result).into()) } diff --git a/crates/vm/src/stdlib/ctypes/array.rs b/crates/vm/src/stdlib/ctypes/array.rs index 60e6516bfe0..208b3e3f4d3 100644 --- a/crates/vm/src/stdlib/ctypes/array.rs +++ b/crates/vm/src/stdlib/ctypes/array.rs @@ -1,9 +1,12 @@ use super::StgInfo; use super::base::{CDATA_BUFFER_METHODS, PyCData}; +use super::type_info; use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, atomic_func, - builtins::{PyBytes, PyInt, PyList, PySlice, PyStr, PyType, PyTypeRef}, + builtins::{ + PyBytes, PyInt, PyList, PySlice, PyStr, PyType, PyTypeRef, genericalias::PyGenericAlias, + }, class::StaticType, function::{ArgBytesLike, FuncArgs, PySetterValue}, protocol::{BufferDescriptor, PyBuffer, PyNumberMethods, PySequenceMethods}, @@ -11,6 +14,19 @@ use crate::{ }; use num_traits::{Signed, ToPrimitive}; +/// Get itemsize from a PEP 3118 format string +/// Extracts the type code (last char after endianness prefix) and returns its size +fn get_size_from_format(fmt: &str) -> usize { + // Format is like "q", etc. - strip endianness prefix and get type code + let code = fmt + .trim_start_matches(['<', '>', '@', '=', '!', '&']) + .chars() + .next() + .map(|c| c.to_string()); + code.map(|c| type_info(&c).map(|t| t.size).unwrap_or(1)) + .unwrap_or(1) +} + /// Creates array type for (element_type, length) /// Uses _array_type_cache to ensure identical calls return the same type object pub(super) fn array_type_from_ctype( @@ -444,6 +460,11 @@ impl AsSequence for PyCArray { with(Constructor, AsSequence, AsBuffer) )] impl PyCArray { + #[pyclassmethod] + fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias { + PyGenericAlias::from_args(cls, args, vm) + } + fn int_to_bytes(i: &malachite_bigint::BigInt, size: usize) -> Vec { // Try unsigned first (handles values like 0xFFFFFFFF that overflow signed) // then fall back to signed (handles negative values) @@ -1056,19 +1077,30 @@ impl AsBuffer for PyCArray { .expect("PyCArray type must have StgInfo"); let format = stg_info.format.clone(); let shape = stg_info.shape.clone(); - let element_size = stg_info.element_size; let desc = if let Some(fmt) = format && !shape.is_empty() { + // itemsize is the size of the base element type (item_info->size) + // For empty arrays, we still need the element size, not 0 + let total_elements: usize = shape.iter().product(); + let has_zero_dim = shape.contains(&0); + let itemsize = if total_elements > 0 && buffer_len > 0 { + buffer_len / total_elements + } else { + // For empty arrays, get itemsize from format type code + get_size_from_format(&fmt) + }; + // Build dim_desc from shape (C-contiguous: row-major order) // stride[i] = product(shape[i+1:]) * itemsize + // For empty arrays (any dimension is 0), all strides are 0 let mut dim_desc = Vec::with_capacity(shape.len()); - let mut stride = element_size as isize; + let mut stride = itemsize as isize; - // Calculate strides from innermost to outermost dimension for &dim_size in shape.iter().rev() { - dim_desc.push((dim_size, stride, 0)); + let current_stride = if has_zero_dim { 0 } else { stride }; + dim_desc.push((dim_size, current_stride, 0)); stride *= dim_size as isize; } dim_desc.reverse(); @@ -1076,7 +1108,7 @@ impl AsBuffer for PyCArray { BufferDescriptor { len: buffer_len, readonly: false, - itemsize: element_size, + itemsize, format: std::borrow::Cow::Owned(fmt), dim_desc, } diff --git a/crates/vm/src/stdlib/ctypes/base.rs b/crates/vm/src/stdlib/ctypes/base.rs index 38c371346e0..44793a21561 100644 --- a/crates/vm/src/stdlib/ctypes/base.rs +++ b/crates/vm/src/stdlib/ctypes/base.rs @@ -300,15 +300,27 @@ pub(super) fn get_field_format( big_endian: bool, vm: &VirtualMachine, ) -> String { + let endian_prefix = if big_endian { ">" } else { "<" }; + // 1. Check StgInfo for format if let Some(type_obj) = field_type.downcast_ref::() && let Some(stg_info) = type_obj.stg_info_opt() && let Some(fmt) = &stg_info.format { - // Handle endian prefix for simple types - if fmt.len() == 1 { - let endian_prefix = if big_endian { ">" } else { "<" }; - return format!("{}{}", endian_prefix, fmt); + // For structures (T{...}), arrays ((n)...), and pointers (&...), return as-is + // These complex types have their own endianness markers inside + if fmt.starts_with('T') + || fmt.starts_with('(') + || fmt.starts_with('&') + || fmt.starts_with("X{") + { + return fmt.clone(); + } + + // For simple types, replace existing endian prefix with the correct one + let base_fmt = fmt.trim_start_matches(['<', '>', '@', '=', '!']); + if !base_fmt.is_empty() { + return format!("{}{}", endian_prefix, base_fmt); } return fmt.clone(); } @@ -318,8 +330,7 @@ pub(super) fn get_field_format( && let Some(type_str) = type_attr.downcast_ref::() { let s = type_str.as_str(); - if s.len() == 1 { - let endian_prefix = if big_endian { ">" } else { "<" }; + if !s.is_empty() { return format!("{}{}", endian_prefix, s); } return s.to_string(); @@ -1168,29 +1179,30 @@ impl PyCData { .ok_or_else(|| vm.new_value_error("Invalid library handle"))? }; - // Get symbol address using platform-specific API - let symbol_name = std::ffi::CString::new(name.as_str()) - .map_err(|_| vm.new_value_error("Invalid symbol name"))?; - - #[cfg(windows)] - let ptr: *const u8 = unsafe { - match windows_sys::Win32::System::LibraryLoader::GetProcAddress( - handle as windows_sys::Win32::Foundation::HMODULE, - symbol_name.as_ptr() as *const u8, - ) { - Some(p) => p as *const u8, - None => std::ptr::null(), + // Look up the library in the cache and use lib.get() for symbol lookup + let library_cache = super::library::libcache().read(); + let library = library_cache + .get_lib(handle) + .ok_or_else(|| vm.new_value_error("Library not found"))?; + let inner_lib = library.lib.lock(); + + let symbol_name_with_nul = format!("{}\0", name.as_str()); + let ptr: *const u8 = if let Some(lib) = &*inner_lib { + unsafe { + lib.get::<*const u8>(symbol_name_with_nul.as_bytes()) + .map(|sym| *sym) + .map_err(|_| { + vm.new_value_error(format!("symbol '{}' not found", name.as_str())) + })? } + } else { + return Err(vm.new_value_error("Library closed")); }; - #[cfg(not(windows))] - let ptr: *const u8 = - unsafe { libc::dlsym(handle as *mut libc::c_void, symbol_name.as_ptr()) as *const u8 }; - + // dlsym can return NULL for symbols that resolve to NULL (e.g., GNU IFUNC) + // Treat NULL addresses as errors if ptr.is_null() { - return Err( - vm.new_value_error(format!("symbol '{}' not found in library", name.as_str())) - ); + return Err(vm.new_value_error(format!("symbol '{}' not found", name.as_str()))); } // PyCData_AtAddress @@ -1593,7 +1605,7 @@ impl PyCField { /// PyCField_set #[pyslot] fn descr_set( - zelf: &crate::PyObject, + zelf: &PyObject, obj: PyObjectRef, value: PySetterValue, vm: &VirtualMachine, @@ -1804,7 +1816,7 @@ pub enum FfiArgValue { F64(f64), Pointer(usize), /// Pointer with owned data. The PyObjectRef keeps the pointed data alive. - OwnedPointer(usize, #[allow(dead_code)] crate::PyObjectRef), + OwnedPointer(usize, #[allow(dead_code)] PyObjectRef), } impl FfiArgValue { @@ -2145,6 +2157,16 @@ pub(super) fn read_ptr_from_buffer(buffer: &[u8]) -> usize { } } +/// Check if a type is a "simple instance" (direct subclass of a simple type) +/// Returns TRUE for c_int, c_void_p, etc. (simple types with _type_ attribute) +/// Returns FALSE for Structure, Array, POINTER(T), etc. +pub(super) fn is_simple_instance(typ: &Py) -> bool { + // _ctypes_simple_instance + // Check if the type's metaclass is PyCSimpleType + let metaclass = typ.class(); + metaclass.fast_issubclass(super::simple::PyCSimpleType::static_type()) +} + /// Set or initialize StgInfo on a type pub(super) fn set_or_init_stginfo(type_ref: &PyType, stg_info: StgInfo) { if type_ref.init_type_data(stg_info.clone()).is_err() diff --git a/crates/vm/src/stdlib/ctypes/function.rs b/crates/vm/src/stdlib/ctypes/function.rs index 9bddb0ef0e8..04ff238ebcf 100644 --- a/crates/vm/src/stdlib/ctypes/function.rs +++ b/crates/vm/src/stdlib/ctypes/function.rs @@ -1,16 +1,19 @@ // spell-checker:disable use super::{ - _ctypes::CArgObject, PyCArray, PyCData, PyCPointer, PyCStructure, base::FfiArgValue, - simple::PyCSimple, type_info, + _ctypes::CArgObject, + PyCArray, PyCData, PyCPointer, PyCStructure, StgInfo, + base::{CDATA_BUFFER_METHODS, FfiArgValue, ParamFunc, StgInfoFlags}, + simple::PyCSimple, + type_info, }; use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyBytes, PyDict, PyNone, PyStr, PyTuple, PyType, PyTypeRef}, class::StaticType, - convert::ToPyObject, function::FuncArgs, - types::{Callable, Constructor, Representable}, + protocol::{BufferDescriptor, PyBuffer}, + types::{AsBuffer, Callable, Constructor, Initializer, Representable}, vm::thread::with_current_vm, }; use libffi::{ @@ -18,9 +21,10 @@ use libffi::{ middle::{Arg, Cif, Closure, CodePtr, Type}, }; use libloading::Symbol; -use num_traits::ToPrimitive; +use num_traits::{Signed, ToPrimitive}; use rustpython_common::lock::PyRwLock; -use std::ffi::{self, c_void}; +use std::borrow::Cow; +use std::ffi::c_void; use std::fmt::Debug; // Internal function addresses for special ctypes functions @@ -28,6 +32,90 @@ pub(super) const INTERNAL_CAST_ADDR: usize = 1; pub(super) const INTERNAL_STRING_AT_ADDR: usize = 2; pub(super) const INTERNAL_WSTRING_AT_ADDR: usize = 3; +// Thread-local errno storage for ctypes +std::thread_local! { + /// Thread-local storage for ctypes errno + /// This is separate from the system errno - ctypes swaps them during FFI calls + /// when use_errno=True is specified. + static CTYPES_LOCAL_ERRNO: std::cell::Cell = const { std::cell::Cell::new(0) }; +} + +/// Get ctypes thread-local errno value +pub(super) fn get_errno_value() -> i32 { + CTYPES_LOCAL_ERRNO.with(|e| e.get()) +} + +/// Set ctypes thread-local errno value, returns old value +pub(super) fn set_errno_value(value: i32) -> i32 { + CTYPES_LOCAL_ERRNO.with(|e| { + let old = e.get(); + e.set(value); + old + }) +} + +/// Save and restore errno around FFI call (called when use_errno=True) +/// Before: restore thread-local errno to system +/// After: save system errno to thread-local +#[cfg(not(windows))] +fn swap_errno(f: F) -> R +where + F: FnOnce() -> R, +{ + // Before call: restore thread-local errno to system + let saved = CTYPES_LOCAL_ERRNO.with(|e| e.get()); + errno::set_errno(errno::Errno(saved)); + + // Call the function + let result = f(); + + // After call: save system errno to thread-local + let new_error = errno::errno().0; + CTYPES_LOCAL_ERRNO.with(|e| e.set(new_error)); + + result +} + +#[cfg(windows)] +std::thread_local! { + /// Thread-local storage for ctypes last_error (Windows only) + static CTYPES_LOCAL_LAST_ERROR: std::cell::Cell = const { std::cell::Cell::new(0) }; +} + +#[cfg(windows)] +pub(super) fn get_last_error_value() -> u32 { + CTYPES_LOCAL_LAST_ERROR.with(|e| e.get()) +} + +#[cfg(windows)] +pub(super) fn set_last_error_value(value: u32) -> u32 { + CTYPES_LOCAL_LAST_ERROR.with(|e| { + let old = e.get(); + e.set(value); + old + }) +} + +/// Save and restore last_error around FFI call (called when use_last_error=True) +#[cfg(windows)] +fn save_and_restore_last_error(f: F) -> R +where + F: FnOnce() -> R, +{ + // Before call: restore thread-local last_error to Windows + let saved = CTYPES_LOCAL_LAST_ERROR.with(|e| e.get()); + unsafe { windows_sys::Win32::Foundation::SetLastError(saved) }; + + // Call the function + let result = f(); + + // After call: save Windows last_error to thread-local + let new_error = unsafe { windows_sys::Win32::Foundation::GetLastError() }; + CTYPES_LOCAL_LAST_ERROR.with(|e| e.set(new_error)); + + result +} + type FP = unsafe extern "C" fn(); /// Get FFI type for a ctypes type code @@ -131,11 +219,19 @@ fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult direct value + // 7. Integer -> direct value (PyLong_AsVoidPtr behavior) if let Ok(int_val) = value.try_int(vm) { - return Ok(FfiArgValue::Pointer( - int_val.as_bigint().to_usize().unwrap_or(0), - )); + let bigint = int_val.as_bigint(); + // Negative values: use signed conversion (allows -1 as 0xFFFF...) + if bigint.is_negative() { + if let Some(signed_val) = bigint.to_isize() { + return Ok(FfiArgValue::Pointer(signed_val as usize)); + } + } else if let Some(unsigned_val) = bigint.to_usize() { + return Ok(FfiArgValue::Pointer(unsigned_val)); + } + // Value out of range - raise OverflowError + return Err(vm.new_overflow_error("int too large to convert to pointer".to_string())); } // 8. Check _as_parameter_ attribute ( recursive ConvParam) @@ -150,46 +246,86 @@ fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult PyResult<(Type, FfiArgValue)> { +/// Returns an Argument with FFI type, value, and optional keep object +fn conv_param(value: &PyObject, vm: &VirtualMachine) -> PyResult { // 1. CArgObject (from byref() or paramfunc) -> use stored type and value if let Some(carg) = value.downcast_ref::() { let ffi_type = ffi_type_from_tag(carg.tag); - return Ok((ffi_type, carg.value.clone())); + return Ok(Argument { + ffi_type, + keep: None, + value: carg.value.clone(), + }); } // 2. None -> NULL pointer if value.is(&vm.ctx.none) { - return Ok((Type::pointer(), FfiArgValue::Pointer(0))); + return Ok(Argument { + ffi_type: Type::pointer(), + keep: None, + value: FfiArgValue::Pointer(0), + }); } // 3. ctypes objects -> use paramfunc if let Ok(carg) = super::base::call_paramfunc(value, vm) { let ffi_type = ffi_type_from_tag(carg.tag); - return Ok((ffi_type, carg.value.clone())); + return Ok(Argument { + ffi_type, + keep: None, + value: carg.value.clone(), + }); } - // 4. Python str -> pointer (use internal UTF-8 buffer) + // 4. Python str -> wide string pointer (like PyUnicode_AsWideCharString) if let Some(s) = value.downcast_ref::() { - let addr = s.as_str().as_ptr() as usize; - return Ok((Type::pointer(), FfiArgValue::Pointer(addr))); + // Convert to null-terminated UTF-16 (wide string) + let wide: Vec = s + .as_str() + .encode_utf16() + .chain(std::iter::once(0)) + .collect(); + let wide_bytes: Vec = wide.iter().flat_map(|&x| x.to_ne_bytes()).collect(); + let keep = vm.ctx.new_bytes(wide_bytes); + let addr = keep.as_bytes().as_ptr() as usize; + return Ok(Argument { + ffi_type: Type::pointer(), + keep: Some(keep.into()), + value: FfiArgValue::Pointer(addr), + }); } - // 9. Python bytes -> pointer to buffer + // 9. Python bytes -> null-terminated buffer pointer + // Need to ensure null termination like c_char_p if let Some(bytes) = value.downcast_ref::() { - let addr = bytes.as_bytes().as_ptr() as usize; - return Ok((Type::pointer(), FfiArgValue::Pointer(addr))); + let mut buffer = bytes.as_bytes().to_vec(); + buffer.push(0); // Add null terminator + let keep = vm.ctx.new_bytes(buffer); + let addr = keep.as_bytes().as_ptr() as usize; + return Ok(Argument { + ffi_type: Type::pointer(), + keep: Some(keep.into()), + value: FfiArgValue::Pointer(addr), + }); } // 10. Python int -> i32 (default integer type) if let Ok(int_val) = value.try_int(vm) { let val = int_val.as_bigint().to_i32().unwrap_or(0); - return Ok((Type::i32(), FfiArgValue::I32(val))); + return Ok(Argument { + ffi_type: Type::i32(), + keep: None, + value: FfiArgValue::I32(val), + }); } // 11. Python float -> f64 if let Ok(float_val) = value.try_float(vm) { - return Ok((Type::f64(), FfiArgValue::F64(float_val.to_f64()))); + return Ok(Argument { + ffi_type: Type::f64(), + keep: None, + value: FfiArgValue::F64(float_val.to_f64()), + }); } // 12. Check _as_parameter_ attribute @@ -245,7 +381,7 @@ impl ArgumentType for PyTypeRef { } fn convert_object(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { - // Call from_param first to convert the value (like CPython's callproc.c:1235) + // Call from_param first to convert the value // converter = PyTuple_GET_ITEM(argtypes, i); // v = PyObject_CallOneArg(converter, arg); let from_param = self @@ -264,11 +400,10 @@ impl ArgumentType for PyTypeRef { return Ok(FfiArgValue::Pointer(0)); } - // For pointer types (POINTER(T)), we need to pass the ADDRESS of the value's buffer + // For pointer types (POINTER(T)), we need to pass the pointer VALUE stored in buffer if self.fast_issubclass(PyCPointer::static_type()) { - if let Some(cdata) = converted.downcast_ref::() { - let addr = cdata.buffer.read().as_ptr() as usize; - return Ok(FfiArgValue::Pointer(addr)); + if let Some(pointer) = converted.downcast_ref::() { + return Ok(FfiArgValue::Pointer(pointer.get_ptr_value())); } return convert_to_pointer(&converted, vm); } @@ -305,12 +440,6 @@ impl ArgumentType for PyTypeRef { trait ReturnType { fn to_ffi_type(&self, vm: &VirtualMachine) -> Option; - #[allow(clippy::wrong_self_convention)] - fn from_ffi_type( - &self, - value: *mut ffi::c_void, - vm: &VirtualMachine, - ) -> PyResult>; } impl ReturnType for PyTypeRef { @@ -343,130 +472,56 @@ impl ReturnType for PyTypeRef { // Fallback to class name get_ffi_type(self.name().to_string().as_str()) } - - fn from_ffi_type( - &self, - value: *mut ffi::c_void, - vm: &VirtualMachine, - ) -> PyResult> { - // Get the type code from _type_ attribute (use get_attr to traverse MRO) - let type_code = self - .as_object() - .get_attr(vm.ctx.intern_str("_type_"), vm) - .ok() - .and_then(|t| t.downcast_ref::().map(|s| s.to_string())); - - let result = match type_code.as_deref() { - Some("b") => vm - .ctx - .new_int(unsafe { *(value as *const i8) } as i32) - .into(), - Some("B") => vm - .ctx - .new_int(unsafe { *(value as *const u8) } as i32) - .into(), - Some("c") => vm - .ctx - .new_bytes(vec![unsafe { *(value as *const u8) }]) - .into(), - Some("h") => vm - .ctx - .new_int(unsafe { *(value as *const i16) } as i32) - .into(), - Some("H") => vm - .ctx - .new_int(unsafe { *(value as *const u16) } as i32) - .into(), - Some("i") => vm.ctx.new_int(unsafe { *(value as *const i32) }).into(), - Some("I") => vm.ctx.new_int(unsafe { *(value as *const u32) }).into(), - Some("l") => vm - .ctx - .new_int(unsafe { *(value as *const libc::c_long) }) - .into(), - Some("L") => vm - .ctx - .new_int(unsafe { *(value as *const libc::c_ulong) }) - .into(), - Some("q") => vm - .ctx - .new_int(unsafe { *(value as *const libc::c_longlong) }) - .into(), - Some("Q") => vm - .ctx - .new_int(unsafe { *(value as *const libc::c_ulonglong) }) - .into(), - Some("f") => vm - .ctx - .new_float(unsafe { *(value as *const f32) } as f64) - .into(), - Some("d") => vm.ctx.new_float(unsafe { *(value as *const f64) }).into(), - Some("P") | Some("z") | Some("Z") => { - vm.ctx.new_int(unsafe { *(value as *const usize) }).into() - } - Some("?") => vm - .ctx - .new_bool(unsafe { *(value as *const u8) } != 0) - .into(), - None => { - // No _type_ attribute - check for Structure/Array types - // GetResult: PyCData_FromBaseObj creates instance from memory - if let Some(stg_info) = self.stg_info_opt() { - let size = stg_info.size; - // Create instance of the ctypes type - let instance = self.as_object().call((), vm)?; - - // Copy return value memory into instance buffer - // Use a block to properly scope the borrow - { - let src = unsafe { std::slice::from_raw_parts(value as *const u8, size) }; - if let Some(cdata) = instance.downcast_ref::() { - let mut buffer = cdata.buffer.write(); - if buffer.len() >= size { - buffer.to_mut()[..size].copy_from_slice(src); - } - } else if let Some(structure) = instance.downcast_ref::() { - let mut buffer = structure.0.buffer.write(); - if buffer.len() >= size { - buffer.to_mut()[..size].copy_from_slice(src); - } - } else if let Some(array) = instance.downcast_ref::() { - let mut buffer = array.0.buffer.write(); - if buffer.len() >= size { - buffer.to_mut()[..size].copy_from_slice(src); - } - } - } - return Ok(Some(instance)); - } - // Not a ctypes type - call type with int result - return self - .as_object() - .call((unsafe { *(value as *const i32) },), vm) - .map(Some); - } - _ => return Err(vm.new_type_error("Unsupported return type")), - }; - Ok(Some(result)) - } } impl ReturnType for PyNone { fn to_ffi_type(&self, _vm: &VirtualMachine) -> Option { get_ffi_type("void") } +} + +// PyCFuncPtrType - Metaclass for function pointer types +// PyCFuncPtrType_init + +#[pyclass(name = "PyCFuncPtrType", base = PyType, module = "_ctypes")] +#[derive(Debug)] +#[repr(transparent)] +pub(super) struct PyCFuncPtrType(PyType); - fn from_ffi_type( - &self, - _value: *mut ffi::c_void, - _vm: &VirtualMachine, - ) -> PyResult> { - Ok(None) +impl Initializer for PyCFuncPtrType { + type Args = FuncArgs; + + fn init(zelf: PyRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { + let obj: PyObjectRef = zelf.clone().into(); + let new_type: PyTypeRef = obj + .downcast() + .map_err(|_| vm.new_type_error("expected type"))?; + + new_type.check_not_initialized(vm)?; + + let ptr_size = std::mem::size_of::(); + let mut stg_info = StgInfo::new(ptr_size, ptr_size); + stg_info.format = Some("X{}".to_string()); + stg_info.length = 1; + stg_info.flags |= StgInfoFlags::TYPEFLAG_ISPOINTER; + stg_info.paramfunc = ParamFunc::Pointer; // CFuncPtr is passed as a pointer + + let _ = new_type.init_type_data(stg_info); + Ok(()) } } +#[pyclass(flags(IMMUTABLETYPE), with(Initializer))] +impl PyCFuncPtrType {} + /// PyCFuncPtr - Function pointer instance /// Saved in _base.buffer -#[pyclass(module = "_ctypes", name = "CFuncPtr", base = PyCData)] +#[pyclass( + module = "_ctypes", + name = "CFuncPtr", + base = PyCData, + metaclass = "PyCFuncPtrType" +)] #[repr(C)] pub(super) struct PyCFuncPtr { pub _base: PyCData, @@ -892,7 +947,13 @@ impl Constructor for PyCFuncPtr { .map_err(|err| err.to_string()) .map_err(|err| vm.new_attribute_error(err))? }; - *pointer as usize + let addr = *pointer as usize; + // dlsym can return NULL for symbols that resolve to NULL (e.g., GNU IFUNC) + // Treat NULL addresses as errors + if addr == 0 { + return Err(vm.new_attribute_error(format!("function '{}' not found", name))); + } + addr } else { 0 }; @@ -921,12 +982,17 @@ impl Constructor for PyCFuncPtr { // Get argument types and result type from the class let class_argtypes = cls.get_attr(vm.ctx.intern_str("_argtypes_")); let class_restype = cls.get_attr(vm.ctx.intern_str("_restype_")); + let class_flags = cls + .get_attr(vm.ctx.intern_str("_flags_")) + .and_then(|f| f.try_to_value::(vm).ok()) + .unwrap_or(0); // Create the thunk (C-callable wrapper for the Python function) let thunk = PyCThunk::new( first_arg.clone(), class_argtypes.clone(), class_restype.clone(), + class_flags, vm, )?; let code_ptr = thunk.code_ptr(); @@ -1060,13 +1126,16 @@ fn extract_call_info(zelf: &Py, vm: &VirtualMachine) -> PyResult().ok()) - .and_then(|t| t.as_object().get_attr(vm.ctx.intern_str("_type_"), vm).ok()) - .and_then(|t| t.downcast_ref::().map(|s| s.to_string())) - .is_some_and(|tc| matches!(tc.as_str(), "P" | "z" | "Z")); + .and_then(|t| { + t.stg_info_opt() + .map(|info| info.flags.contains(StgInfoFlags::TYPEFLAG_ISPOINTER)) + }) + .unwrap_or(false); Ok(CallInfo { explicit_arg_types, @@ -1178,13 +1247,18 @@ fn resolve_com_method( Ok((Some(CodePtr(fptr as *mut _)), true)) } -/// Prepared arguments for FFI call -struct PreparedArgs { - ffi_arg_types: Vec, - ffi_values: Vec, - out_buffers: Vec<(usize, PyObjectRef)>, +/// Single argument for FFI call +// struct argument +struct Argument { + ffi_type: Type, + value: FfiArgValue, + #[allow(dead_code)] + keep: Option, // Object to keep alive during call } +/// Out buffers for paramflags OUT parameters +type OutBuffers = Vec<(usize, PyObjectRef)>; + /// Get buffer address from a ctypes object fn get_buffer_addr(obj: &PyObjectRef) -> Option { obj.downcast_ref::() @@ -1213,18 +1287,16 @@ fn create_out_buffer(arg_type: &PyTypeRef, vm: &VirtualMachine) -> PyResult PyResult { - let results: Vec<(Type, FfiArgValue)> = args +fn build_callargs_no_argtypes( + args: &FuncArgs, + vm: &VirtualMachine, +) -> PyResult<(Vec, OutBuffers)> { + let arguments: Vec = args .args .iter() .map(|arg| conv_param(arg, vm)) .collect::>>()?; - let (ffi_arg_types, ffi_values) = results.into_iter().unzip(); - Ok(PreparedArgs { - ffi_arg_types, - ffi_values, - out_buffers: Vec::new(), - }) + Ok((arguments, Vec::new())) } /// Build callargs for regular function with argtypes (no paramflags) @@ -1232,12 +1304,8 @@ fn build_callargs_simple( args: &FuncArgs, arg_types: &[PyTypeRef], vm: &VirtualMachine, -) -> PyResult { - let ffi_arg_types = arg_types - .iter() - .map(|t| ArgumentType::to_ffi_type(t, vm)) - .collect::>>()?; - let ffi_values = args +) -> PyResult<(Vec, OutBuffers)> { + let arguments: Vec = args .args .iter() .enumerate() @@ -1245,14 +1313,16 @@ fn build_callargs_simple( let arg_type = arg_types .get(n) .ok_or_else(|| vm.new_type_error("argument amount mismatch"))?; - arg_type.convert_object(arg.clone(), vm) + let ffi_type = ArgumentType::to_ffi_type(arg_type, vm)?; + let value = arg_type.convert_object(arg.clone(), vm)?; + Ok(Argument { + ffi_type, + keep: None, + value, + }) }) - .collect::, _>>()?; - Ok(PreparedArgs { - ffi_arg_types, - ffi_values, - out_buffers: Vec::new(), - }) + .collect::>>()?; + Ok((arguments, Vec::new())) } /// Build callargs with paramflags (handles IN/OUT parameters) @@ -1262,27 +1332,21 @@ fn build_callargs_with_paramflags( paramflags: &ParsedParamFlags, skip_first_arg: bool, // true for COM methods vm: &VirtualMachine, -) -> PyResult { - let mut ffi_arg_types = Vec::new(); - let mut ffi_values = Vec::new(); +) -> PyResult<(Vec, OutBuffers)> { + let mut arguments = Vec::new(); let mut out_buffers = Vec::new(); // For COM methods, first arg is self (pointer) let mut caller_arg_idx = if skip_first_arg { - ffi_arg_types.push(Type::pointer()); if !args.args.is_empty() { - ffi_values.push(conv_param(&args.args[0], vm)?.1); + let arg = conv_param(&args.args[0], vm)?; + arguments.push(arg); } 1usize } else { 0usize }; - // Add FFI types for all argtypes - for arg_type in arg_types { - ffi_arg_types.push(ArgumentType::to_ffi_type(arg_type, vm)?); - } - // Process parameters based on paramflags for (param_idx, (direction, _name, default)) in paramflags.iter().enumerate() { let arg_type = arg_types @@ -1292,13 +1356,19 @@ fn build_callargs_with_paramflags( let is_out = (*direction & 2) != 0; // OUT flag let is_in = (*direction & 1) != 0 || *direction == 0; // IN flag or default + let ffi_type = ArgumentType::to_ffi_type(arg_type, vm)?; + if is_out && !is_in { // Pure OUT parameter: create buffer, don't consume caller arg let buffer = create_out_buffer(arg_type, vm)?; let addr = get_buffer_addr(&buffer).ok_or_else(|| { vm.new_type_error("Cannot create OUT buffer for this type".to_string()) })?; - ffi_values.push(FfiArgValue::Pointer(addr)); + arguments.push(Argument { + ffi_type, + keep: None, + value: FfiArgValue::Pointer(addr), + }); out_buffers.push((param_idx, buffer)); } else { // IN or IN|OUT: get from caller args or default @@ -1315,15 +1385,16 @@ fn build_callargs_with_paramflags( // IN|OUT: track for return out_buffers.push((param_idx, arg.clone())); } - ffi_values.push(arg_type.convert_object(arg, vm)?); + let value = arg_type.convert_object(arg, vm)?; + arguments.push(Argument { + ffi_type, + keep: None, + value, + }); } } - Ok(PreparedArgs { - ffi_arg_types, - ffi_values, - out_buffers, - }) + Ok((arguments, out_buffers)) } /// Build call arguments (main dispatcher) @@ -1333,7 +1404,7 @@ fn build_callargs( paramflags: Option<&ParsedParamFlags>, is_com_method: bool, vm: &VirtualMachine, -) -> PyResult { +) -> PyResult<(Vec, OutBuffers)> { let Some(ref arg_types) = call_info.explicit_arg_types else { // No argtypes: use ConvParam return build_callargs_no_argtypes(args, vm); @@ -1344,28 +1415,23 @@ fn build_callargs( build_callargs_with_paramflags(args, arg_types, pflags, is_com_method, vm) } else if is_com_method { // COM method without paramflags - let mut ffi_types = vec![Type::pointer()]; - ffi_types.extend( - arg_types - .iter() - .map(|t| ArgumentType::to_ffi_type(t, vm)) - .collect::>>()?, - ); - let mut ffi_vals = Vec::new(); + let mut arguments = Vec::new(); if !args.args.is_empty() { - ffi_vals.push(conv_param(&args.args[0], vm)?.1); + arguments.push(conv_param(&args.args[0], vm)?); } for (n, arg) in args.args.iter().skip(1).enumerate() { let arg_type = arg_types .get(n) .ok_or_else(|| vm.new_type_error("argument amount mismatch"))?; - ffi_vals.push(arg_type.convert_object(arg.clone(), vm)?); + let ffi_type = ArgumentType::to_ffi_type(arg_type, vm)?; + let value = arg_type.convert_object(arg.clone(), vm)?; + arguments.push(Argument { + ffi_type, + keep: None, + value, + }); } - Ok(PreparedArgs { - ffi_arg_types: ffi_types, - ffi_values: ffi_vals, - out_buffers: Vec::new(), - }) + Ok((arguments, Vec::new())) } else { // Regular function build_callargs_simple(args, arg_types, vm) @@ -1380,12 +1446,10 @@ enum RawResult { } /// Execute FFI call -fn ctypes_callproc(code_ptr: CodePtr, prepared: &PreparedArgs, call_info: &CallInfo) -> RawResult { - let cif = Cif::new( - prepared.ffi_arg_types.clone(), - call_info.ffi_return_type.clone(), - ); - let ffi_args: Vec = prepared.ffi_values.iter().map(|v| v.as_arg()).collect(); +fn ctypes_callproc(code_ptr: CodePtr, arguments: &[Argument], call_info: &CallInfo) -> RawResult { + let ffi_arg_types: Vec = arguments.iter().map(|a| a.ffi_type.clone()).collect(); + let cif = Cif::new(ffi_arg_types, call_info.ffi_return_type.clone()); + let ffi_args: Vec = arguments.iter().map(|a| a.value.as_arg()).collect(); if call_info.restype_is_none { unsafe { cif.call::<()>(code_ptr, &ffi_args) }; @@ -1438,73 +1502,118 @@ fn check_hresult(hresult: i32, zelf: &Py, vm: &VirtualMachine) -> Py } /// Convert raw FFI result to Python object +// = GetResult fn convert_raw_result( raw_result: &mut RawResult, call_info: &CallInfo, vm: &VirtualMachine, ) -> Option { - match raw_result { - RawResult::Void => None, + // Get result as bytes for type conversion + let (result_bytes, result_size) = match raw_result { + RawResult::Void => return None, RawResult::Pointer(ptr) => { - // Get type code from restype to determine conversion method - let type_code = call_info - .restype_obj - .as_ref() - .and_then(|t| t.clone().downcast::().ok()) - .and_then(|t| t.as_object().get_attr(vm.ctx.intern_str("_type_"), vm).ok()) - .and_then(|t| t.downcast_ref::().map(|s| s.to_string())); - - match type_code.as_deref() { - Some("z") => { - // c_char_p: NULL -> None, otherwise read C string -> bytes - if *ptr == 0 { - Some(vm.ctx.none()) - } else { - let cstr = unsafe { std::ffi::CStr::from_ptr(*ptr as _) }; - Some(vm.ctx.new_bytes(cstr.to_bytes().to_vec()).into()) - } - } - Some("Z") => { - // c_wchar_p: NULL -> None, otherwise read wide string -> str - if *ptr == 0 { - Some(vm.ctx.none()) - } else { - let wstr_ptr = *ptr as *const libc::wchar_t; - let mut len = 0; - unsafe { - while *wstr_ptr.add(len) != 0 { - len += 1; - } - } - let slice = unsafe { std::slice::from_raw_parts(wstr_ptr, len) }; - let s: String = slice - .iter() - .filter_map(|&c| char::from_u32(c as u32)) - .collect(); - Some(vm.ctx.new_str(s).into()) - } - } - _ => { - // c_void_p ("P") and other pointer types: NULL -> None, otherwise int - if *ptr == 0 { - Some(vm.ctx.none()) - } else { - Some(vm.ctx.new_int(*ptr).into()) - } - } - } + let bytes = ptr.to_ne_bytes(); + (bytes.to_vec(), std::mem::size_of::()) } - RawResult::Value(val) => call_info - .restype_obj - .as_ref() - .and_then(|f| f.clone().downcast::().ok()) - .map(|f| { - f.from_ffi_type(val as *mut _ as *mut c_void, vm) - .ok() - .flatten() - }) - .unwrap_or_else(|| Some(vm.ctx.new_int(*val as usize).as_object().to_pyobject(vm))), + RawResult::Value(val) => { + let bytes = val.to_ne_bytes(); + (bytes.to_vec(), std::mem::size_of::()) + } + }; + + // 1. No restype → return as int + let restype = match &call_info.restype_obj { + None => { + // Default: return as int + let val = match raw_result { + RawResult::Pointer(p) => *p as isize, + RawResult::Value(v) => *v as isize, + RawResult::Void => return None, + }; + return Some(vm.ctx.new_int(val).into()); + } + Some(r) => r, + }; + + // 2. restype is None → return None + if restype.is(&vm.ctx.none()) { + return None; + } + + // 3. Get restype as PyType + let restype_type = match restype.clone().downcast::() { + Ok(t) => t, + Err(_) => { + // Not a type, call it with int result + let val = match raw_result { + RawResult::Pointer(p) => *p as isize, + RawResult::Value(v) => *v as isize, + RawResult::Void => return None, + }; + return restype.call((val,), vm).ok(); + } + }; + + // 4. Get StgInfo + let stg_info = restype_type.stg_info_opt(); + + // No StgInfo → call restype with int + if stg_info.is_none() { + let val = match raw_result { + RawResult::Pointer(p) => *p as isize, + RawResult::Value(v) => *v as isize, + RawResult::Void => return None, + }; + return restype_type.as_object().call((val,), vm).ok(); + } + + let info = stg_info.unwrap(); + + // 5. Simple type with getfunc → use bytes_to_pyobject (info->getfunc) + // is_simple_instance returns TRUE for c_int, c_void_p, etc. + if super::base::is_simple_instance(&restype_type) { + return super::base::bytes_to_pyobject(&restype_type, &result_bytes, vm).ok(); + } + + // 6. Complex type → create ctypes instance (PyCData_FromBaseObj) + // This handles POINTER(T), Structure, Array, etc. + + // Special handling for POINTER(T) types - set pointer value directly + if info.flags.contains(StgInfoFlags::TYPEFLAG_ISPOINTER) + && info.proto.is_some() + && let RawResult::Pointer(ptr) = raw_result + && let Ok(instance) = restype_type.as_object().call((), vm) + { + if let Some(pointer) = instance.downcast_ref::() { + pointer.set_ptr_value(*ptr); + } + return Some(instance); } + + // Create instance and copy result data + pycdata_from_ffi_result(&restype_type, &result_bytes, result_size, vm).ok() +} + +/// Create a ctypes instance from FFI result (PyCData_FromBaseObj equivalent) +fn pycdata_from_ffi_result( + typ: &PyTypeRef, + result_bytes: &[u8], + size: usize, + vm: &VirtualMachine, +) -> PyResult { + // Create instance + let instance = PyType::call(typ, ().into(), vm)?; + + // Copy result data into instance buffer + if let Some(cdata) = instance.downcast_ref::() { + let mut buffer = cdata.buffer.write(); + let copy_size = size.min(buffer.len()).min(result_bytes.len()); + if copy_size > 0 { + buffer.to_mut()[..copy_size].copy_from_slice(&result_bytes[..copy_size]); + } + } + + Ok(instance) } /// Extract values from OUT buffers @@ -1522,7 +1631,7 @@ fn extract_out_values( fn build_result( mut raw_result: RawResult, call_info: &CallInfo, - prepared: PreparedArgs, + out_buffers: OutBuffers, zelf: &Py, args: &FuncArgs, vm: &VirtualMachine, @@ -1552,11 +1661,11 @@ fn build_result( } // Handle OUT parameter return values - if prepared.out_buffers.is_empty() { + if out_buffers.is_empty() { return result.map(Ok).unwrap_or_else(|| Ok(vm.ctx.none())); } - let out_values = extract_out_values(prepared.out_buffers, vm); + let out_values = extract_out_values(out_buffers, vm); Ok(match <[PyObjectRef; 1]>::try_from(out_values) { Ok([single]) => single, Err(v) => PyTuple::new_ref(v, &vm.ctx).into(), @@ -1584,23 +1693,43 @@ impl Callable for PyCFuncPtr { let paramflags = parse_paramflags(zelf, vm)?; // 5. Build call arguments - let prepared = build_callargs(&args, &call_info, paramflags.as_ref(), is_com_method, vm)?; + let (arguments, out_buffers) = + build_callargs(&args, &call_info, paramflags.as_ref(), is_com_method, vm)?; // 6. Get code pointer let code_ptr = match func_ptr.or_else(|| zelf.get_code_ptr()) { Some(cp) => cp, None => { debug_assert!(false, "NULL function pointer"); - // In release mode, this will crash like CPython + // In release mode, this will crash CodePtr(std::ptr::null_mut()) } }; - // 7. Call the function - let raw_result = ctypes_callproc(code_ptr, &prepared, &call_info); + // 7. Get flags to check for use_last_error/use_errno + let flags = PyCFuncPtr::_flags_(zelf, vm); - // 8. Build result - build_result(raw_result, &call_info, prepared, zelf, &args, vm) + // 8. Call the function (with use_last_error/use_errno handling) + #[cfg(not(windows))] + let raw_result = { + if flags & super::base::StgInfoFlags::FUNCFLAG_USE_ERRNO.bits() != 0 { + swap_errno(|| ctypes_callproc(code_ptr, &arguments, &call_info)) + } else { + ctypes_callproc(code_ptr, &arguments, &call_info) + } + }; + + #[cfg(windows)] + let raw_result = { + if flags & super::base::StgInfoFlags::FUNCFLAG_USE_LASTERROR.bits() != 0 { + save_and_restore_last_error(|| ctypes_callproc(code_ptr, &arguments, &call_info)) + } else { + ctypes_callproc(code_ptr, &arguments, &call_info) + } + }; + + // 9. Build result + build_result(raw_result, &call_info, out_buffers, zelf, &args, vm) } } @@ -1614,7 +1743,36 @@ impl Representable for PyCFuncPtr { } } -#[pyclass(flags(BASETYPE), with(Callable, Constructor, Representable))] +// PyCData_NewGetBuffer +impl AsBuffer for PyCFuncPtr { + fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + // CFuncPtr types may not have StgInfo if PyCFuncPtrType metaclass is not used + // Use default values for function pointers: format="X{}", size=sizeof(pointer) + let (format, itemsize) = if let Some(stg_info) = zelf.class().stg_info_opt() { + ( + stg_info + .format + .clone() + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed("X{}")), + stg_info.size, + ) + } else { + (Cow::Borrowed("X{}"), std::mem::size_of::()) + }; + let desc = BufferDescriptor { + len: itemsize, + readonly: false, + itemsize, + format, + dim_desc: vec![], + }; + let buf = PyBuffer::new(zelf.to_owned().into(), desc, &CDATA_BUFFER_METHODS); + Ok(buf) + } +} + +#[pyclass(flags(BASETYPE), with(Callable, Constructor, Representable, AsBuffer))] impl PyCFuncPtr { // restype getter/setter #[pygetset] @@ -1685,7 +1843,6 @@ impl PyCFuncPtr { } // Fallback to StgInfo for native types - use super::base::StgInfoFlags; zelf.class() .stg_info_opt() .map(|stg| stg.flags.bits()) @@ -1708,7 +1865,9 @@ struct ThunkUserData { /// Argument types for conversion arg_types: Vec, /// Result type for conversion (None means void) - res_type: Option, + pub res_type: Option, + /// Function flags (FUNCFLAG_USE_ERRNO, etc.) + pub flags: u32, } /// Check if ty is a subclass of a simple type (like MyInt(c_int)). @@ -1758,11 +1917,23 @@ fn ffi_to_python(ty: &Py, ptr: *const c_void, vm: &VirtualMachine) -> Py len += 1; } let slice = std::slice::from_raw_parts(wstr_ptr, len); - let s: String = slice - .iter() - .filter_map(|&c| char::from_u32(c as u32)) - .collect(); - vm.ctx.new_str(s).into() + // Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide + // Unix: wchar_t = i32 (UTF-32) -> convert via char::from_u32 + #[cfg(windows)] + { + use rustpython_common::wtf8::Wtf8Buf; + let wide: Vec = slice.to_vec(); + let wtf8 = Wtf8Buf::from_wide(&wide); + vm.ctx.new_str(wtf8).into() + } + #[cfg(not(windows))] + { + let s: String = slice + .iter() + .filter_map(|&c| char::from_u32(c as u32)) + .collect(); + vm.ctx.new_str(s).into() + } } } Some("P") => vm.ctx.new_int(*(ptr as *const usize)).into(), @@ -1865,6 +2036,16 @@ unsafe extern "C" fn thunk_callback( userdata: &ThunkUserData, ) { with_current_vm(|vm| { + // Swap errno before call if FUNCFLAG_USE_ERRNO is set + let use_errno = userdata.flags & StgInfoFlags::FUNCFLAG_USE_ERRNO.bits() != 0; + let saved_errno = if use_errno { + let current = rustpython_common::os::get_errno(); + // TODO: swap with ctypes stored errno (thread-local) + Some(current) + } else { + None + }; + let py_args: Vec = userdata .arg_types .iter() @@ -1877,6 +2058,15 @@ unsafe extern "C" fn thunk_callback( let py_result = userdata.callable.call(py_args, vm); + // Swap errno back after call + if use_errno { + let _current = rustpython_common::os::get_errno(); + // TODO: store current errno to ctypes storage + if let Some(saved) = saved_errno { + rustpython_common::os::set_errno(saved); + } + } + // Call unraisable hook if exception occurred if let Err(exc) = &py_result { let repr = userdata @@ -1935,6 +2125,7 @@ impl PyCThunk { callable: PyObjectRef, arg_types: Option, res_type: Option, + flags: u32, vm: &VirtualMachine, ) -> PyResult { let arg_type_vec: Vec = match arg_types { @@ -1979,6 +2170,7 @@ impl PyCThunk { callable: callable.clone(), arg_types: arg_type_vec, res_type: res_type_ref, + flags, }); let userdata_ptr = Box::into_raw(userdata); let userdata_ref: &'static ThunkUserData = unsafe { &*userdata_ptr }; diff --git a/crates/vm/src/stdlib/ctypes/library.rs b/crates/vm/src/stdlib/ctypes/library.rs index ec8ca91af0d..7512ce29d8a 100644 --- a/crates/vm/src/stdlib/ctypes/library.rs +++ b/crates/vm/src/stdlib/ctypes/library.rs @@ -2,12 +2,14 @@ use crate::VirtualMachine; use libloading::Library; use rustpython_common::lock::{PyMutex, PyRwLock}; use std::collections::HashMap; -use std::ffi::{OsStr, c_void}; +use std::ffi::OsStr; use std::fmt; -use std::ptr::null; -pub(super) struct SharedLibrary { - pub(super) lib: PyMutex>, +#[cfg(unix)] +use libloading::os::unix::Library as UnixLibrary; + +pub struct SharedLibrary { + pub(crate) lib: PyMutex>, } impl fmt::Debug for SharedLibrary { @@ -17,18 +19,44 @@ impl fmt::Debug for SharedLibrary { } impl SharedLibrary { - fn new(name: impl AsRef) -> Result { + #[cfg(windows)] + pub fn new(name: impl AsRef) -> Result { Ok(SharedLibrary { lib: PyMutex::new(unsafe { Some(Library::new(name.as_ref())?) }), }) } - fn get_pointer(&self) -> usize { + #[cfg(unix)] + pub fn new_with_mode( + name: impl AsRef, + mode: i32, + ) -> Result { + Ok(SharedLibrary { + lib: PyMutex::new(Some(unsafe { + UnixLibrary::open(Some(name.as_ref()), mode)?.into() + })), + }) + } + + /// Create a SharedLibrary from a raw dlopen handle (for pythonapi / dlopen(NULL)) + #[cfg(unix)] + pub fn from_raw_handle(handle: *mut libc::c_void) -> SharedLibrary { + SharedLibrary { + lib: PyMutex::new(Some(unsafe { UnixLibrary::from_raw(handle).into() })), + } + } + + /// Get the underlying OS handle (HMODULE on Windows, dlopen handle on Unix) + pub fn get_pointer(&self) -> usize { let lib_lock = self.lib.lock(); if let Some(l) = &*lib_lock { - l as *const Library as usize + // libloading::Library internally stores the OS handle directly + // On Windows: HMODULE (*mut c_void) + // On Unix: *mut c_void from dlopen + // We use transmute_copy to read the handle without consuming the Library + unsafe { std::mem::transmute_copy::(l) } } else { - null::() as usize + 0 } } @@ -36,16 +64,6 @@ impl SharedLibrary { let lib_lock = self.lib.lock(); lib_lock.is_none() } - - fn close(&self) { - *self.lib.lock() = None; - } -} - -impl Drop for SharedLibrary { - fn drop(&mut self) { - self.close(); - } } pub(super) struct ExternalLibs { @@ -63,6 +81,7 @@ impl ExternalLibs { self.libraries.get(&key) } + #[cfg(windows)] pub fn get_or_insert_lib( &mut self, library_path: impl AsRef, @@ -71,20 +90,53 @@ impl ExternalLibs { let new_lib = SharedLibrary::new(library_path)?; let key = new_lib.get_pointer(); - match self.libraries.get(&key) { - Some(l) => { - if l.is_closed() { - self.libraries.insert(key, new_lib); - } - } - _ => { - self.libraries.insert(key, new_lib); - } - }; + // Check if library already exists and is not closed + let should_use_cached = self.libraries.get(&key).is_some_and(|l| !l.is_closed()); + + if should_use_cached { + // new_lib will be dropped, calling FreeLibrary (decrements refcount) + // But library stays loaded because cached version maintains refcount + drop(new_lib); + return Ok((key, self.libraries.get(&key).expect("just checked"))); + } + self.libraries.insert(key, new_lib); Ok((key, self.libraries.get(&key).expect("just inserted"))) } + #[cfg(unix)] + pub fn get_or_insert_lib_with_mode( + &mut self, + library_path: impl AsRef, + mode: i32, + _vm: &VirtualMachine, + ) -> Result<(usize, &SharedLibrary), libloading::Error> { + let new_lib = SharedLibrary::new_with_mode(library_path, mode)?; + let key = new_lib.get_pointer(); + + // Check if library already exists and is not closed + let should_use_cached = self.libraries.get(&key).is_some_and(|l| !l.is_closed()); + + if should_use_cached { + // new_lib will be dropped, calling dlclose (decrements refcount) + // But library stays loaded because cached version maintains refcount + drop(new_lib); + return Ok((key, self.libraries.get(&key).expect("just checked"))); + } + + self.libraries.insert(key, new_lib); + Ok((key, self.libraries.get(&key).expect("just inserted"))) + } + + /// Insert a raw dlopen handle into the cache (for pythonapi / dlopen(NULL)) + #[cfg(unix)] + pub fn insert_raw_handle(&mut self, handle: *mut libc::c_void) -> usize { + let shared_lib = SharedLibrary::from_raw_handle(handle); + let key = handle as usize; + self.libraries.insert(key, shared_lib); + key + } + pub fn drop_lib(&mut self, key: usize) { self.libraries.remove(&key); } diff --git a/crates/vm/src/stdlib/ctypes/pointer.rs b/crates/vm/src/stdlib/ctypes/pointer.rs index 3ee39af3a7a..ae97b741b3c 100644 --- a/crates/vm/src/stdlib/ctypes/pointer.rs +++ b/crates/vm/src/stdlib/ctypes/pointer.rs @@ -1,6 +1,7 @@ +use super::base::CDATA_BUFFER_METHODS; use super::{PyCArray, PyCData, PyCSimple, PyCStructure, StgInfo, StgInfoFlags}; -use crate::protocol::PyNumberMethods; -use crate::types::{AsNumber, Constructor, Initializer}; +use crate::protocol::{BufferDescriptor, PyBuffer, PyNumberMethods}; +use crate::types::{AsBuffer, AsNumber, Constructor, Initializer}; use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyBytes, PyInt, PyList, PySlice, PyStr, PyType, PyTypeRef}, @@ -8,6 +9,7 @@ use crate::{ function::{FuncArgs, OptionalArg}, }; use num_traits::ToPrimitive; +use std::borrow::Cow; #[pyclass(name = "PyCPointerType", base = PyType, module = "_ctypes")] #[derive(Debug)] @@ -42,11 +44,19 @@ impl Initializer for PyCPointerType { stg_info.length = 1; stg_info.flags |= StgInfoFlags::TYPEFLAG_ISPOINTER; - // Set format string: "&" - if let Some(ref proto) = stg_info.proto { - let item_info = proto.stg_info_opt().expect("proto has StgInfo"); + // Set format string: "&" or "&(shape)" for arrays + if let Some(ref proto) = stg_info.proto + && let Some(item_info) = proto.stg_info_opt() + { let current_format = item_info.format.as_deref().unwrap_or("B"); - stg_info.format = Some(format!("&{}", current_format)); + // Include shape for array types in the pointer format + let shape_str = if !item_info.shape.is_empty() { + let dims: Vec = item_info.shape.iter().map(|d| d.to_string()).collect(); + format!("({})", dims.join(",")) + } else { + String::new() + }; + stg_info.format = Some(format!("&{}{}", shape_str, current_format)); } let _ = new_type.init_type_data(stg_info); @@ -69,6 +79,15 @@ impl PyCPointerType { return Ok(value); } + // 1.5 CArgObject (from byref()) - check if underlying obj is instance of _type_ + if let Some(carg) = value.downcast_ref::() + && let Ok(type_attr) = cls.as_object().get_attr("_type_", vm) + && let Ok(type_ref) = type_attr.downcast::() + && carg.obj.is_instance(type_ref.as_object(), vm)? + { + return Ok(value); + } + // 2. If already an instance of the requested type, return it if value.is_instance(cls.as_object(), vm)? { return Ok(value); @@ -149,10 +168,17 @@ impl PyCPointerType { if let Some(mut stg_info) = zelf.get_type_data_mut::() { stg_info.proto = Some(typ_type.clone()); - // Update format string: "&" + // Update format string: "&" or "&(shape)" for arrays let item_info = typ_type.stg_info_opt().expect("proto has StgInfo"); let current_format = item_info.format.as_deref().unwrap_or("B"); - stg_info.format = Some(format!("&{}", current_format)); + // Include shape for array types in the pointer format + let shape_str = if !item_info.shape.is_empty() { + let dims: Vec = item_info.shape.iter().map(|d| d.to_string()).collect(); + format!("({})", dims.join(",")) + } else { + String::new() + }; + stg_info.format = Some(format!("&{}{}", shape_str, current_format)); } // 4. Set _type_ attribute on the pointer type @@ -233,7 +259,10 @@ impl Initializer for PyCPointer { } } -#[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor, Initializer))] +#[pyclass( + flags(BASETYPE, IMMUTABLETYPE), + with(Constructor, Initializer, AsBuffer) +)] impl PyCPointer { /// Get the pointer value stored in buffer as usize pub fn get_ptr_value(&self) -> usize { @@ -605,20 +634,47 @@ impl PyCPointer { unsafe { let ptr = addr as *const u8; match type_code { + // Single-byte types don't need read_unaligned Some("c") => Ok(vm.ctx.new_bytes(vec![*ptr]).into()), - Some("b") => Ok(vm.ctx.new_int(*(ptr as *const i8) as i32).into()), + Some("b") => Ok(vm.ctx.new_int(*ptr as i8 as i32).into()), Some("B") => Ok(vm.ctx.new_int(*ptr as i32).into()), - Some("h") => Ok(vm.ctx.new_int(*(ptr as *const i16) as i32).into()), - Some("H") => Ok(vm.ctx.new_int(*(ptr as *const u16) as i32).into()), - Some("i") | Some("l") => Ok(vm.ctx.new_int(*(ptr as *const i32)).into()), - Some("I") | Some("L") => Ok(vm.ctx.new_int(*(ptr as *const u32)).into()), - Some("q") => Ok(vm.ctx.new_int(*(ptr as *const i64)).into()), - Some("Q") => Ok(vm.ctx.new_int(*(ptr as *const u64)).into()), - Some("f") => Ok(vm.ctx.new_float(*(ptr as *const f32) as f64).into()), - Some("d") | Some("g") => Ok(vm.ctx.new_float(*(ptr as *const f64)).into()), - Some("P") | Some("z") | Some("Z") => { - Ok(vm.ctx.new_int(*(ptr as *const usize)).into()) - } + // Multi-byte types need read_unaligned for safety on strict-alignment architectures + Some("h") => Ok(vm + .ctx + .new_int(std::ptr::read_unaligned(ptr as *const i16) as i32) + .into()), + Some("H") => Ok(vm + .ctx + .new_int(std::ptr::read_unaligned(ptr as *const u16) as i32) + .into()), + Some("i") | Some("l") => Ok(vm + .ctx + .new_int(std::ptr::read_unaligned(ptr as *const i32)) + .into()), + Some("I") | Some("L") => Ok(vm + .ctx + .new_int(std::ptr::read_unaligned(ptr as *const u32)) + .into()), + Some("q") => Ok(vm + .ctx + .new_int(std::ptr::read_unaligned(ptr as *const i64)) + .into()), + Some("Q") => Ok(vm + .ctx + .new_int(std::ptr::read_unaligned(ptr as *const u64)) + .into()), + Some("f") => Ok(vm + .ctx + .new_float(std::ptr::read_unaligned(ptr as *const f32) as f64) + .into()), + Some("d") | Some("g") => Ok(vm + .ctx + .new_float(std::ptr::read_unaligned(ptr as *const f64)) + .into()), + Some("P") | Some("z") | Some("Z") => Ok(vm + .ctx + .new_int(std::ptr::read_unaligned(ptr as *const usize)) + .into()), _ => { // Default: read as bytes let bytes = std::slice::from_raw_parts(ptr, size).to_vec(); @@ -652,27 +708,37 @@ impl PyCPointer { "bytes/string or integer address expected".to_owned(), )); }; - *(ptr as *mut usize) = ptr_val; + std::ptr::write_unaligned(ptr as *mut usize, ptr_val); return Ok(()); } _ => {} } // Try to get value as integer + // Use write_unaligned for safety on strict-alignment architectures if let Ok(int_val) = value.try_int(vm) { let i = int_val.as_bigint(); match size { 1 => { - *ptr = i.to_u8().unwrap_or(0); + *ptr = i.to_u8().expect("int too large"); } 2 => { - *(ptr as *mut i16) = i.to_i16().unwrap_or(0); + std::ptr::write_unaligned( + ptr as *mut i16, + i.to_i16().expect("int too large"), + ); } 4 => { - *(ptr as *mut i32) = i.to_i32().unwrap_or(0); + std::ptr::write_unaligned( + ptr as *mut i32, + i.to_i32().expect("int too large"), + ); } 8 => { - *(ptr as *mut i64) = i.to_i64().unwrap_or(0); + std::ptr::write_unaligned( + ptr as *mut i64, + i.to_i64().expect("int too large"), + ); } _ => { let bytes = i.to_signed_bytes_le(); @@ -688,10 +754,10 @@ impl PyCPointer { let f = float_val.to_f64(); match size { 4 => { - *(ptr as *mut f32) = f as f32; + std::ptr::write_unaligned(ptr as *mut f32, f as f32); } 8 => { - *(ptr as *mut f64) = f; + std::ptr::write_unaligned(ptr as *mut f64, f); } _ => {} } @@ -712,3 +778,28 @@ impl PyCPointer { } } } + +impl AsBuffer for PyCPointer { + fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + let stg_info = zelf + .class() + .stg_info_opt() + .expect("PyCPointer type must have StgInfo"); + let format = stg_info + .format + .clone() + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed("&B")); + let itemsize = stg_info.size; + // Pointer types are scalars with ndim=0, shape=() + let desc = BufferDescriptor { + len: itemsize, + readonly: false, + itemsize, + format, + dim_desc: vec![], + }; + let buf = PyBuffer::new(zelf.to_owned().into(), desc, &CDATA_BUFFER_METHODS); + Ok(buf) + } +} diff --git a/crates/vm/src/stdlib/ctypes/simple.rs b/crates/vm/src/stdlib/ctypes/simple.rs index 1c0ec250d72..803b38d6e05 100644 --- a/crates/vm/src/stdlib/ctypes/simple.rs +++ b/crates/vm/src/stdlib/ctypes/simple.rs @@ -14,12 +14,47 @@ use crate::protocol::{BufferDescriptor, PyBuffer, PyNumberMethods}; use crate::types::{AsBuffer, AsNumber, Constructor, Initializer, Representable}; use crate::{AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; use num_traits::ToPrimitive; +use std::borrow::Cow; use std::fmt::Debug; /// Valid type codes for ctypes simple types // spell-checker: disable-next-line pub(super) const SIMPLE_TYPE_CHARS: &str = "cbBhHiIlLdfuzZqQPXOv?g"; +/// Convert ctypes type code to PEP 3118 format code. +/// Some ctypes codes need to be mapped to standard-size codes based on platform. +/// _ctypes_alloc_format_string_for_type +fn ctypes_code_to_pep3118(code: char) -> char { + match code { + // c_int: map based on sizeof(int) + 'i' if std::mem::size_of::() == 2 => 'h', + 'i' if std::mem::size_of::() == 4 => 'i', + 'i' if std::mem::size_of::() == 8 => 'q', + 'I' if std::mem::size_of::() == 2 => 'H', + 'I' if std::mem::size_of::() == 4 => 'I', + 'I' if std::mem::size_of::() == 8 => 'Q', + // c_long: map based on sizeof(long) + 'l' if std::mem::size_of::() == 4 => 'l', + 'l' if std::mem::size_of::() == 8 => 'q', + 'L' if std::mem::size_of::() == 4 => 'L', + 'L' if std::mem::size_of::() == 8 => 'Q', + // c_bool: map based on sizeof(bool) - typically 1 byte on all platforms + '?' if std::mem::size_of::() == 1 => '?', + '?' if std::mem::size_of::() == 2 => 'H', + '?' if std::mem::size_of::() == 4 => 'L', + '?' if std::mem::size_of::() == 8 => 'Q', + // Default: use the same code + _ => code, + } +} + +/// _ctypes_alloc_format_string_for_type +fn alloc_format_string_for_type(code: char, big_endian: bool) -> String { + let prefix = if big_endian { ">" } else { "<" }; + let pep_code = ctypes_code_to_pep3118(code); + format!("{}{}", prefix, pep_code) +} + /// Create a new simple type instance from a class fn new_simple_type( cls: Either<&PyObject, &Py>, @@ -165,6 +200,20 @@ fn set_primitive(_type_: &str, value: &PyObject, vm: &VirtualMachine) -> PyResul } // O_set: py_object accepts any Python object "O" => Ok(value.to_owned()), + // X_set: BSTR - same as Z (c_wchar_p), accepts None, int, or str + "X" => { + if value.is(&vm.ctx.none) + || value.downcast_ref_if_exact::(vm).is_some() + || value.downcast_ref_if_exact::(vm).is_some() + { + Ok(value.to_owned()) + } else { + Err(vm.new_type_error(format!( + "unicode string or integer address expected instead of {} instance", + value.class().name() + ))) + } + } _ => { // "P" if value.downcast_ref_if_exact::(vm).is_some() @@ -313,7 +362,7 @@ impl PyCSimpleType { .to_pyobject(vm)); } // 2. Array/Pointer with c_wchar element type - if is_cwchar_array_or_pointer(&value, vm) { + if is_cwchar_array_or_pointer(&value, vm)? { return Ok(value); } // 3. CArgObject (byref(c_wchar(...))) @@ -521,13 +570,10 @@ impl Initializer for PyCSimpleType { let mut stg_info = StgInfo::new(size, align); // Set format for PEP 3118 buffer protocol - // Format is endian prefix + type code (e.g., "" - }; - stg_info.format = Some(format!("{}{}", endian_prefix, type_str)); + stg_info.format = Some(alloc_format_string_for_type( + type_str.chars().next().unwrap_or('?'), + cfg!(target_endian = "big"), + )); stg_info.paramfunc = super::base::ParamFunc::Simple; // Set TYPEFLAG_ISPOINTER for pointer types: z (c_char_p), Z (c_wchar_p), @@ -619,13 +665,14 @@ fn create_swapped_types( swapped_type.set_attr("_swappedbytes_", vm.ctx.none(), vm)?; // Update swapped type's StgInfo format to use opposite endian prefix - // Native uses '<' on little-endian, '>' on big-endian - // Swapped uses the opposite if let Ok(swapped_type_ref) = swapped_type.clone().downcast::() && let Some(mut sw_stg) = swapped_type_ref.get_type_data_mut::() { - let swapped_prefix = if is_little_endian { ">" } else { "<" }; - sw_stg.format = Some(format!("{}{}", swapped_prefix, type_str)); + // Swapped: little-endian system uses big-endian prefix and vice versa + sw_stg.format = Some(alloc_format_string_for_type( + type_str.chars().next().unwrap_or('?'), + is_little_endian, + )); } // Set attributes based on system byte order @@ -734,63 +781,56 @@ fn value_to_bytes_endian( } "b" => { // c_byte - signed char (1 byte) - // PyLong_AsLongMask pattern: wrapping for overflow values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().unwrap_or(0) as i8; + let v = int_val.as_bigint().to_i128().expect("int too large") as i8; return vec![v as u8]; } vec![0] } "B" => { // c_ubyte - unsigned char (1 byte) - // PyLong_AsUnsignedLongMask: wrapping for negative values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().map(|n| n as u8).unwrap_or(0); + let v = int_val.as_bigint().to_i128().expect("int too large") as u8; return vec![v]; } vec![0] } "h" => { // c_short (2 bytes) - // PyLong_AsLongMask pattern: wrapping for overflow values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().unwrap_or(0) as i16; + let v = int_val.as_bigint().to_i128().expect("int too large") as i16; return to_bytes!(v); } vec![0; 2] } "H" => { // c_ushort (2 bytes) - // PyLong_AsUnsignedLongMask: wrapping for negative values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().map(|n| n as u16).unwrap_or(0); + let v = int_val.as_bigint().to_i128().expect("int too large") as u16; return to_bytes!(v); } vec![0; 2] } "i" => { // c_int (4 bytes) - // PyLong_AsLongMask pattern: wrapping for overflow values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().unwrap_or(0) as i32; + let v = int_val.as_bigint().to_i128().expect("int too large") as i32; return to_bytes!(v); } vec![0; 4] } "I" => { // c_uint (4 bytes) - // PyLong_AsUnsignedLongMask: wrapping for negative values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().map(|n| n as u32).unwrap_or(0); + let v = int_val.as_bigint().to_i128().expect("int too large") as u32; return to_bytes!(v); } vec![0; 4] } "l" => { // c_long (platform dependent) - // PyLong_AsLongMask pattern: wrapping for overflow values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().unwrap_or(0) as libc::c_long; + let v = int_val.as_bigint().to_i128().expect("int too large") as libc::c_long; return to_bytes!(v); } const SIZE: usize = std::mem::size_of::(); @@ -798,13 +838,8 @@ fn value_to_bytes_endian( } "L" => { // c_ulong (platform dependent) - // PyLong_AsUnsignedLongMask: wrapping for negative values if let Ok(int_val) = value.try_index(vm) { - let v = int_val - .as_bigint() - .to_i128() - .map(|n| n as libc::c_ulong) - .unwrap_or(0); + let v = int_val.as_bigint().to_i128().expect("int too large") as libc::c_ulong; return to_bytes!(v); } const SIZE: usize = std::mem::size_of::(); @@ -812,18 +847,16 @@ fn value_to_bytes_endian( } "q" => { // c_longlong (8 bytes) - // PyLong_AsLongMask pattern: wrapping for overflow values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().unwrap_or(0) as i64; + let v = int_val.as_bigint().to_i128().expect("int too large") as i64; return to_bytes!(v); } vec![0; 8] } "Q" => { // c_ulonglong (8 bytes) - // PyLong_AsUnsignedLongLongMask: wrapping for negative values if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_i128().map(|n| n as u64).unwrap_or(0); + let v = int_val.as_bigint().to_i128().expect("int too large") as u64; return to_bytes!(v); } vec![0; 8] @@ -899,7 +932,10 @@ fn value_to_bytes_endian( "P" => { // c_void_p - pointer type (platform pointer size) if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_usize().unwrap_or(0); + let v = int_val + .as_bigint() + .to_usize() + .expect("int too large for pointer"); return to_bytes!(v); } vec![0; std::mem::size_of::()] @@ -908,7 +944,10 @@ fn value_to_bytes_endian( // c_char_p - pointer to char (stores pointer value from int) // PyBytes case is handled in slot_new/set_value with make_z_buffer() if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_usize().unwrap_or(0); + let v = int_val + .as_bigint() + .to_usize() + .expect("int too large for pointer"); return to_bytes!(v); } vec![0; std::mem::size_of::()] @@ -917,7 +956,10 @@ fn value_to_bytes_endian( // c_wchar_p - pointer to wchar_t (stores pointer value from int) // PyStr case is handled in slot_new/set_value with make_wchar_buffer() if let Ok(int_val) = value.try_index(vm) { - let v = int_val.as_bigint().to_usize().unwrap_or(0); + let v = int_val + .as_bigint() + .to_usize() + .expect("int too large for pointer"); return to_bytes!(v); } vec![0; std::mem::size_of::()] @@ -939,7 +981,7 @@ fn is_cchar_array_or_pointer(value: &PyObject, vm: &VirtualMachine) -> bool { if let Some(arr) = value.downcast_ref::() && let Some(info) = arr.class().stg_info_opt() && let Some(ref elem_type) = info.element_type - && let Some(elem_code) = elem_type.class().type_code(vm) + && let Some(elem_code) = elem_type.type_code(vm) { return elem_code == "c"; } @@ -947,7 +989,7 @@ fn is_cchar_array_or_pointer(value: &PyObject, vm: &VirtualMachine) -> bool { if let Some(ptr) = value.downcast_ref::() && let Some(info) = ptr.class().stg_info_opt() && let Some(ref proto) = info.proto - && let Some(proto_code) = proto.class().type_code(vm) + && let Some(proto_code) = proto.type_code(vm) { return proto_code == "c"; } @@ -955,25 +997,25 @@ fn is_cchar_array_or_pointer(value: &PyObject, vm: &VirtualMachine) -> bool { } /// Check if value is a c_wchar array or pointer(c_wchar) -fn is_cwchar_array_or_pointer(value: &PyObject, vm: &VirtualMachine) -> bool { +fn is_cwchar_array_or_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult { // Check Array with c_wchar element type if let Some(arr) = value.downcast_ref::() { - let info = arr.class().stg_info_opt().expect("array has StgInfo"); + let info = arr.class().stg_info(vm)?; let elem_type = info.element_type.as_ref().expect("array has element_type"); - if let Some(elem_code) = elem_type.class().type_code(vm) { - return elem_code == "u"; + if let Some(elem_code) = elem_type.type_code(vm) { + return Ok(elem_code == "u"); } } // Check Pointer to c_wchar if let Some(ptr) = value.downcast_ref::() { - let info = ptr.class().stg_info_opt().expect("pointer has StgInfo"); + let info = ptr.class().stg_info(vm)?; if let Some(ref proto) = info.proto - && let Some(proto_code) = proto.class().type_code(vm) + && let Some(proto_code) = proto.type_code(vm) { - return proto_code == "u"; + return Ok(proto_code == "u"); } } - false + Ok(false) } impl Constructor for PyCSimple { @@ -1121,15 +1163,27 @@ impl PyCSimple { return Ok(vm.ctx.none()); } // Read null-terminated wide string at the address + // Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide for surrogate pairs + // Unix: wchar_t = i32 (UTF-32) -> convert via char::from_u32 unsafe { let w_ptr = ptr as *const libc::wchar_t; let len = libc::wcslen(w_ptr); let wchars = std::slice::from_raw_parts(w_ptr, len); - let s: String = wchars - .iter() - .filter_map(|&c| char::from_u32(c as u32)) - .collect(); - return Ok(vm.ctx.new_str(s).into()); + #[cfg(windows)] + { + use rustpython_common::wtf8::Wtf8Buf; + let wide: Vec = wchars.to_vec(); + let wtf8 = Wtf8Buf::from_wide(&wide); + return Ok(vm.ctx.new_str(wtf8).into()); + } + #[cfg(not(windows))] + { + let s: String = wchars + .iter() + .filter_map(|&c| char::from_u32(c as u32)) + .collect(); + return Ok(vm.ctx.new_str(s).into()); + } } } @@ -1349,12 +1403,25 @@ impl PyCSimple { impl AsBuffer for PyCSimple { fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { - let buffer_len = zelf.0.buffer.read().len(); - let buf = PyBuffer::new( - zelf.to_owned().into(), - BufferDescriptor::simple(buffer_len, false), // readonly=false for ctypes - &CDATA_BUFFER_METHODS, - ); + let stg_info = zelf + .class() + .stg_info_opt() + .expect("PyCSimple type must have StgInfo"); + let format = stg_info + .format + .clone() + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed("B")); + let itemsize = stg_info.size; + // Simple types are scalars with ndim=0, shape=() + let desc = BufferDescriptor { + len: itemsize, + readonly: false, + itemsize, + format, + dim_desc: vec![], + }; + let buf = PyBuffer::new(zelf.to_owned().into(), desc, &CDATA_BUFFER_METHODS); Ok(buf) } } diff --git a/crates/vm/src/stdlib/ctypes/structure.rs b/crates/vm/src/stdlib/ctypes/structure.rs index 10b8812e42c..1ca428669d4 100644 --- a/crates/vm/src/stdlib/ctypes/structure.rs +++ b/crates/vm/src/stdlib/ctypes/structure.rs @@ -601,12 +601,6 @@ impl PyCStructure { fn _b0_(&self) -> Option { self.0.base.read().clone() } - - #[pygetset] - fn _fields_(&self, vm: &VirtualMachine) -> PyObjectRef { - // Return the _fields_ from the class, not instance - vm.ctx.none() - } } impl AsBuffer for PyCStructure { From 987216bcb8df3b3ce73056d66bee7dff66f6831c Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 16 Dec 2025 20:38:32 +0900 Subject: [PATCH 035/418] enable ctypes --- Lib/ctypes/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/Lib/ctypes/__init__.py b/Lib/ctypes/__init__.py index 3599e13ed28..2ef5e9a4fe7 100644 --- a/Lib/ctypes/__init__.py +++ b/Lib/ctypes/__init__.py @@ -36,9 +36,6 @@ FUNCFLAG_USE_ERRNO as _FUNCFLAG_USE_ERRNO, \ FUNCFLAG_USE_LASTERROR as _FUNCFLAG_USE_LASTERROR -# TODO: RUSTPYTHON remove this -from _ctypes import _non_existing_function - # WINOLEAPI -> HRESULT # WINOLEAPI_(type) # From 03d6f4634f86a1b9d12e85159bdd14361921c7ed Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 20 Dec 2025 14:31:01 +0900 Subject: [PATCH 036/418] Upgrade ctypes to Python v3.13.11 --- Lib/ctypes/__init__.py | 94 ++- Lib/ctypes/_endian.py | 8 +- Lib/ctypes/test/__init__.py | 16 - Lib/ctypes/test/__main__.py | 4 - Lib/ctypes/test/test_anon.py | 73 -- Lib/ctypes/test/test_array_in_pointer.py | 64 -- Lib/ctypes/test/test_arrays.py | 238 ------ Lib/ctypes/test/test_as_parameter.py | 231 ------ Lib/ctypes/test/test_bitfields.py | 297 ------- Lib/ctypes/test/test_buffers.py | 73 -- Lib/ctypes/test/test_bytes.py | 66 -- Lib/ctypes/test/test_byteswap.py | 364 --------- Lib/ctypes/test/test_callbacks.py | 333 -------- Lib/ctypes/test/test_cast.py | 99 --- Lib/ctypes/test/test_cfuncs.py | 218 ----- Lib/ctypes/test/test_checkretval.py | 36 - Lib/ctypes/test/test_delattr.py | 21 - Lib/ctypes/test/test_errno.py | 76 -- Lib/ctypes/test/test_find.py | 127 --- Lib/ctypes/test/test_frombuffer.py | 141 ---- Lib/ctypes/test/test_funcptr.py | 132 --- Lib/ctypes/test/test_functions.py | 384 --------- Lib/ctypes/test/test_incomplete.py | 42 - Lib/ctypes/test/test_init.py | 40 - Lib/ctypes/test/test_internals.py | 100 --- Lib/ctypes/test/test_keeprefs.py | 153 ---- Lib/ctypes/test/test_libc.py | 33 - Lib/ctypes/test/test_loading.py | 182 ----- Lib/ctypes/test/test_macholib.py | 110 --- Lib/ctypes/test/test_memfunctions.py | 79 -- Lib/ctypes/test/test_numbers.py | 218 ----- Lib/ctypes/test/test_objects.py | 67 -- Lib/ctypes/test/test_parameters.py | 250 ------ Lib/ctypes/test/test_pep3118.py | 235 ------ Lib/ctypes/test/test_pickling.py | 81 -- Lib/ctypes/test/test_pointers.py | 223 ----- Lib/ctypes/test/test_prototypes.py | 222 ----- Lib/ctypes/test/test_python_api.py | 85 -- Lib/ctypes/test/test_random_things.py | 77 -- Lib/ctypes/test/test_refcounts.py | 116 --- Lib/ctypes/test/test_repr.py | 29 - Lib/ctypes/test/test_returnfuncptrs.py | 66 -- Lib/ctypes/test/test_simplesubclasses.py | 55 -- Lib/ctypes/test/test_sizes.py | 33 - Lib/ctypes/test/test_slicing.py | 167 ---- Lib/ctypes/test/test_stringptr.py | 77 -- Lib/ctypes/test/test_strings.py | 145 ---- Lib/ctypes/test/test_struct_fields.py | 97 --- Lib/ctypes/test/test_structures.py | 812 ------------------- Lib/ctypes/test/test_unaligned_structures.py | 43 - Lib/ctypes/test/test_unicode.py | 64 -- Lib/ctypes/test/test_values.py | 103 --- Lib/ctypes/test/test_varsize_struct.py | 50 -- Lib/ctypes/test/test_win32.py | 136 ---- Lib/ctypes/test/test_wintypes.py | 43 - Lib/ctypes/util.py | 18 +- 56 files changed, 70 insertions(+), 7276 deletions(-) delete mode 100644 Lib/ctypes/test/__init__.py delete mode 100644 Lib/ctypes/test/__main__.py delete mode 100644 Lib/ctypes/test/test_anon.py delete mode 100644 Lib/ctypes/test/test_array_in_pointer.py delete mode 100644 Lib/ctypes/test/test_arrays.py delete mode 100644 Lib/ctypes/test/test_as_parameter.py delete mode 100644 Lib/ctypes/test/test_bitfields.py delete mode 100644 Lib/ctypes/test/test_buffers.py delete mode 100644 Lib/ctypes/test/test_bytes.py delete mode 100644 Lib/ctypes/test/test_byteswap.py delete mode 100644 Lib/ctypes/test/test_callbacks.py delete mode 100644 Lib/ctypes/test/test_cast.py delete mode 100644 Lib/ctypes/test/test_cfuncs.py delete mode 100644 Lib/ctypes/test/test_checkretval.py delete mode 100644 Lib/ctypes/test/test_delattr.py delete mode 100644 Lib/ctypes/test/test_errno.py delete mode 100644 Lib/ctypes/test/test_find.py delete mode 100644 Lib/ctypes/test/test_frombuffer.py delete mode 100644 Lib/ctypes/test/test_funcptr.py delete mode 100644 Lib/ctypes/test/test_functions.py delete mode 100644 Lib/ctypes/test/test_incomplete.py delete mode 100644 Lib/ctypes/test/test_init.py delete mode 100644 Lib/ctypes/test/test_internals.py delete mode 100644 Lib/ctypes/test/test_keeprefs.py delete mode 100644 Lib/ctypes/test/test_libc.py delete mode 100644 Lib/ctypes/test/test_loading.py delete mode 100644 Lib/ctypes/test/test_macholib.py delete mode 100644 Lib/ctypes/test/test_memfunctions.py delete mode 100644 Lib/ctypes/test/test_numbers.py delete mode 100644 Lib/ctypes/test/test_objects.py delete mode 100644 Lib/ctypes/test/test_parameters.py delete mode 100644 Lib/ctypes/test/test_pep3118.py delete mode 100644 Lib/ctypes/test/test_pickling.py delete mode 100644 Lib/ctypes/test/test_pointers.py delete mode 100644 Lib/ctypes/test/test_prototypes.py delete mode 100644 Lib/ctypes/test/test_python_api.py delete mode 100644 Lib/ctypes/test/test_random_things.py delete mode 100644 Lib/ctypes/test/test_refcounts.py delete mode 100644 Lib/ctypes/test/test_repr.py delete mode 100644 Lib/ctypes/test/test_returnfuncptrs.py delete mode 100644 Lib/ctypes/test/test_simplesubclasses.py delete mode 100644 Lib/ctypes/test/test_sizes.py delete mode 100644 Lib/ctypes/test/test_slicing.py delete mode 100644 Lib/ctypes/test/test_stringptr.py delete mode 100644 Lib/ctypes/test/test_strings.py delete mode 100644 Lib/ctypes/test/test_struct_fields.py delete mode 100644 Lib/ctypes/test/test_structures.py delete mode 100644 Lib/ctypes/test/test_unaligned_structures.py delete mode 100644 Lib/ctypes/test/test_unicode.py delete mode 100644 Lib/ctypes/test/test_values.py delete mode 100644 Lib/ctypes/test/test_varsize_struct.py delete mode 100644 Lib/ctypes/test/test_win32.py delete mode 100644 Lib/ctypes/test/test_wintypes.py diff --git a/Lib/ctypes/__init__.py b/Lib/ctypes/__init__.py index 2ef5e9a4fe7..80651dc64ce 100644 --- a/Lib/ctypes/__init__.py +++ b/Lib/ctypes/__init__.py @@ -1,6 +1,8 @@ """create and manipulate C data types in Python""" -import os as _os, sys as _sys +import os as _os +import sys as _sys +import sysconfig as _sysconfig import types as _types __version__ = "1.1.0" @@ -107,7 +109,7 @@ class CFunctionType(_CFuncPtr): return CFunctionType if _os.name == "nt": - from _ctypes import LoadLibrary as _dlopen + from _ctypes import LoadLibrary as _LoadLibrary from _ctypes import FUNCFLAG_STDCALL as _FUNCFLAG_STDCALL _win_functype_cache = {} @@ -302,8 +304,9 @@ def create_unicode_buffer(init, size=None): raise TypeError(init) -# XXX Deprecated def SetPointerType(pointer, cls): + import warnings + warnings._deprecated("ctypes.SetPointerType", remove=(3, 15)) if _pointer_type_cache.get(cls, None) is not None: raise RuntimeError("This type already exists in the cache") if id(pointer) not in _pointer_type_cache: @@ -312,7 +315,6 @@ def SetPointerType(pointer, cls): _pointer_type_cache[cls] = pointer del _pointer_type_cache[id(pointer)] -# XXX Deprecated def ARRAY(typ, len): return typ * len @@ -344,52 +346,59 @@ def __init__(self, name, mode=DEFAULT_MODE, handle=None, use_errno=False, use_last_error=False, winmode=None): + class _FuncPtr(_CFuncPtr): + _flags_ = self._func_flags_ + _restype_ = self._func_restype_ + if use_errno: + _flags_ |= _FUNCFLAG_USE_ERRNO + if use_last_error: + _flags_ |= _FUNCFLAG_USE_LASTERROR + + self._FuncPtr = _FuncPtr if name: name = _os.fspath(name) + self._handle = self._load_library(name, mode, handle, winmode) + + if _os.name == "nt": + def _load_library(self, name, mode, handle, winmode): + if winmode is None: + import nt as _nt + winmode = _nt._LOAD_LIBRARY_SEARCH_DEFAULT_DIRS + # WINAPI LoadLibrary searches for a DLL if the given name + # is not fully qualified with an explicit drive. For POSIX + # compatibility, and because the DLL search path no longer + # contains the working directory, begin by fully resolving + # any name that contains a path separator. + if name is not None and ('/' in name or '\\' in name): + name = _nt._getfullpathname(name) + winmode |= _nt._LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR + self._name = name + if handle is not None: + return handle + return _LoadLibrary(self._name, winmode) + + else: + def _load_library(self, name, mode, handle, winmode): # If the filename that has been provided is an iOS/tvOS/watchOS # .fwork file, dereference the location to the true origin of the # binary. - if name.endswith(".fwork"): + if name and name.endswith(".fwork"): with open(name) as f: name = _os.path.join( _os.path.dirname(_sys.executable), f.read().strip() ) - - self._name = name - flags = self._func_flags_ - if use_errno: - flags |= _FUNCFLAG_USE_ERRNO - if use_last_error: - flags |= _FUNCFLAG_USE_LASTERROR - if _sys.platform.startswith("aix"): - """When the name contains ".a(" and ends with ")", - e.g., "libFOO.a(libFOO.so)" - this is taken to be an - archive(member) syntax for dlopen(), and the mode is adjusted. - Otherwise, name is presented to dlopen() as a file argument. - """ - if name and name.endswith(")") and ".a(" in name: - mode |= ( _os.RTLD_MEMBER | _os.RTLD_NOW ) - if _os.name == "nt": - if winmode is not None: - mode = winmode - else: - import nt - mode = nt._LOAD_LIBRARY_SEARCH_DEFAULT_DIRS - if '/' in name or '\\' in name: - self._name = nt._getfullpathname(self._name) - mode |= nt._LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR - - class _FuncPtr(_CFuncPtr): - _flags_ = flags - _restype_ = self._func_restype_ - self._FuncPtr = _FuncPtr - - if handle is None: - self._handle = _dlopen(self._name, mode) - else: - self._handle = handle + if _sys.platform.startswith("aix"): + """When the name contains ".a(" and ends with ")", + e.g., "libFOO.a(libFOO.so)" - this is taken to be an + archive(member) syntax for dlopen(), and the mode is adjusted. + Otherwise, name is presented to dlopen() as a file argument. + """ + if name and name.endswith(")") and ".a(" in name: + mode |= _os.RTLD_MEMBER | _os.RTLD_NOW + self._name = name + return _dlopen(name, mode) def __repr__(self): return "<%s '%s', handle %x at %#x>" % \ @@ -477,10 +486,9 @@ def LoadLibrary(self, name): if _os.name == "nt": pythonapi = PyDLL("python dll", None, _sys.dllhandle) -elif _sys.platform == "android": - pythonapi = PyDLL("libpython%d.%d.so" % _sys.version_info[:2]) -elif _sys.platform == "cygwin": - pythonapi = PyDLL("libpython%d.%d.dll" % _sys.version_info[:2]) +elif _sys.platform in ["android", "cygwin"]: + # These are Unix-like platforms which use a dynamically-linked libpython. + pythonapi = PyDLL(_sysconfig.get_config_var("LDLIBRARY")) else: pythonapi = PyDLL(None) diff --git a/Lib/ctypes/_endian.py b/Lib/ctypes/_endian.py index 34dee64b1a6..6382dd22b8a 100644 --- a/Lib/ctypes/_endian.py +++ b/Lib/ctypes/_endian.py @@ -1,5 +1,5 @@ import sys -from ctypes import * +from ctypes import Array, Structure, Union _array_type = type(Array) @@ -15,8 +15,8 @@ def _other_endian(typ): # if typ is array if isinstance(typ, _array_type): return _other_endian(typ._type_) * typ._length_ - # if typ is structure - if issubclass(typ, Structure): + # if typ is structure or union + if issubclass(typ, (Structure, Union)): return typ raise TypeError("This type does not support other endian: %s" % typ) @@ -37,7 +37,7 @@ class _swapped_union_meta(_swapped_meta, type(Union)): pass ################################################################ # Note: The Structure metaclass checks for the *presence* (not the -# value!) of a _swapped_bytes_ attribute to determine the bit order in +# value!) of a _swappedbytes_ attribute to determine the bit order in # structures containing bit fields. if sys.byteorder == "little": diff --git a/Lib/ctypes/test/__init__.py b/Lib/ctypes/test/__init__.py deleted file mode 100644 index 6e496fa5a52..00000000000 --- a/Lib/ctypes/test/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -import os -import unittest -from test import support -from test.support import import_helper - - -# skip tests if _ctypes was not built -ctypes = import_helper.import_module('ctypes') -ctypes_symbols = dir(ctypes) - -def need_symbol(name): - return unittest.skipUnless(name in ctypes_symbols, - '{!r} is required'.format(name)) - -def load_tests(*args): - return support.load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/ctypes/test/__main__.py b/Lib/ctypes/test/__main__.py deleted file mode 100644 index 362a9ec8cff..00000000000 --- a/Lib/ctypes/test/__main__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ctypes.test import load_tests -import unittest - -unittest.main() diff --git a/Lib/ctypes/test/test_anon.py b/Lib/ctypes/test/test_anon.py deleted file mode 100644 index d378392ebe2..00000000000 --- a/Lib/ctypes/test/test_anon.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest -import test.support -from ctypes import * - -class AnonTest(unittest.TestCase): - - def test_anon(self): - class ANON(Union): - _fields_ = [("a", c_int), - ("b", c_int)] - - class Y(Structure): - _fields_ = [("x", c_int), - ("_", ANON), - ("y", c_int)] - _anonymous_ = ["_"] - - self.assertEqual(Y.a.offset, sizeof(c_int)) - self.assertEqual(Y.b.offset, sizeof(c_int)) - - self.assertEqual(ANON.a.offset, 0) - self.assertEqual(ANON.b.offset, 0) - - def test_anon_nonseq(self): - # TypeError: _anonymous_ must be a sequence - self.assertRaises(TypeError, - lambda: type(Structure)("Name", - (Structure,), - {"_fields_": [], "_anonymous_": 42})) - - def test_anon_nonmember(self): - # AttributeError: type object 'Name' has no attribute 'x' - self.assertRaises(AttributeError, - lambda: type(Structure)("Name", - (Structure,), - {"_fields_": [], - "_anonymous_": ["x"]})) - - @test.support.cpython_only - def test_issue31490(self): - # There shouldn't be an assertion failure in case the class has an - # attribute whose name is specified in _anonymous_ but not in _fields_. - - # AttributeError: 'x' is specified in _anonymous_ but not in _fields_ - with self.assertRaises(AttributeError): - class Name(Structure): - _fields_ = [] - _anonymous_ = ["x"] - x = 42 - - def test_nested(self): - class ANON_S(Structure): - _fields_ = [("a", c_int)] - - class ANON_U(Union): - _fields_ = [("_", ANON_S), - ("b", c_int)] - _anonymous_ = ["_"] - - class Y(Structure): - _fields_ = [("x", c_int), - ("_", ANON_U), - ("y", c_int)] - _anonymous_ = ["_"] - - self.assertEqual(Y.x.offset, 0) - self.assertEqual(Y.a.offset, sizeof(c_int)) - self.assertEqual(Y.b.offset, sizeof(c_int)) - self.assertEqual(Y._.offset, sizeof(c_int)) - self.assertEqual(Y.y.offset, sizeof(c_int) * 2) - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_array_in_pointer.py b/Lib/ctypes/test/test_array_in_pointer.py deleted file mode 100644 index ca1edcf6210..00000000000 --- a/Lib/ctypes/test/test_array_in_pointer.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest -from ctypes import * -from binascii import hexlify -import re - -def dump(obj): - # helper function to dump memory contents in hex, with a hyphen - # between the bytes. - h = hexlify(memoryview(obj)).decode() - return re.sub(r"(..)", r"\1-", h)[:-1] - - -class Value(Structure): - _fields_ = [("val", c_byte)] - -class Container(Structure): - _fields_ = [("pvalues", POINTER(Value))] - -class Test(unittest.TestCase): - def test(self): - # create an array of 4 values - val_array = (Value * 4)() - - # create a container, which holds a pointer to the pvalues array. - c = Container() - c.pvalues = val_array - - # memory contains 4 NUL bytes now, that's correct - self.assertEqual("00-00-00-00", dump(val_array)) - - # set the values of the array through the pointer: - for i in range(4): - c.pvalues[i].val = i + 1 - - values = [c.pvalues[i].val for i in range(4)] - - # These are the expected results: here s the bug! - self.assertEqual( - (values, dump(val_array)), - ([1, 2, 3, 4], "01-02-03-04") - ) - - def test_2(self): - - val_array = (Value * 4)() - - # memory contains 4 NUL bytes now, that's correct - self.assertEqual("00-00-00-00", dump(val_array)) - - ptr = cast(val_array, POINTER(Value)) - # set the values of the array through the pointer: - for i in range(4): - ptr[i].val = i + 1 - - values = [ptr[i].val for i in range(4)] - - # These are the expected results: here s the bug! - self.assertEqual( - (values, dump(val_array)), - ([1, 2, 3, 4], "01-02-03-04") - ) - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_arrays.py b/Lib/ctypes/test/test_arrays.py deleted file mode 100644 index 14603b7049c..00000000000 --- a/Lib/ctypes/test/test_arrays.py +++ /dev/null @@ -1,238 +0,0 @@ -import unittest -from test.support import bigmemtest, _2G -import sys -from ctypes import * - -from ctypes.test import need_symbol - -formats = "bBhHiIlLqQfd" - -formats = c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, \ - c_long, c_ulonglong, c_float, c_double, c_longdouble - -class ArrayTestCase(unittest.TestCase): - def test_simple(self): - # create classes holding simple numeric types, and check - # various properties. - - init = list(range(15, 25)) - - for fmt in formats: - alen = len(init) - int_array = ARRAY(fmt, alen) - - ia = int_array(*init) - # length of instance ok? - self.assertEqual(len(ia), alen) - - # slot values ok? - values = [ia[i] for i in range(alen)] - self.assertEqual(values, init) - - # out-of-bounds accesses should be caught - with self.assertRaises(IndexError): ia[alen] - with self.assertRaises(IndexError): ia[-alen-1] - - # change the items - from operator import setitem - new_values = list(range(42, 42+alen)) - [setitem(ia, n, new_values[n]) for n in range(alen)] - values = [ia[i] for i in range(alen)] - self.assertEqual(values, new_values) - - # are the items initialized to 0? - ia = int_array() - values = [ia[i] for i in range(alen)] - self.assertEqual(values, [0] * alen) - - # Too many initializers should be caught - self.assertRaises(IndexError, int_array, *range(alen*2)) - - CharArray = ARRAY(c_char, 3) - - ca = CharArray(b"a", b"b", b"c") - - # Should this work? It doesn't: - # CharArray("abc") - self.assertRaises(TypeError, CharArray, "abc") - - self.assertEqual(ca[0], b"a") - self.assertEqual(ca[1], b"b") - self.assertEqual(ca[2], b"c") - self.assertEqual(ca[-3], b"a") - self.assertEqual(ca[-2], b"b") - self.assertEqual(ca[-1], b"c") - - self.assertEqual(len(ca), 3) - - # cannot delete items - from operator import delitem - self.assertRaises(TypeError, delitem, ca, 0) - - def test_step_overflow(self): - a = (c_int * 5)() - a[3::sys.maxsize] = (1,) - self.assertListEqual(a[3::sys.maxsize], [1]) - a = (c_char * 5)() - a[3::sys.maxsize] = b"A" - self.assertEqual(a[3::sys.maxsize], b"A") - a = (c_wchar * 5)() - a[3::sys.maxsize] = u"X" - self.assertEqual(a[3::sys.maxsize], u"X") - - def test_numeric_arrays(self): - - alen = 5 - - numarray = ARRAY(c_int, alen) - - na = numarray() - values = [na[i] for i in range(alen)] - self.assertEqual(values, [0] * alen) - - na = numarray(*[c_int()] * alen) - values = [na[i] for i in range(alen)] - self.assertEqual(values, [0]*alen) - - na = numarray(1, 2, 3, 4, 5) - values = [i for i in na] - self.assertEqual(values, [1, 2, 3, 4, 5]) - - na = numarray(*map(c_int, (1, 2, 3, 4, 5))) - values = [i for i in na] - self.assertEqual(values, [1, 2, 3, 4, 5]) - - def test_classcache(self): - self.assertIsNot(ARRAY(c_int, 3), ARRAY(c_int, 4)) - self.assertIs(ARRAY(c_int, 3), ARRAY(c_int, 3)) - - def test_from_address(self): - # Failed with 0.9.8, reported by JUrner - p = create_string_buffer(b"foo") - sz = (c_char * 3).from_address(addressof(p)) - self.assertEqual(sz[:], b"foo") - self.assertEqual(sz[::], b"foo") - self.assertEqual(sz[::-1], b"oof") - self.assertEqual(sz[::3], b"f") - self.assertEqual(sz[1:4:2], b"o") - self.assertEqual(sz.value, b"foo") - - @need_symbol('create_unicode_buffer') - def test_from_addressW(self): - p = create_unicode_buffer("foo") - sz = (c_wchar * 3).from_address(addressof(p)) - self.assertEqual(sz[:], "foo") - self.assertEqual(sz[::], "foo") - self.assertEqual(sz[::-1], "oof") - self.assertEqual(sz[::3], "f") - self.assertEqual(sz[1:4:2], "o") - self.assertEqual(sz.value, "foo") - - def test_cache(self): - # Array types are cached internally in the _ctypes extension, - # in a WeakValueDictionary. Make sure the array type is - # removed from the cache when the itemtype goes away. This - # test will not fail, but will show a leak in the testsuite. - - # Create a new type: - class my_int(c_int): - pass - # Create a new array type based on it: - t1 = my_int * 1 - t2 = my_int * 1 - self.assertIs(t1, t2) - - def test_subclass(self): - class T(Array): - _type_ = c_int - _length_ = 13 - class U(T): - pass - class V(U): - pass - class W(V): - pass - class X(T): - _type_ = c_short - class Y(T): - _length_ = 187 - - for c in [T, U, V, W]: - self.assertEqual(c._type_, c_int) - self.assertEqual(c._length_, 13) - self.assertEqual(c()._type_, c_int) - self.assertEqual(c()._length_, 13) - - self.assertEqual(X._type_, c_short) - self.assertEqual(X._length_, 13) - self.assertEqual(X()._type_, c_short) - self.assertEqual(X()._length_, 13) - - self.assertEqual(Y._type_, c_int) - self.assertEqual(Y._length_, 187) - self.assertEqual(Y()._type_, c_int) - self.assertEqual(Y()._length_, 187) - - def test_bad_subclass(self): - with self.assertRaises(AttributeError): - class T(Array): - pass - with self.assertRaises(AttributeError): - class T(Array): - _type_ = c_int - with self.assertRaises(AttributeError): - class T(Array): - _length_ = 13 - - def test_bad_length(self): - with self.assertRaises(ValueError): - class T(Array): - _type_ = c_int - _length_ = - sys.maxsize * 2 - with self.assertRaises(ValueError): - class T(Array): - _type_ = c_int - _length_ = -1 - with self.assertRaises(TypeError): - class T(Array): - _type_ = c_int - _length_ = 1.87 - with self.assertRaises(OverflowError): - class T(Array): - _type_ = c_int - _length_ = sys.maxsize * 2 - - def test_zero_length(self): - # _length_ can be zero. - class T(Array): - _type_ = c_int - _length_ = 0 - - def test_empty_element_struct(self): - class EmptyStruct(Structure): - _fields_ = [] - - obj = (EmptyStruct * 2)() # bpo37188: Floating point exception - self.assertEqual(sizeof(obj), 0) - - def test_empty_element_array(self): - class EmptyArray(Array): - _type_ = c_int - _length_ = 0 - - obj = (EmptyArray * 2)() # bpo37188: Floating point exception - self.assertEqual(sizeof(obj), 0) - - def test_bpo36504_signed_int_overflow(self): - # The overflow check in PyCArrayType_new() could cause signed integer - # overflow. - with self.assertRaises(OverflowError): - c_char * sys.maxsize * 2 - - @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') - @bigmemtest(size=_2G, memuse=1, dry_run=False) - def test_large_array(self, size): - c_char * size - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_as_parameter.py b/Lib/ctypes/test/test_as_parameter.py deleted file mode 100644 index 9c39179d2a4..00000000000 --- a/Lib/ctypes/test/test_as_parameter.py +++ /dev/null @@ -1,231 +0,0 @@ -import unittest -from ctypes import * -from ctypes.test import need_symbol -import _ctypes_test - -dll = CDLL(_ctypes_test.__file__) - -try: - CALLBACK_FUNCTYPE = WINFUNCTYPE -except NameError: - # fake to enable this test on Linux - CALLBACK_FUNCTYPE = CFUNCTYPE - -class POINT(Structure): - _fields_ = [("x", c_int), ("y", c_int)] - -class BasicWrapTestCase(unittest.TestCase): - def wrap(self, param): - return param - - @need_symbol('c_wchar') - def test_wchar_parm(self): - f = dll._testfunc_i_bhilfd - f.argtypes = [c_byte, c_wchar, c_int, c_long, c_float, c_double] - result = f(self.wrap(1), self.wrap("x"), self.wrap(3), self.wrap(4), self.wrap(5.0), self.wrap(6.0)) - self.assertEqual(result, 139) - self.assertIs(type(result), int) - - def test_pointers(self): - f = dll._testfunc_p_p - f.restype = POINTER(c_int) - f.argtypes = [POINTER(c_int)] - - # This only works if the value c_int(42) passed to the - # function is still alive while the pointer (the result) is - # used. - - v = c_int(42) - - self.assertEqual(pointer(v).contents.value, 42) - result = f(self.wrap(pointer(v))) - self.assertEqual(type(result), POINTER(c_int)) - self.assertEqual(result.contents.value, 42) - - # This on works... - result = f(self.wrap(pointer(v))) - self.assertEqual(result.contents.value, v.value) - - p = pointer(c_int(99)) - result = f(self.wrap(p)) - self.assertEqual(result.contents.value, 99) - - def test_shorts(self): - f = dll._testfunc_callback_i_if - - args = [] - expected = [262144, 131072, 65536, 32768, 16384, 8192, 4096, 2048, - 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1] - - def callback(v): - args.append(v) - return v - - CallBack = CFUNCTYPE(c_int, c_int) - - cb = CallBack(callback) - f(self.wrap(2**18), self.wrap(cb)) - self.assertEqual(args, expected) - - ################################################################ - - def test_callbacks(self): - f = dll._testfunc_callback_i_if - f.restype = c_int - f.argtypes = None - - MyCallback = CFUNCTYPE(c_int, c_int) - - def callback(value): - #print "called back with", value - return value - - cb = MyCallback(callback) - - result = f(self.wrap(-10), self.wrap(cb)) - self.assertEqual(result, -18) - - # test with prototype - f.argtypes = [c_int, MyCallback] - cb = MyCallback(callback) - - result = f(self.wrap(-10), self.wrap(cb)) - self.assertEqual(result, -18) - - result = f(self.wrap(-10), self.wrap(cb)) - self.assertEqual(result, -18) - - AnotherCallback = CALLBACK_FUNCTYPE(c_int, c_int, c_int, c_int, c_int) - - # check that the prototype works: we call f with wrong - # argument types - cb = AnotherCallback(callback) - self.assertRaises(ArgumentError, f, self.wrap(-10), self.wrap(cb)) - - def test_callbacks_2(self): - # Can also use simple datatypes as argument type specifiers - # for the callback function. - # In this case the call receives an instance of that type - f = dll._testfunc_callback_i_if - f.restype = c_int - - MyCallback = CFUNCTYPE(c_int, c_int) - - f.argtypes = [c_int, MyCallback] - - def callback(value): - #print "called back with", value - self.assertEqual(type(value), int) - return value - - cb = MyCallback(callback) - result = f(self.wrap(-10), self.wrap(cb)) - self.assertEqual(result, -18) - - @need_symbol('c_longlong') - def test_longlong_callbacks(self): - - f = dll._testfunc_callback_q_qf - f.restype = c_longlong - - MyCallback = CFUNCTYPE(c_longlong, c_longlong) - - f.argtypes = [c_longlong, MyCallback] - - def callback(value): - self.assertIsInstance(value, int) - return value & 0x7FFFFFFF - - cb = MyCallback(callback) - - self.assertEqual(13577625587, int(f(self.wrap(1000000000000), self.wrap(cb)))) - - def test_byval(self): - # without prototype - ptin = POINT(1, 2) - ptout = POINT() - # EXPORT int _testfunc_byval(point in, point *pout) - result = dll._testfunc_byval(ptin, byref(ptout)) - got = result, ptout.x, ptout.y - expected = 3, 1, 2 - self.assertEqual(got, expected) - - # with prototype - ptin = POINT(101, 102) - ptout = POINT() - dll._testfunc_byval.argtypes = (POINT, POINTER(POINT)) - dll._testfunc_byval.restype = c_int - result = dll._testfunc_byval(self.wrap(ptin), byref(ptout)) - got = result, ptout.x, ptout.y - expected = 203, 101, 102 - self.assertEqual(got, expected) - - def test_struct_return_2H(self): - class S2H(Structure): - _fields_ = [("x", c_short), - ("y", c_short)] - dll.ret_2h_func.restype = S2H - dll.ret_2h_func.argtypes = [S2H] - inp = S2H(99, 88) - s2h = dll.ret_2h_func(self.wrap(inp)) - self.assertEqual((s2h.x, s2h.y), (99*2, 88*3)) - - # Test also that the original struct was unmodified (i.e. was passed by - # value) - self.assertEqual((inp.x, inp.y), (99, 88)) - - def test_struct_return_8H(self): - class S8I(Structure): - _fields_ = [("a", c_int), - ("b", c_int), - ("c", c_int), - ("d", c_int), - ("e", c_int), - ("f", c_int), - ("g", c_int), - ("h", c_int)] - dll.ret_8i_func.restype = S8I - dll.ret_8i_func.argtypes = [S8I] - inp = S8I(9, 8, 7, 6, 5, 4, 3, 2) - s8i = dll.ret_8i_func(self.wrap(inp)) - self.assertEqual((s8i.a, s8i.b, s8i.c, s8i.d, s8i.e, s8i.f, s8i.g, s8i.h), - (9*2, 8*3, 7*4, 6*5, 5*6, 4*7, 3*8, 2*9)) - - def test_recursive_as_param(self): - from ctypes import c_int - - class A(object): - pass - - a = A() - a._as_parameter_ = a - with self.assertRaises(RecursionError): - c_int.from_param(a) - - -#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -class AsParamWrapper(object): - def __init__(self, param): - self._as_parameter_ = param - -class AsParamWrapperTestCase(BasicWrapTestCase): - wrap = AsParamWrapper - -#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -class AsParamPropertyWrapper(object): - def __init__(self, param): - self._param = param - - def getParameter(self): - return self._param - _as_parameter_ = property(getParameter) - -class AsParamPropertyWrapperTestCase(BasicWrapTestCase): - wrap = AsParamPropertyWrapper - -#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_bitfields.py b/Lib/ctypes/test/test_bitfields.py deleted file mode 100644 index 66acd62e685..00000000000 --- a/Lib/ctypes/test/test_bitfields.py +++ /dev/null @@ -1,297 +0,0 @@ -from ctypes import * -from ctypes.test import need_symbol -from test import support -import unittest -import os - -import _ctypes_test - -class BITS(Structure): - _fields_ = [("A", c_int, 1), - ("B", c_int, 2), - ("C", c_int, 3), - ("D", c_int, 4), - ("E", c_int, 5), - ("F", c_int, 6), - ("G", c_int, 7), - ("H", c_int, 8), - ("I", c_int, 9), - - ("M", c_short, 1), - ("N", c_short, 2), - ("O", c_short, 3), - ("P", c_short, 4), - ("Q", c_short, 5), - ("R", c_short, 6), - ("S", c_short, 7)] - -func = CDLL(_ctypes_test.__file__).unpack_bitfields -func.argtypes = POINTER(BITS), c_char - -##for n in "ABCDEFGHIMNOPQRS": -## print n, hex(getattr(BITS, n).size), getattr(BITS, n).offset - -class C_Test(unittest.TestCase): - - def test_ints(self): - for i in range(512): - for name in "ABCDEFGHI": - b = BITS() - setattr(b, name, i) - self.assertEqual(getattr(b, name), func(byref(b), name.encode('ascii'))) - - # bpo-46913: _ctypes/cfield.c h_get() has an undefined behavior - @support.skip_if_sanitizer(ub=True) - def test_shorts(self): - b = BITS() - name = "M" - if func(byref(b), name.encode('ascii')) == 999: - self.skipTest("Compiler does not support signed short bitfields") - for i in range(256): - for name in "MNOPQRS": - b = BITS() - setattr(b, name, i) - self.assertEqual(getattr(b, name), func(byref(b), name.encode('ascii'))) - -signed_int_types = (c_byte, c_short, c_int, c_long, c_longlong) -unsigned_int_types = (c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong) -int_types = unsigned_int_types + signed_int_types - -class BitFieldTest(unittest.TestCase): - - def test_longlong(self): - class X(Structure): - _fields_ = [("a", c_longlong, 1), - ("b", c_longlong, 62), - ("c", c_longlong, 1)] - - self.assertEqual(sizeof(X), sizeof(c_longlong)) - x = X() - x.a, x.b, x.c = -1, 7, -1 - self.assertEqual((x.a, x.b, x.c), (-1, 7, -1)) - - def test_ulonglong(self): - class X(Structure): - _fields_ = [("a", c_ulonglong, 1), - ("b", c_ulonglong, 62), - ("c", c_ulonglong, 1)] - - self.assertEqual(sizeof(X), sizeof(c_longlong)) - x = X() - self.assertEqual((x.a, x.b, x.c), (0, 0, 0)) - x.a, x.b, x.c = 7, 7, 7 - self.assertEqual((x.a, x.b, x.c), (1, 7, 1)) - - def test_signed(self): - for c_typ in signed_int_types: - class X(Structure): - _fields_ = [("dummy", c_typ), - ("a", c_typ, 3), - ("b", c_typ, 3), - ("c", c_typ, 1)] - self.assertEqual(sizeof(X), sizeof(c_typ)*2) - - x = X() - self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0)) - x.a = -1 - self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, -1, 0, 0)) - x.a, x.b = 0, -1 - self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, -1, 0)) - - - def test_unsigned(self): - for c_typ in unsigned_int_types: - class X(Structure): - _fields_ = [("a", c_typ, 3), - ("b", c_typ, 3), - ("c", c_typ, 1)] - self.assertEqual(sizeof(X), sizeof(c_typ)) - - x = X() - self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0)) - x.a = -1 - self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 7, 0, 0)) - x.a, x.b = 0, -1 - self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 7, 0)) - - - def fail_fields(self, *fields): - return self.get_except(type(Structure), "X", (), - {"_fields_": fields}) - - def test_nonint_types(self): - # bit fields are not allowed on non-integer types. - result = self.fail_fields(("a", c_char_p, 1)) - self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char_p')) - - result = self.fail_fields(("a", c_void_p, 1)) - self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_void_p')) - - if c_int != c_long: - result = self.fail_fields(("a", POINTER(c_int), 1)) - self.assertEqual(result, (TypeError, 'bit fields not allowed for type LP_c_int')) - - result = self.fail_fields(("a", c_char, 1)) - self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char')) - - class Dummy(Structure): - _fields_ = [] - - result = self.fail_fields(("a", Dummy, 1)) - self.assertEqual(result, (TypeError, 'bit fields not allowed for type Dummy')) - - @need_symbol('c_wchar') - def test_c_wchar(self): - result = self.fail_fields(("a", c_wchar, 1)) - self.assertEqual(result, - (TypeError, 'bit fields not allowed for type c_wchar')) - - def test_single_bitfield_size(self): - for c_typ in int_types: - result = self.fail_fields(("a", c_typ, -1)) - self.assertEqual(result, (ValueError, 'number of bits invalid for bit field')) - - result = self.fail_fields(("a", c_typ, 0)) - self.assertEqual(result, (ValueError, 'number of bits invalid for bit field')) - - class X(Structure): - _fields_ = [("a", c_typ, 1)] - self.assertEqual(sizeof(X), sizeof(c_typ)) - - class X(Structure): - _fields_ = [("a", c_typ, sizeof(c_typ)*8)] - self.assertEqual(sizeof(X), sizeof(c_typ)) - - result = self.fail_fields(("a", c_typ, sizeof(c_typ)*8 + 1)) - self.assertEqual(result, (ValueError, 'number of bits invalid for bit field')) - - def test_multi_bitfields_size(self): - class X(Structure): - _fields_ = [("a", c_short, 1), - ("b", c_short, 14), - ("c", c_short, 1)] - self.assertEqual(sizeof(X), sizeof(c_short)) - - class X(Structure): - _fields_ = [("a", c_short, 1), - ("a1", c_short), - ("b", c_short, 14), - ("c", c_short, 1)] - self.assertEqual(sizeof(X), sizeof(c_short)*3) - self.assertEqual(X.a.offset, 0) - self.assertEqual(X.a1.offset, sizeof(c_short)) - self.assertEqual(X.b.offset, sizeof(c_short)*2) - self.assertEqual(X.c.offset, sizeof(c_short)*2) - - class X(Structure): - _fields_ = [("a", c_short, 3), - ("b", c_short, 14), - ("c", c_short, 14)] - self.assertEqual(sizeof(X), sizeof(c_short)*3) - self.assertEqual(X.a.offset, sizeof(c_short)*0) - self.assertEqual(X.b.offset, sizeof(c_short)*1) - self.assertEqual(X.c.offset, sizeof(c_short)*2) - - - def get_except(self, func, *args, **kw): - try: - func(*args, **kw) - except Exception as detail: - return detail.__class__, str(detail) - - def test_mixed_1(self): - class X(Structure): - _fields_ = [("a", c_byte, 4), - ("b", c_int, 4)] - if os.name == "nt": - self.assertEqual(sizeof(X), sizeof(c_int)*2) - else: - self.assertEqual(sizeof(X), sizeof(c_int)) - - def test_mixed_2(self): - class X(Structure): - _fields_ = [("a", c_byte, 4), - ("b", c_int, 32)] - self.assertEqual(sizeof(X), alignment(c_int)+sizeof(c_int)) - - def test_mixed_3(self): - class X(Structure): - _fields_ = [("a", c_byte, 4), - ("b", c_ubyte, 4)] - self.assertEqual(sizeof(X), sizeof(c_byte)) - - def test_mixed_4(self): - class X(Structure): - _fields_ = [("a", c_short, 4), - ("b", c_short, 4), - ("c", c_int, 24), - ("d", c_short, 4), - ("e", c_short, 4), - ("f", c_int, 24)] - # MSVC does NOT combine c_short and c_int into one field, GCC - # does (unless GCC is run with '-mms-bitfields' which - # produces code compatible with MSVC). - if os.name == "nt": - self.assertEqual(sizeof(X), sizeof(c_int) * 4) - else: - self.assertEqual(sizeof(X), sizeof(c_int) * 2) - - def test_anon_bitfields(self): - # anonymous bit-fields gave a strange error message - class X(Structure): - _fields_ = [("a", c_byte, 4), - ("b", c_ubyte, 4)] - class Y(Structure): - _anonymous_ = ["_"] - _fields_ = [("_", X)] - - @need_symbol('c_uint32') - def test_uint32(self): - class X(Structure): - _fields_ = [("a", c_uint32, 32)] - x = X() - x.a = 10 - self.assertEqual(x.a, 10) - x.a = 0xFDCBA987 - self.assertEqual(x.a, 0xFDCBA987) - - @need_symbol('c_uint64') - def test_uint64(self): - class X(Structure): - _fields_ = [("a", c_uint64, 64)] - x = X() - x.a = 10 - self.assertEqual(x.a, 10) - x.a = 0xFEDCBA9876543211 - self.assertEqual(x.a, 0xFEDCBA9876543211) - - @need_symbol('c_uint32') - def test_uint32_swap_little_endian(self): - # Issue #23319 - class Little(LittleEndianStructure): - _fields_ = [("a", c_uint32, 24), - ("b", c_uint32, 4), - ("c", c_uint32, 4)] - b = bytearray(4) - x = Little.from_buffer(b) - x.a = 0xabcdef - x.b = 1 - x.c = 2 - self.assertEqual(b, b'\xef\xcd\xab\x21') - - @need_symbol('c_uint32') - def test_uint32_swap_big_endian(self): - # Issue #23319 - class Big(BigEndianStructure): - _fields_ = [("a", c_uint32, 24), - ("b", c_uint32, 4), - ("c", c_uint32, 4)] - b = bytearray(4) - x = Big.from_buffer(b) - x.a = 0xabcdef - x.b = 1 - x.c = 2 - self.assertEqual(b, b'\xab\xcd\xef\x12') - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_buffers.py b/Lib/ctypes/test/test_buffers.py deleted file mode 100644 index 15782be757c..00000000000 --- a/Lib/ctypes/test/test_buffers.py +++ /dev/null @@ -1,73 +0,0 @@ -from ctypes import * -from ctypes.test import need_symbol -import unittest - -class StringBufferTestCase(unittest.TestCase): - - def test_buffer(self): - b = create_string_buffer(32) - self.assertEqual(len(b), 32) - self.assertEqual(sizeof(b), 32 * sizeof(c_char)) - self.assertIs(type(b[0]), bytes) - - b = create_string_buffer(b"abc") - self.assertEqual(len(b), 4) # trailing nul char - self.assertEqual(sizeof(b), 4 * sizeof(c_char)) - self.assertIs(type(b[0]), bytes) - self.assertEqual(b[0], b"a") - self.assertEqual(b[:], b"abc\0") - self.assertEqual(b[::], b"abc\0") - self.assertEqual(b[::-1], b"\0cba") - self.assertEqual(b[::2], b"ac") - self.assertEqual(b[::5], b"a") - - self.assertRaises(TypeError, create_string_buffer, "abc") - - def test_buffer_interface(self): - self.assertEqual(len(bytearray(create_string_buffer(0))), 0) - self.assertEqual(len(bytearray(create_string_buffer(1))), 1) - - @need_symbol('c_wchar') - def test_unicode_buffer(self): - b = create_unicode_buffer(32) - self.assertEqual(len(b), 32) - self.assertEqual(sizeof(b), 32 * sizeof(c_wchar)) - self.assertIs(type(b[0]), str) - - b = create_unicode_buffer("abc") - self.assertEqual(len(b), 4) # trailing nul char - self.assertEqual(sizeof(b), 4 * sizeof(c_wchar)) - self.assertIs(type(b[0]), str) - self.assertEqual(b[0], "a") - self.assertEqual(b[:], "abc\0") - self.assertEqual(b[::], "abc\0") - self.assertEqual(b[::-1], "\0cba") - self.assertEqual(b[::2], "ac") - self.assertEqual(b[::5], "a") - - self.assertRaises(TypeError, create_unicode_buffer, b"abc") - - @need_symbol('c_wchar') - def test_unicode_conversion(self): - b = create_unicode_buffer("abc") - self.assertEqual(len(b), 4) # trailing nul char - self.assertEqual(sizeof(b), 4 * sizeof(c_wchar)) - self.assertIs(type(b[0]), str) - self.assertEqual(b[0], "a") - self.assertEqual(b[:], "abc\0") - self.assertEqual(b[::], "abc\0") - self.assertEqual(b[::-1], "\0cba") - self.assertEqual(b[::2], "ac") - self.assertEqual(b[::5], "a") - - @need_symbol('c_wchar') - def test_create_unicode_buffer_non_bmp(self): - expected = 5 if sizeof(c_wchar) == 2 else 3 - for s in '\U00010000\U00100000', '\U00010000\U0010ffff': - b = create_unicode_buffer(s) - self.assertEqual(len(b), expected) - self.assertEqual(b[-1], '\0') - - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_bytes.py b/Lib/ctypes/test/test_bytes.py deleted file mode 100644 index 092ec5af052..00000000000 --- a/Lib/ctypes/test/test_bytes.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Test where byte objects are accepted""" -import unittest -import sys -from ctypes import * - -class BytesTest(unittest.TestCase): - def test_c_char(self): - x = c_char(b"x") - self.assertRaises(TypeError, c_char, "x") - x.value = b"y" - with self.assertRaises(TypeError): - x.value = "y" - c_char.from_param(b"x") - self.assertRaises(TypeError, c_char.from_param, "x") - self.assertIn('xbd', repr(c_char.from_param(b"\xbd"))) - (c_char * 3)(b"a", b"b", b"c") - self.assertRaises(TypeError, c_char * 3, "a", "b", "c") - - def test_c_wchar(self): - x = c_wchar("x") - self.assertRaises(TypeError, c_wchar, b"x") - x.value = "y" - with self.assertRaises(TypeError): - x.value = b"y" - c_wchar.from_param("x") - self.assertRaises(TypeError, c_wchar.from_param, b"x") - (c_wchar * 3)("a", "b", "c") - self.assertRaises(TypeError, c_wchar * 3, b"a", b"b", b"c") - - def test_c_char_p(self): - c_char_p(b"foo bar") - self.assertRaises(TypeError, c_char_p, "foo bar") - - def test_c_wchar_p(self): - c_wchar_p("foo bar") - self.assertRaises(TypeError, c_wchar_p, b"foo bar") - - def test_struct(self): - class X(Structure): - _fields_ = [("a", c_char * 3)] - - x = X(b"abc") - self.assertRaises(TypeError, X, "abc") - self.assertEqual(x.a, b"abc") - self.assertEqual(type(x.a), bytes) - - def test_struct_W(self): - class X(Structure): - _fields_ = [("a", c_wchar * 3)] - - x = X("abc") - self.assertRaises(TypeError, X, b"abc") - self.assertEqual(x.a, "abc") - self.assertEqual(type(x.a), str) - - @unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') - def test_BSTR(self): - from _ctypes import _SimpleCData - class BSTR(_SimpleCData): - _type_ = "X" - - BSTR("abc") - - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_byteswap.py b/Lib/ctypes/test/test_byteswap.py deleted file mode 100644 index 7e98559dfbc..00000000000 --- a/Lib/ctypes/test/test_byteswap.py +++ /dev/null @@ -1,364 +0,0 @@ -import sys, unittest, struct, math, ctypes -from binascii import hexlify - -from ctypes import * - -def bin(s): - return hexlify(memoryview(s)).decode().upper() - -# Each *simple* type that supports different byte orders has an -# __ctype_be__ attribute that specifies the same type in BIG ENDIAN -# byte order, and a __ctype_le__ attribute that is the same type in -# LITTLE ENDIAN byte order. -# -# For Structures and Unions, these types are created on demand. - -class Test(unittest.TestCase): - @unittest.skip('test disabled') - def test_X(self): - print(sys.byteorder, file=sys.stderr) - for i in range(32): - bits = BITS() - setattr(bits, "i%s" % i, 1) - dump(bits) - - def test_slots(self): - class BigPoint(BigEndianStructure): - __slots__ = () - _fields_ = [("x", c_int), ("y", c_int)] - - class LowPoint(LittleEndianStructure): - __slots__ = () - _fields_ = [("x", c_int), ("y", c_int)] - - big = BigPoint() - little = LowPoint() - big.x = 4 - big.y = 2 - little.x = 2 - little.y = 4 - with self.assertRaises(AttributeError): - big.z = 42 - with self.assertRaises(AttributeError): - little.z = 24 - - def test_endian_short(self): - if sys.byteorder == "little": - self.assertIs(c_short.__ctype_le__, c_short) - self.assertIs(c_short.__ctype_be__.__ctype_le__, c_short) - else: - self.assertIs(c_short.__ctype_be__, c_short) - self.assertIs(c_short.__ctype_le__.__ctype_be__, c_short) - s = c_short.__ctype_be__(0x1234) - self.assertEqual(bin(struct.pack(">h", 0x1234)), "1234") - self.assertEqual(bin(s), "1234") - self.assertEqual(s.value, 0x1234) - - s = c_short.__ctype_le__(0x1234) - self.assertEqual(bin(struct.pack("h", 0x1234)), "1234") - self.assertEqual(bin(s), "1234") - self.assertEqual(s.value, 0x1234) - - s = c_ushort.__ctype_le__(0x1234) - self.assertEqual(bin(struct.pack("i", 0x12345678)), "12345678") - self.assertEqual(bin(s), "12345678") - self.assertEqual(s.value, 0x12345678) - - s = c_int.__ctype_le__(0x12345678) - self.assertEqual(bin(struct.pack("I", 0x12345678)), "12345678") - self.assertEqual(bin(s), "12345678") - self.assertEqual(s.value, 0x12345678) - - s = c_uint.__ctype_le__(0x12345678) - self.assertEqual(bin(struct.pack("q", 0x1234567890ABCDEF)), "1234567890ABCDEF") - self.assertEqual(bin(s), "1234567890ABCDEF") - self.assertEqual(s.value, 0x1234567890ABCDEF) - - s = c_longlong.__ctype_le__(0x1234567890ABCDEF) - self.assertEqual(bin(struct.pack("Q", 0x1234567890ABCDEF)), "1234567890ABCDEF") - self.assertEqual(bin(s), "1234567890ABCDEF") - self.assertEqual(s.value, 0x1234567890ABCDEF) - - s = c_ulonglong.__ctype_le__(0x1234567890ABCDEF) - self.assertEqual(bin(struct.pack("f", math.pi)), bin(s)) - - def test_endian_double(self): - if sys.byteorder == "little": - self.assertIs(c_double.__ctype_le__, c_double) - self.assertIs(c_double.__ctype_be__.__ctype_le__, c_double) - else: - self.assertIs(c_double.__ctype_be__, c_double) - self.assertIs(c_double.__ctype_le__.__ctype_be__, c_double) - s = c_double(math.pi) - self.assertEqual(s.value, math.pi) - self.assertEqual(bin(struct.pack("d", math.pi)), bin(s)) - s = c_double.__ctype_le__(math.pi) - self.assertEqual(s.value, math.pi) - self.assertEqual(bin(struct.pack("d", math.pi)), bin(s)) - - def test_endian_other(self): - self.assertIs(c_byte.__ctype_le__, c_byte) - self.assertIs(c_byte.__ctype_be__, c_byte) - - self.assertIs(c_ubyte.__ctype_le__, c_ubyte) - self.assertIs(c_ubyte.__ctype_be__, c_ubyte) - - self.assertIs(c_char.__ctype_le__, c_char) - self.assertIs(c_char.__ctype_be__, c_char) - - def test_struct_fields_unsupported_byte_order(self): - - fields = [ - ("a", c_ubyte), - ("b", c_byte), - ("c", c_short), - ("d", c_ushort), - ("e", c_int), - ("f", c_uint), - ("g", c_long), - ("h", c_ulong), - ("i", c_longlong), - ("k", c_ulonglong), - ("l", c_float), - ("m", c_double), - ("n", c_char), - ("b1", c_byte, 3), - ("b2", c_byte, 3), - ("b3", c_byte, 2), - ("a", c_int * 3 * 3 * 3) - ] - - # these fields do not support different byte order: - for typ in c_wchar, c_void_p, POINTER(c_int): - with self.assertRaises(TypeError): - class T(BigEndianStructure if sys.byteorder == "little" else LittleEndianStructure): - _fields_ = fields + [("x", typ)] - - - def test_struct_struct(self): - # nested structures with different byteorders - - # create nested structures with given byteorders and set memory to data - - for nested, data in ( - (BigEndianStructure, b'\0\0\0\1\0\0\0\2'), - (LittleEndianStructure, b'\1\0\0\0\2\0\0\0'), - ): - for parent in ( - BigEndianStructure, - LittleEndianStructure, - Structure, - ): - class NestedStructure(nested): - _fields_ = [("x", c_uint32), - ("y", c_uint32)] - - class TestStructure(parent): - _fields_ = [("point", NestedStructure)] - - self.assertEqual(len(data), sizeof(TestStructure)) - ptr = POINTER(TestStructure) - s = cast(data, ptr)[0] - del ctypes._pointer_type_cache[TestStructure] - self.assertEqual(s.point.x, 1) - self.assertEqual(s.point.y, 2) - - def test_struct_field_alignment(self): - # standard packing in struct uses no alignment. - # So, we have to align using pad bytes. - # - # Unaligned accesses will crash Python (on those platforms that - # don't allow it, like sparc solaris). - if sys.byteorder == "little": - base = BigEndianStructure - fmt = ">bxhid" - else: - base = LittleEndianStructure - fmt = " float -> double - import math - self.check_type(c_float, math.e) - self.check_type(c_float, -math.e) - - def test_double(self): - self.check_type(c_double, 3.14) - self.check_type(c_double, -3.14) - - @need_symbol('c_longdouble') - def test_longdouble(self): - self.check_type(c_longdouble, 3.14) - self.check_type(c_longdouble, -3.14) - - def test_char(self): - self.check_type(c_char, b"x") - self.check_type(c_char, b"a") - - # disabled: would now (correctly) raise a RuntimeWarning about - # a memory leak. A callback function cannot return a non-integral - # C type without causing a memory leak. - @unittest.skip('test disabled') - def test_char_p(self): - self.check_type(c_char_p, "abc") - self.check_type(c_char_p, "def") - - def test_pyobject(self): - o = () - from sys import getrefcount as grc - for o in (), [], object(): - initial = grc(o) - # This call leaks a reference to 'o'... - self.check_type(py_object, o) - before = grc(o) - # ...but this call doesn't leak any more. Where is the refcount? - self.check_type(py_object, o) - after = grc(o) - self.assertEqual((after, o), (before, o)) - - def test_unsupported_restype_1(self): - # Only "fundamental" result types are supported for callback - # functions, the type must have a non-NULL stgdict->setfunc. - # POINTER(c_double), for example, is not supported. - - prototype = self.functype.__func__(POINTER(c_double)) - # The type is checked when the prototype is called - self.assertRaises(TypeError, prototype, lambda: None) - - def test_unsupported_restype_2(self): - prototype = self.functype.__func__(object) - self.assertRaises(TypeError, prototype, lambda: None) - - def test_issue_7959(self): - proto = self.functype.__func__(None) - - class X(object): - def func(self): pass - def __init__(self): - self.v = proto(self.func) - - import gc - for i in range(32): - X() - gc.collect() - live = [x for x in gc.get_objects() - if isinstance(x, X)] - self.assertEqual(len(live), 0) - - def test_issue12483(self): - import gc - class Nasty: - def __del__(self): - gc.collect() - CFUNCTYPE(None)(lambda x=Nasty(): None) - - -@need_symbol('WINFUNCTYPE') -class StdcallCallbacks(Callbacks): - try: - functype = WINFUNCTYPE - except NameError: - pass - -################################################################ - -class SampleCallbacksTestCase(unittest.TestCase): - - def test_integrate(self): - # Derived from some then non-working code, posted by David Foster - dll = CDLL(_ctypes_test.__file__) - - # The function prototype called by 'integrate': double func(double); - CALLBACK = CFUNCTYPE(c_double, c_double) - - # The integrate function itself, exposed from the _ctypes_test dll - integrate = dll.integrate - integrate.argtypes = (c_double, c_double, CALLBACK, c_long) - integrate.restype = c_double - - def func(x): - return x**2 - - result = integrate(0.0, 1.0, CALLBACK(func), 10) - diff = abs(result - 1./3.) - - self.assertLess(diff, 0.01, "%s not less than 0.01" % diff) - - def test_issue_8959_a(self): - from ctypes.util import find_library - libc_path = find_library("c") - if not libc_path: - self.skipTest('could not find libc') - libc = CDLL(libc_path) - - @CFUNCTYPE(c_int, POINTER(c_int), POINTER(c_int)) - def cmp_func(a, b): - return a[0] - b[0] - - array = (c_int * 5)(5, 1, 99, 7, 33) - - libc.qsort(array, len(array), sizeof(c_int), cmp_func) - self.assertEqual(array[:], [1, 5, 7, 33, 99]) - - @need_symbol('WINFUNCTYPE') - def test_issue_8959_b(self): - from ctypes.wintypes import BOOL, HWND, LPARAM - global windowCount - windowCount = 0 - - @WINFUNCTYPE(BOOL, HWND, LPARAM) - def EnumWindowsCallbackFunc(hwnd, lParam): - global windowCount - windowCount += 1 - return True #Allow windows to keep enumerating - - windll.user32.EnumWindows(EnumWindowsCallbackFunc, 0) - - def test_callback_register_int(self): - # Issue #8275: buggy handling of callback args under Win64 - # NOTE: should be run on release builds as well - dll = CDLL(_ctypes_test.__file__) - CALLBACK = CFUNCTYPE(c_int, c_int, c_int, c_int, c_int, c_int) - # All this function does is call the callback with its args squared - func = dll._testfunc_cbk_reg_int - func.argtypes = (c_int, c_int, c_int, c_int, c_int, CALLBACK) - func.restype = c_int - - def callback(a, b, c, d, e): - return a + b + c + d + e - - result = func(2, 3, 4, 5, 6, CALLBACK(callback)) - self.assertEqual(result, callback(2*2, 3*3, 4*4, 5*5, 6*6)) - - def test_callback_register_double(self): - # Issue #8275: buggy handling of callback args under Win64 - # NOTE: should be run on release builds as well - dll = CDLL(_ctypes_test.__file__) - CALLBACK = CFUNCTYPE(c_double, c_double, c_double, c_double, - c_double, c_double) - # All this function does is call the callback with its args squared - func = dll._testfunc_cbk_reg_double - func.argtypes = (c_double, c_double, c_double, - c_double, c_double, CALLBACK) - func.restype = c_double - - def callback(a, b, c, d, e): - return a + b + c + d + e - - result = func(1.1, 2.2, 3.3, 4.4, 5.5, CALLBACK(callback)) - self.assertEqual(result, - callback(1.1*1.1, 2.2*2.2, 3.3*3.3, 4.4*4.4, 5.5*5.5)) - - def test_callback_large_struct(self): - class Check: pass - - # This should mirror the structure in Modules/_ctypes/_ctypes_test.c - class X(Structure): - _fields_ = [ - ('first', c_ulong), - ('second', c_ulong), - ('third', c_ulong), - ] - - def callback(check, s): - check.first = s.first - check.second = s.second - check.third = s.third - # See issue #29565. - # The structure should be passed by value, so - # any changes to it should not be reflected in - # the value passed - s.first = s.second = s.third = 0x0badf00d - - check = Check() - s = X() - s.first = 0xdeadbeef - s.second = 0xcafebabe - s.third = 0x0bad1dea - - CALLBACK = CFUNCTYPE(None, X) - dll = CDLL(_ctypes_test.__file__) - func = dll._testfunc_cbk_large_struct - func.argtypes = (X, CALLBACK) - func.restype = None - # the function just calls the callback with the passed structure - func(s, CALLBACK(functools.partial(callback, check))) - self.assertEqual(check.first, s.first) - self.assertEqual(check.second, s.second) - self.assertEqual(check.third, s.third) - self.assertEqual(check.first, 0xdeadbeef) - self.assertEqual(check.second, 0xcafebabe) - self.assertEqual(check.third, 0x0bad1dea) - # See issue #29565. - # Ensure that the original struct is unchanged. - self.assertEqual(s.first, check.first) - self.assertEqual(s.second, check.second) - self.assertEqual(s.third, check.third) - - def test_callback_too_many_args(self): - def func(*args): - return len(args) - - # valid call with nargs <= CTYPES_MAX_ARGCOUNT - proto = CFUNCTYPE(c_int, *(c_int,) * CTYPES_MAX_ARGCOUNT) - cb = proto(func) - args1 = (1,) * CTYPES_MAX_ARGCOUNT - self.assertEqual(cb(*args1), CTYPES_MAX_ARGCOUNT) - - # invalid call with nargs > CTYPES_MAX_ARGCOUNT - args2 = (1,) * (CTYPES_MAX_ARGCOUNT + 1) - with self.assertRaises(ArgumentError): - cb(*args2) - - # error when creating the type with too many arguments - with self.assertRaises(ArgumentError): - CFUNCTYPE(c_int, *(c_int,) * (CTYPES_MAX_ARGCOUNT + 1)) - - def test_convert_result_error(self): - def func(): - return ("tuple",) - - proto = CFUNCTYPE(c_int) - ctypes_func = proto(func) - with support.catch_unraisable_exception() as cm: - # don't test the result since it is an uninitialized value - result = ctypes_func() - - self.assertIsInstance(cm.unraisable.exc_value, TypeError) - self.assertEqual(cm.unraisable.err_msg, - "Exception ignored on converting result " - "of ctypes callback function") - self.assertIs(cm.unraisable.object, func) - - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_cast.py b/Lib/ctypes/test/test_cast.py deleted file mode 100644 index 6878f973282..00000000000 --- a/Lib/ctypes/test/test_cast.py +++ /dev/null @@ -1,99 +0,0 @@ -from ctypes import * -from ctypes.test import need_symbol -import unittest -import sys - -class Test(unittest.TestCase): - - def test_array2pointer(self): - array = (c_int * 3)(42, 17, 2) - - # casting an array to a pointer works. - ptr = cast(array, POINTER(c_int)) - self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2]) - - if 2*sizeof(c_short) == sizeof(c_int): - ptr = cast(array, POINTER(c_short)) - if sys.byteorder == "little": - self.assertEqual([ptr[i] for i in range(6)], - [42, 0, 17, 0, 2, 0]) - else: - self.assertEqual([ptr[i] for i in range(6)], - [0, 42, 0, 17, 0, 2]) - - def test_address2pointer(self): - array = (c_int * 3)(42, 17, 2) - - address = addressof(array) - ptr = cast(c_void_p(address), POINTER(c_int)) - self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2]) - - ptr = cast(address, POINTER(c_int)) - self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2]) - - def test_p2a_objects(self): - array = (c_char_p * 5)() - self.assertEqual(array._objects, None) - array[0] = b"foo bar" - self.assertEqual(array._objects, {'0': b"foo bar"}) - - p = cast(array, POINTER(c_char_p)) - # array and p share a common _objects attribute - self.assertIs(p._objects, array._objects) - self.assertEqual(array._objects, {'0': b"foo bar", id(array): array}) - p[0] = b"spam spam" - self.assertEqual(p._objects, {'0': b"spam spam", id(array): array}) - self.assertIs(array._objects, p._objects) - p[1] = b"foo bar" - self.assertEqual(p._objects, {'1': b'foo bar', '0': b"spam spam", id(array): array}) - self.assertIs(array._objects, p._objects) - - def test_other(self): - p = cast((c_int * 4)(1, 2, 3, 4), POINTER(c_int)) - self.assertEqual(p[:4], [1,2, 3, 4]) - self.assertEqual(p[:4:], [1, 2, 3, 4]) - self.assertEqual(p[3:-1:-1], [4, 3, 2, 1]) - self.assertEqual(p[:4:3], [1, 4]) - c_int() - self.assertEqual(p[:4], [1, 2, 3, 4]) - self.assertEqual(p[:4:], [1, 2, 3, 4]) - self.assertEqual(p[3:-1:-1], [4, 3, 2, 1]) - self.assertEqual(p[:4:3], [1, 4]) - p[2] = 96 - self.assertEqual(p[:4], [1, 2, 96, 4]) - self.assertEqual(p[:4:], [1, 2, 96, 4]) - self.assertEqual(p[3:-1:-1], [4, 96, 2, 1]) - self.assertEqual(p[:4:3], [1, 4]) - c_int() - self.assertEqual(p[:4], [1, 2, 96, 4]) - self.assertEqual(p[:4:], [1, 2, 96, 4]) - self.assertEqual(p[3:-1:-1], [4, 96, 2, 1]) - self.assertEqual(p[:4:3], [1, 4]) - - def test_char_p(self): - # This didn't work: bad argument to internal function - s = c_char_p(b"hiho") - self.assertEqual(cast(cast(s, c_void_p), c_char_p).value, - b"hiho") - - @need_symbol('c_wchar_p') - def test_wchar_p(self): - s = c_wchar_p("hiho") - self.assertEqual(cast(cast(s, c_void_p), c_wchar_p).value, - "hiho") - - def test_bad_type_arg(self): - # The type argument must be a ctypes pointer type. - array_type = c_byte * sizeof(c_int) - array = array_type() - self.assertRaises(TypeError, cast, array, None) - self.assertRaises(TypeError, cast, array, array_type) - class Struct(Structure): - _fields_ = [("a", c_int)] - self.assertRaises(TypeError, cast, array, Struct) - class MyUnion(Union): - _fields_ = [("a", c_int)] - self.assertRaises(TypeError, cast, array, MyUnion) - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_cfuncs.py b/Lib/ctypes/test/test_cfuncs.py deleted file mode 100644 index 09b06840bf5..00000000000 --- a/Lib/ctypes/test/test_cfuncs.py +++ /dev/null @@ -1,218 +0,0 @@ -# A lot of failures in these tests on Mac OS X. -# Byte order related? - -import unittest -from ctypes import * -from ctypes.test import need_symbol - -import _ctypes_test - -class CFunctions(unittest.TestCase): - _dll = CDLL(_ctypes_test.__file__) - - def S(self): - return c_longlong.in_dll(self._dll, "last_tf_arg_s").value - def U(self): - return c_ulonglong.in_dll(self._dll, "last_tf_arg_u").value - - def test_byte(self): - self._dll.tf_b.restype = c_byte - self._dll.tf_b.argtypes = (c_byte,) - self.assertEqual(self._dll.tf_b(-126), -42) - self.assertEqual(self.S(), -126) - - def test_byte_plus(self): - self._dll.tf_bb.restype = c_byte - self._dll.tf_bb.argtypes = (c_byte, c_byte) - self.assertEqual(self._dll.tf_bb(0, -126), -42) - self.assertEqual(self.S(), -126) - - def test_ubyte(self): - self._dll.tf_B.restype = c_ubyte - self._dll.tf_B.argtypes = (c_ubyte,) - self.assertEqual(self._dll.tf_B(255), 85) - self.assertEqual(self.U(), 255) - - def test_ubyte_plus(self): - self._dll.tf_bB.restype = c_ubyte - self._dll.tf_bB.argtypes = (c_byte, c_ubyte) - self.assertEqual(self._dll.tf_bB(0, 255), 85) - self.assertEqual(self.U(), 255) - - def test_short(self): - self._dll.tf_h.restype = c_short - self._dll.tf_h.argtypes = (c_short,) - self.assertEqual(self._dll.tf_h(-32766), -10922) - self.assertEqual(self.S(), -32766) - - def test_short_plus(self): - self._dll.tf_bh.restype = c_short - self._dll.tf_bh.argtypes = (c_byte, c_short) - self.assertEqual(self._dll.tf_bh(0, -32766), -10922) - self.assertEqual(self.S(), -32766) - - def test_ushort(self): - self._dll.tf_H.restype = c_ushort - self._dll.tf_H.argtypes = (c_ushort,) - self.assertEqual(self._dll.tf_H(65535), 21845) - self.assertEqual(self.U(), 65535) - - def test_ushort_plus(self): - self._dll.tf_bH.restype = c_ushort - self._dll.tf_bH.argtypes = (c_byte, c_ushort) - self.assertEqual(self._dll.tf_bH(0, 65535), 21845) - self.assertEqual(self.U(), 65535) - - def test_int(self): - self._dll.tf_i.restype = c_int - self._dll.tf_i.argtypes = (c_int,) - self.assertEqual(self._dll.tf_i(-2147483646), -715827882) - self.assertEqual(self.S(), -2147483646) - - def test_int_plus(self): - self._dll.tf_bi.restype = c_int - self._dll.tf_bi.argtypes = (c_byte, c_int) - self.assertEqual(self._dll.tf_bi(0, -2147483646), -715827882) - self.assertEqual(self.S(), -2147483646) - - def test_uint(self): - self._dll.tf_I.restype = c_uint - self._dll.tf_I.argtypes = (c_uint,) - self.assertEqual(self._dll.tf_I(4294967295), 1431655765) - self.assertEqual(self.U(), 4294967295) - - def test_uint_plus(self): - self._dll.tf_bI.restype = c_uint - self._dll.tf_bI.argtypes = (c_byte, c_uint) - self.assertEqual(self._dll.tf_bI(0, 4294967295), 1431655765) - self.assertEqual(self.U(), 4294967295) - - def test_long(self): - self._dll.tf_l.restype = c_long - self._dll.tf_l.argtypes = (c_long,) - self.assertEqual(self._dll.tf_l(-2147483646), -715827882) - self.assertEqual(self.S(), -2147483646) - - def test_long_plus(self): - self._dll.tf_bl.restype = c_long - self._dll.tf_bl.argtypes = (c_byte, c_long) - self.assertEqual(self._dll.tf_bl(0, -2147483646), -715827882) - self.assertEqual(self.S(), -2147483646) - - def test_ulong(self): - self._dll.tf_L.restype = c_ulong - self._dll.tf_L.argtypes = (c_ulong,) - self.assertEqual(self._dll.tf_L(4294967295), 1431655765) - self.assertEqual(self.U(), 4294967295) - - def test_ulong_plus(self): - self._dll.tf_bL.restype = c_ulong - self._dll.tf_bL.argtypes = (c_char, c_ulong) - self.assertEqual(self._dll.tf_bL(b' ', 4294967295), 1431655765) - self.assertEqual(self.U(), 4294967295) - - @need_symbol('c_longlong') - def test_longlong(self): - self._dll.tf_q.restype = c_longlong - self._dll.tf_q.argtypes = (c_longlong, ) - self.assertEqual(self._dll.tf_q(-9223372036854775806), -3074457345618258602) - self.assertEqual(self.S(), -9223372036854775806) - - @need_symbol('c_longlong') - def test_longlong_plus(self): - self._dll.tf_bq.restype = c_longlong - self._dll.tf_bq.argtypes = (c_byte, c_longlong) - self.assertEqual(self._dll.tf_bq(0, -9223372036854775806), -3074457345618258602) - self.assertEqual(self.S(), -9223372036854775806) - - @need_symbol('c_ulonglong') - def test_ulonglong(self): - self._dll.tf_Q.restype = c_ulonglong - self._dll.tf_Q.argtypes = (c_ulonglong, ) - self.assertEqual(self._dll.tf_Q(18446744073709551615), 6148914691236517205) - self.assertEqual(self.U(), 18446744073709551615) - - @need_symbol('c_ulonglong') - def test_ulonglong_plus(self): - self._dll.tf_bQ.restype = c_ulonglong - self._dll.tf_bQ.argtypes = (c_byte, c_ulonglong) - self.assertEqual(self._dll.tf_bQ(0, 18446744073709551615), 6148914691236517205) - self.assertEqual(self.U(), 18446744073709551615) - - def test_float(self): - self._dll.tf_f.restype = c_float - self._dll.tf_f.argtypes = (c_float,) - self.assertEqual(self._dll.tf_f(-42.), -14.) - self.assertEqual(self.S(), -42) - - def test_float_plus(self): - self._dll.tf_bf.restype = c_float - self._dll.tf_bf.argtypes = (c_byte, c_float) - self.assertEqual(self._dll.tf_bf(0, -42.), -14.) - self.assertEqual(self.S(), -42) - - def test_double(self): - self._dll.tf_d.restype = c_double - self._dll.tf_d.argtypes = (c_double,) - self.assertEqual(self._dll.tf_d(42.), 14.) - self.assertEqual(self.S(), 42) - - def test_double_plus(self): - self._dll.tf_bd.restype = c_double - self._dll.tf_bd.argtypes = (c_byte, c_double) - self.assertEqual(self._dll.tf_bd(0, 42.), 14.) - self.assertEqual(self.S(), 42) - - @need_symbol('c_longdouble') - def test_longdouble(self): - self._dll.tf_D.restype = c_longdouble - self._dll.tf_D.argtypes = (c_longdouble,) - self.assertEqual(self._dll.tf_D(42.), 14.) - self.assertEqual(self.S(), 42) - - @need_symbol('c_longdouble') - def test_longdouble_plus(self): - self._dll.tf_bD.restype = c_longdouble - self._dll.tf_bD.argtypes = (c_byte, c_longdouble) - self.assertEqual(self._dll.tf_bD(0, 42.), 14.) - self.assertEqual(self.S(), 42) - - def test_callwithresult(self): - def process_result(result): - return result * 2 - self._dll.tf_i.restype = process_result - self._dll.tf_i.argtypes = (c_int,) - self.assertEqual(self._dll.tf_i(42), 28) - self.assertEqual(self.S(), 42) - self.assertEqual(self._dll.tf_i(-42), -28) - self.assertEqual(self.S(), -42) - - def test_void(self): - self._dll.tv_i.restype = None - self._dll.tv_i.argtypes = (c_int,) - self.assertEqual(self._dll.tv_i(42), None) - self.assertEqual(self.S(), 42) - self.assertEqual(self._dll.tv_i(-42), None) - self.assertEqual(self.S(), -42) - -# The following repeats the above tests with stdcall functions (where -# they are available) -try: - WinDLL -except NameError: - def stdcall_dll(*_): pass -else: - class stdcall_dll(WinDLL): - def __getattr__(self, name): - if name[:2] == '__' and name[-2:] == '__': - raise AttributeError(name) - func = self._FuncPtr(("s_" + name, self)) - setattr(self, name, func) - return func - -@need_symbol('WinDLL') -class stdcallCFunctions(CFunctions): - _dll = stdcall_dll(_ctypes_test.__file__) - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_checkretval.py b/Lib/ctypes/test/test_checkretval.py deleted file mode 100644 index e9567dc3912..00000000000 --- a/Lib/ctypes/test/test_checkretval.py +++ /dev/null @@ -1,36 +0,0 @@ -import unittest - -from ctypes import * -from ctypes.test import need_symbol - -class CHECKED(c_int): - def _check_retval_(value): - # Receives a CHECKED instance. - return str(value.value) - _check_retval_ = staticmethod(_check_retval_) - -class Test(unittest.TestCase): - - def test_checkretval(self): - - import _ctypes_test - dll = CDLL(_ctypes_test.__file__) - self.assertEqual(42, dll._testfunc_p_p(42)) - - dll._testfunc_p_p.restype = CHECKED - self.assertEqual("42", dll._testfunc_p_p(42)) - - dll._testfunc_p_p.restype = None - self.assertEqual(None, dll._testfunc_p_p(42)) - - del dll._testfunc_p_p.restype - self.assertEqual(42, dll._testfunc_p_p(42)) - - @need_symbol('oledll') - def test_oledll(self): - self.assertRaises(OSError, - oledll.oleaut32.CreateTypeLib2, - 0, None, None) - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_delattr.py b/Lib/ctypes/test/test_delattr.py deleted file mode 100644 index 0f4d58691b5..00000000000 --- a/Lib/ctypes/test/test_delattr.py +++ /dev/null @@ -1,21 +0,0 @@ -import unittest -from ctypes import * - -class X(Structure): - _fields_ = [("foo", c_int)] - -class TestCase(unittest.TestCase): - def test_simple(self): - self.assertRaises(TypeError, - delattr, c_int(42), "value") - - def test_chararray(self): - self.assertRaises(TypeError, - delattr, (c_char * 5)(), "value") - - def test_struct(self): - self.assertRaises(TypeError, - delattr, X(), "foo") - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_errno.py b/Lib/ctypes/test/test_errno.py deleted file mode 100644 index 3685164dde6..00000000000 --- a/Lib/ctypes/test/test_errno.py +++ /dev/null @@ -1,76 +0,0 @@ -import unittest, os, errno -import threading - -from ctypes import * -from ctypes.util import find_library - -class Test(unittest.TestCase): - def test_open(self): - libc_name = find_library("c") - if libc_name is None: - raise unittest.SkipTest("Unable to find C library") - libc = CDLL(libc_name, use_errno=True) - if os.name == "nt": - libc_open = libc._open - else: - libc_open = libc.open - - libc_open.argtypes = c_char_p, c_int - - self.assertEqual(libc_open(b"", 0), -1) - self.assertEqual(get_errno(), errno.ENOENT) - - self.assertEqual(set_errno(32), errno.ENOENT) - self.assertEqual(get_errno(), 32) - - def _worker(): - set_errno(0) - - libc = CDLL(libc_name, use_errno=False) - if os.name == "nt": - libc_open = libc._open - else: - libc_open = libc.open - libc_open.argtypes = c_char_p, c_int - self.assertEqual(libc_open(b"", 0), -1) - self.assertEqual(get_errno(), 0) - - t = threading.Thread(target=_worker) - t.start() - t.join() - - self.assertEqual(get_errno(), 32) - set_errno(0) - - @unittest.skipUnless(os.name == "nt", 'Test specific to Windows') - def test_GetLastError(self): - dll = WinDLL("kernel32", use_last_error=True) - GetModuleHandle = dll.GetModuleHandleA - GetModuleHandle.argtypes = [c_wchar_p] - - self.assertEqual(0, GetModuleHandle("foo")) - self.assertEqual(get_last_error(), 126) - - self.assertEqual(set_last_error(32), 126) - self.assertEqual(get_last_error(), 32) - - def _worker(): - set_last_error(0) - - dll = WinDLL("kernel32", use_last_error=False) - GetModuleHandle = dll.GetModuleHandleW - GetModuleHandle.argtypes = [c_wchar_p] - GetModuleHandle("bar") - - self.assertEqual(get_last_error(), 0) - - t = threading.Thread(target=_worker) - t.start() - t.join() - - self.assertEqual(get_last_error(), 32) - - set_last_error(0) - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_find.py b/Lib/ctypes/test/test_find.py deleted file mode 100644 index 1ff9d019b13..00000000000 --- a/Lib/ctypes/test/test_find.py +++ /dev/null @@ -1,127 +0,0 @@ -import unittest -import unittest.mock -import os.path -import sys -import test.support -from test.support import os_helper -from ctypes import * -from ctypes.util import find_library - -# On some systems, loading the OpenGL libraries needs the RTLD_GLOBAL mode. -class Test_OpenGL_libs(unittest.TestCase): - @classmethod - def setUpClass(cls): - lib_gl = lib_glu = lib_gle = None - if sys.platform == "win32": - lib_gl = find_library("OpenGL32") - lib_glu = find_library("Glu32") - elif sys.platform == "darwin": - lib_gl = lib_glu = find_library("OpenGL") - else: - lib_gl = find_library("GL") - lib_glu = find_library("GLU") - lib_gle = find_library("gle") - - ## print, for debugging - if test.support.verbose: - print("OpenGL libraries:") - for item in (("GL", lib_gl), - ("GLU", lib_glu), - ("gle", lib_gle)): - print("\t", item) - - cls.gl = cls.glu = cls.gle = None - if lib_gl: - try: - cls.gl = CDLL(lib_gl, mode=RTLD_GLOBAL) - except OSError: - pass - if lib_glu: - try: - cls.glu = CDLL(lib_glu, RTLD_GLOBAL) - except OSError: - pass - if lib_gle: - try: - cls.gle = CDLL(lib_gle) - except OSError: - pass - - @classmethod - def tearDownClass(cls): - cls.gl = cls.glu = cls.gle = None - - def test_gl(self): - if self.gl is None: - self.skipTest('lib_gl not available') - self.gl.glClearIndex - - def test_glu(self): - if self.glu is None: - self.skipTest('lib_glu not available') - self.glu.gluBeginCurve - - def test_gle(self): - if self.gle is None: - self.skipTest('lib_gle not available') - self.gle.gleGetJoinStyle - - def test_shell_injection(self): - result = find_library('; echo Hello shell > ' + os_helper.TESTFN) - self.assertFalse(os.path.lexists(os_helper.TESTFN)) - self.assertIsNone(result) - - -@unittest.skipUnless(sys.platform.startswith('linux'), - 'Test only valid for Linux') -class FindLibraryLinux(unittest.TestCase): - def test_find_on_libpath(self): - import subprocess - import tempfile - - try: - p = subprocess.Popen(['gcc', '--version'], stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL) - out, _ = p.communicate() - except OSError: - raise unittest.SkipTest('gcc, needed for test, not available') - with tempfile.TemporaryDirectory() as d: - # create an empty temporary file - srcname = os.path.join(d, 'dummy.c') - libname = 'py_ctypes_test_dummy' - dstname = os.path.join(d, 'lib%s.so' % libname) - with open(srcname, 'wb') as f: - pass - self.assertTrue(os.path.exists(srcname)) - # compile the file to a shared library - cmd = ['gcc', '-o', dstname, '--shared', - '-Wl,-soname,lib%s.so' % libname, srcname] - out = subprocess.check_output(cmd) - self.assertTrue(os.path.exists(dstname)) - # now check that the .so can't be found (since not in - # LD_LIBRARY_PATH) - self.assertIsNone(find_library(libname)) - # now add the location to LD_LIBRARY_PATH - with os_helper.EnvironmentVarGuard() as env: - KEY = 'LD_LIBRARY_PATH' - if KEY not in env: - v = d - else: - v = '%s:%s' % (env[KEY], d) - env.set(KEY, v) - # now check that the .so can be found (since in - # LD_LIBRARY_PATH) - self.assertEqual(find_library(libname), 'lib%s.so' % libname) - - def test_find_library_with_gcc(self): - with unittest.mock.patch("ctypes.util._findSoname_ldconfig", lambda *args: None): - self.assertNotEqual(find_library('c'), None) - - def test_find_library_with_ld(self): - with unittest.mock.patch("ctypes.util._findSoname_ldconfig", lambda *args: None), \ - unittest.mock.patch("ctypes.util._findLib_gcc", lambda *args: None): - self.assertNotEqual(find_library('c'), None) - - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_frombuffer.py b/Lib/ctypes/test/test_frombuffer.py deleted file mode 100644 index 55c244356b3..00000000000 --- a/Lib/ctypes/test/test_frombuffer.py +++ /dev/null @@ -1,141 +0,0 @@ -from ctypes import * -import array -import gc -import unittest - -class X(Structure): - _fields_ = [("c_int", c_int)] - init_called = False - def __init__(self): - self._init_called = True - -class Test(unittest.TestCase): - def test_from_buffer(self): - a = array.array("i", range(16)) - x = (c_int * 16).from_buffer(a) - - y = X.from_buffer(a) - self.assertEqual(y.c_int, a[0]) - self.assertFalse(y.init_called) - - self.assertEqual(x[:], a.tolist()) - - a[0], a[-1] = 200, -200 - self.assertEqual(x[:], a.tolist()) - - self.assertRaises(BufferError, a.append, 100) - self.assertRaises(BufferError, a.pop) - - del x; del y; gc.collect(); gc.collect(); gc.collect() - a.append(100) - a.pop() - x = (c_int * 16).from_buffer(a) - - self.assertIn(a, [obj.obj if isinstance(obj, memoryview) else obj - for obj in x._objects.values()]) - - expected = x[:] - del a; gc.collect(); gc.collect(); gc.collect() - self.assertEqual(x[:], expected) - - with self.assertRaisesRegex(TypeError, "not writable"): - (c_char * 16).from_buffer(b"a" * 16) - with self.assertRaisesRegex(TypeError, "not writable"): - (c_char * 16).from_buffer(memoryview(b"a" * 16)) - with self.assertRaisesRegex(TypeError, "not C contiguous"): - (c_char * 16).from_buffer(memoryview(bytearray(b"a" * 16))[::-1]) - msg = "bytes-like object is required" - with self.assertRaisesRegex(TypeError, msg): - (c_char * 16).from_buffer("a" * 16) - - def test_fortran_contiguous(self): - try: - import _testbuffer - except ImportError as err: - self.skipTest(str(err)) - flags = _testbuffer.ND_WRITABLE | _testbuffer.ND_FORTRAN - array = _testbuffer.ndarray( - [97] * 16, format="B", shape=[4, 4], flags=flags) - with self.assertRaisesRegex(TypeError, "not C contiguous"): - (c_char * 16).from_buffer(array) - array = memoryview(array) - self.assertTrue(array.f_contiguous) - self.assertFalse(array.c_contiguous) - with self.assertRaisesRegex(TypeError, "not C contiguous"): - (c_char * 16).from_buffer(array) - - def test_from_buffer_with_offset(self): - a = array.array("i", range(16)) - x = (c_int * 15).from_buffer(a, sizeof(c_int)) - - self.assertEqual(x[:], a.tolist()[1:]) - with self.assertRaises(ValueError): - c_int.from_buffer(a, -1) - with self.assertRaises(ValueError): - (c_int * 16).from_buffer(a, sizeof(c_int)) - with self.assertRaises(ValueError): - (c_int * 1).from_buffer(a, 16 * sizeof(c_int)) - - def test_from_buffer_memoryview(self): - a = [c_char.from_buffer(memoryview(bytearray(b'a')))] - a.append(a) - del a - gc.collect() # Should not crash - - def test_from_buffer_copy(self): - a = array.array("i", range(16)) - x = (c_int * 16).from_buffer_copy(a) - - y = X.from_buffer_copy(a) - self.assertEqual(y.c_int, a[0]) - self.assertFalse(y.init_called) - - self.assertEqual(x[:], list(range(16))) - - a[0], a[-1] = 200, -200 - self.assertEqual(x[:], list(range(16))) - - a.append(100) - self.assertEqual(x[:], list(range(16))) - - self.assertEqual(x._objects, None) - - del a; gc.collect(); gc.collect(); gc.collect() - self.assertEqual(x[:], list(range(16))) - - x = (c_char * 16).from_buffer_copy(b"a" * 16) - self.assertEqual(x[:], b"a" * 16) - with self.assertRaises(TypeError): - (c_char * 16).from_buffer_copy("a" * 16) - - def test_from_buffer_copy_with_offset(self): - a = array.array("i", range(16)) - x = (c_int * 15).from_buffer_copy(a, sizeof(c_int)) - - self.assertEqual(x[:], a.tolist()[1:]) - with self.assertRaises(ValueError): - c_int.from_buffer_copy(a, -1) - with self.assertRaises(ValueError): - (c_int * 16).from_buffer_copy(a, sizeof(c_int)) - with self.assertRaises(ValueError): - (c_int * 1).from_buffer_copy(a, 16 * sizeof(c_int)) - - def test_abstract(self): - from ctypes import _Pointer, _SimpleCData, _CFuncPtr - - self.assertRaises(TypeError, Array.from_buffer, bytearray(10)) - self.assertRaises(TypeError, Structure.from_buffer, bytearray(10)) - self.assertRaises(TypeError, Union.from_buffer, bytearray(10)) - self.assertRaises(TypeError, _CFuncPtr.from_buffer, bytearray(10)) - self.assertRaises(TypeError, _Pointer.from_buffer, bytearray(10)) - self.assertRaises(TypeError, _SimpleCData.from_buffer, bytearray(10)) - - self.assertRaises(TypeError, Array.from_buffer_copy, b"123") - self.assertRaises(TypeError, Structure.from_buffer_copy, b"123") - self.assertRaises(TypeError, Union.from_buffer_copy, b"123") - self.assertRaises(TypeError, _CFuncPtr.from_buffer_copy, b"123") - self.assertRaises(TypeError, _Pointer.from_buffer_copy, b"123") - self.assertRaises(TypeError, _SimpleCData.from_buffer_copy, b"123") - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_funcptr.py b/Lib/ctypes/test/test_funcptr.py deleted file mode 100644 index e0b9b54e97f..00000000000 --- a/Lib/ctypes/test/test_funcptr.py +++ /dev/null @@ -1,132 +0,0 @@ -import unittest -from ctypes import * - -try: - WINFUNCTYPE -except NameError: - # fake to enable this test on Linux - WINFUNCTYPE = CFUNCTYPE - -import _ctypes_test -lib = CDLL(_ctypes_test.__file__) - -class CFuncPtrTestCase(unittest.TestCase): - def test_basic(self): - X = WINFUNCTYPE(c_int, c_int, c_int) - - def func(*args): - return len(args) - - x = X(func) - self.assertEqual(x.restype, c_int) - self.assertEqual(x.argtypes, (c_int, c_int)) - self.assertEqual(sizeof(x), sizeof(c_voidp)) - self.assertEqual(sizeof(X), sizeof(c_voidp)) - - def test_first(self): - StdCallback = WINFUNCTYPE(c_int, c_int, c_int) - CdeclCallback = CFUNCTYPE(c_int, c_int, c_int) - - def func(a, b): - return a + b - - s = StdCallback(func) - c = CdeclCallback(func) - - self.assertEqual(s(1, 2), 3) - self.assertEqual(c(1, 2), 3) - # The following no longer raises a TypeError - it is now - # possible, as in C, to call cdecl functions with more parameters. - #self.assertRaises(TypeError, c, 1, 2, 3) - self.assertEqual(c(1, 2, 3, 4, 5, 6), 3) - if not WINFUNCTYPE is CFUNCTYPE: - self.assertRaises(TypeError, s, 1, 2, 3) - - def test_structures(self): - WNDPROC = WINFUNCTYPE(c_long, c_int, c_int, c_int, c_int) - - def wndproc(hwnd, msg, wParam, lParam): - return hwnd + msg + wParam + lParam - - HINSTANCE = c_int - HICON = c_int - HCURSOR = c_int - LPCTSTR = c_char_p - - class WNDCLASS(Structure): - _fields_ = [("style", c_uint), - ("lpfnWndProc", WNDPROC), - ("cbClsExtra", c_int), - ("cbWndExtra", c_int), - ("hInstance", HINSTANCE), - ("hIcon", HICON), - ("hCursor", HCURSOR), - ("lpszMenuName", LPCTSTR), - ("lpszClassName", LPCTSTR)] - - wndclass = WNDCLASS() - wndclass.lpfnWndProc = WNDPROC(wndproc) - - WNDPROC_2 = WINFUNCTYPE(c_long, c_int, c_int, c_int, c_int) - - # This is no longer true, now that WINFUNCTYPE caches created types internally. - ## # CFuncPtr subclasses are compared by identity, so this raises a TypeError: - ## self.assertRaises(TypeError, setattr, wndclass, - ## "lpfnWndProc", WNDPROC_2(wndproc)) - # instead: - - self.assertIs(WNDPROC, WNDPROC_2) - # 'wndclass.lpfnWndProc' leaks 94 references. Why? - self.assertEqual(wndclass.lpfnWndProc(1, 2, 3, 4), 10) - - - f = wndclass.lpfnWndProc - - del wndclass - del wndproc - - self.assertEqual(f(10, 11, 12, 13), 46) - - def test_dllfunctions(self): - - def NoNullHandle(value): - if not value: - raise WinError() - return value - - strchr = lib.my_strchr - strchr.restype = c_char_p - strchr.argtypes = (c_char_p, c_char) - self.assertEqual(strchr(b"abcdefghi", b"b"), b"bcdefghi") - self.assertEqual(strchr(b"abcdefghi", b"x"), None) - - - strtok = lib.my_strtok - strtok.restype = c_char_p - # Neither of this does work: strtok changes the buffer it is passed -## strtok.argtypes = (c_char_p, c_char_p) -## strtok.argtypes = (c_string, c_char_p) - - def c_string(init): - size = len(init) + 1 - return (c_char*size)(*init) - - s = b"a\nb\nc" - b = c_string(s) - -## b = (c_char * (len(s)+1))() -## b.value = s - -## b = c_string(s) - self.assertEqual(strtok(b, b"\n"), b"a") - self.assertEqual(strtok(None, b"\n"), b"b") - self.assertEqual(strtok(None, b"\n"), b"c") - self.assertEqual(strtok(None, b"\n"), None) - - def test_abstract(self): - from ctypes import _CFuncPtr - - self.assertRaises(TypeError, _CFuncPtr, 13, "name", 42, "iid") - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_functions.py b/Lib/ctypes/test/test_functions.py deleted file mode 100644 index fc571700ce3..00000000000 --- a/Lib/ctypes/test/test_functions.py +++ /dev/null @@ -1,384 +0,0 @@ -""" -Here is probably the place to write the docs, since the test-cases -show how the type behave. - -Later... -""" - -from ctypes import * -from ctypes.test import need_symbol -import sys, unittest - -try: - WINFUNCTYPE -except NameError: - # fake to enable this test on Linux - WINFUNCTYPE = CFUNCTYPE - -import _ctypes_test -dll = CDLL(_ctypes_test.__file__) -if sys.platform == "win32": - windll = WinDLL(_ctypes_test.__file__) - -class POINT(Structure): - _fields_ = [("x", c_int), ("y", c_int)] -class RECT(Structure): - _fields_ = [("left", c_int), ("top", c_int), - ("right", c_int), ("bottom", c_int)] -class FunctionTestCase(unittest.TestCase): - - def test_mro(self): - # in Python 2.3, this raises TypeError: MRO conflict among bases classes, - # in Python 2.2 it works. - # - # But in early versions of _ctypes.c, the result of tp_new - # wasn't checked, and it even crashed Python. - # Found by Greg Chapman. - - with self.assertRaises(TypeError): - class X(object, Array): - _length_ = 5 - _type_ = "i" - - from _ctypes import _Pointer - with self.assertRaises(TypeError): - class X(object, _Pointer): - pass - - from _ctypes import _SimpleCData - with self.assertRaises(TypeError): - class X(object, _SimpleCData): - _type_ = "i" - - with self.assertRaises(TypeError): - class X(object, Structure): - _fields_ = [] - - @need_symbol('c_wchar') - def test_wchar_parm(self): - f = dll._testfunc_i_bhilfd - f.argtypes = [c_byte, c_wchar, c_int, c_long, c_float, c_double] - result = f(1, "x", 3, 4, 5.0, 6.0) - self.assertEqual(result, 139) - self.assertEqual(type(result), int) - - @need_symbol('c_wchar') - def test_wchar_result(self): - f = dll._testfunc_i_bhilfd - f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] - f.restype = c_wchar - result = f(0, 0, 0, 0, 0, 0) - self.assertEqual(result, '\x00') - - def test_voidresult(self): - f = dll._testfunc_v - f.restype = None - f.argtypes = [c_int, c_int, POINTER(c_int)] - result = c_int() - self.assertEqual(None, f(1, 2, byref(result))) - self.assertEqual(result.value, 3) - - def test_intresult(self): - f = dll._testfunc_i_bhilfd - f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] - f.restype = c_int - result = f(1, 2, 3, 4, 5.0, 6.0) - self.assertEqual(result, 21) - self.assertEqual(type(result), int) - - result = f(-1, -2, -3, -4, -5.0, -6.0) - self.assertEqual(result, -21) - self.assertEqual(type(result), int) - - # If we declare the function to return a short, - # is the high part split off? - f.restype = c_short - result = f(1, 2, 3, 4, 5.0, 6.0) - self.assertEqual(result, 21) - self.assertEqual(type(result), int) - - result = f(1, 2, 3, 0x10004, 5.0, 6.0) - self.assertEqual(result, 21) - self.assertEqual(type(result), int) - - # You cannot assign character format codes as restype any longer - self.assertRaises(TypeError, setattr, f, "restype", "i") - - def test_floatresult(self): - f = dll._testfunc_f_bhilfd - f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] - f.restype = c_float - result = f(1, 2, 3, 4, 5.0, 6.0) - self.assertEqual(result, 21) - self.assertEqual(type(result), float) - - result = f(-1, -2, -3, -4, -5.0, -6.0) - self.assertEqual(result, -21) - self.assertEqual(type(result), float) - - def test_doubleresult(self): - f = dll._testfunc_d_bhilfd - f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] - f.restype = c_double - result = f(1, 2, 3, 4, 5.0, 6.0) - self.assertEqual(result, 21) - self.assertEqual(type(result), float) - - result = f(-1, -2, -3, -4, -5.0, -6.0) - self.assertEqual(result, -21) - self.assertEqual(type(result), float) - - @need_symbol('c_longdouble') - def test_longdoubleresult(self): - f = dll._testfunc_D_bhilfD - f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_longdouble] - f.restype = c_longdouble - result = f(1, 2, 3, 4, 5.0, 6.0) - self.assertEqual(result, 21) - self.assertEqual(type(result), float) - - result = f(-1, -2, -3, -4, -5.0, -6.0) - self.assertEqual(result, -21) - self.assertEqual(type(result), float) - - @need_symbol('c_longlong') - def test_longlongresult(self): - f = dll._testfunc_q_bhilfd - f.restype = c_longlong - f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] - result = f(1, 2, 3, 4, 5.0, 6.0) - self.assertEqual(result, 21) - - f = dll._testfunc_q_bhilfdq - f.restype = c_longlong - f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double, c_longlong] - result = f(1, 2, 3, 4, 5.0, 6.0, 21) - self.assertEqual(result, 42) - - def test_stringresult(self): - f = dll._testfunc_p_p - f.argtypes = None - f.restype = c_char_p - result = f(b"123") - self.assertEqual(result, b"123") - - result = f(None) - self.assertEqual(result, None) - - def test_pointers(self): - f = dll._testfunc_p_p - f.restype = POINTER(c_int) - f.argtypes = [POINTER(c_int)] - - # This only works if the value c_int(42) passed to the - # function is still alive while the pointer (the result) is - # used. - - v = c_int(42) - - self.assertEqual(pointer(v).contents.value, 42) - result = f(pointer(v)) - self.assertEqual(type(result), POINTER(c_int)) - self.assertEqual(result.contents.value, 42) - - # This on works... - result = f(pointer(v)) - self.assertEqual(result.contents.value, v.value) - - p = pointer(c_int(99)) - result = f(p) - self.assertEqual(result.contents.value, 99) - - arg = byref(v) - result = f(arg) - self.assertNotEqual(result.contents, v.value) - - self.assertRaises(ArgumentError, f, byref(c_short(22))) - - # It is dangerous, however, because you don't control the lifetime - # of the pointer: - result = f(byref(c_int(99))) - self.assertNotEqual(result.contents, 99) - - ################################################################ - def test_shorts(self): - f = dll._testfunc_callback_i_if - - args = [] - expected = [262144, 131072, 65536, 32768, 16384, 8192, 4096, 2048, - 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1] - - def callback(v): - args.append(v) - return v - - CallBack = CFUNCTYPE(c_int, c_int) - - cb = CallBack(callback) - f(2**18, cb) - self.assertEqual(args, expected) - - ################################################################ - - - def test_callbacks(self): - f = dll._testfunc_callback_i_if - f.restype = c_int - f.argtypes = None - - MyCallback = CFUNCTYPE(c_int, c_int) - - def callback(value): - #print "called back with", value - return value - - cb = MyCallback(callback) - result = f(-10, cb) - self.assertEqual(result, -18) - - # test with prototype - f.argtypes = [c_int, MyCallback] - cb = MyCallback(callback) - result = f(-10, cb) - self.assertEqual(result, -18) - - AnotherCallback = WINFUNCTYPE(c_int, c_int, c_int, c_int, c_int) - - # check that the prototype works: we call f with wrong - # argument types - cb = AnotherCallback(callback) - self.assertRaises(ArgumentError, f, -10, cb) - - - def test_callbacks_2(self): - # Can also use simple datatypes as argument type specifiers - # for the callback function. - # In this case the call receives an instance of that type - f = dll._testfunc_callback_i_if - f.restype = c_int - - MyCallback = CFUNCTYPE(c_int, c_int) - - f.argtypes = [c_int, MyCallback] - - def callback(value): - #print "called back with", value - self.assertEqual(type(value), int) - return value - - cb = MyCallback(callback) - result = f(-10, cb) - self.assertEqual(result, -18) - - @need_symbol('c_longlong') - def test_longlong_callbacks(self): - - f = dll._testfunc_callback_q_qf - f.restype = c_longlong - - MyCallback = CFUNCTYPE(c_longlong, c_longlong) - - f.argtypes = [c_longlong, MyCallback] - - def callback(value): - self.assertIsInstance(value, int) - return value & 0x7FFFFFFF - - cb = MyCallback(callback) - - self.assertEqual(13577625587, f(1000000000000, cb)) - - def test_errors(self): - self.assertRaises(AttributeError, getattr, dll, "_xxx_yyy") - self.assertRaises(ValueError, c_int.in_dll, dll, "_xxx_yyy") - - def test_byval(self): - - # without prototype - ptin = POINT(1, 2) - ptout = POINT() - # EXPORT int _testfunc_byval(point in, point *pout) - result = dll._testfunc_byval(ptin, byref(ptout)) - got = result, ptout.x, ptout.y - expected = 3, 1, 2 - self.assertEqual(got, expected) - - # with prototype - ptin = POINT(101, 102) - ptout = POINT() - dll._testfunc_byval.argtypes = (POINT, POINTER(POINT)) - dll._testfunc_byval.restype = c_int - result = dll._testfunc_byval(ptin, byref(ptout)) - got = result, ptout.x, ptout.y - expected = 203, 101, 102 - self.assertEqual(got, expected) - - def test_struct_return_2H(self): - class S2H(Structure): - _fields_ = [("x", c_short), - ("y", c_short)] - dll.ret_2h_func.restype = S2H - dll.ret_2h_func.argtypes = [S2H] - inp = S2H(99, 88) - s2h = dll.ret_2h_func(inp) - self.assertEqual((s2h.x, s2h.y), (99*2, 88*3)) - - @unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') - def test_struct_return_2H_stdcall(self): - class S2H(Structure): - _fields_ = [("x", c_short), - ("y", c_short)] - - windll.s_ret_2h_func.restype = S2H - windll.s_ret_2h_func.argtypes = [S2H] - s2h = windll.s_ret_2h_func(S2H(99, 88)) - self.assertEqual((s2h.x, s2h.y), (99*2, 88*3)) - - def test_struct_return_8H(self): - class S8I(Structure): - _fields_ = [("a", c_int), - ("b", c_int), - ("c", c_int), - ("d", c_int), - ("e", c_int), - ("f", c_int), - ("g", c_int), - ("h", c_int)] - dll.ret_8i_func.restype = S8I - dll.ret_8i_func.argtypes = [S8I] - inp = S8I(9, 8, 7, 6, 5, 4, 3, 2) - s8i = dll.ret_8i_func(inp) - self.assertEqual((s8i.a, s8i.b, s8i.c, s8i.d, s8i.e, s8i.f, s8i.g, s8i.h), - (9*2, 8*3, 7*4, 6*5, 5*6, 4*7, 3*8, 2*9)) - - @unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') - def test_struct_return_8H_stdcall(self): - class S8I(Structure): - _fields_ = [("a", c_int), - ("b", c_int), - ("c", c_int), - ("d", c_int), - ("e", c_int), - ("f", c_int), - ("g", c_int), - ("h", c_int)] - windll.s_ret_8i_func.restype = S8I - windll.s_ret_8i_func.argtypes = [S8I] - inp = S8I(9, 8, 7, 6, 5, 4, 3, 2) - s8i = windll.s_ret_8i_func(inp) - self.assertEqual( - (s8i.a, s8i.b, s8i.c, s8i.d, s8i.e, s8i.f, s8i.g, s8i.h), - (9*2, 8*3, 7*4, 6*5, 5*6, 4*7, 3*8, 2*9)) - - def test_sf1651235(self): - # see https://www.python.org/sf/1651235 - - proto = CFUNCTYPE(c_int, RECT, POINT) - def callback(*args): - return 0 - - callback = proto(callback) - self.assertRaises(ArgumentError, lambda: callback((1, 2, 3, 4), POINT())) - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_incomplete.py b/Lib/ctypes/test/test_incomplete.py deleted file mode 100644 index 00c430ef53c..00000000000 --- a/Lib/ctypes/test/test_incomplete.py +++ /dev/null @@ -1,42 +0,0 @@ -import unittest -from ctypes import * - -################################################################ -# -# The incomplete pointer example from the tutorial -# - -class MyTestCase(unittest.TestCase): - - def test_incomplete_example(self): - lpcell = POINTER("cell") - class cell(Structure): - _fields_ = [("name", c_char_p), - ("next", lpcell)] - - SetPointerType(lpcell, cell) - - c1 = cell() - c1.name = b"foo" - c2 = cell() - c2.name = b"bar" - - c1.next = pointer(c2) - c2.next = pointer(c1) - - p = c1 - - result = [] - for i in range(8): - result.append(p.name) - p = p.next[0] - self.assertEqual(result, [b"foo", b"bar"] * 4) - - # to not leak references, we must clean _pointer_type_cache - from ctypes import _pointer_type_cache - del _pointer_type_cache[cell] - -################################################################ - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_init.py b/Lib/ctypes/test/test_init.py deleted file mode 100644 index 75fad112a01..00000000000 --- a/Lib/ctypes/test/test_init.py +++ /dev/null @@ -1,40 +0,0 @@ -from ctypes import * -import unittest - -class X(Structure): - _fields_ = [("a", c_int), - ("b", c_int)] - new_was_called = False - - def __new__(cls): - result = super().__new__(cls) - result.new_was_called = True - return result - - def __init__(self): - self.a = 9 - self.b = 12 - -class Y(Structure): - _fields_ = [("x", X)] - - -class InitTest(unittest.TestCase): - def test_get(self): - # make sure the only accessing a nested structure - # doesn't call the structure's __new__ and __init__ - y = Y() - self.assertEqual((y.x.a, y.x.b), (0, 0)) - self.assertEqual(y.x.new_was_called, False) - - # But explicitly creating an X structure calls __new__ and __init__, of course. - x = X() - self.assertEqual((x.a, x.b), (9, 12)) - self.assertEqual(x.new_was_called, True) - - y.x = x - self.assertEqual((y.x.a, y.x.b), (9, 12)) - self.assertEqual(y.x.new_was_called, False) - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_internals.py b/Lib/ctypes/test/test_internals.py deleted file mode 100644 index 271e3f57f81..00000000000 --- a/Lib/ctypes/test/test_internals.py +++ /dev/null @@ -1,100 +0,0 @@ -# This tests the internal _objects attribute -import unittest -from ctypes import * -from sys import getrefcount as grc - -# XXX This test must be reviewed for correctness!!! - -# ctypes' types are container types. -# -# They have an internal memory block, which only consists of some bytes, -# but it has to keep references to other objects as well. This is not -# really needed for trivial C types like int or char, but it is important -# for aggregate types like strings or pointers in particular. -# -# What about pointers? - -class ObjectsTestCase(unittest.TestCase): - def assertSame(self, a, b): - self.assertEqual(id(a), id(b)) - - def test_ints(self): - i = 42000123 - refcnt = grc(i) - ci = c_int(i) - self.assertEqual(refcnt, grc(i)) - self.assertEqual(ci._objects, None) - - def test_c_char_p(self): - s = b"Hello, World" - refcnt = grc(s) - cs = c_char_p(s) - self.assertEqual(refcnt + 1, grc(s)) - self.assertSame(cs._objects, s) - - def test_simple_struct(self): - class X(Structure): - _fields_ = [("a", c_int), ("b", c_int)] - - a = 421234 - b = 421235 - x = X() - self.assertEqual(x._objects, None) - x.a = a - x.b = b - self.assertEqual(x._objects, None) - - def test_embedded_structs(self): - class X(Structure): - _fields_ = [("a", c_int), ("b", c_int)] - - class Y(Structure): - _fields_ = [("x", X), ("y", X)] - - y = Y() - self.assertEqual(y._objects, None) - - x1, x2 = X(), X() - y.x, y.y = x1, x2 - self.assertEqual(y._objects, {"0": {}, "1": {}}) - x1.a, x2.b = 42, 93 - self.assertEqual(y._objects, {"0": {}, "1": {}}) - - def test_xxx(self): - class X(Structure): - _fields_ = [("a", c_char_p), ("b", c_char_p)] - - class Y(Structure): - _fields_ = [("x", X), ("y", X)] - - s1 = b"Hello, World" - s2 = b"Hallo, Welt" - - x = X() - x.a = s1 - x.b = s2 - self.assertEqual(x._objects, {"0": s1, "1": s2}) - - y = Y() - y.x = x - self.assertEqual(y._objects, {"0": {"0": s1, "1": s2}}) -## x = y.x -## del y -## print x._b_base_._objects - - def test_ptr_struct(self): - class X(Structure): - _fields_ = [("data", POINTER(c_int))] - - A = c_int*4 - a = A(11, 22, 33, 44) - self.assertEqual(a._objects, None) - - x = X() - x.data = a -##XXX print x._objects -##XXX print x.data[0] -##XXX print x.data._objects - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_keeprefs.py b/Lib/ctypes/test/test_keeprefs.py deleted file mode 100644 index 94c02573fa1..00000000000 --- a/Lib/ctypes/test/test_keeprefs.py +++ /dev/null @@ -1,153 +0,0 @@ -from ctypes import * -import unittest - -class SimpleTestCase(unittest.TestCase): - def test_cint(self): - x = c_int() - self.assertEqual(x._objects, None) - x.value = 42 - self.assertEqual(x._objects, None) - x = c_int(99) - self.assertEqual(x._objects, None) - - def test_ccharp(self): - x = c_char_p() - self.assertEqual(x._objects, None) - x.value = b"abc" - self.assertEqual(x._objects, b"abc") - x = c_char_p(b"spam") - self.assertEqual(x._objects, b"spam") - -class StructureTestCase(unittest.TestCase): - def test_cint_struct(self): - class X(Structure): - _fields_ = [("a", c_int), - ("b", c_int)] - - x = X() - self.assertEqual(x._objects, None) - x.a = 42 - x.b = 99 - self.assertEqual(x._objects, None) - - def test_ccharp_struct(self): - class X(Structure): - _fields_ = [("a", c_char_p), - ("b", c_char_p)] - x = X() - self.assertEqual(x._objects, None) - - x.a = b"spam" - x.b = b"foo" - self.assertEqual(x._objects, {"0": b"spam", "1": b"foo"}) - - def test_struct_struct(self): - class POINT(Structure): - _fields_ = [("x", c_int), ("y", c_int)] - class RECT(Structure): - _fields_ = [("ul", POINT), ("lr", POINT)] - - r = RECT() - r.ul.x = 0 - r.ul.y = 1 - r.lr.x = 2 - r.lr.y = 3 - self.assertEqual(r._objects, None) - - r = RECT() - pt = POINT(1, 2) - r.ul = pt - self.assertEqual(r._objects, {'0': {}}) - r.ul.x = 22 - r.ul.y = 44 - self.assertEqual(r._objects, {'0': {}}) - r.lr = POINT() - self.assertEqual(r._objects, {'0': {}, '1': {}}) - -class ArrayTestCase(unittest.TestCase): - def test_cint_array(self): - INTARR = c_int * 3 - - ia = INTARR() - self.assertEqual(ia._objects, None) - ia[0] = 1 - ia[1] = 2 - ia[2] = 3 - self.assertEqual(ia._objects, None) - - class X(Structure): - _fields_ = [("x", c_int), - ("a", INTARR)] - - x = X() - x.x = 1000 - x.a[0] = 42 - x.a[1] = 96 - self.assertEqual(x._objects, None) - x.a = ia - self.assertEqual(x._objects, {'1': {}}) - -class PointerTestCase(unittest.TestCase): - def test_p_cint(self): - i = c_int(42) - x = pointer(i) - self.assertEqual(x._objects, {'1': i}) - -class DeletePointerTestCase(unittest.TestCase): - @unittest.skip('test disabled') - def test_X(self): - class X(Structure): - _fields_ = [("p", POINTER(c_char_p))] - x = X() - i = c_char_p("abc def") - from sys import getrefcount as grc - print("2?", grc(i)) - x.p = pointer(i) - print("3?", grc(i)) - for i in range(320): - c_int(99) - x.p[0] - print(x.p[0]) -## del x -## print "2?", grc(i) -## del i - import gc - gc.collect() - for i in range(320): - c_int(99) - x.p[0] - print(x.p[0]) - print(x.p.contents) -## print x._objects - - x.p[0] = "spam spam" -## print x.p[0] - print("+" * 42) - print(x._objects) - -class PointerToStructure(unittest.TestCase): - def test(self): - class POINT(Structure): - _fields_ = [("x", c_int), ("y", c_int)] - class RECT(Structure): - _fields_ = [("a", POINTER(POINT)), - ("b", POINTER(POINT))] - r = RECT() - p1 = POINT(1, 2) - - r.a = pointer(p1) - r.b = pointer(p1) -## from pprint import pprint as pp -## pp(p1._objects) -## pp(r._objects) - - r.a[0].x = 42 - r.a[0].y = 99 - - # to avoid leaking when tests are run several times - # clean up the types left in the cache. - from ctypes import _pointer_type_cache - del _pointer_type_cache[POINT] - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_libc.py b/Lib/ctypes/test/test_libc.py deleted file mode 100644 index 56285b5ff81..00000000000 --- a/Lib/ctypes/test/test_libc.py +++ /dev/null @@ -1,33 +0,0 @@ -import unittest - -from ctypes import * -import _ctypes_test - -lib = CDLL(_ctypes_test.__file__) - -def three_way_cmp(x, y): - """Return -1 if x < y, 0 if x == y and 1 if x > y""" - return (x > y) - (x < y) - -class LibTest(unittest.TestCase): - def test_sqrt(self): - lib.my_sqrt.argtypes = c_double, - lib.my_sqrt.restype = c_double - self.assertEqual(lib.my_sqrt(4.0), 2.0) - import math - self.assertEqual(lib.my_sqrt(2.0), math.sqrt(2.0)) - - def test_qsort(self): - comparefunc = CFUNCTYPE(c_int, POINTER(c_char), POINTER(c_char)) - lib.my_qsort.argtypes = c_void_p, c_size_t, c_size_t, comparefunc - lib.my_qsort.restype = None - - def sort(a, b): - return three_way_cmp(a[0], b[0]) - - chars = create_string_buffer(b"spam, spam, and spam") - lib.my_qsort(chars, len(chars)-1, sizeof(c_char), comparefunc(sort)) - self.assertEqual(chars.raw, b" ,,aaaadmmmnpppsss\x00") - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_loading.py b/Lib/ctypes/test/test_loading.py deleted file mode 100644 index ea892277c4e..00000000000 --- a/Lib/ctypes/test/test_loading.py +++ /dev/null @@ -1,182 +0,0 @@ -from ctypes import * -import os -import shutil -import subprocess -import sys -import unittest -import test.support -from test.support import import_helper -from test.support import os_helper -from ctypes.util import find_library - -libc_name = None - -def setUpModule(): - global libc_name - if os.name == "nt": - libc_name = find_library("c") - elif sys.platform == "cygwin": - libc_name = "cygwin1.dll" - else: - libc_name = find_library("c") - - if test.support.verbose: - print("libc_name is", libc_name) - -class LoaderTest(unittest.TestCase): - - unknowndll = "xxrandomnamexx" - - def test_load(self): - if libc_name is None: - self.skipTest('could not find libc') - CDLL(libc_name) - CDLL(os.path.basename(libc_name)) - self.assertRaises(OSError, CDLL, self.unknowndll) - - def test_load_version(self): - if libc_name is None: - self.skipTest('could not find libc') - if os.path.basename(libc_name) != 'libc.so.6': - self.skipTest('wrong libc path for test') - cdll.LoadLibrary("libc.so.6") - # linux uses version, libc 9 should not exist - self.assertRaises(OSError, cdll.LoadLibrary, "libc.so.9") - self.assertRaises(OSError, cdll.LoadLibrary, self.unknowndll) - - def test_find(self): - for name in ("c", "m"): - lib = find_library(name) - if lib: - cdll.LoadLibrary(lib) - CDLL(lib) - - @unittest.skipUnless(os.name == "nt", - 'test specific to Windows') - def test_load_library(self): - # CRT is no longer directly loadable. See issue23606 for the - # discussion about alternative approaches. - #self.assertIsNotNone(libc_name) - if test.support.verbose: - print(find_library("kernel32")) - print(find_library("user32")) - - if os.name == "nt": - windll.kernel32.GetModuleHandleW - windll["kernel32"].GetModuleHandleW - windll.LoadLibrary("kernel32").GetModuleHandleW - WinDLL("kernel32").GetModuleHandleW - # embedded null character - self.assertRaises(ValueError, windll.LoadLibrary, "kernel32\0") - - @unittest.skipUnless(os.name == "nt", - 'test specific to Windows') - def test_load_ordinal_functions(self): - import _ctypes_test - dll = WinDLL(_ctypes_test.__file__) - # We load the same function both via ordinal and name - func_ord = dll[2] - func_name = dll.GetString - # addressof gets the address where the function pointer is stored - a_ord = addressof(func_ord) - a_name = addressof(func_name) - f_ord_addr = c_void_p.from_address(a_ord).value - f_name_addr = c_void_p.from_address(a_name).value - self.assertEqual(hex(f_ord_addr), hex(f_name_addr)) - - self.assertRaises(AttributeError, dll.__getitem__, 1234) - - @unittest.skipUnless(os.name == "nt", 'Windows-specific test') - def test_1703286_A(self): - from _ctypes import LoadLibrary, FreeLibrary - # On winXP 64-bit, advapi32 loads at an address that does - # NOT fit into a 32-bit integer. FreeLibrary must be able - # to accept this address. - - # These are tests for https://www.python.org/sf/1703286 - handle = LoadLibrary("advapi32") - FreeLibrary(handle) - - @unittest.skipUnless(os.name == "nt", 'Windows-specific test') - def test_1703286_B(self): - # Since on winXP 64-bit advapi32 loads like described - # above, the (arbitrarily selected) CloseEventLog function - # also has a high address. 'call_function' should accept - # addresses so large. - from _ctypes import call_function - advapi32 = windll.advapi32 - # Calling CloseEventLog with a NULL argument should fail, - # but the call should not segfault or so. - self.assertEqual(0, advapi32.CloseEventLog(None)) - windll.kernel32.GetProcAddress.argtypes = c_void_p, c_char_p - windll.kernel32.GetProcAddress.restype = c_void_p - proc = windll.kernel32.GetProcAddress(advapi32._handle, - b"CloseEventLog") - self.assertTrue(proc) - # This is the real test: call the function via 'call_function' - self.assertEqual(0, call_function(proc, (None,))) - - @unittest.skipUnless(os.name == "nt", - 'test specific to Windows') - def test_load_dll_with_flags(self): - _sqlite3 = import_helper.import_module("_sqlite3") - src = _sqlite3.__file__ - if src.lower().endswith("_d.pyd"): - ext = "_d.dll" - else: - ext = ".dll" - - with os_helper.temp_dir() as tmp: - # We copy two files and load _sqlite3.dll (formerly .pyd), - # which has a dependency on sqlite3.dll. Then we test - # loading it in subprocesses to avoid it starting in memory - # for each test. - target = os.path.join(tmp, "_sqlite3.dll") - shutil.copy(src, target) - shutil.copy(os.path.join(os.path.dirname(src), "sqlite3" + ext), - os.path.join(tmp, "sqlite3" + ext)) - - def should_pass(command): - with self.subTest(command): - subprocess.check_output( - [sys.executable, "-c", - "from ctypes import *; import nt;" + command], - cwd=tmp - ) - - def should_fail(command): - with self.subTest(command): - with self.assertRaises(subprocess.CalledProcessError): - subprocess.check_output( - [sys.executable, "-c", - "from ctypes import *; import nt;" + command], - cwd=tmp, stderr=subprocess.STDOUT, - ) - - # Default load should not find this in CWD - should_fail("WinDLL('_sqlite3.dll')") - - # Relative path (but not just filename) should succeed - should_pass("WinDLL('./_sqlite3.dll')") - - # Insecure load flags should succeed - # Clear the DLL directory to avoid safe search settings propagating - should_pass("windll.kernel32.SetDllDirectoryW(None); WinDLL('_sqlite3.dll', winmode=0)") - - # Full path load without DLL_LOAD_DIR shouldn't find dependency - should_fail("WinDLL(nt._getfullpathname('_sqlite3.dll'), " + - "winmode=nt._LOAD_LIBRARY_SEARCH_SYSTEM32)") - - # Full path load with DLL_LOAD_DIR should succeed - should_pass("WinDLL(nt._getfullpathname('_sqlite3.dll'), " + - "winmode=nt._LOAD_LIBRARY_SEARCH_SYSTEM32|" + - "nt._LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)") - - # User-specified directory should succeed - should_pass("import os; p = os.add_dll_directory(os.getcwd());" + - "WinDLL('_sqlite3.dll'); p.close()") - - - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_macholib.py b/Lib/ctypes/test/test_macholib.py deleted file mode 100644 index bc75f1a05a8..00000000000 --- a/Lib/ctypes/test/test_macholib.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import sys -import unittest - -# Bob Ippolito: -# -# Ok.. the code to find the filename for __getattr__ should look -# something like: -# -# import os -# from macholib.dyld import dyld_find -# -# def find_lib(name): -# possible = ['lib'+name+'.dylib', name+'.dylib', -# name+'.framework/'+name] -# for dylib in possible: -# try: -# return os.path.realpath(dyld_find(dylib)) -# except ValueError: -# pass -# raise ValueError, "%s not found" % (name,) -# -# It'll have output like this: -# -# >>> find_lib('pthread') -# '/usr/lib/libSystem.B.dylib' -# >>> find_lib('z') -# '/usr/lib/libz.1.dylib' -# >>> find_lib('IOKit') -# '/System/Library/Frameworks/IOKit.framework/Versions/A/IOKit' -# -# -bob - -from ctypes.macholib.dyld import dyld_find -from ctypes.macholib.dylib import dylib_info -from ctypes.macholib.framework import framework_info - -def find_lib(name): - possible = ['lib'+name+'.dylib', name+'.dylib', name+'.framework/'+name] - for dylib in possible: - try: - return os.path.realpath(dyld_find(dylib)) - except ValueError: - pass - raise ValueError("%s not found" % (name,)) - - -def d(location=None, name=None, shortname=None, version=None, suffix=None): - return {'location': location, 'name': name, 'shortname': shortname, - 'version': version, 'suffix': suffix} - - -class MachOTest(unittest.TestCase): - @unittest.skipUnless(sys.platform == "darwin", 'OSX-specific test') - def test_find(self): - self.assertEqual(dyld_find('libSystem.dylib'), - '/usr/lib/libSystem.dylib') - self.assertEqual(dyld_find('System.framework/System'), - '/System/Library/Frameworks/System.framework/System') - - # On Mac OS 11, system dylibs are only present in the shared cache, - # so symlinks like libpthread.dylib -> libSystem.B.dylib will not - # be resolved by dyld_find - self.assertIn(find_lib('pthread'), - ('/usr/lib/libSystem.B.dylib', '/usr/lib/libpthread.dylib')) - - result = find_lib('z') - # Issue #21093: dyld default search path includes $HOME/lib and - # /usr/local/lib before /usr/lib, which caused test failures if - # a local copy of libz exists in one of them. Now ignore the head - # of the path. - self.assertRegex(result, r".*/lib/libz.*\.dylib") - - self.assertIn(find_lib('IOKit'), - ('/System/Library/Frameworks/IOKit.framework/Versions/A/IOKit', - '/System/Library/Frameworks/IOKit.framework/IOKit')) - - @unittest.skipUnless(sys.platform == "darwin", 'OSX-specific test') - def test_info(self): - self.assertIsNone(dylib_info('completely/invalid')) - self.assertIsNone(dylib_info('completely/invalide_debug')) - self.assertEqual(dylib_info('P/Foo.dylib'), d('P', 'Foo.dylib', 'Foo')) - self.assertEqual(dylib_info('P/Foo_debug.dylib'), - d('P', 'Foo_debug.dylib', 'Foo', suffix='debug')) - self.assertEqual(dylib_info('P/Foo.A.dylib'), - d('P', 'Foo.A.dylib', 'Foo', 'A')) - self.assertEqual(dylib_info('P/Foo_debug.A.dylib'), - d('P', 'Foo_debug.A.dylib', 'Foo_debug', 'A')) - self.assertEqual(dylib_info('P/Foo.A_debug.dylib'), - d('P', 'Foo.A_debug.dylib', 'Foo', 'A', 'debug')) - - @unittest.skipUnless(sys.platform == "darwin", 'OSX-specific test') - def test_framework_info(self): - self.assertIsNone(framework_info('completely/invalid')) - self.assertIsNone(framework_info('completely/invalid/_debug')) - self.assertIsNone(framework_info('P/F.framework')) - self.assertIsNone(framework_info('P/F.framework/_debug')) - self.assertEqual(framework_info('P/F.framework/F'), - d('P', 'F.framework/F', 'F')) - self.assertEqual(framework_info('P/F.framework/F_debug'), - d('P', 'F.framework/F_debug', 'F', suffix='debug')) - self.assertIsNone(framework_info('P/F.framework/Versions')) - self.assertIsNone(framework_info('P/F.framework/Versions/A')) - self.assertEqual(framework_info('P/F.framework/Versions/A/F'), - d('P', 'F.framework/Versions/A/F', 'F', 'A')) - self.assertEqual(framework_info('P/F.framework/Versions/A/F_debug'), - d('P', 'F.framework/Versions/A/F_debug', 'F', 'A', 'debug')) - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_memfunctions.py b/Lib/ctypes/test/test_memfunctions.py deleted file mode 100644 index e784b9a7068..00000000000 --- a/Lib/ctypes/test/test_memfunctions.py +++ /dev/null @@ -1,79 +0,0 @@ -import sys -from test import support -import unittest -from ctypes import * -from ctypes.test import need_symbol - -class MemFunctionsTest(unittest.TestCase): - @unittest.skip('test disabled') - def test_overflow(self): - # string_at and wstring_at must use the Python calling - # convention (which acquires the GIL and checks the Python - # error flag). Provoke an error and catch it; see also issue - # #3554: - self.assertRaises((OverflowError, MemoryError, SystemError), - lambda: wstring_at(u"foo", sys.maxint - 1)) - self.assertRaises((OverflowError, MemoryError, SystemError), - lambda: string_at("foo", sys.maxint - 1)) - - def test_memmove(self): - # large buffers apparently increase the chance that the memory - # is allocated in high address space. - a = create_string_buffer(1000000) - p = b"Hello, World" - result = memmove(a, p, len(p)) - self.assertEqual(a.value, b"Hello, World") - - self.assertEqual(string_at(result), b"Hello, World") - self.assertEqual(string_at(result, 5), b"Hello") - self.assertEqual(string_at(result, 16), b"Hello, World\0\0\0\0") - self.assertEqual(string_at(result, 0), b"") - - def test_memset(self): - a = create_string_buffer(1000000) - result = memset(a, ord('x'), 16) - self.assertEqual(a.value, b"xxxxxxxxxxxxxxxx") - - self.assertEqual(string_at(result), b"xxxxxxxxxxxxxxxx") - self.assertEqual(string_at(a), b"xxxxxxxxxxxxxxxx") - self.assertEqual(string_at(a, 20), b"xxxxxxxxxxxxxxxx\0\0\0\0") - - def test_cast(self): - a = (c_ubyte * 32)(*map(ord, "abcdef")) - self.assertEqual(cast(a, c_char_p).value, b"abcdef") - self.assertEqual(cast(a, POINTER(c_byte))[:7], - [97, 98, 99, 100, 101, 102, 0]) - self.assertEqual(cast(a, POINTER(c_byte))[:7:], - [97, 98, 99, 100, 101, 102, 0]) - self.assertEqual(cast(a, POINTER(c_byte))[6:-1:-1], - [0, 102, 101, 100, 99, 98, 97]) - self.assertEqual(cast(a, POINTER(c_byte))[:7:2], - [97, 99, 101, 0]) - self.assertEqual(cast(a, POINTER(c_byte))[:7:7], - [97]) - - @support.refcount_test - def test_string_at(self): - s = string_at(b"foo bar") - # XXX The following may be wrong, depending on how Python - # manages string instances - self.assertEqual(2, sys.getrefcount(s)) - self.assertTrue(s, "foo bar") - - self.assertEqual(string_at(b"foo bar", 7), b"foo bar") - self.assertEqual(string_at(b"foo bar", 3), b"foo") - - @need_symbol('create_unicode_buffer') - def test_wstring_at(self): - p = create_unicode_buffer("Hello, World") - a = create_unicode_buffer(1000000) - result = memmove(a, p, len(p) * sizeof(c_wchar)) - self.assertEqual(a.value, "Hello, World") - - self.assertEqual(wstring_at(a), "Hello, World") - self.assertEqual(wstring_at(a, 5), "Hello") - self.assertEqual(wstring_at(a, 16), "Hello, World\0\0\0\0") - self.assertEqual(wstring_at(a, 0), "") - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_numbers.py b/Lib/ctypes/test/test_numbers.py deleted file mode 100644 index a5c661b0e97..00000000000 --- a/Lib/ctypes/test/test_numbers.py +++ /dev/null @@ -1,218 +0,0 @@ -from ctypes import * -import unittest -import struct - -def valid_ranges(*types): - # given a sequence of numeric types, collect their _type_ - # attribute, which is a single format character compatible with - # the struct module, use the struct module to calculate the - # minimum and maximum value allowed for this format. - # Returns a list of (min, max) values. - result = [] - for t in types: - fmt = t._type_ - size = struct.calcsize(fmt) - a = struct.unpack(fmt, (b"\x00"*32)[:size])[0] - b = struct.unpack(fmt, (b"\xFF"*32)[:size])[0] - c = struct.unpack(fmt, (b"\x7F"+b"\x00"*32)[:size])[0] - d = struct.unpack(fmt, (b"\x80"+b"\xFF"*32)[:size])[0] - result.append((min(a, b, c, d), max(a, b, c, d))) - return result - -ArgType = type(byref(c_int(0))) - -unsigned_types = [c_ubyte, c_ushort, c_uint, c_ulong] -signed_types = [c_byte, c_short, c_int, c_long, c_longlong] - -bool_types = [] - -float_types = [c_double, c_float] - -try: - c_ulonglong - c_longlong -except NameError: - pass -else: - unsigned_types.append(c_ulonglong) - signed_types.append(c_longlong) - -try: - c_bool -except NameError: - pass -else: - bool_types.append(c_bool) - -unsigned_ranges = valid_ranges(*unsigned_types) -signed_ranges = valid_ranges(*signed_types) -bool_values = [True, False, 0, 1, -1, 5000, 'test', [], [1]] - -################################################################ - -class NumberTestCase(unittest.TestCase): - - def test_default_init(self): - # default values are set to zero - for t in signed_types + unsigned_types + float_types: - self.assertEqual(t().value, 0) - - def test_unsigned_values(self): - # the value given to the constructor is available - # as the 'value' attribute - for t, (l, h) in zip(unsigned_types, unsigned_ranges): - self.assertEqual(t(l).value, l) - self.assertEqual(t(h).value, h) - - def test_signed_values(self): - # see above - for t, (l, h) in zip(signed_types, signed_ranges): - self.assertEqual(t(l).value, l) - self.assertEqual(t(h).value, h) - - def test_bool_values(self): - from operator import truth - for t, v in zip(bool_types, bool_values): - self.assertEqual(t(v).value, truth(v)) - - def test_typeerror(self): - # Only numbers are allowed in the constructor, - # otherwise TypeError is raised - for t in signed_types + unsigned_types + float_types: - self.assertRaises(TypeError, t, "") - self.assertRaises(TypeError, t, None) - - def test_from_param(self): - # the from_param class method attribute always - # returns PyCArgObject instances - for t in signed_types + unsigned_types + float_types: - self.assertEqual(ArgType, type(t.from_param(0))) - - def test_byref(self): - # calling byref returns also a PyCArgObject instance - for t in signed_types + unsigned_types + float_types + bool_types: - parm = byref(t()) - self.assertEqual(ArgType, type(parm)) - - - def test_floats(self): - # c_float and c_double can be created from - # Python int and float - class FloatLike: - def __float__(self): - return 2.0 - f = FloatLike() - for t in float_types: - self.assertEqual(t(2.0).value, 2.0) - self.assertEqual(t(2).value, 2.0) - self.assertEqual(t(2).value, 2.0) - self.assertEqual(t(f).value, 2.0) - - def test_integers(self): - class FloatLike: - def __float__(self): - return 2.0 - f = FloatLike() - class IntLike: - def __int__(self): - return 2 - d = IntLike() - class IndexLike: - def __index__(self): - return 2 - i = IndexLike() - # integers cannot be constructed from floats, - # but from integer-like objects - for t in signed_types + unsigned_types: - self.assertRaises(TypeError, t, 3.14) - self.assertRaises(TypeError, t, f) - self.assertRaises(TypeError, t, d) - self.assertEqual(t(i).value, 2) - - def test_sizes(self): - for t in signed_types + unsigned_types + float_types + bool_types: - try: - size = struct.calcsize(t._type_) - except struct.error: - continue - # sizeof of the type... - self.assertEqual(sizeof(t), size) - # and sizeof of an instance - self.assertEqual(sizeof(t()), size) - - def test_alignments(self): - for t in signed_types + unsigned_types + float_types: - code = t._type_ # the typecode - align = struct.calcsize("c%c" % code) - struct.calcsize(code) - - # alignment of the type... - self.assertEqual((code, alignment(t)), - (code, align)) - # and alignment of an instance - self.assertEqual((code, alignment(t())), - (code, align)) - - def test_int_from_address(self): - from array import array - for t in signed_types + unsigned_types: - # the array module doesn't support all format codes - # (no 'q' or 'Q') - try: - array(t._type_) - except ValueError: - continue - a = array(t._type_, [100]) - - # v now is an integer at an 'external' memory location - v = t.from_address(a.buffer_info()[0]) - self.assertEqual(v.value, a[0]) - self.assertEqual(type(v), t) - - # changing the value at the memory location changes v's value also - a[0] = 42 - self.assertEqual(v.value, a[0]) - - - def test_float_from_address(self): - from array import array - for t in float_types: - a = array(t._type_, [3.14]) - v = t.from_address(a.buffer_info()[0]) - self.assertEqual(v.value, a[0]) - self.assertIs(type(v), t) - a[0] = 2.3456e17 - self.assertEqual(v.value, a[0]) - self.assertIs(type(v), t) - - def test_char_from_address(self): - from ctypes import c_char - from array import array - - a = array('b', [0]) - a[0] = ord('x') - v = c_char.from_address(a.buffer_info()[0]) - self.assertEqual(v.value, b'x') - self.assertIs(type(v), c_char) - - a[0] = ord('?') - self.assertEqual(v.value, b'?') - - def test_init(self): - # c_int() can be initialized from Python's int, and c_int. - # Not from c_long or so, which seems strange, abc should - # probably be changed: - self.assertRaises(TypeError, c_int, c_long(42)) - - def test_float_overflow(self): - import sys - big_int = int(sys.float_info.max) * 2 - for t in float_types + [c_longdouble]: - self.assertRaises(OverflowError, t, big_int) - if (hasattr(t, "__ctype_be__")): - self.assertRaises(OverflowError, t.__ctype_be__, big_int) - if (hasattr(t, "__ctype_le__")): - self.assertRaises(OverflowError, t.__ctype_le__, big_int) - - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_objects.py b/Lib/ctypes/test/test_objects.py deleted file mode 100644 index 19e3dc1f2d7..00000000000 --- a/Lib/ctypes/test/test_objects.py +++ /dev/null @@ -1,67 +0,0 @@ -r''' -This tests the '_objects' attribute of ctypes instances. '_objects' -holds references to objects that must be kept alive as long as the -ctypes instance, to make sure that the memory buffer is valid. - -WARNING: The '_objects' attribute is exposed ONLY for debugging ctypes itself, -it MUST NEVER BE MODIFIED! - -'_objects' is initialized to a dictionary on first use, before that it -is None. - -Here is an array of string pointers: - ->>> from ctypes import * ->>> array = (c_char_p * 5)() ->>> print(array._objects) -None ->>> - -The memory block stores pointers to strings, and the strings itself -assigned from Python must be kept. - ->>> array[4] = b'foo bar' ->>> array._objects -{'4': b'foo bar'} ->>> array[4] -b'foo bar' ->>> - -It gets more complicated when the ctypes instance itself is contained -in a 'base' object. - ->>> class X(Structure): -... _fields_ = [("x", c_int), ("y", c_int), ("array", c_char_p * 5)] -... ->>> x = X() ->>> print(x._objects) -None ->>> - -The'array' attribute of the 'x' object shares part of the memory buffer -of 'x' ('_b_base_' is either None, or the root object owning the memory block): - ->>> print(x.array._b_base_) # doctest: +ELLIPSIS - ->>> - ->>> x.array[0] = b'spam spam spam' ->>> x._objects -{'0:2': b'spam spam spam'} ->>> x.array._b_base_._objects -{'0:2': b'spam spam spam'} ->>> - -''' - -import unittest, doctest - -import ctypes.test.test_objects - -class TestCase(unittest.TestCase): - def test(self): - failures, tests = doctest.testmod(ctypes.test.test_objects) - self.assertFalse(failures, 'doctests failed, see output above') - -if __name__ == '__main__': - doctest.testmod(ctypes.test.test_objects) diff --git a/Lib/ctypes/test/test_parameters.py b/Lib/ctypes/test/test_parameters.py deleted file mode 100644 index 38af7ac13d7..00000000000 --- a/Lib/ctypes/test/test_parameters.py +++ /dev/null @@ -1,250 +0,0 @@ -import unittest -from ctypes.test import need_symbol -import test.support - -class SimpleTypesTestCase(unittest.TestCase): - - def setUp(self): - import ctypes - try: - from _ctypes import set_conversion_mode - except ImportError: - pass - else: - self.prev_conv_mode = set_conversion_mode("ascii", "strict") - - def tearDown(self): - try: - from _ctypes import set_conversion_mode - except ImportError: - pass - else: - set_conversion_mode(*self.prev_conv_mode) - - def test_subclasses(self): - from ctypes import c_void_p, c_char_p - # ctypes 0.9.5 and before did overwrite from_param in SimpleType_new - class CVOIDP(c_void_p): - def from_param(cls, value): - return value * 2 - from_param = classmethod(from_param) - - class CCHARP(c_char_p): - def from_param(cls, value): - return value * 4 - from_param = classmethod(from_param) - - self.assertEqual(CVOIDP.from_param("abc"), "abcabc") - self.assertEqual(CCHARP.from_param("abc"), "abcabcabcabc") - - @need_symbol('c_wchar_p') - def test_subclasses_c_wchar_p(self): - from ctypes import c_wchar_p - - class CWCHARP(c_wchar_p): - def from_param(cls, value): - return value * 3 - from_param = classmethod(from_param) - - self.assertEqual(CWCHARP.from_param("abc"), "abcabcabc") - - # XXX Replace by c_char_p tests - def test_cstrings(self): - from ctypes import c_char_p - - # c_char_p.from_param on a Python String packs the string - # into a cparam object - s = b"123" - self.assertIs(c_char_p.from_param(s)._obj, s) - - # new in 0.9.1: convert (encode) unicode to ascii - self.assertEqual(c_char_p.from_param(b"123")._obj, b"123") - self.assertRaises(TypeError, c_char_p.from_param, "123\377") - self.assertRaises(TypeError, c_char_p.from_param, 42) - - # calling c_char_p.from_param with a c_char_p instance - # returns the argument itself: - a = c_char_p(b"123") - self.assertIs(c_char_p.from_param(a), a) - - @need_symbol('c_wchar_p') - def test_cw_strings(self): - from ctypes import c_wchar_p - - c_wchar_p.from_param("123") - - self.assertRaises(TypeError, c_wchar_p.from_param, 42) - self.assertRaises(TypeError, c_wchar_p.from_param, b"123\377") - - pa = c_wchar_p.from_param(c_wchar_p("123")) - self.assertEqual(type(pa), c_wchar_p) - - def test_int_pointers(self): - from ctypes import c_short, c_uint, c_int, c_long, POINTER, pointer - LPINT = POINTER(c_int) - -## p = pointer(c_int(42)) -## x = LPINT.from_param(p) - x = LPINT.from_param(pointer(c_int(42))) - self.assertEqual(x.contents.value, 42) - self.assertEqual(LPINT(c_int(42)).contents.value, 42) - - self.assertEqual(LPINT.from_param(None), None) - - if c_int != c_long: - self.assertRaises(TypeError, LPINT.from_param, pointer(c_long(42))) - self.assertRaises(TypeError, LPINT.from_param, pointer(c_uint(42))) - self.assertRaises(TypeError, LPINT.from_param, pointer(c_short(42))) - - def test_byref_pointer(self): - # The from_param class method of POINTER(typ) classes accepts what is - # returned by byref(obj), it type(obj) == typ - from ctypes import c_short, c_uint, c_int, c_long, POINTER, byref - LPINT = POINTER(c_int) - - LPINT.from_param(byref(c_int(42))) - - self.assertRaises(TypeError, LPINT.from_param, byref(c_short(22))) - if c_int != c_long: - self.assertRaises(TypeError, LPINT.from_param, byref(c_long(22))) - self.assertRaises(TypeError, LPINT.from_param, byref(c_uint(22))) - - def test_byref_pointerpointer(self): - # See above - from ctypes import c_short, c_uint, c_int, c_long, pointer, POINTER, byref - - LPLPINT = POINTER(POINTER(c_int)) - LPLPINT.from_param(byref(pointer(c_int(42)))) - - self.assertRaises(TypeError, LPLPINT.from_param, byref(pointer(c_short(22)))) - if c_int != c_long: - self.assertRaises(TypeError, LPLPINT.from_param, byref(pointer(c_long(22)))) - self.assertRaises(TypeError, LPLPINT.from_param, byref(pointer(c_uint(22)))) - - def test_array_pointers(self): - from ctypes import c_short, c_uint, c_int, c_long, POINTER - INTARRAY = c_int * 3 - ia = INTARRAY() - self.assertEqual(len(ia), 3) - self.assertEqual([ia[i] for i in range(3)], [0, 0, 0]) - - # Pointers are only compatible with arrays containing items of - # the same type! - LPINT = POINTER(c_int) - LPINT.from_param((c_int*3)()) - self.assertRaises(TypeError, LPINT.from_param, c_short*3) - self.assertRaises(TypeError, LPINT.from_param, c_long*3) - self.assertRaises(TypeError, LPINT.from_param, c_uint*3) - - def test_noctypes_argtype(self): - import _ctypes_test - from ctypes import CDLL, c_void_p, ArgumentError - - func = CDLL(_ctypes_test.__file__)._testfunc_p_p - func.restype = c_void_p - # TypeError: has no from_param method - self.assertRaises(TypeError, setattr, func, "argtypes", (object,)) - - class Adapter(object): - def from_param(cls, obj): - return None - - func.argtypes = (Adapter(),) - self.assertEqual(func(None), None) - self.assertEqual(func(object()), None) - - class Adapter(object): - def from_param(cls, obj): - return obj - - func.argtypes = (Adapter(),) - # don't know how to convert parameter 1 - self.assertRaises(ArgumentError, func, object()) - self.assertEqual(func(c_void_p(42)), 42) - - class Adapter(object): - def from_param(cls, obj): - raise ValueError(obj) - - func.argtypes = (Adapter(),) - # ArgumentError: argument 1: ValueError: 99 - self.assertRaises(ArgumentError, func, 99) - - def test_abstract(self): - from ctypes import (Array, Structure, Union, _Pointer, - _SimpleCData, _CFuncPtr) - - self.assertRaises(TypeError, Array.from_param, 42) - self.assertRaises(TypeError, Structure.from_param, 42) - self.assertRaises(TypeError, Union.from_param, 42) - self.assertRaises(TypeError, _CFuncPtr.from_param, 42) - self.assertRaises(TypeError, _Pointer.from_param, 42) - self.assertRaises(TypeError, _SimpleCData.from_param, 42) - - @test.support.cpython_only - def test_issue31311(self): - # __setstate__ should neither raise a SystemError nor crash in case - # of a bad __dict__. - from ctypes import Structure - - class BadStruct(Structure): - @property - def __dict__(self): - pass - with self.assertRaises(TypeError): - BadStruct().__setstate__({}, b'foo') - - class WorseStruct(Structure): - @property - def __dict__(self): - 1/0 - with self.assertRaises(ZeroDivisionError): - WorseStruct().__setstate__({}, b'foo') - - def test_parameter_repr(self): - from ctypes import ( - c_bool, - c_char, - c_wchar, - c_byte, - c_ubyte, - c_short, - c_ushort, - c_int, - c_uint, - c_long, - c_ulong, - c_longlong, - c_ulonglong, - c_float, - c_double, - c_longdouble, - c_char_p, - c_wchar_p, - c_void_p, - ) - self.assertRegex(repr(c_bool.from_param(True)), r"^$") - self.assertEqual(repr(c_char.from_param(97)), "") - self.assertRegex(repr(c_wchar.from_param('a')), r"^$") - self.assertEqual(repr(c_byte.from_param(98)), "") - self.assertEqual(repr(c_ubyte.from_param(98)), "") - self.assertEqual(repr(c_short.from_param(511)), "") - self.assertEqual(repr(c_ushort.from_param(511)), "") - self.assertRegex(repr(c_int.from_param(20000)), r"^$") - self.assertRegex(repr(c_uint.from_param(20000)), r"^$") - self.assertRegex(repr(c_long.from_param(20000)), r"^$") - self.assertRegex(repr(c_ulong.from_param(20000)), r"^$") - self.assertRegex(repr(c_longlong.from_param(20000)), r"^$") - self.assertRegex(repr(c_ulonglong.from_param(20000)), r"^$") - self.assertEqual(repr(c_float.from_param(1.5)), "") - self.assertEqual(repr(c_double.from_param(1.5)), "") - self.assertEqual(repr(c_double.from_param(1e300)), "") - self.assertRegex(repr(c_longdouble.from_param(1.5)), r"^$") - self.assertRegex(repr(c_char_p.from_param(b'hihi')), r"^$") - self.assertRegex(repr(c_wchar_p.from_param('hihi')), r"^$") - self.assertRegex(repr(c_void_p.from_param(0x12)), r"^$") - -################################################################ - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_pep3118.py b/Lib/ctypes/test/test_pep3118.py deleted file mode 100644 index efffc80a66f..00000000000 --- a/Lib/ctypes/test/test_pep3118.py +++ /dev/null @@ -1,235 +0,0 @@ -import unittest -from ctypes import * -import re, sys - -if sys.byteorder == "little": - THIS_ENDIAN = "<" - OTHER_ENDIAN = ">" -else: - THIS_ENDIAN = ">" - OTHER_ENDIAN = "<" - -def normalize(format): - # Remove current endian specifier and white space from a format - # string - if format is None: - return "" - format = format.replace(OTHER_ENDIAN, THIS_ENDIAN) - return re.sub(r"\s", "", format) - -class Test(unittest.TestCase): - - def test_native_types(self): - for tp, fmt, shape, itemtp in native_types: - ob = tp() - v = memoryview(ob) - try: - self.assertEqual(normalize(v.format), normalize(fmt)) - if shape: - self.assertEqual(len(v), shape[0]) - else: - self.assertEqual(len(v) * sizeof(itemtp), sizeof(ob)) - self.assertEqual(v.itemsize, sizeof(itemtp)) - self.assertEqual(v.shape, shape) - # XXX Issue #12851: PyCData_NewGetBuffer() must provide strides - # if requested. memoryview currently reconstructs missing - # stride information, so this assert will fail. - # self.assertEqual(v.strides, ()) - - # they are always read/write - self.assertFalse(v.readonly) - - if v.shape: - n = 1 - for dim in v.shape: - n = n * dim - self.assertEqual(n * v.itemsize, len(v.tobytes())) - except: - # so that we can see the failing type - print(tp) - raise - - def test_endian_types(self): - for tp, fmt, shape, itemtp in endian_types: - ob = tp() - v = memoryview(ob) - try: - self.assertEqual(v.format, fmt) - if shape: - self.assertEqual(len(v), shape[0]) - else: - self.assertEqual(len(v) * sizeof(itemtp), sizeof(ob)) - self.assertEqual(v.itemsize, sizeof(itemtp)) - self.assertEqual(v.shape, shape) - # XXX Issue #12851 - # self.assertEqual(v.strides, ()) - - # they are always read/write - self.assertFalse(v.readonly) - - if v.shape: - n = 1 - for dim in v.shape: - n = n * dim - self.assertEqual(n, len(v)) - except: - # so that we can see the failing type - print(tp) - raise - -# define some structure classes - -class Point(Structure): - _fields_ = [("x", c_long), ("y", c_long)] - -class PackedPoint(Structure): - _pack_ = 2 - _fields_ = [("x", c_long), ("y", c_long)] - -class Point2(Structure): - pass -Point2._fields_ = [("x", c_long), ("y", c_long)] - -class EmptyStruct(Structure): - _fields_ = [] - -class aUnion(Union): - _fields_ = [("a", c_int)] - -class StructWithArrays(Structure): - _fields_ = [("x", c_long * 3 * 2), ("y", Point * 4)] - -class Incomplete(Structure): - pass - -class Complete(Structure): - pass -PComplete = POINTER(Complete) -Complete._fields_ = [("a", c_long)] - -################################################################ -# -# This table contains format strings as they look on little endian -# machines. The test replaces '<' with '>' on big endian machines. -# - -# Platform-specific type codes -s_bool = {1: '?', 2: 'H', 4: 'L', 8: 'Q'}[sizeof(c_bool)] -s_short = {2: 'h', 4: 'l', 8: 'q'}[sizeof(c_short)] -s_ushort = {2: 'H', 4: 'L', 8: 'Q'}[sizeof(c_ushort)] -s_int = {2: 'h', 4: 'i', 8: 'q'}[sizeof(c_int)] -s_uint = {2: 'H', 4: 'I', 8: 'Q'}[sizeof(c_uint)] -s_long = {4: 'l', 8: 'q'}[sizeof(c_long)] -s_ulong = {4: 'L', 8: 'Q'}[sizeof(c_ulong)] -s_longlong = "q" -s_ulonglong = "Q" -s_float = "f" -s_double = "d" -s_longdouble = "g" - -# Alias definitions in ctypes/__init__.py -if c_int is c_long: - s_int = s_long -if c_uint is c_ulong: - s_uint = s_ulong -if c_longlong is c_long: - s_longlong = s_long -if c_ulonglong is c_ulong: - s_ulonglong = s_ulong -if c_longdouble is c_double: - s_longdouble = s_double - - -native_types = [ - # type format shape calc itemsize - - ## simple types - - (c_char, "l:x:>l:y:}".replace('l', s_long), (), BEPoint), - (LEPoint, "T{l:x:>l:y:}".replace('l', s_long), (), POINTER(BEPoint)), - (POINTER(LEPoint), "&T{= 0: - return a - # View the bits in `a` as unsigned instead. - import struct - num_bits = struct.calcsize("P") * 8 # num bits in native machine address - a += 1 << num_bits - assert a >= 0 - return a - -def c_wbuffer(init): - n = len(init) + 1 - return (c_wchar * n)(*init) - -class CharPointersTestCase(unittest.TestCase): - - def setUp(self): - func = testdll._testfunc_p_p - func.restype = c_long - func.argtypes = None - - def test_paramflags(self): - # function returns c_void_p result, - # and has a required parameter named 'input' - prototype = CFUNCTYPE(c_void_p, c_void_p) - func = prototype(("_testfunc_p_p", testdll), - ((1, "input"),)) - - try: - func() - except TypeError as details: - self.assertEqual(str(details), "required argument 'input' missing") - else: - self.fail("TypeError not raised") - - self.assertEqual(func(None), None) - self.assertEqual(func(input=None), None) - - - def test_int_pointer_arg(self): - func = testdll._testfunc_p_p - if sizeof(c_longlong) == sizeof(c_void_p): - func.restype = c_longlong - else: - func.restype = c_long - self.assertEqual(0, func(0)) - - ci = c_int(0) - - func.argtypes = POINTER(c_int), - self.assertEqual(positive_address(addressof(ci)), - positive_address(func(byref(ci)))) - - func.argtypes = c_char_p, - self.assertRaises(ArgumentError, func, byref(ci)) - - func.argtypes = POINTER(c_short), - self.assertRaises(ArgumentError, func, byref(ci)) - - func.argtypes = POINTER(c_double), - self.assertRaises(ArgumentError, func, byref(ci)) - - def test_POINTER_c_char_arg(self): - func = testdll._testfunc_p_p - func.restype = c_char_p - func.argtypes = POINTER(c_char), - - self.assertEqual(None, func(None)) - self.assertEqual(b"123", func(b"123")) - self.assertEqual(None, func(c_char_p(None))) - self.assertEqual(b"123", func(c_char_p(b"123"))) - - self.assertEqual(b"123", func(c_buffer(b"123"))) - ca = c_char(b"a") - self.assertEqual(ord(b"a"), func(pointer(ca))[0]) - self.assertEqual(ord(b"a"), func(byref(ca))[0]) - - def test_c_char_p_arg(self): - func = testdll._testfunc_p_p - func.restype = c_char_p - func.argtypes = c_char_p, - - self.assertEqual(None, func(None)) - self.assertEqual(b"123", func(b"123")) - self.assertEqual(None, func(c_char_p(None))) - self.assertEqual(b"123", func(c_char_p(b"123"))) - - self.assertEqual(b"123", func(c_buffer(b"123"))) - ca = c_char(b"a") - self.assertEqual(ord(b"a"), func(pointer(ca))[0]) - self.assertEqual(ord(b"a"), func(byref(ca))[0]) - - def test_c_void_p_arg(self): - func = testdll._testfunc_p_p - func.restype = c_char_p - func.argtypes = c_void_p, - - self.assertEqual(None, func(None)) - self.assertEqual(b"123", func(b"123")) - self.assertEqual(b"123", func(c_char_p(b"123"))) - self.assertEqual(None, func(c_char_p(None))) - - self.assertEqual(b"123", func(c_buffer(b"123"))) - ca = c_char(b"a") - self.assertEqual(ord(b"a"), func(pointer(ca))[0]) - self.assertEqual(ord(b"a"), func(byref(ca))[0]) - - func(byref(c_int())) - func(pointer(c_int())) - func((c_int * 3)()) - - @need_symbol('c_wchar_p') - def test_c_void_p_arg_with_c_wchar_p(self): - func = testdll._testfunc_p_p - func.restype = c_wchar_p - func.argtypes = c_void_p, - - self.assertEqual(None, func(c_wchar_p(None))) - self.assertEqual("123", func(c_wchar_p("123"))) - - def test_instance(self): - func = testdll._testfunc_p_p - func.restype = c_void_p - - class X: - _as_parameter_ = None - - func.argtypes = c_void_p, - self.assertEqual(None, func(X())) - - func.argtypes = None - self.assertEqual(None, func(X())) - -@need_symbol('c_wchar') -class WCharPointersTestCase(unittest.TestCase): - - def setUp(self): - func = testdll._testfunc_p_p - func.restype = c_int - func.argtypes = None - - - def test_POINTER_c_wchar_arg(self): - func = testdll._testfunc_p_p - func.restype = c_wchar_p - func.argtypes = POINTER(c_wchar), - - self.assertEqual(None, func(None)) - self.assertEqual("123", func("123")) - self.assertEqual(None, func(c_wchar_p(None))) - self.assertEqual("123", func(c_wchar_p("123"))) - - self.assertEqual("123", func(c_wbuffer("123"))) - ca = c_wchar("a") - self.assertEqual("a", func(pointer(ca))[0]) - self.assertEqual("a", func(byref(ca))[0]) - - def test_c_wchar_p_arg(self): - func = testdll._testfunc_p_p - func.restype = c_wchar_p - func.argtypes = c_wchar_p, - - c_wchar_p.from_param("123") - - self.assertEqual(None, func(None)) - self.assertEqual("123", func("123")) - self.assertEqual(None, func(c_wchar_p(None))) - self.assertEqual("123", func(c_wchar_p("123"))) - - # XXX Currently, these raise TypeErrors, although they shouldn't: - self.assertEqual("123", func(c_wbuffer("123"))) - ca = c_wchar("a") - self.assertEqual("a", func(pointer(ca))[0]) - self.assertEqual("a", func(byref(ca))[0]) - -class ArrayTest(unittest.TestCase): - def test(self): - func = testdll._testfunc_ai8 - func.restype = POINTER(c_int) - func.argtypes = c_int * 8, - - func((c_int * 8)(1, 2, 3, 4, 5, 6, 7, 8)) - - # This did crash before: - - def func(): pass - CFUNCTYPE(None, c_int * 3)(func) - -################################################################ - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_python_api.py b/Lib/ctypes/test/test_python_api.py deleted file mode 100644 index 49571f97bbe..00000000000 --- a/Lib/ctypes/test/test_python_api.py +++ /dev/null @@ -1,85 +0,0 @@ -from ctypes import * -import unittest -from test import support - -################################################################ -# This section should be moved into ctypes\__init__.py, when it's ready. - -from _ctypes import PyObj_FromPtr - -################################################################ - -from sys import getrefcount as grc - -class PythonAPITestCase(unittest.TestCase): - - def test_PyBytes_FromStringAndSize(self): - PyBytes_FromStringAndSize = pythonapi.PyBytes_FromStringAndSize - - PyBytes_FromStringAndSize.restype = py_object - PyBytes_FromStringAndSize.argtypes = c_char_p, c_size_t - - self.assertEqual(PyBytes_FromStringAndSize(b"abcdefghi", 3), b"abc") - - @support.refcount_test - def test_PyString_FromString(self): - pythonapi.PyBytes_FromString.restype = py_object - pythonapi.PyBytes_FromString.argtypes = (c_char_p,) - - s = b"abc" - refcnt = grc(s) - pyob = pythonapi.PyBytes_FromString(s) - self.assertEqual(grc(s), refcnt) - self.assertEqual(s, pyob) - del pyob - self.assertEqual(grc(s), refcnt) - - @support.refcount_test - def test_PyLong_Long(self): - ref42 = grc(42) - pythonapi.PyLong_FromLong.restype = py_object - self.assertEqual(pythonapi.PyLong_FromLong(42), 42) - - self.assertEqual(grc(42), ref42) - - pythonapi.PyLong_AsLong.argtypes = (py_object,) - pythonapi.PyLong_AsLong.restype = c_long - - res = pythonapi.PyLong_AsLong(42) - self.assertEqual(grc(res), ref42 + 1) - del res - self.assertEqual(grc(42), ref42) - - @support.refcount_test - def test_PyObj_FromPtr(self): - s = "abc def ghi jkl" - ref = grc(s) - # id(python-object) is the address - pyobj = PyObj_FromPtr(id(s)) - self.assertIs(s, pyobj) - - self.assertEqual(grc(s), ref + 1) - del pyobj - self.assertEqual(grc(s), ref) - - def test_PyOS_snprintf(self): - PyOS_snprintf = pythonapi.PyOS_snprintf - PyOS_snprintf.argtypes = POINTER(c_char), c_size_t, c_char_p - - buf = c_buffer(256) - PyOS_snprintf(buf, sizeof(buf), b"Hello from %s", b"ctypes") - self.assertEqual(buf.value, b"Hello from ctypes") - - PyOS_snprintf(buf, sizeof(buf), b"Hello from %s (%d, %d, %d)", b"ctypes", 1, 2, 3) - self.assertEqual(buf.value, b"Hello from ctypes (1, 2, 3)") - - # not enough arguments - self.assertRaises(TypeError, PyOS_snprintf, buf) - - def test_pyobject_repr(self): - self.assertEqual(repr(py_object()), "py_object()") - self.assertEqual(repr(py_object(42)), "py_object(42)") - self.assertEqual(repr(py_object(object)), "py_object(%r)" % object) - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/ctypes/test/test_random_things.py b/Lib/ctypes/test/test_random_things.py deleted file mode 100644 index 2988e275cf4..00000000000 --- a/Lib/ctypes/test/test_random_things.py +++ /dev/null @@ -1,77 +0,0 @@ -from ctypes import * -import contextlib -from test import support -import unittest -import sys - - -def callback_func(arg): - 42 / arg - raise ValueError(arg) - -@unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') -class call_function_TestCase(unittest.TestCase): - # _ctypes.call_function is deprecated and private, but used by - # Gary Bishp's readline module. If we have it, we must test it as well. - - def test(self): - from _ctypes import call_function - windll.kernel32.LoadLibraryA.restype = c_void_p - windll.kernel32.GetProcAddress.argtypes = c_void_p, c_char_p - windll.kernel32.GetProcAddress.restype = c_void_p - - hdll = windll.kernel32.LoadLibraryA(b"kernel32") - funcaddr = windll.kernel32.GetProcAddress(hdll, b"GetModuleHandleA") - - self.assertEqual(call_function(funcaddr, (None,)), - windll.kernel32.GetModuleHandleA(None)) - -class CallbackTracbackTestCase(unittest.TestCase): - # When an exception is raised in a ctypes callback function, the C - # code prints a traceback. - # - # This test makes sure the exception types *and* the exception - # value is printed correctly. - # - # Changed in 0.9.3: No longer is '(in callback)' prepended to the - # error message - instead an additional frame for the C code is - # created, then a full traceback printed. When SystemExit is - # raised in a callback function, the interpreter exits. - - @contextlib.contextmanager - def expect_unraisable(self, exc_type, exc_msg=None): - with support.catch_unraisable_exception() as cm: - yield - - self.assertIsInstance(cm.unraisable.exc_value, exc_type) - if exc_msg is not None: - self.assertEqual(str(cm.unraisable.exc_value), exc_msg) - self.assertEqual(cm.unraisable.err_msg, - "Exception ignored on calling ctypes " - "callback function") - self.assertIs(cm.unraisable.object, callback_func) - - def test_ValueError(self): - cb = CFUNCTYPE(c_int, c_int)(callback_func) - with self.expect_unraisable(ValueError, '42'): - cb(42) - - def test_IntegerDivisionError(self): - cb = CFUNCTYPE(c_int, c_int)(callback_func) - with self.expect_unraisable(ZeroDivisionError): - cb(0) - - def test_FloatDivisionError(self): - cb = CFUNCTYPE(c_int, c_double)(callback_func) - with self.expect_unraisable(ZeroDivisionError): - cb(0.0) - - def test_TypeErrorDivisionError(self): - cb = CFUNCTYPE(c_int, c_char_p)(callback_func) - err_msg = "unsupported operand type(s) for /: 'int' and 'bytes'" - with self.expect_unraisable(TypeError, err_msg): - cb(b"spam") - - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_refcounts.py b/Lib/ctypes/test/test_refcounts.py deleted file mode 100644 index 48958cd2a60..00000000000 --- a/Lib/ctypes/test/test_refcounts.py +++ /dev/null @@ -1,116 +0,0 @@ -import unittest -from test import support -import ctypes -import gc - -MyCallback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int) -OtherCallback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulonglong) - -import _ctypes_test -dll = ctypes.CDLL(_ctypes_test.__file__) - -class RefcountTestCase(unittest.TestCase): - - @support.refcount_test - def test_1(self): - from sys import getrefcount as grc - - f = dll._testfunc_callback_i_if - f.restype = ctypes.c_int - f.argtypes = [ctypes.c_int, MyCallback] - - def callback(value): - #print "called back with", value - return value - - self.assertEqual(grc(callback), 2) - cb = MyCallback(callback) - - self.assertGreater(grc(callback), 2) - result = f(-10, cb) - self.assertEqual(result, -18) - cb = None - - gc.collect() - - self.assertEqual(grc(callback), 2) - - - @support.refcount_test - def test_refcount(self): - from sys import getrefcount as grc - def func(*args): - pass - # this is the standard refcount for func - self.assertEqual(grc(func), 2) - - # the CFuncPtr instance holds at least one refcount on func: - f = OtherCallback(func) - self.assertGreater(grc(func), 2) - - # and may release it again - del f - self.assertGreaterEqual(grc(func), 2) - - # but now it must be gone - gc.collect() - self.assertEqual(grc(func), 2) - - class X(ctypes.Structure): - _fields_ = [("a", OtherCallback)] - x = X() - x.a = OtherCallback(func) - - # the CFuncPtr instance holds at least one refcount on func: - self.assertGreater(grc(func), 2) - - # and may release it again - del x - self.assertGreaterEqual(grc(func), 2) - - # and now it must be gone again - gc.collect() - self.assertEqual(grc(func), 2) - - f = OtherCallback(func) - - # the CFuncPtr instance holds at least one refcount on func: - self.assertGreater(grc(func), 2) - - # create a cycle - f.cycle = f - - del f - gc.collect() - self.assertEqual(grc(func), 2) - -class AnotherLeak(unittest.TestCase): - def test_callback(self): - import sys - - proto = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_int) - def func(a, b): - return a * b * 2 - f = proto(func) - - a = sys.getrefcount(ctypes.c_int) - f(1, 2) - self.assertEqual(sys.getrefcount(ctypes.c_int), a) - - @support.refcount_test - def test_callback_py_object_none_return(self): - # bpo-36880: test that returning None from a py_object callback - # does not decrement the refcount of None. - - for FUNCTYPE in (ctypes.CFUNCTYPE, ctypes.PYFUNCTYPE): - with self.subTest(FUNCTYPE=FUNCTYPE): - @FUNCTYPE(ctypes.py_object) - def func(): - return None - - # Check that calling func does not affect None's refcount. - for _ in range(10000): - func() - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/ctypes/test/test_repr.py b/Lib/ctypes/test/test_repr.py deleted file mode 100644 index 60a2c803453..00000000000 --- a/Lib/ctypes/test/test_repr.py +++ /dev/null @@ -1,29 +0,0 @@ -from ctypes import * -import unittest - -subclasses = [] -for base in [c_byte, c_short, c_int, c_long, c_longlong, - c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong, - c_float, c_double, c_longdouble, c_bool]: - class X(base): - pass - subclasses.append(X) - -class X(c_char): - pass - -# This test checks if the __repr__ is correct for subclasses of simple types - -class ReprTest(unittest.TestCase): - def test_numbers(self): - for typ in subclasses: - base = typ.__bases__[0] - self.assertTrue(repr(base(42)).startswith(base.__name__)) - self.assertEqual(" Date: Sat, 20 Dec 2025 14:32:07 +0900 Subject: [PATCH 037/418] test_ctypes from CPython 3.13.11 --- Lib/test/test_ctypes.py | 10 - Lib/test/test_ctypes/__init__.py | 10 + Lib/test/test_ctypes/__main__.py | 4 + Lib/test/test_ctypes/_support.py | 24 + .../test_ctypes/test_aligned_structures.py | 321 ++++++ Lib/test/test_ctypes/test_anon.py | 75 ++ Lib/test/test_ctypes/test_array_in_pointer.py | 68 ++ Lib/test/test_ctypes/test_arrays.py | 272 +++++ Lib/test/test_ctypes/test_as_parameter.py | 245 +++++ Lib/test/test_ctypes/test_bitfields.py | 294 ++++++ Lib/test/test_ctypes/test_buffers.py | 70 ++ Lib/test/test_ctypes/test_bytes.py | 67 ++ Lib/test/test_ctypes/test_byteswap.py | 386 +++++++ .../test_ctypes/test_c_simple_type_meta.py | 152 +++ Lib/test/test_ctypes/test_callbacks.py | 333 ++++++ Lib/test/test_ctypes/test_cast.py | 100 ++ Lib/test/test_ctypes/test_cfuncs.py | 211 ++++ Lib/test/test_ctypes/test_checkretval.py | 37 + Lib/test/test_ctypes/test_delattr.py | 26 + Lib/test/test_ctypes/test_dlerror.py | 179 ++++ Lib/test/test_ctypes/test_errno.py | 81 ++ Lib/test/test_ctypes/test_find.py | 156 +++ Lib/test/test_ctypes/test_frombuffer.py | 144 +++ Lib/test/test_ctypes/test_funcptr.py | 134 +++ Lib/test/test_ctypes/test_functions.py | 454 +++++++++ Lib/test/test_ctypes/test_incomplete.py | 50 + Lib/test/test_ctypes/test_init.py | 43 + Lib/test/test_ctypes/test_internals.py | 97 ++ Lib/test/test_ctypes/test_keeprefs.py | 124 +++ Lib/test/test_ctypes/test_libc.py | 38 + Lib/test/test_ctypes/test_loading.py | 208 ++++ Lib/test/test_ctypes/test_macholib.py | 112 ++ Lib/test/test_ctypes/test_memfunctions.py | 82 ++ Lib/test/test_ctypes/test_numbers.py | 200 ++++ Lib/test/test_ctypes/test_objects.py | 66 ++ Lib/test/test_ctypes/test_parameters.py | 288 ++++++ Lib/test/test_ctypes/test_pep3118.py | 250 +++++ Lib/test/test_ctypes/test_pickling.py | 91 ++ Lib/test/test_ctypes/test_pointers.py | 240 +++++ Lib/test/test_ctypes/test_prototypes.py | 252 +++++ Lib/test/test_ctypes/test_python_api.py | 81 ++ Lib/test/test_ctypes/test_random_things.py | 81 ++ Lib/test/test_ctypes/test_refcounts.py | 143 +++ Lib/test/test_ctypes/test_repr.py | 34 + Lib/test/test_ctypes/test_returnfuncptrs.py | 67 ++ Lib/test/test_ctypes/test_simplesubclasses.py | 96 ++ Lib/test/test_ctypes/test_sizes.py | 38 + Lib/test/test_ctypes/test_slicing.py | 170 ++++ Lib/test/test_ctypes/test_stringptr.py | 81 ++ Lib/test/test_ctypes/test_strings.py | 134 +++ Lib/test/test_ctypes/test_struct_fields.py | 126 +++ Lib/test/test_ctypes/test_structures.py | 956 ++++++++++++++++++ .../test_ctypes/test_unaligned_structures.py | 49 + Lib/test/test_ctypes/test_unicode.py | 63 ++ Lib/test/test_ctypes/test_unions.py | 35 + Lib/test/test_ctypes/test_values.py | 107 ++ Lib/test/test_ctypes/test_varsize_struct.py | 52 + Lib/test/test_ctypes/test_win32.py | 152 +++ .../test_win32_com_foreign_func.py | 286 ++++++ Lib/test/test_ctypes/test_wintypes.py | 61 ++ 60 files changed, 8796 insertions(+), 10 deletions(-) delete mode 100644 Lib/test/test_ctypes.py create mode 100644 Lib/test/test_ctypes/__init__.py create mode 100644 Lib/test/test_ctypes/__main__.py create mode 100644 Lib/test/test_ctypes/_support.py create mode 100644 Lib/test/test_ctypes/test_aligned_structures.py create mode 100644 Lib/test/test_ctypes/test_anon.py create mode 100644 Lib/test/test_ctypes/test_array_in_pointer.py create mode 100644 Lib/test/test_ctypes/test_arrays.py create mode 100644 Lib/test/test_ctypes/test_as_parameter.py create mode 100644 Lib/test/test_ctypes/test_bitfields.py create mode 100644 Lib/test/test_ctypes/test_buffers.py create mode 100644 Lib/test/test_ctypes/test_bytes.py create mode 100644 Lib/test/test_ctypes/test_byteswap.py create mode 100644 Lib/test/test_ctypes/test_c_simple_type_meta.py create mode 100644 Lib/test/test_ctypes/test_callbacks.py create mode 100644 Lib/test/test_ctypes/test_cast.py create mode 100644 Lib/test/test_ctypes/test_cfuncs.py create mode 100644 Lib/test/test_ctypes/test_checkretval.py create mode 100644 Lib/test/test_ctypes/test_delattr.py create mode 100644 Lib/test/test_ctypes/test_dlerror.py create mode 100644 Lib/test/test_ctypes/test_errno.py create mode 100644 Lib/test/test_ctypes/test_find.py create mode 100644 Lib/test/test_ctypes/test_frombuffer.py create mode 100644 Lib/test/test_ctypes/test_funcptr.py create mode 100644 Lib/test/test_ctypes/test_functions.py create mode 100644 Lib/test/test_ctypes/test_incomplete.py create mode 100644 Lib/test/test_ctypes/test_init.py create mode 100644 Lib/test/test_ctypes/test_internals.py create mode 100644 Lib/test/test_ctypes/test_keeprefs.py create mode 100644 Lib/test/test_ctypes/test_libc.py create mode 100644 Lib/test/test_ctypes/test_loading.py create mode 100644 Lib/test/test_ctypes/test_macholib.py create mode 100644 Lib/test/test_ctypes/test_memfunctions.py create mode 100644 Lib/test/test_ctypes/test_numbers.py create mode 100644 Lib/test/test_ctypes/test_objects.py create mode 100644 Lib/test/test_ctypes/test_parameters.py create mode 100644 Lib/test/test_ctypes/test_pep3118.py create mode 100644 Lib/test/test_ctypes/test_pickling.py create mode 100644 Lib/test/test_ctypes/test_pointers.py create mode 100644 Lib/test/test_ctypes/test_prototypes.py create mode 100644 Lib/test/test_ctypes/test_python_api.py create mode 100644 Lib/test/test_ctypes/test_random_things.py create mode 100644 Lib/test/test_ctypes/test_refcounts.py create mode 100644 Lib/test/test_ctypes/test_repr.py create mode 100644 Lib/test/test_ctypes/test_returnfuncptrs.py create mode 100644 Lib/test/test_ctypes/test_simplesubclasses.py create mode 100644 Lib/test/test_ctypes/test_sizes.py create mode 100644 Lib/test/test_ctypes/test_slicing.py create mode 100644 Lib/test/test_ctypes/test_stringptr.py create mode 100644 Lib/test/test_ctypes/test_strings.py create mode 100644 Lib/test/test_ctypes/test_struct_fields.py create mode 100644 Lib/test/test_ctypes/test_structures.py create mode 100644 Lib/test/test_ctypes/test_unaligned_structures.py create mode 100644 Lib/test/test_ctypes/test_unicode.py create mode 100644 Lib/test/test_ctypes/test_unions.py create mode 100644 Lib/test/test_ctypes/test_values.py create mode 100644 Lib/test/test_ctypes/test_varsize_struct.py create mode 100644 Lib/test/test_ctypes/test_win32.py create mode 100644 Lib/test/test_ctypes/test_win32_com_foreign_func.py create mode 100644 Lib/test/test_ctypes/test_wintypes.py diff --git a/Lib/test/test_ctypes.py b/Lib/test/test_ctypes.py deleted file mode 100644 index b0a12c97347..00000000000 --- a/Lib/test/test_ctypes.py +++ /dev/null @@ -1,10 +0,0 @@ -import unittest -from test.support.import_helper import import_module - - -ctypes_test = import_module('ctypes.test') - -load_tests = ctypes_test.load_tests - -if __name__ == "__main__": - unittest.main() diff --git a/Lib/test/test_ctypes/__init__.py b/Lib/test/test_ctypes/__init__.py new file mode 100644 index 00000000000..eb9126cbe18 --- /dev/null +++ b/Lib/test/test_ctypes/__init__.py @@ -0,0 +1,10 @@ +import os +from test import support +from test.support import import_helper + + +# skip tests if the _ctypes extension was not built +import_helper.import_module('ctypes') + +def load_tests(*args): + return support.load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_ctypes/__main__.py b/Lib/test/test_ctypes/__main__.py new file mode 100644 index 00000000000..3003d4db890 --- /dev/null +++ b/Lib/test/test_ctypes/__main__.py @@ -0,0 +1,4 @@ +from test.test_ctypes import load_tests +import unittest + +unittest.main() diff --git a/Lib/test/test_ctypes/_support.py b/Lib/test/test_ctypes/_support.py new file mode 100644 index 00000000000..e4c2b33825a --- /dev/null +++ b/Lib/test/test_ctypes/_support.py @@ -0,0 +1,24 @@ +# Some classes and types are not export to _ctypes module directly. + +import ctypes +from _ctypes import Structure, Union, _Pointer, Array, _SimpleCData, CFuncPtr + + +_CData = Structure.__base__ +assert _CData.__name__ == "_CData" + +class _X(Structure): + _fields_ = [("x", ctypes.c_int)] +CField = type(_X.x) + +# metaclasses +PyCStructType = type(Structure) +UnionType = type(Union) +PyCPointerType = type(_Pointer) +PyCArrayType = type(Array) +PyCSimpleType = type(_SimpleCData) +PyCFuncPtrType = type(CFuncPtr) + +# type flags +Py_TPFLAGS_DISALLOW_INSTANTIATION = 1 << 7 +Py_TPFLAGS_IMMUTABLETYPE = 1 << 8 diff --git a/Lib/test/test_ctypes/test_aligned_structures.py b/Lib/test/test_ctypes/test_aligned_structures.py new file mode 100644 index 00000000000..8e8ac429900 --- /dev/null +++ b/Lib/test/test_ctypes/test_aligned_structures.py @@ -0,0 +1,321 @@ +from ctypes import ( + c_char, c_uint32, c_uint16, c_ubyte, c_byte, alignment, sizeof, + BigEndianStructure, LittleEndianStructure, + BigEndianUnion, LittleEndianUnion, Structure +) +import struct +import unittest + + +class TestAlignedStructures(unittest.TestCase): + def test_aligned_string(self): + for base, e in ( + (LittleEndianStructure, "<"), + (BigEndianStructure, ">"), + ): + data = bytearray(struct.pack(f"{e}i12x16s", 7, b"hello world!")) + class Aligned(base): + _align_ = 16 + _fields_ = [ + ('value', c_char * 12) + ] + + class Main(base): + _fields_ = [ + ('first', c_uint32), + ('string', Aligned), + ] + + main = Main.from_buffer(data) + self.assertEqual(main.first, 7) + self.assertEqual(main.string.value, b'hello world!') + self.assertEqual(bytes(main.string), b'hello world!\0\0\0\0') + self.assertEqual(Main.string.offset, 16) + self.assertEqual(Main.string.size, 16) + self.assertEqual(alignment(main.string), 16) + self.assertEqual(alignment(main), 16) + + def test_aligned_structures(self): + for base, data in ( + (LittleEndianStructure, bytearray(b"\1\0\0\0\1\0\0\0\7\0\0\0")), + (BigEndianStructure, bytearray(b"\1\0\0\0\1\0\0\0\7\0\0\0")), + ): + class SomeBools(base): + _align_ = 4 + _fields_ = [ + ("bool1", c_ubyte), + ("bool2", c_ubyte), + ] + class Main(base): + _fields_ = [ + ("x", c_ubyte), + ("y", SomeBools), + ("z", c_ubyte), + ] + + main = Main.from_buffer(data) + self.assertEqual(alignment(SomeBools), 4) + self.assertEqual(alignment(main), 4) + self.assertEqual(alignment(main.y), 4) + self.assertEqual(Main.x.size, 1) + self.assertEqual(Main.y.offset, 4) + self.assertEqual(Main.y.size, 4) + self.assertEqual(main.y.bool1, True) + self.assertEqual(main.y.bool2, False) + self.assertEqual(Main.z.offset, 8) + self.assertEqual(main.z, 7) + + def test_oversized_structure(self): + data = bytearray(b"\0" * 8) + for base in (LittleEndianStructure, BigEndianStructure): + class SomeBoolsTooBig(base): + _align_ = 8 + _fields_ = [ + ("bool1", c_ubyte), + ("bool2", c_ubyte), + ("bool3", c_ubyte), + ] + class Main(base): + _fields_ = [ + ("y", SomeBoolsTooBig), + ("z", c_uint32), + ] + with self.assertRaises(ValueError) as ctx: + Main.from_buffer(data) + self.assertEqual( + ctx.exception.args[0], + 'Buffer size too small (4 instead of at least 8 bytes)' + ) + + def test_aligned_subclasses(self): + for base, e in ( + (LittleEndianStructure, "<"), + (BigEndianStructure, ">"), + ): + data = bytearray(struct.pack(f"{e}4i", 1, 2, 3, 4)) + class UnalignedSub(base): + x: c_uint32 + _fields_ = [ + ("x", c_uint32), + ] + + class AlignedStruct(UnalignedSub): + _align_ = 8 + _fields_ = [ + ("y", c_uint32), + ] + + class Main(base): + _fields_ = [ + ("a", c_uint32), + ("b", AlignedStruct) + ] + + main = Main.from_buffer(data) + self.assertEqual(alignment(main.b), 8) + self.assertEqual(alignment(main), 8) + self.assertEqual(sizeof(main.b), 8) + self.assertEqual(sizeof(main), 16) + self.assertEqual(main.a, 1) + self.assertEqual(main.b.x, 3) + self.assertEqual(main.b.y, 4) + self.assertEqual(Main.b.offset, 8) + self.assertEqual(Main.b.size, 8) + + def test_aligned_union(self): + for sbase, ubase, e in ( + (LittleEndianStructure, LittleEndianUnion, "<"), + (BigEndianStructure, BigEndianUnion, ">"), + ): + data = bytearray(struct.pack(f"{e}4i", 1, 2, 3, 4)) + class AlignedUnion(ubase): + _align_ = 8 + _fields_ = [ + ("a", c_uint32), + ("b", c_ubyte * 7), + ] + + class Main(sbase): + _fields_ = [ + ("first", c_uint32), + ("union", AlignedUnion), + ] + + main = Main.from_buffer(data) + self.assertEqual(main.first, 1) + self.assertEqual(main.union.a, 3) + self.assertEqual(bytes(main.union.b), data[8:-1]) + self.assertEqual(Main.union.offset, 8) + self.assertEqual(Main.union.size, 8) + self.assertEqual(alignment(main.union), 8) + self.assertEqual(alignment(main), 8) + + def test_aligned_struct_in_union(self): + for sbase, ubase, e in ( + (LittleEndianStructure, LittleEndianUnion, "<"), + (BigEndianStructure, BigEndianUnion, ">"), + ): + data = bytearray(struct.pack(f"{e}4i", 1, 2, 3, 4)) + class Sub(sbase): + _align_ = 8 + _fields_ = [ + ("x", c_uint32), + ("y", c_uint32), + ] + + class MainUnion(ubase): + _fields_ = [ + ("a", c_uint32), + ("b", Sub), + ] + + class Main(sbase): + _fields_ = [ + ("first", c_uint32), + ("union", MainUnion), + ] + + main = Main.from_buffer(data) + self.assertEqual(Main.first.size, 4) + self.assertEqual(alignment(main.union), 8) + self.assertEqual(alignment(main), 8) + self.assertEqual(Main.union.offset, 8) + self.assertEqual(Main.union.size, 8) + self.assertEqual(main.first, 1) + self.assertEqual(main.union.a, 3) + self.assertEqual(main.union.b.x, 3) + self.assertEqual(main.union.b.y, 4) + + def test_smaller_aligned_subclassed_union(self): + for sbase, ubase, e in ( + (LittleEndianStructure, LittleEndianUnion, "<"), + (BigEndianStructure, BigEndianUnion, ">"), + ): + data = bytearray(struct.pack(f"{e}H2xI", 1, 0xD60102D7)) + class SubUnion(ubase): + _align_ = 2 + _fields_ = [ + ("unsigned", c_ubyte), + ("signed", c_byte), + ] + + class MainUnion(SubUnion): + _fields_ = [ + ("num", c_uint32) + ] + + class Main(sbase): + _fields_ = [ + ("first", c_uint16), + ("union", MainUnion), + ] + + main = Main.from_buffer(data) + self.assertEqual(main.union.num, 0xD60102D7) + self.assertEqual(main.union.unsigned, data[4]) + self.assertEqual(main.union.signed, data[4] - 256) + self.assertEqual(alignment(main), 4) + self.assertEqual(alignment(main.union), 4) + self.assertEqual(Main.union.offset, 4) + self.assertEqual(Main.union.size, 4) + self.assertEqual(Main.first.size, 2) + + def test_larger_aligned_subclassed_union(self): + for ubase, e in ( + (LittleEndianUnion, "<"), + (BigEndianUnion, ">"), + ): + data = bytearray(struct.pack(f"{e}I4x", 0xD60102D6)) + class SubUnion(ubase): + _align_ = 8 + _fields_ = [ + ("unsigned", c_ubyte), + ("signed", c_byte), + ] + + class Main(SubUnion): + _fields_ = [ + ("num", c_uint32) + ] + + main = Main.from_buffer(data) + self.assertEqual(alignment(main), 8) + self.assertEqual(sizeof(main), 8) + self.assertEqual(main.num, 0xD60102D6) + self.assertEqual(main.unsigned, 0xD6) + self.assertEqual(main.signed, -42) + + def test_aligned_packed_structures(self): + for sbase, e in ( + (LittleEndianStructure, "<"), + (BigEndianStructure, ">"), + ): + data = bytearray(struct.pack(f"{e}B2H4xB", 1, 2, 3, 4)) + + class Inner(sbase): + _align_ = 8 + _fields_ = [ + ("x", c_uint16), + ("y", c_uint16), + ] + + class Main(sbase): + _pack_ = 1 + _fields_ = [ + ("a", c_ubyte), + ("b", Inner), + ("c", c_ubyte), + ] + + main = Main.from_buffer(data) + self.assertEqual(sizeof(main), 10) + self.assertEqual(Main.b.offset, 1) + # Alignment == 8 because _pack_ wins out. + self.assertEqual(alignment(main.b), 8) + # Size is still 8 though since inside this Structure, it will have + # effect. + self.assertEqual(sizeof(main.b), 8) + self.assertEqual(Main.c.offset, 9) + self.assertEqual(main.a, 1) + self.assertEqual(main.b.x, 2) + self.assertEqual(main.b.y, 3) + self.assertEqual(main.c, 4) + + def test_negative_align(self): + for base in (Structure, LittleEndianStructure, BigEndianStructure): + with ( + self.subTest(base=base), + self.assertRaisesRegex( + ValueError, + '_align_ must be a non-negative integer', + ) + ): + class MyStructure(base): + _align_ = -1 + _fields_ = [] + + def test_zero_align_no_fields(self): + for base in (Structure, LittleEndianStructure, BigEndianStructure): + with self.subTest(base=base): + class MyStructure(base): + _align_ = 0 + _fields_ = [] + + self.assertEqual(alignment(MyStructure), 1) + self.assertEqual(alignment(MyStructure()), 1) + + def test_zero_align_with_fields(self): + for base in (Structure, LittleEndianStructure, BigEndianStructure): + with self.subTest(base=base): + class MyStructure(base): + _align_ = 0 + _fields_ = [ + ("x", c_ubyte), + ] + + self.assertEqual(alignment(MyStructure), 1) + self.assertEqual(alignment(MyStructure()), 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_anon.py b/Lib/test/test_ctypes/test_anon.py new file mode 100644 index 00000000000..b36397b510f --- /dev/null +++ b/Lib/test/test_ctypes/test_anon.py @@ -0,0 +1,75 @@ +import unittest +import test.support +from ctypes import c_int, Union, Structure, sizeof + + +class AnonTest(unittest.TestCase): + + def test_anon(self): + class ANON(Union): + _fields_ = [("a", c_int), + ("b", c_int)] + + class Y(Structure): + _fields_ = [("x", c_int), + ("_", ANON), + ("y", c_int)] + _anonymous_ = ["_"] + + self.assertEqual(Y.a.offset, sizeof(c_int)) + self.assertEqual(Y.b.offset, sizeof(c_int)) + + self.assertEqual(ANON.a.offset, 0) + self.assertEqual(ANON.b.offset, 0) + + def test_anon_nonseq(self): + # TypeError: _anonymous_ must be a sequence + self.assertRaises(TypeError, + lambda: type(Structure)("Name", + (Structure,), + {"_fields_": [], "_anonymous_": 42})) + + def test_anon_nonmember(self): + # AttributeError: type object 'Name' has no attribute 'x' + self.assertRaises(AttributeError, + lambda: type(Structure)("Name", + (Structure,), + {"_fields_": [], + "_anonymous_": ["x"]})) + + @test.support.cpython_only + def test_issue31490(self): + # There shouldn't be an assertion failure in case the class has an + # attribute whose name is specified in _anonymous_ but not in _fields_. + + # AttributeError: 'x' is specified in _anonymous_ but not in _fields_ + with self.assertRaises(AttributeError): + class Name(Structure): + _fields_ = [] + _anonymous_ = ["x"] + x = 42 + + def test_nested(self): + class ANON_S(Structure): + _fields_ = [("a", c_int)] + + class ANON_U(Union): + _fields_ = [("_", ANON_S), + ("b", c_int)] + _anonymous_ = ["_"] + + class Y(Structure): + _fields_ = [("x", c_int), + ("_", ANON_U), + ("y", c_int)] + _anonymous_ = ["_"] + + self.assertEqual(Y.x.offset, 0) + self.assertEqual(Y.a.offset, sizeof(c_int)) + self.assertEqual(Y.b.offset, sizeof(c_int)) + self.assertEqual(Y._.offset, sizeof(c_int)) + self.assertEqual(Y.y.offset, sizeof(c_int) * 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_array_in_pointer.py b/Lib/test/test_ctypes/test_array_in_pointer.py new file mode 100644 index 00000000000..b7c96b2fa49 --- /dev/null +++ b/Lib/test/test_ctypes/test_array_in_pointer.py @@ -0,0 +1,68 @@ +import binascii +import re +import unittest +from ctypes import c_byte, Structure, POINTER, cast + + +def dump(obj): + # helper function to dump memory contents in hex, with a hyphen + # between the bytes. + h = binascii.hexlify(memoryview(obj)).decode() + return re.sub(r"(..)", r"\1-", h)[:-1] + + +class Value(Structure): + _fields_ = [("val", c_byte)] + + +class Container(Structure): + _fields_ = [("pvalues", POINTER(Value))] + + +class Test(unittest.TestCase): + def test(self): + # create an array of 4 values + val_array = (Value * 4)() + + # create a container, which holds a pointer to the pvalues array. + c = Container() + c.pvalues = val_array + + # memory contains 4 NUL bytes now, that's correct + self.assertEqual("00-00-00-00", dump(val_array)) + + # set the values of the array through the pointer: + for i in range(4): + c.pvalues[i].val = i + 1 + + values = [c.pvalues[i].val for i in range(4)] + + # These are the expected results: here s the bug! + self.assertEqual( + (values, dump(val_array)), + ([1, 2, 3, 4], "01-02-03-04") + ) + + def test_2(self): + + val_array = (Value * 4)() + + # memory contains 4 NUL bytes now, that's correct + self.assertEqual("00-00-00-00", dump(val_array)) + + ptr = cast(val_array, POINTER(Value)) + # set the values of the array through the pointer: + for i in range(4): + ptr[i].val = i + 1 + + values = [ptr[i].val for i in range(4)] + + # These are the expected results: here s the bug! + self.assertEqual( + (values, dump(val_array)), + ([1, 2, 3, 4], "01-02-03-04") + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_arrays.py b/Lib/test/test_ctypes/test_arrays.py new file mode 100644 index 00000000000..c80fdff5de6 --- /dev/null +++ b/Lib/test/test_ctypes/test_arrays.py @@ -0,0 +1,272 @@ +import ctypes +import sys +import unittest +from ctypes import (Structure, Array, ARRAY, sizeof, addressof, + create_string_buffer, create_unicode_buffer, + c_char, c_wchar, c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, + c_long, c_ulonglong, c_float, c_double, c_longdouble) +from test.support import bigmemtest, _2G +from ._support import (_CData, PyCArrayType, Py_TPFLAGS_DISALLOW_INSTANTIATION, + Py_TPFLAGS_IMMUTABLETYPE) + + +formats = "bBhHiIlLqQfd" + +formats = c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, \ + c_long, c_ulonglong, c_float, c_double, c_longdouble + + +class ArrayTestCase(unittest.TestCase): + def test_inheritance_hierarchy(self): + self.assertEqual(Array.mro(), [Array, _CData, object]) + + self.assertEqual(PyCArrayType.__name__, "PyCArrayType") + self.assertEqual(type(PyCArrayType), type) + + def test_type_flags(self): + for cls in Array, PyCArrayType: + with self.subTest(cls=cls): + self.assertTrue(cls.__flags__ & Py_TPFLAGS_IMMUTABLETYPE) + self.assertFalse(cls.__flags__ & Py_TPFLAGS_DISALLOW_INSTANTIATION) + + def test_metaclass_details(self): + # Abstract classes (whose metaclass __init__ was not called) can't be + # instantiated directly + NewArray = PyCArrayType.__new__(PyCArrayType, 'NewArray', (Array,), {}) + for cls in Array, NewArray: + with self.subTest(cls=cls): + with self.assertRaisesRegex(TypeError, "abstract class"): + obj = cls() + + # Cannot call the metaclass __init__ more than once + class T(Array): + _type_ = c_int + _length_ = 13 + with self.assertRaisesRegex(SystemError, "already initialized"): + PyCArrayType.__init__(T, 'ptr', (), {}) + + def test_simple(self): + # create classes holding simple numeric types, and check + # various properties. + + init = list(range(15, 25)) + + for fmt in formats: + alen = len(init) + int_array = ARRAY(fmt, alen) + + ia = int_array(*init) + # length of instance ok? + self.assertEqual(len(ia), alen) + + # slot values ok? + values = [ia[i] for i in range(alen)] + self.assertEqual(values, init) + + # out-of-bounds accesses should be caught + with self.assertRaises(IndexError): ia[alen] + with self.assertRaises(IndexError): ia[-alen-1] + + # change the items + new_values = list(range(42, 42+alen)) + for n in range(alen): + ia[n] = new_values[n] + values = [ia[i] for i in range(alen)] + self.assertEqual(values, new_values) + + # are the items initialized to 0? + ia = int_array() + values = [ia[i] for i in range(alen)] + self.assertEqual(values, [0] * alen) + + # Too many initializers should be caught + self.assertRaises(IndexError, int_array, *range(alen*2)) + + CharArray = ARRAY(c_char, 3) + + ca = CharArray(b"a", b"b", b"c") + + # Should this work? It doesn't: + # CharArray("abc") + self.assertRaises(TypeError, CharArray, "abc") + + self.assertEqual(ca[0], b"a") + self.assertEqual(ca[1], b"b") + self.assertEqual(ca[2], b"c") + self.assertEqual(ca[-3], b"a") + self.assertEqual(ca[-2], b"b") + self.assertEqual(ca[-1], b"c") + + self.assertEqual(len(ca), 3) + + # cannot delete items + with self.assertRaises(TypeError): + del ca[0] + + def test_step_overflow(self): + a = (c_int * 5)() + a[3::sys.maxsize] = (1,) + self.assertListEqual(a[3::sys.maxsize], [1]) + a = (c_char * 5)() + a[3::sys.maxsize] = b"A" + self.assertEqual(a[3::sys.maxsize], b"A") + a = (c_wchar * 5)() + a[3::sys.maxsize] = u"X" + self.assertEqual(a[3::sys.maxsize], u"X") + + def test_numeric_arrays(self): + + alen = 5 + + numarray = ARRAY(c_int, alen) + + na = numarray() + values = [na[i] for i in range(alen)] + self.assertEqual(values, [0] * alen) + + na = numarray(*[c_int()] * alen) + values = [na[i] for i in range(alen)] + self.assertEqual(values, [0]*alen) + + na = numarray(1, 2, 3, 4, 5) + values = [i for i in na] + self.assertEqual(values, [1, 2, 3, 4, 5]) + + na = numarray(*map(c_int, (1, 2, 3, 4, 5))) + values = [i for i in na] + self.assertEqual(values, [1, 2, 3, 4, 5]) + + def test_classcache(self): + self.assertIsNot(ARRAY(c_int, 3), ARRAY(c_int, 4)) + self.assertIs(ARRAY(c_int, 3), ARRAY(c_int, 3)) + + def test_from_address(self): + # Failed with 0.9.8, reported by JUrner + p = create_string_buffer(b"foo") + sz = (c_char * 3).from_address(addressof(p)) + self.assertEqual(sz[:], b"foo") + self.assertEqual(sz[::], b"foo") + self.assertEqual(sz[::-1], b"oof") + self.assertEqual(sz[::3], b"f") + self.assertEqual(sz[1:4:2], b"o") + self.assertEqual(sz.value, b"foo") + + def test_from_addressW(self): + p = create_unicode_buffer("foo") + sz = (c_wchar * 3).from_address(addressof(p)) + self.assertEqual(sz[:], "foo") + self.assertEqual(sz[::], "foo") + self.assertEqual(sz[::-1], "oof") + self.assertEqual(sz[::3], "f") + self.assertEqual(sz[1:4:2], "o") + self.assertEqual(sz.value, "foo") + + def test_cache(self): + # Array types are cached internally in the _ctypes extension, + # in a WeakValueDictionary. Make sure the array type is + # removed from the cache when the itemtype goes away. This + # test will not fail, but will show a leak in the testsuite. + + # Create a new type: + class my_int(c_int): + pass + # Create a new array type based on it: + t1 = my_int * 1 + t2 = my_int * 1 + self.assertIs(t1, t2) + + def test_subclass(self): + class T(Array): + _type_ = c_int + _length_ = 13 + class U(T): + pass + class V(U): + pass + class W(V): + pass + class X(T): + _type_ = c_short + class Y(T): + _length_ = 187 + + for c in [T, U, V, W]: + self.assertEqual(c._type_, c_int) + self.assertEqual(c._length_, 13) + self.assertEqual(c()._type_, c_int) + self.assertEqual(c()._length_, 13) + + self.assertEqual(X._type_, c_short) + self.assertEqual(X._length_, 13) + self.assertEqual(X()._type_, c_short) + self.assertEqual(X()._length_, 13) + + self.assertEqual(Y._type_, c_int) + self.assertEqual(Y._length_, 187) + self.assertEqual(Y()._type_, c_int) + self.assertEqual(Y()._length_, 187) + + def test_bad_subclass(self): + with self.assertRaises(AttributeError): + class T(Array): + pass + with self.assertRaises(AttributeError): + class T2(Array): + _type_ = c_int + with self.assertRaises(AttributeError): + class T3(Array): + _length_ = 13 + + def test_bad_length(self): + with self.assertRaises(ValueError): + class T(Array): + _type_ = c_int + _length_ = - sys.maxsize * 2 + with self.assertRaises(ValueError): + class T2(Array): + _type_ = c_int + _length_ = -1 + with self.assertRaises(TypeError): + class T3(Array): + _type_ = c_int + _length_ = 1.87 + with self.assertRaises(OverflowError): + class T4(Array): + _type_ = c_int + _length_ = sys.maxsize * 2 + + def test_zero_length(self): + # _length_ can be zero. + class T(Array): + _type_ = c_int + _length_ = 0 + + def test_empty_element_struct(self): + class EmptyStruct(Structure): + _fields_ = [] + + obj = (EmptyStruct * 2)() # bpo37188: Floating-point exception + self.assertEqual(sizeof(obj), 0) + + def test_empty_element_array(self): + class EmptyArray(Array): + _type_ = c_int + _length_ = 0 + + obj = (EmptyArray * 2)() # bpo37188: Floating-point exception + self.assertEqual(sizeof(obj), 0) + + def test_bpo36504_signed_int_overflow(self): + # The overflow check in PyCArrayType_new() could cause signed integer + # overflow. + with self.assertRaises(OverflowError): + c_char * sys.maxsize * 2 + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @bigmemtest(size=_2G, memuse=1, dry_run=False) + def test_large_array(self, size): + c_char * size + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_as_parameter.py b/Lib/test/test_ctypes/test_as_parameter.py new file mode 100644 index 00000000000..c5e1840b0eb --- /dev/null +++ b/Lib/test/test_ctypes/test_as_parameter.py @@ -0,0 +1,245 @@ +import ctypes +import unittest +from ctypes import (Structure, CDLL, CFUNCTYPE, + POINTER, pointer, byref, + c_short, c_int, c_long, c_longlong, + c_byte, c_wchar, c_float, c_double, + ArgumentError) +from test.support import import_helper +_ctypes_test = import_helper.import_module("_ctypes_test") + + +dll = CDLL(_ctypes_test.__file__) + +try: + CALLBACK_FUNCTYPE = ctypes.WINFUNCTYPE +except AttributeError: + # fake to enable this test on Linux + CALLBACK_FUNCTYPE = CFUNCTYPE + + +class POINT(Structure): + _fields_ = [("x", c_int), ("y", c_int)] + + +class BasicWrapTestCase(unittest.TestCase): + def wrap(self, param): + return param + + def test_wchar_parm(self): + f = dll._testfunc_i_bhilfd + f.argtypes = [c_byte, c_wchar, c_int, c_long, c_float, c_double] + result = f(self.wrap(1), self.wrap("x"), self.wrap(3), self.wrap(4), self.wrap(5.0), self.wrap(6.0)) + self.assertEqual(result, 139) + self.assertIs(type(result), int) + + def test_pointers(self): + f = dll._testfunc_p_p + f.restype = POINTER(c_int) + f.argtypes = [POINTER(c_int)] + + # This only works if the value c_int(42) passed to the + # function is still alive while the pointer (the result) is + # used. + + v = c_int(42) + + self.assertEqual(pointer(v).contents.value, 42) + result = f(self.wrap(pointer(v))) + self.assertEqual(type(result), POINTER(c_int)) + self.assertEqual(result.contents.value, 42) + + # This on works... + result = f(self.wrap(pointer(v))) + self.assertEqual(result.contents.value, v.value) + + p = pointer(c_int(99)) + result = f(self.wrap(p)) + self.assertEqual(result.contents.value, 99) + + def test_shorts(self): + f = dll._testfunc_callback_i_if + + args = [] + expected = [262144, 131072, 65536, 32768, 16384, 8192, 4096, 2048, + 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1] + + def callback(v): + args.append(v) + return v + + CallBack = CFUNCTYPE(c_int, c_int) + + cb = CallBack(callback) + f(self.wrap(2**18), self.wrap(cb)) + self.assertEqual(args, expected) + + def test_callbacks(self): + f = dll._testfunc_callback_i_if + f.restype = c_int + f.argtypes = None + + MyCallback = CFUNCTYPE(c_int, c_int) + + def callback(value): + return value + + cb = MyCallback(callback) + + result = f(self.wrap(-10), self.wrap(cb)) + self.assertEqual(result, -18) + + # test with prototype + f.argtypes = [c_int, MyCallback] + cb = MyCallback(callback) + + result = f(self.wrap(-10), self.wrap(cb)) + self.assertEqual(result, -18) + + result = f(self.wrap(-10), self.wrap(cb)) + self.assertEqual(result, -18) + + AnotherCallback = CALLBACK_FUNCTYPE(c_int, c_int, c_int, c_int, c_int) + + # check that the prototype works: we call f with wrong + # argument types + cb = AnotherCallback(callback) + self.assertRaises(ArgumentError, f, self.wrap(-10), self.wrap(cb)) + + def test_callbacks_2(self): + # Can also use simple datatypes as argument type specifiers + # for the callback function. + # In this case the call receives an instance of that type + f = dll._testfunc_callback_i_if + f.restype = c_int + + MyCallback = CFUNCTYPE(c_int, c_int) + + f.argtypes = [c_int, MyCallback] + + def callback(value): + self.assertEqual(type(value), int) + return value + + cb = MyCallback(callback) + result = f(self.wrap(-10), self.wrap(cb)) + self.assertEqual(result, -18) + + def test_longlong_callbacks(self): + f = dll._testfunc_callback_q_qf + f.restype = c_longlong + + MyCallback = CFUNCTYPE(c_longlong, c_longlong) + + f.argtypes = [c_longlong, MyCallback] + + def callback(value): + self.assertIsInstance(value, int) + return value & 0x7FFFFFFF + + cb = MyCallback(callback) + + self.assertEqual(13577625587, int(f(self.wrap(1000000000000), self.wrap(cb)))) + + def test_byval(self): + # without prototype + ptin = POINT(1, 2) + ptout = POINT() + # EXPORT int _testfunc_byval(point in, point *pout) + result = dll._testfunc_byval(ptin, byref(ptout)) + got = result, ptout.x, ptout.y + expected = 3, 1, 2 + self.assertEqual(got, expected) + + # with prototype + ptin = POINT(101, 102) + ptout = POINT() + dll._testfunc_byval.argtypes = (POINT, POINTER(POINT)) + dll._testfunc_byval.restype = c_int + result = dll._testfunc_byval(self.wrap(ptin), byref(ptout)) + got = result, ptout.x, ptout.y + expected = 203, 101, 102 + self.assertEqual(got, expected) + + def test_struct_return_2H(self): + class S2H(Structure): + _fields_ = [("x", c_short), + ("y", c_short)] + dll.ret_2h_func.restype = S2H + dll.ret_2h_func.argtypes = [S2H] + inp = S2H(99, 88) + s2h = dll.ret_2h_func(self.wrap(inp)) + self.assertEqual((s2h.x, s2h.y), (99*2, 88*3)) + + # Test also that the original struct was unmodified (i.e. was passed by + # value) + self.assertEqual((inp.x, inp.y), (99, 88)) + + def test_struct_return_8H(self): + class S8I(Structure): + _fields_ = [("a", c_int), + ("b", c_int), + ("c", c_int), + ("d", c_int), + ("e", c_int), + ("f", c_int), + ("g", c_int), + ("h", c_int)] + dll.ret_8i_func.restype = S8I + dll.ret_8i_func.argtypes = [S8I] + inp = S8I(9, 8, 7, 6, 5, 4, 3, 2) + s8i = dll.ret_8i_func(self.wrap(inp)) + self.assertEqual((s8i.a, s8i.b, s8i.c, s8i.d, s8i.e, s8i.f, s8i.g, s8i.h), + (9*2, 8*3, 7*4, 6*5, 5*6, 4*7, 3*8, 2*9)) + + def test_recursive_as_param(self): + class A: + pass + + a = A() + a._as_parameter_ = a + for c_type in ( + ctypes.c_wchar_p, + ctypes.c_char_p, + ctypes.c_void_p, + ctypes.c_int, # PyCSimpleType + POINT, # CDataType + ): + with self.subTest(c_type=c_type): + with self.assertRaises(RecursionError): + c_type.from_param(a) + + +class AsParamWrapper: + def __init__(self, param): + self._as_parameter_ = param + +class AsParamWrapperTestCase(BasicWrapTestCase): + wrap = AsParamWrapper + + +class AsParamPropertyWrapper: + def __init__(self, param): + self._param = param + + def getParameter(self): + return self._param + _as_parameter_ = property(getParameter) + +class AsParamPropertyWrapperTestCase(BasicWrapTestCase): + wrap = AsParamPropertyWrapper + + +class AsParamNestedWrapperTestCase(BasicWrapTestCase): + """Test that _as_parameter_ is evaluated recursively. + + The _as_parameter_ attribute can be another object which + defines its own _as_parameter_ attribute. + """ + + def wrap(self, param): + return AsParamWrapper(AsParamWrapper(AsParamWrapper(param))) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_bitfields.py b/Lib/test/test_ctypes/test_bitfields.py new file mode 100644 index 00000000000..0332544b582 --- /dev/null +++ b/Lib/test/test_ctypes/test_bitfields.py @@ -0,0 +1,294 @@ +import os +import unittest +from ctypes import (CDLL, Structure, sizeof, POINTER, byref, alignment, + LittleEndianStructure, BigEndianStructure, + c_byte, c_ubyte, c_char, c_char_p, c_void_p, c_wchar, + c_uint32, c_uint64, + c_short, c_ushort, c_int, c_uint, c_long, c_ulong, c_longlong, c_ulonglong) +from test import support +from test.support import import_helper +_ctypes_test = import_helper.import_module("_ctypes_test") + + +class BITS(Structure): + _fields_ = [("A", c_int, 1), + ("B", c_int, 2), + ("C", c_int, 3), + ("D", c_int, 4), + ("E", c_int, 5), + ("F", c_int, 6), + ("G", c_int, 7), + ("H", c_int, 8), + ("I", c_int, 9), + + ("M", c_short, 1), + ("N", c_short, 2), + ("O", c_short, 3), + ("P", c_short, 4), + ("Q", c_short, 5), + ("R", c_short, 6), + ("S", c_short, 7)] + +func = CDLL(_ctypes_test.__file__).unpack_bitfields +func.argtypes = POINTER(BITS), c_char + + +class C_Test(unittest.TestCase): + + def test_ints(self): + for i in range(512): + for name in "ABCDEFGHI": + b = BITS() + setattr(b, name, i) + self.assertEqual(getattr(b, name), func(byref(b), name.encode('ascii'))) + + # bpo-46913: _ctypes/cfield.c h_get() has an undefined behavior + @support.skip_if_sanitizer(ub=True) + def test_shorts(self): + b = BITS() + name = "M" + if func(byref(b), name.encode('ascii')) == 999: + self.skipTest("Compiler does not support signed short bitfields") + for i in range(256): + for name in "MNOPQRS": + b = BITS() + setattr(b, name, i) + self.assertEqual(getattr(b, name), func(byref(b), name.encode('ascii'))) + + +signed_int_types = (c_byte, c_short, c_int, c_long, c_longlong) +unsigned_int_types = (c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong) +int_types = unsigned_int_types + signed_int_types + +class BitFieldTest(unittest.TestCase): + + def test_longlong(self): + class X(Structure): + _fields_ = [("a", c_longlong, 1), + ("b", c_longlong, 62), + ("c", c_longlong, 1)] + + self.assertEqual(sizeof(X), sizeof(c_longlong)) + x = X() + x.a, x.b, x.c = -1, 7, -1 + self.assertEqual((x.a, x.b, x.c), (-1, 7, -1)) + + def test_ulonglong(self): + class X(Structure): + _fields_ = [("a", c_ulonglong, 1), + ("b", c_ulonglong, 62), + ("c", c_ulonglong, 1)] + + self.assertEqual(sizeof(X), sizeof(c_longlong)) + x = X() + self.assertEqual((x.a, x.b, x.c), (0, 0, 0)) + x.a, x.b, x.c = 7, 7, 7 + self.assertEqual((x.a, x.b, x.c), (1, 7, 1)) + + def test_signed(self): + for c_typ in signed_int_types: + class X(Structure): + _fields_ = [("dummy", c_typ), + ("a", c_typ, 3), + ("b", c_typ, 3), + ("c", c_typ, 1)] + self.assertEqual(sizeof(X), sizeof(c_typ)*2) + + x = X() + self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0)) + x.a = -1 + self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, -1, 0, 0)) + x.a, x.b = 0, -1 + self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, -1, 0)) + + + def test_unsigned(self): + for c_typ in unsigned_int_types: + class X(Structure): + _fields_ = [("a", c_typ, 3), + ("b", c_typ, 3), + ("c", c_typ, 1)] + self.assertEqual(sizeof(X), sizeof(c_typ)) + + x = X() + self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0)) + x.a = -1 + self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 7, 0, 0)) + x.a, x.b = 0, -1 + self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 7, 0)) + + def fail_fields(self, *fields): + return self.get_except(type(Structure), "X", (), + {"_fields_": fields}) + + def test_nonint_types(self): + # bit fields are not allowed on non-integer types. + result = self.fail_fields(("a", c_char_p, 1)) + self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char_p')) + + result = self.fail_fields(("a", c_void_p, 1)) + self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_void_p')) + + if c_int != c_long: + result = self.fail_fields(("a", POINTER(c_int), 1)) + self.assertEqual(result, (TypeError, 'bit fields not allowed for type LP_c_int')) + + result = self.fail_fields(("a", c_char, 1)) + self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char')) + + class Dummy(Structure): + _fields_ = [] + + result = self.fail_fields(("a", Dummy, 1)) + self.assertEqual(result, (TypeError, 'bit fields not allowed for type Dummy')) + + def test_c_wchar(self): + result = self.fail_fields(("a", c_wchar, 1)) + self.assertEqual(result, + (TypeError, 'bit fields not allowed for type c_wchar')) + + def test_single_bitfield_size(self): + for c_typ in int_types: + result = self.fail_fields(("a", c_typ, -1)) + self.assertEqual(result, (ValueError, 'number of bits invalid for bit field')) + + result = self.fail_fields(("a", c_typ, 0)) + self.assertEqual(result, (ValueError, 'number of bits invalid for bit field')) + + class X(Structure): + _fields_ = [("a", c_typ, 1)] + self.assertEqual(sizeof(X), sizeof(c_typ)) + + class X(Structure): + _fields_ = [("a", c_typ, sizeof(c_typ)*8)] + self.assertEqual(sizeof(X), sizeof(c_typ)) + + result = self.fail_fields(("a", c_typ, sizeof(c_typ)*8 + 1)) + self.assertEqual(result, (ValueError, 'number of bits invalid for bit field')) + + def test_multi_bitfields_size(self): + class X(Structure): + _fields_ = [("a", c_short, 1), + ("b", c_short, 14), + ("c", c_short, 1)] + self.assertEqual(sizeof(X), sizeof(c_short)) + + class X(Structure): + _fields_ = [("a", c_short, 1), + ("a1", c_short), + ("b", c_short, 14), + ("c", c_short, 1)] + self.assertEqual(sizeof(X), sizeof(c_short)*3) + self.assertEqual(X.a.offset, 0) + self.assertEqual(X.a1.offset, sizeof(c_short)) + self.assertEqual(X.b.offset, sizeof(c_short)*2) + self.assertEqual(X.c.offset, sizeof(c_short)*2) + + class X(Structure): + _fields_ = [("a", c_short, 3), + ("b", c_short, 14), + ("c", c_short, 14)] + self.assertEqual(sizeof(X), sizeof(c_short)*3) + self.assertEqual(X.a.offset, sizeof(c_short)*0) + self.assertEqual(X.b.offset, sizeof(c_short)*1) + self.assertEqual(X.c.offset, sizeof(c_short)*2) + + def get_except(self, func, *args, **kw): + try: + func(*args, **kw) + except Exception as detail: + return detail.__class__, str(detail) + + def test_mixed_1(self): + class X(Structure): + _fields_ = [("a", c_byte, 4), + ("b", c_int, 4)] + if os.name == "nt": + self.assertEqual(sizeof(X), sizeof(c_int)*2) + else: + self.assertEqual(sizeof(X), sizeof(c_int)) + + def test_mixed_2(self): + class X(Structure): + _fields_ = [("a", c_byte, 4), + ("b", c_int, 32)] + self.assertEqual(sizeof(X), alignment(c_int)+sizeof(c_int)) + + def test_mixed_3(self): + class X(Structure): + _fields_ = [("a", c_byte, 4), + ("b", c_ubyte, 4)] + self.assertEqual(sizeof(X), sizeof(c_byte)) + + def test_mixed_4(self): + class X(Structure): + _fields_ = [("a", c_short, 4), + ("b", c_short, 4), + ("c", c_int, 24), + ("d", c_short, 4), + ("e", c_short, 4), + ("f", c_int, 24)] + # MSVC does NOT combine c_short and c_int into one field, GCC + # does (unless GCC is run with '-mms-bitfields' which + # produces code compatible with MSVC). + if os.name == "nt": + self.assertEqual(sizeof(X), sizeof(c_int) * 4) + else: + self.assertEqual(sizeof(X), sizeof(c_int) * 2) + + def test_anon_bitfields(self): + # anonymous bit-fields gave a strange error message + class X(Structure): + _fields_ = [("a", c_byte, 4), + ("b", c_ubyte, 4)] + class Y(Structure): + _anonymous_ = ["_"] + _fields_ = [("_", X)] + + def test_uint32(self): + class X(Structure): + _fields_ = [("a", c_uint32, 32)] + x = X() + x.a = 10 + self.assertEqual(x.a, 10) + x.a = 0xFDCBA987 + self.assertEqual(x.a, 0xFDCBA987) + + def test_uint64(self): + class X(Structure): + _fields_ = [("a", c_uint64, 64)] + x = X() + x.a = 10 + self.assertEqual(x.a, 10) + x.a = 0xFEDCBA9876543211 + self.assertEqual(x.a, 0xFEDCBA9876543211) + + def test_uint32_swap_little_endian(self): + # Issue #23319 + class Little(LittleEndianStructure): + _fields_ = [("a", c_uint32, 24), + ("b", c_uint32, 4), + ("c", c_uint32, 4)] + b = bytearray(4) + x = Little.from_buffer(b) + x.a = 0xabcdef + x.b = 1 + x.c = 2 + self.assertEqual(b, b'\xef\xcd\xab\x21') + + def test_uint32_swap_big_endian(self): + # Issue #23319 + class Big(BigEndianStructure): + _fields_ = [("a", c_uint32, 24), + ("b", c_uint32, 4), + ("c", c_uint32, 4)] + b = bytearray(4) + x = Big.from_buffer(b) + x.a = 0xabcdef + x.b = 1 + x.c = 2 + self.assertEqual(b, b'\xab\xcd\xef\x12') + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_buffers.py b/Lib/test/test_ctypes/test_buffers.py new file mode 100644 index 00000000000..468f41eb7cf --- /dev/null +++ b/Lib/test/test_ctypes/test_buffers.py @@ -0,0 +1,70 @@ +import unittest +from ctypes import (create_string_buffer, create_unicode_buffer, sizeof, + c_char, c_wchar) + + +class StringBufferTestCase(unittest.TestCase): + def test_buffer(self): + b = create_string_buffer(32) + self.assertEqual(len(b), 32) + self.assertEqual(sizeof(b), 32 * sizeof(c_char)) + self.assertIs(type(b[0]), bytes) + + b = create_string_buffer(b"abc") + self.assertEqual(len(b), 4) # trailing nul char + self.assertEqual(sizeof(b), 4 * sizeof(c_char)) + self.assertIs(type(b[0]), bytes) + self.assertEqual(b[0], b"a") + self.assertEqual(b[:], b"abc\0") + self.assertEqual(b[::], b"abc\0") + self.assertEqual(b[::-1], b"\0cba") + self.assertEqual(b[::2], b"ac") + self.assertEqual(b[::5], b"a") + + self.assertRaises(TypeError, create_string_buffer, "abc") + + def test_buffer_interface(self): + self.assertEqual(len(bytearray(create_string_buffer(0))), 0) + self.assertEqual(len(bytearray(create_string_buffer(1))), 1) + + def test_unicode_buffer(self): + b = create_unicode_buffer(32) + self.assertEqual(len(b), 32) + self.assertEqual(sizeof(b), 32 * sizeof(c_wchar)) + self.assertIs(type(b[0]), str) + + b = create_unicode_buffer("abc") + self.assertEqual(len(b), 4) # trailing nul char + self.assertEqual(sizeof(b), 4 * sizeof(c_wchar)) + self.assertIs(type(b[0]), str) + self.assertEqual(b[0], "a") + self.assertEqual(b[:], "abc\0") + self.assertEqual(b[::], "abc\0") + self.assertEqual(b[::-1], "\0cba") + self.assertEqual(b[::2], "ac") + self.assertEqual(b[::5], "a") + + self.assertRaises(TypeError, create_unicode_buffer, b"abc") + + def test_unicode_conversion(self): + b = create_unicode_buffer("abc") + self.assertEqual(len(b), 4) # trailing nul char + self.assertEqual(sizeof(b), 4 * sizeof(c_wchar)) + self.assertIs(type(b[0]), str) + self.assertEqual(b[0], "a") + self.assertEqual(b[:], "abc\0") + self.assertEqual(b[::], "abc\0") + self.assertEqual(b[::-1], "\0cba") + self.assertEqual(b[::2], "ac") + self.assertEqual(b[::5], "a") + + def test_create_unicode_buffer_non_bmp(self): + expected = 5 if sizeof(c_wchar) == 2 else 3 + for s in '\U00010000\U00100000', '\U00010000\U0010ffff': + b = create_unicode_buffer(s) + self.assertEqual(len(b), expected) + self.assertEqual(b[-1], '\0') + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_bytes.py b/Lib/test/test_ctypes/test_bytes.py new file mode 100644 index 00000000000..fa11e1bbd49 --- /dev/null +++ b/Lib/test/test_ctypes/test_bytes.py @@ -0,0 +1,67 @@ +"""Test where byte objects are accepted""" +import sys +import unittest +from _ctypes import _SimpleCData +from ctypes import Structure, c_char, c_char_p, c_wchar, c_wchar_p + + +class BytesTest(unittest.TestCase): + def test_c_char(self): + x = c_char(b"x") + self.assertRaises(TypeError, c_char, "x") + x.value = b"y" + with self.assertRaises(TypeError): + x.value = "y" + c_char.from_param(b"x") + self.assertRaises(TypeError, c_char.from_param, "x") + self.assertIn('xbd', repr(c_char.from_param(b"\xbd"))) + (c_char * 3)(b"a", b"b", b"c") + self.assertRaises(TypeError, c_char * 3, "a", "b", "c") + + def test_c_wchar(self): + x = c_wchar("x") + self.assertRaises(TypeError, c_wchar, b"x") + x.value = "y" + with self.assertRaises(TypeError): + x.value = b"y" + c_wchar.from_param("x") + self.assertRaises(TypeError, c_wchar.from_param, b"x") + (c_wchar * 3)("a", "b", "c") + self.assertRaises(TypeError, c_wchar * 3, b"a", b"b", b"c") + + def test_c_char_p(self): + c_char_p(b"foo bar") + self.assertRaises(TypeError, c_char_p, "foo bar") + + def test_c_wchar_p(self): + c_wchar_p("foo bar") + self.assertRaises(TypeError, c_wchar_p, b"foo bar") + + def test_struct(self): + class X(Structure): + _fields_ = [("a", c_char * 3)] + + x = X(b"abc") + self.assertRaises(TypeError, X, "abc") + self.assertEqual(x.a, b"abc") + self.assertEqual(type(x.a), bytes) + + def test_struct_W(self): + class X(Structure): + _fields_ = [("a", c_wchar * 3)] + + x = X("abc") + self.assertRaises(TypeError, X, b"abc") + self.assertEqual(x.a, "abc") + self.assertEqual(type(x.a), str) + + @unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') + def test_BSTR(self): + class BSTR(_SimpleCData): + _type_ = "X" + + BSTR("abc") + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_byteswap.py b/Lib/test/test_ctypes/test_byteswap.py new file mode 100644 index 00000000000..78eff0392c4 --- /dev/null +++ b/Lib/test/test_ctypes/test_byteswap.py @@ -0,0 +1,386 @@ +import binascii +import ctypes +import math +import struct +import sys +import unittest +from ctypes import (Structure, Union, LittleEndianUnion, BigEndianUnion, + BigEndianStructure, LittleEndianStructure, + POINTER, sizeof, cast, + c_byte, c_ubyte, c_char, c_wchar, c_void_p, + c_short, c_ushort, c_int, c_uint, + c_long, c_ulong, c_longlong, c_ulonglong, + c_uint32, c_float, c_double) + + +def bin(s): + return binascii.hexlify(memoryview(s)).decode().upper() + + +# Each *simple* type that supports different byte orders has an +# __ctype_be__ attribute that specifies the same type in BIG ENDIAN +# byte order, and a __ctype_le__ attribute that is the same type in +# LITTLE ENDIAN byte order. +# +# For Structures and Unions, these types are created on demand. + +class Test(unittest.TestCase): + def test_slots(self): + class BigPoint(BigEndianStructure): + __slots__ = () + _fields_ = [("x", c_int), ("y", c_int)] + + class LowPoint(LittleEndianStructure): + __slots__ = () + _fields_ = [("x", c_int), ("y", c_int)] + + big = BigPoint() + little = LowPoint() + big.x = 4 + big.y = 2 + little.x = 2 + little.y = 4 + with self.assertRaises(AttributeError): + big.z = 42 + with self.assertRaises(AttributeError): + little.z = 24 + + def test_endian_short(self): + if sys.byteorder == "little": + self.assertIs(c_short.__ctype_le__, c_short) + self.assertIs(c_short.__ctype_be__.__ctype_le__, c_short) + else: + self.assertIs(c_short.__ctype_be__, c_short) + self.assertIs(c_short.__ctype_le__.__ctype_be__, c_short) + s = c_short.__ctype_be__(0x1234) + self.assertEqual(bin(struct.pack(">h", 0x1234)), "1234") + self.assertEqual(bin(s), "1234") + self.assertEqual(s.value, 0x1234) + + s = c_short.__ctype_le__(0x1234) + self.assertEqual(bin(struct.pack("h", 0x1234)), "1234") + self.assertEqual(bin(s), "1234") + self.assertEqual(s.value, 0x1234) + + s = c_ushort.__ctype_le__(0x1234) + self.assertEqual(bin(struct.pack("i", 0x12345678)), "12345678") + self.assertEqual(bin(s), "12345678") + self.assertEqual(s.value, 0x12345678) + + s = c_int.__ctype_le__(0x12345678) + self.assertEqual(bin(struct.pack("I", 0x12345678)), "12345678") + self.assertEqual(bin(s), "12345678") + self.assertEqual(s.value, 0x12345678) + + s = c_uint.__ctype_le__(0x12345678) + self.assertEqual(bin(struct.pack("q", 0x1234567890ABCDEF)), "1234567890ABCDEF") + self.assertEqual(bin(s), "1234567890ABCDEF") + self.assertEqual(s.value, 0x1234567890ABCDEF) + + s = c_longlong.__ctype_le__(0x1234567890ABCDEF) + self.assertEqual(bin(struct.pack("Q", 0x1234567890ABCDEF)), "1234567890ABCDEF") + self.assertEqual(bin(s), "1234567890ABCDEF") + self.assertEqual(s.value, 0x1234567890ABCDEF) + + s = c_ulonglong.__ctype_le__(0x1234567890ABCDEF) + self.assertEqual(bin(struct.pack("f", math.pi)), bin(s)) + + def test_endian_double(self): + if sys.byteorder == "little": + self.assertIs(c_double.__ctype_le__, c_double) + self.assertIs(c_double.__ctype_be__.__ctype_le__, c_double) + else: + self.assertIs(c_double.__ctype_be__, c_double) + self.assertIs(c_double.__ctype_le__.__ctype_be__, c_double) + s = c_double(math.pi) + self.assertEqual(s.value, math.pi) + self.assertEqual(bin(struct.pack("d", math.pi)), bin(s)) + s = c_double.__ctype_le__(math.pi) + self.assertEqual(s.value, math.pi) + self.assertEqual(bin(struct.pack("d", math.pi)), bin(s)) + + def test_endian_other(self): + self.assertIs(c_byte.__ctype_le__, c_byte) + self.assertIs(c_byte.__ctype_be__, c_byte) + + self.assertIs(c_ubyte.__ctype_le__, c_ubyte) + self.assertIs(c_ubyte.__ctype_be__, c_ubyte) + + self.assertIs(c_char.__ctype_le__, c_char) + self.assertIs(c_char.__ctype_be__, c_char) + + def test_struct_fields_unsupported_byte_order(self): + + fields = [ + ("a", c_ubyte), + ("b", c_byte), + ("c", c_short), + ("d", c_ushort), + ("e", c_int), + ("f", c_uint), + ("g", c_long), + ("h", c_ulong), + ("i", c_longlong), + ("k", c_ulonglong), + ("l", c_float), + ("m", c_double), + ("n", c_char), + ("b1", c_byte, 3), + ("b2", c_byte, 3), + ("b3", c_byte, 2), + ("a", c_int * 3 * 3 * 3) + ] + + # these fields do not support different byte order: + for typ in c_wchar, c_void_p, POINTER(c_int): + with self.assertRaises(TypeError): + class T(BigEndianStructure if sys.byteorder == "little" else LittleEndianStructure): + _fields_ = fields + [("x", typ)] + + + def test_struct_struct(self): + # nested structures with different byteorders + + # create nested structures with given byteorders and set memory to data + + for nested, data in ( + (BigEndianStructure, b'\0\0\0\1\0\0\0\2'), + (LittleEndianStructure, b'\1\0\0\0\2\0\0\0'), + ): + for parent in ( + BigEndianStructure, + LittleEndianStructure, + Structure, + ): + class NestedStructure(nested): + _fields_ = [("x", c_uint32), + ("y", c_uint32)] + + class TestStructure(parent): + _fields_ = [("point", NestedStructure)] + + self.assertEqual(len(data), sizeof(TestStructure)) + ptr = POINTER(TestStructure) + s = cast(data, ptr)[0] + del ctypes._pointer_type_cache[TestStructure] + self.assertEqual(s.point.x, 1) + self.assertEqual(s.point.y, 2) + + def test_struct_field_alignment(self): + # standard packing in struct uses no alignment. + # So, we have to align using pad bytes. + # + # Unaligned accesses will crash Python (on those platforms that + # don't allow it, like sparc solaris). + if sys.byteorder == "little": + base = BigEndianStructure + fmt = ">bxhid" + else: + base = LittleEndianStructure + fmt = " float -> double + self.check_type(c_float, math.e) + self.check_type(c_float, -math.e) + + def test_double(self): + self.check_type(c_double, 3.14) + self.check_type(c_double, -3.14) + + def test_longdouble(self): + self.check_type(c_longdouble, 3.14) + self.check_type(c_longdouble, -3.14) + + def test_char(self): + self.check_type(c_char, b"x") + self.check_type(c_char, b"a") + + def test_pyobject(self): + o = () + for o in (), [], object(): + initial = sys.getrefcount(o) + # This call leaks a reference to 'o'... + self.check_type(py_object, o) + before = sys.getrefcount(o) + # ...but this call doesn't leak any more. Where is the refcount? + self.check_type(py_object, o) + after = sys.getrefcount(o) + self.assertEqual((after, o), (before, o)) + + def test_unsupported_restype_1(self): + # Only "fundamental" result types are supported for callback + # functions, the type must have a non-NULL stginfo->setfunc. + # POINTER(c_double), for example, is not supported. + + prototype = self.functype.__func__(POINTER(c_double)) + # The type is checked when the prototype is called + self.assertRaises(TypeError, prototype, lambda: None) + + def test_unsupported_restype_2(self): + prototype = self.functype.__func__(object) + self.assertRaises(TypeError, prototype, lambda: None) + + def test_issue_7959(self): + proto = self.functype.__func__(None) + + class X: + def func(self): pass + def __init__(self): + self.v = proto(self.func) + + for i in range(32): + X() + gc.collect() + live = [x for x in gc.get_objects() + if isinstance(x, X)] + self.assertEqual(len(live), 0) + + def test_issue12483(self): + class Nasty: + def __del__(self): + gc.collect() + CFUNCTYPE(None)(lambda x=Nasty(): None) + + @unittest.skipUnless(hasattr(ctypes, 'WINFUNCTYPE'), + 'ctypes.WINFUNCTYPE is required') + def test_i38748_stackCorruption(self): + callback_funcType = ctypes.WINFUNCTYPE(c_long, c_long, c_longlong) + @callback_funcType + def callback(a, b): + c = a + b + print(f"a={a}, b={b}, c={c}") + return c + dll = cdll[_ctypes_test.__file__] + with support.captured_stdout() as out: + # With no fix for i38748, the next line will raise OSError and cause the test to fail. + self.assertEqual(dll._test_i38748_runCallback(callback, 5, 10), 15) + self.assertEqual(out.getvalue(), "a=5, b=10, c=15\n") + +if hasattr(ctypes, 'WINFUNCTYPE'): + class StdcallCallbacks(Callbacks): + functype = ctypes.WINFUNCTYPE + + +class SampleCallbacksTestCase(unittest.TestCase): + + def test_integrate(self): + # Derived from some then non-working code, posted by David Foster + dll = CDLL(_ctypes_test.__file__) + + # The function prototype called by 'integrate': double func(double); + CALLBACK = CFUNCTYPE(c_double, c_double) + + # The integrate function itself, exposed from the _ctypes_test dll + integrate = dll.integrate + integrate.argtypes = (c_double, c_double, CALLBACK, c_long) + integrate.restype = c_double + + def func(x): + return x**2 + + result = integrate(0.0, 1.0, CALLBACK(func), 10) + diff = abs(result - 1./3.) + + self.assertLess(diff, 0.01, "%s not less than 0.01" % diff) + + def test_issue_8959_a(self): + libc_path = find_library("c") + if not libc_path: + self.skipTest('could not find libc') + libc = CDLL(libc_path) + + @CFUNCTYPE(c_int, POINTER(c_int), POINTER(c_int)) + def cmp_func(a, b): + return a[0] - b[0] + + array = (c_int * 5)(5, 1, 99, 7, 33) + + libc.qsort(array, len(array), sizeof(c_int), cmp_func) + self.assertEqual(array[:], [1, 5, 7, 33, 99]) + + @unittest.skipUnless(hasattr(ctypes, 'WINFUNCTYPE'), + 'ctypes.WINFUNCTYPE is required') + def test_issue_8959_b(self): + from ctypes.wintypes import BOOL, HWND, LPARAM + global windowCount + windowCount = 0 + + @ctypes.WINFUNCTYPE(BOOL, HWND, LPARAM) + def EnumWindowsCallbackFunc(hwnd, lParam): + global windowCount + windowCount += 1 + return True #Allow windows to keep enumerating + + user32 = ctypes.windll.user32 + user32.EnumWindows(EnumWindowsCallbackFunc, 0) + + def test_callback_register_int(self): + # Issue #8275: buggy handling of callback args under Win64 + # NOTE: should be run on release builds as well + dll = CDLL(_ctypes_test.__file__) + CALLBACK = CFUNCTYPE(c_int, c_int, c_int, c_int, c_int, c_int) + # All this function does is call the callback with its args squared + func = dll._testfunc_cbk_reg_int + func.argtypes = (c_int, c_int, c_int, c_int, c_int, CALLBACK) + func.restype = c_int + + def callback(a, b, c, d, e): + return a + b + c + d + e + + result = func(2, 3, 4, 5, 6, CALLBACK(callback)) + self.assertEqual(result, callback(2*2, 3*3, 4*4, 5*5, 6*6)) + + def test_callback_register_double(self): + # Issue #8275: buggy handling of callback args under Win64 + # NOTE: should be run on release builds as well + dll = CDLL(_ctypes_test.__file__) + CALLBACK = CFUNCTYPE(c_double, c_double, c_double, c_double, + c_double, c_double) + # All this function does is call the callback with its args squared + func = dll._testfunc_cbk_reg_double + func.argtypes = (c_double, c_double, c_double, + c_double, c_double, CALLBACK) + func.restype = c_double + + def callback(a, b, c, d, e): + return a + b + c + d + e + + result = func(1.1, 2.2, 3.3, 4.4, 5.5, CALLBACK(callback)) + self.assertEqual(result, + callback(1.1*1.1, 2.2*2.2, 3.3*3.3, 4.4*4.4, 5.5*5.5)) + + def test_callback_large_struct(self): + class Check: pass + + # This should mirror the structure in Modules/_ctypes/_ctypes_test.c + class X(Structure): + _fields_ = [ + ('first', c_ulong), + ('second', c_ulong), + ('third', c_ulong), + ] + + def callback(check, s): + check.first = s.first + check.second = s.second + check.third = s.third + # See issue #29565. + # The structure should be passed by value, so + # any changes to it should not be reflected in + # the value passed + s.first = s.second = s.third = 0x0badf00d + + check = Check() + s = X() + s.first = 0xdeadbeef + s.second = 0xcafebabe + s.third = 0x0bad1dea + + CALLBACK = CFUNCTYPE(None, X) + dll = CDLL(_ctypes_test.__file__) + func = dll._testfunc_cbk_large_struct + func.argtypes = (X, CALLBACK) + func.restype = None + # the function just calls the callback with the passed structure + func(s, CALLBACK(functools.partial(callback, check))) + self.assertEqual(check.first, s.first) + self.assertEqual(check.second, s.second) + self.assertEqual(check.third, s.third) + self.assertEqual(check.first, 0xdeadbeef) + self.assertEqual(check.second, 0xcafebabe) + self.assertEqual(check.third, 0x0bad1dea) + # See issue #29565. + # Ensure that the original struct is unchanged. + self.assertEqual(s.first, check.first) + self.assertEqual(s.second, check.second) + self.assertEqual(s.third, check.third) + + def test_callback_too_many_args(self): + def func(*args): + return len(args) + + # valid call with nargs <= CTYPES_MAX_ARGCOUNT + proto = CFUNCTYPE(c_int, *(c_int,) * CTYPES_MAX_ARGCOUNT) + cb = proto(func) + args1 = (1,) * CTYPES_MAX_ARGCOUNT + self.assertEqual(cb(*args1), CTYPES_MAX_ARGCOUNT) + + # invalid call with nargs > CTYPES_MAX_ARGCOUNT + args2 = (1,) * (CTYPES_MAX_ARGCOUNT + 1) + with self.assertRaises(ArgumentError): + cb(*args2) + + # error when creating the type with too many arguments + with self.assertRaises(ArgumentError): + CFUNCTYPE(c_int, *(c_int,) * (CTYPES_MAX_ARGCOUNT + 1)) + + def test_convert_result_error(self): + def func(): + return ("tuple",) + + proto = CFUNCTYPE(c_int) + ctypes_func = proto(func) + with support.catch_unraisable_exception() as cm: + # don't test the result since it is an uninitialized value + result = ctypes_func() + + self.assertIsInstance(cm.unraisable.exc_value, TypeError) + self.assertEqual(cm.unraisable.err_msg, + f"Exception ignored on converting result " + f"of ctypes callback function {func!r}") + self.assertIsNone(cm.unraisable.object) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_cast.py b/Lib/test/test_ctypes/test_cast.py new file mode 100644 index 00000000000..604f44f03d6 --- /dev/null +++ b/Lib/test/test_ctypes/test_cast.py @@ -0,0 +1,100 @@ +import sys +import unittest +from ctypes import (Structure, Union, POINTER, cast, sizeof, addressof, + c_void_p, c_char_p, c_wchar_p, + c_byte, c_short, c_int) + + +class Test(unittest.TestCase): + def test_array2pointer(self): + array = (c_int * 3)(42, 17, 2) + + # casting an array to a pointer works. + ptr = cast(array, POINTER(c_int)) + self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2]) + + if 2 * sizeof(c_short) == sizeof(c_int): + ptr = cast(array, POINTER(c_short)) + if sys.byteorder == "little": + self.assertEqual([ptr[i] for i in range(6)], + [42, 0, 17, 0, 2, 0]) + else: + self.assertEqual([ptr[i] for i in range(6)], + [0, 42, 0, 17, 0, 2]) + + def test_address2pointer(self): + array = (c_int * 3)(42, 17, 2) + + address = addressof(array) + ptr = cast(c_void_p(address), POINTER(c_int)) + self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2]) + + ptr = cast(address, POINTER(c_int)) + self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2]) + + def test_p2a_objects(self): + array = (c_char_p * 5)() + self.assertEqual(array._objects, None) + array[0] = b"foo bar" + self.assertEqual(array._objects, {'0': b"foo bar"}) + + p = cast(array, POINTER(c_char_p)) + # array and p share a common _objects attribute + self.assertIs(p._objects, array._objects) + self.assertEqual(array._objects, {'0': b"foo bar", id(array): array}) + p[0] = b"spam spam" + self.assertEqual(p._objects, {'0': b"spam spam", id(array): array}) + self.assertIs(array._objects, p._objects) + p[1] = b"foo bar" + self.assertEqual(p._objects, {'1': b'foo bar', '0': b"spam spam", id(array): array}) + self.assertIs(array._objects, p._objects) + + def test_other(self): + p = cast((c_int * 4)(1, 2, 3, 4), POINTER(c_int)) + self.assertEqual(p[:4], [1,2, 3, 4]) + self.assertEqual(p[:4:], [1, 2, 3, 4]) + self.assertEqual(p[3:-1:-1], [4, 3, 2, 1]) + self.assertEqual(p[:4:3], [1, 4]) + c_int() + self.assertEqual(p[:4], [1, 2, 3, 4]) + self.assertEqual(p[:4:], [1, 2, 3, 4]) + self.assertEqual(p[3:-1:-1], [4, 3, 2, 1]) + self.assertEqual(p[:4:3], [1, 4]) + p[2] = 96 + self.assertEqual(p[:4], [1, 2, 96, 4]) + self.assertEqual(p[:4:], [1, 2, 96, 4]) + self.assertEqual(p[3:-1:-1], [4, 96, 2, 1]) + self.assertEqual(p[:4:3], [1, 4]) + c_int() + self.assertEqual(p[:4], [1, 2, 96, 4]) + self.assertEqual(p[:4:], [1, 2, 96, 4]) + self.assertEqual(p[3:-1:-1], [4, 96, 2, 1]) + self.assertEqual(p[:4:3], [1, 4]) + + def test_char_p(self): + # This didn't work: bad argument to internal function + s = c_char_p(b"hiho") + self.assertEqual(cast(cast(s, c_void_p), c_char_p).value, + b"hiho") + + def test_wchar_p(self): + s = c_wchar_p("hiho") + self.assertEqual(cast(cast(s, c_void_p), c_wchar_p).value, + "hiho") + + def test_bad_type_arg(self): + # The type argument must be a ctypes pointer type. + array_type = c_byte * sizeof(c_int) + array = array_type() + self.assertRaises(TypeError, cast, array, None) + self.assertRaises(TypeError, cast, array, array_type) + class Struct(Structure): + _fields_ = [("a", c_int)] + self.assertRaises(TypeError, cast, array, Struct) + class MyUnion(Union): + _fields_ = [("a", c_int)] + self.assertRaises(TypeError, cast, array, MyUnion) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_cfuncs.py b/Lib/test/test_ctypes/test_cfuncs.py new file mode 100644 index 00000000000..48330c4b0a7 --- /dev/null +++ b/Lib/test/test_ctypes/test_cfuncs.py @@ -0,0 +1,211 @@ +import ctypes +import unittest +from ctypes import (CDLL, + c_byte, c_ubyte, c_char, + c_short, c_ushort, c_int, c_uint, + c_long, c_ulong, c_longlong, c_ulonglong, + c_float, c_double, c_longdouble) +from test.support import import_helper +_ctypes_test = import_helper.import_module("_ctypes_test") + + +class CFunctions(unittest.TestCase): + _dll = CDLL(_ctypes_test.__file__) + + def S(self): + return c_longlong.in_dll(self._dll, "last_tf_arg_s").value + def U(self): + return c_ulonglong.in_dll(self._dll, "last_tf_arg_u").value + + def test_byte(self): + self._dll.tf_b.restype = c_byte + self._dll.tf_b.argtypes = (c_byte,) + self.assertEqual(self._dll.tf_b(-126), -42) + self.assertEqual(self.S(), -126) + + def test_byte_plus(self): + self._dll.tf_bb.restype = c_byte + self._dll.tf_bb.argtypes = (c_byte, c_byte) + self.assertEqual(self._dll.tf_bb(0, -126), -42) + self.assertEqual(self.S(), -126) + + def test_ubyte(self): + self._dll.tf_B.restype = c_ubyte + self._dll.tf_B.argtypes = (c_ubyte,) + self.assertEqual(self._dll.tf_B(255), 85) + self.assertEqual(self.U(), 255) + + def test_ubyte_plus(self): + self._dll.tf_bB.restype = c_ubyte + self._dll.tf_bB.argtypes = (c_byte, c_ubyte) + self.assertEqual(self._dll.tf_bB(0, 255), 85) + self.assertEqual(self.U(), 255) + + def test_short(self): + self._dll.tf_h.restype = c_short + self._dll.tf_h.argtypes = (c_short,) + self.assertEqual(self._dll.tf_h(-32766), -10922) + self.assertEqual(self.S(), -32766) + + def test_short_plus(self): + self._dll.tf_bh.restype = c_short + self._dll.tf_bh.argtypes = (c_byte, c_short) + self.assertEqual(self._dll.tf_bh(0, -32766), -10922) + self.assertEqual(self.S(), -32766) + + def test_ushort(self): + self._dll.tf_H.restype = c_ushort + self._dll.tf_H.argtypes = (c_ushort,) + self.assertEqual(self._dll.tf_H(65535), 21845) + self.assertEqual(self.U(), 65535) + + def test_ushort_plus(self): + self._dll.tf_bH.restype = c_ushort + self._dll.tf_bH.argtypes = (c_byte, c_ushort) + self.assertEqual(self._dll.tf_bH(0, 65535), 21845) + self.assertEqual(self.U(), 65535) + + def test_int(self): + self._dll.tf_i.restype = c_int + self._dll.tf_i.argtypes = (c_int,) + self.assertEqual(self._dll.tf_i(-2147483646), -715827882) + self.assertEqual(self.S(), -2147483646) + + def test_int_plus(self): + self._dll.tf_bi.restype = c_int + self._dll.tf_bi.argtypes = (c_byte, c_int) + self.assertEqual(self._dll.tf_bi(0, -2147483646), -715827882) + self.assertEqual(self.S(), -2147483646) + + def test_uint(self): + self._dll.tf_I.restype = c_uint + self._dll.tf_I.argtypes = (c_uint,) + self.assertEqual(self._dll.tf_I(4294967295), 1431655765) + self.assertEqual(self.U(), 4294967295) + + def test_uint_plus(self): + self._dll.tf_bI.restype = c_uint + self._dll.tf_bI.argtypes = (c_byte, c_uint) + self.assertEqual(self._dll.tf_bI(0, 4294967295), 1431655765) + self.assertEqual(self.U(), 4294967295) + + def test_long(self): + self._dll.tf_l.restype = c_long + self._dll.tf_l.argtypes = (c_long,) + self.assertEqual(self._dll.tf_l(-2147483646), -715827882) + self.assertEqual(self.S(), -2147483646) + + def test_long_plus(self): + self._dll.tf_bl.restype = c_long + self._dll.tf_bl.argtypes = (c_byte, c_long) + self.assertEqual(self._dll.tf_bl(0, -2147483646), -715827882) + self.assertEqual(self.S(), -2147483646) + + def test_ulong(self): + self._dll.tf_L.restype = c_ulong + self._dll.tf_L.argtypes = (c_ulong,) + self.assertEqual(self._dll.tf_L(4294967295), 1431655765) + self.assertEqual(self.U(), 4294967295) + + def test_ulong_plus(self): + self._dll.tf_bL.restype = c_ulong + self._dll.tf_bL.argtypes = (c_char, c_ulong) + self.assertEqual(self._dll.tf_bL(b' ', 4294967295), 1431655765) + self.assertEqual(self.U(), 4294967295) + + def test_longlong(self): + self._dll.tf_q.restype = c_longlong + self._dll.tf_q.argtypes = (c_longlong, ) + self.assertEqual(self._dll.tf_q(-9223372036854775806), -3074457345618258602) + self.assertEqual(self.S(), -9223372036854775806) + + def test_longlong_plus(self): + self._dll.tf_bq.restype = c_longlong + self._dll.tf_bq.argtypes = (c_byte, c_longlong) + self.assertEqual(self._dll.tf_bq(0, -9223372036854775806), -3074457345618258602) + self.assertEqual(self.S(), -9223372036854775806) + + def test_ulonglong(self): + self._dll.tf_Q.restype = c_ulonglong + self._dll.tf_Q.argtypes = (c_ulonglong, ) + self.assertEqual(self._dll.tf_Q(18446744073709551615), 6148914691236517205) + self.assertEqual(self.U(), 18446744073709551615) + + def test_ulonglong_plus(self): + self._dll.tf_bQ.restype = c_ulonglong + self._dll.tf_bQ.argtypes = (c_byte, c_ulonglong) + self.assertEqual(self._dll.tf_bQ(0, 18446744073709551615), 6148914691236517205) + self.assertEqual(self.U(), 18446744073709551615) + + def test_float(self): + self._dll.tf_f.restype = c_float + self._dll.tf_f.argtypes = (c_float,) + self.assertEqual(self._dll.tf_f(-42.), -14.) + self.assertEqual(self.S(), -42) + + def test_float_plus(self): + self._dll.tf_bf.restype = c_float + self._dll.tf_bf.argtypes = (c_byte, c_float) + self.assertEqual(self._dll.tf_bf(0, -42.), -14.) + self.assertEqual(self.S(), -42) + + def test_double(self): + self._dll.tf_d.restype = c_double + self._dll.tf_d.argtypes = (c_double,) + self.assertEqual(self._dll.tf_d(42.), 14.) + self.assertEqual(self.S(), 42) + + def test_double_plus(self): + self._dll.tf_bd.restype = c_double + self._dll.tf_bd.argtypes = (c_byte, c_double) + self.assertEqual(self._dll.tf_bd(0, 42.), 14.) + self.assertEqual(self.S(), 42) + + def test_longdouble(self): + self._dll.tf_D.restype = c_longdouble + self._dll.tf_D.argtypes = (c_longdouble,) + self.assertEqual(self._dll.tf_D(42.), 14.) + self.assertEqual(self.S(), 42) + + def test_longdouble_plus(self): + self._dll.tf_bD.restype = c_longdouble + self._dll.tf_bD.argtypes = (c_byte, c_longdouble) + self.assertEqual(self._dll.tf_bD(0, 42.), 14.) + self.assertEqual(self.S(), 42) + + def test_callwithresult(self): + def process_result(result): + return result * 2 + self._dll.tf_i.restype = process_result + self._dll.tf_i.argtypes = (c_int,) + self.assertEqual(self._dll.tf_i(42), 28) + self.assertEqual(self.S(), 42) + self.assertEqual(self._dll.tf_i(-42), -28) + self.assertEqual(self.S(), -42) + + def test_void(self): + self._dll.tv_i.restype = None + self._dll.tv_i.argtypes = (c_int,) + self.assertEqual(self._dll.tv_i(42), None) + self.assertEqual(self.S(), 42) + self.assertEqual(self._dll.tv_i(-42), None) + self.assertEqual(self.S(), -42) + + +# The following repeats the above tests with stdcall functions (where +# they are available) +if hasattr(ctypes, 'WinDLL'): + class stdcall_dll(ctypes.WinDLL): + def __getattr__(self, name): + if name[:2] == '__' and name[-2:] == '__': + raise AttributeError(name) + func = self._FuncPtr(("s_" + name, self)) + setattr(self, name, func) + return func + + class stdcallCFunctions(CFunctions): + _dll = stdcall_dll(_ctypes_test.__file__) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_checkretval.py b/Lib/test/test_ctypes/test_checkretval.py new file mode 100644 index 00000000000..9d6bfdb845e --- /dev/null +++ b/Lib/test/test_ctypes/test_checkretval.py @@ -0,0 +1,37 @@ +import ctypes +import unittest +from ctypes import CDLL, c_int +from test.support import import_helper +_ctypes_test = import_helper.import_module("_ctypes_test") + + +class CHECKED(c_int): + def _check_retval_(value): + # Receives a CHECKED instance. + return str(value.value) + _check_retval_ = staticmethod(_check_retval_) + + +class Test(unittest.TestCase): + def test_checkretval(self): + dll = CDLL(_ctypes_test.__file__) + self.assertEqual(42, dll._testfunc_p_p(42)) + + dll._testfunc_p_p.restype = CHECKED + self.assertEqual("42", dll._testfunc_p_p(42)) + + dll._testfunc_p_p.restype = None + self.assertEqual(None, dll._testfunc_p_p(42)) + + del dll._testfunc_p_p.restype + self.assertEqual(42, dll._testfunc_p_p(42)) + + @unittest.skipUnless(hasattr(ctypes, 'oledll'), + 'ctypes.oledll is required') + def test_oledll(self): + oleaut32 = ctypes.oledll.oleaut32 + self.assertRaises(OSError, oleaut32.CreateTypeLib2, 0, None, None) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_delattr.py b/Lib/test/test_ctypes/test_delattr.py new file mode 100644 index 00000000000..e80b5fa6efb --- /dev/null +++ b/Lib/test/test_ctypes/test_delattr.py @@ -0,0 +1,26 @@ +import unittest +from ctypes import Structure, c_char, c_int + + +class X(Structure): + _fields_ = [("foo", c_int)] + + +class TestCase(unittest.TestCase): + def test_simple(self): + with self.assertRaises(TypeError): + del c_int(42).value + + def test_chararray(self): + chararray = (c_char * 5)() + with self.assertRaises(TypeError): + del chararray.value + + def test_struct(self): + struct = X() + with self.assertRaises(TypeError): + del struct.foo + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_dlerror.py b/Lib/test/test_ctypes/test_dlerror.py new file mode 100644 index 00000000000..1c1b2aab3d5 --- /dev/null +++ b/Lib/test/test_ctypes/test_dlerror.py @@ -0,0 +1,179 @@ +import _ctypes +import os +import platform +import sys +import test.support +import unittest +from ctypes import CDLL, c_int +from ctypes.util import find_library + + +FOO_C = r""" +#include + +/* This is a 'GNU indirect function' (IFUNC) that will be called by + dlsym() to resolve the symbol "foo" to an address. Typically, such + a function would return the address of an actual function, but it + can also just return NULL. For some background on IFUNCs, see + https://willnewton.name/uncategorized/using-gnu-indirect-functions. + + Adapted from Michael Kerrisk's answer: https://stackoverflow.com/a/53590014. +*/ + +asm (".type foo STT_GNU_IFUNC"); + +void *foo(void) +{ + write($DESCRIPTOR, "OK", 2); + return NULL; +} +""" + + +@unittest.skipUnless(sys.platform.startswith('linux'), + 'test requires GNU IFUNC support') +class TestNullDlsym(unittest.TestCase): + """GH-126554: Ensure that we catch NULL dlsym return values + + In rare cases, such as when using GNU IFUNCs, dlsym(), + the C function that ctypes' CDLL uses to get the address + of symbols, can return NULL. + + The objective way of telling if an error during symbol + lookup happened is to call glibc's dlerror() and check + for a non-NULL return value. + + However, there can be cases where dlsym() returns NULL + and dlerror() is also NULL, meaning that glibc did not + encounter any error. + + In the case of ctypes, we subjectively treat that as + an error, and throw a relevant exception. + + This test case ensures that we correctly enforce + this 'dlsym returned NULL -> throw Error' rule. + """ + + def test_null_dlsym(self): + import subprocess + import tempfile + + try: + retcode = subprocess.call(["gcc", "--version"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL) + except OSError: + self.skipTest("gcc is missing") + if retcode != 0: + self.skipTest("gcc --version failed") + + pipe_r, pipe_w = os.pipe() + self.addCleanup(os.close, pipe_r) + self.addCleanup(os.close, pipe_w) + + with tempfile.TemporaryDirectory() as d: + # Create a C file with a GNU Indirect Function (FOO_C) + # and compile it into a shared library. + srcname = os.path.join(d, 'foo.c') + dstname = os.path.join(d, 'libfoo.so') + with open(srcname, 'w') as f: + f.write(FOO_C.replace('$DESCRIPTOR', str(pipe_w))) + args = ['gcc', '-fPIC', '-shared', '-o', dstname, srcname] + p = subprocess.run(args, capture_output=True) + + if p.returncode != 0: + # IFUNC is not supported on all architectures. + if platform.machine() == 'x86_64': + # It should be supported here. Something else went wrong. + p.check_returncode() + else: + # IFUNC might not be supported on this machine. + self.skipTest(f"could not compile indirect function: {p}") + + # Case #1: Test 'PyCFuncPtr_FromDll' from Modules/_ctypes/_ctypes.c + L = CDLL(dstname) + with self.assertRaisesRegex(AttributeError, "function 'foo' not found"): + # Try accessing the 'foo' symbol. + # It should resolve via dlsym() to NULL, + # and since we subjectively treat NULL + # addresses as errors, we should get + # an error. + L.foo + + # Assert that the IFUNC was called + self.assertEqual(os.read(pipe_r, 2), b'OK') + + # Case #2: Test 'CDataType_in_dll_impl' from Modules/_ctypes/_ctypes.c + with self.assertRaisesRegex(ValueError, "symbol 'foo' not found"): + c_int.in_dll(L, "foo") + + # Assert that the IFUNC was called + self.assertEqual(os.read(pipe_r, 2), b'OK') + + # Case #3: Test 'py_dl_sym' from Modules/_ctypes/callproc.c + dlopen = test.support.get_attribute(_ctypes, 'dlopen') + dlsym = test.support.get_attribute(_ctypes, 'dlsym') + L = dlopen(dstname) + with self.assertRaisesRegex(OSError, "symbol 'foo' not found"): + dlsym(L, "foo") + + # Assert that the IFUNC was called + self.assertEqual(os.read(pipe_r, 2), b'OK') + + +@unittest.skipUnless(os.name != 'nt', 'test requires dlerror() calls') +class TestLocalization(unittest.TestCase): + + @staticmethod + def configure_locales(func): + return test.support.run_with_locale( + 'LC_ALL', + 'fr_FR.iso88591', 'ja_JP.sjis', 'zh_CN.gbk', + 'fr_FR.utf8', 'en_US.utf8', + '', + )(func) + + @classmethod + def setUpClass(cls): + cls.libc_filename = find_library("c") + if cls.libc_filename is None: + raise unittest.SkipTest('cannot find libc') + + @configure_locales + def test_localized_error_from_dll(self): + dll = CDLL(self.libc_filename) + with self.assertRaises(AttributeError): + dll.this_name_does_not_exist + + @configure_locales + def test_localized_error_in_dll(self): + dll = CDLL(self.libc_filename) + with self.assertRaises(ValueError): + c_int.in_dll(dll, 'this_name_does_not_exist') + + @unittest.skipUnless(hasattr(_ctypes, 'dlopen'), + 'test requires _ctypes.dlopen()') + @configure_locales + def test_localized_error_dlopen(self): + missing_filename = b'missing\xff.so' + # Depending whether the locale, we may encode '\xff' differently + # but we are only interested in avoiding a UnicodeDecodeError + # when reporting the dlerror() error message which contains + # the localized filename. + filename_pattern = r'missing.*?\.so' + with self.assertRaisesRegex(OSError, filename_pattern): + _ctypes.dlopen(missing_filename, 2) + + @unittest.skipUnless(hasattr(_ctypes, 'dlopen'), + 'test requires _ctypes.dlopen()') + @unittest.skipUnless(hasattr(_ctypes, 'dlsym'), + 'test requires _ctypes.dlsym()') + @configure_locales + def test_localized_error_dlsym(self): + dll = _ctypes.dlopen(self.libc_filename) + with self.assertRaises(OSError): + _ctypes.dlsym(dll, 'this_name_does_not_exist') + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_errno.py b/Lib/test/test_ctypes/test_errno.py new file mode 100644 index 00000000000..65d99c1e492 --- /dev/null +++ b/Lib/test/test_ctypes/test_errno.py @@ -0,0 +1,81 @@ +import ctypes +import errno +import os +import threading +import unittest +from ctypes import CDLL, c_int, c_char_p, c_wchar_p, get_errno, set_errno +from ctypes.util import find_library + + +class Test(unittest.TestCase): + def test_open(self): + libc_name = find_library("c") + if libc_name is None: + self.skipTest("Unable to find C library") + + libc = CDLL(libc_name, use_errno=True) + if os.name == "nt": + libc_open = libc._open + else: + libc_open = libc.open + + libc_open.argtypes = c_char_p, c_int + + self.assertEqual(libc_open(b"", 0), -1) + self.assertEqual(get_errno(), errno.ENOENT) + + self.assertEqual(set_errno(32), errno.ENOENT) + self.assertEqual(get_errno(), 32) + + def _worker(): + set_errno(0) + + libc = CDLL(libc_name, use_errno=False) + if os.name == "nt": + libc_open = libc._open + else: + libc_open = libc.open + libc_open.argtypes = c_char_p, c_int + self.assertEqual(libc_open(b"", 0), -1) + self.assertEqual(get_errno(), 0) + + t = threading.Thread(target=_worker) + t.start() + t.join() + + self.assertEqual(get_errno(), 32) + set_errno(0) + + @unittest.skipUnless(os.name == "nt", 'Test specific to Windows') + def test_GetLastError(self): + dll = ctypes.WinDLL("kernel32", use_last_error=True) + GetModuleHandle = dll.GetModuleHandleA + GetModuleHandle.argtypes = [c_wchar_p] + + self.assertEqual(0, GetModuleHandle("foo")) + self.assertEqual(ctypes.get_last_error(), 126) + + self.assertEqual(ctypes.set_last_error(32), 126) + self.assertEqual(ctypes.get_last_error(), 32) + + def _worker(): + ctypes.set_last_error(0) + + dll = ctypes.WinDLL("kernel32", use_last_error=False) + GetModuleHandle = dll.GetModuleHandleW + GetModuleHandle.argtypes = [c_wchar_p] + GetModuleHandle("bar") + + self.assertEqual(ctypes.get_last_error(), 0) + + t = threading.Thread(target=_worker) + t.start() + t.join() + + self.assertEqual(ctypes.get_last_error(), 32) + + ctypes.set_last_error(0) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_find.py b/Lib/test/test_ctypes/test_find.py new file mode 100644 index 00000000000..85b28617d2d --- /dev/null +++ b/Lib/test/test_ctypes/test_find.py @@ -0,0 +1,156 @@ +import os.path +import sys +import test.support +import unittest +import unittest.mock +from ctypes import CDLL, RTLD_GLOBAL +from ctypes.util import find_library +from test.support import os_helper + + +# On some systems, loading the OpenGL libraries needs the RTLD_GLOBAL mode. +class Test_OpenGL_libs(unittest.TestCase): + @classmethod + def setUpClass(cls): + lib_gl = lib_glu = lib_gle = None + if sys.platform == "win32": + lib_gl = find_library("OpenGL32") + lib_glu = find_library("Glu32") + elif sys.platform == "darwin": + lib_gl = lib_glu = find_library("OpenGL") + else: + lib_gl = find_library("GL") + lib_glu = find_library("GLU") + lib_gle = find_library("gle") + + # print, for debugging + if test.support.verbose: + print("OpenGL libraries:") + for item in (("GL", lib_gl), + ("GLU", lib_glu), + ("gle", lib_gle)): + print("\t", item) + + cls.gl = cls.glu = cls.gle = None + if lib_gl: + try: + cls.gl = CDLL(lib_gl, mode=RTLD_GLOBAL) + except OSError: + pass + + if lib_glu: + try: + cls.glu = CDLL(lib_glu, RTLD_GLOBAL) + except OSError: + pass + + if lib_gle: + try: + cls.gle = CDLL(lib_gle) + except OSError: + pass + + @classmethod + def tearDownClass(cls): + cls.gl = cls.glu = cls.gle = None + + def test_gl(self): + if self.gl is None: + self.skipTest('lib_gl not available') + self.gl.glClearIndex + + def test_glu(self): + if self.glu is None: + self.skipTest('lib_glu not available') + self.glu.gluBeginCurve + + def test_gle(self): + if self.gle is None: + self.skipTest('lib_gle not available') + self.gle.gleGetJoinStyle + + def test_shell_injection(self): + result = find_library('; echo Hello shell > ' + os_helper.TESTFN) + self.assertFalse(os.path.lexists(os_helper.TESTFN)) + self.assertIsNone(result) + + +@unittest.skipUnless(sys.platform.startswith('linux'), + 'Test only valid for Linux') +class FindLibraryLinux(unittest.TestCase): + def test_find_on_libpath(self): + import subprocess + import tempfile + + try: + p = subprocess.Popen(['gcc', '--version'], stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL) + out, _ = p.communicate() + except OSError: + raise unittest.SkipTest('gcc, needed for test, not available') + with tempfile.TemporaryDirectory() as d: + # create an empty temporary file + srcname = os.path.join(d, 'dummy.c') + libname = 'py_ctypes_test_dummy' + dstname = os.path.join(d, 'lib%s.so' % libname) + with open(srcname, 'wb') as f: + pass + self.assertTrue(os.path.exists(srcname)) + # compile the file to a shared library + cmd = ['gcc', '-o', dstname, '--shared', + '-Wl,-soname,lib%s.so' % libname, srcname] + out = subprocess.check_output(cmd) + self.assertTrue(os.path.exists(dstname)) + # now check that the .so can't be found (since not in + # LD_LIBRARY_PATH) + self.assertIsNone(find_library(libname)) + # now add the location to LD_LIBRARY_PATH + with os_helper.EnvironmentVarGuard() as env: + KEY = 'LD_LIBRARY_PATH' + if KEY not in env: + v = d + else: + v = '%s:%s' % (env[KEY], d) + env.set(KEY, v) + # now check that the .so can be found (since in + # LD_LIBRARY_PATH) + self.assertEqual(find_library(libname), 'lib%s.so' % libname) + + def test_find_library_with_gcc(self): + with unittest.mock.patch("ctypes.util._findSoname_ldconfig", lambda *args: None): + self.assertNotEqual(find_library('c'), None) + + def test_find_library_with_ld(self): + with unittest.mock.patch("ctypes.util._findSoname_ldconfig", lambda *args: None), \ + unittest.mock.patch("ctypes.util._findLib_gcc", lambda *args: None): + self.assertNotEqual(find_library('c'), None) + + def test_gh114257(self): + self.assertIsNone(find_library("libc")) + + +@unittest.skipUnless(sys.platform == 'android', 'Test only valid for Android') +class FindLibraryAndroid(unittest.TestCase): + def test_find(self): + for name in [ + "c", "m", # POSIX + "z", # Non-POSIX, but present on Linux + "log", # Not present on Linux + ]: + with self.subTest(name=name): + path = find_library(name) + self.assertIsInstance(path, str) + self.assertEqual( + os.path.dirname(path), + "/system/lib64" if "64" in os.uname().machine + else "/system/lib") + self.assertEqual(os.path.basename(path), f"lib{name}.so") + self.assertTrue(os.path.isfile(path), path) + + for name in ["libc", "nonexistent"]: + with self.subTest(name=name): + self.assertIsNone(find_library(name)) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_frombuffer.py b/Lib/test/test_ctypes/test_frombuffer.py new file mode 100644 index 00000000000..d4e161f864d --- /dev/null +++ b/Lib/test/test_ctypes/test_frombuffer.py @@ -0,0 +1,144 @@ +import array +import gc +import unittest +from ctypes import (Structure, Union, Array, sizeof, + _Pointer, _SimpleCData, _CFuncPtr, + c_char, c_int) + + +class X(Structure): + _fields_ = [("c_int", c_int)] + init_called = False + def __init__(self): + self._init_called = True + + +class Test(unittest.TestCase): + def test_from_buffer(self): + a = array.array("i", range(16)) + x = (c_int * 16).from_buffer(a) + + y = X.from_buffer(a) + self.assertEqual(y.c_int, a[0]) + self.assertFalse(y.init_called) + + self.assertEqual(x[:], a.tolist()) + + a[0], a[-1] = 200, -200 + self.assertEqual(x[:], a.tolist()) + + self.assertRaises(BufferError, a.append, 100) + self.assertRaises(BufferError, a.pop) + + del x; del y; gc.collect(); gc.collect(); gc.collect() + a.append(100) + a.pop() + x = (c_int * 16).from_buffer(a) + + self.assertIn(a, [obj.obj if isinstance(obj, memoryview) else obj + for obj in x._objects.values()]) + + expected = x[:] + del a; gc.collect(); gc.collect(); gc.collect() + self.assertEqual(x[:], expected) + + with self.assertRaisesRegex(TypeError, "not writable"): + (c_char * 16).from_buffer(b"a" * 16) + with self.assertRaisesRegex(TypeError, "not writable"): + (c_char * 16).from_buffer(memoryview(b"a" * 16)) + with self.assertRaisesRegex(TypeError, "not C contiguous"): + (c_char * 16).from_buffer(memoryview(bytearray(b"a" * 16))[::-1]) + msg = "bytes-like object is required" + with self.assertRaisesRegex(TypeError, msg): + (c_char * 16).from_buffer("a" * 16) + + def test_fortran_contiguous(self): + try: + import _testbuffer + except ImportError as err: + self.skipTest(str(err)) + flags = _testbuffer.ND_WRITABLE | _testbuffer.ND_FORTRAN + array = _testbuffer.ndarray( + [97] * 16, format="B", shape=[4, 4], flags=flags) + with self.assertRaisesRegex(TypeError, "not C contiguous"): + (c_char * 16).from_buffer(array) + array = memoryview(array) + self.assertTrue(array.f_contiguous) + self.assertFalse(array.c_contiguous) + with self.assertRaisesRegex(TypeError, "not C contiguous"): + (c_char * 16).from_buffer(array) + + def test_from_buffer_with_offset(self): + a = array.array("i", range(16)) + x = (c_int * 15).from_buffer(a, sizeof(c_int)) + + self.assertEqual(x[:], a.tolist()[1:]) + with self.assertRaises(ValueError): + c_int.from_buffer(a, -1) + with self.assertRaises(ValueError): + (c_int * 16).from_buffer(a, sizeof(c_int)) + with self.assertRaises(ValueError): + (c_int * 1).from_buffer(a, 16 * sizeof(c_int)) + + def test_from_buffer_memoryview(self): + a = [c_char.from_buffer(memoryview(bytearray(b'a')))] + a.append(a) + del a + gc.collect() # Should not crash + + def test_from_buffer_copy(self): + a = array.array("i", range(16)) + x = (c_int * 16).from_buffer_copy(a) + + y = X.from_buffer_copy(a) + self.assertEqual(y.c_int, a[0]) + self.assertFalse(y.init_called) + + self.assertEqual(x[:], list(range(16))) + + a[0], a[-1] = 200, -200 + self.assertEqual(x[:], list(range(16))) + + a.append(100) + self.assertEqual(x[:], list(range(16))) + + self.assertEqual(x._objects, None) + + del a; gc.collect(); gc.collect(); gc.collect() + self.assertEqual(x[:], list(range(16))) + + x = (c_char * 16).from_buffer_copy(b"a" * 16) + self.assertEqual(x[:], b"a" * 16) + with self.assertRaises(TypeError): + (c_char * 16).from_buffer_copy("a" * 16) + + def test_from_buffer_copy_with_offset(self): + a = array.array("i", range(16)) + x = (c_int * 15).from_buffer_copy(a, sizeof(c_int)) + + self.assertEqual(x[:], a.tolist()[1:]) + with self.assertRaises(ValueError): + c_int.from_buffer_copy(a, -1) + with self.assertRaises(ValueError): + (c_int * 16).from_buffer_copy(a, sizeof(c_int)) + with self.assertRaises(ValueError): + (c_int * 1).from_buffer_copy(a, 16 * sizeof(c_int)) + + def test_abstract(self): + self.assertRaises(TypeError, Array.from_buffer, bytearray(10)) + self.assertRaises(TypeError, Structure.from_buffer, bytearray(10)) + self.assertRaises(TypeError, Union.from_buffer, bytearray(10)) + self.assertRaises(TypeError, _CFuncPtr.from_buffer, bytearray(10)) + self.assertRaises(TypeError, _Pointer.from_buffer, bytearray(10)) + self.assertRaises(TypeError, _SimpleCData.from_buffer, bytearray(10)) + + self.assertRaises(TypeError, Array.from_buffer_copy, b"123") + self.assertRaises(TypeError, Structure.from_buffer_copy, b"123") + self.assertRaises(TypeError, Union.from_buffer_copy, b"123") + self.assertRaises(TypeError, _CFuncPtr.from_buffer_copy, b"123") + self.assertRaises(TypeError, _Pointer.from_buffer_copy, b"123") + self.assertRaises(TypeError, _SimpleCData.from_buffer_copy, b"123") + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_funcptr.py b/Lib/test/test_ctypes/test_funcptr.py new file mode 100644 index 00000000000..8362fb16d94 --- /dev/null +++ b/Lib/test/test_ctypes/test_funcptr.py @@ -0,0 +1,134 @@ +import ctypes +import unittest +from ctypes import (CDLL, Structure, CFUNCTYPE, sizeof, _CFuncPtr, + c_void_p, c_char_p, c_char, c_int, c_uint, c_long) +from test.support import import_helper +_ctypes_test = import_helper.import_module("_ctypes_test") +from ._support import (_CData, PyCFuncPtrType, Py_TPFLAGS_DISALLOW_INSTANTIATION, + Py_TPFLAGS_IMMUTABLETYPE) + + +try: + WINFUNCTYPE = ctypes.WINFUNCTYPE +except AttributeError: + # fake to enable this test on Linux + WINFUNCTYPE = CFUNCTYPE + +lib = CDLL(_ctypes_test.__file__) + + +class CFuncPtrTestCase(unittest.TestCase): + def test_inheritance_hierarchy(self): + self.assertEqual(_CFuncPtr.mro(), [_CFuncPtr, _CData, object]) + + self.assertEqual(PyCFuncPtrType.__name__, "PyCFuncPtrType") + self.assertEqual(type(PyCFuncPtrType), type) + + def test_type_flags(self): + for cls in _CFuncPtr, PyCFuncPtrType: + with self.subTest(cls=cls): + self.assertTrue(_CFuncPtr.__flags__ & Py_TPFLAGS_IMMUTABLETYPE) + self.assertFalse(_CFuncPtr.__flags__ & Py_TPFLAGS_DISALLOW_INSTANTIATION) + + def test_metaclass_details(self): + # Cannot call the metaclass __init__ more than once + CdeclCallback = CFUNCTYPE(c_int, c_int, c_int) + with self.assertRaisesRegex(SystemError, "already initialized"): + PyCFuncPtrType.__init__(CdeclCallback, 'ptr', (), {}) + + def test_basic(self): + X = WINFUNCTYPE(c_int, c_int, c_int) + + def func(*args): + return len(args) + + x = X(func) + self.assertEqual(x.restype, c_int) + self.assertEqual(x.argtypes, (c_int, c_int)) + self.assertEqual(sizeof(x), sizeof(c_void_p)) + self.assertEqual(sizeof(X), sizeof(c_void_p)) + + def test_first(self): + StdCallback = WINFUNCTYPE(c_int, c_int, c_int) + CdeclCallback = CFUNCTYPE(c_int, c_int, c_int) + + def func(a, b): + return a + b + + s = StdCallback(func) + c = CdeclCallback(func) + + self.assertEqual(s(1, 2), 3) + self.assertEqual(c(1, 2), 3) + # The following no longer raises a TypeError - it is now + # possible, as in C, to call cdecl functions with more parameters. + #self.assertRaises(TypeError, c, 1, 2, 3) + self.assertEqual(c(1, 2, 3, 4, 5, 6), 3) + if WINFUNCTYPE is not CFUNCTYPE: + self.assertRaises(TypeError, s, 1, 2, 3) + + def test_structures(self): + WNDPROC = WINFUNCTYPE(c_long, c_int, c_int, c_int, c_int) + + def wndproc(hwnd, msg, wParam, lParam): + return hwnd + msg + wParam + lParam + + HINSTANCE = c_int + HICON = c_int + HCURSOR = c_int + LPCTSTR = c_char_p + + class WNDCLASS(Structure): + _fields_ = [("style", c_uint), + ("lpfnWndProc", WNDPROC), + ("cbClsExtra", c_int), + ("cbWndExtra", c_int), + ("hInstance", HINSTANCE), + ("hIcon", HICON), + ("hCursor", HCURSOR), + ("lpszMenuName", LPCTSTR), + ("lpszClassName", LPCTSTR)] + + wndclass = WNDCLASS() + wndclass.lpfnWndProc = WNDPROC(wndproc) + + WNDPROC_2 = WINFUNCTYPE(c_long, c_int, c_int, c_int, c_int) + + self.assertIs(WNDPROC, WNDPROC_2) + self.assertEqual(wndclass.lpfnWndProc(1, 2, 3, 4), 10) + + f = wndclass.lpfnWndProc + + del wndclass + del wndproc + + self.assertEqual(f(10, 11, 12, 13), 46) + + def test_dllfunctions(self): + strchr = lib.my_strchr + strchr.restype = c_char_p + strchr.argtypes = (c_char_p, c_char) + self.assertEqual(strchr(b"abcdefghi", b"b"), b"bcdefghi") + self.assertEqual(strchr(b"abcdefghi", b"x"), None) + + strtok = lib.my_strtok + strtok.restype = c_char_p + + def c_string(init): + size = len(init) + 1 + return (c_char*size)(*init) + + s = b"a\nb\nc" + b = c_string(s) + + self.assertEqual(strtok(b, b"\n"), b"a") + self.assertEqual(strtok(None, b"\n"), b"b") + self.assertEqual(strtok(None, b"\n"), b"c") + self.assertEqual(strtok(None, b"\n"), None) + + def test_abstract(self): + self.assertRaises(TypeError, _CFuncPtr, 13, "name", 42, "iid") + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_functions.py b/Lib/test/test_ctypes/test_functions.py new file mode 100644 index 00000000000..63e393f7b7c --- /dev/null +++ b/Lib/test/test_ctypes/test_functions.py @@ -0,0 +1,454 @@ +import ctypes +import sys +import unittest +from ctypes import (CDLL, Structure, Array, CFUNCTYPE, + byref, POINTER, pointer, ArgumentError, + c_char, c_wchar, c_byte, c_char_p, c_wchar_p, + c_short, c_int, c_long, c_longlong, c_void_p, + c_float, c_double, c_longdouble) +from test.support import import_helper +_ctypes_test = import_helper.import_module("_ctypes_test") +from _ctypes import _Pointer, _SimpleCData + + +try: + WINFUNCTYPE = ctypes.WINFUNCTYPE +except AttributeError: + # fake to enable this test on Linux + WINFUNCTYPE = CFUNCTYPE + +dll = CDLL(_ctypes_test.__file__) +if sys.platform == "win32": + windll = ctypes.WinDLL(_ctypes_test.__file__) + + +class POINT(Structure): + _fields_ = [("x", c_int), ("y", c_int)] + + +class RECT(Structure): + _fields_ = [("left", c_int), ("top", c_int), + ("right", c_int), ("bottom", c_int)] + + +class FunctionTestCase(unittest.TestCase): + + def test_mro(self): + # in Python 2.3, this raises TypeError: MRO conflict among bases classes, + # in Python 2.2 it works. + # + # But in early versions of _ctypes.c, the result of tp_new + # wasn't checked, and it even crashed Python. + # Found by Greg Chapman. + + with self.assertRaises(TypeError): + class X(object, Array): + _length_ = 5 + _type_ = "i" + + with self.assertRaises(TypeError): + class X2(object, _Pointer): + pass + + with self.assertRaises(TypeError): + class X3(object, _SimpleCData): + _type_ = "i" + + with self.assertRaises(TypeError): + class X4(object, Structure): + _fields_ = [] + + def test_c_char_parm(self): + proto = CFUNCTYPE(c_int, c_char) + def callback(*args): + return 0 + + callback = proto(callback) + + self.assertEqual(callback(b"a"), 0) + + with self.assertRaises(ArgumentError) as cm: + callback(b"abc") + + self.assertEqual(str(cm.exception), + "argument 1: TypeError: one character bytes, " + "bytearray or integer expected") + + def test_wchar_parm(self): + f = dll._testfunc_i_bhilfd + f.argtypes = [c_byte, c_wchar, c_int, c_long, c_float, c_double] + result = f(1, "x", 3, 4, 5.0, 6.0) + self.assertEqual(result, 139) + self.assertEqual(type(result), int) + + with self.assertRaises(ArgumentError) as cm: + f(1, 2, 3, 4, 5.0, 6.0) + self.assertEqual(str(cm.exception), + "argument 2: TypeError: unicode string expected " + "instead of int instance") + + with self.assertRaises(ArgumentError) as cm: + f(1, "abc", 3, 4, 5.0, 6.0) + self.assertEqual(str(cm.exception), + "argument 2: TypeError: one character unicode string " + "expected") + + def test_c_char_p_parm(self): + """Test the error message when converting an incompatible type to c_char_p.""" + proto = CFUNCTYPE(c_int, c_char_p) + def callback(*args): + return 0 + + callback = proto(callback) + self.assertEqual(callback(b"abc"), 0) + + with self.assertRaises(ArgumentError) as cm: + callback(10) + + self.assertEqual(str(cm.exception), + "argument 1: TypeError: 'int' object cannot be " + "interpreted as ctypes.c_char_p") + + def test_c_wchar_p_parm(self): + """Test the error message when converting an incompatible type to c_wchar_p.""" + proto = CFUNCTYPE(c_int, c_wchar_p) + def callback(*args): + return 0 + + callback = proto(callback) + self.assertEqual(callback("abc"), 0) + + with self.assertRaises(ArgumentError) as cm: + callback(10) + + self.assertEqual(str(cm.exception), + "argument 1: TypeError: 'int' object cannot be " + "interpreted as ctypes.c_wchar_p") + + def test_c_void_p_parm(self): + """Test the error message when converting an incompatible type to c_void_p.""" + proto = CFUNCTYPE(c_int, c_void_p) + def callback(*args): + return 0 + + callback = proto(callback) + self.assertEqual(callback(5), 0) + + with self.assertRaises(ArgumentError) as cm: + callback(2.5) + + self.assertEqual(str(cm.exception), + "argument 1: TypeError: 'float' object cannot be " + "interpreted as ctypes.c_void_p") + + def test_wchar_result(self): + f = dll._testfunc_i_bhilfd + f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] + f.restype = c_wchar + result = f(0, 0, 0, 0, 0, 0) + self.assertEqual(result, '\x00') + + def test_voidresult(self): + f = dll._testfunc_v + f.restype = None + f.argtypes = [c_int, c_int, POINTER(c_int)] + result = c_int() + self.assertEqual(None, f(1, 2, byref(result))) + self.assertEqual(result.value, 3) + + def test_intresult(self): + f = dll._testfunc_i_bhilfd + f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] + f.restype = c_int + result = f(1, 2, 3, 4, 5.0, 6.0) + self.assertEqual(result, 21) + self.assertEqual(type(result), int) + + result = f(-1, -2, -3, -4, -5.0, -6.0) + self.assertEqual(result, -21) + self.assertEqual(type(result), int) + + # If we declare the function to return a short, + # is the high part split off? + f.restype = c_short + result = f(1, 2, 3, 4, 5.0, 6.0) + self.assertEqual(result, 21) + self.assertEqual(type(result), int) + + result = f(1, 2, 3, 0x10004, 5.0, 6.0) + self.assertEqual(result, 21) + self.assertEqual(type(result), int) + + # You cannot assign character format codes as restype any longer + self.assertRaises(TypeError, setattr, f, "restype", "i") + + def test_floatresult(self): + f = dll._testfunc_f_bhilfd + f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] + f.restype = c_float + result = f(1, 2, 3, 4, 5.0, 6.0) + self.assertEqual(result, 21) + self.assertEqual(type(result), float) + + result = f(-1, -2, -3, -4, -5.0, -6.0) + self.assertEqual(result, -21) + self.assertEqual(type(result), float) + + def test_doubleresult(self): + f = dll._testfunc_d_bhilfd + f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] + f.restype = c_double + result = f(1, 2, 3, 4, 5.0, 6.0) + self.assertEqual(result, 21) + self.assertEqual(type(result), float) + + result = f(-1, -2, -3, -4, -5.0, -6.0) + self.assertEqual(result, -21) + self.assertEqual(type(result), float) + + def test_longdoubleresult(self): + f = dll._testfunc_D_bhilfD + f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_longdouble] + f.restype = c_longdouble + result = f(1, 2, 3, 4, 5.0, 6.0) + self.assertEqual(result, 21) + self.assertEqual(type(result), float) + + result = f(-1, -2, -3, -4, -5.0, -6.0) + self.assertEqual(result, -21) + self.assertEqual(type(result), float) + + def test_longlongresult(self): + f = dll._testfunc_q_bhilfd + f.restype = c_longlong + f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double] + result = f(1, 2, 3, 4, 5.0, 6.0) + self.assertEqual(result, 21) + + f = dll._testfunc_q_bhilfdq + f.restype = c_longlong + f.argtypes = [c_byte, c_short, c_int, c_long, c_float, c_double, c_longlong] + result = f(1, 2, 3, 4, 5.0, 6.0, 21) + self.assertEqual(result, 42) + + def test_stringresult(self): + f = dll._testfunc_p_p + f.argtypes = None + f.restype = c_char_p + result = f(b"123") + self.assertEqual(result, b"123") + + result = f(None) + self.assertEqual(result, None) + + def test_pointers(self): + f = dll._testfunc_p_p + f.restype = POINTER(c_int) + f.argtypes = [POINTER(c_int)] + + # This only works if the value c_int(42) passed to the + # function is still alive while the pointer (the result) is + # used. + + v = c_int(42) + + self.assertEqual(pointer(v).contents.value, 42) + result = f(pointer(v)) + self.assertEqual(type(result), POINTER(c_int)) + self.assertEqual(result.contents.value, 42) + + # This on works... + result = f(pointer(v)) + self.assertEqual(result.contents.value, v.value) + + p = pointer(c_int(99)) + result = f(p) + self.assertEqual(result.contents.value, 99) + + arg = byref(v) + result = f(arg) + self.assertNotEqual(result.contents, v.value) + + self.assertRaises(ArgumentError, f, byref(c_short(22))) + + # It is dangerous, however, because you don't control the lifetime + # of the pointer: + result = f(byref(c_int(99))) + self.assertNotEqual(result.contents, 99) + + def test_shorts(self): + f = dll._testfunc_callback_i_if + + args = [] + expected = [262144, 131072, 65536, 32768, 16384, 8192, 4096, 2048, + 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1] + + def callback(v): + args.append(v) + return v + + CallBack = CFUNCTYPE(c_int, c_int) + + cb = CallBack(callback) + f(2**18, cb) + self.assertEqual(args, expected) + + def test_callbacks(self): + f = dll._testfunc_callback_i_if + f.restype = c_int + f.argtypes = None + + MyCallback = CFUNCTYPE(c_int, c_int) + + def callback(value): + return value + + cb = MyCallback(callback) + result = f(-10, cb) + self.assertEqual(result, -18) + + # test with prototype + f.argtypes = [c_int, MyCallback] + cb = MyCallback(callback) + result = f(-10, cb) + self.assertEqual(result, -18) + + AnotherCallback = WINFUNCTYPE(c_int, c_int, c_int, c_int, c_int) + + # check that the prototype works: we call f with wrong + # argument types + cb = AnotherCallback(callback) + self.assertRaises(ArgumentError, f, -10, cb) + + + def test_callbacks_2(self): + # Can also use simple datatypes as argument type specifiers + # for the callback function. + # In this case the call receives an instance of that type + f = dll._testfunc_callback_i_if + f.restype = c_int + + MyCallback = CFUNCTYPE(c_int, c_int) + + f.argtypes = [c_int, MyCallback] + + def callback(value): + self.assertEqual(type(value), int) + return value + + cb = MyCallback(callback) + result = f(-10, cb) + self.assertEqual(result, -18) + + def test_longlong_callbacks(self): + + f = dll._testfunc_callback_q_qf + f.restype = c_longlong + + MyCallback = CFUNCTYPE(c_longlong, c_longlong) + + f.argtypes = [c_longlong, MyCallback] + + def callback(value): + self.assertIsInstance(value, int) + return value & 0x7FFFFFFF + + cb = MyCallback(callback) + + self.assertEqual(13577625587, f(1000000000000, cb)) + + def test_errors(self): + self.assertRaises(AttributeError, getattr, dll, "_xxx_yyy") + self.assertRaises(ValueError, c_int.in_dll, dll, "_xxx_yyy") + + def test_byval(self): + + # without prototype + ptin = POINT(1, 2) + ptout = POINT() + # EXPORT int _testfunc_byval(point in, point *pout) + result = dll._testfunc_byval(ptin, byref(ptout)) + got = result, ptout.x, ptout.y + expected = 3, 1, 2 + self.assertEqual(got, expected) + + # with prototype + ptin = POINT(101, 102) + ptout = POINT() + dll._testfunc_byval.argtypes = (POINT, POINTER(POINT)) + dll._testfunc_byval.restype = c_int + result = dll._testfunc_byval(ptin, byref(ptout)) + got = result, ptout.x, ptout.y + expected = 203, 101, 102 + self.assertEqual(got, expected) + + def test_struct_return_2H(self): + class S2H(Structure): + _fields_ = [("x", c_short), + ("y", c_short)] + dll.ret_2h_func.restype = S2H + dll.ret_2h_func.argtypes = [S2H] + inp = S2H(99, 88) + s2h = dll.ret_2h_func(inp) + self.assertEqual((s2h.x, s2h.y), (99*2, 88*3)) + + @unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') + def test_struct_return_2H_stdcall(self): + class S2H(Structure): + _fields_ = [("x", c_short), + ("y", c_short)] + + windll.s_ret_2h_func.restype = S2H + windll.s_ret_2h_func.argtypes = [S2H] + s2h = windll.s_ret_2h_func(S2H(99, 88)) + self.assertEqual((s2h.x, s2h.y), (99*2, 88*3)) + + def test_struct_return_8H(self): + class S8I(Structure): + _fields_ = [("a", c_int), + ("b", c_int), + ("c", c_int), + ("d", c_int), + ("e", c_int), + ("f", c_int), + ("g", c_int), + ("h", c_int)] + dll.ret_8i_func.restype = S8I + dll.ret_8i_func.argtypes = [S8I] + inp = S8I(9, 8, 7, 6, 5, 4, 3, 2) + s8i = dll.ret_8i_func(inp) + self.assertEqual((s8i.a, s8i.b, s8i.c, s8i.d, s8i.e, s8i.f, s8i.g, s8i.h), + (9*2, 8*3, 7*4, 6*5, 5*6, 4*7, 3*8, 2*9)) + + @unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') + def test_struct_return_8H_stdcall(self): + class S8I(Structure): + _fields_ = [("a", c_int), + ("b", c_int), + ("c", c_int), + ("d", c_int), + ("e", c_int), + ("f", c_int), + ("g", c_int), + ("h", c_int)] + windll.s_ret_8i_func.restype = S8I + windll.s_ret_8i_func.argtypes = [S8I] + inp = S8I(9, 8, 7, 6, 5, 4, 3, 2) + s8i = windll.s_ret_8i_func(inp) + self.assertEqual( + (s8i.a, s8i.b, s8i.c, s8i.d, s8i.e, s8i.f, s8i.g, s8i.h), + (9*2, 8*3, 7*4, 6*5, 5*6, 4*7, 3*8, 2*9)) + + def test_sf1651235(self): + # see https://bugs.python.org/issue1651235 + + proto = CFUNCTYPE(c_int, RECT, POINT) + def callback(*args): + return 0 + + callback = proto(callback) + self.assertRaises(ArgumentError, lambda: callback((1, 2, 3, 4), POINT())) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_incomplete.py b/Lib/test/test_ctypes/test_incomplete.py new file mode 100644 index 00000000000..9f859793d88 --- /dev/null +++ b/Lib/test/test_ctypes/test_incomplete.py @@ -0,0 +1,50 @@ +import ctypes +import unittest +import warnings +from ctypes import Structure, POINTER, pointer, c_char_p + + +# The incomplete pointer example from the tutorial +class TestSetPointerType(unittest.TestCase): + def tearDown(self): + # to not leak references, we must clean _pointer_type_cache + ctypes._reset_cache() + + def test_incomplete_example(self): + lpcell = POINTER("cell") + class cell(Structure): + _fields_ = [("name", c_char_p), + ("next", lpcell)] + + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + ctypes.SetPointerType(lpcell, cell) + + c1 = cell() + c1.name = b"foo" + c2 = cell() + c2.name = b"bar" + + c1.next = pointer(c2) + c2.next = pointer(c1) + + p = c1 + + result = [] + for i in range(8): + result.append(p.name) + p = p.next[0] + self.assertEqual(result, [b"foo", b"bar"] * 4) + + def test_deprecation(self): + lpcell = POINTER("cell") + class cell(Structure): + _fields_ = [("name", c_char_p), + ("next", lpcell)] + + with self.assertWarns(DeprecationWarning): + ctypes.SetPointerType(lpcell, cell) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_init.py b/Lib/test/test_ctypes/test_init.py new file mode 100644 index 00000000000..113425e5823 --- /dev/null +++ b/Lib/test/test_ctypes/test_init.py @@ -0,0 +1,43 @@ +import unittest +from ctypes import Structure, c_int + + +class X(Structure): + _fields_ = [("a", c_int), + ("b", c_int)] + new_was_called = False + + def __new__(cls): + result = super().__new__(cls) + result.new_was_called = True + return result + + def __init__(self): + self.a = 9 + self.b = 12 + + +class Y(Structure): + _fields_ = [("x", X)] + + +class InitTest(unittest.TestCase): + def test_get(self): + # make sure the only accessing a nested structure + # doesn't call the structure's __new__ and __init__ + y = Y() + self.assertEqual((y.x.a, y.x.b), (0, 0)) + self.assertEqual(y.x.new_was_called, False) + + # But explicitly creating an X structure calls __new__ and __init__, of course. + x = X() + self.assertEqual((x.a, x.b), (9, 12)) + self.assertEqual(x.new_was_called, True) + + y.x = x + self.assertEqual((y.x.a, y.x.b), (9, 12)) + self.assertEqual(y.x.new_was_called, False) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_internals.py b/Lib/test/test_ctypes/test_internals.py new file mode 100644 index 00000000000..778da6573da --- /dev/null +++ b/Lib/test/test_ctypes/test_internals.py @@ -0,0 +1,97 @@ +# This tests the internal _objects attribute + +# XXX This test must be reviewed for correctness!!! + +# ctypes' types are container types. +# +# They have an internal memory block, which only consists of some bytes, +# but it has to keep references to other objects as well. This is not +# really needed for trivial C types like int or char, but it is important +# for aggregate types like strings or pointers in particular. +# +# What about pointers? + +import sys +import unittest +from ctypes import Structure, POINTER, c_char_p, c_int + + +class ObjectsTestCase(unittest.TestCase): + def assertSame(self, a, b): + self.assertEqual(id(a), id(b)) + + def test_ints(self): + i = 42000123 + refcnt = sys.getrefcount(i) + ci = c_int(i) + self.assertEqual(refcnt, sys.getrefcount(i)) + self.assertEqual(ci._objects, None) + + def test_c_char_p(self): + s = "Hello, World".encode("ascii") + refcnt = sys.getrefcount(s) + cs = c_char_p(s) + self.assertEqual(refcnt + 1, sys.getrefcount(s)) + self.assertSame(cs._objects, s) + + def test_simple_struct(self): + class X(Structure): + _fields_ = [("a", c_int), ("b", c_int)] + + a = 421234 + b = 421235 + x = X() + self.assertEqual(x._objects, None) + x.a = a + x.b = b + self.assertEqual(x._objects, None) + + def test_embedded_structs(self): + class X(Structure): + _fields_ = [("a", c_int), ("b", c_int)] + + class Y(Structure): + _fields_ = [("x", X), ("y", X)] + + y = Y() + self.assertEqual(y._objects, None) + + x1, x2 = X(), X() + y.x, y.y = x1, x2 + self.assertEqual(y._objects, {"0": {}, "1": {}}) + x1.a, x2.b = 42, 93 + self.assertEqual(y._objects, {"0": {}, "1": {}}) + + def test_xxx(self): + class X(Structure): + _fields_ = [("a", c_char_p), ("b", c_char_p)] + + class Y(Structure): + _fields_ = [("x", X), ("y", X)] + + s1 = b"Hello, World" + s2 = b"Hallo, Welt" + + x = X() + x.a = s1 + x.b = s2 + self.assertEqual(x._objects, {"0": s1, "1": s2}) + + y = Y() + y.x = x + self.assertEqual(y._objects, {"0": {"0": s1, "1": s2}}) + + def test_ptr_struct(self): + class X(Structure): + _fields_ = [("data", POINTER(c_int))] + + A = c_int*4 + a = A(11, 22, 33, 44) + self.assertEqual(a._objects, None) + + x = X() + x.data = a + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_keeprefs.py b/Lib/test/test_ctypes/test_keeprefs.py new file mode 100644 index 00000000000..23b03b64b4a --- /dev/null +++ b/Lib/test/test_ctypes/test_keeprefs.py @@ -0,0 +1,124 @@ +import unittest +from ctypes import (Structure, POINTER, pointer, _pointer_type_cache, + c_char_p, c_int) + + +class SimpleTestCase(unittest.TestCase): + def test_cint(self): + x = c_int() + self.assertEqual(x._objects, None) + x.value = 42 + self.assertEqual(x._objects, None) + x = c_int(99) + self.assertEqual(x._objects, None) + + def test_ccharp(self): + x = c_char_p() + self.assertEqual(x._objects, None) + x.value = b"abc" + self.assertEqual(x._objects, b"abc") + x = c_char_p(b"spam") + self.assertEqual(x._objects, b"spam") + + +class StructureTestCase(unittest.TestCase): + def test_cint_struct(self): + class X(Structure): + _fields_ = [("a", c_int), + ("b", c_int)] + + x = X() + self.assertEqual(x._objects, None) + x.a = 42 + x.b = 99 + self.assertEqual(x._objects, None) + + def test_ccharp_struct(self): + class X(Structure): + _fields_ = [("a", c_char_p), + ("b", c_char_p)] + x = X() + self.assertEqual(x._objects, None) + + x.a = b"spam" + x.b = b"foo" + self.assertEqual(x._objects, {"0": b"spam", "1": b"foo"}) + + def test_struct_struct(self): + class POINT(Structure): + _fields_ = [("x", c_int), ("y", c_int)] + class RECT(Structure): + _fields_ = [("ul", POINT), ("lr", POINT)] + + r = RECT() + r.ul.x = 0 + r.ul.y = 1 + r.lr.x = 2 + r.lr.y = 3 + self.assertEqual(r._objects, None) + + r = RECT() + pt = POINT(1, 2) + r.ul = pt + self.assertEqual(r._objects, {'0': {}}) + r.ul.x = 22 + r.ul.y = 44 + self.assertEqual(r._objects, {'0': {}}) + r.lr = POINT() + self.assertEqual(r._objects, {'0': {}, '1': {}}) + + +class ArrayTestCase(unittest.TestCase): + def test_cint_array(self): + INTARR = c_int * 3 + + ia = INTARR() + self.assertEqual(ia._objects, None) + ia[0] = 1 + ia[1] = 2 + ia[2] = 3 + self.assertEqual(ia._objects, None) + + class X(Structure): + _fields_ = [("x", c_int), + ("a", INTARR)] + + x = X() + x.x = 1000 + x.a[0] = 42 + x.a[1] = 96 + self.assertEqual(x._objects, None) + x.a = ia + self.assertEqual(x._objects, {'1': {}}) + + +class PointerTestCase(unittest.TestCase): + def test_p_cint(self): + i = c_int(42) + x = pointer(i) + self.assertEqual(x._objects, {'1': i}) + + +class PointerToStructure(unittest.TestCase): + def test(self): + class POINT(Structure): + _fields_ = [("x", c_int), ("y", c_int)] + class RECT(Structure): + _fields_ = [("a", POINTER(POINT)), + ("b", POINTER(POINT))] + r = RECT() + p1 = POINT(1, 2) + + r.a = pointer(p1) + r.b = pointer(p1) + + r.a[0].x = 42 + r.a[0].y = 99 + + # to avoid leaking when tests are run several times + # clean up the types left in the cache. + del _pointer_type_cache[POINT] + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_libc.py b/Lib/test/test_ctypes/test_libc.py new file mode 100644 index 00000000000..7716100b08f --- /dev/null +++ b/Lib/test/test_ctypes/test_libc.py @@ -0,0 +1,38 @@ +import math +import unittest +from ctypes import (CDLL, CFUNCTYPE, POINTER, create_string_buffer, sizeof, + c_void_p, c_char, c_int, c_double, c_size_t) +from test.support import import_helper +_ctypes_test = import_helper.import_module("_ctypes_test") + + +lib = CDLL(_ctypes_test.__file__) + + +def three_way_cmp(x, y): + """Return -1 if x < y, 0 if x == y and 1 if x > y""" + return (x > y) - (x < y) + + +class LibTest(unittest.TestCase): + def test_sqrt(self): + lib.my_sqrt.argtypes = c_double, + lib.my_sqrt.restype = c_double + self.assertEqual(lib.my_sqrt(4.0), 2.0) + self.assertEqual(lib.my_sqrt(2.0), math.sqrt(2.0)) + + def test_qsort(self): + comparefunc = CFUNCTYPE(c_int, POINTER(c_char), POINTER(c_char)) + lib.my_qsort.argtypes = c_void_p, c_size_t, c_size_t, comparefunc + lib.my_qsort.restype = None + + def sort(a, b): + return three_way_cmp(a[0], b[0]) + + chars = create_string_buffer(b"spam, spam, and spam") + lib.my_qsort(chars, len(chars)-1, sizeof(c_char), comparefunc(sort)) + self.assertEqual(chars.raw, b" ,,aaaadmmmnpppsss\x00") + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_loading.py b/Lib/test/test_ctypes/test_loading.py new file mode 100644 index 00000000000..a4d54f676a6 --- /dev/null +++ b/Lib/test/test_ctypes/test_loading.py @@ -0,0 +1,208 @@ +import _ctypes +import ctypes +import os +import shutil +import subprocess +import sys +import test.support +import unittest +from ctypes import CDLL, cdll, addressof, c_void_p, c_char_p +from ctypes.util import find_library +from test.support import import_helper, os_helper +_ctypes_test = import_helper.import_module("_ctypes_test") + + +libc_name = None + + +def setUpModule(): + global libc_name + if os.name == "nt": + libc_name = find_library("c") + elif sys.platform == "cygwin": + libc_name = "cygwin1.dll" + else: + libc_name = find_library("c") + + if test.support.verbose: + print("libc_name is", libc_name) + + +class LoaderTest(unittest.TestCase): + + unknowndll = "xxrandomnamexx" + + def test_load(self): + if libc_name is not None: + test_lib = libc_name + else: + if os.name == "nt": + test_lib = _ctypes_test.__file__ + else: + self.skipTest('could not find library to load') + CDLL(test_lib) + CDLL(os.path.basename(test_lib)) + CDLL(os_helper.FakePath(test_lib)) + self.assertRaises(OSError, CDLL, self.unknowndll) + + def test_load_version(self): + if libc_name is None: + self.skipTest('could not find libc') + if os.path.basename(libc_name) != 'libc.so.6': + self.skipTest('wrong libc path for test') + cdll.LoadLibrary("libc.so.6") + # linux uses version, libc 9 should not exist + self.assertRaises(OSError, cdll.LoadLibrary, "libc.so.9") + self.assertRaises(OSError, cdll.LoadLibrary, self.unknowndll) + + def test_find(self): + found = False + for name in ("c", "m"): + lib = find_library(name) + if lib: + found = True + cdll.LoadLibrary(lib) + CDLL(lib) + if not found: + self.skipTest("Could not find c and m libraries") + + @unittest.skipUnless(os.name == "nt", + 'test specific to Windows') + def test_load_library(self): + # CRT is no longer directly loadable. See issue23606 for the + # discussion about alternative approaches. + #self.assertIsNotNone(libc_name) + if test.support.verbose: + print(find_library("kernel32")) + print(find_library("user32")) + + if os.name == "nt": + ctypes.windll.kernel32.GetModuleHandleW + ctypes.windll["kernel32"].GetModuleHandleW + ctypes.windll.LoadLibrary("kernel32").GetModuleHandleW + ctypes.WinDLL("kernel32").GetModuleHandleW + # embedded null character + self.assertRaises(ValueError, ctypes.windll.LoadLibrary, "kernel32\0") + + @unittest.skipUnless(os.name == "nt", + 'test specific to Windows') + def test_load_ordinal_functions(self): + dll = ctypes.WinDLL(_ctypes_test.__file__) + # We load the same function both via ordinal and name + func_ord = dll[2] + func_name = dll.GetString + # addressof gets the address where the function pointer is stored + a_ord = addressof(func_ord) + a_name = addressof(func_name) + f_ord_addr = c_void_p.from_address(a_ord).value + f_name_addr = c_void_p.from_address(a_name).value + self.assertEqual(hex(f_ord_addr), hex(f_name_addr)) + + self.assertRaises(AttributeError, dll.__getitem__, 1234) + + @unittest.skipUnless(os.name == "nt", 'Windows-specific test') + def test_load_without_name_and_with_handle(self): + handle = ctypes.windll.kernel32._handle + lib = ctypes.WinDLL(name=None, handle=handle) + self.assertIs(handle, lib._handle) + + @unittest.skipUnless(os.name == "nt", 'Windows-specific test') + def test_1703286_A(self): + # On winXP 64-bit, advapi32 loads at an address that does + # NOT fit into a 32-bit integer. FreeLibrary must be able + # to accept this address. + + # These are tests for https://bugs.python.org/issue1703286 + handle = _ctypes.LoadLibrary("advapi32") + _ctypes.FreeLibrary(handle) + + @unittest.skipUnless(os.name == "nt", 'Windows-specific test') + def test_1703286_B(self): + # Since on winXP 64-bit advapi32 loads like described + # above, the (arbitrarily selected) CloseEventLog function + # also has a high address. 'call_function' should accept + # addresses so large. + + advapi32 = ctypes.windll.advapi32 + # Calling CloseEventLog with a NULL argument should fail, + # but the call should not segfault or so. + self.assertEqual(0, advapi32.CloseEventLog(None)) + + kernel32 = ctypes.windll.kernel32 + kernel32.GetProcAddress.argtypes = c_void_p, c_char_p + kernel32.GetProcAddress.restype = c_void_p + proc = kernel32.GetProcAddress(advapi32._handle, b"CloseEventLog") + self.assertTrue(proc) + + # This is the real test: call the function via 'call_function' + self.assertEqual(0, _ctypes.call_function(proc, (None,))) + + @unittest.skipUnless(os.name == "nt", + 'test specific to Windows') + def test_load_hasattr(self): + # bpo-34816: shouldn't raise OSError + self.assertFalse(hasattr(ctypes.windll, 'test')) + + @unittest.skipUnless(os.name == "nt", + 'test specific to Windows') + def test_load_dll_with_flags(self): + _sqlite3 = import_helper.import_module("_sqlite3") + src = _sqlite3.__file__ + if os.path.basename(src).partition(".")[0].lower().endswith("_d"): + ext = "_d.dll" + else: + ext = ".dll" + + with os_helper.temp_dir() as tmp: + # We copy two files and load _sqlite3.dll (formerly .pyd), + # which has a dependency on sqlite3.dll. Then we test + # loading it in subprocesses to avoid it starting in memory + # for each test. + target = os.path.join(tmp, "_sqlite3.dll") + shutil.copy(src, target) + shutil.copy(os.path.join(os.path.dirname(src), "sqlite3" + ext), + os.path.join(tmp, "sqlite3" + ext)) + + def should_pass(command): + with self.subTest(command): + subprocess.check_output( + [sys.executable, "-c", + "from ctypes import *; import nt;" + command], + cwd=tmp + ) + + def should_fail(command): + with self.subTest(command): + with self.assertRaises(subprocess.CalledProcessError): + subprocess.check_output( + [sys.executable, "-c", + "from ctypes import *; import nt;" + command], + cwd=tmp, stderr=subprocess.STDOUT, + ) + + # Default load should not find this in CWD + should_fail("WinDLL('_sqlite3.dll')") + + # Relative path (but not just filename) should succeed + should_pass("WinDLL('./_sqlite3.dll')") + + # Insecure load flags should succeed + # Clear the DLL directory to avoid safe search settings propagating + should_pass("windll.kernel32.SetDllDirectoryW(None); WinDLL('_sqlite3.dll', winmode=0)") + + # Full path load without DLL_LOAD_DIR shouldn't find dependency + should_fail("WinDLL(nt._getfullpathname('_sqlite3.dll'), " + + "winmode=nt._LOAD_LIBRARY_SEARCH_SYSTEM32)") + + # Full path load with DLL_LOAD_DIR should succeed + should_pass("WinDLL(nt._getfullpathname('_sqlite3.dll'), " + + "winmode=nt._LOAD_LIBRARY_SEARCH_SYSTEM32|" + + "nt._LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)") + + # User-specified directory should succeed + should_pass("import os; p = os.add_dll_directory(os.getcwd());" + + "WinDLL('_sqlite3.dll'); p.close()") + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_macholib.py b/Lib/test/test_ctypes/test_macholib.py new file mode 100644 index 00000000000..9d906179956 --- /dev/null +++ b/Lib/test/test_ctypes/test_macholib.py @@ -0,0 +1,112 @@ +# Bob Ippolito: +# +# Ok.. the code to find the filename for __getattr__ should look +# something like: +# +# import os +# from macholib.dyld import dyld_find +# +# def find_lib(name): +# possible = ['lib'+name+'.dylib', name+'.dylib', +# name+'.framework/'+name] +# for dylib in possible: +# try: +# return os.path.realpath(dyld_find(dylib)) +# except ValueError: +# pass +# raise ValueError, "%s not found" % (name,) +# +# It'll have output like this: +# +# >>> find_lib('pthread') +# '/usr/lib/libSystem.B.dylib' +# >>> find_lib('z') +# '/usr/lib/libz.1.dylib' +# >>> find_lib('IOKit') +# '/System/Library/Frameworks/IOKit.framework/Versions/A/IOKit' +# +# -bob + +import os +import sys +import unittest + +from ctypes.macholib.dyld import dyld_find +from ctypes.macholib.dylib import dylib_info +from ctypes.macholib.framework import framework_info + + +def find_lib(name): + possible = ['lib'+name+'.dylib', name+'.dylib', name+'.framework/'+name] + for dylib in possible: + try: + return os.path.realpath(dyld_find(dylib)) + except ValueError: + pass + raise ValueError("%s not found" % (name,)) + + +def d(location=None, name=None, shortname=None, version=None, suffix=None): + return {'location': location, 'name': name, 'shortname': shortname, + 'version': version, 'suffix': suffix} + + +class MachOTest(unittest.TestCase): + @unittest.skipUnless(sys.platform == "darwin", 'OSX-specific test') + def test_find(self): + self.assertEqual(dyld_find('libSystem.dylib'), + '/usr/lib/libSystem.dylib') + self.assertEqual(dyld_find('System.framework/System'), + '/System/Library/Frameworks/System.framework/System') + + # On Mac OS 11, system dylibs are only present in the shared cache, + # so symlinks like libpthread.dylib -> libSystem.B.dylib will not + # be resolved by dyld_find + self.assertIn(find_lib('pthread'), + ('/usr/lib/libSystem.B.dylib', '/usr/lib/libpthread.dylib')) + + result = find_lib('z') + # Issue #21093: dyld default search path includes $HOME/lib and + # /usr/local/lib before /usr/lib, which caused test failures if + # a local copy of libz exists in one of them. Now ignore the head + # of the path. + self.assertRegex(result, r".*/lib/libz.*\.dylib") + + self.assertIn(find_lib('IOKit'), + ('/System/Library/Frameworks/IOKit.framework/Versions/A/IOKit', + '/System/Library/Frameworks/IOKit.framework/IOKit')) + + @unittest.skipUnless(sys.platform == "darwin", 'OSX-specific test') + def test_info(self): + self.assertIsNone(dylib_info('completely/invalid')) + self.assertIsNone(dylib_info('completely/invalide_debug')) + self.assertEqual(dylib_info('P/Foo.dylib'), d('P', 'Foo.dylib', 'Foo')) + self.assertEqual(dylib_info('P/Foo_debug.dylib'), + d('P', 'Foo_debug.dylib', 'Foo', suffix='debug')) + self.assertEqual(dylib_info('P/Foo.A.dylib'), + d('P', 'Foo.A.dylib', 'Foo', 'A')) + self.assertEqual(dylib_info('P/Foo_debug.A.dylib'), + d('P', 'Foo_debug.A.dylib', 'Foo_debug', 'A')) + self.assertEqual(dylib_info('P/Foo.A_debug.dylib'), + d('P', 'Foo.A_debug.dylib', 'Foo', 'A', 'debug')) + + @unittest.skipUnless(sys.platform == "darwin", 'OSX-specific test') + def test_framework_info(self): + self.assertIsNone(framework_info('completely/invalid')) + self.assertIsNone(framework_info('completely/invalid/_debug')) + self.assertIsNone(framework_info('P/F.framework')) + self.assertIsNone(framework_info('P/F.framework/_debug')) + self.assertEqual(framework_info('P/F.framework/F'), + d('P', 'F.framework/F', 'F')) + self.assertEqual(framework_info('P/F.framework/F_debug'), + d('P', 'F.framework/F_debug', 'F', suffix='debug')) + self.assertIsNone(framework_info('P/F.framework/Versions')) + self.assertIsNone(framework_info('P/F.framework/Versions/A')) + self.assertEqual(framework_info('P/F.framework/Versions/A/F'), + d('P', 'F.framework/Versions/A/F', 'F', 'A')) + self.assertEqual(framework_info('P/F.framework/Versions/A/F_debug'), + d('P', 'F.framework/Versions/A/F_debug', 'F', 'A', 'debug')) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_memfunctions.py b/Lib/test/test_ctypes/test_memfunctions.py new file mode 100644 index 00000000000..112b27ba48e --- /dev/null +++ b/Lib/test/test_ctypes/test_memfunctions.py @@ -0,0 +1,82 @@ +import sys +import unittest +from test import support +from ctypes import (POINTER, sizeof, cast, + create_string_buffer, string_at, + create_unicode_buffer, wstring_at, + memmove, memset, + c_char_p, c_byte, c_ubyte, c_wchar) + + +class MemFunctionsTest(unittest.TestCase): + def test_overflow(self): + # string_at and wstring_at must use the Python calling + # convention (which acquires the GIL and checks the Python + # error flag). Provoke an error and catch it; see also issue + # gh-47804. + self.assertRaises((OverflowError, MemoryError, SystemError), + lambda: wstring_at(u"foo", sys.maxsize - 1)) + self.assertRaises((OverflowError, MemoryError, SystemError), + lambda: string_at("foo", sys.maxsize - 1)) + + def test_memmove(self): + # large buffers apparently increase the chance that the memory + # is allocated in high address space. + a = create_string_buffer(1000000) + p = b"Hello, World" + result = memmove(a, p, len(p)) + self.assertEqual(a.value, b"Hello, World") + + self.assertEqual(string_at(result), b"Hello, World") + self.assertEqual(string_at(result, 5), b"Hello") + self.assertEqual(string_at(result, 16), b"Hello, World\0\0\0\0") + self.assertEqual(string_at(result, 0), b"") + + def test_memset(self): + a = create_string_buffer(1000000) + result = memset(a, ord('x'), 16) + self.assertEqual(a.value, b"xxxxxxxxxxxxxxxx") + + self.assertEqual(string_at(result), b"xxxxxxxxxxxxxxxx") + self.assertEqual(string_at(a), b"xxxxxxxxxxxxxxxx") + self.assertEqual(string_at(a, 20), b"xxxxxxxxxxxxxxxx\0\0\0\0") + + def test_cast(self): + a = (c_ubyte * 32)(*map(ord, "abcdef")) + self.assertEqual(cast(a, c_char_p).value, b"abcdef") + self.assertEqual(cast(a, POINTER(c_byte))[:7], + [97, 98, 99, 100, 101, 102, 0]) + self.assertEqual(cast(a, POINTER(c_byte))[:7:], + [97, 98, 99, 100, 101, 102, 0]) + self.assertEqual(cast(a, POINTER(c_byte))[6:-1:-1], + [0, 102, 101, 100, 99, 98, 97]) + self.assertEqual(cast(a, POINTER(c_byte))[:7:2], + [97, 99, 101, 0]) + self.assertEqual(cast(a, POINTER(c_byte))[:7:7], + [97]) + + @support.refcount_test + def test_string_at(self): + s = string_at(b"foo bar") + # XXX The following may be wrong, depending on how Python + # manages string instances + self.assertEqual(2, sys.getrefcount(s)) + self.assertTrue(s, "foo bar") + + self.assertEqual(string_at(b"foo bar", 7), b"foo bar") + self.assertEqual(string_at(b"foo bar", 3), b"foo") + + def test_wstring_at(self): + p = create_unicode_buffer("Hello, World") + a = create_unicode_buffer(1000000) + result = memmove(a, p, len(p) * sizeof(c_wchar)) + self.assertEqual(a.value, "Hello, World") + + self.assertEqual(wstring_at(a), "Hello, World") + self.assertEqual(wstring_at(a, 5), "Hello") + self.assertEqual(wstring_at(a, 16), "Hello, World\0\0\0\0") + self.assertEqual(wstring_at(a, 0), "") + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_numbers.py b/Lib/test/test_ctypes/test_numbers.py new file mode 100644 index 00000000000..29108a28ec1 --- /dev/null +++ b/Lib/test/test_ctypes/test_numbers.py @@ -0,0 +1,200 @@ +import array +import struct +import sys +import unittest +from operator import truth +from ctypes import (byref, sizeof, alignment, + c_char, c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, + c_long, c_ulong, c_longlong, c_ulonglong, + c_float, c_double, c_longdouble, c_bool) + + +def valid_ranges(*types): + # given a sequence of numeric types, collect their _type_ + # attribute, which is a single format character compatible with + # the struct module, use the struct module to calculate the + # minimum and maximum value allowed for this format. + # Returns a list of (min, max) values. + result = [] + for t in types: + fmt = t._type_ + size = struct.calcsize(fmt) + a = struct.unpack(fmt, (b"\x00"*32)[:size])[0] + b = struct.unpack(fmt, (b"\xFF"*32)[:size])[0] + c = struct.unpack(fmt, (b"\x7F"+b"\x00"*32)[:size])[0] + d = struct.unpack(fmt, (b"\x80"+b"\xFF"*32)[:size])[0] + result.append((min(a, b, c, d), max(a, b, c, d))) + return result + + +ArgType = type(byref(c_int(0))) + +unsigned_types = [c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong] +signed_types = [c_byte, c_short, c_int, c_long, c_longlong] +bool_types = [c_bool] +float_types = [c_double, c_float] + +unsigned_ranges = valid_ranges(*unsigned_types) +signed_ranges = valid_ranges(*signed_types) +bool_values = [True, False, 0, 1, -1, 5000, 'test', [], [1]] + + +class NumberTestCase(unittest.TestCase): + + def test_default_init(self): + # default values are set to zero + for t in signed_types + unsigned_types + float_types: + self.assertEqual(t().value, 0) + + def test_unsigned_values(self): + # the value given to the constructor is available + # as the 'value' attribute + for t, (l, h) in zip(unsigned_types, unsigned_ranges): + self.assertEqual(t(l).value, l) + self.assertEqual(t(h).value, h) + + def test_signed_values(self): + # see above + for t, (l, h) in zip(signed_types, signed_ranges): + self.assertEqual(t(l).value, l) + self.assertEqual(t(h).value, h) + + def test_bool_values(self): + for t, v in zip(bool_types, bool_values): + self.assertEqual(t(v).value, truth(v)) + + def test_typeerror(self): + # Only numbers are allowed in the constructor, + # otherwise TypeError is raised + for t in signed_types + unsigned_types + float_types: + self.assertRaises(TypeError, t, "") + self.assertRaises(TypeError, t, None) + + def test_from_param(self): + # the from_param class method attribute always + # returns PyCArgObject instances + for t in signed_types + unsigned_types + float_types: + self.assertEqual(ArgType, type(t.from_param(0))) + + def test_byref(self): + # calling byref returns also a PyCArgObject instance + for t in signed_types + unsigned_types + float_types + bool_types: + parm = byref(t()) + self.assertEqual(ArgType, type(parm)) + + + def test_floats(self): + # c_float and c_double can be created from + # Python int and float + class FloatLike: + def __float__(self): + return 2.0 + f = FloatLike() + for t in float_types: + self.assertEqual(t(2.0).value, 2.0) + self.assertEqual(t(2).value, 2.0) + self.assertEqual(t(2).value, 2.0) + self.assertEqual(t(f).value, 2.0) + + def test_integers(self): + class FloatLike: + def __float__(self): + return 2.0 + f = FloatLike() + class IntLike: + def __int__(self): + return 2 + d = IntLike() + class IndexLike: + def __index__(self): + return 2 + i = IndexLike() + # integers cannot be constructed from floats, + # but from integer-like objects + for t in signed_types + unsigned_types: + self.assertRaises(TypeError, t, 3.14) + self.assertRaises(TypeError, t, f) + self.assertRaises(TypeError, t, d) + self.assertEqual(t(i).value, 2) + + def test_sizes(self): + for t in signed_types + unsigned_types + float_types + bool_types: + try: + size = struct.calcsize(t._type_) + except struct.error: + continue + # sizeof of the type... + self.assertEqual(sizeof(t), size) + # and sizeof of an instance + self.assertEqual(sizeof(t()), size) + + def test_alignments(self): + for t in signed_types + unsigned_types + float_types: + code = t._type_ # the typecode + align = struct.calcsize("c%c" % code) - struct.calcsize(code) + + # alignment of the type... + self.assertEqual((code, alignment(t)), + (code, align)) + # and alignment of an instance + self.assertEqual((code, alignment(t())), + (code, align)) + + def test_int_from_address(self): + for t in signed_types + unsigned_types: + # the array module doesn't support all format codes + # (no 'q' or 'Q') + try: + array.array(t._type_) + except ValueError: + continue + a = array.array(t._type_, [100]) + + # v now is an integer at an 'external' memory location + v = t.from_address(a.buffer_info()[0]) + self.assertEqual(v.value, a[0]) + self.assertEqual(type(v), t) + + # changing the value at the memory location changes v's value also + a[0] = 42 + self.assertEqual(v.value, a[0]) + + + def test_float_from_address(self): + for t in float_types: + a = array.array(t._type_, [3.14]) + v = t.from_address(a.buffer_info()[0]) + self.assertEqual(v.value, a[0]) + self.assertIs(type(v), t) + a[0] = 2.3456e17 + self.assertEqual(v.value, a[0]) + self.assertIs(type(v), t) + + def test_char_from_address(self): + a = array.array('b', [0]) + a[0] = ord('x') + v = c_char.from_address(a.buffer_info()[0]) + self.assertEqual(v.value, b'x') + self.assertIs(type(v), c_char) + + a[0] = ord('?') + self.assertEqual(v.value, b'?') + + def test_init(self): + # c_int() can be initialized from Python's int, and c_int. + # Not from c_long or so, which seems strange, abc should + # probably be changed: + self.assertRaises(TypeError, c_int, c_long(42)) + + def test_float_overflow(self): + big_int = int(sys.float_info.max) * 2 + for t in float_types + [c_longdouble]: + self.assertRaises(OverflowError, t, big_int) + if (hasattr(t, "__ctype_be__")): + self.assertRaises(OverflowError, t.__ctype_be__, big_int) + if (hasattr(t, "__ctype_le__")): + self.assertRaises(OverflowError, t.__ctype_le__, big_int) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_objects.py b/Lib/test/test_ctypes/test_objects.py new file mode 100644 index 00000000000..fb01421b955 --- /dev/null +++ b/Lib/test/test_ctypes/test_objects.py @@ -0,0 +1,66 @@ +r''' +This tests the '_objects' attribute of ctypes instances. '_objects' +holds references to objects that must be kept alive as long as the +ctypes instance, to make sure that the memory buffer is valid. + +WARNING: The '_objects' attribute is exposed ONLY for debugging ctypes itself, +it MUST NEVER BE MODIFIED! + +'_objects' is initialized to a dictionary on first use, before that it +is None. + +Here is an array of string pointers: + +>>> from ctypes import Structure, c_int, c_char_p +>>> array = (c_char_p * 5)() +>>> print(array._objects) +None +>>> + +The memory block stores pointers to strings, and the strings itself +assigned from Python must be kept. + +>>> array[4] = b'foo bar' +>>> array._objects +{'4': b'foo bar'} +>>> array[4] +b'foo bar' +>>> + +It gets more complicated when the ctypes instance itself is contained +in a 'base' object. + +>>> class X(Structure): +... _fields_ = [("x", c_int), ("y", c_int), ("array", c_char_p * 5)] +... +>>> x = X() +>>> print(x._objects) +None +>>> + +The'array' attribute of the 'x' object shares part of the memory buffer +of 'x' ('_b_base_' is either None, or the root object owning the memory block): + +>>> print(x.array._b_base_) # doctest: +ELLIPSIS + +>>> + +>>> x.array[0] = b'spam spam spam' +>>> x._objects +{'0:2': b'spam spam spam'} +>>> x.array._b_base_._objects +{'0:2': b'spam spam spam'} +>>> +''' + +import doctest +import unittest + + +def load_tests(loader, tests, pattern): + tests.addTest(doctest.DocTestSuite()) + return tests + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_parameters.py b/Lib/test/test_ctypes/test_parameters.py new file mode 100644 index 00000000000..1a6ddb91a7d --- /dev/null +++ b/Lib/test/test_ctypes/test_parameters.py @@ -0,0 +1,288 @@ +import sys +import unittest +import test.support +from ctypes import (CDLL, PyDLL, ArgumentError, + Structure, Array, Union, + _Pointer, _SimpleCData, _CFuncPtr, + POINTER, pointer, byref, + c_void_p, c_char_p, c_wchar_p, py_object, + c_bool, + c_char, c_wchar, + c_byte, c_ubyte, + c_short, c_ushort, + c_int, c_uint, + c_long, c_ulong, + c_longlong, c_ulonglong, + c_float, c_double, c_longdouble) +from test.support import import_helper +_ctypes_test = import_helper.import_module("_ctypes_test") + + +class SimpleTypesTestCase(unittest.TestCase): + def setUp(self): + try: + from _ctypes import set_conversion_mode + except ImportError: + pass + else: + self.prev_conv_mode = set_conversion_mode("ascii", "strict") + + def tearDown(self): + try: + from _ctypes import set_conversion_mode + except ImportError: + pass + else: + set_conversion_mode(*self.prev_conv_mode) + + def test_subclasses(self): + # ctypes 0.9.5 and before did overwrite from_param in SimpleType_new + class CVOIDP(c_void_p): + def from_param(cls, value): + return value * 2 + from_param = classmethod(from_param) + + class CCHARP(c_char_p): + def from_param(cls, value): + return value * 4 + from_param = classmethod(from_param) + + self.assertEqual(CVOIDP.from_param("abc"), "abcabc") + self.assertEqual(CCHARP.from_param("abc"), "abcabcabcabc") + + def test_subclasses_c_wchar_p(self): + class CWCHARP(c_wchar_p): + def from_param(cls, value): + return value * 3 + from_param = classmethod(from_param) + + self.assertEqual(CWCHARP.from_param("abc"), "abcabcabc") + + # XXX Replace by c_char_p tests + def test_cstrings(self): + # c_char_p.from_param on a Python String packs the string + # into a cparam object + s = b"123" + self.assertIs(c_char_p.from_param(s)._obj, s) + + # new in 0.9.1: convert (encode) unicode to ascii + self.assertEqual(c_char_p.from_param(b"123")._obj, b"123") + self.assertRaises(TypeError, c_char_p.from_param, "123\377") + self.assertRaises(TypeError, c_char_p.from_param, 42) + + # calling c_char_p.from_param with a c_char_p instance + # returns the argument itself: + a = c_char_p(b"123") + self.assertIs(c_char_p.from_param(a), a) + + def test_cw_strings(self): + c_wchar_p.from_param("123") + + self.assertRaises(TypeError, c_wchar_p.from_param, 42) + self.assertRaises(TypeError, c_wchar_p.from_param, b"123\377") + + pa = c_wchar_p.from_param(c_wchar_p("123")) + self.assertEqual(type(pa), c_wchar_p) + + def test_c_char(self): + with self.assertRaises(TypeError) as cm: + c_char.from_param(b"abc") + self.assertEqual(str(cm.exception), + "one character bytes, bytearray or integer expected") + + def test_c_wchar(self): + with self.assertRaises(TypeError) as cm: + c_wchar.from_param("abc") + self.assertEqual(str(cm.exception), + "one character unicode string expected") + + + with self.assertRaises(TypeError) as cm: + c_wchar.from_param(123) + self.assertEqual(str(cm.exception), + "unicode string expected instead of int instance") + + def test_int_pointers(self): + LPINT = POINTER(c_int) + + x = LPINT.from_param(pointer(c_int(42))) + self.assertEqual(x.contents.value, 42) + self.assertEqual(LPINT(c_int(42)).contents.value, 42) + + self.assertEqual(LPINT.from_param(None), None) + + if c_int != c_long: + self.assertRaises(TypeError, LPINT.from_param, pointer(c_long(42))) + self.assertRaises(TypeError, LPINT.from_param, pointer(c_uint(42))) + self.assertRaises(TypeError, LPINT.from_param, pointer(c_short(42))) + + def test_byref_pointer(self): + # The from_param class method of POINTER(typ) classes accepts what is + # returned by byref(obj), it type(obj) == typ + LPINT = POINTER(c_int) + + LPINT.from_param(byref(c_int(42))) + + self.assertRaises(TypeError, LPINT.from_param, byref(c_short(22))) + if c_int != c_long: + self.assertRaises(TypeError, LPINT.from_param, byref(c_long(22))) + self.assertRaises(TypeError, LPINT.from_param, byref(c_uint(22))) + + def test_byref_pointerpointer(self): + # See above + + LPLPINT = POINTER(POINTER(c_int)) + LPLPINT.from_param(byref(pointer(c_int(42)))) + + self.assertRaises(TypeError, LPLPINT.from_param, byref(pointer(c_short(22)))) + if c_int != c_long: + self.assertRaises(TypeError, LPLPINT.from_param, byref(pointer(c_long(22)))) + self.assertRaises(TypeError, LPLPINT.from_param, byref(pointer(c_uint(22)))) + + def test_array_pointers(self): + INTARRAY = c_int * 3 + ia = INTARRAY() + self.assertEqual(len(ia), 3) + self.assertEqual([ia[i] for i in range(3)], [0, 0, 0]) + + # Pointers are only compatible with arrays containing items of + # the same type! + LPINT = POINTER(c_int) + LPINT.from_param((c_int*3)()) + self.assertRaises(TypeError, LPINT.from_param, c_short*3) + self.assertRaises(TypeError, LPINT.from_param, c_long*3) + self.assertRaises(TypeError, LPINT.from_param, c_uint*3) + + def test_noctypes_argtype(self): + func = CDLL(_ctypes_test.__file__)._testfunc_p_p + func.restype = c_void_p + # TypeError: has no from_param method + self.assertRaises(TypeError, setattr, func, "argtypes", (object,)) + + class Adapter: + def from_param(cls, obj): + return None + + func.argtypes = (Adapter(),) + self.assertEqual(func(None), None) + self.assertEqual(func(object()), None) + + class Adapter: + def from_param(cls, obj): + return obj + + func.argtypes = (Adapter(),) + # don't know how to convert parameter 1 + self.assertRaises(ArgumentError, func, object()) + self.assertEqual(func(c_void_p(42)), 42) + + class Adapter: + def from_param(cls, obj): + raise ValueError(obj) + + func.argtypes = (Adapter(),) + # ArgumentError: argument 1: ValueError: 99 + self.assertRaises(ArgumentError, func, 99) + + def test_abstract(self): + self.assertRaises(TypeError, Array.from_param, 42) + self.assertRaises(TypeError, Structure.from_param, 42) + self.assertRaises(TypeError, Union.from_param, 42) + self.assertRaises(TypeError, _CFuncPtr.from_param, 42) + self.assertRaises(TypeError, _Pointer.from_param, 42) + self.assertRaises(TypeError, _SimpleCData.from_param, 42) + + @test.support.cpython_only + def test_issue31311(self): + # __setstate__ should neither raise a SystemError nor crash in case + # of a bad __dict__. + + class BadStruct(Structure): + @property + def __dict__(self): + pass + with self.assertRaises(TypeError): + BadStruct().__setstate__({}, b'foo') + + class WorseStruct(Structure): + @property + def __dict__(self): + 1/0 + with self.assertRaises(ZeroDivisionError): + WorseStruct().__setstate__({}, b'foo') + + def test_parameter_repr(self): + self.assertRegex(repr(c_bool.from_param(True)), r"^$") + self.assertEqual(repr(c_char.from_param(97)), "") + self.assertRegex(repr(c_wchar.from_param('a')), r"^$") + self.assertEqual(repr(c_byte.from_param(98)), "") + self.assertEqual(repr(c_ubyte.from_param(98)), "") + self.assertEqual(repr(c_short.from_param(511)), "") + self.assertEqual(repr(c_ushort.from_param(511)), "") + self.assertRegex(repr(c_int.from_param(20000)), r"^$") + self.assertRegex(repr(c_uint.from_param(20000)), r"^$") + self.assertRegex(repr(c_long.from_param(20000)), r"^$") + self.assertRegex(repr(c_ulong.from_param(20000)), r"^$") + self.assertRegex(repr(c_longlong.from_param(20000)), r"^$") + self.assertRegex(repr(c_ulonglong.from_param(20000)), r"^$") + self.assertEqual(repr(c_float.from_param(1.5)), "") + self.assertEqual(repr(c_double.from_param(1.5)), "") + if sys.float_repr_style == 'short': + self.assertEqual(repr(c_double.from_param(1e300)), "") + self.assertRegex(repr(c_longdouble.from_param(1.5)), r"^$") + self.assertRegex(repr(c_char_p.from_param(b'hihi')), r"^$") + self.assertRegex(repr(c_wchar_p.from_param('hihi')), r"^$") + self.assertRegex(repr(c_void_p.from_param(0x12)), r"^$") + + @test.support.cpython_only + def test_from_param_result_refcount(self): + # Issue #99952 + class X(Structure): + """This struct size is <= sizeof(void*).""" + _fields_ = [("a", c_void_p)] + + def __del__(self): + trace.append(4) + + @classmethod + def from_param(cls, value): + trace.append(2) + return cls() + + PyList_Append = PyDLL(_ctypes_test.__file__)._testfunc_pylist_append + PyList_Append.restype = c_int + PyList_Append.argtypes = [py_object, py_object, X] + + trace = [] + trace.append(1) + PyList_Append(trace, 3, "dummy") + trace.append(5) + + self.assertEqual(trace, [1, 2, 3, 4, 5]) + + class Y(Structure): + """This struct size is > sizeof(void*).""" + _fields_ = [("a", c_void_p), ("b", c_void_p)] + + def __del__(self): + trace.append(4) + + @classmethod + def from_param(cls, value): + trace.append(2) + return cls() + + PyList_Append = PyDLL(_ctypes_test.__file__)._testfunc_pylist_append + PyList_Append.restype = c_int + PyList_Append.argtypes = [py_object, py_object, Y] + + trace = [] + trace.append(1) + PyList_Append(trace, 3, "dummy") + trace.append(5) + + self.assertEqual(trace, [1, 2, 3, 4, 5]) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_pep3118.py b/Lib/test/test_ctypes/test_pep3118.py new file mode 100644 index 00000000000..06b2ccecade --- /dev/null +++ b/Lib/test/test_ctypes/test_pep3118.py @@ -0,0 +1,250 @@ +import re +import sys +import unittest +from ctypes import (CFUNCTYPE, POINTER, sizeof, Union, + Structure, LittleEndianStructure, BigEndianStructure, + c_char, c_byte, c_ubyte, + c_short, c_ushort, c_int, c_uint, + c_long, c_ulong, c_longlong, c_ulonglong, c_uint64, + c_bool, c_float, c_double, c_longdouble, py_object) + + +if sys.byteorder == "little": + THIS_ENDIAN = "<" + OTHER_ENDIAN = ">" +else: + THIS_ENDIAN = ">" + OTHER_ENDIAN = "<" + + +def normalize(format): + # Remove current endian specifier and white space from a format + # string + if format is None: + return "" + format = format.replace(OTHER_ENDIAN, THIS_ENDIAN) + return re.sub(r"\s", "", format) + + +class Test(unittest.TestCase): + def test_native_types(self): + for tp, fmt, shape, itemtp in native_types: + ob = tp() + v = memoryview(ob) + self.assertEqual(normalize(v.format), normalize(fmt)) + if shape: + self.assertEqual(len(v), shape[0]) + else: + self.assertRaises(TypeError, len, v) + self.assertEqual(v.itemsize, sizeof(itemtp)) + self.assertEqual(v.shape, shape) + # XXX Issue #12851: PyCData_NewGetBuffer() must provide strides + # if requested. memoryview currently reconstructs missing + # stride information, so this assert will fail. + # self.assertEqual(v.strides, ()) + + # they are always read/write + self.assertFalse(v.readonly) + + n = 1 + for dim in v.shape: + n = n * dim + self.assertEqual(n * v.itemsize, len(v.tobytes())) + + def test_endian_types(self): + for tp, fmt, shape, itemtp in endian_types: + ob = tp() + v = memoryview(ob) + self.assertEqual(v.format, fmt) + if shape: + self.assertEqual(len(v), shape[0]) + else: + self.assertRaises(TypeError, len, v) + self.assertEqual(v.itemsize, sizeof(itemtp)) + self.assertEqual(v.shape, shape) + # XXX Issue #12851 + # self.assertEqual(v.strides, ()) + + # they are always read/write + self.assertFalse(v.readonly) + + n = 1 + for dim in v.shape: + n = n * dim + self.assertEqual(n * v.itemsize, len(v.tobytes())) + + +# define some structure classes + +class Point(Structure): + _fields_ = [("x", c_long), ("y", c_long)] + +class PackedPoint(Structure): + _pack_ = 2 + _fields_ = [("x", c_long), ("y", c_long)] + +class PointMidPad(Structure): + _fields_ = [("x", c_byte), ("y", c_uint)] + +class PackedPointMidPad(Structure): + _pack_ = 2 + _fields_ = [("x", c_byte), ("y", c_uint64)] + +class PointEndPad(Structure): + _fields_ = [("x", c_uint), ("y", c_byte)] + +class PackedPointEndPad(Structure): + _pack_ = 2 + _fields_ = [("x", c_uint64), ("y", c_byte)] + +class Point2(Structure): + pass +Point2._fields_ = [("x", c_long), ("y", c_long)] + +class EmptyStruct(Structure): + _fields_ = [] + +class aUnion(Union): + _fields_ = [("a", c_int)] + +class StructWithArrays(Structure): + _fields_ = [("x", c_long * 3 * 2), ("y", Point * 4)] + +class Incomplete(Structure): + pass + +class Complete(Structure): + pass +PComplete = POINTER(Complete) +Complete._fields_ = [("a", c_long)] + + +################################################################ +# +# This table contains format strings as they look on little endian +# machines. The test replaces '<' with '>' on big endian machines. +# + +# Platform-specific type codes +s_bool = {1: '?', 2: 'H', 4: 'L', 8: 'Q'}[sizeof(c_bool)] +s_short = {2: 'h', 4: 'l', 8: 'q'}[sizeof(c_short)] +s_ushort = {2: 'H', 4: 'L', 8: 'Q'}[sizeof(c_ushort)] +s_int = {2: 'h', 4: 'i', 8: 'q'}[sizeof(c_int)] +s_uint = {2: 'H', 4: 'I', 8: 'Q'}[sizeof(c_uint)] +s_long = {4: 'l', 8: 'q'}[sizeof(c_long)] +s_ulong = {4: 'L', 8: 'Q'}[sizeof(c_ulong)] +s_longlong = "q" +s_ulonglong = "Q" +s_float = "f" +s_double = "d" +s_longdouble = "g" + +# Alias definitions in ctypes/__init__.py +if c_int is c_long: + s_int = s_long +if c_uint is c_ulong: + s_uint = s_ulong +if c_longlong is c_long: + s_longlong = s_long +if c_ulonglong is c_ulong: + s_ulonglong = s_ulong +if c_longdouble is c_double: + s_longdouble = s_double + + +native_types = [ + # type format shape calc itemsize + + ## simple types + + (c_char, "l:x:>l:y:}".replace('l', s_long), (), BEPoint), + (LEPoint * 1, "T{l:x:>l:y:}".replace('l', s_long), (), POINTER(BEPoint)), + (POINTER(LEPoint), "&T{= 0: + return a + # View the bits in `a` as unsigned instead. + import struct + num_bits = struct.calcsize("P") * 8 # num bits in native machine address + a += 1 << num_bits + assert a >= 0 + return a + + +def c_wbuffer(init): + n = len(init) + 1 + return (c_wchar * n)(*init) + + +class CharPointersTestCase(unittest.TestCase): + def setUp(self): + func = testdll._testfunc_p_p + func.restype = c_long + func.argtypes = None + + def test_paramflags(self): + # function returns c_void_p result, + # and has a required parameter named 'input' + prototype = CFUNCTYPE(c_void_p, c_void_p) + func = prototype(("_testfunc_p_p", testdll), + ((1, "input"),)) + + try: + func() + except TypeError as details: + self.assertEqual(str(details), "required argument 'input' missing") + else: + self.fail("TypeError not raised") + + self.assertEqual(func(None), None) + self.assertEqual(func(input=None), None) + + def test_invalid_paramflags(self): + proto = CFUNCTYPE(c_int, c_char_p) + with self.assertRaises(ValueError): + func = proto(("myprintf", testdll), ((1, "fmt"), (1, "arg1"))) + + def test_invalid_setattr_argtypes(self): + proto = CFUNCTYPE(c_int, c_char_p) + func = proto(("myprintf", testdll), ((1, "fmt"),)) + + with self.assertRaisesRegex(TypeError, "_argtypes_ must be a sequence of types"): + func.argtypes = 123 + self.assertEqual(func.argtypes, (c_char_p,)) + + with self.assertRaisesRegex(ValueError, "paramflags must have the same length as argtypes"): + func.argtypes = (c_char_p, c_int) + self.assertEqual(func.argtypes, (c_char_p,)) + + def test_paramflags_outarg(self): + proto = CFUNCTYPE(c_int, c_char_p, c_int) + with self.assertRaisesRegex(TypeError, "must be a pointer type"): + func = proto(("myprintf", testdll), ((1, "fmt"), (2, "out"))) + + proto = CFUNCTYPE(c_int, c_char_p, c_void_p) + func = proto(("myprintf", testdll), ((1, "fmt"), (2, "out"))) + with self.assertRaisesRegex(TypeError, "must be a pointer type"): + func.argtypes = (c_char_p, c_int) + + def test_int_pointer_arg(self): + func = testdll._testfunc_p_p + if sizeof(c_longlong) == sizeof(c_void_p): + func.restype = c_longlong + else: + func.restype = c_long + self.assertEqual(0, func(0)) + + ci = c_int(0) + + func.argtypes = POINTER(c_int), + self.assertEqual(positive_address(addressof(ci)), + positive_address(func(byref(ci)))) + + func.argtypes = c_char_p, + self.assertRaises(ArgumentError, func, byref(ci)) + + func.argtypes = POINTER(c_short), + self.assertRaises(ArgumentError, func, byref(ci)) + + func.argtypes = POINTER(c_double), + self.assertRaises(ArgumentError, func, byref(ci)) + + def test_POINTER_c_char_arg(self): + func = testdll._testfunc_p_p + func.restype = c_char_p + func.argtypes = POINTER(c_char), + + self.assertEqual(None, func(None)) + self.assertEqual(b"123", func(b"123")) + self.assertEqual(None, func(c_char_p(None))) + self.assertEqual(b"123", func(c_char_p(b"123"))) + + self.assertEqual(b"123", func(create_string_buffer(b"123"))) + ca = c_char(b"a") + self.assertEqual(ord(b"a"), func(pointer(ca))[0]) + self.assertEqual(ord(b"a"), func(byref(ca))[0]) + + def test_c_char_p_arg(self): + func = testdll._testfunc_p_p + func.restype = c_char_p + func.argtypes = c_char_p, + + self.assertEqual(None, func(None)) + self.assertEqual(b"123", func(b"123")) + self.assertEqual(None, func(c_char_p(None))) + self.assertEqual(b"123", func(c_char_p(b"123"))) + + self.assertEqual(b"123", func(create_string_buffer(b"123"))) + ca = c_char(b"a") + self.assertEqual(ord(b"a"), func(pointer(ca))[0]) + self.assertEqual(ord(b"a"), func(byref(ca))[0]) + + def test_c_void_p_arg(self): + func = testdll._testfunc_p_p + func.restype = c_char_p + func.argtypes = c_void_p, + + self.assertEqual(None, func(None)) + self.assertEqual(b"123", func(b"123")) + self.assertEqual(b"123", func(c_char_p(b"123"))) + self.assertEqual(None, func(c_char_p(None))) + + self.assertEqual(b"123", func(create_string_buffer(b"123"))) + ca = c_char(b"a") + self.assertEqual(ord(b"a"), func(pointer(ca))[0]) + self.assertEqual(ord(b"a"), func(byref(ca))[0]) + + func(byref(c_int())) + func(pointer(c_int())) + func((c_int * 3)()) + + def test_c_void_p_arg_with_c_wchar_p(self): + func = testdll._testfunc_p_p + func.restype = c_wchar_p + func.argtypes = c_void_p, + + self.assertEqual(None, func(c_wchar_p(None))) + self.assertEqual("123", func(c_wchar_p("123"))) + + def test_instance(self): + func = testdll._testfunc_p_p + func.restype = c_void_p + + class X: + _as_parameter_ = None + + func.argtypes = c_void_p, + self.assertEqual(None, func(X())) + + func.argtypes = None + self.assertEqual(None, func(X())) + + +class WCharPointersTestCase(unittest.TestCase): + def setUp(self): + func = testdll._testfunc_p_p + func.restype = c_int + func.argtypes = None + + + def test_POINTER_c_wchar_arg(self): + func = testdll._testfunc_p_p + func.restype = c_wchar_p + func.argtypes = POINTER(c_wchar), + + self.assertEqual(None, func(None)) + self.assertEqual("123", func("123")) + self.assertEqual(None, func(c_wchar_p(None))) + self.assertEqual("123", func(c_wchar_p("123"))) + + self.assertEqual("123", func(c_wbuffer("123"))) + ca = c_wchar("a") + self.assertEqual("a", func(pointer(ca))[0]) + self.assertEqual("a", func(byref(ca))[0]) + + def test_c_wchar_p_arg(self): + func = testdll._testfunc_p_p + func.restype = c_wchar_p + func.argtypes = c_wchar_p, + + c_wchar_p.from_param("123") + + self.assertEqual(None, func(None)) + self.assertEqual("123", func("123")) + self.assertEqual(None, func(c_wchar_p(None))) + self.assertEqual("123", func(c_wchar_p("123"))) + + # XXX Currently, these raise TypeErrors, although they shouldn't: + self.assertEqual("123", func(c_wbuffer("123"))) + ca = c_wchar("a") + self.assertEqual("a", func(pointer(ca))[0]) + self.assertEqual("a", func(byref(ca))[0]) + + +class ArrayTest(unittest.TestCase): + def test(self): + func = testdll._testfunc_ai8 + func.restype = POINTER(c_int) + func.argtypes = c_int * 8, + + func((c_int * 8)(1, 2, 3, 4, 5, 6, 7, 8)) + + # This did crash before: + + def func(): pass + CFUNCTYPE(None, c_int * 3)(func) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_python_api.py b/Lib/test/test_ctypes/test_python_api.py new file mode 100644 index 00000000000..1072a109833 --- /dev/null +++ b/Lib/test/test_ctypes/test_python_api.py @@ -0,0 +1,81 @@ +import _ctypes +import sys +import unittest +from test import support +from ctypes import (pythonapi, POINTER, create_string_buffer, sizeof, + py_object, c_char_p, c_char, c_long, c_size_t) + + +class PythonAPITestCase(unittest.TestCase): + def test_PyBytes_FromStringAndSize(self): + PyBytes_FromStringAndSize = pythonapi.PyBytes_FromStringAndSize + + PyBytes_FromStringAndSize.restype = py_object + PyBytes_FromStringAndSize.argtypes = c_char_p, c_size_t + + self.assertEqual(PyBytes_FromStringAndSize(b"abcdefghi", 3), b"abc") + + @support.refcount_test + def test_PyString_FromString(self): + pythonapi.PyBytes_FromString.restype = py_object + pythonapi.PyBytes_FromString.argtypes = (c_char_p,) + + s = b"abc" + refcnt = sys.getrefcount(s) + pyob = pythonapi.PyBytes_FromString(s) + self.assertEqual(sys.getrefcount(s), refcnt) + self.assertEqual(s, pyob) + del pyob + self.assertEqual(sys.getrefcount(s), refcnt) + + @support.refcount_test + def test_PyLong_Long(self): + ref42 = sys.getrefcount(42) + pythonapi.PyLong_FromLong.restype = py_object + self.assertEqual(pythonapi.PyLong_FromLong(42), 42) + + self.assertEqual(sys.getrefcount(42), ref42) + + pythonapi.PyLong_AsLong.argtypes = (py_object,) + pythonapi.PyLong_AsLong.restype = c_long + + res = pythonapi.PyLong_AsLong(42) + # Small int refcnts don't change + self.assertEqual(sys.getrefcount(res), ref42) + del res + self.assertEqual(sys.getrefcount(42), ref42) + + @support.refcount_test + def test_PyObj_FromPtr(self): + s = object() + ref = sys.getrefcount(s) + # id(python-object) is the address + pyobj = _ctypes.PyObj_FromPtr(id(s)) + self.assertIs(s, pyobj) + + self.assertEqual(sys.getrefcount(s), ref + 1) + del pyobj + self.assertEqual(sys.getrefcount(s), ref) + + def test_PyOS_snprintf(self): + PyOS_snprintf = pythonapi.PyOS_snprintf + PyOS_snprintf.argtypes = POINTER(c_char), c_size_t, c_char_p + + buf = create_string_buffer(256) + PyOS_snprintf(buf, sizeof(buf), b"Hello from %s", b"ctypes") + self.assertEqual(buf.value, b"Hello from ctypes") + + PyOS_snprintf(buf, sizeof(buf), b"Hello from %s (%d, %d, %d)", b"ctypes", 1, 2, 3) + self.assertEqual(buf.value, b"Hello from ctypes (1, 2, 3)") + + # not enough arguments + self.assertRaises(TypeError, PyOS_snprintf, buf) + + def test_pyobject_repr(self): + self.assertEqual(repr(py_object()), "py_object()") + self.assertEqual(repr(py_object(42)), "py_object(42)") + self.assertEqual(repr(py_object(object)), "py_object(%r)" % object) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_ctypes/test_random_things.py b/Lib/test/test_ctypes/test_random_things.py new file mode 100644 index 00000000000..630f6ed9489 --- /dev/null +++ b/Lib/test/test_ctypes/test_random_things.py @@ -0,0 +1,81 @@ +import _ctypes +import contextlib +import ctypes +import sys +import unittest +from test import support +from ctypes import CFUNCTYPE, c_void_p, c_char_p, c_int, c_double + + +def callback_func(arg): + 42 / arg + raise ValueError(arg) + + +@unittest.skipUnless(sys.platform == "win32", 'Windows-specific test') +class call_function_TestCase(unittest.TestCase): + # _ctypes.call_function is deprecated and private, but used by + # Gary Bishp's readline module. If we have it, we must test it as well. + + def test(self): + kernel32 = ctypes.windll.kernel32 + kernel32.LoadLibraryA.restype = c_void_p + kernel32.GetProcAddress.argtypes = c_void_p, c_char_p + kernel32.GetProcAddress.restype = c_void_p + + hdll = kernel32.LoadLibraryA(b"kernel32") + funcaddr = kernel32.GetProcAddress(hdll, b"GetModuleHandleA") + + self.assertEqual(_ctypes.call_function(funcaddr, (None,)), + kernel32.GetModuleHandleA(None)) + + +class CallbackTracbackTestCase(unittest.TestCase): + # When an exception is raised in a ctypes callback function, the C + # code prints a traceback. + # + # This test makes sure the exception types *and* the exception + # value is printed correctly. + # + # Changed in 0.9.3: No longer is '(in callback)' prepended to the + # error message - instead an additional frame for the C code is + # created, then a full traceback printed. When SystemExit is + # raised in a callback function, the interpreter exits. + + @contextlib.contextmanager + def expect_unraisable(self, exc_type, exc_msg=None): + with support.catch_unraisable_exception() as cm: + yield + + self.assertIsInstance(cm.unraisable.exc_value, exc_type) + if exc_msg is not None: + self.assertEqual(str(cm.unraisable.exc_value), exc_msg) + self.assertEqual(cm.unraisable.err_msg, + f"Exception ignored on calling ctypes " + f"callback function {callback_func!r}") + self.assertIsNone(cm.unraisable.object) + + def test_ValueError(self): + cb = CFUNCTYPE(c_int, c_int)(callback_func) + with self.expect_unraisable(ValueError, '42'): + cb(42) + + def test_IntegerDivisionError(self): + cb = CFUNCTYPE(c_int, c_int)(callback_func) + with self.expect_unraisable(ZeroDivisionError): + cb(0) + + def test_FloatDivisionError(self): + cb = CFUNCTYPE(c_int, c_double)(callback_func) + with self.expect_unraisable(ZeroDivisionError): + cb(0.0) + + def test_TypeErrorDivisionError(self): + cb = CFUNCTYPE(c_int, c_char_p)(callback_func) + err_msg = "unsupported operand type(s) for /: 'int' and 'bytes'" + with self.expect_unraisable(TypeError, err_msg): + cb(b"spam") + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_refcounts.py b/Lib/test/test_ctypes/test_refcounts.py new file mode 100644 index 00000000000..9e87cfc661e --- /dev/null +++ b/Lib/test/test_ctypes/test_refcounts.py @@ -0,0 +1,143 @@ +import ctypes +import gc +import sys +import unittest +from test import support +from test.support import import_helper +from test.support import script_helper +_ctypes_test = import_helper.import_module("_ctypes_test") + + +MyCallback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int) +OtherCallback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_ulonglong) + +dll = ctypes.CDLL(_ctypes_test.__file__) + + +class RefcountTestCase(unittest.TestCase): + @support.refcount_test + def test_1(self): + f = dll._testfunc_callback_i_if + f.restype = ctypes.c_int + f.argtypes = [ctypes.c_int, MyCallback] + + def callback(value): + return value + + self.assertEqual(sys.getrefcount(callback), 2) + cb = MyCallback(callback) + + self.assertGreater(sys.getrefcount(callback), 2) + result = f(-10, cb) + self.assertEqual(result, -18) + cb = None + + gc.collect() + + self.assertEqual(sys.getrefcount(callback), 2) + + @support.refcount_test + def test_refcount(self): + def func(*args): + pass + # this is the standard refcount for func + self.assertEqual(sys.getrefcount(func), 2) + + # the CFuncPtr instance holds at least one refcount on func: + f = OtherCallback(func) + self.assertGreater(sys.getrefcount(func), 2) + + # and may release it again + del f + self.assertGreaterEqual(sys.getrefcount(func), 2) + + # but now it must be gone + gc.collect() + self.assertEqual(sys.getrefcount(func), 2) + + class X(ctypes.Structure): + _fields_ = [("a", OtherCallback)] + x = X() + x.a = OtherCallback(func) + + # the CFuncPtr instance holds at least one refcount on func: + self.assertGreater(sys.getrefcount(func), 2) + + # and may release it again + del x + self.assertGreaterEqual(sys.getrefcount(func), 2) + + # and now it must be gone again + gc.collect() + self.assertEqual(sys.getrefcount(func), 2) + + f = OtherCallback(func) + + # the CFuncPtr instance holds at least one refcount on func: + self.assertGreater(sys.getrefcount(func), 2) + + # create a cycle + f.cycle = f + + del f + gc.collect() + self.assertEqual(sys.getrefcount(func), 2) + + +class AnotherLeak(unittest.TestCase): + def test_callback(self): + proto = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_int) + def func(a, b): + return a * b * 2 + f = proto(func) + + a = sys.getrefcount(ctypes.c_int) + f(1, 2) + self.assertEqual(sys.getrefcount(ctypes.c_int), a) + + @support.refcount_test + def test_callback_py_object_none_return(self): + # bpo-36880: test that returning None from a py_object callback + # does not decrement the refcount of None. + + for FUNCTYPE in (ctypes.CFUNCTYPE, ctypes.PYFUNCTYPE): + with self.subTest(FUNCTYPE=FUNCTYPE): + @FUNCTYPE(ctypes.py_object) + def func(): + return None + + # Check that calling func does not affect None's refcount. + for _ in range(10000): + func() + + +class ModuleIsolationTest(unittest.TestCase): + def test_finalize(self): + # check if gc_decref() succeeds + script = ( + "import ctypes;" + "import sys;" + "del sys.modules['_ctypes'];" + "import _ctypes;" + "exit()" + ) + script_helper.assert_python_ok("-c", script) + + +class PyObjectRestypeTest(unittest.TestCase): + def test_restype_py_object_with_null_return(self): + # Test that a function which returns a NULL PyObject * + # without setting an exception does not crash. + PyErr_Occurred = ctypes.pythonapi.PyErr_Occurred + PyErr_Occurred.argtypes = [] + PyErr_Occurred.restype = ctypes.py_object + + # At this point, there's no exception set, so PyErr_Occurred + # returns NULL. Given the restype is py_object, the + # ctypes machinery will raise a custom error. + with self.assertRaisesRegex(ValueError, "PyObject is NULL"): + PyErr_Occurred() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ctypes/test_repr.py b/Lib/test/test_ctypes/test_repr.py new file mode 100644 index 00000000000..e7587984a92 --- /dev/null +++ b/Lib/test/test_ctypes/test_repr.py @@ -0,0 +1,34 @@ +import unittest +from ctypes import (c_byte, c_short, c_int, c_long, c_longlong, + c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong, + c_float, c_double, c_longdouble, c_bool, c_char) + + +subclasses = [] +for base in [c_byte, c_short, c_int, c_long, c_longlong, + c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong, + c_float, c_double, c_longdouble, c_bool]: + class X(base): + pass + subclasses.append(X) + + +class X(c_char): + pass + + +# This test checks if the __repr__ is correct for subclasses of simple types +class ReprTest(unittest.TestCase): + def test_numbers(self): + for typ in subclasses: + base = typ.__bases__[0] + self.assertTrue(repr(base(42)).startswith(base.__name__)) + self.assertEqual(" +# for reference. +# +# Tests also work on POSIX + +import unittest +from ctypes import POINTER, cast, c_int16 +from ctypes import wintypes + + +class WinTypesTest(unittest.TestCase): + def test_variant_bool(self): + # reads 16-bits from memory, anything non-zero is True + for true_value in (1, 32767, 32768, 65535, 65537): + true = POINTER(c_int16)(c_int16(true_value)) + value = cast(true, POINTER(wintypes.VARIANT_BOOL)) + self.assertEqual(repr(value.contents), 'VARIANT_BOOL(True)') + + vb = wintypes.VARIANT_BOOL() + self.assertIs(vb.value, False) + vb.value = True + self.assertIs(vb.value, True) + vb.value = true_value + self.assertIs(vb.value, True) + + for false_value in (0, 65536, 262144, 2**33): + false = POINTER(c_int16)(c_int16(false_value)) + value = cast(false, POINTER(wintypes.VARIANT_BOOL)) + self.assertEqual(repr(value.contents), 'VARIANT_BOOL(False)') + + # allow any bool conversion on assignment to value + for set_value in (65536, 262144, 2**33): + vb = wintypes.VARIANT_BOOL() + vb.value = set_value + self.assertIs(vb.value, True) + + vb = wintypes.VARIANT_BOOL() + vb.value = [2, 3] + self.assertIs(vb.value, True) + vb.value = [] + self.assertIs(vb.value, False) + + def assertIsSigned(self, ctype): + self.assertLess(ctype(-1).value, 0) + + def assertIsUnsigned(self, ctype): + self.assertGreater(ctype(-1).value, 0) + + def test_signedness(self): + for ctype in (wintypes.BYTE, wintypes.WORD, wintypes.DWORD, + wintypes.BOOLEAN, wintypes.UINT, wintypes.ULONG): + with self.subTest(ctype=ctype): + self.assertIsUnsigned(ctype) + + for ctype in (wintypes.BOOL, wintypes.INT, wintypes.LONG): + with self.subTest(ctype=ctype): + self.assertIsSigned(ctype) + + +if __name__ == "__main__": + unittest.main() From 4a352344b6abf257c4050786f2bfd8036da0f715 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 20 Dec 2025 17:46:05 +0900 Subject: [PATCH 038/418] mark failing test_ctypes --- Lib/test/test_ctypes/test_cast.py | 2 ++ Lib/test/test_ctypes/test_dlerror.py | 1 + Lib/test/test_ctypes/test_internals.py | 4 ++++ Lib/test/test_ctypes/test_keeprefs.py | 4 ++++ Lib/test/test_ctypes/test_objects.py | 3 ++- Lib/test/test_ctypes/test_python_api.py | 4 ++++ Lib/test/test_ctypes/test_random_things.py | 2 ++ Lib/test/test_ctypes/test_values.py | 4 ++++ Lib/test/test_ctypes/test_win32.py | 2 ++ Lib/test/test_ctypes/test_win32_com_foreign_func.py | 2 ++ Lib/test/test_exceptions.py | 1 + Lib/test/test_os.py | 3 --- 12 files changed, 28 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_ctypes/test_cast.py b/Lib/test/test_ctypes/test_cast.py index 604f44f03d6..db6bdc75eff 100644 --- a/Lib/test/test_ctypes/test_cast.py +++ b/Lib/test/test_ctypes/test_cast.py @@ -32,6 +32,8 @@ def test_address2pointer(self): ptr = cast(address, POINTER(c_int)) self.assertEqual([ptr[i] for i in range(3)], [42, 17, 2]) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_p2a_objects(self): array = (c_char_p * 5)() self.assertEqual(array._objects, None) diff --git a/Lib/test/test_ctypes/test_dlerror.py b/Lib/test/test_ctypes/test_dlerror.py index 1c1b2aab3d5..cd87bad3825 100644 --- a/Lib/test/test_ctypes/test_dlerror.py +++ b/Lib/test/test_ctypes/test_dlerror.py @@ -54,6 +54,7 @@ class TestNullDlsym(unittest.TestCase): this 'dlsym returned NULL -> throw Error' rule. """ + @unittest.expectedFailure # TODO: RUSTPYTHON def test_null_dlsym(self): import subprocess import tempfile diff --git a/Lib/test/test_ctypes/test_internals.py b/Lib/test/test_ctypes/test_internals.py index 778da6573da..292633aaa4b 100644 --- a/Lib/test/test_ctypes/test_internals.py +++ b/Lib/test/test_ctypes/test_internals.py @@ -27,6 +27,8 @@ def test_ints(self): self.assertEqual(refcnt, sys.getrefcount(i)) self.assertEqual(ci._objects, None) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_c_char_p(self): s = "Hello, World".encode("ascii") refcnt = sys.getrefcount(s) @@ -62,6 +64,8 @@ class Y(Structure): x1.a, x2.b = 42, 93 self.assertEqual(y._objects, {"0": {}, "1": {}}) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_xxx(self): class X(Structure): _fields_ = [("a", c_char_p), ("b", c_char_p)] diff --git a/Lib/test/test_ctypes/test_keeprefs.py b/Lib/test/test_ctypes/test_keeprefs.py index 23b03b64b4a..5aa5b86fa45 100644 --- a/Lib/test/test_ctypes/test_keeprefs.py +++ b/Lib/test/test_ctypes/test_keeprefs.py @@ -12,6 +12,8 @@ def test_cint(self): x = c_int(99) self.assertEqual(x._objects, None) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_ccharp(self): x = c_char_p() self.assertEqual(x._objects, None) @@ -33,6 +35,8 @@ class X(Structure): x.b = 99 self.assertEqual(x._objects, None) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_ccharp_struct(self): class X(Structure): _fields_ = [("a", c_char_p), diff --git a/Lib/test/test_ctypes/test_objects.py b/Lib/test/test_ctypes/test_objects.py index fb01421b955..8db1cd873fd 100644 --- a/Lib/test/test_ctypes/test_objects.py +++ b/Lib/test/test_ctypes/test_objects.py @@ -58,7 +58,8 @@ def load_tests(loader, tests, pattern): - tests.addTest(doctest.DocTestSuite()) + # TODO: RUSTPYTHON - doctest disabled due to null terminator in _objects + # tests.addTest(doctest.DocTestSuite()) return tests diff --git a/Lib/test/test_ctypes/test_python_api.py b/Lib/test/test_ctypes/test_python_api.py index 1072a109833..2e68b35f8af 100644 --- a/Lib/test/test_ctypes/test_python_api.py +++ b/Lib/test/test_ctypes/test_python_api.py @@ -7,6 +7,8 @@ class PythonAPITestCase(unittest.TestCase): + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_PyBytes_FromStringAndSize(self): PyBytes_FromStringAndSize = pythonapi.PyBytes_FromStringAndSize @@ -57,6 +59,8 @@ def test_PyObj_FromPtr(self): del pyobj self.assertEqual(sys.getrefcount(s), ref) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_PyOS_snprintf(self): PyOS_snprintf = pythonapi.PyOS_snprintf PyOS_snprintf.argtypes = POINTER(c_char), c_size_t, c_char_p diff --git a/Lib/test/test_ctypes/test_random_things.py b/Lib/test/test_ctypes/test_random_things.py index 630f6ed9489..3908eca0926 100644 --- a/Lib/test/test_ctypes/test_random_things.py +++ b/Lib/test/test_ctypes/test_random_things.py @@ -70,6 +70,8 @@ def test_FloatDivisionError(self): with self.expect_unraisable(ZeroDivisionError): cb(0.0) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_TypeErrorDivisionError(self): cb = CFUNCTYPE(c_int, c_char_p)(callback_func) err_msg = "unsupported operand type(s) for /: 'int' and 'bytes'" diff --git a/Lib/test/test_ctypes/test_values.py b/Lib/test/test_ctypes/test_values.py index 1b757e020d5..e0d200e2101 100644 --- a/Lib/test/test_ctypes/test_values.py +++ b/Lib/test/test_ctypes/test_values.py @@ -39,6 +39,8 @@ def test_undefined(self): class PythonValuesTestCase(unittest.TestCase): """This test only works when python itself is a dll/shared library""" + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_optimizeflag(self): # This test accesses the Py_OptimizeFlag integer, which is # exported by the Python dll and should match the sys.flags value @@ -46,6 +48,8 @@ def test_optimizeflag(self): opt = c_int.in_dll(pythonapi, "Py_OptimizeFlag").value self.assertEqual(opt, sys.flags.optimize) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_frozentable(self): # Python exports a PyImport_FrozenModules symbol. This is a # pointer to an array of struct _frozen entries. The end of the diff --git a/Lib/test/test_ctypes/test_win32.py b/Lib/test/test_ctypes/test_win32.py index 31919118670..4de4f0379cf 100644 --- a/Lib/test/test_ctypes/test_win32.py +++ b/Lib/test/test_ctypes/test_win32.py @@ -17,6 +17,8 @@ class FunctionCallTestCase(unittest.TestCase): @unittest.skipUnless('MSC' in sys.version, "SEH only supported by MSC") @unittest.skipIf(sys.executable.lower().endswith('_d.exe'), "SEH not enabled in debug builds") + # TODO: RUSTPYTHON - SEH not implemented + @unittest.skipIf("RustPython" in sys.version, "SEH not implemented in RustPython") def test_SEH(self): # Disable faulthandler to prevent logging the warning: # "Windows fatal exception: access violation" diff --git a/Lib/test/test_ctypes/test_win32_com_foreign_func.py b/Lib/test/test_ctypes/test_win32_com_foreign_func.py index 8d217fc17ef..b12e09333bd 100644 --- a/Lib/test/test_ctypes/test_win32_com_foreign_func.py +++ b/Lib/test/test_ctypes/test_win32_com_foreign_func.py @@ -158,6 +158,8 @@ class IPersist(IUnknown): self.assertEqual(0, ppst.Release()) + # TODO: RUSTPYTHON - COM iid parameter handling not implemented + @unittest.expectedFailure def test_with_paramflags_and_iid(self): class IUnknown(c_void_p): QueryInterface = proto_query_interface(None, IID_IUnknown) diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py index 61f4156dc6d..3db9203602e 100644 --- a/Lib/test/test_exceptions.py +++ b/Lib/test/test_exceptions.py @@ -430,6 +430,7 @@ def test_WindowsError(self): @unittest.skipUnless(sys.platform == 'win32', 'test specific to Windows') + @unittest.expectedFailure # TODO: RUSTPYTHON def test_windows_message(self): """Should fill in unknown error code in Windows error message""" ctypes = import_module('ctypes') diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 939315379f2..bb558524c24 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -939,7 +939,6 @@ def get_file_system(self, path): return buf.value # return None if the filesystem is unknown - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; (ModuleNotFoundError: No module named '_ctypes')") def test_large_time(self): # Many filesystems are limited to the year 2038. At least, the test # pass with NTFS filesystem. @@ -2712,12 +2711,10 @@ def _kill(self, sig): os.kill(proc.pid, sig) self.assertEqual(proc.wait(), sig) - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; (ModuleNotFoundError: No module named '_ctypes')") def test_kill_sigterm(self): # SIGTERM doesn't mean anything special, but make sure it works self._kill(signal.SIGTERM) - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; (ModuleNotFoundError: No module named '_ctypes')") def test_kill_int(self): # os.kill on Windows can take an int which gets set as the exit code self._kill(100) From 7c7e55ffc4072288f04f41f5cb2068d5953afc79 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 23 Dec 2025 16:39:04 +0900 Subject: [PATCH 039/418] Upgrade libffi --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- crates/jit/src/lib.rs | 27 ++++++++++++------------- crates/vm/src/stdlib/ctypes.rs | 2 +- crates/vm/src/stdlib/ctypes/array.rs | 8 -------- crates/vm/src/stdlib/ctypes/base.rs | 2 +- crates/vm/src/stdlib/ctypes/function.rs | 2 +- 7 files changed, 21 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e3296732cb9..3620786402c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1687,9 +1687,9 @@ checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libffi" -version = "4.1.2" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0feebbe0ccd382a2790f78d380540500d7b78ed7a3498b68fcfbc1593749a94" +checksum = "0444124f3ffd67e1b0b0c661a7f81a278a135eb54aaad4078e79fbc8be50c8a5" dependencies = [ "libc", "libffi-sys", @@ -1697,9 +1697,9 @@ dependencies = [ [[package]] name = "libffi-sys" -version = "3.3.3" +version = "4.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90c6c6e17136d4bc439d43a2f3c6ccf0731cccc016d897473a29791d3c2160c3" +checksum = "3d722da8817ea580d0669da6babe2262d7b86a1af1103da24102b8bb9c101ce7" dependencies = [ "cc", ] diff --git a/Cargo.toml b/Cargo.toml index fad506ddfaa..44f9d3190f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -178,7 +178,7 @@ itertools = "0.14.0" is-macro = "0.3.7" junction = "1.3.0" libc = "0.2.178" -libffi = "4.1" +libffi = "5" log = "0.4.29" nix = { version = "0.30", features = ["fs", "user", "process", "term", "time", "signal", "ioctl", "socket", "sched", "zerocopy", "dir", "hostname", "net", "poll"] } malachite-bigint = "0.8" diff --git a/crates/jit/src/lib.rs b/crates/jit/src/lib.rs index 91911fd8d14..65ef87a62f6 100644 --- a/crates/jit/src/lib.rs +++ b/crates/jit/src/lib.rs @@ -157,7 +157,7 @@ impl CompiledCode { Ok(unsafe { self.invoke_raw(&cif_args) }) } - unsafe fn invoke_raw(&self, cif_args: &[libffi::middle::Arg]) -> Option { + unsafe fn invoke_raw(&self, cif_args: &[libffi::middle::Arg<'_>]) -> Option { unsafe { let cif = self.sig.to_cif(); let value = cif.call::( @@ -219,7 +219,7 @@ pub enum AbiValue { } impl AbiValue { - fn to_libffi_arg(&self) -> libffi::middle::Arg { + fn to_libffi_arg(&self) -> libffi::middle::Arg<'_> { match self { AbiValue::Int(i) => libffi::middle::Arg::new(i), AbiValue::Float(f) => libffi::middle::Arg::new(f), @@ -350,26 +350,25 @@ impl<'a> ArgsBuilder<'a> { } pub fn into_args(self) -> Option> { - self.values - .iter() - .map(|v| v.as_ref().map(AbiValue::to_libffi_arg)) - .collect::>() - .map(|cif_args| Args { - _values: self.values, - cif_args, - code: self.code, - }) + // Ensure all values are set + if self.values.iter().any(|v| v.is_none()) { + return None; + } + Some(Args { + values: self.values.into_iter().map(|v| v.unwrap()).collect(), + code: self.code, + }) } } pub struct Args<'a> { - _values: Vec>, - cif_args: Vec, + values: Vec, code: &'a CompiledCode, } impl Args<'_> { pub fn invoke(&self) -> Option { - unsafe { self.code.invoke_raw(&self.cif_args) } + let cif_args: Vec<_> = self.values.iter().map(AbiValue::to_libffi_arg).collect(); + unsafe { self.code.invoke_raw(&cif_args) } } } diff --git a/crates/vm/src/stdlib/ctypes.rs b/crates/vm/src/stdlib/ctypes.rs index 9b922230431..a9c0636bd12 100644 --- a/crates/vm/src/stdlib/ctypes.rs +++ b/crates/vm/src/stdlib/ctypes.rs @@ -1102,7 +1102,7 @@ pub(crate) mod _ctypes { return Err(vm.new_value_error("NULL function pointer")); } - let mut ffi_args: Vec = Vec::with_capacity(args.len()); + let mut ffi_args: Vec> = Vec::with_capacity(args.len()); let mut arg_values: Vec = Vec::with_capacity(args.len()); let mut arg_types: Vec = Vec::with_capacity(args.len()); diff --git a/crates/vm/src/stdlib/ctypes/array.rs b/crates/vm/src/stdlib/ctypes/array.rs index 208b3e3f4d3..f31c8284d8b 100644 --- a/crates/vm/src/stdlib/ctypes/array.rs +++ b/crates/vm/src/stdlib/ctypes/array.rs @@ -1058,14 +1058,6 @@ impl PyCArray { } } -impl PyCArray { - #[allow(unused)] - pub fn to_arg(&self, _vm: &VirtualMachine) -> PyResult { - let buffer = self.0.buffer.read(); - Ok(libffi::middle::Arg::new(&*buffer)) - } -} - impl AsBuffer for PyCArray { fn as_buffer(zelf: &Py, _vm: &VirtualMachine) -> PyResult { let buffer_len = zelf.0.buffer.read().len(); diff --git a/crates/vm/src/stdlib/ctypes/base.rs b/crates/vm/src/stdlib/ctypes/base.rs index 44793a21561..0f859b3d10b 100644 --- a/crates/vm/src/stdlib/ctypes/base.rs +++ b/crates/vm/src/stdlib/ctypes/base.rs @@ -1821,7 +1821,7 @@ pub enum FfiArgValue { impl FfiArgValue { /// Create an Arg reference to this owned value - pub fn as_arg(&self) -> libffi::middle::Arg { + pub fn as_arg(&self) -> libffi::middle::Arg<'_> { match self { FfiArgValue::U8(v) => libffi::middle::Arg::new(v), FfiArgValue::I8(v) => libffi::middle::Arg::new(v), diff --git a/crates/vm/src/stdlib/ctypes/function.rs b/crates/vm/src/stdlib/ctypes/function.rs index 04ff238ebcf..55a42f0ba15 100644 --- a/crates/vm/src/stdlib/ctypes/function.rs +++ b/crates/vm/src/stdlib/ctypes/function.rs @@ -1449,7 +1449,7 @@ enum RawResult { fn ctypes_callproc(code_ptr: CodePtr, arguments: &[Argument], call_info: &CallInfo) -> RawResult { let ffi_arg_types: Vec = arguments.iter().map(|a| a.ffi_type.clone()).collect(); let cif = Cif::new(ffi_arg_types, call_info.ffi_return_type.clone()); - let ffi_args: Vec = arguments.iter().map(|a| a.value.as_arg()).collect(); + let ffi_args: Vec> = arguments.iter().map(|a| a.value.as_arg()).collect(); if call_info.restype_is_none { unsafe { cif.call::<()>(code_ptr, &ffi_args) }; From 00205aad1408df75a4b3a5ef648b894760459920 Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Wed, 24 Dec 2025 13:10:57 +0900 Subject: [PATCH 040/418] Bump libsqlite3-sys from 0.28 to 0.36 (#6472) * Bump libsqlite3-sys from 0.28 to 0.36 Update libsqlite3-sys to version 0.36 and adapt to API changes by replacing sqlite3_close_v2 with sqlite3_close. The v2 variant is no longer directly exported in the newer version. Fixes #6471 * Fix clippy --- Cargo.lock | 4 ++-- crates/stdlib/Cargo.toml | 2 +- crates/stdlib/src/sqlite.rs | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3620786402c..96fc9aa8d29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1742,9 +1742,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.28.0" +version = "0.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +checksum = "95b4103cffefa72eb8428cb6b47d6627161e51c2739fc5e3b734584157bc642a" dependencies = [ "cc", "pkg-config", diff --git a/crates/stdlib/Cargo.toml b/crates/stdlib/Cargo.toml index 0cd853223e2..4885319ea17 100644 --- a/crates/stdlib/Cargo.toml +++ b/crates/stdlib/Cargo.toml @@ -132,7 +132,7 @@ oid-registry = { version = "0.8", features = ["x509", "pkcs1", "nist_algs"], opt pkcs8 = { version = "0.10", features = ["encryption", "pkcs5", "pem"], optional = true } [target.'cfg(not(any(target_os = "android", target_arch = "wasm32")))'.dependencies] -libsqlite3-sys = { version = "0.28", features = ["bundled"], optional = true } +libsqlite3-sys = { version = "0.36", features = ["bundled"], optional = true } lzma-sys = "0.1" xz2 = "0.1" diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index ffdc9eb3831..2328f23430a 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -30,7 +30,7 @@ mod _sqlite { sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name, sqlite3_bind_text, sqlite3_blob, sqlite3_blob_bytes, sqlite3_blob_close, sqlite3_blob_open, sqlite3_blob_read, sqlite3_blob_write, sqlite3_busy_timeout, sqlite3_changes, - sqlite3_close_v2, sqlite3_column_blob, sqlite3_column_bytes, sqlite3_column_count, + sqlite3_close, sqlite3_column_blob, sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int64, sqlite3_column_name, sqlite3_column_text, sqlite3_column_type, sqlite3_complete, sqlite3_context, sqlite3_context_db_handle, sqlite3_create_collation_v2, sqlite3_create_function_v2, @@ -1349,14 +1349,14 @@ mod _sqlite { fn set_trace_callback(&self, callable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { let db = self.db_lock(vm)?; let Some(data) = CallbackData::new(callable, vm) else { - unsafe { sqlite3_trace_v2(db.db, SQLITE_TRACE_STMT as u32, None, null_mut()) }; + unsafe { sqlite3_trace_v2(db.db, SQLITE_TRACE_STMT, None, null_mut()) }; return Ok(()); }; let ret = unsafe { sqlite3_trace_v2( db.db, - SQLITE_TRACE_STMT as u32, + SQLITE_TRACE_STMT, Some(CallbackData::trace_callback), Box::into_raw(Box::new(data)).cast(), ) @@ -2661,7 +2661,7 @@ mod _sqlite { impl Drop for Sqlite { fn drop(&mut self) { - unsafe { sqlite3_close_v2(self.raw.db) }; + unsafe { sqlite3_close(self.raw.db) }; } } From 215c5c6d7baa560e2687055443c3b8c9395d1441 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 24 Dec 2025 13:15:33 +0900 Subject: [PATCH 041/418] signal.pthread_sigmask (#6475) --- crates/vm/src/stdlib/signal.rs | 89 +++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 6 deletions(-) diff --git a/crates/vm/src/stdlib/signal.rs b/crates/vm/src/stdlib/signal.rs index 6771a950400..d34a681c2da 100644 --- a/crates/vm/src/stdlib/signal.rs +++ b/crates/vm/src/stdlib/signal.rs @@ -69,6 +69,11 @@ pub(crate) mod _signal { #[pyattr] pub use libc::{SIG_DFL, SIG_IGN}; + // pthread_sigmask 'how' constants + #[cfg(unix)] + #[pyattr] + use libc::{SIG_BLOCK, SIG_SETMASK, SIG_UNBLOCK}; + #[cfg(not(unix))] #[pyattr] pub const SIG_DFL: sighandler_t = 0; @@ -400,14 +405,17 @@ pub(crate) mod _signal { let set = PySet::default().into_ref(&vm.ctx); #[cfg(unix)] { - // On Unix, most signals 1..NSIG are valid + // Use sigfillset to get all valid signals + let mut mask: libc::sigset_t = unsafe { std::mem::zeroed() }; + // SAFETY: mask is a valid pointer + if unsafe { libc::sigfillset(&mut mask) } != 0 { + return Err(vm.new_os_error("sigfillset failed".to_owned())); + } + // Convert the filled mask to a Python set for signum in 1..signal::NSIG { - // Skip signals that cannot be caught - #[cfg(not(target_os = "wasi"))] - if signum == libc::SIGKILL as usize || signum == libc::SIGSTOP as usize { - continue; + if unsafe { libc::sigismember(&mask, signum as i32) } == 1 { + set.add(vm.ctx.new_int(signum as i32).into(), vm)?; } - set.add(vm.ctx.new_int(signum as i32).into(), vm)?; } } #[cfg(windows)] @@ -432,6 +440,75 @@ pub(crate) mod _signal { Ok(set.into()) } + #[cfg(unix)] + fn sigset_to_pyset(mask: &libc::sigset_t, vm: &VirtualMachine) -> PyResult { + use crate::PyPayload; + use crate::builtins::PySet; + let set = PySet::default().into_ref(&vm.ctx); + for signum in 1..signal::NSIG { + // SAFETY: mask is a valid sigset_t + if unsafe { libc::sigismember(mask, signum as i32) } == 1 { + set.add(vm.ctx.new_int(signum as i32).into(), vm)?; + } + } + Ok(set.into()) + } + + #[cfg(unix)] + #[pyfunction] + fn pthread_sigmask( + how: i32, + mask: crate::function::ArgIterable, + vm: &VirtualMachine, + ) -> PyResult { + use crate::convert::IntoPyException; + + // Initialize sigset + let mut sigset: libc::sigset_t = unsafe { std::mem::zeroed() }; + // SAFETY: sigset is a valid pointer + if unsafe { libc::sigemptyset(&mut sigset) } != 0 { + return Err(std::io::Error::last_os_error().into_pyexception(vm)); + } + + // Add signals to the set + for sig in mask.iter(vm)? { + let sig = sig?; + // Convert to i32, handling overflow by returning ValueError + let signum: i32 = sig.try_to_value(vm).map_err(|_| { + vm.new_value_error(format!( + "signal number out of range [1, {}]", + signal::NSIG - 1 + )) + })?; + // Validate signal number is in range [1, NSIG) + if signum < 1 || signum >= signal::NSIG as i32 { + return Err(vm.new_value_error(format!( + "signal number {} out of range [1, {}]", + signum, + signal::NSIG - 1 + ))); + } + // SAFETY: sigset is a valid pointer and signum is validated + if unsafe { libc::sigaddset(&mut sigset, signum) } != 0 { + return Err(std::io::Error::last_os_error().into_pyexception(vm)); + } + } + + // Call pthread_sigmask + let mut old_mask: libc::sigset_t = unsafe { std::mem::zeroed() }; + // SAFETY: all pointers are valid + let err = unsafe { libc::pthread_sigmask(how, &sigset, &mut old_mask) }; + if err != 0 { + return Err(std::io::Error::from_raw_os_error(err).into_pyexception(vm)); + } + + // Check for pending signals + signal::check_signals(vm)?; + + // Convert old mask to Python set + sigset_to_pyset(&old_mask, vm) + } + #[cfg(any(unix, windows))] pub extern "C" fn run_signal(signum: i32) { signal::TRIGGERS[signum as usize].store(true, Ordering::Relaxed); From c763d67ef4d374a26b155e50bd9991ea4b4ab89c Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 24 Dec 2025 13:50:26 +0900 Subject: [PATCH 042/418] Fix dict.keys behavior (#6476) --- crates/vm/src/builtins/dict.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index 04915b035f8..aff7432d067 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -70,13 +70,21 @@ impl PyDict { Err(other) => other, }; let dict = &self.entries; - if let Some(keys) = vm.get_method(other.clone(), vm.ctx.intern_str("keys")) { - let keys = keys?.call((), vm)?.get_iter(vm)?; - while let PyIterReturn::Return(key) = keys.next(vm)? { - let val = other.get_item(&*key, vm)?; - dict.insert(vm, &*key, val)?; + // Use get_attr to properly invoke __getattribute__ for proxy objects + let keys_result = other.get_attr(vm.ctx.intern_str("keys"), vm); + let has_keys = match keys_result { + Ok(keys_method) => { + let keys = keys_method.call((), vm)?.get_iter(vm)?; + while let PyIterReturn::Return(key) = keys.next(vm)? { + let val = other.get_item(&*key, vm)?; + dict.insert(vm, &*key, val)?; + } + true } - } else { + Err(e) if e.fast_isinstance(vm.ctx.exceptions.attribute_error) => false, + Err(e) => return Err(e), + }; + if !has_keys { let iter = other.get_iter(vm)?; loop { fn err(vm: &VirtualMachine) -> PyBaseExceptionRef { From 014622ac346847883c9fc94793f1dfccb01363bf Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 24 Dec 2025 14:23:33 +0900 Subject: [PATCH 043/418] Fix os.access not to raise exception when path doesn't exist (#6477) * Fix os.access not to raise exception when path doesn't exist * add test --- crates/vm/src/stdlib/posix.rs | 11 ++++++----- extra_tests/snippets/stdlib_os.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/crates/vm/src/stdlib/posix.rs b/crates/vm/src/stdlib/posix.rs index cfe605733d3..59e41782574 100644 --- a/crates/vm/src/stdlib/posix.rs +++ b/crates/vm/src/stdlib/posix.rs @@ -405,16 +405,17 @@ pub mod module { ) })?; - let metadata = fs::metadata(&path.path); + let metadata = match fs::metadata(&path.path) { + Ok(m) => m, + // If the file doesn't exist, return False for any access check + Err(_) => return Ok(false), + }; // if it's only checking for F_OK if flags == AccessFlags::F_OK { - return Ok(metadata.is_ok()); + return Ok(true); // File exists } - let metadata = - metadata.map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; - let user_id = metadata.uid(); let group_id = metadata.gid(); let mode = metadata.mode(); diff --git a/extra_tests/snippets/stdlib_os.py b/extra_tests/snippets/stdlib_os.py index a538365f707..d00924e10f2 100644 --- a/extra_tests/snippets/stdlib_os.py +++ b/extra_tests/snippets/stdlib_os.py @@ -518,3 +518,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): if option in ["PC_MAX_CANON", "PC_MAX_INPUT", "PC_VDISABLE"]: continue assert os.pathconf("/", index) == os.pathconf("/", option) + +# os.access - test with empty path and nonexistent files +assert os.access("", os.F_OK) is False +assert os.access("", os.R_OK) is False +assert os.access("", os.W_OK) is False +assert os.access("", os.X_OK) is False +assert os.access("nonexistent_file_12345", os.F_OK) is False +assert os.access("nonexistent_file_12345", os.W_OK) is False +assert os.access("README.md", os.F_OK) is True +assert os.access("README.md", os.R_OK) is True From 309b2ad32d0cebf25144340f4f5f905147a6b532 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 24 Dec 2025 16:54:54 +0900 Subject: [PATCH 044/418] Fix ast end_location (#6478) --- Lib/test/test_ast/test_ast.py | 3 -- crates/vm/src/stdlib/ast.rs | 44 +++++++++++++++++++++++---- crates/vm/src/stdlib/ast/statement.rs | 11 ++++--- 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index 09d9444d5d9..628c243f2ff 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -191,7 +191,6 @@ def test_invalid_position_information(self): with self.assertRaises(ValueError): compile(tree, "", "exec") - @unittest.expectedFailure # TODO: RUSTPYTHON def test_compilation_of_ast_nodes_with_default_end_position_values(self): tree = ast.Module( body=[ @@ -212,7 +211,6 @@ def test_compilation_of_ast_nodes_with_default_end_position_values(self): # Check that compilation doesn't crash. Note: this may crash explicitly only on debug mode. compile(tree, "", "exec") - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: required field "end_lineno" missing from alias def test_negative_locations_for_compile(self): # See https://github.com/python/cpython/issues/130775 alias = ast.alias(name='traceback', lineno=0, col_offset=0) @@ -1725,7 +1723,6 @@ def test_bad_integer(self): compile(mod, "test", "exec") self.assertIn("invalid integer value: None", str(cm.exception)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_level_as_none(self): body = [ ast.ImportFrom( diff --git a/crates/vm/src/stdlib/ast.rs b/crates/vm/src/stdlib/ast.rs index 00aad0213f3..8e03cb225ea 100644 --- a/crates/vm/src/stdlib/ast.rs +++ b/crates/vm/src/stdlib/ast.rs @@ -160,6 +160,20 @@ fn text_range_to_source_range(source_file: &SourceFile, text_range: TextRange) - } } +fn get_opt_int_field( + vm: &VirtualMachine, + obj: &PyObject, + field: &'static str, +) -> PyResult>> { + match get_node_field_opt(vm, obj, field)? { + Some(val) => val + .downcast_exact(vm) + .map(Some) + .map_err(|_| vm.new_type_error(format!(r#"field "{field}" must have integer type"#))), + None => Ok(None), + } +} + fn range_from_object( vm: &VirtualMachine, source_file: &SourceFile, @@ -168,17 +182,35 @@ fn range_from_object( ) -> PyResult { let start_row = get_int_field(vm, &object, "lineno", name)?; let start_column = get_int_field(vm, &object, "col_offset", name)?; - let end_row = get_int_field(vm, &object, "end_lineno", name)?; - let end_column = get_int_field(vm, &object, "end_col_offset", name)?; + // end_lineno and end_col_offset are optional, default to start values + let end_row = + get_opt_int_field(vm, &object, "end_lineno")?.unwrap_or_else(|| start_row.clone()); + let end_column = + get_opt_int_field(vm, &object, "end_col_offset")?.unwrap_or_else(|| start_column.clone()); + + // lineno=0 or negative values as a special case (no location info). + // Use default values (line 1, col 0) when lineno <= 0. + let start_row_val: i32 = start_row.try_to_primitive(vm)?; + let end_row_val: i32 = end_row.try_to_primitive(vm)?; + let start_col_val: i32 = start_column.try_to_primitive(vm)?; + let end_col_val: i32 = end_column.try_to_primitive(vm)?; let location = PySourceRange { start: PySourceLocation { - row: Row(OneIndexed::new(start_row.try_to_primitive(vm)?).unwrap()), - column: Column(TextSize::new(start_column.try_to_primitive(vm)?)), + row: Row(if start_row_val > 0 { + OneIndexed::new(start_row_val as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + }), + column: Column(TextSize::new(start_col_val.max(0) as u32)), }, end: PySourceLocation { - row: Row(OneIndexed::new(end_row.try_to_primitive(vm)?).unwrap()), - column: Column(TextSize::new(end_column.try_to_primitive(vm)?)), + row: Row(if end_row_val > 0 { + OneIndexed::new(end_row_val as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + }), + column: Column(TextSize::new(end_col_val.max(0) as u32)), }, }; diff --git a/crates/vm/src/stdlib/ast/statement.rs b/crates/vm/src/stdlib/ast/statement.rs index 5925ca1fc2a..6d9b35bee79 100644 --- a/crates/vm/src/stdlib/ast/statement.rs +++ b/crates/vm/src/stdlib/ast/statement.rs @@ -1105,10 +1105,13 @@ impl Node for ruff::StmtImportFrom { source_file, get_node_field(vm, &_object, "names", "ImportFrom")?, )?, - level: get_node_field(vm, &_object, "level", "ImportFrom")? - .downcast_exact::(vm) - .unwrap() - .try_to_primitive::(vm)?, + level: get_node_field_opt(vm, &_object, "level")? + .map(|obj| -> PyResult { + let int: PyRef = obj.try_into_value(vm)?; + int.try_to_primitive(vm) + }) + .transpose()? + .unwrap_or(0), range: range_from_object(vm, source_file, _object, "ImportFrom")?, }) } From 4f0b940b160759dcef15c759971f730dfc18e3ba Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 24 Dec 2025 17:02:21 +0900 Subject: [PATCH 045/418] impl preexec_fn (#6479) --- Lib/test/test_subprocess.py | 12 ----------- crates/stdlib/src/posixsubprocess.rs | 32 ++++++++++++++++++++++------ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py index e04f8b8fcc4..3917c0a76d9 100644 --- a/Lib/test/test_subprocess.py +++ b/Lib/test/test_subprocess.py @@ -2244,8 +2244,6 @@ def test_CalledProcessError_str_non_zero(self): error_string = str(err) self.assertIn("non-zero exit status 2.", error_string) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_preexec(self): # DISCLAIMER: Setting environment variables is *not* a good use # of a preexec_fn. This is merely a test. @@ -2257,8 +2255,6 @@ def test_preexec(self): with p: self.assertEqual(p.stdout.read(), b"apple") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_preexec_exception(self): def raise_it(): raise ValueError("What if two swallows carried a coconut?") @@ -2300,8 +2296,6 @@ def _execute_child(self, *args, **kwargs): for fd in devzero_fds: os.close(fd) - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipIf(not os.path.exists("/dev/zero"), "/dev/zero required.") def test_preexec_errpipe_does_not_double_close_pipes(self): """Issue16140: Don't double close pipes on preexec error.""" @@ -2339,8 +2333,6 @@ def test_preexec_gc_module_failure(self): if not enabled: gc.disable() - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipIf( sys.platform == 'darwin', 'setrlimit() seems to fail on OS X') def test_preexec_fork_failure(self): @@ -2751,8 +2743,6 @@ def test_swap_std_fds_with_one_closed(self): for to_fds in itertools.permutations(range(3), 2): self._check_swap_std_fds_with_one_closed(from_fds, to_fds) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogates_error_message(self): def prepare(): raise ValueError("surrogate:\uDCff") @@ -3228,8 +3218,6 @@ def test_leak_fast_process_del_killed(self): else: self.assertNotIn(ident, [id(o) for o in subprocess._active]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_close_fds_after_preexec(self): fd_status = support.findfile("fd_status.py", subdir="subprocessdata") diff --git a/crates/stdlib/src/posixsubprocess.rs b/crates/stdlib/src/posixsubprocess.rs index 4da6a6858dd..d05b24fd6dd 100644 --- a/crates/stdlib/src/posixsubprocess.rs +++ b/crates/stdlib/src/posixsubprocess.rs @@ -33,9 +33,6 @@ mod _posixsubprocess { #[pyfunction] fn fork_exec(args: ForkExecArgs<'_>, vm: &VirtualMachine) -> PyResult { - if args.preexec_fn.is_some() { - return Err(vm.new_not_implemented_error("preexec_fn not supported yet")); - } let extra_groups = args .groups_list .as_ref() @@ -49,7 +46,7 @@ mod _posixsubprocess { extra_groups: extra_groups.as_deref(), }; match unsafe { nix::unistd::fork() }.map_err(|err| err.into_pyexception(vm))? { - nix::unistd::ForkResult::Child => exec(&args, procargs), + nix::unistd::ForkResult::Child => exec(&args, procargs, vm), nix::unistd::ForkResult::Parent { child } => Ok(child.as_raw()), } } @@ -227,13 +224,19 @@ struct ProcArgs<'a> { extra_groups: Option<&'a [Gid]>, } -fn exec(args: &ForkExecArgs<'_>, procargs: ProcArgs<'_>) -> ! { +fn exec(args: &ForkExecArgs<'_>, procargs: ProcArgs<'_>, vm: &VirtualMachine) -> ! { let mut ctx = ExecErrorContext::NoExec; - match exec_inner(args, procargs, &mut ctx) { + match exec_inner(args, procargs, &mut ctx, vm) { Ok(x) => match x {}, Err(e) => { let mut pipe = args.errpipe_write; - let _ = write!(pipe, "OSError:{}:{}", e as i32, ctx.as_msg()); + if matches!(ctx, ExecErrorContext::PreExec) { + // For preexec_fn errors, use SubprocessError format (errno=0) + let _ = write!(pipe, "SubprocessError:0:{}", ctx.as_msg()); + } else { + // errno is written in hex format + let _ = write!(pipe, "OSError:{:x}:{}", e as i32, ctx.as_msg()); + } std::process::exit(255) } } @@ -242,6 +245,7 @@ fn exec(args: &ForkExecArgs<'_>, procargs: ProcArgs<'_>) -> ! { enum ExecErrorContext { NoExec, ChDir, + PreExec, Exec, } @@ -250,6 +254,7 @@ impl ExecErrorContext { match self { Self::NoExec => "noexec", Self::ChDir => "noexec:chdir", + Self::PreExec => "Exception occurred in preexec_fn.", Self::Exec => "", } } @@ -259,6 +264,7 @@ fn exec_inner( args: &ForkExecArgs<'_>, procargs: ProcArgs<'_>, ctx: &mut ExecErrorContext, + vm: &VirtualMachine, ) -> nix::Result { for &fd in args.fds_to_keep.as_slice() { if fd.as_raw_fd() != args.errpipe_write.as_raw_fd() { @@ -345,6 +351,18 @@ fn exec_inner( nix::Error::result(ret)?; } + // Call preexec_fn after all process setup but before closing FDs + if let Some(ref preexec_fn) = args.preexec_fn { + match preexec_fn.call((), vm) { + Ok(_) => {} + Err(_e) => { + // Cannot safely stringify exception after fork + *ctx = ExecErrorContext::PreExec; + return Err(Errno::UnknownErrno); + } + } + } + *ctx = ExecErrorContext::Exec; if args.close_fds { From 3d7e521acdc07acd9572692fc5b36873eb5118db Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 24 Dec 2025 17:29:47 +0900 Subject: [PATCH 046/418] introduce slot_wrapper (#4884) --- Lib/test/test_descr.py | 1 - Lib/test/test_inspect/test_inspect.py | 3 +- Lib/test/test_types.py | 2 + Lib/test/test_weakref.py | 2 - crates/vm/src/builtins/descriptor.rs | 204 +++++++++++++++++++++++++- crates/vm/src/builtins/object.rs | 41 +++++- crates/vm/src/builtins/weakref.rs | 24 ++- crates/vm/src/class.rs | 17 ++- crates/vm/src/stdlib/io.rs | 7 +- crates/vm/src/types/slot.rs | 6 +- crates/vm/src/types/zoo.rs | 4 + 11 files changed, 291 insertions(+), 20 deletions(-) diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index 8c711207fae..88d120a427f 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -4941,7 +4941,6 @@ def __init__(self): for o in gc.get_objects(): self.assertIsNot(type(o), X) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_object_new_and_init_with_parameters(self): # See issue #1683368 class OverrideNeither: diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py index d0fec18250e..e403ab7b226 100644 --- a/Lib/test/test_inspect/test_inspect.py +++ b/Lib/test/test_inspect/test_inspect.py @@ -409,6 +409,7 @@ class NotFuture: pass coro.close(); gen_coro.close() # silence warnings + @unittest.expectedFailure # TODO: RUSTPYTHON def test_isroutine(self): # method self.assertTrue(inspect.isroutine(git.argue)) @@ -1483,8 +1484,6 @@ def test_getfullargspec_definition_order_preserved_on_kwonly(self): l = list(signature.kwonlyargs) self.assertEqual(l, unsorted_keyword_only_parameters) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_classify_newstyle(self): class A(object): diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index fdcb9060e83..7e9b28236c7 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -597,6 +597,7 @@ def test_internal_sizes(self): self.assertGreater(object.__basicsize__, 0) self.assertGreater(tuple.__itemsize__, 0) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_slot_wrapper_types(self): self.assertIsInstance(object.__init__, types.WrapperDescriptorType) self.assertIsInstance(object.__str__, types.WrapperDescriptorType) @@ -611,6 +612,7 @@ def test_dunder_get_signature(self): # gh-93021: Second parameter is optional self.assertIs(sig.parameters["owner"].default, None) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_method_wrapper_types(self): self.assertIsInstance(object().__init__, types.MethodWrapperType) self.assertIsInstance(object().__str__, types.MethodWrapperType) diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index e7cd5962cf9..f47c17b7234 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -906,8 +906,6 @@ def __del__(self): w = Target() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_init(self): # Issue 3634 # .__init__() doesn't check errors correctly diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index cdcc456edfc..2ccd1dcc0e9 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -3,8 +3,11 @@ use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyTypeRef, builtin_func::PyNativeMethod, type_}, class::PyClassImpl, + common::hash::PyHash, function::{FuncArgs, PyMethodDef, PyMethodFlags, PySetterValue}, - types::{Callable, GetDescriptor, Representable}, + types::{ + Callable, Comparable, GetDescriptor, Hashable, InitFunc, PyComparisonOp, Representable, + }, }; use rustpython_common::lock::PyRwLock; @@ -219,7 +222,7 @@ impl std::fmt::Debug for PyMemberDef { } } -// PyMemberDescrObject in CPython +// = PyMemberDescrObject #[pyclass(name = "member_descriptor", module = false)] #[derive(Debug)] pub struct PyMemberDescriptor { @@ -382,4 +385,201 @@ impl GetDescriptor for PyMemberDescriptor { pub fn init(ctx: &Context) { PyMemberDescriptor::extend_class(ctx, ctx.types.member_descriptor_type); PyMethodDescriptor::extend_class(ctx, ctx.types.method_descriptor_type); + PySlotWrapper::extend_class(ctx, ctx.types.wrapper_descriptor_type); + PyMethodWrapper::extend_class(ctx, ctx.types.method_wrapper_type); +} + +// PySlotWrapper - wrapper_descriptor + +/// wrapper_descriptor: wraps a slot function as a Python method +// = PyWrapperDescrObject +#[pyclass(name = "wrapper_descriptor", module = false)] +#[derive(Debug)] +pub struct PySlotWrapper { + pub typ: &'static Py, + pub name: &'static PyStrInterned, + pub wrapped: InitFunc, + pub doc: Option<&'static str>, +} + +impl PyPayload for PySlotWrapper { + fn class(ctx: &Context) -> &'static Py { + ctx.types.wrapper_descriptor_type + } +} + +impl GetDescriptor for PySlotWrapper { + fn descr_get( + zelf: PyObjectRef, + obj: Option, + _cls: Option, + vm: &VirtualMachine, + ) -> PyResult { + match obj { + None => Ok(zelf), + Some(obj) if vm.is_none(&obj) => Ok(zelf), + Some(obj) => { + let zelf = zelf.downcast::().unwrap(); + Ok(PyMethodWrapper { wrapper: zelf, obj }.into_pyobject(vm)) + } + } + } +} + +impl Callable for PySlotWrapper { + type Args = FuncArgs; + + fn call(zelf: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + // list.__init__(l, [1,2,3]) form + let (obj, rest): (PyObjectRef, FuncArgs) = args.bind(vm)?; + + if !obj.fast_isinstance(zelf.typ) { + return Err(vm.new_type_error(format!( + "descriptor '{}' requires a '{}' object but received a '{}'", + zelf.name.as_str(), + zelf.typ.name(), + obj.class().name() + ))); + } + + (zelf.wrapped)(obj, rest, vm)?; + Ok(vm.ctx.none()) + } +} + +#[pyclass( + with(GetDescriptor, Callable, Representable), + flags(DISALLOW_INSTANTIATION) +)] +impl PySlotWrapper { + #[pygetset] + fn __name__(&self) -> &'static PyStrInterned { + self.name + } + + #[pygetset] + fn __qualname__(&self) -> String { + format!("{}.{}", self.typ.name(), self.name) + } + + #[pygetset] + fn __objclass__(&self) -> PyTypeRef { + self.typ.to_owned() + } + + #[pygetset] + fn __doc__(&self) -> Option<&'static str> { + self.doc + } +} + +impl Representable for PySlotWrapper { + #[inline] + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + Ok(format!( + "", + zelf.name.as_str(), + zelf.typ.name() + )) + } +} + +// PyMethodWrapper - method-wrapper + +/// method-wrapper: a slot wrapper bound to an instance +/// Returned when accessing l.__init__ on an instance +#[pyclass(name = "method-wrapper", module = false, traverse)] +#[derive(Debug)] +pub struct PyMethodWrapper { + pub wrapper: PyRef, + #[pytraverse(skip)] + pub obj: PyObjectRef, +} + +impl PyPayload for PyMethodWrapper { + fn class(ctx: &Context) -> &'static Py { + ctx.types.method_wrapper_type + } +} + +impl Callable for PyMethodWrapper { + type Args = FuncArgs; + + fn call(zelf: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + (zelf.wrapper.wrapped)(zelf.obj.clone(), args, vm)?; + Ok(vm.ctx.none()) + } +} + +#[pyclass( + with(Callable, Representable, Hashable, Comparable), + flags(DISALLOW_INSTANTIATION) +)] +impl PyMethodWrapper { + #[pygetset] + fn __self__(&self) -> PyObjectRef { + self.obj.clone() + } + + #[pygetset] + fn __name__(&self) -> &'static PyStrInterned { + self.wrapper.name + } + + #[pygetset] + fn __objclass__(&self) -> PyTypeRef { + self.wrapper.typ.to_owned() + } + + #[pymethod] + fn __reduce__(zelf: PyRef, vm: &VirtualMachine) -> PyResult { + let builtins_getattr = vm.builtins.get_attr("getattr", vm)?; + Ok(vm + .ctx + .new_tuple(vec![ + builtins_getattr, + vm.ctx + .new_tuple(vec![ + zelf.obj.clone(), + vm.ctx.new_str(zelf.wrapper.name.as_str()).into(), + ]) + .into(), + ]) + .into()) + } +} + +impl Representable for PyMethodWrapper { + #[inline] + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + Ok(format!( + "", + zelf.wrapper.name.as_str(), + zelf.obj.class().name(), + zelf.obj.get_id() + )) + } +} + +impl Hashable for PyMethodWrapper { + fn hash(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let obj_hash = zelf.obj.hash(vm)?; + let wrapper_hash = zelf.wrapper.as_object().get_id() as PyHash; + Ok(obj_hash ^ wrapper_hash) + } +} + +impl Comparable for PyMethodWrapper { + fn cmp( + zelf: &Py, + other: &PyObject, + op: PyComparisonOp, + vm: &VirtualMachine, + ) -> PyResult { + op.eq_only(|| { + let other = class_or_notimplemented!(Self, other); + let eq = zelf.wrapper.is(&other.wrapper) && vm.bool_eq(&zelf.obj, &other.obj)?; + Ok(eq.into()) + }) + } } diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs index 0970496c7b1..854cb2701ed 100644 --- a/crates/vm/src/builtins/object.rs +++ b/crates/vm/src/builtins/object.rs @@ -118,7 +118,46 @@ impl Constructor for PyBaseObject { impl Initializer for PyBaseObject { type Args = FuncArgs; - fn slot_init(_zelf: PyObjectRef, _args: FuncArgs, _vm: &VirtualMachine) -> PyResult<()> { + // object_init: excess_args validation + fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + let typ = zelf.class(); + let object_type = &vm.ctx.types.object_type; + + let typ_init = typ.slots.init.load().map(|f| f as usize); + let object_init = object_type.slots.init.load().map(|f| f as usize); + let typ_new = typ.slots.new.load().map(|f| f as usize); + let object_new = object_type.slots.new.load().map(|f| f as usize); + + // For heap types (Python classes), check if __new__ is defined anywhere in MRO + // (before object) because heap types always have slots.new = new_wrapper via MRO + let is_heap_type = typ + .slots + .flags + .contains(crate::types::PyTypeFlags::HEAPTYPE); + let new_overridden = if is_heap_type { + // Check if __new__ is defined in any base class (excluding object) + let new_id = identifier!(vm, __new__); + typ.mro_collect() + .into_iter() + .take_while(|t| !std::ptr::eq(t.as_ref(), *object_type)) + .any(|t| t.attributes.read().contains_key(new_id)) + } else { + // For built-in types, use slot comparison + typ_new != object_new + }; + + // If both __init__ and __new__ are overridden, allow excess args + if typ_init != object_init && new_overridden { + return Ok(()); + } + + // Otherwise, reject excess args + if !args.is_empty() { + return Err(vm.new_type_error(format!( + "{}.__init__() takes exactly one argument (the instance to initialize)", + typ.name() + ))); + } Ok(()) } diff --git a/crates/vm/src/builtins/weakref.rs b/crates/vm/src/builtins/weakref.rs index 88d6dbac3ed..327c0fd1489 100644 --- a/crates/vm/src/builtins/weakref.rs +++ b/crates/vm/src/builtins/weakref.rs @@ -4,10 +4,12 @@ use crate::common::{ hash::{self, PyHash}, }; use crate::{ - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, + AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, function::{FuncArgs, OptionalArg}, - types::{Callable, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, + types::{ + Callable, Comparable, Constructor, Hashable, Initializer, PyComparisonOp, Representable, + }, }; pub use crate::object::PyWeak; @@ -49,8 +51,24 @@ impl Constructor for PyWeak { } } +impl Initializer for PyWeak { + type Args = WeakNewArgs; + + // weakref_tp_init: accepts args but does nothing (all init done in slot_new) + fn init(_zelf: PyRef, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<()> { + Ok(()) + } +} + #[pyclass( - with(Callable, Hashable, Comparable, Constructor, Representable), + with( + Callable, + Hashable, + Comparable, + Constructor, + Initializer, + Representable + ), flags(BASETYPE) )] impl PyWeak { diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index 6a366385702..92e9f6a15be 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -1,7 +1,8 @@ //! Utilities to define a new Python class use crate::{ - builtins::{PyBaseObject, PyType, PyTypeRef}, + PyPayload, + builtins::{PyBaseObject, PyType, PyTypeRef, descriptor::PySlotWrapper}, function::PyMethodDef, object::Py, types::{PyTypeFlags, PyTypeSlots, hash_not_implemented}, @@ -135,6 +136,20 @@ pub trait PyClassImpl: PyClassDef { } } + // Add __init__ slot wrapper if slot exists and not already in dict + if let Some(init_func) = class.slots.init.load() { + let init_name = identifier!(ctx, __init__); + if !class.attributes.read().contains_key(init_name) { + let wrapper = PySlotWrapper { + typ: class, + name: ctx.intern_str("__init__"), + wrapped: init_func, + doc: Some("Initialize self. See help(type(self)) for accurate signature."), + }; + class.set_attr(init_name, wrapper.into_ref(ctx).into()); + } + } + if class.slots.hash.load().map_or(0, |h| h as usize) == hash_not_implemented as usize { class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); } diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs index ba3576a176c..2c3c02ec65c 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/io.rs @@ -1464,17 +1464,12 @@ mod _io { #[pyslot] fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { let zelf: PyRef = zelf.try_into_value(vm)?; - zelf.__init__(args, vm) - } - - #[pymethod] - fn __init__(&self, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { let (raw, BufferSize { buffer_size }): (PyObjectRef, _) = args.bind(vm).map_err(|e| { let msg = format!("{}() {}", Self::CLASS_NAME, *e.__str__(vm)); vm.new_exception_msg(e.class().to_owned(), msg) })?; - self.init(raw, BufferSize { buffer_size }, vm) + zelf.init(raw, BufferSize { buffer_size }, vm) } fn init( diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index 5deb593818e..d09f6925eef 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -465,7 +465,10 @@ fn descr_set_wrapper( fn init_wrapper(obj: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { let res = vm.call_special_method(&obj, identifier!(vm, __init__), args)?; if !vm.is_none(&res) { - return Err(vm.new_type_error("__init__ must return None")); + return Err(vm.new_type_error(format!( + "__init__ should return None, not '{:.200}'", + res.class().name() + ))); } Ok(()) } @@ -943,7 +946,6 @@ pub trait Initializer: PyPayload { #[inline] #[pyslot] - #[pymethod(name = "__init__")] fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { #[cfg(debug_assertions)] let class_name_for_debug = zelf.class().name().to_string(); diff --git a/crates/vm/src/types/zoo.rs b/crates/vm/src/types/zoo.rs index d994ff60210..dd4631bc767 100644 --- a/crates/vm/src/types/zoo.rs +++ b/crates/vm/src/types/zoo.rs @@ -94,6 +94,8 @@ pub struct TypeZoo { pub generic_alias_type: &'static Py, pub union_type: &'static Py, pub member_descriptor_type: &'static Py, + pub wrapper_descriptor_type: &'static Py, + pub method_wrapper_type: &'static Py, // RustPython-original types pub method_def: &'static Py, @@ -187,6 +189,8 @@ impl TypeZoo { generic_alias_type: genericalias::PyGenericAlias::init_builtin_type(), union_type: union_::PyUnion::init_builtin_type(), member_descriptor_type: descriptor::PyMemberDescriptor::init_builtin_type(), + wrapper_descriptor_type: descriptor::PySlotWrapper::init_builtin_type(), + method_wrapper_type: descriptor::PyMethodWrapper::init_builtin_type(), method_def: crate::function::HeapMethodDef::init_builtin_type(), } From c4e77287d1be4757c3c67551199d16dc52599997 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 24 Dec 2025 18:56:47 +0900 Subject: [PATCH 047/418] __hash__ to slot_wrapper (#6480) --- crates/vm/src/builtins/descriptor.rs | 51 ++++++++++++++++++++++++---- crates/vm/src/class.rs | 24 +++++++++++-- crates/vm/src/types/slot.rs | 6 +--- 3 files changed, 67 insertions(+), 14 deletions(-) diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index 2ccd1dcc0e9..5cd54bb753d 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -6,7 +6,8 @@ use crate::{ common::hash::PyHash, function::{FuncArgs, PyMethodDef, PyMethodFlags, PySetterValue}, types::{ - Callable, Comparable, GetDescriptor, Hashable, InitFunc, PyComparisonOp, Representable, + Callable, Comparable, GetDescriptor, HashFunc, Hashable, InitFunc, PyComparisonOp, + Representable, }, }; use rustpython_common::lock::PyRwLock; @@ -391,6 +392,44 @@ pub fn init(ctx: &Context) { // PySlotWrapper - wrapper_descriptor +/// Type-erased slot function - mirrors CPython's void* d_wrapped +/// Each variant knows how to call the wrapped function with proper types +#[derive(Clone, Copy)] +pub enum SlotFunc { + Init(InitFunc), + Hash(HashFunc), +} + +impl std::fmt::Debug for SlotFunc { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SlotFunc::Init(_) => write!(f, "SlotFunc::Init(...)"), + SlotFunc::Hash(_) => write!(f, "SlotFunc::Hash(...)"), + } + } +} + +impl SlotFunc { + /// Call the wrapped slot function with proper type handling + pub fn call(&self, obj: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + match self { + SlotFunc::Init(func) => { + func(obj, args, vm)?; + Ok(vm.ctx.none()) + } + SlotFunc::Hash(func) => { + if !args.args.is_empty() || !args.kwargs.is_empty() { + return Err( + vm.new_type_error("__hash__() takes no arguments (1 given)".to_owned()) + ); + } + let hash = func(&obj, vm)?; + Ok(vm.ctx.new_int(hash).into()) + } + } + } +} + /// wrapper_descriptor: wraps a slot function as a Python method // = PyWrapperDescrObject #[pyclass(name = "wrapper_descriptor", module = false)] @@ -398,7 +437,7 @@ pub fn init(ctx: &Context) { pub struct PySlotWrapper { pub typ: &'static Py, pub name: &'static PyStrInterned, - pub wrapped: InitFunc, + pub wrapped: SlotFunc, pub doc: Option<&'static str>, } @@ -430,7 +469,7 @@ impl Callable for PySlotWrapper { type Args = FuncArgs; fn call(zelf: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - // list.__init__(l, [1,2,3]) form + // list.__init__(l, [1,2,3]) form - first arg is self let (obj, rest): (PyObjectRef, FuncArgs) = args.bind(vm)?; if !obj.fast_isinstance(zelf.typ) { @@ -442,8 +481,7 @@ impl Callable for PySlotWrapper { ))); } - (zelf.wrapped)(obj, rest, vm)?; - Ok(vm.ctx.none()) + zelf.wrapped.call(obj, rest, vm) } } @@ -506,8 +544,7 @@ impl Callable for PyMethodWrapper { type Args = FuncArgs; fn call(zelf: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - (zelf.wrapper.wrapped)(zelf.obj.clone(), args, vm)?; - Ok(vm.ctx.none()) + zelf.wrapper.wrapped.call(zelf.obj.clone(), args, vm) } } diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index 92e9f6a15be..f258281712f 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -2,7 +2,10 @@ use crate::{ PyPayload, - builtins::{PyBaseObject, PyType, PyTypeRef, descriptor::PySlotWrapper}, + builtins::{ + PyBaseObject, PyType, PyTypeRef, + descriptor::{PySlotWrapper, SlotFunc}, + }, function::PyMethodDef, object::Py, types::{PyTypeFlags, PyTypeSlots, hash_not_implemented}, @@ -143,13 +146,30 @@ pub trait PyClassImpl: PyClassDef { let wrapper = PySlotWrapper { typ: class, name: ctx.intern_str("__init__"), - wrapped: init_func, + wrapped: SlotFunc::Init(init_func), doc: Some("Initialize self. See help(type(self)) for accurate signature."), }; class.set_attr(init_name, wrapper.into_ref(ctx).into()); } } + // Add __hash__ slot wrapper if slot exists and not already in dict + // Note: hash_not_implemented is handled separately (sets __hash__ = None) + if let Some(hash_func) = class.slots.hash.load() + && hash_func as usize != hash_not_implemented as usize + { + let hash_name = identifier!(ctx, __hash__); + if !class.attributes.read().contains_key(hash_name) { + let wrapper = PySlotWrapper { + typ: class, + name: ctx.intern_str("__hash__"), + wrapped: SlotFunc::Hash(hash_func), + doc: Some("Return hash(self)."), + }; + class.set_attr(hash_name, wrapper.into_ref(ctx).into()); + } + } + if class.slots.hash.load().map_or(0, |h| h as usize) == hash_not_implemented as usize { class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); } diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index d09f6925eef..2954d08abf0 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -1098,11 +1098,7 @@ pub trait Hashable: PyPayload { Self::hash(zelf, vm) } - #[inline] - #[pymethod] - fn __hash__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { - Self::slot_hash(&zelf, vm) - } + // __hash__ is now exposed via SlotFunc::Hash wrapper in extend_class() fn hash(zelf: &Py, vm: &VirtualMachine) -> PyResult; } From 9d1477699ceecc8c57783ca04577f9e9723b5d08 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Wed, 24 Dec 2025 10:57:00 +0100 Subject: [PATCH 048/418] Update `test_threading_local.py` from 3.13.11 (#6482) * Update `test_threading_local.py` from 3.13.11 * Mark failing test --- Lib/test/test_threading_local.py | 62 ++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/Lib/test/test_threading_local.py b/Lib/test/test_threading_local.py index 3443e3875d0..99052de4c7f 100644 --- a/Lib/test/test_threading_local.py +++ b/Lib/test/test_threading_local.py @@ -3,8 +3,8 @@ from doctest import DocTestSuite from test import support from test.support import threading_helper +from test.support.import_helper import import_module import weakref -import gc # Modules under test import _thread @@ -12,6 +12,9 @@ import _threading_local +threading_helper.requires_working_threading(module=True) + + class Weak(object): pass @@ -23,7 +26,7 @@ def target(local, weaklist): class BaseLocalTest: - @unittest.skip("TODO: RUSTPYTHON, flaky test") + @unittest.skip('TODO: RUSTPYTHON; flaky test') def test_local_refs(self): self._local_refs(20) self._local_refs(50) @@ -182,8 +185,7 @@ class LocalSubclass(self._local): """To test that subclasses behave properly.""" self._test_dict_attribute(LocalSubclass) - # TODO: RUSTPYTHON, cycle detection/collection - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; cycle detection/collection def test_cycle_collection(self): class X: pass @@ -197,35 +199,57 @@ class X: self.assertIsNone(wr()) + def test_threading_local_clear_race(self): + # See https://github.com/python/cpython/issues/100892 + + _testcapi = import_module('_testcapi') + _testcapi.call_in_temporary_c_thread(lambda: None, False) + + for _ in range(1000): + _ = threading.local() + + _testcapi.join_temporary_c_thread() + + @support.cpython_only + def test_error(self): + class Loop(self._local): + attr = 1 + + # Trick the "if name == '__dict__':" test of __setattr__() + # to always be true + class NameCompareTrue: + def __eq__(self, other): + return True + + loop = Loop() + with self.assertRaisesRegex(AttributeError, 'Loop.*read-only'): + loop.__setattr__(NameCompareTrue(), 2) + + class ThreadLocalTest(unittest.TestCase, BaseLocalTest): _local = _thread._local - # TODO: RUSTPYTHON, __new__ vs __init__ cooperation - @unittest.expectedFailure - def test_arguments(): - super().test_arguments() - + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: TypeError not raised by _local + def test_arguments(self): + return super().test_arguments() class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest): _local = _threading_local.local -def test_main(): - suite = unittest.TestSuite() - suite.addTest(DocTestSuite('_threading_local')) - suite.addTest(unittest.makeSuite(ThreadLocalTest)) - suite.addTest(unittest.makeSuite(PyThreadingLocalTest)) +def load_tests(loader, tests, pattern): + tests.addTest(DocTestSuite('_threading_local')) local_orig = _threading_local.local def setUp(test): _threading_local.local = _thread._local def tearDown(test): _threading_local.local = local_orig - suite.addTest(DocTestSuite('_threading_local', - setUp=setUp, tearDown=tearDown) - ) + tests.addTests(DocTestSuite('_threading_local', + setUp=setUp, tearDown=tearDown) + ) + return tests - support.run_unittest(suite) if __name__ == '__main__': - test_main() + unittest.main() From 7c3bc5ed8de973236d64e084dcc102f53fddd272 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Wed, 24 Dec 2025 11:45:48 +0100 Subject: [PATCH 049/418] Update `test_http_cookies.py` from 3.13.11 (#6481) * Update `test_http_cookies.py` from 3.13.11 * Mark failing test * Update `http/cookies.py` from 3.13.11 * Unmark passing test * Lower amount --- Lib/http/cookies.py | 42 +++++--------- Lib/test/test_http_cookies.py | 102 +++++++++++++++++++++++++++++++--- 2 files changed, 108 insertions(+), 36 deletions(-) diff --git a/Lib/http/cookies.py b/Lib/http/cookies.py index 35ac2dc6ae2..57791c6ab08 100644 --- a/Lib/http/cookies.py +++ b/Lib/http/cookies.py @@ -184,8 +184,13 @@ def _quote(str): return '"' + str.translate(_Translator) + '"' -_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]") -_QuotePatt = re.compile(r"[\\].") +_unquote_sub = re.compile(r'\\(?:([0-3][0-7][0-7])|(.))').sub + +def _unquote_replace(m): + if m[1]: + return chr(int(m[1], 8)) + else: + return m[2] def _unquote(str): # If there aren't any doublequotes, @@ -205,36 +210,13 @@ def _unquote(str): # \012 --> \n # \" --> " # - i = 0 - n = len(str) - res = [] - while 0 <= i < n: - o_match = _OctalPatt.search(str, i) - q_match = _QuotePatt.search(str, i) - if not o_match and not q_match: # Neither matched - res.append(str[i:]) - break - # else: - j = k = -1 - if o_match: - j = o_match.start(0) - if q_match: - k = q_match.start(0) - if q_match and (not o_match or k < j): # QuotePatt matched - res.append(str[i:k]) - res.append(str[k+1]) - i = k + 2 - else: # OctalPatt matched - res.append(str[i:j]) - res.append(chr(int(str[j+1:j+4], 8))) - i = j + 4 - return _nulljoin(res) + return _unquote_sub(_unquote_replace, str) # The _getdate() routine is used to set the expiration time in the cookie's HTTP # header. By default, _getdate() returns the current time in the appropriate # "expires" format for a Set-Cookie header. The one optional argument is an # offset from now, in seconds. For example, an offset of -3600 means "one hour -# ago". The offset may be a floating point number. +# ago". The offset may be a floating-point number. # _weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] @@ -442,9 +424,11 @@ def OutputString(self, attrs=None): ( # Optional group: there may not be a value. \s*=\s* # Equal Sign (?P # Start of group 'val' - "(?:[^\\"]|\\.)*" # Any doublequoted string + "(?:[^\\"]|\\.)*" # Any double-quoted string | # or - \w{3},\s[\w\d\s-]{9,11}\s[\d:]{8}\sGMT # Special case for "expires" attr + # Special case for "expires" attr + (\w{3,6}day|\w{3}),\s # Day of the week or abbreviated day + [\w\d\s-]{9,11}\s[\d:]{8}\sGMT # Date and time in specific format | # or [""" + _LegalValueChars + r"""]* # Any word or empty string ) # End of group 'val' diff --git a/Lib/test/test_http_cookies.py b/Lib/test/test_http_cookies.py index 6072c7e15e9..3e0b4d1d5ca 100644 --- a/Lib/test/test_http_cookies.py +++ b/Lib/test/test_http_cookies.py @@ -1,13 +1,15 @@ # Simple test suite for http/cookies.py import copy -from test.support import run_unittest, run_doctest import unittest +import doctest from http import cookies import pickle +from test import support +from test.support.testcase import ExtraAssertions -class CookieTests(unittest.TestCase): +class CookieTests(unittest.TestCase, ExtraAssertions): def test_basic(self): cases = [ @@ -58,6 +60,90 @@ def test_basic(self): for k, v in sorted(case['dict'].items()): self.assertEqual(C[k].value, v) + def test_obsolete_rfc850_date_format(self): + # Test cases with different days and dates in obsolete RFC 850 format + test_cases = [ + # from RFC 850, change EST to GMT + # https://datatracker.ietf.org/doc/html/rfc850#section-2 + { + 'data': 'key=value; expires=Saturday, 01-Jan-83 00:00:00 GMT', + 'output': 'Saturday, 01-Jan-83 00:00:00 GMT' + }, + { + 'data': 'key=value; expires=Friday, 19-Nov-82 16:59:30 GMT', + 'output': 'Friday, 19-Nov-82 16:59:30 GMT' + }, + # from RFC 9110 + # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.6.7-6 + { + 'data': 'key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT', + 'output': 'Sunday, 06-Nov-94 08:49:37 GMT' + }, + # other test cases + { + 'data': 'key=value; expires=Wednesday, 09-Nov-94 08:49:37 GMT', + 'output': 'Wednesday, 09-Nov-94 08:49:37 GMT' + }, + { + 'data': 'key=value; expires=Friday, 11-Nov-94 08:49:37 GMT', + 'output': 'Friday, 11-Nov-94 08:49:37 GMT' + }, + { + 'data': 'key=value; expires=Monday, 14-Nov-94 08:49:37 GMT', + 'output': 'Monday, 14-Nov-94 08:49:37 GMT' + }, + ] + + for case in test_cases: + with self.subTest(data=case['data']): + C = cookies.SimpleCookie() + C.load(case['data']) + + # Extract the cookie name from the data string + cookie_name = case['data'].split('=')[0] + + # Check if the cookie is loaded correctly + self.assertIn(cookie_name, C) + self.assertEqual(C[cookie_name].get('expires'), case['output']) + + def test_unquote(self): + cases = [ + (r'a="b=\""', 'b="'), + (r'a="b=\\"', 'b=\\'), + (r'a="b=\="', 'b=='), + (r'a="b=\n"', 'b=n'), + (r'a="b=\042"', 'b="'), + (r'a="b=\134"', 'b=\\'), + (r'a="b=\377"', 'b=\xff'), + (r'a="b=\400"', 'b=400'), + (r'a="b=\42"', 'b=42'), + (r'a="b=\\042"', 'b=\\042'), + (r'a="b=\\134"', 'b=\\134'), + (r'a="b=\\\""', 'b=\\"'), + (r'a="b=\\\042"', 'b=\\"'), + (r'a="b=\134\""', 'b=\\"'), + (r'a="b=\134\042"', 'b=\\"'), + ] + for encoded, decoded in cases: + with self.subTest(encoded): + C = cookies.SimpleCookie() + C.load(encoded) + self.assertEqual(C['a'].value, decoded) + + @support.requires_resource('cpu') + def test_unquote_large(self): + #n = 10**6 + n = 10**4 # XXX: RUSTPYTHON; This takes more than 10 minutes to run. lower to 4 + for encoded in r'\\', r'\134': + with self.subTest(encoded): + data = 'a="b=' + encoded*n + ';"' + C = cookies.SimpleCookie() + C.load(data) + value = C['a'].value + self.assertEqual(value[:3], 'b=\\') + self.assertEqual(value[-2:], '\\;') + self.assertEqual(len(value), n + 3) + def test_load(self): C = cookies.SimpleCookie() C.load('Customer="WILE_E_COYOTE"; Version=1; Path=/acme') @@ -96,7 +182,7 @@ def test_special_attrs(self): C = cookies.SimpleCookie('Customer="WILE_E_COYOTE"') C['Customer']['expires'] = 0 # can't test exact output, it always depends on current date/time - self.assertTrue(C.output().endswith('GMT')) + self.assertEndsWith(C.output(), 'GMT') # loading 'expires' C = cookies.SimpleCookie() @@ -479,9 +565,11 @@ def test_repr(self): r'Set-Cookie: key=coded_val; ' r'expires=\w+, \d+ \w+ \d+ \d+:\d+:\d+ \w+') -def test_main(): - run_unittest(CookieTests, MorselTests) - run_doctest(cookies) + +def load_tests(loader, tests, pattern): + tests.addTest(doctest.DocTestSuite(cookies)) + return tests + if __name__ == '__main__': - test_main() + unittest.main() From a4e60f569e6c1c9c5ff1c89d76e786f2ac08a827 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 24 Dec 2025 17:48:41 +0900 Subject: [PATCH 050/418] repr --- crates/vm/src/builtins/descriptor.rs | 14 ++++++++++++-- crates/vm/src/class.rs | 14 ++++++++++++++ crates/vm/src/types/slot.rs | 6 ------ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index 5cd54bb753d..c251ac45ee3 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -7,7 +7,7 @@ use crate::{ function::{FuncArgs, PyMethodDef, PyMethodFlags, PySetterValue}, types::{ Callable, Comparable, GetDescriptor, HashFunc, Hashable, InitFunc, PyComparisonOp, - Representable, + Representable, StringifyFunc, }, }; use rustpython_common::lock::PyRwLock; @@ -398,6 +398,7 @@ pub fn init(ctx: &Context) { pub enum SlotFunc { Init(InitFunc), Hash(HashFunc), + Repr(StringifyFunc), } impl std::fmt::Debug for SlotFunc { @@ -405,6 +406,7 @@ impl std::fmt::Debug for SlotFunc { match self { SlotFunc::Init(_) => write!(f, "SlotFunc::Init(...)"), SlotFunc::Hash(_) => write!(f, "SlotFunc::Hash(...)"), + SlotFunc::Repr(_) => write!(f, "SlotFunc::Repr(...)"), } } } @@ -426,6 +428,15 @@ impl SlotFunc { let hash = func(&obj, vm)?; Ok(vm.ctx.new_int(hash).into()) } + SlotFunc::Repr(func) => { + if !args.args.is_empty() || !args.kwargs.is_empty() { + return Err( + vm.new_type_error("__repr__() takes no arguments (1 given)".to_owned()) + ); + } + let s = func(&obj, vm)?; + Ok(s.into()) + } } } } @@ -456,7 +467,6 @@ impl GetDescriptor for PySlotWrapper { ) -> PyResult { match obj { None => Ok(zelf), - Some(obj) if vm.is_none(&obj) => Ok(zelf), Some(obj) => { let zelf = zelf.downcast::().unwrap(); Ok(PyMethodWrapper { wrapper: zelf, obj }.into_pyobject(vm)) diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index f258281712f..236967dd36e 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -174,6 +174,20 @@ pub trait PyClassImpl: PyClassDef { class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); } + // Add __repr__ slot wrapper if slot exists and not already in dict + if let Some(repr_func) = class.slots.repr.load() { + let repr_name = identifier!(ctx, __repr__); + if !class.attributes.read().contains_key(repr_name) { + let wrapper = PySlotWrapper { + typ: class, + name: ctx.intern_str("__repr__"), + wrapped: SlotFunc::Repr(repr_func), + doc: Some("Return repr(self)."), + }; + class.set_attr(repr_name, wrapper.into_ref(ctx).into()); + } + } + class.extend_methods(class.slots.methods, ctx); } diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index 2954d08abf0..6e98da173a4 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -1114,12 +1114,6 @@ pub trait Representable: PyPayload { Self::repr(zelf, vm) } - #[inline] - #[pymethod] - fn __repr__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - Self::slot_repr(&zelf, vm) - } - #[inline] fn repr(zelf: &Py, vm: &VirtualMachine) -> PyResult> { let repr = Self::repr_str(zelf, vm)?; From c2a739319149052f3bc03f873f2e9840571ff541 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 24 Dec 2025 21:36:42 +0900 Subject: [PATCH 051/418] uniform __str__ --- crates/stdlib/src/ssl/error.rs | 6 +- crates/vm/src/builtins/str.rs | 16 ++-- crates/vm/src/builtins/weakproxy.rs | 4 +- crates/vm/src/exception_group.rs | 6 +- crates/vm/src/exceptions.rs | 138 +++++++++++++++------------- crates/vm/src/stdlib/io.rs | 6 +- crates/vm/src/stdlib/winreg.rs | 16 ++-- 7 files changed, 104 insertions(+), 88 deletions(-) diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs index 879275228ec..d77910f6aa1 100644 --- a/crates/stdlib/src/ssl/error.rs +++ b/crates/stdlib/src/ssl/error.rs @@ -5,8 +5,8 @@ pub(crate) use ssl_error::*; #[pymodule(sub)] pub(crate) mod ssl_error { use crate::vm::{ - PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyBaseExceptionRef, PyOSError, PyStrRef}, + Py, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyBaseException, PyOSError, PyStrRef}, types::{Constructor, Initializer}, }; @@ -42,7 +42,7 @@ pub(crate) mod ssl_error { impl PySSLError { // Returns strerror attribute if available, otherwise str(args) #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { + fn __str__(exc: &Py, vm: &VirtualMachine) -> PyResult { use crate::vm::AsObject; // Try to get strerror attribute first (OSError compatibility) if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index 8084c4d053e..95b41e6d55c 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -529,7 +529,6 @@ impl Py { #[pyclass( flags(BASETYPE, _MATCH_SELF), with( - PyRef, AsMapping, AsNumber, AsSequence, @@ -1448,15 +1447,16 @@ impl PyStr { fn __getnewargs__(zelf: PyRef, vm: &VirtualMachine) -> PyObjectRef { (zelf.as_str(),).to_pyobject(vm) } -} -#[pyclass] -impl PyRef { #[pymethod] - fn __str__(self, vm: &VirtualMachine) -> PyRefExact { - self.into_exact_or(&vm.ctx, |zelf| { - PyStr::from(zelf.data.clone()).into_exact_ref(&vm.ctx) - }) + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { + if zelf.class().is(vm.ctx.types.str_type) { + // Already exact str, just return a reference + Ok(zelf.to_owned()) + } else { + // Subclass, create a new exact str + Ok(PyStr::from(zelf.data.clone()).into_ref(&vm.ctx)) + } } } diff --git a/crates/vm/src/builtins/weakproxy.rs b/crates/vm/src/builtins/weakproxy.rs index a9221ec876f..6e0e8308dbc 100644 --- a/crates/vm/src/builtins/weakproxy.rs +++ b/crates/vm/src/builtins/weakproxy.rs @@ -79,8 +79,8 @@ impl PyWeakProxy { } #[pymethod] - fn __str__(&self, vm: &VirtualMachine) -> PyResult { - self.try_upgrade(vm)?.str(vm) + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { + zelf.try_upgrade(vm)?.str(vm) } fn len(&self, vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/exception_group.rs b/crates/vm/src/exception_group.rs index e19dbceb8da..645a3e779ff 100644 --- a/crates/vm/src/exception_group.rs +++ b/crates/vm/src/exception_group.rs @@ -182,7 +182,7 @@ pub(super) mod types { } #[pymethod] - fn __str__(zelf: PyRef, vm: &VirtualMachine) -> PyResult { + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { let message = zelf .get_arg(0) .map(|m| m.str(vm)) @@ -196,10 +196,10 @@ pub(super) mod types { .unwrap_or(0); let suffix = if num_excs == 1 { "" } else { "s" }; - Ok(format!( + Ok(vm.ctx.new_str(format!( "{} ({} sub-exception{})", message, num_excs, suffix - )) + ))) } #[pymethod] diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 2c36aa13bd5..3a932a5df54 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -560,7 +560,7 @@ impl PyBaseException { } #[pyclass( - with(PyRef, Constructor, Initializer, Representable), + with(Py, PyRef, Constructor, Initializer, Representable), flags(BASETYPE, HAS_DICT) )] impl PyBaseException { @@ -633,15 +633,18 @@ impl PyBaseException { fn set_suppress_context(&self, suppress_context: bool) { self.suppress_context.store(suppress_context); } +} +#[pyclass] +impl Py { #[pymethod] - pub(super) fn __str__(&self, vm: &VirtualMachine) -> PyStrRef { + pub(super) fn __str__(&self, vm: &VirtualMachine) -> PyResult { let str_args = vm.exception_args_as_string(self.args(), true); - match str_args.into_iter().exactly_one() { + Ok(match str_args.into_iter().exactly_one() { Err(i) if i.len() == 0 => vm.ctx.empty_str.to_owned(), Ok(s) => s, Err(i) => PyStr::from(format!("({})", i.format(", "))).into_ref(&vm.ctx), - } + }) } } @@ -1527,16 +1530,16 @@ pub(super) mod types { #[pyexception] impl PyKeyError { #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyStrRef { - let args = exc.args(); - if args.len() == 1 { + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let args = zelf.args(); + Ok(if args.len() == 1 { vm.exception_args_as_string(args, false) .into_iter() .exactly_one() .unwrap() } else { - exc.__str__(vm) - } + zelf.__str__(vm)? + }) } } @@ -1731,8 +1734,8 @@ pub(super) mod types { #[pyexception(with(Constructor, Initializer))] impl PyOSError { #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - let obj = exc.as_object().to_owned(); + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let obj = zelf.as_object(); // Get OSError fields directly let errno_field = obj.get_attr("errno", vm).ok().filter(|v| !vm.is_none(v)); @@ -1819,7 +1822,7 @@ pub(super) mod types { } // fallback to BaseException.__str__ - Ok(exc.__str__(vm)) + zelf.__str__(vm) } #[pymethod] @@ -2026,7 +2029,7 @@ pub(super) mod types { #[pyexception(with(Initializer))] impl PySyntaxError { #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyStrRef { + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { fn basename(filename: &str) -> &str { let splitted = if cfg!(windows) { filename.rsplit(&['/', '\\']).next() @@ -2036,16 +2039,16 @@ pub(super) mod types { splitted.unwrap_or(filename) } - let maybe_lineno = exc.as_object().get_attr("lineno", vm).ok().map(|obj| { + let maybe_lineno = zelf.as_object().get_attr("lineno", vm).ok().map(|obj| { obj.str(vm) .unwrap_or_else(|_| vm.ctx.new_str("")) }); - let maybe_filename = exc.as_object().get_attr("filename", vm).ok().map(|obj| { + let maybe_filename = zelf.as_object().get_attr("filename", vm).ok().map(|obj| { obj.str(vm) .unwrap_or_else(|_| vm.ctx.new_str("")) }); - let args = exc.args(); + let args = zelf.args(); let msg = if args.len() == 1 { vm.exception_args_as_string(args, false) @@ -2053,7 +2056,7 @@ pub(super) mod types { .exactly_one() .unwrap() } else { - return exc.__str__(vm); + return zelf.__str__(vm); }; let msg_with_location_info: String = match (maybe_lineno, maybe_filename) { @@ -2069,7 +2072,7 @@ pub(super) mod types { (None, None) => msg.to_string(), }; - vm.ctx.new_str(msg_with_location_info) + Ok(vm.ctx.new_str(msg_with_location_info)) } } @@ -2178,29 +2181,32 @@ pub(super) mod types { #[pyexception(with(Initializer))] impl PyUnicodeDecodeError { #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - let Ok(object) = exc.as_object().get_attr("object", vm) else { - return Ok("".to_owned()); + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let Ok(object) = zelf.as_object().get_attr("object", vm) else { + return Ok(vm.ctx.empty_str.to_owned()); }; let object: ArgBytesLike = object.try_into_value(vm)?; - let encoding: PyStrRef = exc + let encoding: PyStrRef = zelf .as_object() .get_attr("encoding", vm)? .try_into_value(vm)?; - let start: usize = exc.as_object().get_attr("start", vm)?.try_into_value(vm)?; - let end: usize = exc.as_object().get_attr("end", vm)?.try_into_value(vm)?; - let reason: PyStrRef = exc.as_object().get_attr("reason", vm)?.try_into_value(vm)?; - if start < object.len() && end <= object.len() && end == start + 1 { + let start: usize = zelf.as_object().get_attr("start", vm)?.try_into_value(vm)?; + let end: usize = zelf.as_object().get_attr("end", vm)?.try_into_value(vm)?; + let reason: PyStrRef = zelf + .as_object() + .get_attr("reason", vm)? + .try_into_value(vm)?; + Ok(vm.ctx.new_str(if start < object.len() && end <= object.len() && end == start + 1 { let b = object.borrow_buf()[start]; - Ok(format!( + format!( "'{encoding}' codec can't decode byte {b:#02x} in position {start}: {reason}" - )) + ) } else { - Ok(format!( + format!( "'{encoding}' codec can't decode bytes in position {start}-{}: {reason}", end - 1, - )) - } + ) + })) } } @@ -2232,30 +2238,33 @@ pub(super) mod types { #[pyexception(with(Initializer))] impl PyUnicodeEncodeError { #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - let Ok(object) = exc.as_object().get_attr("object", vm) else { - return Ok("".to_owned()); + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let Ok(object) = zelf.as_object().get_attr("object", vm) else { + return Ok(vm.ctx.empty_str.to_owned()); }; let object: PyStrRef = object.try_into_value(vm)?; - let encoding: PyStrRef = exc + let encoding: PyStrRef = zelf .as_object() .get_attr("encoding", vm)? .try_into_value(vm)?; - let start: usize = exc.as_object().get_attr("start", vm)?.try_into_value(vm)?; - let end: usize = exc.as_object().get_attr("end", vm)?.try_into_value(vm)?; - let reason: PyStrRef = exc.as_object().get_attr("reason", vm)?.try_into_value(vm)?; - if start < object.char_len() && end <= object.char_len() && end == start + 1 { + let start: usize = zelf.as_object().get_attr("start", vm)?.try_into_value(vm)?; + let end: usize = zelf.as_object().get_attr("end", vm)?.try_into_value(vm)?; + let reason: PyStrRef = zelf + .as_object() + .get_attr("reason", vm)? + .try_into_value(vm)?; + Ok(vm.ctx.new_str(if start < object.char_len() && end <= object.char_len() && end == start + 1 { let ch = object.as_wtf8().code_points().nth(start).unwrap(); - Ok(format!( + format!( "'{encoding}' codec can't encode character '{}' in position {start}: {reason}", UnicodeEscapeCodepoint(ch) - )) + ) } else { - Ok(format!( + format!( "'{encoding}' codec can't encode characters in position {start}-{}: {reason}", end - 1, - )) - } + ) + })) } } @@ -2286,26 +2295,31 @@ pub(super) mod types { #[pyexception(with(Initializer))] impl PyUnicodeTranslateError { #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - let Ok(object) = exc.as_object().get_attr("object", vm) else { - return Ok("".to_owned()); + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let Ok(object) = zelf.as_object().get_attr("object", vm) else { + return Ok(vm.ctx.empty_str.to_owned()); }; let object: PyStrRef = object.try_into_value(vm)?; - let start: usize = exc.as_object().get_attr("start", vm)?.try_into_value(vm)?; - let end: usize = exc.as_object().get_attr("end", vm)?.try_into_value(vm)?; - let reason: PyStrRef = exc.as_object().get_attr("reason", vm)?.try_into_value(vm)?; - if start < object.char_len() && end <= object.char_len() && end == start + 1 { - let ch = object.as_wtf8().code_points().nth(start).unwrap(); - Ok(format!( - "can't translate character '{}' in position {start}: {reason}", - UnicodeEscapeCodepoint(ch) - )) - } else { - Ok(format!( - "can't translate characters in position {start}-{}: {reason}", - end - 1, - )) - } + let start: usize = zelf.as_object().get_attr("start", vm)?.try_into_value(vm)?; + let end: usize = zelf.as_object().get_attr("end", vm)?.try_into_value(vm)?; + let reason: PyStrRef = zelf + .as_object() + .get_attr("reason", vm)? + .try_into_value(vm)?; + Ok(vm.ctx.new_str( + if start < object.char_len() && end <= object.char_len() && end == start + 1 { + let ch = object.as_wtf8().code_points().nth(start).unwrap(); + format!( + "can't translate character '{}' in position {start}: {reason}", + UnicodeEscapeCodepoint(ch) + ) + } else { + format!( + "can't translate characters in position {start}-{}: {reason}", + end - 1, + ) + }, + )) } } diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs index 2c3c02ec65c..9402660a86e 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/io.rs @@ -1466,7 +1466,11 @@ mod _io { let zelf: PyRef = zelf.try_into_value(vm)?; let (raw, BufferSize { buffer_size }): (PyObjectRef, _) = args.bind(vm).map_err(|e| { - let msg = format!("{}() {}", Self::CLASS_NAME, *e.__str__(vm)); + let str_repr = e + .__str__(vm) + .map(|s| s.as_str().to_owned()) + .unwrap_or_else(|_| "".to_owned()); + let msg = format!("{}() {}", Self::CLASS_NAME, str_repr); vm.new_exception_msg(e.class().to_owned(), msg) })?; zelf.init(raw, BufferSize { buffer_size }, vm) diff --git a/crates/vm/src/stdlib/winreg.rs b/crates/vm/src/stdlib/winreg.rs index a7619025866..b5e568fce6d 100644 --- a/crates/vm/src/stdlib/winreg.rs +++ b/crates/vm/src/stdlib/winreg.rs @@ -9,7 +9,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule] mod winreg { - use crate::builtins::{PyInt, PyTuple, PyTypeRef}; + use crate::builtins::{PyInt, PyStr, PyTuple, PyTypeRef}; use crate::common::hash::PyHash; use crate::common::windows::ToWideString; use crate::convert::TryFromObject; @@ -233,8 +233,8 @@ mod winreg { } #[pymethod] - fn __str__(&self) -> String { - format!("", self.hkey.load()) + fn __str__(zelf: &Py, vm: &VirtualMachine) -> PyResult> { + Ok(vm.ctx.new_str(format!("", zelf.hkey.load()))) } } @@ -1029,7 +1029,7 @@ mod winreg { return Ok(Some(vec![0u8, 0u8])); } let s = value - .downcast::() + .downcast::() .map_err(|_| vm.new_type_error("value must be a string".to_string()))?; let wide = s.as_str().to_wide_with_nul(); // Convert Vec to Vec @@ -1047,11 +1047,9 @@ mod winreg { let mut bytes: Vec = Vec::new(); for item in list.borrow_vec().iter() { - let s = item - .downcast_ref::() - .ok_or_else(|| { - vm.new_type_error("list items must be strings".to_string()) - })?; + let s = item.downcast_ref::().ok_or_else(|| { + vm.new_type_error("list items must be strings".to_string()) + })?; let wide = s.as_str().to_wide_with_nul(); bytes.extend(wide.iter().flat_map(|&c| c.to_le_bytes())); } From cbde5ce3218066e89a61f42dd429ab8c2fe77321 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 24 Dec 2025 23:35:35 +0900 Subject: [PATCH 052/418] iter with slot-wrapper (#6488) --- crates/vm/src/builtins/descriptor.rs | 25 +++++++- crates/vm/src/class.rs | 86 ++++++++++++++-------------- crates/vm/src/types/slot.rs | 17 ++---- 3 files changed, 69 insertions(+), 59 deletions(-) diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index c251ac45ee3..fcf40d0082d 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -4,10 +4,11 @@ use crate::{ builtins::{PyTypeRef, builtin_func::PyNativeMethod, type_}, class::PyClassImpl, common::hash::PyHash, + convert::ToPyResult, function::{FuncArgs, PyMethodDef, PyMethodFlags, PySetterValue}, types::{ - Callable, Comparable, GetDescriptor, HashFunc, Hashable, InitFunc, PyComparisonOp, - Representable, StringifyFunc, + Callable, Comparable, GetDescriptor, HashFunc, Hashable, InitFunc, IterFunc, IterNextFunc, + PyComparisonOp, Representable, StringifyFunc, }, }; use rustpython_common::lock::PyRwLock; @@ -399,6 +400,8 @@ pub enum SlotFunc { Init(InitFunc), Hash(HashFunc), Repr(StringifyFunc), + Iter(IterFunc), + IterNext(IterNextFunc), } impl std::fmt::Debug for SlotFunc { @@ -407,6 +410,8 @@ impl std::fmt::Debug for SlotFunc { SlotFunc::Init(_) => write!(f, "SlotFunc::Init(...)"), SlotFunc::Hash(_) => write!(f, "SlotFunc::Hash(...)"), SlotFunc::Repr(_) => write!(f, "SlotFunc::Repr(...)"), + SlotFunc::Iter(_) => write!(f, "SlotFunc::Iter(...)"), + SlotFunc::IterNext(_) => write!(f, "SlotFunc::IterNext(...)"), } } } @@ -437,6 +442,22 @@ impl SlotFunc { let s = func(&obj, vm)?; Ok(s.into()) } + SlotFunc::Iter(func) => { + if !args.args.is_empty() || !args.kwargs.is_empty() { + return Err( + vm.new_type_error("__iter__() takes no arguments (1 given)".to_owned()) + ); + } + func(obj, vm) + } + SlotFunc::IterNext(func) => { + if !args.args.is_empty() || !args.kwargs.is_empty() { + return Err( + vm.new_type_error("__next__() takes no arguments (1 given)".to_owned()) + ); + } + func(&obj, vm).to_pyresult(vm) + } } } } diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index 236967dd36e..1addf00497d 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -139,52 +139,50 @@ pub trait PyClassImpl: PyClassDef { } } - // Add __init__ slot wrapper if slot exists and not already in dict - if let Some(init_func) = class.slots.init.load() { - let init_name = identifier!(ctx, __init__); - if !class.attributes.read().contains_key(init_name) { - let wrapper = PySlotWrapper { - typ: class, - name: ctx.intern_str("__init__"), - wrapped: SlotFunc::Init(init_func), - doc: Some("Initialize self. See help(type(self)) for accurate signature."), - }; - class.set_attr(init_name, wrapper.into_ref(ctx).into()); - } - } - - // Add __hash__ slot wrapper if slot exists and not already in dict - // Note: hash_not_implemented is handled separately (sets __hash__ = None) - if let Some(hash_func) = class.slots.hash.load() - && hash_func as usize != hash_not_implemented as usize - { - let hash_name = identifier!(ctx, __hash__); - if !class.attributes.read().contains_key(hash_name) { - let wrapper = PySlotWrapper { - typ: class, - name: ctx.intern_str("__hash__"), - wrapped: SlotFunc::Hash(hash_func), - doc: Some("Return hash(self)."), - }; - class.set_attr(hash_name, wrapper.into_ref(ctx).into()); - } - } - - if class.slots.hash.load().map_or(0, |h| h as usize) == hash_not_implemented as usize { - class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); + // Add slot wrappers for slots that exist and are not already in dict + // This mirrors CPython's add_operators() in typeobject.c + macro_rules! add_slot_wrapper { + ($slot:ident, $name:ident, $variant:ident, $doc:expr) => { + if let Some(func) = class.slots.$slot.load() { + let attr_name = identifier!(ctx, $name); + if !class.attributes.read().contains_key(attr_name) { + let wrapper = PySlotWrapper { + typ: class, + name: ctx.intern_str(stringify!($name)), + wrapped: SlotFunc::$variant(func), + doc: Some($doc), + }; + class.set_attr(attr_name, wrapper.into_ref(ctx).into()); + } + } + }; } - // Add __repr__ slot wrapper if slot exists and not already in dict - if let Some(repr_func) = class.slots.repr.load() { - let repr_name = identifier!(ctx, __repr__); - if !class.attributes.read().contains_key(repr_name) { - let wrapper = PySlotWrapper { - typ: class, - name: ctx.intern_str("__repr__"), - wrapped: SlotFunc::Repr(repr_func), - doc: Some("Return repr(self)."), - }; - class.set_attr(repr_name, wrapper.into_ref(ctx).into()); + add_slot_wrapper!( + init, + __init__, + Init, + "Initialize self. See help(type(self)) for accurate signature." + ); + add_slot_wrapper!(repr, __repr__, Repr, "Return repr(self)."); + add_slot_wrapper!(iter, __iter__, Iter, "Implement iter(self)."); + add_slot_wrapper!(iternext, __next__, IterNext, "Implement next(self)."); + + // __hash__ needs special handling: hash_not_implemented sets __hash__ = None + if let Some(hash_func) = class.slots.hash.load() { + if hash_func as usize == hash_not_implemented as usize { + class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); + } else { + let hash_name = identifier!(ctx, __hash__); + if !class.attributes.read().contains_key(hash_name) { + let wrapper = PySlotWrapper { + typ: class, + name: ctx.intern_str("__hash__"), + wrapped: SlotFunc::Hash(hash_func), + doc: Some("Return hash(self)."), + }; + class.set_attr(hash_name, wrapper.into_ref(ctx).into()); + } } } diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index 6e98da173a4..26c059067e0 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -6,7 +6,7 @@ use crate::{ builtins::{PyInt, PyStr, PyStrInterned, PyStrRef, PyType, PyTypeRef, type_::PointerSlot}, bytecode::ComparisonOperator, common::hash::PyHash, - convert::{ToPyObject, ToPyResult}, + convert::ToPyObject, function::{ Either, FromArgs, FuncArgs, OptionalArg, PyComparisonValue, PyMethodDef, PySetterValue, }, @@ -1435,10 +1435,7 @@ pub trait Iterable: PyPayload { Self::iter(zelf, vm) } - #[pymethod] - fn __iter__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { - Self::slot_iter(zelf, vm) - } + // __iter__ is exposed via SlotFunc::Iter wrapper in extend_class() fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult; @@ -1458,11 +1455,7 @@ pub trait IterNext: PyPayload + Iterable { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult; - #[inline] - #[pymethod] - fn __next__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { - Self::slot_iternext(&zelf, vm).to_pyresult(vm) - } + // __next__ is exposed via SlotFunc::IterNext wrapper in extend_class() } pub trait SelfIter: PyPayload {} @@ -1477,9 +1470,7 @@ where unreachable!("slot must be overridden for {}", repr.as_str()); } - fn __iter__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self_iter(zelf, vm) - } + // __iter__ is exposed via SlotFunc::Iter wrapper in extend_class() #[cold] fn iter(_zelf: PyRef, _vm: &VirtualMachine) -> PyResult { From be9e44aafbb5410adc26fe582613bc24dd6da277 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 25 Dec 2025 09:08:18 +0900 Subject: [PATCH 053/418] Allow SyntaxError.msg to be writable and reflected in string formatting (#6493) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> --- DEVELOPMENT.md | 3 +- crates/vm/src/exceptions.rs | 45 ++++++++++++++++------ extra_tests/snippets/builtin_exceptions.py | 7 ++++ 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index aa7d99eef33..d5c675faca6 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -130,7 +130,8 @@ repository's structure: - `stdlib`: Standard library parts implemented in rust. - `src`: using the other subcrates to bring rustpython to life. - `wasm`: Binary crate and resources for WebAssembly build -- `extra_tests`: extra integration test snippets as a supplement to `Lib/test` +- `extra_tests`: extra integration test snippets as a supplement to `Lib/test`. + Add new RustPython-only regression tests here; do not place new tests under `Lib/test`. ## Understanding Internals diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 3a932a5df54..8d6ce6142b4 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -9,7 +9,7 @@ use crate::{ }, class::{PyClassImpl, StaticType}, convert::{ToPyException, ToPyObject}, - function::{ArgIterable, FuncArgs, IntoFuncArgs}, + function::{ArgIterable, FuncArgs, IntoFuncArgs, PySetterValue}, py_io::{self, Write}, stdlib::sys, suggestion::offer_suggestions, @@ -994,7 +994,12 @@ impl ExceptionZoo { extend_exception!(PyRecursionError, ctx, excs.recursion_error); extend_exception!(PySyntaxError, ctx, excs.syntax_error, { - "msg" => ctx.new_readonly_getset("msg", excs.syntax_error, make_arg_getter(0)), + "msg" => ctx.new_static_getset( + "msg", + excs.syntax_error, + make_arg_getter(0), + syntax_error_set_msg, + ), // TODO: members "filename" => ctx.none(), "lineno" => ctx.none(), @@ -1041,6 +1046,25 @@ fn make_arg_getter(idx: usize) -> impl Fn(PyBaseExceptionRef) -> Option PyResult<()> { + let mut args = exc.args.write(); + let mut new_args = args.as_slice().to_vec(); + // Ensure the message slot at index 0 always exists for SyntaxError.args. + if new_args.is_empty() { + new_args.push(vm.ctx.none()); + } + match value { + PySetterValue::Assign(value) => new_args[0] = value, + PySetterValue::Delete => new_args[0] = vm.ctx.none(), + } + *args = PyTuple::new_ref(new_args, &vm.ctx); + Ok(()) +} + fn system_exit_code(exc: PyBaseExceptionRef) -> Option { exc.args.read().first().map(|code| { match_class!(match code { @@ -2048,15 +2072,14 @@ pub(super) mod types { .unwrap_or_else(|_| vm.ctx.new_str("")) }); - let args = zelf.args(); - - let msg = if args.len() == 1 { - vm.exception_args_as_string(args, false) - .into_iter() - .exactly_one() - .unwrap() - } else { - return zelf.__str__(vm); + let msg = match zelf.as_object().get_attr("msg", vm) { + Ok(obj) => obj + .str(vm) + .unwrap_or_else(|_| vm.ctx.new_str("")), + Err(_) => { + // Fallback to the base formatting if the msg attribute was deleted or attribute lookup fails for any reason. + return Py::::__str__(zelf, vm); + } }; let msg_with_location_info: String = match (maybe_lineno, maybe_filename) { diff --git a/extra_tests/snippets/builtin_exceptions.py b/extra_tests/snippets/builtin_exceptions.py index 490f831f522..0af5cf05eaf 100644 --- a/extra_tests/snippets/builtin_exceptions.py +++ b/extra_tests/snippets/builtin_exceptions.py @@ -85,6 +85,13 @@ def __init__(self, value): assert exc.offset is None assert exc.text is None +err = SyntaxError("bad bad", ("bad.py", 1, 2, "abcdefg")) +err.msg = "changed" +assert err.msg == "changed" +assert str(err) == "changed (bad.py, line 1)" +del err.msg +assert err.msg is None + # Regression to: # https://github.com/RustPython/RustPython/issues/2779 From 72cf6c36d53148903c5665cb18de04c490a2f264 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 25 Dec 2025 09:09:58 +0900 Subject: [PATCH 054/418] Add missing xmlparser attributes: namespace_prefixes, ordered_attributes, specified_attributes, intern (#6494) * Add namespace_prefixes and other missing xmlparser attributes - Added namespace_prefixes, ordered_attributes, specified_attributes (boolean attributes) - Added intern dictionary attribute - Added stub handlers for all missing handler types to ensure compatibility - Created bool_property macro to ensure boolean attributes are converted correctly * Remove expectedFailure decorators from passing tests - Tests for buffer_text, namespace_prefixes, ordered_attributes, and specified_attributes now pass --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> Co-authored-by: github-actions[bot] --- Lib/test/test_pyexpat.py | 6 -- crates/stdlib/src/pyexpat.rs | 183 ++++++++++++++++++++++++++++++++++- 2 files changed, 182 insertions(+), 7 deletions(-) diff --git a/Lib/test/test_pyexpat.py b/Lib/test/test_pyexpat.py index 80485cc74b9..015e7497268 100644 --- a/Lib/test/test_pyexpat.py +++ b/Lib/test/test_pyexpat.py @@ -20,28 +20,24 @@ class SetAttributeTest(unittest.TestCase): def setUp(self): self.parser = expat.ParserCreate(namespace_separator='!') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_buffer_text(self): self.assertIs(self.parser.buffer_text, False) for x in 0, 1, 2, 0: self.parser.buffer_text = x self.assertIs(self.parser.buffer_text, bool(x)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_namespace_prefixes(self): self.assertIs(self.parser.namespace_prefixes, False) for x in 0, 1, 2, 0: self.parser.namespace_prefixes = x self.assertIs(self.parser.namespace_prefixes, bool(x)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_ordered_attributes(self): self.assertIs(self.parser.ordered_attributes, False) for x in 0, 1, 2, 0: self.parser.ordered_attributes = x self.assertIs(self.parser.ordered_attributes, bool(x)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_specified_attributes(self): self.assertIs(self.parser.specified_attributes, False) for x in 0, 1, 2, 0: @@ -244,7 +240,6 @@ def test_parse_bytes(self): # Issue #6697. self.assertRaises(AttributeError, getattr, parser, '\uD800') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_parse_str(self): out = self.Outputter() parser = expat.ParserCreate(namespace_separator='!') @@ -255,7 +250,6 @@ def test_parse_str(self): operations = out.out self._verify_parse_output(operations) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_parse_file(self): # Try parsing a file out = self.Outputter() diff --git a/crates/stdlib/src/pyexpat.rs b/crates/stdlib/src/pyexpat.rs index 871ba7d5987..699fa21852d 100644 --- a/crates/stdlib/src/pyexpat.rs +++ b/crates/stdlib/src/pyexpat.rs @@ -25,6 +25,26 @@ macro_rules! create_property { }; } +macro_rules! create_bool_property { + ($ctx: expr, $attributes: expr, $name: expr, $class: expr, $element: ident) => { + let attr = $ctx.new_static_getset( + $name, + $class, + move |this: &PyExpatLikeXmlParser| this.$element.read().clone(), + move |this: &PyExpatLikeXmlParser, + value: PyObjectRef, + vm: &VirtualMachine| + -> PyResult<()> { + let bool_value = value.is_true(vm)?; + *this.$element.write() = vm.ctx.new_bool(bool_value).into(); + Ok(()) + }, + ); + + $attributes.insert($ctx.intern_str($name), attr.into()); + }; +} + #[pymodule(name = "pyexpat")] mod _pyexpat { use crate::vm::{ @@ -51,6 +71,29 @@ mod _pyexpat { character_data: MutableObject, entity_decl: MutableObject, buffer_text: MutableObject, + namespace_prefixes: MutableObject, + ordered_attributes: MutableObject, + specified_attributes: MutableObject, + intern: MutableObject, + // Additional handlers (stubs for compatibility) + processing_instruction: MutableObject, + unparsed_entity_decl: MutableObject, + notation_decl: MutableObject, + start_namespace_decl: MutableObject, + end_namespace_decl: MutableObject, + comment: MutableObject, + start_cdata_section: MutableObject, + end_cdata_section: MutableObject, + default: MutableObject, + default_expand: MutableObject, + not_standalone: MutableObject, + external_entity_ref: MutableObject, + start_doctype_decl: MutableObject, + end_doctype_decl: MutableObject, + xml_decl: MutableObject, + element_decl: MutableObject, + attlist_decl: MutableObject, + skipped_entity: MutableObject, } type PyExpatLikeXmlParserRef = PyRef; @@ -71,6 +114,31 @@ mod _pyexpat { character_data: MutableObject::new(vm.ctx.none()), entity_decl: MutableObject::new(vm.ctx.none()), buffer_text: MutableObject::new(vm.ctx.new_bool(false).into()), + namespace_prefixes: MutableObject::new(vm.ctx.new_bool(false).into()), + ordered_attributes: MutableObject::new(vm.ctx.new_bool(false).into()), + specified_attributes: MutableObject::new(vm.ctx.new_bool(false).into()), + // String interning dictionary - used by the parser to intern element/attribute names + // for memory efficiency and faster comparisons. See CPython's pyexpat documentation. + intern: MutableObject::new(vm.ctx.new_dict().into()), + // Additional handlers (stubs for compatibility) + processing_instruction: MutableObject::new(vm.ctx.none()), + unparsed_entity_decl: MutableObject::new(vm.ctx.none()), + notation_decl: MutableObject::new(vm.ctx.none()), + start_namespace_decl: MutableObject::new(vm.ctx.none()), + end_namespace_decl: MutableObject::new(vm.ctx.none()), + comment: MutableObject::new(vm.ctx.none()), + start_cdata_section: MutableObject::new(vm.ctx.none()), + end_cdata_section: MutableObject::new(vm.ctx.none()), + default: MutableObject::new(vm.ctx.none()), + default_expand: MutableObject::new(vm.ctx.none()), + not_standalone: MutableObject::new(vm.ctx.none()), + external_entity_ref: MutableObject::new(vm.ctx.none()), + start_doctype_decl: MutableObject::new(vm.ctx.none()), + end_doctype_decl: MutableObject::new(vm.ctx.none()), + xml_decl: MutableObject::new(vm.ctx.none()), + element_decl: MutableObject::new(vm.ctx.none()), + attlist_decl: MutableObject::new(vm.ctx.none()), + skipped_entity: MutableObject::new(vm.ctx.none()), } .into_ref(&vm.ctx)) } @@ -89,7 +157,120 @@ mod _pyexpat { character_data ); create_property!(ctx, attributes, "EntityDeclHandler", class, entity_decl); - create_property!(ctx, attributes, "buffer_text", class, buffer_text); + create_bool_property!(ctx, attributes, "buffer_text", class, buffer_text); + create_bool_property!( + ctx, + attributes, + "namespace_prefixes", + class, + namespace_prefixes + ); + create_bool_property!( + ctx, + attributes, + "ordered_attributes", + class, + ordered_attributes + ); + create_bool_property!( + ctx, + attributes, + "specified_attributes", + class, + specified_attributes + ); + create_property!(ctx, attributes, "intern", class, intern); + // Additional handlers (stubs for compatibility) + create_property!( + ctx, + attributes, + "ProcessingInstructionHandler", + class, + processing_instruction + ); + create_property!( + ctx, + attributes, + "UnparsedEntityDeclHandler", + class, + unparsed_entity_decl + ); + create_property!(ctx, attributes, "NotationDeclHandler", class, notation_decl); + create_property!( + ctx, + attributes, + "StartNamespaceDeclHandler", + class, + start_namespace_decl + ); + create_property!( + ctx, + attributes, + "EndNamespaceDeclHandler", + class, + end_namespace_decl + ); + create_property!(ctx, attributes, "CommentHandler", class, comment); + create_property!( + ctx, + attributes, + "StartCdataSectionHandler", + class, + start_cdata_section + ); + create_property!( + ctx, + attributes, + "EndCdataSectionHandler", + class, + end_cdata_section + ); + create_property!(ctx, attributes, "DefaultHandler", class, default); + create_property!( + ctx, + attributes, + "DefaultHandlerExpand", + class, + default_expand + ); + create_property!( + ctx, + attributes, + "NotStandaloneHandler", + class, + not_standalone + ); + create_property!( + ctx, + attributes, + "ExternalEntityRefHandler", + class, + external_entity_ref + ); + create_property!( + ctx, + attributes, + "StartDoctypeDeclHandler", + class, + start_doctype_decl + ); + create_property!( + ctx, + attributes, + "EndDoctypeDeclHandler", + class, + end_doctype_decl + ); + create_property!(ctx, attributes, "XmlDeclHandler", class, xml_decl); + create_property!(ctx, attributes, "ElementDeclHandler", class, element_decl); + create_property!(ctx, attributes, "AttlistDeclHandler", class, attlist_decl); + create_property!( + ctx, + attributes, + "SkippedEntityHandler", + class, + skipped_entity + ); } fn create_config(&self) -> xml::ParserConfig { From 7ebb0f0c5c8e8958238e80e23f5e108dbb892db7 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 25 Dec 2025 09:21:09 +0900 Subject: [PATCH 055/418] impl path_converter and os functions (#6484) * os.setpgrp * tcgetpgrp * impl more os functions * impl PathConverter --- Lib/test/test_os.py | 9 +- crates/vm/src/function/fspath.rs | 3 +- crates/vm/src/ospath.rs | 217 ++++++++++++++++++++++++++----- crates/vm/src/stdlib/nt.rs | 29 +++-- crates/vm/src/stdlib/os.rs | 46 +++++-- crates/vm/src/stdlib/posix.rs | 26 ++++ 6 files changed, 272 insertions(+), 58 deletions(-) diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index bb558524c24..4755aef080e 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -1725,15 +1725,15 @@ def walk(self, top, **kwargs): bdirs[:] = list(map(os.fsencode, dirs)) bfiles[:] = list(map(os.fsencode, files)) - @unittest.expectedFailure # TODO: RUSTPYTHON; (TypeError: Can't mix strings and bytes in path components) + @unittest.expectedFailure # TODO: RUSTPYTHON; WalkTests doesn't have these methods def test_compare_to_walk(self): return super().test_compare_to_walk() - @unittest.expectedFailure # TODO: RUSTPYTHON; (TypeError: Can't mix strings and bytes in path components) + @unittest.expectedFailure # TODO: RUSTPYTHON; WalkTests doesn't have these methods def test_dir_fd(self): return super().test_dir_fd() - @unittest.expectedFailure # TODO: RUSTPYTHON; (TypeError: Can't mix strings and bytes in path components) + @unittest.expectedFailure # TODO: RUSTPYTHON; WalkTests doesn't have these methods def test_yields_correct_dir_fd(self): return super().test_yields_correct_dir_fd() @@ -4502,7 +4502,6 @@ class Str(str): self.filenames = self.bytes_filenames + self.unicode_filenames - @unittest.expectedFailure # TODO: RUSTPYTHON; (AssertionError: b'@test_22106_tmp\xe7w\xf0' is not b'@test_22106_tmp\xe7w\xf0' : ) def test_oserror_filename(self): funcs = [ (self.filenames, os.chdir,), @@ -4906,7 +4905,6 @@ def setUp(self): def test_uninstantiable(self): self.assertRaises(TypeError, os.DirEntry) - @unittest.expectedFailure # TODO: RUSTPYTHON; (pickle.PicklingError: Can't pickle : it's not found as _os.DirEntry) def test_unpickable(self): filename = create_file(os.path.join(self.path, "file.txt"), b'python') entry = [entry for entry in os.scandir(self.path)].pop() @@ -5337,7 +5335,6 @@ def __fspath__(self): return '' self.assertFalse(hasattr(A(), '__dict__')) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_fspath_set_to_None(self): class Foo: __fspath__ = None diff --git a/crates/vm/src/function/fspath.rs b/crates/vm/src/function/fspath.rs index 2bc331844c6..44d41ab7632 100644 --- a/crates/vm/src/function/fspath.rs +++ b/crates/vm/src/function/fspath.rs @@ -7,6 +7,7 @@ use crate::{ }; use std::{borrow::Cow, ffi::OsStr, path::PathBuf}; +/// Helper to implement os.fspath() #[derive(Clone)] pub enum FsPath { Str(PyStrRef), @@ -27,7 +28,7 @@ impl FsPath { ) } - // PyOS_FSPath in CPython + // PyOS_FSPath pub fn try_from( obj: PyObjectRef, check_for_nul: bool, diff --git a/crates/vm/src/ospath.rs b/crates/vm/src/ospath.rs index 9fca53d869c..77abbee2cd5 100644 --- a/crates/vm/src/ospath.rs +++ b/crates/vm/src/ospath.rs @@ -2,20 +2,181 @@ use rustpython_common::crt_fd; use crate::{ PyObjectRef, PyResult, VirtualMachine, + builtins::{PyBytes, PyStr}, convert::{IntoPyException, ToPyException, ToPyObject, TryFromObject}, function::FsPath, }; use std::path::{Path, PathBuf}; -// path_ without allow_fd in CPython +/// path_converter +#[derive(Clone, Copy, Default)] +pub struct PathConverter { + /// Function name for error messages (e.g., "rename") + pub function_name: Option<&'static str>, + /// Argument name for error messages (e.g., "src", "dst") + pub argument_name: Option<&'static str>, + /// If true, embedded null characters are allowed + pub non_strict: bool, +} + +impl PathConverter { + pub const fn new() -> Self { + Self { + function_name: None, + argument_name: None, + non_strict: false, + } + } + + pub const fn function(mut self, name: &'static str) -> Self { + self.function_name = Some(name); + self + } + + pub const fn argument(mut self, name: &'static str) -> Self { + self.argument_name = Some(name); + self + } + + pub const fn non_strict(mut self) -> Self { + self.non_strict = true; + self + } + + /// Generate error message prefix like "rename: " + fn error_prefix(&self) -> String { + match self.function_name { + Some(func) => format!("{}: ", func), + None => String::new(), + } + } + + /// Get argument name for error messages, defaults to "path" + fn arg_name(&self) -> &'static str { + self.argument_name.unwrap_or("path") + } + + /// Format a type error message + fn type_error_msg(&self, type_name: &str, allow_fd: bool) -> String { + let expected = if allow_fd { + "string, bytes, os.PathLike or integer" + } else { + "string, bytes or os.PathLike" + }; + format!( + "{}{} should be {}, not {}", + self.error_prefix(), + self.arg_name(), + expected, + type_name + ) + } + + /// Convert to OsPathOrFd (path or file descriptor) + pub(crate) fn try_path_or_fd<'fd>( + &self, + obj: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult> { + // Handle fd (before __fspath__ check, like CPython) + if let Some(int) = obj.try_index_opt(vm) { + let fd = int?.try_to_primitive(vm)?; + return unsafe { crt_fd::Borrowed::try_borrow_raw(fd) } + .map(OsPathOrFd::Fd) + .map_err(|e| e.into_pyexception(vm)); + } + + self.try_path_inner(obj, true, vm).map(OsPathOrFd::Path) + } + + /// Convert to OsPath only (no fd support) + fn try_path_inner( + &self, + obj: PyObjectRef, + allow_fd: bool, + vm: &VirtualMachine, + ) -> PyResult { + // Try direct str/bytes match + let obj = match self.try_match_str_bytes(obj.clone(), vm)? { + Ok(path) => return Ok(path), + Err(obj) => obj, + }; + + // Call __fspath__ + let type_error_msg = || self.type_error_msg(&obj.class().name(), allow_fd); + let method = + vm.get_method_or_type_error(obj.clone(), identifier!(vm, __fspath__), type_error_msg)?; + if vm.is_none(&method) { + return Err(vm.new_type_error(type_error_msg())); + } + let result = method.call((), vm)?; + + // Match __fspath__ result + self.try_match_str_bytes(result.clone(), vm)?.map_err(|_| { + vm.new_type_error(format!( + "{}expected {}.__fspath__() to return str or bytes, not {}", + self.error_prefix(), + obj.class().name(), + result.class().name(), + )) + }) + } + + /// Try to match str or bytes, returns Err(obj) if neither + fn try_match_str_bytes( + &self, + obj: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult> { + let check_nul = |b: &[u8]| { + if self.non_strict || memchr::memchr(b'\0', b).is_none() { + Ok(()) + } else { + Err(vm.new_value_error(format!( + "{}embedded null character in {}", + self.error_prefix(), + self.arg_name() + ))) + } + }; + + match_class!(match obj { + s @ PyStr => { + check_nul(s.as_bytes())?; + let path = vm.fsencode(&s)?.into_owned(); + Ok(Ok(OsPath { + path, + origin: Some(s.into()), + })) + } + b @ PyBytes => { + check_nul(&b)?; + let path = FsPath::bytes_as_os_str(&b, vm)?.to_owned(); + Ok(Ok(OsPath { + path, + origin: Some(b.into()), + })) + } + obj => Ok(Err(obj)), + }) + } + + /// Convert to OsPath directly + pub fn try_path(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.try_path_inner(obj, false, vm) + } +} + +/// path_t output - the converted path #[derive(Clone)] pub struct OsPath { pub path: std::ffi::OsString, - pub(super) mode: OutputMode, + /// Original Python object for identity preservation in OSError + pub(super) origin: Option, } #[derive(Debug, Copy, Clone)] -pub(super) enum OutputMode { +pub enum OutputMode { String, Bytes, } @@ -38,19 +199,19 @@ impl OutputMode { impl OsPath { pub fn new_str(path: impl Into) -> Self { let path = path.into(); - Self { - path, - mode: OutputMode::String, - } + Self { path, origin: None } } pub(crate) fn from_fspath(fspath: FsPath, vm: &VirtualMachine) -> PyResult { let path = fspath.as_os_str(vm)?.into_owned(); - let mode = match fspath { - FsPath::Str(_) => OutputMode::String, - FsPath::Bytes(_) => OutputMode::Bytes, + let origin = match fspath { + FsPath::Str(s) => s.into(), + FsPath::Bytes(b) => b.into(), }; - Ok(Self { path, mode }) + Ok(Self { + path, + origin: Some(origin), + }) } /// Convert an object to OsPath using the os.fspath-style error message. @@ -83,7 +244,20 @@ impl OsPath { } pub fn filename(&self, vm: &VirtualMachine) -> PyObjectRef { - self.mode.process_path(self.path.clone(), vm) + if let Some(ref origin) = self.origin { + origin.clone() + } else { + // Default to string when no origin (e.g., from new_str) + OutputMode::String.process_path(self.path.clone(), vm) + } + } + + /// Get the output mode based on origin type (bytes -> Bytes, otherwise -> String) + pub fn mode(&self) -> OutputMode { + match &self.origin { + Some(obj) if obj.downcast_ref::().is_some() => OutputMode::Bytes, + _ => OutputMode::String, + } } } @@ -94,15 +268,8 @@ impl AsRef for OsPath { } impl TryFromObject for OsPath { - // TODO: path_converter with allow_fd=0 in CPython fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let fspath = FsPath::try_from( - obj, - true, - "should be string, bytes, os.PathLike or integer", - vm, - )?; - Self::from_fspath(fspath, vm) + PathConverter::new().try_path(obj, vm) } } @@ -115,15 +282,7 @@ pub(crate) enum OsPathOrFd<'fd> { impl TryFromObject for OsPathOrFd<'_> { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - match obj.try_index_opt(vm) { - Some(int) => { - let fd = int?.try_to_primitive(vm)?; - unsafe { crt_fd::Borrowed::try_borrow_raw(fd) } - .map(Self::Fd) - .map_err(|e| e.into_pyexception(vm)) - } - None => obj.try_into_value(vm).map(Self::Path), - } + PathConverter::new().try_path_or_fd(obj, vm) } } diff --git a/crates/vm/src/stdlib/nt.rs b/crates/vm/src/stdlib/nt.rs index ada939b1549..fc68139dc88 100644 --- a/crates/vm/src/stdlib/nt.rs +++ b/crates/vm/src/stdlib/nt.rs @@ -17,6 +17,7 @@ pub(crate) mod module { builtins::{PyBaseExceptionRef, PyDictRef, PyListRef, PyStrRef, PyTupleRef}, common::{crt_fd, suppress_iph, windows::ToWideString}, convert::ToPyException, + exceptions::OSErrorBuilder, function::{Either, OptionalArg}, ospath::{OsPath, OsPathOrFd}, stdlib::os::{_os, DirFd, SupportFunc, TargetIsDirectory}, @@ -193,10 +194,12 @@ pub(crate) mod module { }; // Use symlink_metadata to avoid following dangling symlinks - let meta = fs::symlink_metadata(&actual_path).map_err(|err| err.to_pyexception(vm))?; + let meta = fs::symlink_metadata(&actual_path) + .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; let mut permissions = meta.permissions(); permissions.set_readonly(mode & S_IWRITE == 0); - fs::set_permissions(&*actual_path, permissions).map_err(|err| err.to_pyexception(vm)) + fs::set_permissions(&*actual_path, permissions) + .map_err(|err| OSErrorBuilder::with_filename(&err, path, vm)) } /// Get the real file name (with correct case) without accessing the file. @@ -925,7 +928,7 @@ pub(crate) mod module { .as_ref() .canonicalize() .map_err(|e| e.to_pyexception(vm))?; - Ok(path.mode.process_path(real, vm)) + Ok(path.mode().process_path(real, vm)) } #[pyfunction] @@ -958,7 +961,7 @@ pub(crate) mod module { } } let buffer = widestring::WideCString::from_vec_truncate(buffer); - Ok(path.mode.process_path(buffer.to_os_string(), vm)) + Ok(path.mode().process_path(buffer.to_os_string(), vm)) } #[pyfunction] @@ -973,7 +976,7 @@ pub(crate) mod module { return Err(vm.new_last_os_error()); } let buffer = widestring::WideCString::from_vec_truncate(buffer); - Ok(path.mode.process_path(buffer.to_os_string(), vm)) + Ok(path.mode().process_path(buffer.to_os_string(), vm)) } /// Implements _Py_skiproot logic for Windows paths @@ -1053,7 +1056,7 @@ pub(crate) mod module { use crate::builtins::{PyBytes, PyStr}; use rustpython_common::wtf8::Wtf8Buf; - // Handle path-like objects via os.fspath, but without null check (nonstrict=True) + // Handle path-like objects via os.fspath, but without null check (non_strict=True) let path = if let Some(fspath) = vm.get_method(path.clone(), identifier!(vm, __fspath__)) { fspath?.call((), vm)? } else { @@ -1585,7 +1588,7 @@ pub(crate) mod module { use windows_sys::Win32::System::IO::DeviceIoControl; use windows_sys::Win32::System::Ioctl::FSCTL_GET_REPARSE_POINT; - let mode = path.mode; + let mode = path.mode(); let wide_path = path.as_ref().to_wide_with_nul(); // Open the file/directory with reparse point flag @@ -1602,7 +1605,11 @@ pub(crate) mod module { }; if handle == INVALID_HANDLE_VALUE { - return Err(io::Error::last_os_error().to_pyexception(vm)); + return Err(OSErrorBuilder::with_filename( + &io::Error::last_os_error(), + path.clone(), + vm, + )); } // Buffer for reparse data - MAXIMUM_REPARSE_DATA_BUFFER_SIZE is 16384 @@ -1626,7 +1633,11 @@ pub(crate) mod module { unsafe { CloseHandle(handle) }; if result == 0 { - return Err(io::Error::last_os_error().to_pyexception(vm)); + return Err(OSErrorBuilder::with_filename( + &io::Error::last_os_error(), + path.clone(), + vm, + )); } // Parse the reparse data buffer diff --git a/crates/vm/src/stdlib/os.rs b/crates/vm/src/stdlib/os.rs index 868fc727c65..87080ff8e0f 100644 --- a/crates/vm/src/stdlib/os.rs +++ b/crates/vm/src/stdlib/os.rs @@ -164,7 +164,7 @@ pub(super) mod _os { convert::{IntoPyException, ToPyObject}, exceptions::OSErrorBuilder, function::{ArgBytesLike, FsPath, FuncArgs, OptionalArg}, - ospath::{OsPath, OsPathOrFd, OutputMode}, + ospath::{OsPath, OsPathOrFd, OutputMode, PathConverter}, protocol::PyIterReturn, recursion::ReprGuard, types::{IterNext, Iterable, PyStructSequence, Representable, SelfIter}, @@ -366,10 +366,12 @@ pub(super) mod _os { #[pyfunction] fn listdir( - path: OptionalArg>, + path: OptionalArg>>, vm: &VirtualMachine, ) -> PyResult> { - let path = path.unwrap_or_else(|| OsPathOrFd::Path(OsPath::new_str("."))); + let path = path + .flatten() + .unwrap_or_else(|| OsPathOrFd::Path(OsPath::new_str("."))); let list = match path { OsPathOrFd::Path(path) => { let dir_iter = match fs::read_dir(&path) { @@ -378,9 +380,10 @@ pub(super) mod _os { return Err(OSErrorBuilder::with_filename(&err, path, vm)); } }; + let mode = path.mode(); dir_iter .map(|entry| match entry { - Ok(entry_path) => Ok(path.mode.process_path(entry_path.file_name(), vm)), + Ok(entry_path) => Ok(mode.process_path(entry_path.file_name(), vm)), Err(err) => Err(OSErrorBuilder::with_filename(&err, path.clone(), vm)), }) .collect::>()? @@ -545,7 +548,7 @@ pub(super) mod _os { #[pyfunction] fn readlink(path: OsPath, dir_fd: DirFd<'_, 0>, vm: &VirtualMachine) -> PyResult { - let mode = path.mode; + let mode = path.mode(); let [] = dir_fd.0; let path = fs::read_link(&path).map_err(|err| OSErrorBuilder::with_filename(&err, path, vm))?; @@ -640,7 +643,7 @@ pub(super) mod _os { stat( OsPath { path: self.pathval.as_os_str().to_owned(), - mode: OutputMode::String, + origin: None, } .into(), dir_fd, @@ -671,11 +674,7 @@ pub(super) mod _os { Some(ino) => Ok(ino), None => { let stat = stat_inner( - OsPath { - path: self.pathval.as_os_str().to_owned(), - mode: OutputMode::String, - } - .into(), + OsPath::new_str(self.pathval.as_os_str()).into(), DirFd::default(), FollowSymlinks(false), ) @@ -730,6 +729,11 @@ pub(super) mod _os { ) -> PyGenericAlias { PyGenericAlias::from_args(cls, args, vm) } + + #[pymethod] + fn __reduce__(&self, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("cannot pickle 'DirEntry' object".to_owned())) + } } impl Representable for DirEntry { @@ -785,6 +789,11 @@ pub(super) mod _os { fn __exit__(zelf: PyRef, _args: FuncArgs) { zelf.close() } + + #[pymethod] + fn __reduce__(&self, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("cannot pickle 'ScandirIterator' object".to_owned())) + } } impl SelfIter for ScandirIterator {} impl IterNext for ScandirIterator { @@ -861,7 +870,7 @@ pub(super) mod _os { .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; Ok(ScandirIterator { entries: PyRwLock::new(Some(entries)), - mode: path.mode, + mode: path.mode(), } .into_ref(&vm.ctx) .into()) @@ -1124,7 +1133,16 @@ pub(super) mod _os { #[pyfunction] #[pyfunction(name = "replace")] - fn rename(src: OsPath, dst: OsPath, vm: &VirtualMachine) -> PyResult<()> { + fn rename(src: PyObjectRef, dst: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let src = PathConverter::new() + .function("rename") + .argument("src") + .try_path(src, vm)?; + let dst = PathConverter::new() + .function("rename") + .argument("dst") + .try_path(dst, vm)?; + fs::rename(&src.path, &dst.path).map_err(|err| { let builder = err.to_os_error_builder(vm); let builder = builder.filename(src.filename(vm)); @@ -1708,6 +1726,8 @@ pub(super) mod _os { SupportFunc::new("fstat", Some(true), Some(STAT_DIR_FD), Some(true)), SupportFunc::new("symlink", Some(false), Some(SYMLINK_DIR_FD), Some(false)), SupportFunc::new("truncate", Some(true), Some(false), Some(false)), + SupportFunc::new("ftruncate", Some(true), Some(false), Some(false)), + SupportFunc::new("fsync", Some(true), Some(false), Some(false)), SupportFunc::new( "utime", Some(false), diff --git a/crates/vm/src/stdlib/posix.rs b/crates/vm/src/stdlib/posix.rs index 59e41782574..15c8745ded4 100644 --- a/crates/vm/src/stdlib/posix.rs +++ b/crates/vm/src/stdlib/posix.rs @@ -1240,6 +1240,12 @@ pub mod module { .map_err(|err| err.into_pyexception(vm)) } + #[pyfunction] + fn setpgrp(vm: &VirtualMachine) -> PyResult<()> { + // setpgrp() is equivalent to setpgid(0, 0) + unistd::setpgid(Pid::from_raw(0), Pid::from_raw(0)).map_err(|err| err.into_pyexception(vm)) + } + #[cfg(not(any(target_os = "wasi", target_os = "redox")))] #[pyfunction] fn setsid(vm: &VirtualMachine) -> PyResult<()> { @@ -1248,6 +1254,24 @@ pub mod module { .map_err(|err| err.into_pyexception(vm)) } + #[cfg(not(any(target_os = "wasi", target_os = "redox")))] + #[pyfunction] + fn tcgetpgrp(fd: i32, vm: &VirtualMachine) -> PyResult { + use std::os::fd::BorrowedFd; + let fd = unsafe { BorrowedFd::borrow_raw(fd) }; + unistd::tcgetpgrp(fd) + .map(|pid| pid.as_raw()) + .map_err(|err| err.into_pyexception(vm)) + } + + #[cfg(not(any(target_os = "wasi", target_os = "redox")))] + #[pyfunction] + fn tcsetpgrp(fd: i32, pgid: libc::pid_t, vm: &VirtualMachine) -> PyResult<()> { + use std::os::fd::BorrowedFd; + let fd = unsafe { BorrowedFd::borrow_raw(fd) }; + unistd::tcsetpgrp(fd, Pid::from_raw(pgid)).map_err(|err| err.into_pyexception(vm)) + } + fn try_from_id(vm: &VirtualMachine, obj: PyObjectRef, typ_name: &str) -> PyResult { use std::cmp::Ordering; let i = obj @@ -1833,6 +1857,8 @@ pub mod module { SupportFunc::new("umask", Some(false), Some(false), Some(false)), SupportFunc::new("execv", None, None, None), SupportFunc::new("pathconf", Some(true), None, None), + SupportFunc::new("fpathconf", Some(true), None, None), + SupportFunc::new("fchdir", Some(true), None, None), ] } From ebdc033c59d39405ee1d6c2b31afe8fb1b2c61bc Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 25 Dec 2025 09:48:20 +0900 Subject: [PATCH 056/418] Bundle Lib directory in GitHub releases (#6497) * Initial plan * Add Lib directory archives to release workflow Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> * Add release notes explaining users need both binary and Lib archive Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> * Remove unnecessary --notes-start-tag parameter Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> * Only create zip archive, remove tar.gz Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> --- .github/workflows/release.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3efff295c37..5ad2cf3f97a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -144,6 +144,8 @@ jobs: if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }} needs: [build, build-wasm] steps: + - uses: actions/checkout@v6.0.1 + - name: Download Binary Artifacts uses: actions/download-artifact@v7.0.0 with: @@ -151,6 +153,10 @@ jobs: pattern: rustpython-* merge-multiple: true + - name: Create Lib Archive + run: | + zip -r bin/rustpython-lib.zip Lib/ + - name: List Binaries run: | ls -lah bin/ @@ -174,6 +180,7 @@ jobs: --repo="$GITHUB_REPOSITORY" \ --title="RustPython $RELEASE_TYPE_NAME $today-$tag #$run" \ --target="$tag" \ + --notes "âš ï¸ **Important**: To run RustPython, you must download both the binary for your platform AND the \`rustpython-lib.zip\` archive. Extract the Lib directory from the archive to the same location as the binary, or set the \`RUSTPYTHONPATH\` environment variable to point to the Lib directory." \ --generate-notes \ $PRERELEASE_ARG \ bin/rustpython-release-* From 5122aed738b4a50a4f52af30b650967be9c081ef Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 25 Dec 2025 10:32:51 +0900 Subject: [PATCH 057/418] Fix install ujson (#6502) --- crates/stdlib/src/ssl/compat.rs | 96 ++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 20 deletions(-) diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index cd927a0e410..45aa9c4fce9 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -1421,6 +1421,50 @@ pub(super) fn ssl_read( return Err(SslError::Eof); } } + + // For non-blocking sockets, return WantRead so caller can poll and retry. + // For blocking sockets (or sockets with timeout), wait for more data. + if !is_bio { + let timeout = socket.get_socket_timeout(vm).map_err(SslError::Py)?; + if let Some(t) = timeout + && t.is_zero() + { + // Non-blocking socket: return immediately + return Err(SslError::WantRead); + } + // Blocking socket or socket with timeout: try to read more data from socket. + // Even though rustls says it doesn't want to read, more TLS records may arrive. + // This handles the case where rustls processed all buffered TLS records but + // more data is coming over the network. + let data = match socket.sock_recv(2048, vm) { + Ok(data) => data, + Err(e) => { + if is_connection_closed_error(&e, vm) { + return Err(SslError::Eof); + } + return Err(SslError::Py(e)); + } + }; + + let bytes_read = data + .clone() + .try_into_value::(vm) + .map(|b| b.as_bytes().len()) + .unwrap_or(0); + + if bytes_read == 0 { + // No more data available - connection might be closed + return Err(SslError::Eof); + } + + // Feed data to rustls and process + ssl_read_tls_records(conn, data, false, vm)?; + conn.process_new_packets().map_err(SslError::from_rustls)?; + + // Continue loop to try reading plaintext + continue; + } + return Err(SslError::WantRead); } @@ -1432,20 +1476,9 @@ pub(super) fn ssl_read( // Continue loop to try reading plaintext } Err(SslError::Io(ref io_err)) if io_err.to_string().contains("message buffer full") => { - // Buffer is full - we need to consume plaintext before reading more - // Try to read plaintext now - match try_read_plaintext(conn, buf)? { - Some(n) if n > 0 => { - // Have plaintext - return it - // Python will call read() again if it needs more data - return Ok(n); - } - _ => { - // No plaintext available yet - this is unusual - // Return WantRead to let Python retry - return Err(SslError::WantRead); - } - } + // This case should be rare now that ssl_read_tls_records handles buffer full + // Just continue loop to try again + continue; } Err(e) => { // Other errors - check for buffered plaintext before propagating @@ -1524,7 +1557,7 @@ fn ssl_read_tls_records( } // Feed all received data to read_tls - loop to consume all data - // read_tls may not consume all data in one call + // read_tls may not consume all data in one call, and buffer may become full let mut offset = 0; while offset < bytes_data.len() { let remaining = &bytes_data[offset..]; @@ -1533,12 +1566,33 @@ fn ssl_read_tls_records( match conn.read_tls(&mut cursor) { Ok(read_bytes) => { if read_bytes == 0 { - // No more data can be consumed - break; + // Buffer is full - process existing packets to make room + conn.process_new_packets().map_err(SslError::from_rustls)?; + + // Try again - if we still can't consume, break + let mut retry_cursor = std::io::Cursor::new(remaining); + match conn.read_tls(&mut retry_cursor) { + Ok(0) => { + // Still can't consume - break to avoid infinite loop + break; + } + Ok(n) => { + offset += n; + } + Err(e) => { + return Err(SslError::Io(e)); + } + } + } else { + offset += read_bytes; } - offset += read_bytes; } Err(e) => { + // Check if it's a buffer full error (unlikely but handle it) + if e.to_string().contains("buffer full") { + conn.process_new_packets().map_err(SslError::from_rustls)?; + continue; + } // Real error - propagate it return Err(SslError::Io(e)); } @@ -1599,8 +1653,10 @@ fn ssl_ensure_data_available( .sock_wait_for_io_impl(SelectKind::Read, vm) .map_err(SslError::Py)?; if timed_out { - // Socket not ready within timeout - return Err(SslError::WantRead); + // Socket not ready within timeout - raise socket.timeout + return Err(SslError::Timeout( + "The read operation timed out".to_string(), + )); } } // else: non-blocking socket (timeout=0) or blocking socket (timeout=None) - skip select From 6dbc8f0cfa1d0cd48b90d48c6f7768a030a091a8 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 25 Dec 2025 10:33:22 +0900 Subject: [PATCH 058/418] Set SourceFileLoader for script execution in run_simple_file (#6496) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> --- Lib/test/test_cmd_line_script.py | 2 -- crates/vm/src/vm/compile.rs | 12 +++++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py index 833dc6b15d8..d773674feef 100644 --- a/Lib/test/test_cmd_line_script.py +++ b/Lib/test/test_cmd_line_script.py @@ -225,8 +225,6 @@ def test_repl_stderr_flush(self): def test_repl_stderr_flush_separate_stderr(self): self.check_repl_stderr_flush(True) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_basic_script(self): with os_helper.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'script') diff --git a/crates/vm/src/vm/compile.rs b/crates/vm/src/vm/compile.rs index 44332cda838..a7e31cf0377 100644 --- a/crates/vm/src/vm/compile.rs +++ b/crates/vm/src/vm/compile.rs @@ -81,7 +81,7 @@ impl VirtualMachine { todo!("running pyc is not implemented yet"); } else { if path != "" { - // TODO: set_main_loader(dict, filename, "SourceFileLoader"); + set_main_loader(&module_dict, path, self)?; } // TODO: replace to something equivalent to py_run_file match std::fs::read_to_string(path) { @@ -125,6 +125,16 @@ impl VirtualMachine { } } +fn set_main_loader(module_dict: &PyDictRef, filename: &str, vm: &VirtualMachine) -> PyResult<()> { + vm.import("importlib.machinery", 0)?; + let sys_modules = vm.sys_module.get_attr(identifier!(vm, modules), vm)?; + let machinery = sys_modules.get_item("importlib.machinery", vm)?; + let loader_class = machinery.get_attr("SourceFileLoader", vm)?; + let loader = loader_class.call((identifier!(vm, __main__).to_owned(), filename), vm)?; + module_dict.set_item("__loader__", loader, vm)?; + Ok(()) +} + fn get_importer(path: &str, vm: &VirtualMachine) -> PyResult> { let path_importer_cache = vm.sys_module.get_attr("path_importer_cache", vm)?; let path_importer_cache = PyDictRef::try_from_object(vm, path_importer_cache)?; From aae6bf566f0ef4de7beea473b1a73030834ddb16 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 25 Dec 2025 11:00:53 +0900 Subject: [PATCH 059/418] Add tp_str (#6495) * clean up hash * __str__ slot wrapper --- Lib/test/test_inspect/test_inspect.py | 1 - crates/vm/src/builtins/descriptor.rs | 13 +++++++++---- crates/vm/src/builtins/object.rs | 4 ++-- crates/vm/src/class.rs | 20 +++++--------------- crates/vm/src/types/slot.rs | 2 +- 5 files changed, 17 insertions(+), 23 deletions(-) diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py index e403ab7b226..353d7c2e7b2 100644 --- a/Lib/test/test_inspect/test_inspect.py +++ b/Lib/test/test_inspect/test_inspect.py @@ -409,7 +409,6 @@ class NotFuture: pass coro.close(); gen_coro.close() # silence warnings - @unittest.expectedFailure # TODO: RUSTPYTHON def test_isroutine(self): # method self.assertTrue(inspect.isroutine(git.argue)) diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index fcf40d0082d..297199eee18 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -399,6 +399,7 @@ pub fn init(ctx: &Context) { pub enum SlotFunc { Init(InitFunc), Hash(HashFunc), + Str(StringifyFunc), Repr(StringifyFunc), Iter(IterFunc), IterNext(IterNextFunc), @@ -409,6 +410,7 @@ impl std::fmt::Debug for SlotFunc { match self { SlotFunc::Init(_) => write!(f, "SlotFunc::Init(...)"), SlotFunc::Hash(_) => write!(f, "SlotFunc::Hash(...)"), + SlotFunc::Str(_) => write!(f, "SlotFunc::Str(...)"), SlotFunc::Repr(_) => write!(f, "SlotFunc::Repr(...)"), SlotFunc::Iter(_) => write!(f, "SlotFunc::Iter(...)"), SlotFunc::IterNext(_) => write!(f, "SlotFunc::IterNext(...)"), @@ -433,11 +435,14 @@ impl SlotFunc { let hash = func(&obj, vm)?; Ok(vm.ctx.new_int(hash).into()) } - SlotFunc::Repr(func) => { + SlotFunc::Repr(func) | SlotFunc::Str(func) => { if !args.args.is_empty() || !args.kwargs.is_empty() { - return Err( - vm.new_type_error("__repr__() takes no arguments (1 given)".to_owned()) - ); + let name = match self { + SlotFunc::Repr(_) => "__repr__", + SlotFunc::Str(_) => "__str__", + _ => unreachable!(), + }; + return Err(vm.new_type_error(format!("{name}() takes no arguments (1 given)"))); } let s = func(&obj, vm)?; Ok(s.into()) diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs index 854cb2701ed..65a11d55314 100644 --- a/crates/vm/src/builtins/object.rs +++ b/crates/vm/src/builtins/object.rs @@ -425,8 +425,8 @@ impl PyBaseObject { } /// Return str(self). - #[pymethod] - fn __str__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + #[pyslot] + fn slot_str(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { // FIXME: try tp_repr first and fallback to object.__repr__ zelf.repr(vm) } diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index 1addf00497d..40f01d98b00 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -165,25 +165,15 @@ pub trait PyClassImpl: PyClassDef { "Initialize self. See help(type(self)) for accurate signature." ); add_slot_wrapper!(repr, __repr__, Repr, "Return repr(self)."); + add_slot_wrapper!(str, __str__, Str, "Return str(self)."); add_slot_wrapper!(iter, __iter__, Iter, "Implement iter(self)."); add_slot_wrapper!(iternext, __next__, IterNext, "Implement next(self)."); // __hash__ needs special handling: hash_not_implemented sets __hash__ = None - if let Some(hash_func) = class.slots.hash.load() { - if hash_func as usize == hash_not_implemented as usize { - class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); - } else { - let hash_name = identifier!(ctx, __hash__); - if !class.attributes.read().contains_key(hash_name) { - let wrapper = PySlotWrapper { - typ: class, - name: ctx.intern_str("__hash__"), - wrapped: SlotFunc::Hash(hash_func), - doc: Some("Return hash(self)."), - }; - class.set_attr(hash_name, wrapper.into_ref(ctx).into()); - } - } + if class.slots.hash.load().map_or(0, |h| h as usize) == hash_not_implemented as usize { + class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); + } else { + add_slot_wrapper!(hash, __hash__, Hash, "Return hash(self)."); } class.extend_methods(class.slots.methods, ctx); diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index 26c059067e0..57ac24f461a 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -140,7 +140,7 @@ pub struct PyTypeSlots { // More standard operations (here for binary compatibility) pub hash: AtomicCell>, pub call: AtomicCell>, - // tp_str + pub str: AtomicCell>, pub repr: AtomicCell>, pub getattro: AtomicCell>, pub setattro: AtomicCell>, From 49dbbbd5b9990038bed3ee740255a56da8678492 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 25 Dec 2025 13:28:50 +0900 Subject: [PATCH 060/418] Fix SSL test_preauth_data_to_tls_server (#6508) --- crates/stdlib/src/ssl/compat.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 45aa9c4fce9..fa12855e242 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -391,6 +391,8 @@ pub(super) enum SslError { ZeroReturn, /// Unexpected EOF without close_notify (protocol violation) Eof, + /// Non-TLS data received before handshake completed + PreauthData, /// Certificate verification error CertVerification(rustls::CertificateError), /// I/O error @@ -562,6 +564,15 @@ impl SslError { .upcast(), SslError::ZeroReturn => create_ssl_zero_return_error(vm).upcast(), SslError::Eof => create_ssl_eof_error(vm).upcast(), + SslError::PreauthData => { + // Non-TLS data received before handshake + Self::create_ssl_error_with_reason( + vm, + None, + "before TLS handshake with data", + "before TLS handshake with data", + ) + } SslError::CertVerification(cert_err) => { // Use the proper cert verification error creator create_ssl_cert_verification_error(vm, &cert_err).expect("unlikely to happen") @@ -1245,6 +1256,12 @@ pub(super) fn ssl_do_handshake( } } + // InvalidMessage during handshake means non-TLS data was received + // before the handshake completed (e.g., HTTP request to TLS server) + if matches!(e, rustls::Error::InvalidMessage(_)) { + return Err(SslError::PreauthData); + } + // Certificate verification errors are already handled by from_rustls return Err(SslError::from_rustls(e)); From 92acf339a1d6b30ea605fc3179b34d8c6b1fb381 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 25 Dec 2025 13:31:29 +0900 Subject: [PATCH 061/418] fix pyexpat hang (#6507) --- crates/stdlib/src/pyexpat.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/stdlib/src/pyexpat.rs b/crates/stdlib/src/pyexpat.rs index 699fa21852d..e96d6287489 100644 --- a/crates/stdlib/src/pyexpat.rs +++ b/crates/stdlib/src/pyexpat.rs @@ -102,7 +102,9 @@ mod _pyexpat { where T: IntoFuncArgs, { - handler.read().call(args, vm).ok(); + // Clone the handler while holding the read lock, then release the lock + let handler = handler.read().clone(); + handler.call(args, vm).ok(); } #[pyclass] From 8443b2c97e75af3ec14b6657cd40429ad0a78d99 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 25 Dec 2025 14:54:44 +0900 Subject: [PATCH 062/418] Add POSIX shared memory module for multiprocessing on Unix (#6498) --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> Co-authored-by: github-actions[bot] --- .cspell.json | 2 + crates/stdlib/src/lib.rs | 6 +++ crates/stdlib/src/posixshmem.rs | 48 ++++++++++++++++++++++ extra_tests/snippets/builtin_posixshmem.py | 12 ++++++ 4 files changed, 68 insertions(+) create mode 100644 crates/stdlib/src/posixshmem.rs create mode 100644 extra_tests/snippets/builtin_posixshmem.py diff --git a/.cspell.json b/.cspell.json index 3bd06fc2032..89cde1ce775 100644 --- a/.cspell.json +++ b/.cspell.json @@ -124,6 +124,8 @@ "wasi", "zelf", // unix + "posixshmem", + "shm", "CLOEXEC", "codeset", "endgrent", diff --git a/crates/stdlib/src/lib.rs b/crates/stdlib/src/lib.rs index c9b5ca32b57..4b463e09c73 100644 --- a/crates/stdlib/src/lib.rs +++ b/crates/stdlib/src/lib.rs @@ -57,6 +57,8 @@ mod faulthandler; mod fcntl; #[cfg(not(target_arch = "wasm32"))] mod multiprocessing; +#[cfg(all(unix, not(target_os = "redox"), not(target_os = "android")))] +mod posixshmem; #[cfg(unix)] mod posixsubprocess; // libc is missing constants on redox @@ -190,6 +192,10 @@ pub fn get_module_inits() -> impl Iterator, StdlibInit { "_posixsubprocess" => posixsubprocess::make_module, } + #[cfg(all(unix, not(target_os = "redox"), not(target_os = "android")))] + { + "_posixshmem" => posixshmem::make_module, + } #[cfg(any(unix, windows))] { "mmap" => mmap::make_module, diff --git a/crates/stdlib/src/posixshmem.rs b/crates/stdlib/src/posixshmem.rs new file mode 100644 index 00000000000..2957f16792c --- /dev/null +++ b/crates/stdlib/src/posixshmem.rs @@ -0,0 +1,48 @@ +#[cfg(all(unix, not(target_os = "redox"), not(target_os = "android")))] +pub(crate) use _posixshmem::make_module; + +#[cfg(all(unix, not(target_os = "redox"), not(target_os = "android")))] +#[pymodule] +mod _posixshmem { + use std::ffi::CString; + + use crate::{ + common::os::errno_io_error, + vm::{ + PyResult, VirtualMachine, builtins::PyStrRef, convert::IntoPyException, + function::OptionalArg, + }, + }; + + #[pyfunction] + fn shm_open( + name: PyStrRef, + flags: libc::c_int, + mode: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let name = CString::new(name.as_str()).map_err(|e| e.into_pyexception(vm))?; + let mode: libc::c_uint = mode.unwrap_or(0o600) as _; + #[cfg(target_os = "freebsd")] + let mode = mode.try_into().unwrap(); + // SAFETY: `name` is a NUL-terminated string and `shm_open` does not write through it. + let fd = unsafe { libc::shm_open(name.as_ptr(), flags, mode) }; + if fd == -1 { + Err(errno_io_error().into_pyexception(vm)) + } else { + Ok(fd) + } + } + + #[pyfunction] + fn shm_unlink(name: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { + let name = CString::new(name.as_str()).map_err(|e| e.into_pyexception(vm))?; + // SAFETY: `name` is a valid NUL-terminated string and `shm_unlink` only reads it. + let ret = unsafe { libc::shm_unlink(name.as_ptr()) }; + if ret == -1 { + Err(errno_io_error().into_pyexception(vm)) + } else { + Ok(()) + } + } +} diff --git a/extra_tests/snippets/builtin_posixshmem.py b/extra_tests/snippets/builtin_posixshmem.py new file mode 100644 index 00000000000..38ace68d584 --- /dev/null +++ b/extra_tests/snippets/builtin_posixshmem.py @@ -0,0 +1,12 @@ +import os +import sys + +if os.name != "posix": + sys.exit(0) + +import _posixshmem + +name = f"/rp_posixshmem_{os.getpid()}" +fd = _posixshmem.shm_open(name, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0o600) +os.close(fd) +_posixshmem.shm_unlink(name) From 7eb0fe4984b8b5604f74a9319cdadb5ea1bd2489 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 25 Dec 2025 15:49:44 +0900 Subject: [PATCH 063/418] Fix potential deadlock (#6509) --- crates/vm/src/builtins/classmethod.rs | 8 +++--- crates/vm/src/builtins/property.rs | 33 +++++++++++++----------- crates/vm/src/stdlib/functools.rs | 36 +++++++++++++++++++-------- 3 files changed, 49 insertions(+), 28 deletions(-) diff --git a/crates/vm/src/builtins/classmethod.rs b/crates/vm/src/builtins/classmethod.rs index 911960bf691..21f1ae4ba10 100644 --- a/crates/vm/src/builtins/classmethod.rs +++ b/crates/vm/src/builtins/classmethod.rs @@ -57,11 +57,11 @@ impl GetDescriptor for PyClassMethod { ) -> PyResult { let (zelf, _obj) = Self::_unwrap(&zelf, obj, vm)?; let cls = cls.unwrap_or_else(|| _obj.class().to_owned().into()); - let call_descr_get: PyResult = zelf.callable.lock().get_attr("__get__", vm); + // Clone and release lock before calling Python code to prevent deadlock + let callable = zelf.callable.lock().clone(); + let call_descr_get: PyResult = callable.get_attr("__get__", vm); match call_descr_get { - Err(_) => Ok(PyBoundMethod::new(cls, zelf.callable.lock().clone()) - .into_ref(&vm.ctx) - .into()), + Err(_) => Ok(PyBoundMethod::new(cls, callable).into_ref(&vm.ctx).into()), Ok(call_descr_get) => call_descr_get.call((cls.clone(), cls), vm), } } diff --git a/crates/vm/src/builtins/property.rs b/crates/vm/src/builtins/property.rs index 41b05a60049..7ea36d39768 100644 --- a/crates/vm/src/builtins/property.rs +++ b/crates/vm/src/builtins/property.rs @@ -55,7 +55,8 @@ impl GetDescriptor for PyProperty { let (zelf, obj) = Self::_unwrap(&zelf_obj, obj, vm)?; if vm.is_none(&obj) { Ok(zelf_obj) - } else if let Some(getter) = zelf.getter.read().as_ref() { + } else if let Some(getter) = zelf.getter.read().clone() { + // Clone and release lock before calling Python code to prevent deadlock getter.call((obj,), vm) } else { let error_msg = zelf.format_property_error(&obj, "getter", vm)?; @@ -70,12 +71,12 @@ impl PyProperty { // Returns the name if available, None if not found, or propagates errors fn get_property_name(&self, vm: &VirtualMachine) -> PyResult> { // First check if name was set via __set_name__ - if let Some(name) = self.name.read().as_ref() { - return Ok(Some(name.clone())); + if let Some(name) = self.name.read().clone() { + return Ok(Some(name)); } - let getter = self.getter.read(); - let Some(getter) = getter.as_ref() else { + // Clone and release lock before calling Python code to prevent deadlock + let Some(getter) = self.getter.read().clone() else { return Ok(None); }; @@ -105,7 +106,8 @@ impl PyProperty { let zelf = zelf.try_to_ref::(vm)?; match value { PySetterValue::Assign(value) => { - if let Some(setter) = zelf.setter.read().as_ref() { + // Clone and release lock before calling Python code to prevent deadlock + if let Some(setter) = zelf.setter.read().clone() { setter.call((obj, value), vm).map(drop) } else { let error_msg = zelf.format_property_error(&obj, "setter", vm)?; @@ -113,7 +115,8 @@ impl PyProperty { } } PySetterValue::Delete => { - if let Some(deleter) = zelf.deleter.read().as_ref() { + // Clone and release lock before calling Python code to prevent deadlock + if let Some(deleter) = zelf.deleter.read().clone() { deleter.call((obj,), vm).map(drop) } else { let error_msg = zelf.format_property_error(&obj, "deleter", vm)?; @@ -273,23 +276,24 @@ impl PyProperty { } }; + // Clone and release lock before calling Python code to prevent deadlock // Check getter - if let Some(getter) = self.getter.read().as_ref() - && is_abstract(getter)? + if let Some(getter) = self.getter.read().clone() + && is_abstract(&getter)? { return Ok(vm.ctx.new_bool(true).into()); } // Check setter - if let Some(setter) = self.setter.read().as_ref() - && is_abstract(setter)? + if let Some(setter) = self.setter.read().clone() + && is_abstract(&setter)? { return Ok(vm.ctx.new_bool(true).into()); } // Check deleter - if let Some(deleter) = self.deleter.read().as_ref() - && is_abstract(deleter)? + if let Some(deleter) = self.deleter.read().clone() + && is_abstract(&deleter)? { return Ok(vm.ctx.new_bool(true).into()); } @@ -299,7 +303,8 @@ impl PyProperty { #[pygetset(setter)] fn set___isabstractmethod__(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if let Some(getter) = self.getter.read().to_owned() { + // Clone and release lock before calling Python code to prevent deadlock + if let Some(getter) = self.getter.read().clone() { getter.set_attr("__isabstractmethod__", value, vm)?; } Ok(()) diff --git a/crates/vm/src/stdlib/functools.rs b/crates/vm/src/stdlib/functools.rs index 26dff8b4426..77f352c78fc 100644 --- a/crates/vm/src/stdlib/functools.rs +++ b/crates/vm/src/stdlib/functools.rs @@ -248,15 +248,24 @@ mod _functools { type Args = FuncArgs; fn call(zelf: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - let inner = zelf.inner.read(); - let mut combined_args = inner.args.as_slice().to_vec(); + // Clone and release lock before calling Python code to prevent deadlock + let (func, stored_args, keywords) = { + let inner = zelf.inner.read(); + ( + inner.func.clone(), + inner.args.clone(), + inner.keywords.clone(), + ) + }; + + let mut combined_args = stored_args.as_slice().to_vec(); combined_args.extend_from_slice(&args.args); // Merge keywords from self.keywords and args.kwargs let mut final_kwargs = IndexMap::new(); // Add keywords from self.keywords - for (key, value) in &*inner.keywords { + for (key, value) in &*keywords { let key_str = key .downcast::() .map_err(|_| vm.new_type_error("keywords must be strings"))?; @@ -268,9 +277,7 @@ mod _functools { final_kwargs.insert(key, value); } - inner - .func - .call(FuncArgs::new(combined_args, KwArgs::new(final_kwargs)), vm) + func.call(FuncArgs::new(combined_args, KwArgs::new(final_kwargs)), vm) } } @@ -280,15 +287,24 @@ mod _functools { // Check for recursive repr let obj = zelf.as_object(); if let Some(_guard) = ReprGuard::enter(vm, obj) { - let inner = zelf.inner.read(); - let func_repr = inner.func.repr(vm)?; + // Clone and release lock before calling Python code to prevent deadlock + let (func, args, keywords) = { + let inner = zelf.inner.read(); + ( + inner.func.clone(), + inner.args.clone(), + inner.keywords.clone(), + ) + }; + + let func_repr = func.repr(vm)?; let mut parts = vec![func_repr.as_str().to_owned()]; - for arg in inner.args.as_slice() { + for arg in args.as_slice() { parts.push(arg.repr(vm)?.as_str().to_owned()); } - for (key, value) in inner.keywords.clone() { + for (key, value) in &*keywords { // For string keys, use them directly without quotes let key_part = if let Ok(s) = key.clone().downcast::() { s.as_str().to_owned() From aaecdd17470b17994a1a090be036db99fce005e6 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 25 Dec 2025 16:55:19 +0900 Subject: [PATCH 064/418] repr for getset (#6514) --- crates/vm/src/builtins/getset.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/crates/vm/src/builtins/getset.rs b/crates/vm/src/builtins/getset.rs index f56191f5f8b..3fa4667a997 100644 --- a/crates/vm/src/builtins/getset.rs +++ b/crates/vm/src/builtins/getset.rs @@ -7,7 +7,7 @@ use crate::{ builtins::type_::PointerSlot, class::PyClassImpl, function::{IntoPyGetterFunc, IntoPySetterFunc, PyGetterFunc, PySetterFunc, PySetterValue}, - types::GetDescriptor, + types::{GetDescriptor, Representable}, }; #[pyclass(module = false, name = "getset_descriptor")] @@ -96,7 +96,7 @@ impl PyGetSet { } } -#[pyclass(flags(DISALLOW_INSTANTIATION), with(GetDescriptor))] +#[pyclass(flags(DISALLOW_INSTANTIATION), with(GetDescriptor, Representable))] impl PyGetSet { // Descriptor methods @@ -153,6 +153,23 @@ impl PyGetSet { } } +impl Representable for PyGetSet { + #[inline] + fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let class = unsafe { zelf.class.borrow_static() }; + // Special case for object type + if std::ptr::eq(class, vm.ctx.types.object_type) { + Ok(format!("", zelf.name)) + } else { + Ok(format!( + "", + zelf.name, + class.name() + )) + } + } +} + pub(crate) fn init(context: &Context) { PyGetSet::extend_class(context, context.types.getset_type); } From 151f0746a3e4130d060879f69731b74585e21f38 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 25 Dec 2025 16:56:41 +0900 Subject: [PATCH 065/418] Implement copyslot (#6505) --- Lib/test/test_typing.py | 1 - crates/stdlib/src/sqlite.rs | 4 +- crates/vm/src/builtins/bool.rs | 2 +- crates/vm/src/builtins/classmethod.rs | 2 +- crates/vm/src/builtins/dict.rs | 4 +- crates/vm/src/builtins/iter.rs | 18 +- crates/vm/src/builtins/list.rs | 6 +- crates/vm/src/builtins/mappingproxy.rs | 10 +- crates/vm/src/builtins/object.rs | 5 +- crates/vm/src/builtins/type.rs | 226 ++++++++++++++++++----- crates/vm/src/builtins/weakproxy.rs | 4 +- crates/vm/src/class.rs | 5 + crates/vm/src/frame.rs | 2 +- crates/vm/src/function/protocol.rs | 40 ++-- crates/vm/src/object/core.rs | 2 +- crates/vm/src/protocol/buffer.rs | 3 +- crates/vm/src/protocol/callable.rs | 2 +- crates/vm/src/protocol/iter.rs | 34 ++-- crates/vm/src/protocol/mapping.rs | 92 +++++---- crates/vm/src/protocol/mod.rs | 4 +- crates/vm/src/protocol/number.rs | 219 ++++++++++------------ crates/vm/src/protocol/object.rs | 92 ++++----- crates/vm/src/protocol/sequence.rs | 141 +++++++++----- crates/vm/src/stdlib/builtins.rs | 2 +- crates/vm/src/stdlib/ctypes/structure.rs | 2 +- crates/vm/src/stdlib/ctypes/union.rs | 2 +- crates/vm/src/stdlib/typing.rs | 7 +- crates/vm/src/types/slot.rs | 125 ++++++++----- crates/vm/src/types/structseq.rs | 25 +-- crates/vm/src/types/zoo.rs | 1 + crates/vm/src/vm/method.rs | 8 +- crates/vm/src/vm/vm_object.rs | 4 +- crates/vm/src/vm/vm_ops.rs | 40 ++-- 33 files changed, 648 insertions(+), 486 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index db0dc916f1a..74ab94eb5c3 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -10496,7 +10496,6 @@ class CustomerModel(ModelBase, init=False): class NoDefaultTests(BaseTestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON def test_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): s = pickle.dumps(NoDefault, proto) diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index 2328f23430a..c760c2830d7 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -488,7 +488,7 @@ mod _sqlite { let text2 = vm.ctx.new_str(text2); let val = callable.call((text1, text2), vm)?; - let Some(val) = val.to_number().index(vm) else { + let Some(val) = val.number().index(vm) else { return Ok(0); }; @@ -2980,7 +2980,7 @@ mod _sqlite { fn bind_parameters(self, parameters: &PyObject, vm: &VirtualMachine) -> PyResult<()> { if let Some(dict) = parameters.downcast_ref::() { self.bind_parameters_name(dict, vm) - } else if let Ok(seq) = PySequence::try_protocol(parameters, vm) { + } else if let Ok(seq) = parameters.try_sequence(vm) { self.bind_parameters_sequence(seq, vm) } else { Err(new_programming_error( diff --git a/crates/vm/src/builtins/bool.rs b/crates/vm/src/builtins/bool.rs index 6b3ddd8241a..8fee4af3834 100644 --- a/crates/vm/src/builtins/bool.rs +++ b/crates/vm/src/builtins/bool.rs @@ -43,7 +43,7 @@ impl PyObjectRef { return Ok(false); } let rs_bool = if let Some(nb_bool) = self.class().slots.as_number.boolean.load() { - nb_bool(self.as_object().to_number(), vm)? + nb_bool(self.as_object().number(), vm)? } else { // TODO: Fully implement AsNumber and remove this block match vm.get_method(self.clone(), identifier!(vm, __bool__)) { diff --git a/crates/vm/src/builtins/classmethod.rs b/crates/vm/src/builtins/classmethod.rs index 21f1ae4ba10..5b7f9218658 100644 --- a/crates/vm/src/builtins/classmethod.rs +++ b/crates/vm/src/builtins/classmethod.rs @@ -124,7 +124,7 @@ impl PyClassMethod { } #[pyclass( - with(GetDescriptor, Constructor, Representable), + with(GetDescriptor, Constructor, Initializer, Representable), flags(BASETYPE, HAS_DICT) )] impl PyClassMethod { diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index aff7432d067..567e18d6419 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -1145,7 +1145,7 @@ impl ViewSetOps for PyDictKeys {} impl PyDictKeys { #[pymethod] fn __contains__(zelf: PyObjectRef, key: PyObjectRef, vm: &VirtualMachine) -> PyResult { - zelf.to_sequence().contains(&key, vm) + zelf.sequence_unchecked().contains(&key, vm) } #[pygetset] @@ -1210,7 +1210,7 @@ impl ViewSetOps for PyDictItems {} impl PyDictItems { #[pymethod] fn __contains__(zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { - zelf.to_sequence().contains(&needle, vm) + zelf.sequence_unchecked().contains(&needle, vm) } #[pygetset] fn mapping(zelf: PyRef) -> PyMappingProxy { diff --git a/crates/vm/src/builtins/iter.rs b/crates/vm/src/builtins/iter.rs index 56dfc14d164..736303e95ee 100644 --- a/crates/vm/src/builtins/iter.rs +++ b/crates/vm/src/builtins/iter.rs @@ -8,7 +8,7 @@ use crate::{ class::PyClassImpl, function::ArgCallable, object::{Traverse, TraverseFn}, - protocol::{PyIterReturn, PySequence, PySequenceMethods}, + protocol::PyIterReturn, types::{IterNext, Iterable, SelfIter}, }; use rustpython_common::{ @@ -177,9 +177,6 @@ pub fn builtins_reversed(vm: &VirtualMachine) -> &PyObject { #[pyclass(module = false, name = "iterator", traverse)] #[derive(Debug)] pub struct PySequenceIterator { - // cached sequence methods - #[pytraverse(skip)] - seq_methods: &'static PySequenceMethods, internal: PyMutex>, } @@ -193,9 +190,8 @@ impl PyPayload for PySequenceIterator { #[pyclass(with(IterNext, Iterable))] impl PySequenceIterator { pub fn new(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let seq = PySequence::try_protocol(&obj, vm)?; + let _seq = obj.try_sequence(vm)?; Ok(Self { - seq_methods: seq.methods, internal: PyMutex::new(PositionIterInternal::new(obj, 0)), }) } @@ -204,10 +200,7 @@ impl PySequenceIterator { fn __length_hint__(&self, vm: &VirtualMachine) -> PyObjectRef { let internal = self.internal.lock(); if let IterStatus::Active(obj) = &internal.status { - let seq = PySequence { - obj, - methods: self.seq_methods, - }; + let seq = obj.sequence_unchecked(); seq.length(vm) .map(|x| PyInt::from(x).into_pyobject(vm)) .unwrap_or_else(|_| vm.ctx.not_implemented()) @@ -231,10 +224,7 @@ impl SelfIter for PySequenceIterator {} impl IterNext for PySequenceIterator { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { zelf.internal.lock().next(|obj, pos| { - let seq = PySequence { - obj, - methods: zelf.seq_methods, - }; + let seq = obj.sequence_unchecked(); PyIterReturn::from_getitem_result(seq.get_item(pos as isize, vm), vm) }) } diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index 13e8864cd1f..12cab27a750 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -354,7 +354,11 @@ where } else { let iter = obj.to_owned().get_iter(vm)?; let iter = iter.iter::(vm)?; - let len = obj.to_sequence().length_opt(vm).transpose()?.unwrap_or(0); + let len = obj + .sequence_unchecked() + .length_opt(vm) + .transpose()? + .unwrap_or(0); let mut v = Vec::with_capacity(len); for x in iter { v.push(f(x?)?); diff --git a/crates/vm/src/builtins/mappingproxy.rs b/crates/vm/src/builtins/mappingproxy.rs index fb8ff5de9cc..475f36cb5a9 100644 --- a/crates/vm/src/builtins/mappingproxy.rs +++ b/crates/vm/src/builtins/mappingproxy.rs @@ -6,7 +6,7 @@ use crate::{ convert::ToPyObject, function::{ArgMapping, OptionalArg, PyComparisonValue}, object::{Traverse, TraverseFn}, - protocol::{PyMapping, PyMappingMethods, PyNumberMethods, PySequenceMethods}, + protocol::{PyMappingMethods, PyNumberMethods, PySequenceMethods}, types::{ AsMapping, AsNumber, AsSequence, Comparable, Constructor, Iterable, PyComparisonOp, Representable, @@ -62,14 +62,12 @@ impl Constructor for PyMappingProxy { type Args = PyObjectRef; fn py_new(_cls: &Py, mapping: Self::Args, vm: &VirtualMachine) -> PyResult { - if let Some(methods) = PyMapping::find_methods(&mapping) + if mapping.mapping_unchecked().check() && !mapping.downcastable::() && !mapping.downcastable::() { return Ok(Self { - mapping: MappingProxyInner::Mapping(ArgMapping::with_methods(mapping, unsafe { - methods.borrow_static() - })), + mapping: MappingProxyInner::Mapping(ArgMapping::new(mapping)), }); } Err(vm.new_type_error(format!( @@ -124,7 +122,7 @@ impl PyMappingProxy { MappingProxyInner::Class(class) => Ok(key .as_interned_str(vm) .is_some_and(|key| class.attributes.read().contains_key(key))), - MappingProxyInner::Mapping(mapping) => mapping.to_sequence().contains(key, vm), + MappingProxyInner::Mapping(mapping) => mapping.sequence_unchecked().contains(key, vm), } } diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs index 65a11d55314..ca208790f4b 100644 --- a/crates/vm/src/builtins/object.rs +++ b/crates/vm/src/builtins/object.rs @@ -320,10 +320,7 @@ impl PyBaseObject { } } PyComparisonOp::Ne => { - let cmp = zelf - .class() - .mro_find_map(|cls| cls.slots.richcompare.load()) - .unwrap(); + let cmp = zelf.class().slots.richcompare.load().unwrap(); let value = match cmp(zelf, other, PyComparisonOp::Eq, vm)? { Either::A(obj) => PyArithmeticValue::from_object(vm, obj) .map(|obj| obj.try_to_bool(vm)) diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 68de17f60b6..c2373f26faf 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -23,7 +23,7 @@ use crate::{ convert::ToPyResult, function::{FuncArgs, KwArgs, OptionalArg, PyMethodDef, PySetterValue}, object::{Traverse, TraverseFn}, - protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods}, + protocol::{PyIterReturn, PyNumberMethods}, types::{ AsNumber, Callable, Constructor, GetAttr, PyTypeFlags, PyTypeSlots, Representable, SetAttr, TypeDataRef, TypeDataRefMut, TypeDataSlot, @@ -64,8 +64,6 @@ pub struct HeapTypeExt { pub name: PyRwLock, pub qualname: PyRwLock, pub slots: Option>>, - pub sequence_methods: PySequenceMethods, - pub mapping_methods: PyMappingMethods, pub type_data: PyRwLock>, } @@ -100,17 +98,6 @@ impl AsRef for PointerSlot { } } -impl PointerSlot { - pub unsafe fn from_heaptype(typ: &PyType, f: F) -> Option - where - F: FnOnce(&HeapTypeExt) -> &T, - { - typ.heaptype_ext - .as_ref() - .map(|ext| Self(NonNull::from(f(ext)))) - } -} - pub type PyTypeRef = PyRef; cfg_if::cfg_if! { @@ -206,8 +193,6 @@ impl PyType { name: PyRwLock::new(name.clone()), qualname: PyRwLock::new(name), slots: None, - sequence_methods: PySequenceMethods::default(), - mapping_methods: PyMappingMethods::default(), type_data: PyRwLock::new(None), }; let base = bases[0].clone(); @@ -331,6 +316,8 @@ impl PyType { slots.basicsize = base.slots.basicsize; } + Self::inherit_readonly_slots(&mut slots, &base); + if let Some(qualname) = attrs.get(identifier!(ctx, __qualname__)) && !qualname.fast_isinstance(ctx.types.str_type) { @@ -387,6 +374,8 @@ impl PyType { slots.basicsize = base.slots.basicsize; } + Self::inherit_readonly_slots(&mut slots, &base); + let bases = PyRwLock::new(vec![base.clone()]); let mro = base.mro_map_collect(|x| x.to_owned()); @@ -404,6 +393,9 @@ impl PyType { None, ); + // Note: inherit_slots is called in PyClassImpl::init_class after + // slots are fully initialized by make_slots() + Self::set_new(&new_type.slots, &new_type.base); let weakref_type = super::PyWeak::static_type(); @@ -420,6 +412,12 @@ impl PyType { } pub(crate) fn init_slots(&self, ctx: &Context) { + // Inherit slots from direct bases (not MRO) + for base in self.bases.read().iter() { + self.inherit_slots(base); + } + + // Wire dunder methods to slots #[allow(clippy::mutable_key_type)] let mut slot_name_set = std::collections::HashSet::new(); @@ -454,6 +452,164 @@ impl PyType { } } + /// Inherit readonly slots from base type at creation time. + /// These slots are not AtomicCell and must be set before the type is used. + fn inherit_readonly_slots(slots: &mut PyTypeSlots, base: &Self) { + if slots.as_buffer.is_none() { + slots.as_buffer = base.slots.as_buffer; + } + } + + /// Inherit slots from base type. typeobject.c: inherit_slots + pub(crate) fn inherit_slots(&self, base: &Self) { + macro_rules! copyslot { + ($slot:ident) => { + if self.slots.$slot.load().is_none() { + if let Some(base_val) = base.slots.$slot.load() { + self.slots.$slot.store(Some(base_val)); + } + } + }; + } + + // Core slots + copyslot!(hash); + copyslot!(call); + copyslot!(str); + copyslot!(repr); + copyslot!(getattro); + copyslot!(setattro); + copyslot!(richcompare); + copyslot!(iter); + copyslot!(iternext); + copyslot!(descr_get); + copyslot!(descr_set); + // Note: init is NOT inherited here because object_init has special + // handling in CPython (checks if type->tp_init != object_init). + // TODO: implement proper init inheritance with object_init check + copyslot!(del); + // new is handled by set_new() + // as_buffer is inherited at type creation time (not AtomicCell) + + // Sub-slots (number, sequence, mapping) + self.inherit_number_slots(base); + self.inherit_sequence_slots(base); + self.inherit_mapping_slots(base); + } + + /// Inherit number sub-slots from base type + fn inherit_number_slots(&self, base: &Self) { + macro_rules! copy_num_slot { + ($slot:ident) => { + if self.slots.as_number.$slot.load().is_none() { + if let Some(base_val) = base.slots.as_number.$slot.load() { + self.slots.as_number.$slot.store(Some(base_val)); + } + } + }; + } + + // Binary operations + copy_num_slot!(add); + copy_num_slot!(right_add); + copy_num_slot!(inplace_add); + copy_num_slot!(subtract); + copy_num_slot!(right_subtract); + copy_num_slot!(inplace_subtract); + copy_num_slot!(multiply); + copy_num_slot!(right_multiply); + copy_num_slot!(inplace_multiply); + copy_num_slot!(remainder); + copy_num_slot!(right_remainder); + copy_num_slot!(inplace_remainder); + copy_num_slot!(divmod); + copy_num_slot!(right_divmod); + copy_num_slot!(power); + copy_num_slot!(right_power); + copy_num_slot!(inplace_power); + + // Bitwise operations + copy_num_slot!(lshift); + copy_num_slot!(right_lshift); + copy_num_slot!(inplace_lshift); + copy_num_slot!(rshift); + copy_num_slot!(right_rshift); + copy_num_slot!(inplace_rshift); + copy_num_slot!(and); + copy_num_slot!(right_and); + copy_num_slot!(inplace_and); + copy_num_slot!(xor); + copy_num_slot!(right_xor); + copy_num_slot!(inplace_xor); + copy_num_slot!(or); + copy_num_slot!(right_or); + copy_num_slot!(inplace_or); + + // Division operations + copy_num_slot!(floor_divide); + copy_num_slot!(right_floor_divide); + copy_num_slot!(inplace_floor_divide); + copy_num_slot!(true_divide); + copy_num_slot!(right_true_divide); + copy_num_slot!(inplace_true_divide); + + // Matrix multiplication + copy_num_slot!(matrix_multiply); + copy_num_slot!(right_matrix_multiply); + copy_num_slot!(inplace_matrix_multiply); + + // Unary operations + copy_num_slot!(negative); + copy_num_slot!(positive); + copy_num_slot!(absolute); + copy_num_slot!(boolean); + copy_num_slot!(invert); + + // Conversion + copy_num_slot!(int); + copy_num_slot!(float); + copy_num_slot!(index); + } + + /// Inherit sequence sub-slots from base type + fn inherit_sequence_slots(&self, base: &Self) { + macro_rules! copy_seq_slot { + ($slot:ident) => { + if self.slots.as_sequence.$slot.load().is_none() { + if let Some(base_val) = base.slots.as_sequence.$slot.load() { + self.slots.as_sequence.$slot.store(Some(base_val)); + } + } + }; + } + + copy_seq_slot!(length); + copy_seq_slot!(concat); + copy_seq_slot!(repeat); + copy_seq_slot!(item); + copy_seq_slot!(ass_item); + copy_seq_slot!(contains); + copy_seq_slot!(inplace_concat); + copy_seq_slot!(inplace_repeat); + } + + /// Inherit mapping sub-slots from base type + fn inherit_mapping_slots(&self, base: &Self) { + macro_rules! copy_map_slot { + ($slot:ident) => { + if self.slots.as_mapping.$slot.load().is_none() { + if let Some(base_val) = base.slots.as_mapping.$slot.load() { + self.slots.as_mapping.$slot.store(Some(base_val)); + } + } + }; + } + + copy_map_slot!(length); + copy_map_slot!(subscript); + copy_map_slot!(ass_subscript); + } + // This is used for class initialization where the vm is not yet available. pub fn set_str_attr>( &self, @@ -663,19 +819,6 @@ impl Py { .collect() } - pub(crate) fn mro_find_map(&self, f: F) -> Option - where - F: Fn(&Self) -> Option, - { - // the hot path will be primitive types which usually hit the result from itself. - // try std::intrinsics::likely once it is stabilized - if let Some(r) = f(self) { - Some(r) - } else { - self.mro.read().iter().find_map(|cls| f(cls)) - } - } - pub fn iter_base_chain(&self) -> impl Iterator { std::iter::successors(Some(self), |cls| cls.base.as_deref()) } @@ -1228,8 +1371,6 @@ impl Constructor for PyType { name: PyRwLock::new(name), qualname: PyRwLock::new(qualname), slots: heaptype_slots.clone(), - sequence_methods: PySequenceMethods::default(), - mapping_methods: PyMappingMethods::default(), type_data: PyRwLock::new(None), }; (slots, heaptype_ext) @@ -1414,11 +1555,9 @@ impl GetAttr for PyType { if let Some(ref attr) = mcl_attr { let attr_class = attr.class(); - let has_descr_set = attr_class - .mro_find_map(|cls| cls.slots.descr_set.load()) - .is_some(); + let has_descr_set = attr_class.slots.descr_set.load().is_some(); if has_descr_set { - let descr_get = attr_class.mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = attr_class.slots.descr_get.load(); if let Some(descr_get) = descr_get { let mcl = mcl.to_owned().into(); return descr_get(attr.clone(), Some(zelf.to_owned().into()), Some(mcl), vm); @@ -1429,7 +1568,7 @@ impl GetAttr for PyType { let zelf_attr = zelf.get_attr(name); if let Some(attr) = zelf_attr { - let descr_get = attr.class().mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = attr.class().slots.descr_get.load(); if let Some(descr_get) = descr_get { descr_get(attr, None, Some(zelf.to_owned().into()), vm) } else { @@ -1467,9 +1606,7 @@ impl Py { // CPython returns None if __doc__ is not in the type's own dict if let Some(doc_attr) = self.get_direct_attr(vm.ctx.intern_str("__doc__")) { // If it's a descriptor, call its __get__ method - let descr_get = doc_attr - .class() - .mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = doc_attr.class().slots.descr_get.load(); if let Some(descr_get) = descr_get { descr_get(doc_attr, None, Some(self.to_owned().into()), vm) } else { @@ -1545,7 +1682,7 @@ impl SetAttr for PyType { // TODO: pass PyRefExact instead of &str let attr_name = vm.ctx.intern_str(attr_name.as_str()); if let Some(attr) = zelf.get_class_attr(attr_name) { - let descr_set = attr.class().mro_find_map(|cls| cls.slots.descr_set.load()); + let descr_set = attr.class().slots.descr_set.load(); if let Some(descriptor) = descr_set { return descriptor(&attr, zelf.to_owned().into(), value, vm); } @@ -1702,7 +1839,9 @@ fn subtype_set_dict(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) - // Call the descriptor's tp_descr_set let descr_set = descr .class() - .mro_find_map(|cls| cls.slots.descr_set.load()) + .slots + .descr_set + .load() .ok_or_else(|| raise_dict_descriptor_error(&obj, vm))?; descr_set(&descr, obj, PySetterValue::Assign(value), vm) } else { @@ -1764,8 +1903,9 @@ pub(crate) fn call_slot_new( } let slot_new = typ - .deref() - .mro_find_map(|cls| cls.slots.new.load()) + .slots + .new + .load() .expect("Should be able to find a new slot somewhere in the mro"); slot_new(subtype, args, vm) } diff --git a/crates/vm/src/builtins/weakproxy.rs b/crates/vm/src/builtins/weakproxy.rs index 6e0e8308dbc..94c54b5459e 100644 --- a/crates/vm/src/builtins/weakproxy.rs +++ b/crates/vm/src/builtins/weakproxy.rs @@ -104,7 +104,9 @@ impl PyWeakProxy { } #[pymethod] fn __contains__(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.try_upgrade(vm)?.to_sequence().contains(&needle, vm) + self.try_upgrade(vm)? + .sequence_unchecked() + .contains(&needle, vm) } fn getitem(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index 40f01d98b00..da860a96289 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -176,6 +176,11 @@ pub trait PyClassImpl: PyClassDef { add_slot_wrapper!(hash, __hash__, Hash, "Return hash(self)."); } + // Inherit slots from base types after slots are fully initialized + for base in class.bases.read().iter() { + class.inherit_slots(base); + } + class.extend_methods(class.slots.methods, ctx); } diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index ad50f972aef..d9cc404b47a 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -1900,7 +1900,7 @@ impl ExecutingFrame<'_> { // TODO: It was PyMethod before #4873. Check if it's correct. let func = if is_method { - if let Some(descr_get) = func.class().mro_find_map(|cls| cls.slots.descr_get.load()) { + if let Some(descr_get) = func.class().slots.descr_get.load() { let cls = target.class().to_owned().into(); descr_get(func, Some(target), Some(cls), vm)? } else { diff --git a/crates/vm/src/function/protocol.rs b/crates/vm/src/function/protocol.rs index 1e670b96389..a87ef339edd 100644 --- a/crates/vm/src/function/protocol.rs +++ b/crates/vm/src/function/protocol.rs @@ -1,11 +1,11 @@ use super::IntoFuncArgs; use crate::{ AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, - builtins::{PyDict, PyDictRef, iter::PySequenceIterator}, + builtins::{PyDictRef, iter::PySequenceIterator}, convert::ToPyObject, object::{Traverse, TraverseFn}, - protocol::{PyIter, PyIterIter, PyMapping, PyMappingMethods}, - types::{AsMapping, GenericMethod}, + protocol::{PyIter, PyIterIter, PyMapping}, + types::GenericMethod, }; use std::{borrow::Borrow, marker::PhantomData, ops::Deref}; @@ -104,14 +104,11 @@ where T: TryFromObject, { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let iter_fn = { - let cls = obj.class(); - let iter_fn = cls.mro_find_map(|x| x.slots.iter.load()); - if iter_fn.is_none() && !cls.has_attr(identifier!(vm, __getitem__)) { - return Err(vm.new_type_error(format!("'{}' object is not iterable", cls.name()))); - } - iter_fn - }; + let cls = obj.class(); + let iter_fn = cls.slots.iter.load(); + if iter_fn.is_none() && !cls.has_attr(identifier!(vm, __getitem__)) { + return Err(vm.new_type_error(format!("'{}' object is not iterable", cls.name()))); + } Ok(Self { iterable: obj, iter_fn, @@ -123,30 +120,22 @@ where #[derive(Debug, Clone, Traverse)] pub struct ArgMapping { obj: PyObjectRef, - #[pytraverse(skip)] - methods: &'static PyMappingMethods, } impl ArgMapping { #[inline] - pub const fn with_methods(obj: PyObjectRef, methods: &'static PyMappingMethods) -> Self { - Self { obj, methods } + pub const fn new(obj: PyObjectRef) -> Self { + Self { obj } } #[inline(always)] pub fn from_dict_exact(dict: PyDictRef) -> Self { - Self { - obj: dict.into(), - methods: PyDict::as_mapping(), - } + Self { obj: dict.into() } } #[inline(always)] pub fn mapping(&self) -> PyMapping<'_> { - PyMapping { - obj: &self.obj, - methods: self.methods, - } + self.obj.mapping_unchecked() } } @@ -188,9 +177,8 @@ impl ToPyObject for ArgMapping { impl TryFromObject for ArgMapping { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let mapping = PyMapping::try_protocol(&obj, vm)?; - let methods = mapping.methods; - Ok(Self { obj, methods }) + let _mapping = obj.try_mapping(vm)?; + Ok(Self { obj }) } } diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index 60b623ef3ed..a092d7097be 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -810,7 +810,7 @@ impl PyObject { } // CPython-compatible drop implementation - let del = self.class().mro_find_map(|cls| cls.slots.del.load()); + let del = self.class().slots.del.load(); if let Some(slot_del) = del { call_slot_del(self, slot_del)?; } diff --git a/crates/vm/src/protocol/buffer.rs b/crates/vm/src/protocol/buffer.rs index 0a34af59080..88524a9a9ee 100644 --- a/crates/vm/src/protocol/buffer.rs +++ b/crates/vm/src/protocol/buffer.rs @@ -143,8 +143,7 @@ impl PyBuffer { impl<'a> TryFromBorrowedObject<'a> for PyBuffer { fn try_from_borrowed_object(vm: &VirtualMachine, obj: &'a PyObject) -> PyResult { let cls = obj.class(); - let as_buffer = cls.mro_find_map(|cls| cls.slots.as_buffer); - if let Some(f) = as_buffer { + if let Some(f) = cls.slots.as_buffer { return f(obj, vm); } Err(vm.new_type_error(format!( diff --git a/crates/vm/src/protocol/callable.rs b/crates/vm/src/protocol/callable.rs index 1444b6bf73a..5280e04e928 100644 --- a/crates/vm/src/protocol/callable.rs +++ b/crates/vm/src/protocol/callable.rs @@ -42,7 +42,7 @@ pub struct PyCallable<'a> { impl<'a> PyCallable<'a> { pub fn new(obj: &'a PyObject) -> Option { - let call = obj.class().mro_find_map(|cls| cls.slots.call.load())?; + let call = obj.class().slots.call.load()?; Some(PyCallable { obj, call }) } diff --git a/crates/vm/src/protocol/iter.rs b/crates/vm/src/protocol/iter.rs index 18f2b5243e2..f6146543de9 100644 --- a/crates/vm/src/protocol/iter.rs +++ b/crates/vm/src/protocol/iter.rs @@ -23,9 +23,7 @@ unsafe impl> Traverse for PyIter { impl PyIter { pub fn check(obj: &PyObject) -> bool { - obj.class() - .mro_find_map(|x| x.slots.iternext.load()) - .is_some() + obj.class().slots.iternext.load().is_some() } } @@ -37,18 +35,19 @@ where Self(obj) } pub fn next(&self, vm: &VirtualMachine) -> PyResult { - let iternext = { - self.0 - .borrow() - .class() - .mro_find_map(|x| x.slots.iternext.load()) - .ok_or_else(|| { - vm.new_type_error(format!( - "'{}' object is not an iterator", - self.0.borrow().class().name() - )) - })? - }; + let iternext = self + .0 + .borrow() + .class() + .slots + .iternext + .load() + .ok_or_else(|| { + vm.new_type_error(format!( + "'{}' object is not an iterator", + self.0.borrow().class().name() + )) + })?; iternext(self.0.borrow(), vm) } @@ -126,10 +125,7 @@ impl TryFromObject for PyIter { // in the vm when a for loop is entered. Next, it is used when the builtin // function 'iter' is called. fn try_from_object(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult { - let get_iter = { - let cls = iter_target.class(); - cls.mro_find_map(|x| x.slots.iter.load()) - }; + let get_iter = iter_target.class().slots.iter.load(); if let Some(get_iter) = get_iter { let iter = get_iter(iter_target, vm)?; if Self::check(&iter) { diff --git a/crates/vm/src/protocol/mapping.rs b/crates/vm/src/protocol/mapping.rs index a942303dbb4..36813bf1df8 100644 --- a/crates/vm/src/protocol/mapping.rs +++ b/crates/vm/src/protocol/mapping.rs @@ -3,7 +3,6 @@ use crate::{ builtins::{ PyDict, PyStrInterned, dict::{PyDictItems, PyDictKeys, PyDictValues}, - type_::PointerSlot, }, convert::ToPyResult, object::{Traverse, TraverseFn}, @@ -13,9 +12,38 @@ use crossbeam_utils::atomic::AtomicCell; // Mapping protocol // https://docs.python.org/3/c-api/mapping.html -impl PyObject { - pub fn to_mapping(&self) -> PyMapping<'_> { - PyMapping::from(self) +#[allow(clippy::type_complexity)] +#[derive(Default)] +pub struct PyMappingSlots { + pub length: AtomicCell, &VirtualMachine) -> PyResult>>, + pub subscript: AtomicCell, &PyObject, &VirtualMachine) -> PyResult>>, + pub ass_subscript: AtomicCell< + Option, &PyObject, Option, &VirtualMachine) -> PyResult<()>>, + >, +} + +impl std::fmt::Debug for PyMappingSlots { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PyMappingSlots") + } +} + +impl PyMappingSlots { + pub fn has_subscript(&self) -> bool { + self.subscript.load().is_some() + } + + /// Copy from static PyMappingMethods + pub fn copy_from(&self, methods: &PyMappingMethods) { + if let Some(f) = methods.length.load() { + self.length.store(Some(f)); + } + if let Some(f) = methods.subscript.load() { + self.subscript.store(Some(f)); + } + if let Some(f) = methods.ass_subscript.load() { + self.ass_subscript.store(Some(f)); + } } } @@ -31,15 +59,11 @@ pub struct PyMappingMethods { impl std::fmt::Debug for PyMappingMethods { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "mapping methods") + f.write_str("PyMappingMethods") } } impl PyMappingMethods { - fn check(&self) -> bool { - self.subscript.load().is_some() - } - #[allow(clippy::declare_interior_mutable_const)] pub const NOT_IMPLEMENTED: Self = Self { length: AtomicCell::new(None), @@ -48,19 +72,24 @@ impl PyMappingMethods { }; } -impl<'a> From<&'a PyObject> for PyMapping<'a> { - fn from(obj: &'a PyObject) -> Self { - static GLOBAL_NOT_IMPLEMENTED: PyMappingMethods = PyMappingMethods::NOT_IMPLEMENTED; - let methods = Self::find_methods(obj) - .map_or(&GLOBAL_NOT_IMPLEMENTED, |x| unsafe { x.borrow_static() }); - Self { obj, methods } +impl PyObject { + pub fn mapping_unchecked(&self) -> PyMapping<'_> { + PyMapping { obj: self } + } + + pub fn try_mapping(&self, vm: &VirtualMachine) -> PyResult> { + let mapping = self.mapping_unchecked(); + if mapping.check() { + Ok(mapping) + } else { + Err(vm.new_type_error(format!("{} is not a mapping object", self.class()))) + } } } #[derive(Copy, Clone)] pub struct PyMapping<'a> { pub obj: &'a PyObject, - pub methods: &'static PyMappingMethods, } unsafe impl Traverse for PyMapping<'_> { @@ -76,34 +105,19 @@ impl AsRef for PyMapping<'_> { } } -impl<'a> PyMapping<'a> { - pub fn try_protocol(obj: &'a PyObject, vm: &VirtualMachine) -> PyResult { - if let Some(methods) = Self::find_methods(obj) - && methods.as_ref().check() - { - return Ok(Self { - obj, - methods: unsafe { methods.borrow_static() }, - }); - } - - Err(vm.new_type_error(format!("{} is not a mapping object", obj.class()))) - } -} - impl PyMapping<'_> { - // PyMapping::Check #[inline] - pub fn check(obj: &PyObject) -> bool { - Self::find_methods(obj).is_some_and(|x| x.as_ref().check()) + pub fn slots(&self) -> &PyMappingSlots { + &self.obj.class().slots.as_mapping } - pub fn find_methods(obj: &PyObject) -> Option> { - obj.class().mro_find_map(|cls| cls.slots.as_mapping.load()) + #[inline] + pub fn check(&self) -> bool { + self.slots().has_subscript() } pub fn length_opt(self, vm: &VirtualMachine) -> Option> { - self.methods.length.load().map(|f| f(self, vm)) + self.slots().length.load().map(|f| f(self, vm)) } pub fn length(self, vm: &VirtualMachine) -> PyResult { @@ -130,7 +144,7 @@ impl PyMapping<'_> { fn _subscript(self, needle: &PyObject, vm: &VirtualMachine) -> PyResult { let f = - self.methods.subscript.load().ok_or_else(|| { + self.slots().subscript.load().ok_or_else(|| { vm.new_type_error(format!("{} is not a mapping", self.obj.class())) })?; f(self, needle, vm) @@ -142,7 +156,7 @@ impl PyMapping<'_> { value: Option, vm: &VirtualMachine, ) -> PyResult<()> { - let f = self.methods.ass_subscript.load().ok_or_else(|| { + let f = self.slots().ass_subscript.load().ok_or_else(|| { vm.new_type_error(format!( "'{}' object does not support item assignment", self.obj.class() diff --git a/crates/vm/src/protocol/mod.rs b/crates/vm/src/protocol/mod.rs index d5c7e239a24..e7be286c265 100644 --- a/crates/vm/src/protocol/mod.rs +++ b/crates/vm/src/protocol/mod.rs @@ -9,9 +9,9 @@ mod sequence; pub use buffer::{BufferDescriptor, BufferMethods, BufferResizeGuard, PyBuffer, VecBuffer}; pub use callable::PyCallable; pub use iter::{PyIter, PyIterIter, PyIterReturn}; -pub use mapping::{PyMapping, PyMappingMethods}; +pub use mapping::{PyMapping, PyMappingMethods, PyMappingSlots}; pub use number::{ PyNumber, PyNumberBinaryFunc, PyNumberBinaryOp, PyNumberMethods, PyNumberSlots, PyNumberTernaryOp, PyNumberUnaryFunc, handle_bytes_to_int_err, }; -pub use sequence::{PySequence, PySequenceMethods}; +pub use sequence::{PySequence, PySequenceMethods, PySequenceSlots}; diff --git a/crates/vm/src/protocol/number.rs b/crates/vm/src/protocol/number.rs index 1242ee52795..4f21a1e64f9 100644 --- a/crates/vm/src/protocol/number.rs +++ b/crates/vm/src/protocol/number.rs @@ -20,8 +20,8 @@ pub type PyNumberTernaryFunc = fn(&PyObject, &PyObject, &PyObject, &VirtualMachi impl PyObject { #[inline] - pub const fn to_number(&self) -> PyNumber<'_> { - PyNumber(self) + pub const fn number(&self) -> PyNumber<'_> { + PyNumber { obj: self } } pub fn try_index_opt(&self, vm: &VirtualMachine) -> Option> { @@ -30,7 +30,7 @@ impl PyObject { } else if let Some(i) = self.downcast_ref::() { Some(Ok(vm.ctx.new_bigint(i.as_bigint()))) } else { - self.to_number().index(vm) + self.number().index(vm) } } @@ -56,7 +56,7 @@ impl PyObject { if let Some(i) = self.downcast_ref_if_exact::(vm) { Ok(i.to_owned()) - } else if let Some(i) = self.to_number().int(vm).or_else(|| self.try_index_opt(vm)) { + } else if let Some(i) = self.number().int(vm).or_else(|| self.try_index_opt(vm)) { i } else if let Ok(Some(f)) = vm.get_special_method(self, identifier!(vm, __trunc__)) { warnings::warn( @@ -92,7 +92,7 @@ impl PyObject { pub fn try_float_opt(&self, vm: &VirtualMachine) -> Option>> { if let Some(float) = self.downcast_ref_if_exact::(vm) { Some(Ok(float.to_owned())) - } else if let Some(f) = self.to_number().float(vm) { + } else if let Some(f) = self.number().float(vm) { Some(f) } else { self.try_index_opt(vm) @@ -420,11 +420,13 @@ impl PyNumberSlots { } } #[derive(Copy, Clone)] -pub struct PyNumber<'a>(&'a PyObject); +pub struct PyNumber<'a> { + pub obj: &'a PyObject, +} unsafe impl Traverse for PyNumber<'_> { fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { - self.0.traverse(tracer_fn) + self.obj.traverse(tracer_fn) } } @@ -432,36 +434,17 @@ impl Deref for PyNumber<'_> { type Target = PyObject; fn deref(&self) -> &Self::Target { - self.0 + self.obj } } impl<'a> PyNumber<'a> { - pub(crate) const fn obj(self) -> &'a PyObject { - self.0 - } - - // PyNumber_Check + // PyNumber_Check - slots are now inherited pub fn check(obj: &PyObject) -> bool { - let cls = &obj.class(); - // TODO: when we finally have a proper slot inheritance, mro_find_map can be removed - // methods.int.load().is_some() - // || methods.index.load().is_some() - // || methods.float.load().is_some() - // || obj.downcastable::() - let has_number = cls - .mro_find_map(|x| { - let methods = &x.slots.as_number; - if methods.int.load().is_some() - || methods.index.load().is_some() - || methods.float.load().is_some() - { - Some(()) - } else { - None - } - }) - .is_some(); + let methods = &obj.class().slots.as_number; + let has_number = methods.int.load().is_some() + || methods.index.load().is_some() + || methods.float.load().is_some(); has_number || obj.downcastable::() } } @@ -469,114 +452,106 @@ impl<'a> PyNumber<'a> { impl PyNumber<'_> { // PyIndex_Check pub fn is_index(self) -> bool { - self.class() - .mro_find_map(|x| x.slots.as_number.index.load()) - .is_some() + self.class().slots.as_number.index.load().is_some() } #[inline] pub fn int(self, vm: &VirtualMachine) -> Option> { - self.class() - .mro_find_map(|x| x.slots.as_number.int.load()) - .map(|f| { - let ret = f(self, vm)?; - - if let Some(ret) = ret.downcast_ref_if_exact::(vm) { - return Ok(ret.to_owned()); - } - - let ret_class = ret.class().to_owned(); - if let Some(ret) = ret.downcast_ref::() { - warnings::warn( - vm.ctx.exceptions.deprecation_warning, - format!( - "__int__ returned non-int (type {ret_class}). \ + self.class().slots.as_number.int.load().map(|f| { + let ret = f(self, vm)?; + + if let Some(ret) = ret.downcast_ref_if_exact::(vm) { + return Ok(ret.to_owned()); + } + + let ret_class = ret.class().to_owned(); + if let Some(ret) = ret.downcast_ref::() { + warnings::warn( + vm.ctx.exceptions.deprecation_warning, + format!( + "__int__ returned non-int (type {ret_class}). \ The ability to return an instance of a strict subclass of int \ is deprecated, and may be removed in a future version of Python." - ), - 1, - vm, - )?; - - Ok(ret.to_owned()) - } else { - Err(vm.new_type_error(format!( - "{}.__int__ returned non-int(type {})", - self.class(), - ret_class - ))) - } - }) + ), + 1, + vm, + )?; + + Ok(ret.to_owned()) + } else { + Err(vm.new_type_error(format!( + "{}.__int__ returned non-int(type {})", + self.class(), + ret_class + ))) + } + }) } #[inline] pub fn index(self, vm: &VirtualMachine) -> Option> { - self.class() - .mro_find_map(|x| x.slots.as_number.index.load()) - .map(|f| { - let ret = f(self, vm)?; - - if let Some(ret) = ret.downcast_ref_if_exact::(vm) { - return Ok(ret.to_owned()); - } - - let ret_class = ret.class().to_owned(); - if let Some(ret) = ret.downcast_ref::() { - warnings::warn( - vm.ctx.exceptions.deprecation_warning, - format!( - "__index__ returned non-int (type {ret_class}). \ + self.class().slots.as_number.index.load().map(|f| { + let ret = f(self, vm)?; + + if let Some(ret) = ret.downcast_ref_if_exact::(vm) { + return Ok(ret.to_owned()); + } + + let ret_class = ret.class().to_owned(); + if let Some(ret) = ret.downcast_ref::() { + warnings::warn( + vm.ctx.exceptions.deprecation_warning, + format!( + "__index__ returned non-int (type {ret_class}). \ The ability to return an instance of a strict subclass of int \ is deprecated, and may be removed in a future version of Python." - ), - 1, - vm, - )?; - - Ok(ret.to_owned()) - } else { - Err(vm.new_type_error(format!( - "{}.__index__ returned non-int(type {})", - self.class(), - ret_class - ))) - } - }) + ), + 1, + vm, + )?; + + Ok(ret.to_owned()) + } else { + Err(vm.new_type_error(format!( + "{}.__index__ returned non-int(type {})", + self.class(), + ret_class + ))) + } + }) } #[inline] pub fn float(self, vm: &VirtualMachine) -> Option>> { - self.class() - .mro_find_map(|x| x.slots.as_number.float.load()) - .map(|f| { - let ret = f(self, vm)?; - - if let Some(ret) = ret.downcast_ref_if_exact::(vm) { - return Ok(ret.to_owned()); - } - - let ret_class = ret.class().to_owned(); - if let Some(ret) = ret.downcast_ref::() { - warnings::warn( - vm.ctx.exceptions.deprecation_warning, - format!( - "__float__ returned non-float (type {ret_class}). \ + self.class().slots.as_number.float.load().map(|f| { + let ret = f(self, vm)?; + + if let Some(ret) = ret.downcast_ref_if_exact::(vm) { + return Ok(ret.to_owned()); + } + + let ret_class = ret.class().to_owned(); + if let Some(ret) = ret.downcast_ref::() { + warnings::warn( + vm.ctx.exceptions.deprecation_warning, + format!( + "__float__ returned non-float (type {ret_class}). \ The ability to return an instance of a strict subclass of float \ is deprecated, and may be removed in a future version of Python." - ), - 1, - vm, - )?; - - Ok(ret.to_owned()) - } else { - Err(vm.new_type_error(format!( - "{}.__float__ returned non-float(type {})", - self.class(), - ret_class - ))) - } - }) + ), + 1, + vm, + )?; + + Ok(ret.to_owned()) + } else { + Err(vm.new_type_error(format!( + "{}.__float__ returned non-float(type {})", + self.class(), + ret_class + ))) + } + }) } } diff --git a/crates/vm/src/protocol/object.rs b/crates/vm/src/protocol/object.rs index f2e52a94004..ec1a6f55969 100644 --- a/crates/vm/src/protocol/object.rs +++ b/crates/vm/src/protocol/object.rs @@ -12,7 +12,7 @@ use crate::{ dict_inner::DictKey, function::{Either, FuncArgs, PyArithmeticValue, PySetterValue}, object::PyPayload, - protocol::{PyIter, PyMapping, PySequence}, + protocol::PyIter, types::{Constructor, PyComparisonOp}, }; @@ -136,10 +136,7 @@ impl PyObject { #[inline] pub(crate) fn get_attr_inner(&self, attr_name: &Py, vm: &VirtualMachine) -> PyResult { vm_trace!("object.__getattribute__: {:?} {:?}", self, attr_name); - let getattro = self - .class() - .mro_find_map(|cls| cls.slots.getattro.load()) - .unwrap(); + let getattro = self.class().slots.getattro.load().unwrap(); getattro(self, attr_name, vm).inspect_err(|exc| { vm.set_attribute_error_context(exc, self.to_owned(), attr_name.to_owned()); }) @@ -153,21 +150,20 @@ impl PyObject { ) -> PyResult<()> { let setattro = { let cls = self.class(); - cls.mro_find_map(|cls| cls.slots.setattro.load()) - .ok_or_else(|| { - let has_getattr = cls.mro_find_map(|cls| cls.slots.getattro.load()).is_some(); - vm.new_type_error(format!( - "'{}' object has {} attributes ({} {})", - cls.name(), - if has_getattr { "only read-only" } else { "no" }, - if attr_value.is_assign() { - "assign to" - } else { - "del" - }, - attr_name - )) - })? + cls.slots.setattro.load().ok_or_else(|| { + let has_getattr = cls.slots.getattro.load().is_some(); + vm.new_type_error(format!( + "'{}' object has {} attributes ({} {})", + cls.name(), + if has_getattr { "only read-only" } else { "no" }, + if attr_value.is_assign() { + "assign to" + } else { + "del" + }, + attr_name + )) + })? }; setattro(self, attr_name, attr_value, vm) } @@ -197,7 +193,7 @@ impl PyObject { .interned_str(attr_name) .and_then(|attr_name| self.get_class_attr(attr_name)) { - let descr_set = attr.class().mro_find_map(|cls| cls.slots.descr_set.load()); + let descr_set = attr.class().slots.descr_set.load(); if let Some(descriptor) = descr_set { return descriptor(&attr, self.to_owned(), value, vm); } @@ -239,11 +235,9 @@ impl PyObject { let cls_attr = match cls_attr_name.and_then(|name| obj_cls.get_attr(name)) { Some(descr) => { let descr_cls = descr.class(); - let descr_get = descr_cls.mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = descr_cls.slots.descr_get.load(); if let Some(descr_get) = descr_get - && descr_cls - .mro_find_map(|cls| cls.slots.descr_set.load()) - .is_some() + && descr_cls.slots.descr_set.load().is_some() { let cls = obj_cls.to_owned().into(); return descr_get(descr, Some(self.to_owned()), Some(cls), vm).map(Some); @@ -293,10 +287,7 @@ impl PyObject { ) -> PyResult> { let swapped = op.swapped(); let call_cmp = |obj: &Self, other: &Self, op| { - let cmp = obj - .class() - .mro_find_map(|cls| cls.slots.richcompare.load()) - .unwrap(); + let cmp = obj.class().slots.richcompare.load().unwrap(); let r = match cmp(obj, other, op, vm)? { Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A), Either::B(arithmetic) => arithmetic.map(Either::B), @@ -353,18 +344,15 @@ impl PyObject { pub fn repr(&self, vm: &VirtualMachine) -> PyResult> { vm.with_recursion("while getting the repr of an object", || { - // TODO: RustPython does not implement type slots inheritance yet - self.class() - .mro_find_map(|cls| cls.slots.repr.load()) - .map_or_else( - || { - Err(vm.new_runtime_error(format!( + self.class().slots.repr.load().map_or_else( + || { + Err(vm.new_runtime_error(format!( "BUG: object of type '{}' has no __repr__ method. This is a bug in RustPython.", self.class().name() ))) - }, - |repr| repr(self, vm), - ) + }, + |repr| repr(self, vm), + ) }) } @@ -659,7 +647,7 @@ impl PyObject { } pub fn hash(&self, vm: &VirtualMachine) -> PyResult { - if let Some(hash) = self.class().mro_find_map(|cls| cls.slots.hash.load()) { + if let Some(hash) = self.class().slots.hash.load() { return hash(self, vm); } @@ -681,9 +669,9 @@ impl PyObject { } pub fn length_opt(&self, vm: &VirtualMachine) -> Option> { - self.to_sequence() + self.sequence_unchecked() .length_opt(vm) - .or_else(|| self.to_mapping().length_opt(vm)) + .or_else(|| self.mapping_unchecked().length_opt(vm)) } pub fn length(&self, vm: &VirtualMachine) -> PyResult { @@ -702,9 +690,9 @@ impl PyObject { let needle = needle.to_pyobject(vm); - if let Ok(mapping) = PyMapping::try_protocol(self, vm) { + if let Ok(mapping) = self.try_mapping(vm) { mapping.subscript(&needle, vm) - } else if let Ok(seq) = PySequence::try_protocol(self, vm) { + } else if let Ok(seq) = self.try_sequence(vm) { let i = needle.key_as_isize(vm)?; seq.get_item(i, vm) } else { @@ -734,14 +722,14 @@ impl PyObject { return dict.set_item(needle, value, vm); } - let mapping = self.to_mapping(); - if let Some(f) = mapping.methods.ass_subscript.load() { + let mapping = self.mapping_unchecked(); + if let Some(f) = mapping.slots().ass_subscript.load() { let needle = needle.to_pyobject(vm); return f(mapping, &needle, Some(value), vm); } - let seq = self.to_sequence(); - if let Some(f) = seq.methods.ass_item.load() { + let seq = self.sequence_unchecked(); + if let Some(f) = seq.slots().ass_item.load() { let i = needle.key_as_isize(vm)?; return f(seq, i, Some(value), vm); } @@ -757,13 +745,13 @@ impl PyObject { return dict.del_item(needle, vm); } - let mapping = self.to_mapping(); - if let Some(f) = mapping.methods.ass_subscript.load() { + let mapping = self.mapping_unchecked(); + if let Some(f) = mapping.slots().ass_subscript.load() { let needle = needle.to_pyobject(vm); return f(mapping, &needle, None, vm); } - let seq = self.to_sequence(); - if let Some(f) = seq.methods.ass_item.load() { + let seq = self.sequence_unchecked(); + if let Some(f) = seq.slots().ass_item.load() { let i = needle.key_as_isize(vm)?; return f(seq, i, None, vm); } @@ -781,7 +769,7 @@ impl PyObject { let res = obj_cls.lookup_ref(attr, vm)?; // If it's a descriptor, call its __get__ method - let descr_get = res.class().mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = res.class().slots.descr_get.load(); if let Some(descr_get) = descr_get { let obj_cls = obj_cls.to_owned().into(); // CPython ignores exceptions in _PyObject_LookupSpecial and returns NULL diff --git a/crates/vm/src/protocol/sequence.rs b/crates/vm/src/protocol/sequence.rs index fb71446a5a4..a7576e63efb 100644 --- a/crates/vm/src/protocol/sequence.rs +++ b/crates/vm/src/protocol/sequence.rs @@ -1,26 +1,70 @@ use crate::{ PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, - builtins::{PyList, PyListRef, PySlice, PyTuple, PyTupleRef, type_::PointerSlot}, + builtins::{PyList, PyListRef, PySlice, PyTuple, PyTupleRef}, convert::ToPyObject, function::PyArithmeticValue, object::{Traverse, TraverseFn}, - protocol::{PyMapping, PyNumberBinaryOp}, + protocol::PyNumberBinaryOp, }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; -use std::fmt::Debug; // Sequence Protocol // https://docs.python.org/3/c-api/sequence.html -impl PyObject { - #[inline] - pub fn to_sequence(&self) -> PySequence<'_> { - static GLOBAL_NOT_IMPLEMENTED: PySequenceMethods = PySequenceMethods::NOT_IMPLEMENTED; - PySequence { - obj: self, - methods: PySequence::find_methods(self) - .map_or(&GLOBAL_NOT_IMPLEMENTED, |x| unsafe { x.borrow_static() }), +#[allow(clippy::type_complexity)] +#[derive(Default)] +pub struct PySequenceSlots { + pub length: AtomicCell, &VirtualMachine) -> PyResult>>, + pub concat: AtomicCell, &PyObject, &VirtualMachine) -> PyResult>>, + pub repeat: AtomicCell, isize, &VirtualMachine) -> PyResult>>, + pub item: AtomicCell, isize, &VirtualMachine) -> PyResult>>, + pub ass_item: AtomicCell< + Option, isize, Option, &VirtualMachine) -> PyResult<()>>, + >, + pub contains: + AtomicCell, &PyObject, &VirtualMachine) -> PyResult>>, + pub inplace_concat: + AtomicCell, &PyObject, &VirtualMachine) -> PyResult>>, + pub inplace_repeat: AtomicCell, isize, &VirtualMachine) -> PyResult>>, +} + +impl std::fmt::Debug for PySequenceSlots { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PySequenceSlots") + } +} + +impl PySequenceSlots { + pub fn has_item(&self) -> bool { + self.item.load().is_some() + } + + /// Copy from static PySequenceMethods + pub fn copy_from(&self, methods: &PySequenceMethods) { + if let Some(f) = methods.length.load() { + self.length.store(Some(f)); + } + if let Some(f) = methods.concat.load() { + self.concat.store(Some(f)); + } + if let Some(f) = methods.repeat.load() { + self.repeat.store(Some(f)); + } + if let Some(f) = methods.item.load() { + self.item.store(Some(f)); + } + if let Some(f) = methods.ass_item.load() { + self.ass_item.store(Some(f)); + } + if let Some(f) = methods.contains.load() { + self.contains.store(Some(f)); + } + if let Some(f) = methods.inplace_concat.load() { + self.inplace_concat.store(Some(f)); + } + if let Some(f) = methods.inplace_repeat.load() { + self.inplace_repeat.store(Some(f)); } } } @@ -42,9 +86,9 @@ pub struct PySequenceMethods { pub inplace_repeat: AtomicCell, isize, &VirtualMachine) -> PyResult>>, } -impl Debug for PySequenceMethods { +impl std::fmt::Debug for PySequenceMethods { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Sequence Methods") + f.write_str("PySequenceMethods") } } @@ -62,10 +106,25 @@ impl PySequenceMethods { }; } +impl PyObject { + #[inline] + pub fn sequence_unchecked(&self) -> PySequence<'_> { + PySequence { obj: self } + } + + pub fn try_sequence(&self, vm: &VirtualMachine) -> PyResult> { + let seq = self.sequence_unchecked(); + if seq.check() { + Ok(seq) + } else { + Err(vm.new_type_error(format!("'{}' is not a sequence", self.class()))) + } + } +} + #[derive(Copy, Clone)] pub struct PySequence<'a> { pub obj: &'a PyObject, - pub methods: &'static PySequenceMethods, } unsafe impl Traverse for PySequence<'_> { @@ -74,34 +133,18 @@ unsafe impl Traverse for PySequence<'_> { } } -impl<'a> PySequence<'a> { +impl PySequence<'_> { #[inline] - pub const fn with_methods(obj: &'a PyObject, methods: &'static PySequenceMethods) -> Self { - Self { obj, methods } + pub fn slots(&self) -> &PySequenceSlots { + &self.obj.class().slots.as_sequence } - pub fn try_protocol(obj: &'a PyObject, vm: &VirtualMachine) -> PyResult { - let seq = obj.to_sequence(); - if seq.check() { - Ok(seq) - } else { - Err(vm.new_type_error(format!("'{}' is not a sequence", obj.class()))) - } - } -} - -impl PySequence<'_> { pub fn check(&self) -> bool { - self.methods.item.load().is_some() - } - - pub fn find_methods(obj: &PyObject) -> Option> { - let cls = obj.class(); - cls.mro_find_map(|x| x.slots.as_sequence.load()) + self.slots().has_item() } pub fn length_opt(self, vm: &VirtualMachine) -> Option> { - self.methods.length.load().map(|f| f(self, vm)) + self.slots().length.load().map(|f| f(self, vm)) } pub fn length(self, vm: &VirtualMachine) -> PyResult { @@ -114,12 +157,12 @@ impl PySequence<'_> { } pub fn concat(self, other: &PyObject, vm: &VirtualMachine) -> PyResult { - if let Some(f) = self.methods.concat.load() { + if let Some(f) = self.slots().concat.load() { return f(self, other, vm); } // if both arguments appear to be sequences, try fallback to __add__ - if self.check() && other.to_sequence().check() { + if self.check() && other.sequence_unchecked().check() { let ret = vm.binary_op1(self.obj, other, PyNumberBinaryOp::Add)?; if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) { return Ok(ret); @@ -133,7 +176,7 @@ impl PySequence<'_> { } pub fn repeat(self, n: isize, vm: &VirtualMachine) -> PyResult { - if let Some(f) = self.methods.repeat.load() { + if let Some(f) = self.slots().repeat.load() { return f(self, n, vm); } @@ -149,15 +192,15 @@ impl PySequence<'_> { } pub fn inplace_concat(self, other: &PyObject, vm: &VirtualMachine) -> PyResult { - if let Some(f) = self.methods.inplace_concat.load() { + if let Some(f) = self.slots().inplace_concat.load() { return f(self, other, vm); } - if let Some(f) = self.methods.concat.load() { + if let Some(f) = self.slots().concat.load() { return f(self, other, vm); } // if both arguments appear to be sequences, try fallback to __iadd__ - if self.check() && other.to_sequence().check() { + if self.check() && other.sequence_unchecked().check() { let ret = vm._iadd(self.obj, other)?; if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) { return Ok(ret); @@ -171,10 +214,10 @@ impl PySequence<'_> { } pub fn inplace_repeat(self, n: isize, vm: &VirtualMachine) -> PyResult { - if let Some(f) = self.methods.inplace_repeat.load() { + if let Some(f) = self.slots().inplace_repeat.load() { return f(self, n, vm); } - if let Some(f) = self.methods.repeat.load() { + if let Some(f) = self.slots().repeat.load() { return f(self, n, vm); } @@ -189,7 +232,7 @@ impl PySequence<'_> { } pub fn get_item(self, i: isize, vm: &VirtualMachine) -> PyResult { - if let Some(f) = self.methods.item.load() { + if let Some(f) = self.slots().item.load() { return f(self, i, vm); } Err(vm.new_type_error(format!( @@ -199,7 +242,7 @@ impl PySequence<'_> { } fn _ass_item(self, i: isize, value: Option, vm: &VirtualMachine) -> PyResult<()> { - if let Some(f) = self.methods.ass_item.load() { + if let Some(f) = self.slots().ass_item.load() { return f(self, i, value, vm); } Err(vm.new_type_error(format!( @@ -222,7 +265,7 @@ impl PySequence<'_> { } pub fn get_slice(&self, start: isize, stop: isize, vm: &VirtualMachine) -> PyResult { - if let Ok(mapping) = PyMapping::try_protocol(self.obj, vm) { + if let Ok(mapping) = self.obj.try_mapping(vm) { let slice = PySlice { start: Some(start.to_pyobject(vm)), stop: stop.to_pyobject(vm), @@ -241,8 +284,8 @@ impl PySequence<'_> { value: Option, vm: &VirtualMachine, ) -> PyResult<()> { - let mapping = self.obj.to_mapping(); - if let Some(f) = mapping.methods.ass_subscript.load() { + let mapping = self.obj.mapping_unchecked(); + if let Some(f) = mapping.slots().ass_subscript.load() { let slice = PySlice { start: Some(start.to_pyobject(vm)), stop: stop.to_pyobject(vm), @@ -355,7 +398,7 @@ impl PySequence<'_> { } pub fn contains(self, target: &PyObject, vm: &VirtualMachine) -> PyResult { - if let Some(f) = self.methods.contains.load() { + if let Some(f) = self.slots().contains.load() { return f(self, target, vm); } diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index 72d2c724159..19f85af55da 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -268,7 +268,7 @@ mod builtins { if !globals.fast_isinstance(vm.ctx.types.dict_type) { return Err(match func_name { "eval" => { - let is_mapping = crate::protocol::PyMapping::check(globals); + let is_mapping = globals.mapping_unchecked().check(); vm.new_type_error(if is_mapping { "globals must be a real dict; try eval(expr, {}, mapping)" .to_owned() diff --git a/crates/vm/src/stdlib/ctypes/structure.rs b/crates/vm/src/stdlib/ctypes/structure.rs index 1ca428669d4..d5aca392c52 100644 --- a/crates/vm/src/stdlib/ctypes/structure.rs +++ b/crates/vm/src/stdlib/ctypes/structure.rs @@ -467,7 +467,7 @@ impl SetAttr for PyCStructType { // Check for data descriptor first if let Some(attr) = pytype.get_class_attr(attr_name_interned) { - let descr_set = attr.class().mro_find_map(|cls| cls.slots.descr_set.load()); + let descr_set = attr.class().slots.descr_set.load(); if let Some(descriptor) = descr_set { return descriptor(&attr, pytype.to_owned().into(), value, vm); } diff --git a/crates/vm/src/stdlib/ctypes/union.rs b/crates/vm/src/stdlib/ctypes/union.rs index 500aa8e6244..41bc7492a25 100644 --- a/crates/vm/src/stdlib/ctypes/union.rs +++ b/crates/vm/src/stdlib/ctypes/union.rs @@ -369,7 +369,7 @@ impl SetAttr for PyCUnionType { // 1. First, do PyType's setattro (PyType_Type.tp_setattro first) // Check for data descriptor first if let Some(attr) = pytype.get_class_attr(attr_name_interned) { - let descr_set = attr.class().mro_find_map(|cls| cls.slots.descr_set.load()); + let descr_set = attr.class().slots.descr_set.load(); if let Some(descriptor) = descr_set { descriptor(&attr, pytype.to_owned().into(), value.clone(), vm)?; // After successful setattro, check if _fields_ and call process_fields diff --git a/crates/vm/src/stdlib/typing.rs b/crates/vm/src/stdlib/typing.rs index afa8bd6eb90..469a75b010f 100644 --- a/crates/vm/src/stdlib/typing.rs +++ b/crates/vm/src/stdlib/typing.rs @@ -1,5 +1,5 @@ // spell-checker:ignore typevarobject funcobj -use crate::{PyPayload, PyRef, VirtualMachine, class::PyClassImpl, stdlib::PyModule}; +use crate::{Context, PyPayload, PyRef, VirtualMachine, class::PyClassImpl, stdlib::PyModule}; pub use crate::stdlib::typevar::{ Generic, ParamSpec, ParamSpecArgs, ParamSpecKwargs, TypeVar, TypeVarTuple, @@ -7,6 +7,11 @@ pub use crate::stdlib::typevar::{ }; pub use decl::*; +/// Initialize typing types (call extend_class) +pub fn init(ctx: &Context) { + NoDefault::extend_class(ctx, ctx.types.typing_no_default_type); +} + pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { let module = decl::make_module(vm); TypeVar::make_class(&vm.ctx); diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index 57ac24f461a..bfe6e047622 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -3,7 +3,7 @@ use crate::common::lock::{ }; use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyInt, PyStr, PyStrInterned, PyStrRef, PyType, PyTypeRef, type_::PointerSlot}, + builtins::{PyInt, PyStr, PyStrInterned, PyStrRef, PyType, PyTypeRef}, bytecode::ComparisonOperator, common::hash::PyHash, convert::ToPyObject, @@ -11,8 +11,8 @@ use crate::{ Either, FromArgs, FuncArgs, OptionalArg, PyComparisonValue, PyMethodDef, PySetterValue, }, protocol::{ - PyBuffer, PyIterReturn, PyMapping, PyMappingMethods, PyNumber, PyNumberMethods, - PyNumberSlots, PySequence, PySequenceMethods, + PyBuffer, PyIterReturn, PyMapping, PyMappingMethods, PyMappingSlots, PyNumber, + PyNumberMethods, PyNumberSlots, PySequence, PySequenceMethods, PySequenceSlots, }, vm::Context, }; @@ -134,8 +134,8 @@ pub struct PyTypeSlots { // Method suites for standard classes pub as_number: PyNumberSlots, - pub as_sequence: AtomicCell>>, - pub as_mapping: AtomicCell>>, + pub as_sequence: PySequenceSlots, + pub as_mapping: PyMappingSlots, // More standard operations (here for binary compatibility) pub hash: AtomicCell>, @@ -358,6 +358,16 @@ fn repr_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult> }) } +fn str_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult> { + let ret = vm.call_special_method(zelf, identifier!(vm, __str__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!( + "__str__ returned non-string (type {})", + obj.class() + )) + }) +} + fn hash_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { let hash_obj = vm.call_special_method(zelf, identifier!(vm, __hash__), ())?; let py_int = hash_obj @@ -426,6 +436,18 @@ fn iter_wrapper(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { vm.call_special_method(&zelf, identifier!(vm, __iter__), ()) } +fn bool_wrapper(num: PyNumber<'_>, vm: &VirtualMachine) -> PyResult { + let result = vm.call_special_method(num.obj, identifier!(vm, __bool__), ())?; + // __bool__ must return exactly bool, not int subclass + if !result.class().is(vm.ctx.types.bool_type) { + return Err(vm.new_type_error(format!( + "__bool__ should return bool, returned {}", + result.class().name() + ))); + } + Ok(crate::builtins::bool_::get_value(&result)) +} + // PyObject_SelfIter in CPython const fn self_iter(zelf: PyObjectRef, _vm: &VirtualMachine) -> PyResult { Ok(zelf) @@ -491,17 +513,36 @@ impl PyType { macro_rules! toggle_slot { ($name:ident, $func:expr) => {{ - self.slots.$name.store(if ADD { Some($func) } else { None }); + if ADD { + self.slots.$name.store(Some($func)); + } else { + // When deleting, re-inherit from MRO (skip self) + let inherited = self + .mro + .read() + .iter() + .skip(1) + .find_map(|cls| cls.slots.$name.load()); + self.slots.$name.store(inherited); + } }}; } macro_rules! toggle_sub_slot { - ($group:ident, $name:ident, $func:expr) => { - self.slots - .$group - .$name - .store(if ADD { Some($func) } else { None }); - }; + ($group:ident, $name:ident, $func:expr) => {{ + if ADD { + self.slots.$group.$name.store(Some($func)); + } else { + // When deleting, re-inherit from MRO (skip self) + let inherited = self + .mro + .read() + .iter() + .skip(1) + .find_map(|cls| cls.slots.$group.$name.load()); + self.slots.$group.$name.store(inherited); + } + }}; } macro_rules! update_slot { @@ -510,66 +551,41 @@ impl PyType { }}; } - macro_rules! update_pointer_slot { - ($name:ident, $pointed:ident) => {{ - self.slots - .$name - .store(unsafe { PointerSlot::from_heaptype(self, |ext| &ext.$pointed) }); - }}; - } - - macro_rules! toggle_ext_func { - ($n1:ident, $n2:ident, $func:expr) => {{ - self.heaptype_ext.as_ref().unwrap().$n1.$n2.store(if ADD { - Some($func) - } else { - None - }); - }}; - } - match name { _ if name == identifier!(ctx, __len__) => { - // update_slot!(as_mapping, slot_as_mapping); - toggle_ext_func!(sequence_methods, length, |seq, vm| len_wrapper(seq.obj, vm)); - update_pointer_slot!(as_sequence, sequence_methods); - toggle_ext_func!(mapping_methods, length, |mapping, vm| len_wrapper( + toggle_sub_slot!(as_sequence, length, |seq, vm| len_wrapper(seq.obj, vm)); + toggle_sub_slot!(as_mapping, length, |mapping, vm| len_wrapper( mapping.obj, vm )); - update_pointer_slot!(as_mapping, mapping_methods); } _ if name == identifier!(ctx, __getitem__) => { - // update_slot!(as_mapping, slot_as_mapping); - toggle_ext_func!(sequence_methods, item, |seq, i, vm| getitem_wrapper( + toggle_sub_slot!(as_sequence, item, |seq, i, vm| getitem_wrapper( seq.obj, i, vm )); - update_pointer_slot!(as_sequence, sequence_methods); - toggle_ext_func!(mapping_methods, subscript, |mapping, key, vm| { + toggle_sub_slot!(as_mapping, subscript, |mapping, key, vm| { getitem_wrapper(mapping.obj, key, vm) }); - update_pointer_slot!(as_mapping, mapping_methods); } _ if name == identifier!(ctx, __setitem__) || name == identifier!(ctx, __delitem__) => { - // update_slot!(as_mapping, slot_as_mapping); - toggle_ext_func!(sequence_methods, ass_item, |seq, i, value, vm| { + toggle_sub_slot!(as_sequence, ass_item, |seq, i, value, vm| { setitem_wrapper(seq.obj, i, value, vm) }); - update_pointer_slot!(as_sequence, sequence_methods); - toggle_ext_func!(mapping_methods, ass_subscript, |mapping, key, value, vm| { + toggle_sub_slot!(as_mapping, ass_subscript, |mapping, key, value, vm| { setitem_wrapper(mapping.obj, key, value, vm) }); - update_pointer_slot!(as_mapping, mapping_methods); } _ if name == identifier!(ctx, __contains__) => { - toggle_ext_func!(sequence_methods, contains, |seq, needle, vm| { + toggle_sub_slot!(as_sequence, contains, |seq, needle, vm| { contains_wrapper(seq.obj, needle, vm) }); - update_pointer_slot!(as_sequence, sequence_methods); } _ if name == identifier!(ctx, __repr__) => { update_slot!(repr, repr_wrapper); } + _ if name == identifier!(ctx, __str__) => { + update_slot!(str, str_wrapper); + } _ if name == identifier!(ctx, __hash__) => { let is_unhashable = self .attributes @@ -624,6 +640,9 @@ impl PyType { _ if name == identifier!(ctx, __del__) => { toggle_slot!(del, del_wrapper); } + _ if name == identifier!(ctx, __bool__) => { + toggle_sub_slot!(as_number, boolean, bool_wrapper); + } _ if name == identifier!(ctx, __int__) => { toggle_sub_slot!(as_number, int, number_unary_op_wrapper!(__int__)); } @@ -1380,24 +1399,30 @@ pub trait AsBuffer: PyPayload { #[pyclass] pub trait AsMapping: PyPayload { - #[pyslot] fn as_mapping() -> &'static PyMappingMethods; #[inline] fn mapping_downcast(mapping: PyMapping<'_>) -> &Py { unsafe { mapping.obj.downcast_unchecked_ref() } } + + fn extend_slots(slots: &mut PyTypeSlots) { + slots.as_mapping.copy_from(Self::as_mapping()); + } } #[pyclass] pub trait AsSequence: PyPayload { - #[pyslot] fn as_sequence() -> &'static PySequenceMethods; #[inline] fn sequence_downcast(seq: PySequence<'_>) -> &Py { unsafe { seq.obj.downcast_unchecked_ref() } } + + fn extend_slots(slots: &mut PyTypeSlots) { + slots.as_sequence.copy_from(Self::as_sequence()); + } } #[pyclass] @@ -1412,7 +1437,7 @@ pub trait AsNumber: PyPayload { #[inline] fn number_downcast(num: PyNumber<'_>) -> &Py { - unsafe { num.obj().downcast_unchecked_ref() } + unsafe { num.obj.downcast_unchecked_ref() } } #[inline] diff --git a/crates/vm/src/types/structseq.rs b/crates/vm/src/types/structseq.rs index 2b6a2530b02..27315749e06 100644 --- a/crates/vm/src/types/structseq.rs +++ b/crates/vm/src/types/structseq.rs @@ -1,8 +1,6 @@ use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, - builtins::{ - PyBaseExceptionRef, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, type_::PointerSlot, - }, + builtins::{PyBaseExceptionRef, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef}, class::{PyClassImpl, StaticType}, function::{Either, PyComparisonValue}, iter::PyExactSizeIterator, @@ -87,7 +85,10 @@ static STRUCT_SEQUENCE_AS_SEQUENCE: LazyLock = let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect(); let visible_tuple = PyTuple::new_ref(visible, &vm.ctx); // Use tuple's concat implementation - visible_tuple.as_object().to_sequence().concat(other, vm) + visible_tuple + .as_object() + .sequence_unchecked() + .concat(other, vm) }), repeat: atomic_func!(|seq, n, vm| { // Convert to visible-only tuple, then use regular tuple repeat @@ -96,7 +97,7 @@ static STRUCT_SEQUENCE_AS_SEQUENCE: LazyLock = let visible: Vec<_> = tuple.iter().take(n_seq).cloned().collect(); let visible_tuple = PyTuple::new_ref(visible, &vm.ctx); // Use tuple's repeat implementation - visible_tuple.as_object().to_sequence().repeat(n, vm) + visible_tuple.as_object().sequence_unchecked().repeat(n, vm) }), item: atomic_func!(|seq, i, vm| { let n_seq = get_visible_len(seq.obj, vm)?; @@ -306,12 +307,14 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { ); // Override as_sequence and as_mapping slots to use visible length - class.slots.as_sequence.store(Some(PointerSlot::from( - &*STRUCT_SEQUENCE_AS_SEQUENCE as &'static PySequenceMethods, - ))); - class.slots.as_mapping.store(Some(PointerSlot::from( - &*STRUCT_SEQUENCE_AS_MAPPING as &'static PyMappingMethods, - ))); + class + .slots + .as_sequence + .copy_from(&STRUCT_SEQUENCE_AS_SEQUENCE); + class + .slots + .as_mapping + .copy_from(&STRUCT_SEQUENCE_AS_MAPPING); // Override iter slot to return only visible elements class.slots.iter.store(Some(struct_sequence_iter)); diff --git a/crates/vm/src/types/zoo.rs b/crates/vm/src/types/zoo.rs index dd4631bc767..0cd04a0ac17 100644 --- a/crates/vm/src/types/zoo.rs +++ b/crates/vm/src/types/zoo.rs @@ -242,5 +242,6 @@ impl TypeZoo { genericalias::init(context); union_::init(context); descriptor::init(context); + crate::stdlib::typing::init(context); } } diff --git a/crates/vm/src/vm/method.rs b/crates/vm/src/vm/method.rs index 5df01c556ea..ba323391488 100644 --- a/crates/vm/src/vm/method.rs +++ b/crates/vm/src/vm/method.rs @@ -21,7 +21,7 @@ pub enum PyMethod { impl PyMethod { pub fn get(obj: PyObjectRef, name: &Py, vm: &VirtualMachine) -> PyResult { let cls = obj.class(); - let getattro = cls.mro_find_map(|cls| cls.slots.getattro.load()).unwrap(); + let getattro = cls.slots.getattro.load().unwrap(); if getattro as usize != PyBaseObject::getattro as usize { return obj.get_attr(name, vm).map(Self::Attribute); } @@ -41,11 +41,9 @@ impl PyMethod { is_method = true; None } else { - let descr_get = descr_cls.mro_find_map(|cls| cls.slots.descr_get.load()); + let descr_get = descr_cls.slots.descr_get.load(); if let Some(descr_get) = descr_get - && descr_cls - .mro_find_map(|cls| cls.slots.descr_set.load()) - .is_some() + && descr_cls.slots.descr_set.load().is_some() { let cls = cls.to_owned().into(); return descr_get(descr, Some(obj), Some(cls), vm).map(Self::Attribute); diff --git a/crates/vm/src/vm/vm_object.rs b/crates/vm/src/vm/vm_object.rs index e69301820d6..0d5b286148c 100644 --- a/crates/vm/src/vm/vm_object.rs +++ b/crates/vm/src/vm/vm_object.rs @@ -96,9 +96,7 @@ impl VirtualMachine { obj: Option, cls: Option, ) -> Option { - let descr_get = descr - .class() - .mro_find_map(|cls| cls.slots.descr_get.load())?; + let descr_get = descr.class().slots.descr_get.load()?; Some(descr_get(descr.to_owned(), obj, cls, self)) } diff --git a/crates/vm/src/vm/vm_ops.rs b/crates/vm/src/vm/vm_ops.rs index e30e19981a9..635fa10e630 100644 --- a/crates/vm/src/vm/vm_ops.rs +++ b/crates/vm/src/vm/vm_ops.rs @@ -4,7 +4,7 @@ use crate::{ PyRef, builtins::{PyInt, PyStr, PyStrRef, PyUtf8Str}, object::{AsObject, PyObject, PyObjectRef, PyResult}, - protocol::{PyNumberBinaryOp, PyNumberTernaryOp, PySequence}, + protocol::{PyNumberBinaryOp, PyNumberTernaryOp}, types::PyComparisonOp, }; use num_traits::ToPrimitive; @@ -160,12 +160,12 @@ impl VirtualMachine { let class_a = a.class(); let class_b = b.class(); - // Look up number slots across MRO for inheritance - let slot_a = class_a.mro_find_map(|x| x.slots.as_number.left_binary_op(op_slot)); + // Number slots are inherited, direct access is O(1) + let slot_a = class_a.slots.as_number.left_binary_op(op_slot); let mut slot_b = None; if !class_a.is(class_b) { - let slot_bb = class_b.mro_find_map(|x| x.slots.as_number.right_binary_op(op_slot)); + let slot_bb = class_b.slots.as_number.right_binary_op(op_slot); if slot_bb.map(|x| x as usize) != slot_a.map(|x| x as usize) { slot_b = slot_bb; } @@ -231,10 +231,7 @@ impl VirtualMachine { iop_slot: PyNumberBinaryOp, op_slot: PyNumberBinaryOp, ) -> PyResult { - if let Some(slot) = a - .class() - .mro_find_map(|x| x.slots.as_number.left_binary_op(iop_slot)) - { + if let Some(slot) = a.class().slots.as_number.left_binary_op(iop_slot) { let x = slot(a, b, self)?; if !x.is(&self.ctx.not_implemented) { return Ok(x); @@ -270,12 +267,12 @@ impl VirtualMachine { let class_b = b.class(); let class_c = c.class(); - // Look up number slots across MRO for inheritance - let slot_a = class_a.mro_find_map(|x| x.slots.as_number.left_ternary_op(op_slot)); + // Number slots are inherited, direct access is O(1) + let slot_a = class_a.slots.as_number.left_ternary_op(op_slot); let mut slot_b = None; if !class_a.is(class_b) { - let slot_bb = class_b.mro_find_map(|x| x.slots.as_number.right_ternary_op(op_slot)); + let slot_bb = class_b.slots.as_number.right_ternary_op(op_slot); if slot_bb.map(|x| x as usize) != slot_a.map(|x| x as usize) { slot_b = slot_bb; } @@ -304,7 +301,7 @@ impl VirtualMachine { } } - if let Some(slot_c) = class_c.mro_find_map(|x| x.slots.as_number.left_ternary_op(op_slot)) + if let Some(slot_c) = class_c.slots.as_number.left_ternary_op(op_slot) && slot_a.is_some_and(|slot_a| !std::ptr::fn_addr_eq(slot_a, slot_c)) && slot_b.is_some_and(|slot_b| !std::ptr::fn_addr_eq(slot_b, slot_c)) { @@ -343,10 +340,7 @@ impl VirtualMachine { op_slot: PyNumberTernaryOp, op_str: &str, ) -> PyResult { - if let Some(slot) = a - .class() - .mro_find_map(|x| x.slots.as_number.left_ternary_op(iop_slot)) - { + if let Some(slot) = a.class().slots.as_number.left_ternary_op(iop_slot) { let x = slot(a, b, c, self)?; if !x.is(&self.ctx.not_implemented) { return Ok(x); @@ -386,7 +380,7 @@ impl VirtualMachine { if !result.is(&self.ctx.not_implemented) { return Ok(result); } - if let Ok(seq_a) = PySequence::try_protocol(a, self) { + if let Ok(seq_a) = a.try_sequence(self) { let result = seq_a.concat(b, self)?; if !result.is(&self.ctx.not_implemented) { return Ok(result); @@ -400,7 +394,7 @@ impl VirtualMachine { if !result.is(&self.ctx.not_implemented) { return Ok(result); } - if let Ok(seq_a) = PySequence::try_protocol(a, self) { + if let Ok(seq_a) = a.try_sequence(self) { let result = seq_a.inplace_concat(b, self)?; if !result.is(&self.ctx.not_implemented) { return Ok(result); @@ -414,14 +408,14 @@ impl VirtualMachine { if !result.is(&self.ctx.not_implemented) { return Ok(result); } - if let Ok(seq_a) = PySequence::try_protocol(a, self) { + if let Ok(seq_a) = a.try_sequence(self) { let n = b .try_index(self)? .as_bigint() .to_isize() .ok_or_else(|| self.new_overflow_error("repeated bytes are too long"))?; return seq_a.repeat(n, self); - } else if let Ok(seq_b) = PySequence::try_protocol(b, self) { + } else if let Ok(seq_b) = b.try_sequence(self) { let n = a .try_index(self)? .as_bigint() @@ -442,14 +436,14 @@ impl VirtualMachine { if !result.is(&self.ctx.not_implemented) { return Ok(result); } - if let Ok(seq_a) = PySequence::try_protocol(a, self) { + if let Ok(seq_a) = a.try_sequence(self) { let n = b .try_index(self)? .as_bigint() .to_isize() .ok_or_else(|| self.new_overflow_error("repeated bytes are too long"))?; return seq_a.inplace_repeat(n, self); - } else if let Ok(seq_b) = PySequence::try_protocol(b, self) { + } else if let Ok(seq_b) = b.try_sequence(self) { let n = a .try_index(self)? .as_bigint() @@ -530,7 +524,7 @@ impl VirtualMachine { } pub fn _contains(&self, haystack: &PyObject, needle: &PyObject) -> PyResult { - let seq = haystack.to_sequence(); + let seq = haystack.sequence_unchecked(); seq.contains(needle, self) } } From a5a1173c3f2aac724d0c48ce592670d221b1aa30 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 25 Dec 2025 18:27:19 +0900 Subject: [PATCH 066/418] Disallow rebinding __module__ on immutable builtins and add regression snippet (#6513) * Initial plan * Fix __module__ setter on immutable builtin types Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> * Remove unused value parameter from immutability helper Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> * Auto-format: cargo fmt --all --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> Co-authored-by: github-actions[bot] --- crates/vm/src/builtins/type.rs | 11 ++++++----- extra_tests/snippets/builtin_type.py | 9 +++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index c2373f26faf..ffaad324687 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -1050,10 +1050,12 @@ impl PyType { } #[pygetset(setter)] - fn set___module__(&self, value: PyObjectRef, vm: &VirtualMachine) { + fn set___module__(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + self.check_set_special_type_attr(identifier!(vm, __module__), vm)?; self.attributes .write() .insert(identifier!(vm, __module__), value); + Ok(()) } #[pyclassmethod] @@ -1103,7 +1105,6 @@ impl PyType { fn check_set_special_type_attr( &self, - _value: &PyObject, name: &PyStrInterned, vm: &VirtualMachine, ) -> PyResult<()> { @@ -1119,7 +1120,7 @@ impl PyType { #[pygetset(setter)] fn set___name__(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.check_set_special_type_attr(&value, identifier!(vm, __name__), vm)?; + self.check_set_special_type_attr(identifier!(vm, __name__), vm)?; let name = value.downcast::().map_err(|value| { vm.new_type_error(format!( "can only assign string to {}.__name__, not '{}'", @@ -1172,7 +1173,7 @@ impl PyType { match value { PySetterValue::Assign(ref val) => { let key = identifier!(vm, __type_params__); - self.check_set_special_type_attr(val.as_ref(), key, vm)?; + self.check_set_special_type_attr(key, vm)?; let mut attrs = self.attributes.write(); attrs.insert(key, val.clone().into()); } @@ -1628,7 +1629,7 @@ impl Py { })?; // Check if we can set this special type attribute - self.check_set_special_type_attr(&value, identifier!(vm, __doc__), vm)?; + self.check_set_special_type_attr(identifier!(vm, __doc__), vm)?; // Set the __doc__ in the type's dict self.attributes diff --git a/extra_tests/snippets/builtin_type.py b/extra_tests/snippets/builtin_type.py index abb68f812be..7a8e4840e13 100644 --- a/extra_tests/snippets/builtin_type.py +++ b/extra_tests/snippets/builtin_type.py @@ -72,6 +72,15 @@ assert object.__qualname__ == "object" assert int.__qualname__ == "int" +with assert_raises(TypeError): + type.__module__ = "nope" + +with assert_raises(TypeError): + object.__module__ = "nope" + +with assert_raises(TypeError): + map.__module__ = "nope" + class A(type): pass From 61ddd98b893072d9b067a953318d93a22a68d3ca Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 25 Dec 2025 19:08:23 +0900 Subject: [PATCH 067/418] Handle missing type_params on AST Function/Class/TypeAlias nodes (#6512) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com> Co-authored-by: github-actions[bot] --- crates/vm/src/stdlib/ast/statement.rs | 6 +++--- extra_tests/snippets/stdlib_types.py | 30 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/crates/vm/src/stdlib/ast/statement.rs b/crates/vm/src/stdlib/ast/statement.rs index 6d9b35bee79..f1d36c52e2e 100644 --- a/crates/vm/src/stdlib/ast/statement.rs +++ b/crates/vm/src/stdlib/ast/statement.rs @@ -257,7 +257,7 @@ impl Node for ruff::StmtFunctionDef { type_params: Node::ast_from_object( _vm, source_file, - get_node_field(_vm, &_object, "type_params", "FunctionDef")?, + get_node_field_opt(_vm, &_object, "type_params")?.unwrap_or_else(|| _vm.ctx.none()), )?, range: range_from_object(_vm, source_file, _object, "FunctionDef")?, is_async, @@ -341,7 +341,7 @@ impl Node for ruff::StmtClassDef { type_params: Node::ast_from_object( _vm, source_file, - get_node_field(_vm, &_object, "type_params", "ClassDef")?, + get_node_field_opt(_vm, &_object, "type_params")?.unwrap_or_else(|| _vm.ctx.none()), )?, range: range_from_object(_vm, source_file, _object, "ClassDef")?, }) @@ -503,7 +503,7 @@ impl Node for ruff::StmtTypeAlias { type_params: Node::ast_from_object( _vm, source_file, - get_node_field(_vm, &_object, "type_params", "TypeAlias")?, + get_node_field_opt(_vm, &_object, "type_params")?.unwrap_or_else(|| _vm.ctx.none()), )?, value: Node::ast_from_object( _vm, diff --git a/extra_tests/snippets/stdlib_types.py b/extra_tests/snippets/stdlib_types.py index 3a3872d2f4e..14028268f0e 100644 --- a/extra_tests/snippets/stdlib_types.py +++ b/extra_tests/snippets/stdlib_types.py @@ -1,3 +1,5 @@ +import _ast +import platform import types from testutils import assert_raises @@ -8,3 +10,31 @@ assert ns.b == "Rust" with assert_raises(AttributeError): _ = ns.c + + +def _run_missing_type_params_regression(): + args = _ast.arguments( + posonlyargs=[], + args=[], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + ) + fn = _ast.FunctionDef("f", args, [], [], None, None) + fn.lineno = 1 + fn.col_offset = 0 + fn.end_lineno = 1 + fn.end_col_offset = 0 + mod = _ast.Module([fn], []) + mod.lineno = 1 + mod.col_offset = 0 + mod.end_lineno = 1 + mod.end_col_offset = 0 + compiled = compile(mod, "", "exec") + exec(compiled, {}) + + +if platform.python_implementation() == "RustPython": + _run_missing_type_params_regression() From bcdf37bef17beacdaa41e430eb3e7e2d0bf5cb6c Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Thu, 25 Dec 2025 14:17:31 +0100 Subject: [PATCH 068/418] Align opcode names in `dis` (#6526) * opcodes dis repr like cpython. POP-> POP_TOP * Adjust frame.rs * Fix codegen * snapshots * Add doc for `PopTop` * fix jit --- crates/codegen/src/compile.rs | 46 ++-- ...thon_codegen__compile__tests__if_ands.snap | 16 +- ...hon_codegen__compile__tests__if_mixed.snap | 20 +- ...ython_codegen__compile__tests__if_ors.snap | 16 +- ...pile__tests__nested_double_async_with.snap | 142 ++++++------ crates/compiler-core/src/bytecode.rs | 206 +++++++++--------- crates/jit/src/instructions.rs | 8 +- crates/vm/src/frame.rs | 10 +- 8 files changed, 234 insertions(+), 230 deletions(-) diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index 4f7174c8bad..0131b3008e3 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -1063,7 +1063,7 @@ impl Compiler { } ); - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } else { self.compile_statement(statement)?; } @@ -1079,7 +1079,7 @@ impl Compiler { } ); - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } else { self.compile_statement(last)?; self.emit_load_const(ConstantData::None); @@ -1104,7 +1104,7 @@ impl Compiler { if let Some(last_statement) = body.last() { match last_statement { Stmt::Expr(_) => { - self.current_block().instructions.pop(); // pop Instruction::Pop + self.current_block().instructions.pop(); // pop Instruction::PopTop } Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { let pop_instructions = self.current_block().instructions.pop(); @@ -1401,14 +1401,14 @@ impl Compiler { } // Pop module from stack: - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } } Stmt::Expr(StmtExpr { value, .. }) => { self.compile_expression(value)?; // Pop result of stack, since we not use it: - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } Stmt::Global(_) | Stmt::Nonlocal(_) => { // Handled during symbol table construction. @@ -2051,12 +2051,12 @@ impl Compiler { self.store_name(alias.as_str())? } else { // Drop exception from top of stack: - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } } else { // Catch all! // Drop exception from top of stack: - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } // Handler code: @@ -2950,7 +2950,7 @@ impl Compiler { self.compile_store(var)?; } None => { - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } } final_block @@ -3143,7 +3143,7 @@ impl Compiler { for &label in pc.fail_pop.iter().skip(1).rev() { self.switch_to_block(label); // Emit the POP instruction. - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } // Finally, use the first label. self.switch_to_block(pc.fail_pop[0]); @@ -3187,7 +3187,7 @@ impl Compiler { match n { // If no name is provided, simply pop the top of the stack. None => { - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); Ok(()) } Some(name) => { @@ -3313,7 +3313,7 @@ impl Compiler { } // Pop the subject off the stack. pc.on_top -= 1; - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); Ok(()) } @@ -3497,7 +3497,7 @@ impl Compiler { pc.on_top -= 1; if is_true_wildcard { - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); continue; // Don't compile wildcard patterns } @@ -3548,7 +3548,7 @@ impl Compiler { if size == 0 && star_target.is_none() { // If the pattern is just "{}", we're done! Pop the subject pc.on_top -= 1; - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); return Ok(()); } @@ -3703,8 +3703,8 @@ impl Compiler { // Non-rest pattern: just clean up the stack // Pop them as we're not using them - emit!(self, Instruction::Pop); // Pop keys_tuple - emit!(self, Instruction::Pop); // Pop subject + emit!(self, Instruction::PopTop); // Pop keys_tuple + emit!(self, Instruction::PopTop); // Pop subject } Ok(()) @@ -3792,7 +3792,7 @@ impl Compiler { // In Rust, old_pc is a local clone, so we need not worry about that. // No alternative matched: pop the subject and fail. - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); self.jump_to_fail_pop(pc, JumpOp::Jump)?; // Use the label "end". @@ -3814,7 +3814,7 @@ impl Compiler { // Old context and control will be dropped automatically. // Finally, pop the copy of the subject. - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); Ok(()) } @@ -3888,7 +3888,7 @@ impl Compiler { pc.on_top -= 1; if only_wildcard { // Patterns like: [] / [_] / [_, _] / [*_] / [_, *_] / [_, _, *_] / etc. - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } else if star_wildcard { self.pattern_helper_sequence_subscr(patterns, star.unwrap(), pc)?; } else { @@ -4012,7 +4012,7 @@ impl Compiler { } if i != case_count - 1 { - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } self.compile_statements(&m.body)?; @@ -4023,7 +4023,7 @@ impl Compiler { if has_default { let m = &cases[num_cases - 1]; if num_cases == 1 { - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } else { emit!(self, Instruction::Nop); } @@ -4125,7 +4125,7 @@ impl Compiler { // early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs. self.switch_to_block(break_block); emit!(self, Instruction::Swap { index: 2 }); - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); self.switch_to_block(after_block); } @@ -4192,7 +4192,7 @@ impl Compiler { emit!(self, Instruction::StoreSubscript); } else { // Drop annotation if not assigned to simple identifier. - emit!(self, Instruction::Pop); + emit!(self, Instruction::PopTop); } Ok(()) @@ -4823,7 +4823,7 @@ impl Compiler { arg: bytecode::ResumeType::AfterYield as u32 } ); - emit!(compiler, Instruction::Pop); + emit!(compiler, Instruction::PopTop); Ok(()) }, diff --git a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_ands.snap b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_ands.snap index 8b2907ef6ff..4c9c29887ee 100644 --- a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_ands.snap +++ b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_ands.snap @@ -1,12 +1,12 @@ --- -source: compiler/codegen/src/compile.rs +source: crates/codegen/src/compile.rs expression: "compile_exec(\"\\\nif True and False and False:\n pass\n\")" --- - 1 0 LoadConst (True) - 1 PopJumpIfFalse (6) - 2 LoadConst (False) - 3 PopJumpIfFalse (6) - 4 LoadConst (False) - 5 PopJumpIfFalse (6) + 1 0 LOAD_CONST (True) + 1 POP_JUMP_IF_FALSE (6) + 2 LOAD_CONST (False) + 3 POP_JUMP_IF_FALSE (6) + 4 LOAD_CONST (False) + 5 POP_JUMP_IF_FALSE (6) - 2 >> 6 ReturnConst (None) + 2 >> 6 RETURN_CONST (None) diff --git a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_mixed.snap b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_mixed.snap index fc91a74283b..a93479df96e 100644 --- a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_mixed.snap +++ b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_mixed.snap @@ -1,14 +1,14 @@ --- -source: compiler/codegen/src/compile.rs +source: crates/codegen/src/compile.rs expression: "compile_exec(\"\\\nif (True and False) or (False and True):\n pass\n\")" --- - 1 0 LoadConst (True) - 1 PopJumpIfFalse (4) - 2 LoadConst (False) - 3 PopJumpIfTrue (8) - >> 4 LoadConst (False) - 5 PopJumpIfFalse (8) - 6 LoadConst (True) - 7 PopJumpIfFalse (8) + 1 0 LOAD_CONST (True) + 1 POP_JUMP_IF_FALSE (4) + 2 LOAD_CONST (False) + 3 POP_JUMP_IF_TRUE (8) + >> 4 LOAD_CONST (False) + 5 POP_JUMP_IF_FALSE (8) + 6 LOAD_CONST (True) + 7 POP_JUMP_IF_FALSE (8) - 2 >> 8 ReturnConst (None) + 2 >> 8 RETURN_CONST (None) diff --git a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_ors.snap b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_ors.snap index 9be7c2af7bd..37a862cc65a 100644 --- a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_ors.snap +++ b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__if_ors.snap @@ -1,12 +1,12 @@ --- -source: compiler/codegen/src/compile.rs +source: crates/codegen/src/compile.rs expression: "compile_exec(\"\\\nif True or False or False:\n pass\n\")" --- - 1 0 LoadConst (True) - 1 PopJumpIfTrue (6) - 2 LoadConst (False) - 3 PopJumpIfTrue (6) - 4 LoadConst (False) - 5 PopJumpIfFalse (6) + 1 0 LOAD_CONST (True) + 1 POP_JUMP_IF_TRUE (6) + 2 LOAD_CONST (False) + 3 POP_JUMP_IF_TRUE (6) + 4 LOAD_CONST (False) + 5 POP_JUMP_IF_FALSE (6) - 2 >> 6 ReturnConst (None) + 2 >> 6 RETURN_CONST (None) diff --git a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__nested_double_async_with.snap b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__nested_double_async_with.snap index 435b73a14de..6cd7d4d523f 100644 --- a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__nested_double_async_with.snap +++ b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__nested_double_async_with.snap @@ -2,85 +2,85 @@ source: crates/codegen/src/compile.rs expression: "compile_exec(\"\\\nfor stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):\n with self.subTest(type=type(stop_exc)):\n try:\n async with egg():\n raise stop_exc\n except Exception as ex:\n self.assertIs(ex, stop_exc)\n else:\n self.fail(f'{stop_exc} was suppressed')\n\")" --- - 1 0 SetupLoop - 1 LoadNameAny (0, StopIteration) - 2 LoadConst ("spam") - 3 CallFunctionPositional(1) - 4 LoadNameAny (1, StopAsyncIteration) - 5 LoadConst ("ham") - 6 CallFunctionPositional(1) - 7 BuildTuple (2) - 8 GetIter - >> 9 ForIter (71) - 10 StoreLocal (2, stop_exc) + 1 0 SETUP_LOOP + 1 LOAD_NAME_ANY (0, StopIteration) + 2 LOAD_CONST ("spam") + 3 CALL_FUNCTION_POSITIONAL(1) + 4 LOAD_NAME_ANY (1, StopAsyncIteration) + 5 LOAD_CONST ("ham") + 6 CALL_FUNCTION_POSITIONAL(1) + 7 BUILD_TUPLE (2) + 8 GET_ITER + >> 9 FOR_ITER (71) + 10 STORE_LOCAL (2, stop_exc) - 2 11 LoadNameAny (3, self) - 12 LoadMethod (4, subTest) - 13 LoadNameAny (5, type) - 14 LoadNameAny (2, stop_exc) - 15 CallFunctionPositional(1) - 16 LoadConst (("type")) - 17 CallMethodKeyword (1) - 18 SetupWith (68) - 19 Pop + 2 11 LOAD_NAME_ANY (3, self) + 12 LOAD_METHOD (4, subTest) + 13 LOAD_NAME_ANY (5, type) + 14 LOAD_NAME_ANY (2, stop_exc) + 15 CALL_FUNCTION_POSITIONAL(1) + 16 LOAD_CONST (("type")) + 17 CALL_METHOD_KEYWORD (1) + 18 SETUP_WITH (68) + 19 POP_TOP - 3 20 SetupExcept (42) + 3 20 SETUP_EXCEPT (42) - 4 21 LoadNameAny (6, egg) - 22 CallFunctionPositional(0) - 23 BeforeAsyncWith - 24 GetAwaitable - 25 LoadConst (None) - 26 YieldFrom - 27 Resume (3) - 28 SetupAsyncWith (34) - 29 Pop + 4 21 LOAD_NAME_ANY (6, egg) + 22 CALL_FUNCTION_POSITIONAL(0) + 23 BEFORE_ASYNC_WITH + 24 GET_AWAITABLE + 25 LOAD_CONST (None) + 26 YIELD_FROM + 27 RESUME (3) + 28 SETUP_ASYNC_WITH (34) + 29 POP_TOP - 5 30 LoadNameAny (2, stop_exc) - 31 Raise (Raise) + 5 30 LOAD_NAME_ANY (2, stop_exc) + 31 RAISE (Raise) - 4 32 PopBlock - 33 EnterFinally - >> 34 WithCleanupStart - 35 GetAwaitable - 36 LoadConst (None) - 37 YieldFrom - 38 Resume (3) - 39 WithCleanupFinish - 40 PopBlock - 41 Jump (58) - >> 42 CopyItem (1) + 4 32 POP_BLOCK + 33 ENTER_FINALLY + >> 34 WITH_CLEANUP_START + 35 GET_AWAITABLE + 36 LOAD_CONST (None) + 37 YIELD_FROM + 38 RESUME (3) + 39 WITH_CLEANUP_FINISH + 40 POP_BLOCK + 41 JUMP (58) + >> 42 COPY (1) - 6 43 LoadNameAny (7, Exception) + 6 43 LOAD_NAME_ANY (7, Exception) 44 JUMP_IF_NOT_EXC_MATCH(57) - 45 StoreLocal (8, ex) + 45 STORE_LOCAL (8, ex) - 7 46 LoadNameAny (3, self) - 47 LoadMethod (9, assertIs) - 48 LoadNameAny (8, ex) - 49 LoadNameAny (2, stop_exc) - 50 CallMethodPositional (2) - 51 Pop - 52 PopException - 53 LoadConst (None) - 54 StoreLocal (8, ex) - 55 DeleteLocal (8, ex) - 56 Jump (66) - >> 57 Raise (Reraise) + 7 46 LOAD_NAME_ANY (3, self) + 47 LOAD_METHOD (9, assertIs) + 48 LOAD_NAME_ANY (8, ex) + 49 LOAD_NAME_ANY (2, stop_exc) + 50 CALL_METHOD_POSITIONAL(2) + 51 POP_TOP + 52 POP_EXCEPTION + 53 LOAD_CONST (None) + 54 STORE_LOCAL (8, ex) + 55 DELETE_LOCAL (8, ex) + 56 JUMP (66) + >> 57 RAISE (Reraise) - 9 >> 58 LoadNameAny (3, self) - 59 LoadMethod (10, fail) - 60 LoadNameAny (2, stop_exc) + 9 >> 58 LOAD_NAME_ANY (3, self) + 59 LOAD_METHOD (10, fail) + 60 LOAD_NAME_ANY (2, stop_exc) 61 FORMAT_SIMPLE - 62 LoadConst (" was suppressed") - 63 BuildString (2) - 64 CallMethodPositional (1) - 65 Pop + 62 LOAD_CONST (" was suppressed") + 63 BUILD_STRING (2) + 64 CALL_METHOD_POSITIONAL(1) + 65 POP_TOP - 2 >> 66 PopBlock - 67 EnterFinally - >> 68 WithCleanupStart - 69 WithCleanupFinish - 70 Jump (9) - >> 71 PopBlock - 72 ReturnConst (None) + 2 >> 66 POP_BLOCK + 67 ENTER_FINALLY + >> 68 WITH_CLEANUP_START + 69 WITH_CLEANUP_FINISH + 70 JUMP (9) + >> 71 POP_BLOCK + 72 RETURN_CONST (None) diff --git a/crates/compiler-core/src/bytecode.rs b/crates/compiler-core/src/bytecode.rs index 11d2a7b5f1b..dd49e679f27 100644 --- a/crates/compiler-core/src/bytecode.rs +++ b/crates/compiler-core/src/bytecode.rs @@ -782,7 +782,6 @@ pub enum Instruction { MatchMapping, MatchSequence, Nop, - Pop, PopBlock, PopException, /// Pop the top of the stack, and jump if this value is false. @@ -793,6 +792,11 @@ pub enum Instruction { PopJumpIfTrue { target: Arg