diff --git a/sqlx-mysql/src/connection/tls.rs b/sqlx-mysql/src/connection/tls.rs index 9034fbd63a..208d449065 100644 --- a/sqlx-mysql/src/connection/tls.rs +++ b/sqlx-mysql/src/connection/tls.rs @@ -53,13 +53,14 @@ pub(super) async fn maybe_upgrade( } } + let hostname = options.tls_server_name.as_deref().unwrap_or(&options.host); let tls_config = TlsConfig { accept_invalid_certs: !matches!( options.ssl_mode, MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity ), accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity), - hostname: &options.host, + hostname, root_cert_path: options.ssl_ca.as_ref(), client_cert_path: options.ssl_client_cert.as_ref(), client_key_path: options.ssl_client_key.as_ref(), diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index 421bfb700e..155b165db0 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -68,6 +68,7 @@ pub struct MySqlConnectOptions { pub(crate) password: Option, pub(crate) database: Option, pub(crate) ssl_mode: MySqlSslMode, + pub(crate) tls_server_name: Option, pub(crate) ssl_ca: Option, pub(crate) ssl_client_cert: Option, pub(crate) ssl_client_key: Option, @@ -101,6 +102,7 @@ impl MySqlConnectOptions { charset: String::from("utf8mb4"), collation: None, ssl_mode: MySqlSslMode::Preferred, + tls_server_name: None, ssl_ca: None, ssl_client_cert: None, ssl_client_key: None, @@ -123,6 +125,23 @@ impl MySqlConnectOptions { self } + /// Overrides the TLS server name used for SNI and hostname verification. + /// + /// By default, the host from `MySqlConnectOptions` is used. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_mysql::MySqlConnectOptions; + /// let _options = MySqlConnectOptions::new() + /// .host("haproxy.example.com") + /// .tls_server_name("mysql.example.com"); + /// ``` + pub fn tls_server_name(mut self, server_name: &str) -> Self { + self.tls_server_name = Some(server_name.to_owned()); + self + } + /// Sets the port to connect to at the server host. /// /// The default port for MySQL is `3306`. @@ -527,3 +546,14 @@ impl MySqlConnectOptions { self.collation.as_deref() } } + +#[cfg(test)] +mod tests { + use super::MySqlConnectOptions; + + #[test] + fn tls_server_name_is_stored() { + let opts = MySqlConnectOptions::new().tls_server_name("sni.example.com"); + assert_eq!(opts.tls_server_name.as_deref(), Some("sni.example.com")); + } +}