diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index 052248a583..d151c2948c 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -94,9 +94,6 @@ mod _ssl { SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, - // X509_V_FLAG_CRL_CHECK as VERIFY_CRL_CHECK_LEAF, - // sys::X509_V_FLAG_CRL_CHECK|sys::X509_V_FLAG_CRL_CHECK_ALL as VERIFY_CRL_CHECK_CHAIN - // X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT, SSL_ERROR_ZERO_RETURN, SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE, SSL_OP_ENABLE_MIDDLEBOX_COMPAT as OP_ENABLE_MIDDLEBOX_COMPAT, @@ -114,6 +111,11 @@ mod _ssl { X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT, }; + // CRL verification constants + #[pyattr] + const VERIFY_CRL_CHECK_CHAIN: libc::c_ulong = + sys::X509_V_FLAG_CRL_CHECK | sys::X509_V_FLAG_CRL_CHECK_ALL; + // taken from CPython, should probably be kept up to date with their version if it ever changes #[pyattr] const _DEFAULT_CIPHERS: &str = @@ -631,6 +633,12 @@ mod _ssl { Ok(()) } + #[cfg(ossl110)] + #[pygetset] + fn security_level(&self) -> i32 { + unsafe { SSL_CTX_get_security_level(self.ctx().as_ptr()) } + } + #[pymethod] fn set_ciphers(&self, cipherlist: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { let ciphers = cipherlist.as_str(); @@ -677,19 +685,29 @@ mod _ssl { } #[pymethod] - fn set_ecdh_curve(&self, name: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { + fn set_ecdh_curve( + &self, + name: Either, + vm: &VirtualMachine, + ) -> PyResult<()> { use openssl::ec::{EcGroup, EcKey}; - let curve_name = name.as_str(); - if curve_name.contains('\0') { - return Err(exceptions::cstring_error(vm)); - } + // Convert name to CString, supporting both str and bytes + let name_cstr = match name { + Either::A(s) => { + if s.as_str().contains('\0') { + return Err(exceptions::cstring_error(vm)); + } + s.to_cstring(vm)? + } + Either::B(b) => std::ffi::CString::new(b.borrow_buf().to_vec()) + .map_err(|_| exceptions::cstring_error(vm))?, + }; // Find the NID for the curve name using OBJ_sn2nid - let name_cstr = name.to_cstring(vm)?; let nid_raw = unsafe { sys::OBJ_sn2nid(name_cstr.as_ptr()) }; if nid_raw == 0 { - return Err(vm.new_value_error(format!("unknown curve name: {}", curve_name))); + return Err(vm.new_value_error("unknown curve name")); } let nid = Nid::from_raw(nid_raw); @@ -794,6 +812,47 @@ mod _ssl { self.check_hostname.store(ch); } + // PY_PROTO_MINIMUM_SUPPORTED = -2, PY_PROTO_MAXIMUM_SUPPORTED = -1 + #[pygetset] + fn minimum_version(&self) -> i32 { + let ctx = self.ctx(); + let version = unsafe { sys::SSL_CTX_get_min_proto_version(ctx.as_ptr()) }; + if version == 0 { + -2 // PY_PROTO_MINIMUM_SUPPORTED + } else { + version + } + } + #[pygetset(setter)] + fn set_minimum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { + let ctx = self.builder(); + let result = unsafe { sys::SSL_CTX_set_min_proto_version(ctx.as_ptr(), value) }; + if result == 0 { + return Err(vm.new_value_error("invalid protocol version")); + } + Ok(()) + } + + #[pygetset] + fn maximum_version(&self) -> i32 { + let ctx = self.ctx(); + let version = unsafe { sys::SSL_CTX_get_max_proto_version(ctx.as_ptr()) }; + if version == 0 { + -1 // PY_PROTO_MAXIMUM_SUPPORTED + } else { + version + } + } + #[pygetset(setter)] + fn set_maximum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { + let ctx = self.builder(); + let result = unsafe { sys::SSL_CTX_set_max_proto_version(ctx.as_ptr(), value) }; + if result == 0 { + return Err(vm.new_value_error("invalid protocol version")); + } + Ok(()) + } + #[pymethod] fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> { cfg_if::cfg_if! { @@ -852,12 +911,6 @@ mod _ssl { if let (None, None, None) = (&args.cafile, &args.capath, &args.cadata) { return Err(vm.new_type_error("cafile, capath and cadata cannot be all omitted")); } - if let Some(cafile) = &args.cafile { - cafile.ensure_no_nul(vm)? - } - if let Some(capath) = &args.capath { - capath.ensure_no_nul(vm)? - } #[cold] fn invalid_cadata(vm: &VirtualMachine) -> PyBaseExceptionRef { @@ -887,11 +940,10 @@ mod _ssl { } if args.cafile.is_some() || args.capath.is_some() { - ctx.load_verify_locations( - args.cafile.as_ref().map(|s| s.as_str().as_ref()), - args.capath.as_ref().map(|s| s.as_str().as_ref()), - ) - .map_err(|e| convert_openssl_error(vm, e))?; + let cafile_path = args.cafile.map(|p| p.to_path_buf(vm)).transpose()?; + let capath_path = args.capath.map(|p| p.to_path_buf(vm)).transpose()?; + ctx.load_verify_locations(cafile_path.as_deref(), capath_path.as_deref()) + .map_err(|e| convert_openssl_error(vm, e))?; } Ok(()) @@ -1064,9 +1116,9 @@ mod _ssl { #[derive(FromArgs)] struct LoadVerifyLocationsArgs { #[pyarg(any, default)] - cafile: Option, + cafile: Option, #[pyarg(any, default)] - capath: Option, + capath: Option, #[pyarg(any, default)] cadata: Option>, } @@ -1794,6 +1846,11 @@ mod _ssl { fn SSL_verify_client_post_handshake(ssl: *const sys::SSL) -> libc::c_int; } + #[cfg(ossl110)] + unsafe extern "C" { + fn SSL_CTX_get_security_level(ctx: *const sys::SSL_CTX) -> libc::c_int; + } + // OpenSSL BIO helper functions // These are typically macros in OpenSSL, implemented via BIO_ctrl const BIO_CTRL_PENDING: libc::c_int = 10; @@ -2082,7 +2139,7 @@ mod _ssl { let lib = sys::ERR_GET_LIB(err_code); if lib == ERR_LIB_SSL && reason == SSL_R_UNEXPECTED_EOF_WHILE_READING { return vm.new_exception( - vm.class("_ssl", "SSLEOFError"), + PySslEOFError::class(&vm.ctx).to_owned(), vec![ vm.ctx.new_int(SSL_ERROR_EOF).into(), vm.ctx diff --git a/stdlib/src/ssl/cert.rs b/stdlib/src/ssl/cert.rs index 9a77dee1ea..19dd09f337 100644 --- a/stdlib/src/ssl/cert.rs +++ b/stdlib/src/ssl/cert.rs @@ -164,17 +164,15 @@ pub(crate) mod ssl_cert { // IPv4 format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]) } else if ip.len() == 16 { - // IPv6 - format like: "X:X:X:X:X:X:X:X" (not compressed) + // IPv6 - format with all zeros visible (not compressed) + let ip_addr = std::net::Ipv6Addr::from([ + ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], ip[8], + ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15], + ]); + let s = ip_addr.segments(); format!( "{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}", - (ip[0] as u16) << 8 | ip[1] as u16, - (ip[2] as u16) << 8 | ip[3] as u16, - (ip[4] as u16) << 8 | ip[5] as u16, - (ip[6] as u16) << 8 | ip[7] as u16, - (ip[8] as u16) << 8 | ip[9] as u16, - (ip[10] as u16) << 8 | ip[11] as u16, - (ip[12] as u16) << 8 | ip[13] as u16, - (ip[14] as u16) << 8 | ip[15] as u16 + s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7] ) } else { // Fallback for unexpected length