Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions crates/vm/src/builtins/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyStrInterned, pystr::AsPyStr},
class::PyClassImpl,
common::lock::PyRwLock,
convert::ToPyObject,
function::{FuncArgs, PyMethodDef},
types::{GetAttr, Initializer, Representable},
Expand Down Expand Up @@ -48,6 +49,7 @@ pub struct PyModule {
// PyObject *md_dict;
pub def: Option<&'static PyModuleDef>,
// state: Any
state: PyRwLock<Option<PyObjectRef>>,
// weaklist
// for logging purposes after md_dict is cleared
pub name: Option<&'static PyStrInterned>,
Expand All @@ -72,13 +74,15 @@ impl PyModule {
pub const fn new() -> Self {
Self {
def: None,
state: PyRwLock::new(None),
name: None,
}
}

pub const fn from_def(def: &'static PyModuleDef) -> Self {
Self {
def: Some(def),
state: PyRwLock::new(None),
name: Some(def.name),
}
}
Expand All @@ -87,6 +91,20 @@ impl PyModule {
let doc = module.def.unwrap().doc.map(|doc| doc.to_owned());
module.init_dict(module.name.unwrap(), doc, vm);
}

/// Return the stored module state if it exists and matches `T`. Returns `None` when no state
/// has been set or when the stored state is of a different type.
pub fn get_state<T: PyPayload>(&self) -> Option<PyRef<T>> {
self.state
.read()
.as_ref()
.and_then(|obj| obj.clone().downcast().ok())
}

/// Set the module state.
pub fn set_state(&self, state: PyObjectRef) {
*self.state.write() = Some(state);
}
}

impl Py<PyModule> {
Expand Down Expand Up @@ -224,3 +242,33 @@ impl Representable for PyModule {
pub(crate) fn init(context: &Context) {
PyModule::extend_class(context, context.types.module_type);
}

#[cfg(test)]
mod tests {
use crate::{AsObject, builtins::PyInt, vm::Interpreter};
use malachite_bigint::BigInt;

#[test]
fn module_state_is_per_module_and_typed() {
Interpreter::without_stdlib(Default::default()).enter(|vm| {
let m1 = vm.new_module("m1", vm.ctx.new_dict(), None);
let m2 = vm.new_module("m2", vm.ctx.new_dict(), None);

assert!(m1.get_state::<PyInt>().is_none());

let s1 = vm.ctx.new_int(1);
let s2 = vm.ctx.new_int(2);
m1.set_state(s1.as_object().to_owned());
m2.set_state(s2.as_object().to_owned());

assert_eq!(
m1.get_state::<PyInt>().unwrap().as_bigint(),
&BigInt::from(1)
);
assert_eq!(
m2.get_state::<PyInt>().unwrap().as_bigint(),
&BigInt::from(2)
);
});
}
}
Loading