Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Lib/test/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,6 @@ def _test_gen():

class TestOneTrickPonyABCs(ABCTestCase):

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_Awaitable(self):
def gen():
yield
Expand Down Expand Up @@ -840,7 +839,6 @@ class CoroLike: pass
CoroLike = None
support.gc_collect() # Kill CoroLike to clean-up ABCMeta cache

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_Coroutine(self):
def gen():
yield
Expand Down
2 changes: 0 additions & 2 deletions Lib/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,7 +2174,6 @@ def foo():
foo = types.coroutine(foo)
self.assertIs(aw, foo())

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_async_def(self):
# Test that types.coroutine passes 'async def' coroutines
# without modification
Expand Down Expand Up @@ -2431,7 +2430,6 @@ def foo():
foo = types.coroutine(foo)
self.assertIs(foo(), gencoro)

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_genfunc(self):
def gen(): yield
self.assertIs(types.coroutine(gen), gen)
Expand Down
128 changes: 63 additions & 65 deletions Lib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,78 @@
Define names for built-in types that aren't directly accessible as a builtin.
"""

import sys

# Iterators in Python aren't a matter of type but of protocol. A large
# and changing number of builtin types implement *some* flavor of
# iterator. Don't check the type! Use hasattr to check for both
# "__iter__" and "__next__" attributes instead.

def _f(): pass
FunctionType = type(_f)
LambdaType = type(lambda: None) # Same as FunctionType
CodeType = type(_f.__code__)
MappingProxyType = type(type.__dict__)
SimpleNamespace = type(sys.implementation)

def _cell_factory():
a = 1
def f():
nonlocal a
return f.__closure__[0]
CellType = type(_cell_factory())

def _g():
yield 1
GeneratorType = type(_g())

async def _c(): pass
_c = _c()
CoroutineType = type(_c)
_c.close() # Prevent ResourceWarning

async def _ag():
yield
_ag = _ag()
AsyncGeneratorType = type(_ag)

class _C:
def _m(self): pass
MethodType = type(_C()._m)
try:
from _types import *
except ImportError:
import sys

def _f(): pass
FunctionType = type(_f)
LambdaType = type(lambda: None) # Same as FunctionType
CodeType = type(_f.__code__)
MappingProxyType = type(type.__dict__)
SimpleNamespace = type(sys.implementation)

def _cell_factory():
a = 1
def f():
nonlocal a
return f.__closure__[0]
CellType = type(_cell_factory())

def _g():
yield 1
GeneratorType = type(_g())

async def _c(): pass
_c = _c()
CoroutineType = type(_c)
_c.close() # Prevent ResourceWarning

async def _ag():
yield
_ag = _ag()
AsyncGeneratorType = type(_ag)

class _C:
def _m(self): pass
MethodType = type(_C()._m)

BuiltinFunctionType = type(len)
BuiltinMethodType = type([].append) # Same as BuiltinFunctionType

WrapperDescriptorType = type(object.__init__)
MethodWrapperType = type(object().__str__)
MethodDescriptorType = type(str.join)
ClassMethodDescriptorType = type(dict.__dict__['fromkeys'])

ModuleType = type(sys)

BuiltinFunctionType = type(len)
BuiltinMethodType = type([].append) # Same as BuiltinFunctionType
try:
raise TypeError
except TypeError as exc:
TracebackType = type(exc.__traceback__)
FrameType = type(exc.__traceback__.tb_frame)

WrapperDescriptorType = type(object.__init__)
MethodWrapperType = type(object().__str__)
MethodDescriptorType = type(str.join)
ClassMethodDescriptorType = type(dict.__dict__['fromkeys'])
GetSetDescriptorType = type(FunctionType.__code__)
MemberDescriptorType = type(FunctionType.__globals__)

ModuleType = type(sys)
GenericAlias = type(list[int])
UnionType = type(int | str)

try:
raise TypeError
except TypeError as exc:
TracebackType = type(exc.__traceback__)
FrameType = type(exc.__traceback__.tb_frame)
EllipsisType = type(Ellipsis)
NoneType = type(None)
NotImplementedType = type(NotImplemented)

GetSetDescriptorType = type(FunctionType.__code__)
MemberDescriptorType = type(FunctionType.__globals__)
# CapsuleType cannot be accessed from pure Python,
# so there is no fallback definition.

del sys, _f, _g, _C, _c, _ag, _cell_factory # Not for export
del sys, _f, _g, _C, _c, _ag, _cell_factory # Not for export


# Provide a PEP 3115 compliant mechanism for class creation
Expand Down Expand Up @@ -279,8 +292,7 @@ def coroutine(func):
if not callable(func):
raise TypeError('types.coroutine() expects a callable')

# XXX RUSTPYTHON TODO: iterable coroutine
if (False and func.__class__ is FunctionType and
if (func.__class__ is FunctionType and
getattr(func, '__code__', None).__class__ is CodeType):

co_flags = func.__code__.co_flags
Expand Down Expand Up @@ -325,18 +337,4 @@ def wrapped(*args, **kwargs):

return wrapped

GenericAlias = type(list[int])
UnionType = type(int | str)

EllipsisType = type(Ellipsis)
NoneType = type(None)
NotImplementedType = type(NotImplemented)

def __getattr__(name):
if name == 'CapsuleType':
import _socket
return type(_socket.CAPI)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

__all__ = [n for n in globals() if n[:1] != '_']
__all__ += ['CapsuleType']
__all__ = [n for n in globals() if not n.startswith('_')] # for pydoc
1 change: 1 addition & 0 deletions crates/compiler-core/src/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ bitflags! {
const VARKEYWORDS = 0x0008;
const GENERATOR = 0x0020;
const COROUTINE = 0x0080;
const ITERABLE_COROUTINE = 0x0100;
/// If a code object represents a function and has a docstring,
/// this bit is set and the first item in co_consts is the docstring.
const HAS_DOCSTRING = 0x4000000;
Expand Down
24 changes: 23 additions & 1 deletion crates/vm/src/builtins/asyncgenerator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{PyCode, PyGenericAlias, PyStrRef, PyType, PyTypeRef};
use super::{PyCode, PyGenerator, PyGenericAlias, PyStrRef, PyType, PyTypeRef};
use crate::{
AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::PyBaseExceptionRef,
Expand Down Expand Up @@ -592,6 +592,28 @@ impl PyAnextAwaitable {
let awaitable = if wrapped.class().is(vm.ctx.types.coroutine_type) {
// Coroutine - get __await__ later
wrapped.clone()
} else if let Some(generator) = wrapped.downcast_ref::<PyGenerator>() {
// Generator with CO_ITERABLE_COROUTINE flag can be awaited
// (e.g., generators decorated with @types.coroutine)
if generator
.as_coro()
.frame()
.code
.flags
.contains(crate::bytecode::CodeFlags::ITERABLE_COROUTINE)
{
// Return the generator itself as the iterator
return Ok(wrapped.clone());
}
// Fall through: try to get __await__ method for generator subclasses
if let Some(await_method) = vm.get_method(wrapped.clone(), identifier!(vm, __await__)) {
await_method?.call((), vm)?
} else {
return Err(vm.new_type_error(format!(
"object {} can't be used in 'await' expression",
wrapped.class().name()
)));
}
} else {
// Try to get __await__ method
if let Some(await_method) = vm.get_method(wrapped.clone(), identifier!(vm, __await__)) {
Expand Down
21 changes: 20 additions & 1 deletion crates/vm/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,23 @@ impl ExecutingFrame<'_> {
);
}
awaited_obj
} else if let Some(generator) = awaited_obj.downcast_ref::<PyGenerator>() {
// Generator with CO_ITERABLE_COROUTINE flag can be awaited
// (e.g., generators decorated with @types.coroutine)
if generator
.as_coro()
.frame()
.code
.flags
.contains(bytecode::CodeFlags::ITERABLE_COROUTINE)
{
awaited_obj
} else {
return Err(vm.new_type_error(format!(
"object {} can't be used in 'await' expression",
awaited_obj.class().name(),
)));
}
} else {
let await_method = vm.get_method_or_type_error(
awaited_obj.clone(),
Expand Down Expand Up @@ -1051,7 +1068,9 @@ impl ExecutingFrame<'_> {
let iterable = self.pop_value();
let iter = if iterable.class().is(vm.ctx.types.coroutine_type) {
// Coroutine requires CO_COROUTINE or CO_ITERABLE_COROUTINE flag
if !self.code.flags.intersects(bytecode::CodeFlags::COROUTINE) {
if !self.code.flags.intersects(
bytecode::CodeFlags::COROUTINE | bytecode::CodeFlags::ITERABLE_COROUTINE,
) {
return Err(vm.new_type_error(
"cannot 'yield from' a coroutine object in a non-coroutine generator"
.to_owned(),
Expand Down
Loading