diff --git a/Lib/test/test_sqlite3/test_regression.py b/Lib/test/test_sqlite3/test_regression.py index dfcf3b11f5..870958ceee 100644 --- a/Lib/test/test_sqlite3/test_regression.py +++ b/Lib/test/test_sqlite3/test_regression.py @@ -343,8 +343,6 @@ def test_null_character(self): self.assertRaisesRegex(sqlite.ProgrammingError, "null char", cur.execute, query) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogates(self): con = sqlite.connect(":memory:") self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'") diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index ce84ac2988..4e9620eeab 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -844,7 +844,7 @@ mod _sqlite { type Args = (PyStrRef,); fn call(zelf: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { - if let Some(stmt) = Statement::new(zelf, &args.0, vm)? { + if let Some(stmt) = Statement::new(zelf, args.0, vm)? { Ok(stmt.into_ref(&vm.ctx).into()) } else { Ok(vm.ctx.none()) @@ -1480,7 +1480,7 @@ mod _sqlite { stmt.lock().reset(); } - let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else { + let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else { drop(inner); return Ok(zelf); }; @@ -1552,7 +1552,7 @@ mod _sqlite { stmt.lock().reset(); } - let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else { + let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else { drop(inner); return Ok(zelf); }; @@ -2291,9 +2291,10 @@ mod _sqlite { impl Statement { fn new( connection: &Connection, - sql: &PyStr, + sql: PyStrRef, vm: &VirtualMachine, ) -> PyResult> { + let sql = sql.try_into_utf8(vm)?; let sql_cstr = sql.to_cstring(vm)?; let sql_len = sql.byte_len() + 1; diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 9f86da3da0..73349c6141 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -37,8 +37,8 @@ use rustpython_common::{ str::DeduceStrKind, wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Chunk}, }; -use std::sync::LazyLock; use std::{borrow::Cow, char, fmt, ops::Range}; +use std::{mem, sync::LazyLock}; use unic_ucd_bidi::BidiClass; use unic_ucd_category::GeneralCategory; use unic_ucd_ident::{is_xid_continue, is_xid_start}; @@ -80,6 +80,30 @@ impl fmt::Debug for PyStr { } } +#[repr(transparent)] +#[derive(Debug)] +pub struct PyUtf8Str(PyStr); + +// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str +impl std::ops::Deref for PyUtf8Str { + type Target = PyStr; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PyUtf8Str { + /// Returns the underlying string slice. + pub fn as_str(&self) -> &str { + debug_assert!( + self.0.is_utf8(), + "PyUtf8Str invariant violated: inner string is not valid UTF-8" + ); + // Safety: This is safe because the type invariant guarantees UTF-8 validity. + unsafe { self.0.to_str().unwrap_unchecked() } + } +} + impl AsRef for PyStr { #[track_caller] // <- can remove this once it doesn't panic fn as_ref(&self) -> &str { @@ -433,21 +457,29 @@ impl PyStr { self.data.as_str() } - pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> { - self.to_str().ok_or_else(|| { + fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> { + if self.is_utf8() { + Ok(()) + } else { let start = self .as_wtf8() .code_points() .position(|c| c.to_char().is_none()) .unwrap(); - vm.new_unicode_encode_error_real( + Err(vm.new_unicode_encode_error_real( identifier!(vm, utf_8).to_owned(), vm.ctx.new_str(self.data.clone()), start, start + 1, vm.ctx.new_str("surrogates not allowed"), - ) - }) + )) + } + } + + pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> { + self.ensure_valid_utf8(vm)?; + // SAFETY: ensure_valid_utf8 passed, so unwrap is safe. + Ok(unsafe { self.to_str().unwrap_unchecked() }) } pub fn to_string_lossy(&self) -> Cow<'_, str> { @@ -1486,6 +1518,11 @@ impl PyStrRef { s.push_wtf8(other); *self = PyStr::from(s).into_ref(&vm.ctx); } + + pub fn try_into_utf8(self, vm: &VirtualMachine) -> PyResult> { + self.ensure_valid_utf8(vm)?; + Ok(unsafe { mem::transmute::, PyRef>(self) }) + } } impl Representable for PyStr {