diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 3530fc6c535..0a9c237ecd0 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -195,8 +195,7 @@ def test_chain_reducible(self): self.assertRaises(TypeError, list, oper(chain(2, 3))) for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, chain('abc', 'def'), compare=list('abcdef')) - # TODO: RUSTPYTHON - @unittest.expectedFailure + def test_chain_setstate(self): self.assertRaises(TypeError, chain().__setstate__, ()) self.assertRaises(TypeError, chain().__setstate__, []) @@ -209,6 +208,7 @@ def test_chain_setstate(self): it = chain() it.__setstate__((iter(['abc', 'def']), iter(['ghi']))) self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f']) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_combinations(self): diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 7dba8364685..29d8b14d99c 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -58,7 +58,36 @@ mod decl { fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias { PyGenericAlias::new(cls, args, vm) } + + #[pymethod(magic)] + fn setstate(&self, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { + let args = state.as_slice(); + if args.is_empty() { + let msg = String::from("function takes at leat 1 arguments (0 given)"); + return Err(vm.new_type_error(msg)); + } + if args.len() > 2 { + let msg = format!("function takes at most 2 arguments ({} given)", args.len()); + return Err(vm.new_type_error(msg)); + } + let source = &args[0]; + let active = args.get(1); + if !PyIter::check(source.as_ref()) + || !active.map_or(true, |active| PyIter::check(active.as_ref())) + { + return Err(vm.new_type_error(String::from("Arguments must be iterators."))); + } + *self.iterables.write() = source.try_to_value(vm)?; + self.cur_idx.store(0); + *self.cached_iter.write() = if let Some(active) = active { + Some(active.clone().get_iter(vm)?) + } else { + None + }; + Ok(()) + } } + impl IterNextIterable for PyItertoolsChain {} impl IterNext for PyItertoolsChain { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult {