From 053cfeecce89148cded2106d9eb9fe9f71699139 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 30 Jul 2025 10:28:04 +0900 Subject: [PATCH 1/3] downcastable_from --- vm/src/builtins/str.rs | 264 ++++++++++++++++++++++++++++++------- vm/src/convert/try_from.rs | 10 +- vm/src/object/core.rs | 23 +++- vm/src/object/payload.rs | 27 +++- 4 files changed, 267 insertions(+), 57 deletions(-) diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 2e4678d557..57bc1428b6 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -15,7 +15,7 @@ use crate::{ format::{format, format_map}, function::{ArgIterable, ArgSize, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue}, intern::PyInterned, - object::{Traverse, TraverseFn}, + object::{MaybeTraverse, Traverse, TraverseFn}, protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods}, sequence::SequenceExt, sliceable::{SequenceIndex, SliceableSequenceOp}, @@ -64,6 +64,9 @@ impl<'a> TryFromBorrowedObject<'a> for &'a Wtf8 { } } +pub type PyStrRef = PyRef; +pub type PyUtf8StrRef = PyRef; + #[pyclass(module = false, name = "str")] pub struct PyStr { data: StrData, @@ -80,30 +83,6 @@ impl fmt::Debug for PyStr { } } -#[repr(transparent)] -#[derive(Debug)] -pub struct PyUtf8Str(PyStr); - -// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str -impl std::ops::Deref for PyUtf8Str { - type Target = PyStr; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl PyUtf8Str { - /// Returns the underlying string slice. - pub fn as_str(&self) -> &str { - debug_assert!( - self.0.is_utf8(), - "PyUtf8Str invariant violated: inner string is not valid UTF-8" - ); - // Safety: This is safe because the type invariant guarantees UTF-8 validity. - unsafe { self.0.to_str().unwrap_unchecked() } - } -} - impl AsRef for PyStr { #[track_caller] // <- can remove this once it doesn't panic fn as_ref(&self) -> &str { @@ -241,8 +220,6 @@ impl Default for PyStr { } } -pub type PyStrRef = PyRef; - impl fmt::Display for PyStr { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -374,7 +351,7 @@ impl Constructor for PyStr { type Args = StrArgs; fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult { - let string: PyStrRef = match args.object { + let string: PyRef = match args.object { OptionalArg::Present(input) => { if let OptionalArg::Present(enc) = args.encoding { vm.state.codec_registry.decode_text( @@ -458,7 +435,7 @@ impl PyStr { self.data.as_str() } - pub fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> { + fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> { if self.is_utf8() { Ok(()) } else { @@ -531,6 +508,22 @@ impl PyStr { .mul(vm, value) .map(|x| Self::from(unsafe { Wtf8Buf::from_bytes_unchecked(x) }).into_ref(&vm.ctx)) } + + pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a PyUtf8Str> { + // Check if the string contains surrogates + self.ensure_valid_utf8(vm)?; + // If no surrogates, we can safely cast to PyStr + Ok(unsafe { &*(self as *const _ as *const PyUtf8Str) }) + } +} + +impl Py { + pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a Py> { + // Check if the string contains surrogates + self.ensure_valid_utf8(vm)?; + // If no surrogates, we can safely cast to PyStr + Ok(unsafe { &*(self as *const _ as *const Py) }) + } } #[pyclass( @@ -980,7 +973,11 @@ impl PyStr { } #[pymethod(name = "__format__")] - fn __format__(zelf: PyRef, spec: PyStrRef, vm: &VirtualMachine) -> PyResult { + fn __format__( + zelf: PyRef, + spec: PyStrRef, + vm: &VirtualMachine, + ) -> PyResult> { let spec = spec.as_str(); if spec.is_empty() { return if zelf.class().is(vm.ctx.types.str_type) { @@ -989,7 +986,7 @@ impl PyStr { zelf.as_object().str(vm) }; } - + let zelf = zelf.try_into_utf8(vm)?; let s = FormatSpec::parse(spec) .and_then(|format_spec| { format_spec.format_string(&CharLenStr(zelf.as_str(), zelf.char_len())) @@ -1351,8 +1348,12 @@ impl PyStr { } #[pymethod] - fn expandtabs(&self, args: anystr::ExpandTabsArgs) -> String { - rustpython_common::str::expandtabs(self.as_str(), args.tabsize()) + fn expandtabs(&self, args: anystr::ExpandTabsArgs, vm: &VirtualMachine) -> PyResult { + // TODO: support WTF-8 + Ok(rustpython_common::str::expandtabs( + self.try_as_utf8(vm)?.as_str(), + args.tabsize(), + )) } #[pymethod] @@ -1480,20 +1481,6 @@ impl PyStr { } } -struct CharLenStr<'a>(&'a str, usize); -impl std::ops::Deref for CharLenStr<'_> { - type Target = str; - - fn deref(&self) -> &Self::Target { - self.0 - } -} -impl crate::common::format::CharLen for CharLenStr<'_> { - fn char_len(&self) -> usize { - self.1 - } -} - #[pyclass] impl PyRef { #[pymethod] @@ -1504,7 +1491,7 @@ impl PyRef { } } -impl PyStrRef { +impl PyRef { pub fn is_empty(&self) -> bool { (**self).is_empty() } @@ -1526,6 +1513,20 @@ impl PyStrRef { } } +struct CharLenStr<'a>(&'a str, usize); +impl std::ops::Deref for CharLenStr<'_> { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0 + } +} +impl crate::common::format::CharLen for CharLenStr<'_> { + fn char_len(&self) -> usize { + self.1 + } +} + impl Representable for PyStr { #[inline] fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -1941,6 +1942,170 @@ impl AnyStrWrapper for PyStrRef { } } +#[repr(transparent)] +#[derive(Debug)] +pub struct PyUtf8Str(PyStr); + +impl fmt::Display for PyUtf8Str { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl MaybeTraverse for PyUtf8Str { + fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + self.0.try_traverse(traverse_fn); + } +} + +impl PyPayload for PyUtf8Str { + #[inline] + fn class(ctx: &Context) -> &'static Py { + ctx.types.str_type + } + + fn payload_type_id() -> std::any::TypeId { + std::any::TypeId::of::() + } + + fn downcastable_from(obj: &PyObject) -> bool { + obj.typeid() == Self::payload_type_id() && { + // SAFETY: we know the object is a PyStr in this context + let wtf8 = unsafe { obj.downcast_unchecked_ref::() }; + wtf8.is_utf8() + } + } + + fn try_downcast_from(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> { + let str = obj.try_downcast_ref::(vm)?; + str.ensure_valid_utf8(vm) + } +} + +impl<'a> From<&'a AsciiStr> for PyUtf8Str { + fn from(s: &'a AsciiStr) -> Self { + s.to_owned().into() + } +} + +impl From for PyUtf8Str { + fn from(s: AsciiString) -> Self { + s.into_boxed_ascii_str().into() + } +} + +impl From> for PyUtf8Str { + fn from(s: Box) -> Self { + let data = StrData::from(s); + unsafe { Self::from_str_data_unchecked(data) } + } +} + +impl From for PyUtf8Str { + fn from(ch: AsciiChar) -> Self { + AsciiString::from(ch).into() + } +} + +impl<'a> From<&'a str> for PyUtf8Str { + fn from(s: &'a str) -> Self { + s.to_owned().into() + } +} + +impl From for PyUtf8Str { + fn from(s: String) -> Self { + s.into_boxed_str().into() + } +} + +impl From for PyUtf8Str { + fn from(ch: char) -> Self { + let data = StrData::from(ch); + unsafe { Self::from_str_data_unchecked(data) } + } +} + +impl<'a> From> for PyUtf8Str { + fn from(s: std::borrow::Cow<'a, str>) -> Self { + s.into_owned().into() + } +} + +impl From> for PyUtf8Str { + #[inline] + fn from(value: Box) -> Self { + let data = StrData::from(value); + unsafe { Self::from_str_data_unchecked(data) } + } +} + +impl AsRef for PyUtf8Str { + #[inline] + fn as_ref(&self) -> &Wtf8 { + self.0.as_wtf8() + } +} + +impl AsRef for PyUtf8Str { + #[inline] + fn as_ref(&self) -> &str { + self.0.as_str() + } +} + +impl PyUtf8Str { + // Create a new `PyUtf8Str` from `StrData` without validation. + // This function must be only used in this module to create conversions. + // # Safety: must be called with a valid UTF-8 string data. + unsafe fn from_str_data_unchecked(data: StrData) -> Self { + Self(PyStr::from(data)) + } + + /// Returns the underlying string slice. + pub fn as_str(&self) -> &str { + debug_assert!( + self.0.is_utf8(), + "PyUtf8Str invariant violated: inner string is not valid UTF-8" + ); + // Safety: This is safe because the type invariant guarantees UTF-8 validity. + unsafe { self.0.to_str().unwrap_unchecked() } + } + + #[inline] + pub fn byte_len(&self) -> usize { + self.0.byte_len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + #[inline] + pub fn char_len(&self) -> usize { + self.0.char_len() + } +} + +impl Py { + /// Upcast to PyStr. + pub fn as_pystr(&self) -> &Py { + unsafe { + // Safety: PyUtf8Str is a wrapper around PyStr, so this cast is safe. + &*(self as *const Self as *const Py) + } + } +} + +impl PartialEq for PyUtf8Str { + fn eq(&self, other: &Self) -> bool { + self.as_str() == other.as_str() + } +} +impl Eq for PyUtf8Str {} + impl AnyStrContainer for String { fn new() -> Self { Self::new() @@ -2302,7 +2467,8 @@ impl std::fmt::Display for PyStrInterned { impl AsRef for PyStrInterned { #[inline(always)] fn as_ref(&self) -> &str { - self.as_str() + self.to_str() + .expect("Interned PyStr should always be valid UTF-8") } } diff --git a/vm/src/convert/try_from.rs b/vm/src/convert/try_from.rs index 3fda682d40..4f921e9c5d 100644 --- a/vm/src/convert/try_from.rs +++ b/vm/src/convert/try_from.rs @@ -78,12 +78,12 @@ where #[inline] fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { let class = T::class(&vm.ctx); - let result = if obj.fast_isinstance(class) { - obj.downcast() + if obj.fast_isinstance(class) { + T::try_downcast_from(&obj, vm)?; + Ok(unsafe { obj.downcast_unchecked() }) } else { - Err(obj) - }; - result.map_err(|obj| vm.new_downcast_type_error(class, &obj)) + Err(vm.new_downcast_type_error(class, &obj)) + } } } diff --git a/vm/src/object/core.rs b/vm/src/object/core.rs index 54a6657a9f..57576ce703 100644 --- a/vm/src/object/core.rs +++ b/vm/src/object/core.rs @@ -448,7 +448,7 @@ impl PyInner { let member_count = typ.slots.member_count; Box::new(Self { ref_count: RefCount::new(), - typeid: TypeId::of::(), + typeid: T::payload_type_id(), vtable: PyObjVTable::of::(), typ: PyAtomicRef::from(typ), dict: dict.map(InstanceDict::new), @@ -541,6 +541,11 @@ impl PyObjectRef { } } + pub fn try_downcast(self, vm: &VirtualMachine) -> PyResult> { + T::try_downcast_from(&self, vm)?; + Ok(unsafe { self.downcast_unchecked() }) + } + /// Force to downcast this reference to a subclass. /// /// # Safety @@ -720,10 +725,24 @@ impl PyObject { } } + #[inline] + pub(crate) fn typeid(&self) -> TypeId { + self.0.typeid + } + /// Check if this object can be downcast to T. #[inline(always)] pub fn downcastable(&self) -> bool { - self.0.typeid == T::payload_type_id() + T::downcastable_from(self) + } + + /// Attempt to downcast this reference to a subclass. + pub fn try_downcast_ref<'a, T: PyObjectPayload>( + &'a self, + vm: &VirtualMachine, + ) -> PyResult<&'a Py> { + T::try_downcast_from(self, vm)?; + Ok(unsafe { self.downcast_unchecked_ref::() }) } /// Attempt to downcast this reference to a subclass. diff --git a/vm/src/object/payload.rs b/vm/src/object/payload.rs index f223af6e96..0b7bfe0dc1 100644 --- a/vm/src/object/payload.rs +++ b/vm/src/object/payload.rs @@ -1,6 +1,6 @@ use crate::object::{MaybeTraverse, Py, PyObjectRef, PyRef, PyResult}; use crate::{ - PyRefExact, + PyObject, PyRefExact, builtins::{PyBaseExceptionRef, PyType, PyTypeRef}, types::PyTypeFlags, vm::{Context, VirtualMachine}, @@ -23,6 +23,31 @@ pub trait PyPayload: fn payload_type_id() -> std::any::TypeId { std::any::TypeId::of::() } + + /// # Safety: this function should only be called if `payload_type_id` matches the type of `obj`. + #[inline] + fn downcastable_from(obj: &PyObject) -> bool { + obj.typeid() == Self::payload_type_id() + } + + fn try_downcast_from(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> { + if Self::downcastable_from(obj) { + return Ok(()); + } + + #[cold] + fn raise_downcast_type_error( + vm: &VirtualMachine, + class: &Py, + obj: &PyObject, + ) -> PyBaseExceptionRef { + vm.new_downcast_type_error(class, obj) + } + + let class = Self::class(&vm.ctx); + Err(raise_downcast_type_error(vm, class, obj)) + } + fn class(ctx: &Context) -> &'static Py; #[inline] From d46c882347cb693f271c0d59eccfd5e96966ed53 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 30 Jul 2025 10:28:13 +0900 Subject: [PATCH 2/3] remove try_to_str Rewrite sqlite3 UTF8 validation --- stdlib/src/sqlite.rs | 29 ++++++++++++++--------------- vm/src/builtins/mod.rs | 2 +- vm/src/builtins/str.rs | 6 ------ vm/src/macros.rs | 9 +++++++++ vm/src/stdlib/codecs.rs | 6 +++--- vm/src/stdlib/io.rs | 26 +++++++++++++------------- vm/src/stdlib/time.rs | 16 ++++++++-------- 7 files changed, 48 insertions(+), 46 deletions(-) diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index 96c5bbebe3..0ee3f43f17 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -59,6 +59,7 @@ mod _sqlite { builtins::{ PyBaseException, PyBaseExceptionRef, PyByteArray, PyBytes, PyDict, PyDictRef, PyFloat, PyInt, PyIntRef, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, + PyUtf8Str, PyUtf8StrRef, }, convert::IntoObject, function::{ @@ -851,7 +852,7 @@ mod _sqlite { } impl Callable for Connection { - type Args = (PyStrRef,); + type Args = (PyUtf8StrRef,); fn call(zelf: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { if let Some(stmt) = Statement::new(zelf, args.0, vm)? { @@ -986,7 +987,7 @@ mod _sqlite { #[pymethod] fn execute( zelf: PyRef, - sql: PyStrRef, + sql: PyUtf8StrRef, parameters: OptionalArg, vm: &VirtualMachine, ) -> PyResult> { @@ -998,7 +999,7 @@ mod _sqlite { #[pymethod] fn executemany( zelf: PyRef, - sql: PyStrRef, + sql: PyUtf8StrRef, seq_of_params: ArgIterable, vm: &VirtualMachine, ) -> PyResult> { @@ -1010,7 +1011,7 @@ mod _sqlite { #[pymethod] fn executescript( zelf: PyRef, - script: PyStrRef, + script: PyUtf8StrRef, vm: &VirtualMachine, ) -> PyResult> { let row_factory = zelf.row_factory.to_owned(); @@ -1159,11 +1160,10 @@ mod _sqlite { #[pymethod] fn create_collation( &self, - name: PyStrRef, + name: PyUtf8StrRef, callable: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { - name.ensure_valid_utf8(vm)?; let name = name.to_cstring(vm)?; let db = self.db_lock(vm)?; let Some(data) = CallbackData::new(callable.clone(), vm) else { @@ -1491,7 +1491,7 @@ mod _sqlite { #[pymethod] fn execute( zelf: PyRef, - sql: PyStrRef, + sql: PyUtf8StrRef, parameters: OptionalArg, vm: &VirtualMachine, ) -> PyResult> { @@ -1563,7 +1563,7 @@ mod _sqlite { #[pymethod] fn executemany( zelf: PyRef, - sql: PyStrRef, + sql: PyUtf8StrRef, seq_of_params: ArgIterable, vm: &VirtualMachine, ) -> PyResult> { @@ -1637,11 +1637,9 @@ mod _sqlite { #[pymethod] fn executescript( zelf: PyRef, - script: PyStrRef, + script: PyUtf8StrRef, vm: &VirtualMachine, ) -> PyResult> { - script.ensure_valid_utf8(vm)?; - let db = zelf.connection.db_lock(vm)?; db.sql_limit(script.byte_len(), vm)?; @@ -2375,10 +2373,9 @@ mod _sqlite { impl Statement { fn new( connection: &Connection, - sql: PyStrRef, + sql: PyUtf8StrRef, vm: &VirtualMachine, ) -> PyResult> { - let sql = sql.try_into_utf8(vm)?; if sql.as_str().contains('\0') { return Err(new_programming_error( vm, @@ -2731,6 +2728,7 @@ mod _sqlite { let val = val.to_f64(); unsafe { sqlite3_bind_double(self.st, pos, val) } } else if let Some(val) = obj.downcast_ref::() { + let val = val.try_as_utf8(vm)?; let (ptr, len) = str_to_ptr_len(val, vm)?; unsafe { sqlite3_bind_text(self.st, pos, ptr, len, SQLITE_TRANSIENT()) } } else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, obj) { @@ -2990,6 +2988,7 @@ mod _sqlite { } else if let Some(val) = val.downcast_ref::() { sqlite3_result_double(self.ctx, val.to_f64()) } else if let Some(val) = val.downcast_ref::() { + let val = val.try_as_utf8(vm)?; let (ptr, len) = str_to_ptr_len(val, vm)?; sqlite3_result_text(self.ctx, ptr, len, SQLITE_TRANSIENT()) } else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, val) { @@ -3070,8 +3069,8 @@ mod _sqlite { } } - fn str_to_ptr_len(s: &PyStr, vm: &VirtualMachine) -> PyResult<(*const libc::c_char, i32)> { - let s_str = s.try_to_str(vm)?; + fn str_to_ptr_len(s: &PyUtf8Str, vm: &VirtualMachine) -> PyResult<(*const libc::c_char, i32)> { + let s_str = s.as_str(); let len = c_int::try_from(s_str.len()) .map_err(|_| vm.new_overflow_error("TEXT longer than INT_MAX bytes"))?; let ptr = s_str.as_ptr().cast(); diff --git a/vm/src/builtins/mod.rs b/vm/src/builtins/mod.rs index 8540e6887c..6f379a4bab 100644 --- a/vm/src/builtins/mod.rs +++ b/vm/src/builtins/mod.rs @@ -59,7 +59,7 @@ pub(crate) mod bool_; pub use bool_::PyBool; #[path = "str.rs"] pub(crate) mod pystr; -pub use pystr::{PyStr, PyStrInterned, PyStrRef}; +pub use pystr::{PyStr, PyStrInterned, PyStrRef, PyUtf8Str, PyUtf8StrRef}; #[path = "super.rs"] pub(crate) mod super_; pub use super_::PySuper; diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 57bc1428b6..a7809fc4c7 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -454,12 +454,6 @@ impl PyStr { } } - pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> { - self.ensure_valid_utf8(vm)?; - // SAFETY: ensure_valid_utf8 passed, so unwrap is safe. - Ok(unsafe { self.to_str().unwrap_unchecked() }) - } - pub fn to_string_lossy(&self) -> Cow<'_, str> { self.to_str() .map(Cow::Borrowed) diff --git a/vm/src/macros.rs b/vm/src/macros.rs index ff9b28cf88..1284c20278 100644 --- a/vm/src/macros.rs +++ b/vm/src/macros.rs @@ -186,6 +186,15 @@ macro_rules! identifier( }; ); +#[macro_export] +macro_rules! identifier_utf8( + ($as_ctx:expr, $name:ident) => { + // Safety: All known identifiers are ascii strings. + #[allow(clippy::macro_metavars_in_unsafe)] + unsafe { $as_ctx.as_ref().names.$name.as_object().downcast_unchecked_ref::<$crate::builtins::PyUtf8Str>() } + }; +); + /// Super detailed logging. Might soon overflow your log buffers /// Default, this logging is discarded, except when a the `vm-tracing-logging` /// build feature is enabled. diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index c0a091bcf8..468b5dda6e 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -7,7 +7,7 @@ mod _codecs { use crate::common::wtf8::Wtf8Buf; use crate::{ AsObject, PyObjectRef, PyResult, VirtualMachine, - builtins::PyStrRef, + builtins::{PyStrRef, PyUtf8StrRef}, codecs, function::{ArgBytesLike, FuncArgs}, }; @@ -23,10 +23,10 @@ mod _codecs { } #[pyfunction] - fn lookup(encoding: PyStrRef, vm: &VirtualMachine) -> PyResult { + fn lookup(encoding: PyUtf8StrRef, vm: &VirtualMachine) -> PyResult { vm.state .codec_registry - .lookup(encoding.try_to_str(vm)?, vm) + .lookup(encoding.as_str(), vm) .map(|codec| codec.into_tuple().into()) } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 67ac51615a..06363dd98b 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -120,7 +120,7 @@ mod _io { TryFromBorrowedObject, TryFromObject, builtins::{ PyBaseExceptionRef, PyByteArray, PyBytes, PyBytesRef, PyIntRef, PyMemoryView, PyStr, - PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, + PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef, }, class::StaticType, common::lock::{ @@ -1945,7 +1945,7 @@ mod _io { #[derive(FromArgs)] struct TextIOWrapperArgs { #[pyarg(any, default)] - encoding: Option, + encoding: Option, #[pyarg(any, default)] errors: Option, #[pyarg(any, default)] @@ -2108,7 +2108,7 @@ mod _io { buffer: PyObjectRef, encoder: Option<(PyObjectRef, Option)>, decoder: Option, - encoding: PyStrRef, + encoding: PyUtf8StrRef, errors: PyStrRef, newline: Newlines, line_buffering: bool, @@ -2294,8 +2294,8 @@ mod _io { *data = None; let encoding = match args.encoding { - None if vm.state.settings.utf8_mode > 0 => identifier!(vm, utf_8).to_owned(), - Some(enc) if enc.as_wtf8() != "locale" => enc, + None if vm.state.settings.utf8_mode > 0 => identifier_utf8!(vm, utf_8).to_owned(), + Some(enc) if enc.as_str() != "locale" => enc, _ => { // None without utf8_mode or "locale" encoding vm.import("locale", 0)? @@ -2314,7 +2314,7 @@ mod _io { let newline = args.newline.unwrap_or_default(); let (encoder, decoder) = - Self::find_coder(&buffer, encoding.try_to_str(vm)?, &errors, newline, vm)?; + Self::find_coder(&buffer, encoding.as_str(), &errors, newline, vm)?; *data = Some(TextIOData { buffer, @@ -2414,7 +2414,7 @@ mod _io { if let Some(encoding) = args.encoding { let (encoder, decoder) = Self::find_coder( &data.buffer, - encoding.try_to_str(vm)?, + encoding.as_str(), &data.errors, data.newline, vm, @@ -2739,7 +2739,7 @@ mod _io { } #[pygetset] - fn encoding(&self, vm: &VirtualMachine) -> PyResult { + fn encoding(&self, vm: &VirtualMachine) -> PyResult { Ok(self.lock(vm)?.encoding.clone()) } @@ -3892,7 +3892,7 @@ mod _io { struct IoOpenArgs { file: PyObjectRef, #[pyarg(any, optional)] - mode: OptionalArg, + mode: OptionalArg, #[pyarg(flatten)] opts: OpenArgs, } @@ -3918,7 +3918,7 @@ mod _io { #[pyarg(any, default = -1)] pub buffering: isize, #[pyarg(any, default)] - pub encoding: Option, + pub encoding: Option, #[pyarg(any, default)] pub errors: Option, #[pyarg(any, default)] @@ -4130,7 +4130,7 @@ mod fileio { use super::{_io::*, Offset}; use crate::{ AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, - builtins::{PyBaseExceptionRef, PyStr, PyStrRef}, + builtins::{PyBaseExceptionRef, PyUtf8Str, PyUtf8StrRef}, common::crt_fd::Fd, convert::ToPyException, function::{ArgBytesLike, ArgMemoryBuffer, OptionalArg, OptionalOption}, @@ -4257,7 +4257,7 @@ mod fileio { #[pyarg(positional)] name: PyObjectRef, #[pyarg(any, default)] - mode: Option, + mode: Option, #[pyarg(any, default = true)] closefd: bool, #[pyarg(any, default)] @@ -4295,7 +4295,7 @@ mod fileio { let mode_obj = args .mode - .unwrap_or_else(|| PyStr::from("rb").into_ref(&vm.ctx)); + .unwrap_or_else(|| PyUtf8Str::from("rb").into_ref(&vm.ctx)); let mode_str = mode_obj.as_str(); let (mode, flags) = compute_mode(mode_str).map_err(|e| vm.new_value_error(e.error_msg(mode_str)))?; diff --git a/vm/src/stdlib/time.rs b/vm/src/stdlib/time.rs index 85e8f4569c..6a061b3454 100644 --- a/vm/src/stdlib/time.rs +++ b/vm/src/stdlib/time.rs @@ -35,7 +35,7 @@ unsafe extern "C" { mod decl { use crate::{ AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, - builtins::{PyStrRef, PyTypeRef}, + builtins::{PyStrRef, PyTypeRef, PyUtf8StrRef}, function::{Either, FuncArgs, OptionalArg}, types::PyStructSequence, }; @@ -344,7 +344,11 @@ mod decl { } #[pyfunction] - fn strftime(format: PyStrRef, t: OptionalArg, vm: &VirtualMachine) -> PyResult { + fn strftime( + format: PyUtf8StrRef, + t: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { use std::fmt::Write; let instant = t.naive_or_local(vm)?; @@ -355,12 +359,8 @@ mod decl { * raises an error if unsupported format is supplied. * If error happens, we set result as input arg. */ - write!( - &mut formatted_time, - "{}", - instant.format(format.try_to_str(vm)?) - ) - .unwrap_or_else(|_| formatted_time = format.to_string()); + write!(&mut formatted_time, "{}", instant.format(format.as_str())) + .unwrap_or_else(|_| formatted_time = format.to_string()); Ok(vm.ctx.new_str(formatted_time).into()) } From 9583af057b5789eea2e776b7fba29dd761ad9a18 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 30 Jul 2025 10:43:25 +0900 Subject: [PATCH 3/3] Apply PyUtf8Str --- vm/src/builtins/builtin_func.rs | 3 ++- vm/src/builtins/namespace.rs | 4 ++-- vm/src/builtins/slice.rs | 9 ++------- vm/src/builtins/super.rs | 4 ++-- vm/src/builtins/type.rs | 4 ++-- vm/src/exceptions.rs | 23 ++++++++++------------- vm/src/format.rs | 2 +- vm/src/protocol/object.rs | 20 +++++++++++++------- vm/src/stdlib/ctypes/structure.rs | 2 +- vm/src/stdlib/io.rs | 4 ++-- vm/src/stdlib/operator.rs | 11 ++++++----- vm/src/stdlib/posix.rs | 4 ++-- vm/src/stdlib/pwd.rs | 4 ++-- vm/src/stdlib/sys.rs | 3 +-- vm/src/types/slot.rs | 10 +++++----- vm/src/utils.rs | 3 ++- vm/src/vm/mod.rs | 2 +- vm/src/vm/vm_ops.rs | 6 +++++- vm/src/warn.rs | 14 +++++++------- 19 files changed, 68 insertions(+), 64 deletions(-) diff --git a/vm/src/builtins/builtin_func.rs b/vm/src/builtins/builtin_func.rs index 464f47f13c..3df93398dd 100644 --- a/vm/src/builtins/builtin_func.rs +++ b/vm/src/builtins/builtin_func.rs @@ -2,6 +2,7 @@ use super::{PyStrInterned, PyStrRef, PyType, type_}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, + common::wtf8::Wtf8, convert::TryFromObject, function::{FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags, PyNativeFn}, types::{Callable, Comparable, PyComparisonOp, Representable, Unconstructible}, @@ -27,7 +28,7 @@ impl fmt::Debug for PyNativeFunction { write!( f, "builtin function {}.{} ({:?}) self as instance of {:?}", - self.module.map_or("", |m| m.as_str()), + self.module.map_or(Wtf8::new(""), |m| m.as_wtf8()), self.value.name, self.value.flags, self.zelf.as_ref().map(|z| z.class().name().to_owned()) diff --git a/vm/src/builtins/namespace.rs b/vm/src/builtins/namespace.rs index ea430225a8..b62086e4b3 100644 --- a/vm/src/builtins/namespace.rs +++ b/vm/src/builtins/namespace.rs @@ -89,8 +89,8 @@ impl Representable for PyNamespace { let dict = zelf.as_object().dict().unwrap(); let mut parts = Vec::with_capacity(dict.__len__()); for (key, value) in dict { - let k = &key.repr(vm)?; - let key_str = k.as_str(); + let k = key.repr(vm)?; + let key_str = k.as_wtf8(); let value_repr = value.repr(vm)?; parts.push(format!("{}={}", &key_str[1..key_str.len() - 1], value_repr)); } diff --git a/vm/src/builtins/slice.rs b/vm/src/builtins/slice.rs index f77c8cb8e8..84546b2c64 100644 --- a/vm/src/builtins/slice.rs +++ b/vm/src/builtins/slice.rs @@ -292,15 +292,10 @@ impl Representable for PySlice { #[inline] fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { let start_repr = zelf.start_ref(vm).repr(vm)?; - let stop_repr = &zelf.stop.repr(vm)?; + let stop_repr = zelf.stop.repr(vm)?; let step_repr = zelf.step_ref(vm).repr(vm)?; - Ok(format!( - "slice({}, {}, {})", - start_repr.as_str(), - stop_repr.as_str(), - step_repr.as_str() - )) + Ok(format!("slice({start_repr}, {stop_repr}, {step_repr})")) } } diff --git a/vm/src/builtins/super.rs b/vm/src/builtins/super.rs index 2d7e48447f..8f5bc0e112 100644 --- a/vm/src/builtins/super.rs +++ b/vm/src/builtins/super.rs @@ -104,7 +104,7 @@ impl Initializer for PySuper { let mut typ = None; for (i, var) in frame.code.freevars.iter().enumerate() { - if var.as_str() == "__class__" { + if var.as_bytes() == b"__class__" { let i = frame.code.cellvars.len() + i; let class = frame.cells_frees[i] .get() @@ -162,7 +162,7 @@ impl GetAttr for PySuper { // We want __class__ to return the class of the super object // (i.e. super, or a subclass), not the class of su->obj. - if name.as_str() == "__class__" { + if name.as_bytes() == b"__class__" { return skip(zelf, name); } diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 41ed70d939..6a93751a93 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -340,13 +340,13 @@ impl PyType { if name == identifier!(ctx, __new__) { continue; } - if name.as_str().starts_with("__") && name.as_str().ends_with("__") { + if name.as_bytes().starts_with(b"__") && name.as_bytes().ends_with(b"__") { slot_name_set.insert(name); } } } for &name in self.attributes.read().keys() { - if name.as_str().starts_with("__") && name.as_str().ends_with("__") { + if name.as_bytes().starts_with(b"__") && name.as_bytes().ends_with(b"__") { slot_name_set.insert(name); } } diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index a443887bce..a1896a8247 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -198,15 +198,11 @@ impl VirtualMachine { let mut filename_suffix = String::new(); if let Some(lineno) = maybe_lineno { - writeln!( - output, - r##" File "{}", line {}"##, - maybe_filename - .as_ref() - .map(|s| s.as_str()) - .unwrap_or(""), - lineno - )?; + let filename = match maybe_filename { + Some(filename) => filename, + None => vm.ctx.new_str(""), + }; + writeln!(output, r##" File "{filename}", line {lineno}"##,)?; } else if let Some(filename) = maybe_filename { filename_suffix = format!(" ({filename})"); } @@ -1498,7 +1494,7 @@ pub(super) mod types { let args = exc.args(); let obj = exc.as_object().to_owned(); - if args.len() == 2 { + let str = if args.len() == 2 { // SAFETY: len() == 2 is checked so get_arg 1 or 2 won't panic let errno = exc.get_arg(0).unwrap().str(vm)?; let msg = exc.get_arg(1).unwrap().str(vm)?; @@ -1518,10 +1514,11 @@ pub(super) mod types { format!("[Errno {errno}] {msg}") } }; - Ok(vm.ctx.new_str(s)) + vm.ctx.new_str(s) } else { - Ok(exc.__str__(vm)) - } + exc.__str__(vm) + }; + Ok(str) } #[pymethod] diff --git a/vm/src/format.rs b/vm/src/format.rs index f95f161f7a..04d06e9be0 100644 --- a/vm/src/format.rs +++ b/vm/src/format.rs @@ -112,7 +112,7 @@ fn format_internal( // FIXME: compiler can intern specs using parser tree. Then this call can be interned_str pystr = vm.format(&argument, vm.ctx.new_str(format_spec))?; - pystr.as_ref() + pystr.as_wtf8() } FormatPart::Literal(literal) => literal, }; diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index aade5a18e4..d398da0237 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -2,10 +2,10 @@ //! use crate::{ - AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, + AsObject, Py, PyObject, PyObjectRef, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{ - PyAsyncGen, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyStrRef, - PyTuple, PyTupleRef, PyType, PyTypeRef, pystr::AsPyStr, + PyAsyncGen, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyTuple, + PyTupleRef, PyType, PyTypeRef, PyUtf8Str, pystr::AsPyStr, }, bytes_inner::ByteInnerNewOptions, common::{hash::PyHash, str::to_ascii}, @@ -328,7 +328,11 @@ impl PyObject { } } - pub fn repr(&self, vm: &VirtualMachine) -> PyResult { + pub fn repr_utf8(&self, vm: &VirtualMachine) -> PyResult> { + self.repr(vm)?.try_into_utf8(vm) + } + + pub fn repr(&self, vm: &VirtualMachine) -> PyResult> { vm.with_recursion("while getting the repr of an object", || { // TODO: RustPython does not implement type slots inheritance yet self.class() @@ -346,13 +350,15 @@ impl PyObject { } pub fn ascii(&self, vm: &VirtualMachine) -> PyResult { - let repr = self.repr(vm)?; + let repr = self.repr_utf8(vm)?; let ascii = to_ascii(repr.as_str()); Ok(ascii) } - // Container of the virtual machine state: - pub fn str(&self, vm: &VirtualMachine) -> PyResult { + pub fn str_utf8(&self, vm: &VirtualMachine) -> PyResult> { + self.str(vm)?.try_into_utf8(vm) + } + pub fn str(&self, vm: &VirtualMachine) -> PyResult> { let obj = match self.to_owned().downcast_exact::(vm) { Ok(s) => return Ok(s.into_pyref()), Err(obj) => obj, diff --git a/vm/src/stdlib/ctypes/structure.rs b/vm/src/stdlib/ctypes/structure.rs index d675c3263d..8ca8bb51df 100644 --- a/vm/src/stdlib/ctypes/structure.rs +++ b/vm/src/stdlib/ctypes/structure.rs @@ -39,7 +39,7 @@ impl Constructor for PyCStructure { .downcast_ref::() .ok_or_else(|| vm.new_type_error("Field name must be a string"))?; let typ = field.get(1).unwrap().clone(); - field_data.insert(name.as_str().to_string(), typ); + field_data.insert(name.to_string(), typ); } todo!("Implement PyCStructure::py_new") } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 06363dd98b..4931b2f11a 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -1579,7 +1579,7 @@ mod _io { } #[pyslot] - fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { + fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult> { let name_repr = repr_file_obj_name(zelf, vm)?; let cls = zelf.class(); let slot_name = cls.slot_name(); @@ -1592,7 +1592,7 @@ mod _io { } #[pymethod] - fn __repr__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn __repr__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult> { Self::slot_repr(&zelf, vm) } diff --git a/vm/src/stdlib/operator.rs b/vm/src/stdlib/operator.rs index 4fd74734ec..f13c00ac1d 100644 --- a/vm/src/stdlib/operator.rs +++ b/vm/src/stdlib/operator.rs @@ -5,8 +5,8 @@ mod _operator { use crate::{ AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyInt, PyIntRef, PyStr, PyStrRef, PyTupleRef, PyTypeRef}, - function::Either, - function::{ArgBytesLike, FuncArgs, KwArgs, OptionalArg}, + common::wtf8::Wtf8, + function::{ArgBytesLike, Either, FuncArgs, KwArgs, OptionalArg}, identifier, protocol::PyIter, recursion::ReprGuard, @@ -324,7 +324,7 @@ mod _operator { ) -> PyResult { let res = match (a, b) { (Either::A(a), Either::A(b)) => { - if !a.as_str().is_ascii() || !b.as_str().is_ascii() { + if !a.is_ascii() || !b.is_ascii() { return Err(vm.new_type_error( "comparing strings with non-ASCII characters is not supported", )); @@ -371,13 +371,14 @@ mod _operator { attr: &Py, vm: &VirtualMachine, ) -> PyResult { - let attr_str = attr.as_str(); - let parts = attr_str.split('.').collect::>(); + let attr_str = attr.as_bytes(); + let parts = attr_str.split(|&b| b == b'.').collect::>(); if parts.len() == 1 { return obj.get_attr(attr, vm); } let mut obj = obj; for part in parts { + let part = Wtf8::from_bytes(part).expect("originally valid WTF-8"); obj = obj.get_attr(&vm.ctx.new_str(part), vm)?; } Ok(obj) diff --git a/vm/src/stdlib/posix.rs b/vm/src/stdlib/posix.rs index 6220a15b8a..7a99b96f46 100644 --- a/vm/src/stdlib/posix.rs +++ b/vm/src/stdlib/posix.rs @@ -24,7 +24,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { pub mod module { use crate::{ AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, - builtins::{PyDictRef, PyInt, PyListRef, PyStrRef, PyTupleRef, PyTypeRef}, + builtins::{PyDictRef, PyInt, PyListRef, PyStrRef, PyTupleRef, PyTypeRef, PyUtf8StrRef}, convert::{IntoPyException, ToPyObject, TryFromObject}, function::{Either, KwArgs, OptionalArg}, ospath::{IOErrorBuilder, OsPath, OsPathOrFd}, @@ -2242,7 +2242,7 @@ pub mod module { let i = match obj.downcast::() { Ok(int) => int.try_to_primitive(vm)?, Err(obj) => { - let s = PyStrRef::try_from_object(vm, obj)?; + let s = PyUtf8StrRef::try_from_object(vm, obj)?; s.as_str().parse::().or_else(|_| { if s.as_str() == "SC_PAGESIZE" { Ok(SysconfVar::SC_PAGESIZE) diff --git a/vm/src/stdlib/pwd.rs b/vm/src/stdlib/pwd.rs index 633a710030..16570ff6be 100644 --- a/vm/src/stdlib/pwd.rs +++ b/vm/src/stdlib/pwd.rs @@ -6,7 +6,7 @@ pub(crate) use pwd::make_module; mod pwd { use crate::{ PyObjectRef, PyResult, VirtualMachine, - builtins::{PyIntRef, PyStrRef}, + builtins::{PyIntRef, PyUtf8StrRef}, convert::{IntoPyException, ToPyObject}, exceptions, types::PyStructSequence, @@ -54,7 +54,7 @@ mod pwd { } #[pyfunction] - fn getpwnam(name: PyStrRef, vm: &VirtualMachine) -> PyResult { + fn getpwnam(name: PyUtf8StrRef, vm: &VirtualMachine) -> PyResult { let pw_name = name.as_str(); if pw_name.contains('\0') { return Err(exceptions::cstring_error(vm)); diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index 4a4d60ef2a..a8fb4031a9 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -712,7 +712,7 @@ mod sys { if !vm.is_none(&unraisable.exc_value) { write!(stderr, "{}: ", unraisable.exc_type); if let Ok(str) = unraisable.exc_value.str(vm) { - write!(stderr, "{}", str.as_str()); + write!(stderr, "{}", str.to_str().unwrap_or("")); } else { write!(stderr, ""); } @@ -734,7 +734,6 @@ mod sys { e.as_object() .repr(vm) .unwrap_or_else(|_| vm.ctx.empty_str.to_owned()) - .as_str() ); } } diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 8788b9f3ae..eec15d9631 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -169,7 +169,7 @@ impl Default for PyTypeFlags { pub(crate) type GenericMethod = fn(&PyObject, FuncArgs, &VirtualMachine) -> PyResult; pub(crate) type HashFunc = fn(&PyObject, &VirtualMachine) -> PyResult; // CallFunc = GenericMethod -pub(crate) type StringifyFunc = fn(&PyObject, &VirtualMachine) -> PyResult; +pub(crate) type StringifyFunc = fn(&PyObject, &VirtualMachine) -> PyResult>; pub(crate) type GetattroFunc = fn(&PyObject, &Py, &VirtualMachine) -> PyResult; pub(crate) type SetattroFunc = fn(&PyObject, &Py, PySetterValue, &VirtualMachine) -> PyResult<()>; @@ -250,7 +250,7 @@ fn setitem_wrapper( .map(drop) } -fn repr_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { +fn repr_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult> { let ret = vm.call_special_method(zelf, identifier!(vm, __repr__), ())?; ret.downcast::().map_err(|obj| { vm.new_type_error(format!( @@ -977,7 +977,7 @@ pub trait Hashable: PyPayload { pub trait Representable: PyPayload { #[inline] #[pyslot] - fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { + fn slot_repr(zelf: &PyObject, vm: &VirtualMachine) -> PyResult> { let zelf = zelf .downcast_ref() .ok_or_else(|| vm.new_type_error("unexpected payload for __repr__"))?; @@ -986,12 +986,12 @@ pub trait Representable: PyPayload { #[inline] #[pymethod] - fn __repr__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn __repr__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult> { Self::slot_repr(&zelf, vm) } #[inline] - fn repr(zelf: &Py, vm: &VirtualMachine) -> PyResult { + fn repr(zelf: &Py, vm: &VirtualMachine) -> PyResult> { let repr = Self::repr_str(zelf, vm)?; Ok(vm.ctx.new_str(repr)) } diff --git a/vm/src/utils.rs b/vm/src/utils.rs index 78edfb71cc..af34405c7b 100644 --- a/vm/src/utils.rs +++ b/vm/src/utils.rs @@ -2,7 +2,7 @@ use rustpython_common::wtf8::Wtf8; use crate::{ PyObjectRef, PyResult, VirtualMachine, - builtins::PyStr, + builtins::{PyStr, PyUtf8Str}, convert::{ToPyException, ToPyObject}, exceptions::cstring_error, }; @@ -35,6 +35,7 @@ pub trait ToCString: AsRef { impl ToCString for &str {} impl ToCString for PyStr {} +impl ToCString for PyUtf8Str {} pub(crate) fn collection_repr<'a, I>( class_name: Option<&str>, diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 78f5ec309c..e88153ecee 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -472,7 +472,7 @@ impl VirtualMachine { object, }; if let Err(e) = unraisablehook.call((args,), self) { - println!("{}", e.as_object().repr(self).unwrap().as_str()); + println!("{}", e.as_object().repr(self).unwrap()); } } diff --git a/vm/src/vm/vm_ops.rs b/vm/src/vm/vm_ops.rs index f0e495c8db..d4bf4563c8 100644 --- a/vm/src/vm/vm_ops.rs +++ b/vm/src/vm/vm_ops.rs @@ -1,7 +1,8 @@ use super::VirtualMachine; use crate::stdlib::warnings; use crate::{ - builtins::{PyInt, PyIntRef, PyStr, PyStrRef}, + PyRef, + builtins::{PyInt, PyIntRef, PyStr, PyStrRef, PyUtf8Str}, object::{AsObject, PyObject, PyObjectRef, PyResult}, protocol::{PyIterReturn, PyNumberBinaryOp, PyNumberTernaryOp, PySequence}, types::PyComparisonOp, @@ -517,6 +518,9 @@ impl VirtualMachine { )) }) } + pub fn format_utf8(&self, obj: &PyObject, format_spec: PyStrRef) -> PyResult> { + self.format(obj, format_spec)?.try_into_utf8(self) + } // https://docs.python.org/3/reference/expressions.html#membership-test-operations fn _membership_iter_search( diff --git a/vm/src/warn.rs b/vm/src/warn.rs index 943fbcee23..6480f77843 100644 --- a/vm/src/warn.rs +++ b/vm/src/warn.rs @@ -122,7 +122,7 @@ fn get_filter( /* Python code: action, msg, cat, mod, ln = item */ let action = if let Some(action) = tmp_item.first() { - action.str(vm).map(|action| action.into_object()) + action.str_utf8(vm).map(|action| action.into_object()) } else { Err(vm.new_type_error("action must be a string")) }; @@ -201,8 +201,8 @@ fn already_warned( fn normalize_module(filename: &Py, vm: &VirtualMachine) -> Option { let obj = match filename.char_len() { 0 => vm.new_pyobj(""), - len if len >= 3 && filename.as_str().ends_with(".py") => { - vm.new_pyobj(&filename.as_str()[..len - 3]) + len if len >= 3 && filename.as_bytes().ends_with(b".py") => { + vm.new_pyobj(&filename.as_wtf8()[..len - 3]) } _ => filename.as_object().to_owned(), }; @@ -232,7 +232,7 @@ fn warn_explicit( }; // Normalize message. - let text = message.as_str(); + let text = message.as_wtf8(); let category = if let Some(category) = category { if !category.fast_issubclass(vm.ctx.exceptions.warning) { @@ -278,11 +278,11 @@ fn warn_explicit( vm, )?; - if action.str(vm)?.as_str().eq("error") { + if action.str_utf8(vm)?.as_str().eq("error") { return Err(vm.new_type_error(message.to_string())); } - if action.str(vm)?.as_str().eq("ignore") { + if action.str_utf8(vm)?.as_str().eq("ignore") { return Ok(()); } @@ -345,7 +345,7 @@ fn show_warning( vm: &VirtualMachine, ) -> PyResult<()> { let stderr = crate::stdlib::sys::PyStderr(vm); - writeln!(stderr, "{}: {}", category.name(), text.as_str(),); + writeln!(stderr, "{}: {}", category.name(), text); Ok(()) }