diff --git a/crates/vm/src/builtins/module.rs b/crates/vm/src/builtins/module.rs index f8e42b28e0..d318070816 100644 --- a/crates/vm/src/builtins/module.rs +++ b/crates/vm/src/builtins/module.rs @@ -3,6 +3,7 @@ use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyStrInterned, pystr::AsPyStr}, class::PyClassImpl, + common::lock::PyRwLock, convert::ToPyObject, function::{FuncArgs, PyMethodDef}, types::{GetAttr, Initializer, Representable}, @@ -48,6 +49,7 @@ pub struct PyModule { // PyObject *md_dict; pub def: Option<&'static PyModuleDef>, // state: Any + state: PyRwLock>, // weaklist // for logging purposes after md_dict is cleared pub name: Option<&'static PyStrInterned>, @@ -72,6 +74,7 @@ impl PyModule { pub const fn new() -> Self { Self { def: None, + state: PyRwLock::new(None), name: None, } } @@ -79,6 +82,7 @@ impl PyModule { pub const fn from_def(def: &'static PyModuleDef) -> Self { Self { def: Some(def), + state: PyRwLock::new(None), name: Some(def.name), } } @@ -87,6 +91,20 @@ impl PyModule { let doc = module.def.unwrap().doc.map(|doc| doc.to_owned()); module.init_dict(module.name.unwrap(), doc, vm); } + + /// Return the stored module state if it exists and matches `T`. Returns `None` when no state + /// has been set or when the stored state is of a different type. + pub fn get_state(&self) -> Option> { + self.state + .read() + .as_ref() + .and_then(|obj| obj.clone().downcast().ok()) + } + + /// Set the module state. + pub fn set_state(&self, state: PyObjectRef) { + *self.state.write() = Some(state); + } } impl Py { @@ -224,3 +242,33 @@ impl Representable for PyModule { pub(crate) fn init(context: &Context) { PyModule::extend_class(context, context.types.module_type); } + +#[cfg(test)] +mod tests { + use crate::{AsObject, builtins::PyInt, vm::Interpreter}; + use malachite_bigint::BigInt; + + #[test] + fn module_state_is_per_module_and_typed() { + Interpreter::without_stdlib(Default::default()).enter(|vm| { + let m1 = vm.new_module("m1", vm.ctx.new_dict(), None); + let m2 = vm.new_module("m2", vm.ctx.new_dict(), None); + + assert!(m1.get_state::().is_none()); + + let s1 = vm.ctx.new_int(1); + let s2 = vm.ctx.new_int(2); + m1.set_state(s1.as_object().to_owned()); + m2.set_state(s2.as_object().to_owned()); + + assert_eq!( + m1.get_state::().unwrap().as_bigint(), + &BigInt::from(1) + ); + assert_eq!( + m2.get_state::().unwrap().as_bigint(), + &BigInt::from(2) + ); + }); + } +}