From b78f4db52e37c7a805186f285ac8785d09123792 Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Mon, 14 Jul 2025 19:02:46 +0900 Subject: [PATCH 1/5] fix(sqlite): validate surrogates in SQL statements --- Lib/test/test_sqlite3/test_regression.py | 2 -- stdlib/src/sqlite.rs | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Lib/test/test_sqlite3/test_regression.py b/Lib/test/test_sqlite3/test_regression.py index dfcf3b11f5..870958ceee 100644 --- a/Lib/test/test_sqlite3/test_regression.py +++ b/Lib/test/test_sqlite3/test_regression.py @@ -343,8 +343,6 @@ def test_null_character(self): self.assertRaisesRegex(sqlite.ProgrammingError, "null char", cur.execute, query) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogates(self): con = sqlite.connect(":memory:") self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'") diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index ce84ac2988..472035565c 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -2294,6 +2294,7 @@ mod _sqlite { sql: &PyStr, vm: &VirtualMachine, ) -> PyResult> { + let _ = sql.try_to_str(vm)?; let sql_cstr = sql.to_cstring(vm)?; let sql_len = sql.byte_len() + 1; From 716efc47cc40574eb5feb5a466d437f1b1b4a31a Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Mon, 14 Jul 2025 23:45:54 +0900 Subject: [PATCH 2/5] Add `PyUtf8Str` wrapper for safe conversion --- stdlib/src/sqlite.rs | 4 ++-- vm/src/builtins/str.rs | 26 +++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index 472035565c..df90fadd8d 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -2291,10 +2291,10 @@ mod _sqlite { impl Statement { fn new( connection: &Connection, - sql: &PyStr, + sql: PyRef, vm: &VirtualMachine, ) -> PyResult> { - let _ = sql.try_to_str(vm)?; + let sql = sql.try_into_utf8(vm)?; let sql_cstr = sql.to_cstring(vm)?; let sql_len = sql.byte_len() + 1; diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 9f86da3da0..26da163884 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -37,8 +37,8 @@ use rustpython_common::{ str::DeduceStrKind, wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Chunk}, }; -use std::sync::LazyLock; use std::{borrow::Cow, char, fmt, ops::Range}; +use std::{mem, sync::LazyLock}; use unic_ucd_bidi::BidiClass; use unic_ucd_category::GeneralCategory; use unic_ucd_ident::{is_xid_continue, is_xid_start}; @@ -80,6 +80,25 @@ impl fmt::Debug for PyStr { } } +#[repr(transparent)] +#[derive(Debug)] +pub struct PyUtf8Str(PyStr); + +impl std::ops::Deref for PyUtf8Str { + type Target = PyStr; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PyUtf8Str { + /// Returns the underlying string slice. This is safe because the + /// type invariant guarantees UTF-8 validity. + pub fn as_str(&self) -> &str { + self.0.to_str().expect("PyUtf8Str invariant was violated") + } +} + impl AsRef for PyStr { #[track_caller] // <- can remove this once it doesn't panic fn as_ref(&self) -> &str { @@ -1486,6 +1505,11 @@ impl PyStrRef { s.push_wtf8(other); *self = PyStr::from(s).into_ref(&vm.ctx); } + + pub fn try_into_utf8(self, vm: &VirtualMachine) -> PyResult> { + let _ = self.try_to_str(vm)?; + Ok(unsafe { mem::transmute::, PyRef>(self) }) + } } impl Representable for PyStr { From 62b8a568d5030cca0ebb04352d782e750f333069 Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Mon, 14 Jul 2025 23:52:25 +0900 Subject: [PATCH 3/5] discord review, --- vm/src/builtins/str.rs | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 26da163884..87bd5a5f39 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -95,7 +95,11 @@ impl PyUtf8Str { /// Returns the underlying string slice. This is safe because the /// type invariant guarantees UTF-8 validity. pub fn as_str(&self) -> &str { - self.0.to_str().expect("PyUtf8Str invariant was violated") + debug_assert!( + self.0.is_utf8(), + "PyUtf8Str invariant violated: inner string is not valid UTF-8" + ); + unsafe { self.0.to_str().unwrap_unchecked() } } } @@ -452,21 +456,29 @@ impl PyStr { self.data.as_str() } - pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> { - self.to_str().ok_or_else(|| { + fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> { + if self.is_utf8() { + Ok(()) + } else { let start = self .as_wtf8() .code_points() .position(|c| c.to_char().is_none()) .unwrap(); - vm.new_unicode_encode_error_real( + Err(vm.new_unicode_encode_error_real( identifier!(vm, utf_8).to_owned(), vm.ctx.new_str(self.data.clone()), start, start + 1, vm.ctx.new_str("surrogates not allowed"), - ) - }) + )) + } + } + + pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> { + self.ensure_valid_utf8(vm)?; + // SAFETY: ensure_valid_utf8 passed, so unwrap is safe. + Ok(unsafe { self.to_str().unwrap_unchecked() }) } pub fn to_string_lossy(&self) -> Cow<'_, str> { @@ -1507,7 +1519,7 @@ impl PyStrRef { } pub fn try_into_utf8(self, vm: &VirtualMachine) -> PyResult> { - let _ = self.try_to_str(vm)?; + self.ensure_valid_utf8(vm)?; Ok(unsafe { mem::transmute::, PyRef>(self) }) } } From a50a7bc8f0372e49efb2bd3aabded4a6481924cb Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Mon, 14 Jul 2025 23:56:31 +0900 Subject: [PATCH 4/5] cargo check --- stdlib/src/sqlite.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index df90fadd8d..4e9620eeab 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -844,7 +844,7 @@ mod _sqlite { type Args = (PyStrRef,); fn call(zelf: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { - if let Some(stmt) = Statement::new(zelf, &args.0, vm)? { + if let Some(stmt) = Statement::new(zelf, args.0, vm)? { Ok(stmt.into_ref(&vm.ctx).into()) } else { Ok(vm.ctx.none()) @@ -1480,7 +1480,7 @@ mod _sqlite { stmt.lock().reset(); } - let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else { + let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else { drop(inner); return Ok(zelf); }; @@ -1552,7 +1552,7 @@ mod _sqlite { stmt.lock().reset(); } - let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else { + let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else { drop(inner); return Ok(zelf); }; @@ -2291,7 +2291,7 @@ mod _sqlite { impl Statement { fn new( connection: &Connection, - sql: PyRef, + sql: PyStrRef, vm: &VirtualMachine, ) -> PyResult> { let sql = sql.try_into_utf8(vm)?; From 3899fbe03db686f8386754c459c2fe13fc850e0d Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Tue, 15 Jul 2025 00:13:14 +0900 Subject: [PATCH 5/5] fix review --- vm/src/builtins/str.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 87bd5a5f39..73349c6141 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -84,6 +84,7 @@ impl fmt::Debug for PyStr { #[derive(Debug)] pub struct PyUtf8Str(PyStr); +// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str impl std::ops::Deref for PyUtf8Str { type Target = PyStr; fn deref(&self) -> &Self::Target { @@ -92,13 +93,13 @@ impl std::ops::Deref for PyUtf8Str { } impl PyUtf8Str { - /// Returns the underlying string slice. This is safe because the - /// type invariant guarantees UTF-8 validity. + /// Returns the underlying string slice. pub fn as_str(&self) -> &str { debug_assert!( self.0.is_utf8(), "PyUtf8Str invariant violated: inner string is not valid UTF-8" ); + // Safety: This is safe because the type invariant guarantees UTF-8 validity. unsafe { self.0.to_str().unwrap_unchecked() } } }