diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 6c638f07eba..d0cdd9fe7a8 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -181,8 +181,6 @@ def test_chain_from_iterable(self): self.assertEqual(take(4, chain.from_iterable(['abc', 'def'])), list('abcd')) self.assertRaises(TypeError, list, chain.from_iterable([2, 3])) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_chain_reducible(self): for oper in [copy.deepcopy] + picklecopiers: it = chain('abc', 'def') @@ -195,8 +193,7 @@ def test_chain_reducible(self): self.assertRaises(TypeError, list, oper(chain(2, 3))) for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, chain('abc', 'def'), compare=list('abcdef')) - # TODO: RUSTPYTHON - @unittest.expectedFailure + def test_chain_setstate(self): self.assertRaises(TypeError, chain().__setstate__, ()) self.assertRaises(TypeError, chain().__setstate__, []) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index a8acdb11687..09a3f263440 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -29,7 +29,7 @@ mod decl { active: PyRwLock>, } - #[pyclass(with(IterNext), flags(BASETYPE))] + #[pyclass(with(IterNext), flags(BASETYPE, HAS_DICT))] impl PyItertoolsChain { #[pyslot] fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -59,6 +59,53 @@ mod decl { fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias { PyGenericAlias::new(cls, args, vm) } + + #[pymethod(magic)] + fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyResult { + let source = zelf.source.read().clone(); + let active = zelf.active.read().clone(); + let cls = zelf.class().to_owned(); + let empty_tuple = vm.ctx.empty_tuple.clone(); + let reduced = match source { + Some(source) => match active { + Some(active) => vm.new_tuple((cls, empty_tuple, (source, active))), + None => vm.new_tuple((cls, empty_tuple, (source,))), + }, + None => vm.new_tuple((cls, empty_tuple)), + }; + Ok(reduced) + } + + #[pymethod(magic)] + fn setstate(zelf: PyRef, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { + let args = state.as_slice(); + if args.is_empty() { + let msg = String::from("function takes at leat 1 arguments (0 given)"); + return Err(vm.new_type_error(msg)); + } + if args.len() > 2 { + let msg = format!("function takes at most 2 arguments ({} given)", args.len()); + return Err(vm.new_type_error(msg)); + } + let source = &args[0]; + if args.len() == 1 { + if !PyIter::check(source.as_ref()) { + return Err(vm.new_type_error(String::from("Arguments must be iterators."))); + } + *zelf.source.write() = source.to_owned().try_into_value(vm)?; + return Ok(()); + } + let active = &args[1]; + + if !PyIter::check(source.as_ref()) || !PyIter::check(active.as_ref()) { + return Err(vm.new_type_error(String::from("Arguments must be iterators."))); + } + let mut source_lock = zelf.source.write(); + let mut active_lock = zelf.active.write(); + *source_lock = source.to_owned().try_into_value(vm)?; + *active_lock = active.to_owned().try_into_value(vm)?; + Ok(()) + } } impl IterNextIterable for PyItertoolsChain {} impl IterNext for PyItertoolsChain {