From 76daad31c5adab5d3b85db9e043ad51868bc6b8f Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Fri, 21 Nov 2025 14:42:48 -0300 Subject: [PATCH 1/6] compression --- .github/workflows/sqlx.yml | 24 +- Cargo.lock | 68 +++++ Cargo.toml | 5 +- README.md | 2 + sqlx-mysql/Cargo.toml | 5 + sqlx-mysql/src/connection/compression.rs | 245 ++++++++++++++++++ sqlx-mysql/src/connection/establish.rs | 8 +- sqlx-mysql/src/connection/mod.rs | 1 + sqlx-mysql/src/connection/stream.rs | 51 +++- sqlx-mysql/src/connection/tls.rs | 7 +- sqlx-mysql/src/lib.rs | 2 +- sqlx-mysql/src/options/mod.rs | 124 +++++++++ sqlx-mysql/src/options/parse.rs | 51 +++- sqlx-mysql/src/protocol/compressed_packet.rs | 108 ++++++++ .../protocol/connect/handshake_response.rs | 15 +- sqlx-mysql/src/protocol/mod.rs | 4 + tests/mysql/mysql.rs | 145 ++++++++++- tests/x.py | 8 +- 18 files changed, 827 insertions(+), 46 deletions(-) create mode 100644 sqlx-mysql/src/connection/compression.rs create mode 100644 sqlx-mysql/src/protocol/compressed_packet.rs diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index b2f81b75ad..58d449a128 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -343,7 +343,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + - run: cargo build --features mysql,mysql-compression,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mysql_${{ matrix.mysql }} mysql_${{ matrix.mysql }} - run: sleep 60 @@ -354,7 +354,7 @@ jobs: - run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled SQLX_OFFLINE_DIR: .sqlx @@ -365,7 +365,7 @@ jobs: cargo test --test mysql-test-attr --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled SQLX_OFFLINE_DIR: .sqlx @@ -376,7 +376,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -390,7 +390,7 @@ jobs: cargo build --no-default-features --test mysql-macros - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: SQLX_OFFLINE: true SQLX_OFFLINE_DIR: .sqlx @@ -402,7 +402,7 @@ jobs: cargo test --no-default-features --test mysql-macros - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE: true @@ -421,7 +421,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-ompression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} @@ -444,7 +444,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + - run: cargo build --features mysql,mysql-compression,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mariadb_${{ matrix.mariadb }} mariadb_${{ matrix.mariadb }} - run: sleep 30 @@ -455,7 +455,7 @@ jobs: - run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -466,7 +466,7 @@ jobs: cargo test --test mysql-test-attr --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -491,7 +491,7 @@ jobs: cargo test --no-default-features --test mysql-macros - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE: true @@ -510,7 +510,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" diff --git a/Cargo.lock b/Cargo.lock index 78e40f0c12..cc33848e1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1020,6 +1020,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -1373,6 +1382,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "libz-sys", + "miniz_oxide", +] + [[package]] name = "float-cmp" version = "0.9.0" @@ -2160,6 +2180,17 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libz-sys" +version = "1.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15d118bbf3771060e7311cc7bb0545b01d08a8b4a7de949198dec1fa0ca1c0f7" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -2266,6 +2297,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -3431,6 +3463,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "simdutf8" version = "0.1.5" @@ -3903,6 +3941,7 @@ dependencies = [ "digest", "dotenvy", "either", + "flate2", "futures-channel", "futures-core", "futures-io", @@ -3931,6 +3970,7 @@ dependencies = [ "tracing", "uuid", "whoami", + "zstd", ] [[package]] @@ -5290,3 +5330,31 @@ dependencies = [ "quote", "syn 2.0.104", ] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 00d5d656c1..6d5ec3cc4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -161,6 +161,9 @@ uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mysql?/uuid", "sqlx-postgre regexp = ["sqlx-sqlite?/regexp"] bstr = ["sqlx-core/bstr"] +# compression +mysql-compression = ["sqlx-mysql/compression"] + [workspace.dependencies] # Core Crates sqlx-core = { version = "=0.9.0-alpha.1", path = "sqlx-core" } @@ -359,7 +362,7 @@ required-features = ["sqlite"] [[test]] name = "mysql" path = "tests/mysql/mysql.rs" -required-features = ["mysql"] +required-features = ["mysql", "compression"] [[test]] name = "mysql-types" diff --git a/README.md b/README.md index f1e53cdced..2700b9aef9 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,8 @@ be removed in the future. - `mysql`: Add support for the MySQL/MariaDB database server. +- `mysql-compression`: Add compression support for MySQL/MariaDB database server. + - `mssql`: Add support for the MSSQL database server. - `sqlite`: Add support for the self-contained [SQLite](https://sqlite.org/) database engine with SQLite bundled and statically-linked. diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index ee9512b61e..d9eb8eea64 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -14,6 +14,7 @@ json = ["sqlx-core/json", "serde"] any = ["sqlx-core/any"] offline = ["sqlx-core/offline", "serde/derive"] migrate = ["sqlx-core/migrate"] +compression = ["zstd", "flate2"] # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] @@ -67,6 +68,10 @@ stringprep = "0.1.2" tracing = { version = "0.1.37", features = ["log"] } whoami = { version = "1.2.1", default-features = false } +# Compression +zstd = { version = "0.13.3", optional = true, default-features = false, features = ["zdict_builder"] } +flate2 = { version = "1.1.5", optional = true, default-features = false, features = ["rust_backend", "zlib"] } + dotenvy.workspace = true thiserror.workspace = true diff --git a/sqlx-mysql/src/connection/compression.rs b/sqlx-mysql/src/connection/compression.rs new file mode 100644 index 0000000000..2fdac04874 --- /dev/null +++ b/sqlx-mysql/src/connection/compression.rs @@ -0,0 +1,245 @@ +use crate::protocol::Capabilities; +#[cfg(feature = "compression")] +use crate::Compression; +use crate::CompressionConfig; +#[cfg(feature = "compression")] +use compressed_stream::CompressedStream; +use sqlx_core::io::{ProtocolDecode, ProtocolEncode}; +use sqlx_core::net::{BufferedSocket, Socket}; +use sqlx_core::Error; + +pub(crate) struct CompressionMySqlStream> { + stream: CompressionStream, + pub(crate) socket: BufferedSocket, +} + +impl CompressionMySqlStream { + pub(crate) fn not_compressed(socket: BufferedSocket) -> Self { + let stream = CompressionStream::NotCompressed(NoCompressionStream {}); + Self { stream, socket } + } + + #[cfg(feature = "compression")] + fn compressed(socket: BufferedSocket, compression: CompressionConfig) -> Self { + let stream = CompressionStream::Compressed(CompressedStream::new(compression)); + Self { stream, socket } + } + + pub(crate) fn create( + socket: BufferedSocket, + #[cfg_attr(not(feature = "compression"), allow(unused_variables))] + capabilities: &Capabilities, + compression: Option, + ) -> Self { + match compression { + #[cfg(feature = "compression")] + Some(c) if c.is_supported(&capabilities) => { + CompressionMySqlStream::compressed(socket, c) + } + _ => CompressionMySqlStream::not_compressed(socket), + } + } + + pub(crate) fn boxed(self) -> CompressionMySqlStream> { + CompressionMySqlStream { + socket: self.socket.boxed(), + stream: self.stream, + } + } + + pub(crate) async fn read_with<'de, T, C>( + &mut self, + byte_len: usize, + context: C, + ) -> Result + where + T: ProtocolDecode<'de, C>, + { + match self.stream { + CompressionStream::NotCompressed(ref mut s) => { + s.read_with(byte_len, context, &mut self.socket).await + } + #[cfg(feature = "compression")] + CompressionStream::Compressed(ref mut s) => { + s.read_with(byte_len, context, &mut self.socket).await + } + } + } + + pub(crate) fn write_with<'en, 'stream, T>( + &mut self, + value: T, + context: (Capabilities, &'stream mut u8), + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, (Capabilities, &'stream mut u8)>, + { + match self.stream { + CompressionStream::NotCompressed(ref mut s) => { + s.write_with(value, context, &mut self.socket) + } + #[cfg(feature = "compression")] + CompressionStream::Compressed(ref mut s) => { + s.write_with(value, context, &mut self.socket) + } + } + } +} + +enum CompressionStream { + NotCompressed(NoCompressionStream), + #[cfg(feature = "compression")] + Compressed(CompressedStream), +} + +struct NoCompressionStream {} +impl NoCompressionStream { + async fn read_with<'de, T, C, S: Socket>( + &mut self, + byte_len: usize, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result + where + T: ProtocolDecode<'de, C>, + { + buffered_socket.read_with(byte_len, context).await + } + + fn write_with<'en, 'stream, T, C, S: Socket>( + &mut self, + packet: T, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, C>, + { + buffered_socket.write_with(packet, context) + } +} + +#[cfg(feature = "compression")] +mod compressed_stream { + use crate::protocol::{CompressedPacket, CompressedPacketContext}; + use crate::CompressionConfig; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + use sqlx_core::io::{ProtocolDecode, ProtocolEncode}; + use sqlx_core::net::{BufferedSocket, Socket}; + use sqlx_core::Error; + use std::cmp::min; + + pub(crate) struct CompressedStream { + compression: CompressionConfig, + sequence_id: u8, + last_read_packet: Option, + } + + impl CompressedStream { + pub(crate) fn new(compression: CompressionConfig) -> Self { + Self { + sequence_id: 0, + last_read_packet: None, + compression, + } + } + + async fn receive_packet( + &mut self, + buffered_socket: &mut BufferedSocket, + ) -> Result { + let mut header: Bytes = buffered_socket.read(7).await?; + #[allow(clippy::cast_possible_truncation)] + let compressed_payload_length = header.get_uint_le(3) as usize; + let sequence_id = header.get_u8(); + let uncompressed_payload_length = header.get_uint_le(3); + + self.sequence_id = sequence_id.wrapping_add(1); + + let packet = if uncompressed_payload_length > 0 { + let compressed_context = CompressedPacketContext { + nested_context: (), + sequence_id: &mut self.sequence_id, + compression: self.compression, + }; + let compressed_payload: CompressedPacket = buffered_socket + .read_with(compressed_payload_length, compressed_context) + .await?; + + compressed_payload.0 + } else { + let uncompressed_payload: Bytes = buffered_socket + .read_with(compressed_payload_length, ()) + .await?; + + uncompressed_payload + }; + + Ok(packet) + } + + pub(crate) async fn read_with<'de, T, C, S: Socket>( + &mut self, + byte_len: usize, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result + where + T: ProtocolDecode<'de, C>, + { + let mut result_buffer = BytesMut::with_capacity(byte_len); + while result_buffer.len() != byte_len { + let current_packet = match self.last_read_packet.as_mut() { + None => { + let received_packet = self.receive_packet(buffered_socket).await?; + self.last_read_packet = Some(received_packet); + self.last_read_packet.as_mut().unwrap() + } + Some(p) => p, + }; + + let remaining_bytes_count = byte_len.saturating_sub(result_buffer.len()); + let available_bytes_count = min(current_packet.len(), remaining_bytes_count); + let chunk = current_packet.split_to(available_bytes_count); + result_buffer.put_slice(chunk.chunk()); + + if current_packet.is_empty() { + self.last_read_packet = None + } + } + + T::decode_with(result_buffer.freeze(), context) + } + + pub(crate) fn write_with<'en, T, C, S: Socket>( + &mut self, + packet: T, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, C>, + { + self.sequence_id = 0; + let compressed_packet = CompressedPacket(packet); + buffered_socket.write_with( + compressed_packet, + CompressedPacketContext { + nested_context: context, + sequence_id: &mut self.sequence_id, + compression: self.compression, + }, + ) + } + } +} + +#[cfg(feature = "compression")] +impl CompressionConfig { + fn is_supported(&self, capabilities: &Capabilities) -> bool { + match self.0 { + Compression::Zlib => capabilities.contains(Capabilities::COMPRESS), + Compression::Zstd => capabilities.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM), + } + } +} diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index f61654d876..1ca62c4571 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -1,6 +1,3 @@ -use bytes::buf::Buf; -use bytes::Bytes; - use crate::common::StatementCache; use crate::connection::{tls, MySqlConnectionInner, MySqlStream, MAX_PACKET_SIZE}; use crate::error::Error; @@ -10,6 +7,8 @@ use crate::protocol::connect::{ }; use crate::protocol::Capabilities; use crate::{MySqlConnectOptions, MySqlConnection, MySqlSslMode}; +use bytes::buf::Buf; +use bytes::Bytes; impl MySqlConnection { pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result { @@ -112,6 +111,7 @@ impl<'a> DoHandshake<'a> { database: options.database.as_deref(), auth_plugin: plugin, auth_response: auth_response.as_deref(), + compression: options.compression, })?; stream.flush().await?; @@ -121,7 +121,7 @@ impl<'a> DoHandshake<'a> { match packet[0] { 0x00 => { let _ok = packet.ok()?; - + stream = stream.maybe_enable_compression(options); break; } diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index 569ad32722..8d4a69db34 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -16,6 +16,7 @@ use crate::transaction::Transaction; use crate::{MySql, MySqlConnectOptions}; mod auth; +mod compression; mod establish; mod executor; mod stream; diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index ff931b2f46..7f72a85cd7 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -1,19 +1,21 @@ use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; -use bytes::{Buf, Bytes, BytesMut}; - +use crate::connection::compression::CompressionMySqlStream; use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::{ProtocolDecode, ProtocolEncode}; use crate::net::{BufferedSocket, Socket}; +#[cfg(feature = "compression")] +use crate::options::Compression; use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::protocol::{Capabilities, Packet}; use crate::{MySqlConnectOptions, MySqlDatabaseError}; +use bytes::{Buf, Bytes, BytesMut}; pub struct MySqlStream> { // Wrapping the socket in `Box` allows us to unsize in-place. - pub(crate) socket: BufferedSocket, + pub(crate) compression_stream: CompressionMySqlStream, pub(crate) server_version: (u16, u16, u16), pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, @@ -49,19 +51,27 @@ impl MySqlStream { capabilities |= Capabilities::CONNECT_WITH_DB; } + #[cfg(feature = "compression")] + if let Some(compression) = options.compression { + match compression.0 { + Compression::Zlib => capabilities |= Capabilities::COMPRESS, + Compression::Zstd => capabilities |= Capabilities::ZSTD_COMPRESSION_ALGORITHM, + } + } + Self { waiting: VecDeque::new(), capabilities, server_version: (0, 0, 0), sequence_id: 0, - socket: BufferedSocket::new(socket), + compression_stream: CompressionMySqlStream::not_compressed(BufferedSocket::new(socket)), is_tls: false, } } pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { - if !self.socket.write_buffer().is_empty() { - self.socket.flush().await?; + if !self.write_buffer().is_empty() { + self.flush().await?; } while !self.waiting.is_empty() { @@ -112,7 +122,7 @@ impl MySqlStream { where T: ProtocolEncode<'en, Capabilities>, { - self.socket + self.compression_stream .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)) } @@ -120,7 +130,7 @@ impl MySqlStream { // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html // https://mariadb.com/kb/en/library/0-packet/#standard-packet - let mut header: Bytes = self.socket.read(4).await?; + let mut header: Bytes = self.compression_stream.read_with(4, ()).await?; // cannot overflow #[allow(clippy::cast_possible_truncation)] @@ -129,9 +139,7 @@ impl MySqlStream { self.sequence_id = sequence_id.wrapping_add(1); - let payload: Bytes = self.socket.read(packet_size).await?; - - // TODO: packet compression + let payload: Bytes = self.compression_stream.read_with(packet_size, ()).await?; Ok(payload) } @@ -207,7 +215,22 @@ impl MySqlStream { pub fn boxed_socket(self) -> MySqlStream { MySqlStream { - socket: self.socket.boxed(), + compression_stream: self.compression_stream.boxed(), + server_version: self.server_version, + capabilities: self.capabilities, + sequence_id: self.sequence_id, + waiting: self.waiting, + is_tls: self.is_tls, + } + } + + pub fn maybe_enable_compression(self, options: &MySqlConnectOptions) -> Self { + MySqlStream { + compression_stream: CompressionMySqlStream::create( + self.compression_stream.socket, + &self.capabilities, + options.compression, + ), server_version: self.server_version, capabilities: self.capabilities, sequence_id: self.sequence_id, @@ -221,12 +244,12 @@ impl Deref for MySqlStream { type Target = BufferedSocket; fn deref(&self) -> &Self::Target { - &self.socket + &self.compression_stream.socket } } impl DerefMut for MySqlStream { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.socket + &mut self.compression_stream.socket } } diff --git a/sqlx-mysql/src/connection/tls.rs b/sqlx-mysql/src/connection/tls.rs index 9034fbd63a..b363b19c32 100644 --- a/sqlx-mysql/src/connection/tls.rs +++ b/sqlx-mysql/src/connection/tls.rs @@ -1,3 +1,4 @@ +use crate::connection::compression::CompressionMySqlStream; use crate::connection::{MySqlStream, Waiting}; use crate::error::Error; use crate::net::tls::TlsConfig; @@ -74,7 +75,7 @@ pub(super) async fn maybe_upgrade( stream.flush().await?; tls::handshake( - stream.socket.into_inner(), + stream.compression_stream.socket.into_inner(), tls_config, MapStream { server_version: stream.server_version, @@ -91,7 +92,9 @@ impl WithSocket for MapStream { async fn with_socket(self, socket: S) -> Self::Output { MySqlStream { - socket: BufferedSocket::new(Box::new(socket)), + compression_stream: CompressionMySqlStream::not_compressed(BufferedSocket::new( + Box::new(socket), + )), server_version: self.server_version, capabilities: self.capabilities, sequence_id: self.sequence_id, diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index 7aa14256f3..da4b7ae715 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -42,7 +42,7 @@ pub use column::MySqlColumn; pub use connection::MySqlConnection; pub use database::MySql; pub use error::MySqlDatabaseError; -pub use options::{MySqlConnectOptions, MySqlSslMode}; +pub use options::{Compression, CompressionConfig, MySqlConnectOptions, MySqlSslMode}; pub use query_result::MySqlQueryResult; pub use row::MySqlRow; pub use statement::MySqlStatement; diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index 421bfb700e..6f1cd61ef1 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "compression")] +use sqlx_core::Error; use std::path::{Path, PathBuf}; mod connect; @@ -80,6 +82,93 @@ pub struct MySqlConnectOptions { pub(crate) no_engine_substitution: bool, pub(crate) timezone: Option, pub(crate) set_names: bool, + pub(crate) compression: Option, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct CompressionConfig( + pub(crate) Compression, + #[cfg_attr(not(feature = "compression"), allow(dead_code))] pub(crate) u8, +); + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum Compression { + #[cfg(feature = "compression")] + Zlib, + #[cfg(feature = "compression")] + Zstd, +} + +#[cfg(feature = "compression")] +impl Compression { + /// Selects a default compression level optimized for both encoding speed and output size. + pub fn default(self) -> CompressionConfig { + match self { + Compression::Zlib => CompressionConfig(self, 5), + Compression::Zstd => CompressionConfig(self, 11), + } + } + + /// Optimize for the best speed of encoding. + pub fn fast(self) -> CompressionConfig { + CompressionConfig(self, 1) + } + + /// Optimize for the size of data being encoded. + pub fn best(self) -> CompressionConfig { + match self { + Compression::Zlib => CompressionConfig(self, 9), + Compression::Zstd => CompressionConfig(self, 22), + } + } + + /// Sets the compression level for the current algorithm. + /// + /// Each compression method supports its own valid range of levels: + /// + /// - **Zstd:** `1` to `22` + /// - **Zlib:** `1` to `9` + /// + /// If the provided level is valid for the selected algorithm, a new + /// [`CompressionConfig`] is returned. + /// If the level is out of range, an [`Error::Configuration`] is returned. + /// + /// # Returns + /// + /// - `Ok(CompressionConfig)` if the level is valid + /// - `Err(Error)` if the level is invalid + /// + /// # Examples + /// + /// ```rust + /// # use sqlx_mysql::Compression; + /// + /// let ok = Compression::Zstd.level(5); + /// assert!(ok.is_ok()); + /// + /// let bad = Compression::Zlib.level(42); + /// assert!(bad.is_err()); + /// ``` + pub fn level(self, value: u8) -> Result { + let range = match self { + Compression::Zstd => 1..=22, + Compression::Zlib => 1..=9, + }; + + range + .contains(&value) + .then_some(CompressionConfig(self, value)) + .ok_or_else(|| { + Error::Configuration( + format!( + "Illegal compression level for {self:?}: expected {}..={}, got {value}", + range.start(), + range.end() + ) + .into(), + ) + }) + } } impl Default for MySqlConnectOptions { @@ -111,6 +200,7 @@ impl MySqlConnectOptions { no_engine_substitution: true, timezone: Some(String::from("+00:00")), set_names: true, + compression: None, } } @@ -414,6 +504,24 @@ impl MySqlConnectOptions { self.set_names = flag_val; self } + + /// Sets the compression mode for the connection. + /// + /// Data is uncompressed by default. + /// Ensure that the server supports the selected compression algorithm; + /// if it does not, the client will fall back to uncompressed mode. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_mysql::{MySqlConnectOptions, Compression}; + /// let options = MySqlConnectOptions::new() + /// .compression(Compression::Zlib.fast()); + /// ``` + pub fn compression(mut self, compression: CompressionConfig) -> Self { + self.compression = Some(compression); + self + } } impl MySqlConnectOptions { @@ -526,4 +634,20 @@ impl MySqlConnectOptions { pub fn get_collation(&self) -> Option<&str> { self.collation.as_deref() } + + /// Get compression + /// + /// # Example + /// + /// ```rust + /// #![cfg(feature = "compression")] + /// # use sqlx_mysql::{Compression, CompressionConfig, MySqlConnectOptions}; + /// let options = MySqlConnectOptions::new() + /// .compression(Compression::Zlib.fast()); + /// + /// assert!(options.get_compression().is_some()); + /// ``` + pub fn get_compression(&self) -> Option { + self.compression + } } diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index e31ddc46d4..68ccabfd19 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -1,11 +1,11 @@ -use std::str::FromStr; - +use super::MySqlConnectOptions; +use crate::error::Error; +#[cfg(feature = "compression")] +use crate::Compression; +use crate::MySqlSslMode; use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; - -use crate::{error::Error, MySqlSslMode}; - -use super::MySqlConnectOptions; +use std::str::FromStr; impl MySqlConnectOptions { pub(crate) fn parse_from_url(url: &Url) -> Result { @@ -80,6 +80,29 @@ impl MySqlConnectOptions { options = options.timezone(Some(value.to_string())); } + #[cfg(feature = "compression")] + "compression" => { + let (algorithm, level) = value.split_once(":").ok_or_else(|| { + Error::Configuration( + format!( + "Invalid compression parameter. Expected algorithm:level, but got '{}'", + value + ) + .into(), + ) + })?; + let compression = match algorithm { + "zlib" => Ok(Compression::Zlib), + "zstd" => Ok(Compression::Zstd), + _ => Err(Error::Configuration( + format!("Unknown compression algorithm: {}", algorithm).into(), + )), + }?; + let compression_config = + compression.level(level.parse().map_err(Error::config)?)?; + options = options.compression(compression_config); + } + _ => {} } } @@ -197,3 +220,19 @@ fn it_parses_timezone() { .unwrap(); assert_eq!(opts.timezone.as_deref(), Some("+08:00")); } + +#[test] +#[cfg(feature = "compression")] +fn it_parses_compression() { + let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?compression=zstd:10" + .parse() + .unwrap(); + + assert_eq!(opts.compression, Compression::Zstd.level(10).ok()); + + let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?compression=zlib:2" + .parse() + .unwrap(); + + assert_eq!(opts.compression, Compression::Zlib.level(2).ok()); +} diff --git a/sqlx-mysql/src/protocol/compressed_packet.rs b/sqlx-mysql/src/protocol/compressed_packet.rs new file mode 100644 index 0000000000..0dbc3d36cd --- /dev/null +++ b/sqlx-mysql/src/protocol/compressed_packet.rs @@ -0,0 +1,108 @@ +use crate::error::Error; +use crate::io::ProtocolEncode; +use crate::options::Compression; +use crate::CompressionConfig; +use bytes::{BufMut, Bytes}; +use flate2::read::ZlibDecoder; +use flate2::{write::ZlibEncoder, Compression as ZlibCompression}; +use sqlx_core::io::ProtocolDecode; +use std::io::{Cursor, Read, Write}; + +#[derive(Debug)] +pub(crate) struct CompressedPacket(pub(crate) T); + +pub(crate) struct CompressedPacketContext<'cs, C> { + pub(crate) nested_context: C, + pub(crate) sequence_id: &'cs mut u8, + pub(crate) compression: CompressionConfig, +} + +impl<'en, 'compressed_stream, T, C> + ProtocolEncode<'en, CompressedPacketContext<'compressed_stream, C>> for CompressedPacket +where + T: ProtocolEncode<'en, C>, +{ + fn encode_with( + &self, + buf: &mut Vec, + context: CompressedPacketContext<'compressed_stream, C>, + ) -> Result<(), Error> { + let mut uncompressed_payload = Vec::with_capacity(0xFF_FF_FF); + self.0 + .encode_with(&mut uncompressed_payload, context.nested_context)?; + + let mut chunks = uncompressed_payload.chunks(0xFF_FF_FF); + for chunk in chunks.by_ref() { + add_packet(buf, *context.sequence_id, &context.compression, chunk)?; + *context.sequence_id = context.sequence_id.wrapping_add(1); + } + + Ok(()) + } +} + +fn add_packet( + buf: &mut Vec, + sequence_id: u8, + compression: &CompressionConfig, + uncompressed_chunk: &[u8], +) -> Result<(), Error> { + let offset = buf.len(); + buf.extend_from_slice(&[0; 7]); + + let compressed_payload_length = compress(compression, uncompressed_chunk, buf)?; + + let mut header = Vec::with_capacity(7); + header.put_uint_le(compressed_payload_length as u64, 3); + header.put_u8(sequence_id); + header.put_uint_le(uncompressed_chunk.len() as u64, 3); + buf[offset..offset + 7].copy_from_slice(&header); + + Ok(()) +} + +impl<'compressed_stream, C> ProtocolDecode<'_, CompressedPacketContext<'compressed_stream, C>> + for CompressedPacket +{ + fn decode_with( + buf: Bytes, + context: CompressedPacketContext<'compressed_stream, C>, + ) -> Result { + decompress(&context.compression, buf.as_ref()).map(|d| CompressedPacket(Bytes::from(d))) + } +} + +fn compress( + compression: &CompressionConfig, + input: &[u8], + output: &mut Vec, +) -> Result { + let offset = output.len(); + let mut cursor = Cursor::new(output); + cursor.set_position(offset as u64); + + let cursor = match compression { + CompressionConfig(Compression::Zlib, level) => { + let mut encoder = ZlibEncoder::new(cursor, ZlibCompression::new(*level as u32)); + let _ = encoder.write(input)?; + encoder.finish()? + } + CompressionConfig(Compression::Zstd, level) => { + zstd::stream::copy_encode(input, &mut cursor, *level as i32)?; + cursor + } + }; + + Ok(cursor.get_ref().len().saturating_sub(offset)) +} + +fn decompress(compression: &CompressionConfig, bytes: &[u8]) -> Result, Error> { + match compression.0 { + Compression::Zlib => { + let mut out = Vec::with_capacity(bytes.len() * 2); + ZlibDecoder::new(bytes).read_to_end(&mut out)?; + Ok(out) + } + Compression::Zstd => Ok(zstd::stream::decode_all(bytes)?), + } +} diff --git a/sqlx-mysql/src/protocol/connect/handshake_response.rs b/sqlx-mysql/src/protocol/connect/handshake_response.rs index 6911419d98..c5d1bcc3d9 100644 --- a/sqlx-mysql/src/protocol/connect/handshake_response.rs +++ b/sqlx-mysql/src/protocol/connect/handshake_response.rs @@ -1,9 +1,11 @@ use crate::io::MySqlBufMutExt; use crate::io::{BufMutExt, ProtocolEncode}; +#[cfg(feature = "compression")] +use crate::options::Compression; use crate::protocol::auth::AuthPlugin; use crate::protocol::connect::ssl_request::SslRequest; use crate::protocol::Capabilities; - +use crate::CompressionConfig; // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse // https://mariadb.com/kb/en/connection/#client-handshake-response @@ -25,6 +27,10 @@ pub struct HandshakeResponse<'a> { /// Opaque authentication response pub auth_response: Option<&'a [u8]>, + + /// compression algorithm + #[cfg_attr(not(feature = "compression"), allow(dead_code))] + pub compression: Option, } impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { @@ -77,6 +83,13 @@ impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { } } + #[cfg(feature = "compression")] + if context.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM) { + if let Some(CompressionConfig(Compression::Zstd, level)) = self.compression { + buf.push(level) + } + } + Ok(()) } } diff --git a/sqlx-mysql/src/protocol/mod.rs b/sqlx-mysql/src/protocol/mod.rs index d1860f5c65..325ce456f4 100644 --- a/sqlx-mysql/src/protocol/mod.rs +++ b/sqlx-mysql/src/protocol/mod.rs @@ -1,5 +1,7 @@ pub(crate) mod auth; mod capabilities; +#[cfg(feature = "compression")] +mod compressed_packet; pub(crate) mod connect; mod packet; pub(crate) mod response; @@ -8,5 +10,7 @@ pub(crate) mod statement; pub(crate) mod text; pub(crate) use capabilities::Capabilities; +#[cfg(feature = "compression")] +pub(crate) use compressed_packet::{CompressedPacket, CompressedPacketContext}; pub(crate) use packet::Packet; pub(crate) use row::Row; diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 5d6a5ef233..cc5d2b5eab 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -3,7 +3,7 @@ use futures_util::TryStreamExt; use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow}; use sqlx::{Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo}; use sqlx_core::connection::ConnectOptions; -use sqlx_mysql::MySqlConnectOptions; +use sqlx_mysql::{Compression, MySqlConnectOptions}; use sqlx_test::{new, setup_if_needed}; use std::env; use url::Url; @@ -39,6 +39,64 @@ async fn it_connects_without_password() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_connects_with_zlib_compression() -> anyhow::Result<()> { + let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) + .context("error parsing DATABASE_URL")?; + let mut conn = MySqlConnectOptions::from_url(&url)? + .compression(Compression::Zlib.default()) + .connect() + .await?; + + let rows = sqlx::raw_sql(r#"SHOW SESSION STATUS LIKE 'Compression'"#) + .fetch_all(&mut conn) + .await?; + + let result = rows + .first() + .map(|r| r.try_get::(1).unwrap_or_default()) + .unwrap_or_default(); + + assert!(!rows.is_empty()); + assert_eq!(result, "ON"); + + Ok(()) +} + +#[sqlx_macros::test] +#[cfg(all( + not(any( + mariadb = "verylatest", + mariadb = "10_6", + mariadb = "10_11", + mariadb = "11_4", + mariadb = "11_8", + )), + feature = "mysql" +))] +async fn it_connects_with_zstd_compression() -> anyhow::Result<()> { + let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) + .context("error parsing DATABASE_URL")?; + let mut conn = MySqlConnectOptions::from_url(&url)? + .compression(Compression::Zstd.default()) + .connect() + .await?; + + let rows = sqlx::raw_sql(r#"SHOW SESSION STATUS LIKE 'Compression'"#) + .fetch_all(&mut conn) + .await?; + + let result = rows + .first() + .map(|r| r.try_get::(1).unwrap_or_default()) + .unwrap_or_default(); + + assert!(!rows.is_empty()); + assert_eq!(result, "ON"); + + Ok(()) +} + #[sqlx_macros::test] async fn it_maths() -> anyhow::Result<()> { let mut conn = new::().await?; @@ -560,6 +618,91 @@ CREATE TEMPORARY TABLE large_table (data LONGBLOB); Ok(()) } +#[sqlx_macros::test] +#[cfg(all( + not(any( + mariadb = "verylatest", + mariadb = "10_6", + mariadb = "10_11", + mariadb = "11_4", + mariadb = "11_8", + )), + feature = "mysql" +))] +async fn it_can_handle_split_packets_with_zstd_compression() -> anyhow::Result<()> { + let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) + .context("error parsing DATABASE_URL")?; + + let options = MySqlConnectOptions::from_url(&url)?.compression(Compression::Zstd.best()); + + // This will only take effect on new connections + options + .connect() + .await? + .execute("SET GLOBAL max_allowed_packet = 4294967297") + .await?; + + let mut conn = MySqlConnectOptions::from_url(&url)? + .compression(Compression::Zstd.best()) + .connect() + .await?; + conn.execute(r#" CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) + .await?; + + let data = vec![0x41; 0xFF_FF_FF * 2]; + + sqlx::query("INSERT INTO large_table (data) VALUES (?)") + .bind(&data) + .execute(&mut conn) + .await?; + + let ret: Vec = sqlx::query_scalar("SELECT * FROM large_table") + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, data); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_handle_split_packets_with_zlib_compression() -> anyhow::Result<()> { + let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) + .context("error parsing DATABASE_URL")?; + + let options = MySqlConnectOptions::from_url(&url)?.compression(Compression::Zlib.best()); + + // This will only take effect on new connections + options + .connect() + .await? + .execute("SET GLOBAL max_allowed_packet = 4294967297") + .await?; + + let mut conn = MySqlConnectOptions::from_url(&url)? + .compression(Compression::Zstd.best()) + .connect() + .await?; + + conn.execute(r#"CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) + .await?; + + let data = vec![0x41; 0xFF_FF_FF * 2]; + + sqlx::query("INSERT INTO large_table (data) VALUES (?)") + .bind(&data) + .execute(&mut conn) + .await?; + + let ret: Vec = sqlx::query_scalar("SELECT * FROM large_table") + .fetch_one(&mut conn) + .await?; + + assert_eq!(ret, data); + + Ok(()) +} + #[sqlx_macros::test] async fn test_shrink_buffers() -> anyhow::Result<()> { // We don't really have a good way to test that `.shrink_buffers()` functions as expected diff --git a/tests/x.py b/tests/x.py index e1308f2fa4..7b01ce0f54 100755 --- a/tests/x.py +++ b/tests/x.py @@ -211,7 +211,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data # https://github.com/docker-library/mysql/issues/567 if not(version == "5_7" and tls == "rustls"): run( - f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mysql {version}", service=f"mysql_{version}", tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", @@ -220,7 +220,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data ## +client-ssl if tls != "none" and not(version == "5_7" and tls == "rustls"): run( - f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mysql {version}_client_ssl no-password", database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt", service=f"mysql_{version}_client_ssl", @@ -233,7 +233,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data for version in ["verylatest", "10_11", "10_6", "10_5", "10_4"]: run( - f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mariadb {version}", service=f"mariadb_{version}", tag=f"mariadb_{version}" if runtime == "async-std" else f"mariadb_{version}_{runtime}", @@ -242,7 +242,7 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data ## +client-ssl if tls != "none": run( - f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mariadb {version}_client_ssl no-password", database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt", service=f"mariadb_{version}_client_ssl", From 12a719f8c081f874a598a053bc99e6cdadf512e1 Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Fri, 21 Nov 2025 16:06:09 -0300 Subject: [PATCH 2/6] fix typo --- .github/workflows/sqlx.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 58d449a128..69a93897a0 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -421,7 +421,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,mysql-ompression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} From 5579f5ffcc762419d16d82c1b8ca7da8173020ba Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Sat, 22 Nov 2025 08:20:09 -0300 Subject: [PATCH 3/6] add test --- sqlx-mysql/src/options/parse.rs | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index 68ccabfd19..56cf7aa6be 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -1,8 +1,8 @@ use super::MySqlConnectOptions; use crate::error::Error; -#[cfg(feature = "compression")] -use crate::Compression; use crate::MySqlSslMode; +#[cfg(feature = "compression")] +use crate::{Compression, CompressionConfig}; use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; use std::str::FromStr; @@ -166,6 +166,15 @@ impl MySqlConnectOptions { .append_pair("socket", &socket.to_string_lossy()); } + #[cfg(feature = "compression")] + if let Some(compression_config) = &self.compression { + let value = match compression_config { + CompressionConfig(Compression::Zstd, level) => format!("zstd:{}", level), + CompressionConfig(Compression::Zlib, level) => format!("zlib:{}", level), + }; + url.query_pairs_mut().append_pair("compression", &value); + } + url } } @@ -208,6 +217,25 @@ fn it_returns_the_parsed_url() { assert_eq!(expected_url, opts.build_url()); } +#[test] +#[cfg(feature = "compression")] +fn it_returns_the_build_url_with_compression_param() { + let url = "mysql://username:p@ssw0rd@hostname:3306/database"; + let opts = MySqlConnectOptions::from_str(url) + .unwrap() + .compression(Compression::Zstd.fast()); + + let mut expected_url = Url::parse(url).unwrap(); + let mut query_string = String::new(); + // MySqlConnectOptions defaults + query_string += "ssl-mode=PREFERRED&charset=utf8mb4&statement-cache-capacity=100"; + query_string += "&compression=zstd%3A1"; + + expected_url.set_query(Some(&query_string)); + + assert_eq!(expected_url, opts.build_url()); +} + #[test] fn it_parses_timezone() { let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?timezone=%2B08:00" From 6fb8b2c154b1ad136c45e4580a499917960524e9 Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Sat, 22 Nov 2025 08:26:05 -0300 Subject: [PATCH 4/6] remove code duplication from tests --- tests/mysql/mysql.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index cc5d2b5eab..0c05195ae5 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -642,10 +642,8 @@ async fn it_can_handle_split_packets_with_zstd_compression() -> anyhow::Result<( .execute("SET GLOBAL max_allowed_packet = 4294967297") .await?; - let mut conn = MySqlConnectOptions::from_url(&url)? - .compression(Compression::Zstd.best()) - .connect() - .await?; + let mut conn = options.await?; + conn.execute(r#" CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) .await?; @@ -679,10 +677,7 @@ async fn it_can_handle_split_packets_with_zlib_compression() -> anyhow::Result<( .execute("SET GLOBAL max_allowed_packet = 4294967297") .await?; - let mut conn = MySqlConnectOptions::from_url(&url)? - .compression(Compression::Zstd.best()) - .connect() - .await?; + let mut conn = options.await?; conn.execute(r#"CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) .await?; From da30e8dfced90f57a37b75270198395d421a02c2 Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Sat, 22 Nov 2025 10:48:52 -0300 Subject: [PATCH 5/6] simplify logic of CompressionMySqlStream --- sqlx-mysql/src/connection/compression.rs | 39 +++--------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/sqlx-mysql/src/connection/compression.rs b/sqlx-mysql/src/connection/compression.rs index 2fdac04874..8e7c03f0be 100644 --- a/sqlx-mysql/src/connection/compression.rs +++ b/sqlx-mysql/src/connection/compression.rs @@ -15,7 +15,7 @@ pub(crate) struct CompressionMySqlStream> { impl CompressionMySqlStream { pub(crate) fn not_compressed(socket: BufferedSocket) -> Self { - let stream = CompressionStream::NotCompressed(NoCompressionStream {}); + let stream = CompressionStream::NotCompressed; Self { stream, socket } } @@ -56,9 +56,7 @@ impl CompressionMySqlStream { T: ProtocolDecode<'de, C>, { match self.stream { - CompressionStream::NotCompressed(ref mut s) => { - s.read_with(byte_len, context, &mut self.socket).await - } + CompressionStream::NotCompressed => self.socket.read_with(byte_len, context).await, #[cfg(feature = "compression")] CompressionStream::Compressed(ref mut s) => { s.read_with(byte_len, context, &mut self.socket).await @@ -75,9 +73,7 @@ impl CompressionMySqlStream { T: ProtocolEncode<'en, (Capabilities, &'stream mut u8)>, { match self.stream { - CompressionStream::NotCompressed(ref mut s) => { - s.write_with(value, context, &mut self.socket) - } + CompressionStream::NotCompressed => self.socket.write_with(value, context), #[cfg(feature = "compression")] CompressionStream::Compressed(ref mut s) => { s.write_with(value, context, &mut self.socket) @@ -87,38 +83,11 @@ impl CompressionMySqlStream { } enum CompressionStream { - NotCompressed(NoCompressionStream), + NotCompressed, #[cfg(feature = "compression")] Compressed(CompressedStream), } -struct NoCompressionStream {} -impl NoCompressionStream { - async fn read_with<'de, T, C, S: Socket>( - &mut self, - byte_len: usize, - context: C, - buffered_socket: &mut BufferedSocket, - ) -> Result - where - T: ProtocolDecode<'de, C>, - { - buffered_socket.read_with(byte_len, context).await - } - - fn write_with<'en, 'stream, T, C, S: Socket>( - &mut self, - packet: T, - context: C, - buffered_socket: &mut BufferedSocket, - ) -> Result<(), Error> - where - T: ProtocolEncode<'en, C>, - { - buffered_socket.write_with(packet, context) - } -} - #[cfg(feature = "compression")] mod compressed_stream { use crate::protocol::{CompressedPacket, CompressedPacketContext}; From f5ca325264ca795db9589dc9f8a54132b5a5a7a4 Mon Sep 17 00:00:00 2001 From: Artem Konovalov Date: Tue, 23 Dec 2025 20:35:17 -0300 Subject: [PATCH 6/6] update --- .github/workflows/sqlx.yml | 86 ++- Cargo.toml | 11 +- README.md | 4 +- sqlx-mysql/Cargo.toml | 3 +- sqlx-mysql/src/connection/auth.rs | 4 +- sqlx-mysql/src/connection/compression.rs | 574 +++++++++++++++--- sqlx-mysql/src/connection/establish.rs | 22 +- sqlx-mysql/src/connection/stream.rs | 31 +- sqlx-mysql/src/connection/tls.rs | 10 +- sqlx-mysql/src/options/mod.rs | 88 ++- sqlx-mysql/src/options/parse.rs | 127 +++- sqlx-mysql/src/protocol/compressed_packet.rs | 108 ---- .../protocol/connect/handshake_response.rs | 19 +- sqlx-mysql/src/protocol/mod.rs | 4 - sqlx-mysql/src/transaction.rs | 2 +- tests/mysql/compression.rs | 25 + tests/mysql/mysql.rs | 140 +---- tests/mysql/rustsec.rs | 27 +- tests/x.py | 62 +- 19 files changed, 869 insertions(+), 478 deletions(-) delete mode 100644 sqlx-mysql/src/protocol/compressed_packet.rs create mode 100644 tests/mysql/compression.rs diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 69a93897a0..55700a46d0 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -343,7 +343,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - - run: cargo build --features mysql,mysql-compression,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mysql_${{ matrix.mysql }} mysql_${{ matrix.mysql }} - run: sleep 60 @@ -354,18 +354,38 @@ jobs: - run: > cargo test --no-default-features - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled SQLX_OFFLINE_DIR: .sqlx RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + # Run tests to validate zstd compression for traffic + - run: > + cargo test + --no-default-features + --features any,mysql,mysql-zstd-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled&compression=zstd:1 + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + + # Run tests to validate zlib compression for traffic + - run: > + cargo test + --no-default-features + --features any,mysql,mysql-zlib-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled&compression=zlib:1 + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + # Run the `test-attr` test again to cover cleanup. - run: > cargo test --test mysql-test-attr --no-default-features - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled SQLX_OFFLINE_DIR: .sqlx @@ -376,7 +396,7 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -390,7 +410,7 @@ jobs: cargo build --no-default-features --test mysql-macros - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: SQLX_OFFLINE: true SQLX_OFFLINE_DIR: .sqlx @@ -402,7 +422,7 @@ jobs: cargo test --no-default-features --test mysql-macros - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE: true @@ -421,11 +441,32 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + # Run tests to validate zstd compression for traffic with tls + - if: ${{ matrix.tls != 'none' }} + run: > + cargo test + --no-default-features + --features any,mysql,mysql-zstd-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zstd:1 + RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + + # Run tests to validate zlib compression for traffic with tls + - if: ${{ matrix.tls != 'none' }} + run: > + cargo test + --no-default-features + --features any,mysql,mysql-zlib-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zlib:1 + RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + + mariadb: name: MariaDB runs-on: ubuntu-24.04 @@ -444,7 +485,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - - run: cargo build --features mysql,mysql-compression,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker compose -f tests/docker-compose.yml run -d -p 3306:3306 --name mariadb_${{ matrix.mariadb }} mariadb_${{ matrix.mariadb }} - run: sleep 30 @@ -455,18 +496,28 @@ jobs: - run: > cargo test --no-default-features - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" + # Run tests to validate zlib compression for traffic + - run: > + cargo test + --no-default-features + --features any,mysql,mysql-zlib-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root:password@localhost:3306/sqlx?compression=zlib:1 + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" + # Run the `test-attr` test again to cover cleanup. - run: > cargo test --test mysql-test-attr --no-default-features - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE_DIR: .sqlx @@ -491,7 +542,7 @@ jobs: cargo test --no-default-features --test mysql-macros - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx SQLX_OFFLINE: true @@ -510,7 +561,18 @@ jobs: run: > cargo test --no-default-features - --features any,mysql,macros,mysql-compression,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + --features any,mysql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" + + + # Run tests to validate zlib compression for traffic with tls + - if: ${{ matrix.tls != 'none' }} + run: > + cargo test + --no-default-features + --features any,mysql,mysql-zlib-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zlib:1 + RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" diff --git a/Cargo.toml b/Cargo.toml index 6d5ec3cc4c..92bd9cee77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -162,7 +162,8 @@ regexp = ["sqlx-sqlite?/regexp"] bstr = ["sqlx-core/bstr"] # compression -mysql-compression = ["sqlx-mysql/compression"] +mysql-zstd-compression = ["sqlx-mysql/zstd-compression"] +mysql-zlib-compression = ["sqlx-mysql/zlib-compression"] [workspace.dependencies] # Core Crates @@ -362,7 +363,7 @@ required-features = ["sqlite"] [[test]] name = "mysql" path = "tests/mysql/mysql.rs" -required-features = ["mysql", "compression"] +required-features = ["mysql"] [[test]] name = "mysql-types" @@ -404,6 +405,12 @@ name = "mysql-rustsec" path = "tests/mysql/rustsec.rs" required-features = ["mysql"] +[[test]] +name = "mysql-compression" +path = "tests/mysql/compression.rs" +required-features = ["mysql"] + + # # PostgreSQL # diff --git a/README.md b/README.md index 2700b9aef9..099fa0093b 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,9 @@ be removed in the future. - `mysql`: Add support for the MySQL/MariaDB database server. -- `mysql-compression`: Add compression support for MySQL/MariaDB database server. +- `mysql-zlib-compression`: Add zlib compression support for MySQL/MariaDB database server. + +- `mysql-zstd-compression`: Add std compression support for MySQL database server. - `mssql`: Add support for the MSSQL database server. diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index d9eb8eea64..57eb7db9cd 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -14,7 +14,8 @@ json = ["sqlx-core/json", "serde"] any = ["sqlx-core/any"] offline = ["sqlx-core/offline", "serde/derive"] migrate = ["sqlx-core/migrate"] -compression = ["zstd", "flate2"] +zstd-compression = ["zstd"] +zlib-compression = ["flate2"] # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] diff --git a/sqlx-mysql/src/connection/auth.rs b/sqlx-mysql/src/connection/auth.rs index 613f8e702f..9739bd57a1 100644 --- a/sqlx-mysql/src/connection/auth.rs +++ b/sqlx-mysql/src/connection/auth.rs @@ -53,7 +53,7 @@ impl AuthPlugin { 0x04 => { let payload = encrypt_rsa(stream, 0x02, password, nonce).await?; - stream.write_packet(&*payload)?; + stream.write_packet(&*payload).await?; stream.flush().await?; Ok(false) @@ -143,7 +143,7 @@ async fn encrypt_rsa<'s>( } // client sends a public key request - stream.write_packet(&[public_key_request_id][..])?; + stream.write_packet(&[public_key_request_id][..]).await?; stream.flush().await?; // server sends a public key response diff --git a/sqlx-mysql/src/connection/compression.rs b/sqlx-mysql/src/connection/compression.rs index 8e7c03f0be..fd547b35e7 100644 --- a/sqlx-mysql/src/connection/compression.rs +++ b/sqlx-mysql/src/connection/compression.rs @@ -1,41 +1,56 @@ use crate::protocol::Capabilities; -#[cfg(feature = "compression")] -use crate::Compression; use crate::CompressionConfig; -#[cfg(feature = "compression")] -use compressed_stream::CompressedStream; use sqlx_core::io::{ProtocolDecode, ProtocolEncode}; use sqlx_core::net::{BufferedSocket, Socket}; use sqlx_core::Error; +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] +use {crate::Compression, compressed_stream::CompressedStream}; pub(crate) struct CompressionMySqlStream> { - stream: CompressionStream, + mode: CompressionMode, pub(crate) socket: BufferedSocket, } impl CompressionMySqlStream { pub(crate) fn not_compressed(socket: BufferedSocket) -> Self { - let stream = CompressionStream::NotCompressed; - Self { stream, socket } + let mode = CompressionMode::NotCompressed; + Self { mode, socket } } - #[cfg(feature = "compression")] + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] fn compressed(socket: BufferedSocket, compression: CompressionConfig) -> Self { - let stream = CompressionStream::Compressed(CompressedStream::new(compression)); - Self { stream, socket } + let mode = CompressionMode::Compressed(CompressedStream::new(compression)); + Self { mode, socket } } pub(crate) fn create( socket: BufferedSocket, - #[cfg_attr(not(feature = "compression"), allow(unused_variables))] + #[cfg_attr( + not(all(feature = "zstd-compression", feature = "zlib-compression")), + allow(unused_variables) + )] capabilities: &Capabilities, - compression: Option, + compression_configs: &[CompressionConfig], ) -> Self { - match compression { - #[cfg(feature = "compression")] - Some(c) if c.is_supported(&capabilities) => { - CompressionMySqlStream::compressed(socket, c) + let supported_compression = compression_configs.iter().find(|c| { + let is_supported = match c.0 { + #[cfg(feature = "zlib-compression")] + Compression::Zlib => capabilities.contains(Capabilities::COMPRESS), + #[cfg(feature = "zstd-compression")] + Compression::Zstd => { + capabilities.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM) + } + #[cfg(not(any(feature = "zstd-compression", feature = "zlib-compression")))] + _ => false, + }; + if !is_supported { + tracing::warn!("server doesn't support '{:?}' compression", c.0); } + is_supported + }); + match supported_compression { + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + Some(c) => CompressionMySqlStream::compressed(socket, *c), _ => CompressionMySqlStream::not_compressed(socket), } } @@ -43,7 +58,7 @@ impl CompressionMySqlStream { pub(crate) fn boxed(self) -> CompressionMySqlStream> { CompressionMySqlStream { socket: self.socket.boxed(), - stream: self.stream, + mode: self.mode, } } @@ -55,16 +70,33 @@ impl CompressionMySqlStream { where T: ProtocolDecode<'de, C>, { - match self.stream { - CompressionStream::NotCompressed => self.socket.read_with(byte_len, context).await, - #[cfg(feature = "compression")] - CompressionStream::Compressed(ref mut s) => { + match self.mode { + CompressionMode::NotCompressed => self.socket.read_with(byte_len, context).await, + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + CompressionMode::Compressed(ref mut s) => { s.read_with(byte_len, context, &mut self.socket).await } } } - pub(crate) fn write_with<'en, 'stream, T>( + pub(crate) async fn write_with<'en, 'stream, T>( + &mut self, + value: T, + context: (Capabilities, &'stream mut u8), + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, (Capabilities, &'stream mut u8)>, + { + match self.mode { + CompressionMode::NotCompressed => self.socket.write_with(value, context), + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + CompressionMode::Compressed(ref mut s) => { + s.write_with(value, context, &mut self.socket).await + } + } + } + + pub(crate) fn uncompressed_write_with<'en, 'stream, T>( &mut self, value: T, context: (Capabilities, &'stream mut u8), @@ -72,81 +104,58 @@ impl CompressionMySqlStream { where T: ProtocolEncode<'en, (Capabilities, &'stream mut u8)>, { - match self.stream { - CompressionStream::NotCompressed => self.socket.write_with(value, context), - #[cfg(feature = "compression")] - CompressionStream::Compressed(ref mut s) => { - s.write_with(value, context, &mut self.socket) + match self.mode { + CompressionMode::NotCompressed => self.socket.write_with(value, context), + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + CompressionMode::Compressed(ref mut s) => { + s.uncompressed_write_with(value, context, &mut self.socket) } } } } -enum CompressionStream { +enum CompressionMode { NotCompressed, - #[cfg(feature = "compression")] + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] Compressed(CompressedStream), } -#[cfg(feature = "compression")] +#[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] mod compressed_stream { - use crate::protocol::{CompressedPacket, CompressedPacketContext}; - use crate::CompressionConfig; + use crate::{Compression, CompressionConfig}; use bytes::{Buf, BufMut, Bytes, BytesMut}; + #[cfg(feature = "zlib-compression")] + use flate2::{ + write::ZlibEncoder, Compression as ZlibCompression, Decompress as ZlibDecompressor, + FlushDecompress, Status, + }; use sqlx_core::io::{ProtocolDecode, ProtocolEncode}; use sqlx_core::net::{BufferedSocket, Socket}; + use sqlx_core::rt::yield_now; use sqlx_core::Error; use std::cmp::min; + use std::io::{Cursor, Write}; + #[cfg(feature = "zstd-compression")] + use zstd::stream::{ + raw::{Decoder as ZstdDecoder, InBuffer, Operation, OutBuffer}, + Encoder as ZstdEncoder, + }; pub(crate) struct CompressedStream { - compression: CompressionConfig, + compression_config: CompressionConfig, sequence_id: u8, - last_read_packet: Option, + packet_reader: Option, } impl CompressedStream { - pub(crate) fn new(compression: CompressionConfig) -> Self { + pub(crate) fn new(compression_config: CompressionConfig) -> Self { Self { sequence_id: 0, - last_read_packet: None, - compression, + packet_reader: None, + compression_config, } } - async fn receive_packet( - &mut self, - buffered_socket: &mut BufferedSocket, - ) -> Result { - let mut header: Bytes = buffered_socket.read(7).await?; - #[allow(clippy::cast_possible_truncation)] - let compressed_payload_length = header.get_uint_le(3) as usize; - let sequence_id = header.get_u8(); - let uncompressed_payload_length = header.get_uint_le(3); - - self.sequence_id = sequence_id.wrapping_add(1); - - let packet = if uncompressed_payload_length > 0 { - let compressed_context = CompressedPacketContext { - nested_context: (), - sequence_id: &mut self.sequence_id, - compression: self.compression, - }; - let compressed_payload: CompressedPacket = buffered_socket - .read_with(compressed_payload_length, compressed_context) - .await?; - - compressed_payload.0 - } else { - let uncompressed_payload: Bytes = buffered_socket - .read_with(compressed_payload_length, ()) - .await?; - - uncompressed_payload - }; - - Ok(packet) - } - pub(crate) async fn read_with<'de, T, C, S: Socket>( &mut self, byte_len: usize, @@ -158,29 +167,65 @@ mod compressed_stream { { let mut result_buffer = BytesMut::with_capacity(byte_len); while result_buffer.len() != byte_len { - let current_packet = match self.last_read_packet.as_mut() { + let compressed_packet_reader = match self.packet_reader.as_mut() { None => { - let received_packet = self.receive_packet(buffered_socket).await?; - self.last_read_packet = Some(received_packet); - self.last_read_packet.as_mut().unwrap() + let packet_reader = + CompressedPacketReader::new(buffered_socket, &self.compression_config) + .await?; + self.sequence_id = packet_reader.sequence_id.wrapping_add(1); + self.packet_reader = Some(packet_reader); + self.packet_reader.as_mut().unwrap() } Some(p) => p, }; - let remaining_bytes_count = byte_len.saturating_sub(result_buffer.len()); - let available_bytes_count = min(current_packet.len(), remaining_bytes_count); - let chunk = current_packet.split_to(available_bytes_count); - result_buffer.put_slice(chunk.chunk()); + let required_bytes_count = byte_len.saturating_sub(result_buffer.len()); + let chunk = compressed_packet_reader + .read(buffered_socket, required_bytes_count) + .await?; + result_buffer.put_slice(&chunk); - if current_packet.is_empty() { - self.last_read_packet = None + if !compressed_packet_reader.is_available() { + self.packet_reader = None } } T::decode_with(result_buffer.freeze(), context) } - pub(crate) fn write_with<'en, T, C, S: Socket>( + pub(crate) async fn write_with<'en, T, C, S: Socket>( + &mut self, + packet: T, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, C>, + { + self.sequence_id = 0; + let mut uncompressed_payload = Vec::with_capacity(0xFF_FF_FF); + packet.encode_with(&mut uncompressed_payload, context)?; + + let mut uncompressed_chunks = uncompressed_payload.chunks(0xFF_FF_FF); + for uncompressed_chunk in uncompressed_chunks.by_ref() { + let mut compressed_payload = Vec::with_capacity(uncompressed_chunk.len() + 7); + Self::add_compressed_packet( + self.sequence_id, + &self.compression_config, + &mut compressed_payload, + uncompressed_chunk, + ) + .await?; + + buffered_socket.write_with(compressed_payload.as_slice(), ())?; + + self.sequence_id = self.sequence_id.wrapping_add(1); + } + + Ok(()) + } + + pub(crate) fn uncompressed_write_with<'en, T, C, S: Socket>( &mut self, packet: T, context: C, @@ -190,25 +235,360 @@ mod compressed_stream { T: ProtocolEncode<'en, C>, { self.sequence_id = 0; - let compressed_packet = CompressedPacket(packet); - buffered_socket.write_with( - compressed_packet, - CompressedPacketContext { - nested_context: context, - sequence_id: &mut self.sequence_id, - compression: self.compression, + let mut uncompressed_payload = Vec::with_capacity(0xFF_FF_FF); + packet.encode_with(&mut uncompressed_payload, context)?; + + let mut uncompressed_chunks = uncompressed_payload.chunks(0xFF_FF_FF); + for uncompressed_chunk in uncompressed_chunks.by_ref() { + let mut header = Vec::with_capacity(7); + header.put_uint_le(uncompressed_chunk.len() as u64, 3); + header.put_u8(self.sequence_id); + header.put_uint_le(0, 3); + + buffered_socket.write_with(header.as_slice(), ())?; + buffered_socket.write_with(uncompressed_chunk, ())?; + + self.sequence_id = self.sequence_id.wrapping_add(1); + } + + Ok(()) + } + + async fn add_compressed_packet( + sequence_id: u8, + compression: &CompressionConfig, + compressed_chunk: &mut Vec, + uncompressed_chunk: &[u8], + ) -> Result<(), Error> { + compressed_chunk.extend_from_slice(&[0; 7]); + + let compressed_payload_length = + Self::compress_chunk(compression, compressed_chunk, uncompressed_chunk).await?; + + let mut header = &mut compressed_chunk[0..7]; + header.put_uint_le(compressed_payload_length as u64, 3); + header.put_u8(sequence_id); + header.put_uint_le(uncompressed_chunk.len() as u64, 3); + + Ok(()) + } + + async fn compress_chunk( + compression: &CompressionConfig, + output: &mut Vec, + uncompressed_chunk: &[u8], + ) -> Result { + let offset = output.len(); + let mut cursor = Cursor::new(output); + cursor.set_position(offset as u64); + + let mut encoder = Encoder::new(compression, cursor)?; + + for chunk in uncompressed_chunk.chunks(encoder.get_chunk_size()) { + encoder.write_all(chunk)?; + yield_now().await; + } + let cursor = encoder.finish()?; + Ok(cursor.get_ref().len().saturating_sub(offset)) + } + } + + enum Encoder<'en> { + #[cfg(feature = "zlib-compression")] + Zlib(ZlibEncoder>>, u8), + #[cfg(feature = "zstd-compression")] + Zstd(ZstdEncoder<'en, Cursor<&'en mut Vec>>, u8), + } + + impl<'en> Encoder<'en> { + fn new( + compression_config: &CompressionConfig, + cursor: Cursor<&'en mut Vec>, + ) -> Result, Error> { + let encoder = match compression_config { + #[cfg(feature = "zlib-compression")] + CompressionConfig(Compression::Zlib, level) => Encoder::Zlib( + ZlibEncoder::new(cursor, ZlibCompression::new(*level as u32)), + *level, + ), + #[cfg(feature = "zstd-compression")] + CompressionConfig(Compression::Zstd, level) => { + Encoder::Zstd(ZstdEncoder::new(cursor, *level as i32)?, *level) + } + }; + Ok(encoder) + } + + fn write_all(&mut self, buf: &'en [u8]) -> Result<(), Error> { + match self { + #[cfg(feature = "zlib-compression")] + Encoder::Zlib(encoder, _) => encoder.write_all(buf)?, + #[cfg(feature = "zstd-compression")] + Encoder::Zstd(encoder, _) => encoder.write_all(buf)?, + } + Ok(()) + } + + fn finish(self) -> Result>, Error> { + let cursor = match self { + #[cfg(feature = "zlib-compression")] + Encoder::Zlib(encoder, _) => encoder.finish()?, + #[cfg(feature = "zstd-compression")] + Encoder::Zstd(encoder, _) => encoder.finish()?, + }; + Ok(cursor) + } + + // Chunk size is chosen based on lzbench benchmarks: + // https://github.com/inikep/lzbench?tab=readme-ov-file#benchmarks + // The target is to keep runtime under 50 ms. + fn get_chunk_size(&self) -> usize { + match self { + #[cfg(feature = "zlib-compression")] + Encoder::Zlib(_, level) => match level { + 1 => 4 * 1024, + 2..=4 => 2 * 1024, + 5..=6 => 1024, + _ => 512, + }, + #[cfg(feature = "zstd-compression")] + Encoder::Zstd(_, level) => match level { + 1..=2 => 16 * 1024, + 3..=4 => 8 * 1024, + 5..=6 => 4 * 1024, + 7..=10 => 2 * 1024, + 11..=12 => 1024, + 13..=14 => 512, + 15..=16 => 256, + 17..=20 => 128, + _ => 64, }, - ) + } } } -} -#[cfg(feature = "compression")] -impl CompressionConfig { - fn is_supported(&self, capabilities: &Capabilities) -> bool { - match self.0 { - Compression::Zlib => capabilities.contains(Capabilities::COMPRESS), - Compression::Zstd => capabilities.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM), + struct CompressedPacketReader { + sequence_id: u8, + remaining_bytes: usize, + is_compressed: bool, + + decoder: Decoder, + input_buffer: Bytes, + input_buffer_pos: usize, + output_buffer: BytesMut, + } + + impl CompressedPacketReader { + async fn new( + buffered_socket: &mut BufferedSocket, + compression_config: &CompressionConfig, + ) -> Result { + let mut header: Bytes = buffered_socket.read(7).await?; + #[allow(clippy::cast_possible_truncation)] + let compressed_payload_length = header.get_uint_le(3) as usize; + let sequence_id = header.get_u8(); + #[allow(clippy::cast_possible_truncation)] + let uncompressed_payload_length = header.get_uint_le(3) as usize; + let decoder = Decoder::new(compression_config)?; + + Ok(CompressedPacketReader { + sequence_id, + remaining_bytes: compressed_payload_length, + is_compressed: uncompressed_payload_length > 0, + decoder, + + input_buffer: Bytes::new(), + input_buffer_pos: 0, + output_buffer: BytesMut::with_capacity(uncompressed_payload_length), + }) + } + + fn is_available(&self) -> bool { + !self.output_buffer.is_empty() + || self.input_buffer_pos < self.input_buffer.len() + || self.remaining_bytes > 0 + } + + async fn read( + &mut self, + buffered_socket: &mut BufferedSocket, + bytes_count: usize, + ) -> Result { + let chunk = if self.is_compressed { + self.decompress(buffered_socket, bytes_count).await? + } else { + let available_bytes_count = min(self.remaining_bytes, bytes_count); + let result: Bytes = buffered_socket.read(available_bytes_count).await?; + self.remaining_bytes = self.remaining_bytes.saturating_sub(result.len()); + result + }; + + Ok(chunk) + } + + async fn decompress( + &mut self, + buffered_socket: &mut BufferedSocket, + output_bytes_count: usize, + ) -> Result { + if self.output_buffer.len() >= output_bytes_count { + return Ok(self.output_buffer.split_to(output_bytes_count).freeze()); + } + + while self.output_buffer.len() < output_bytes_count { + let mut is_refill_required = self.input_buffer_pos >= self.input_buffer.len(); + + if !is_refill_required { + let input = &self.input_buffer[self.input_buffer_pos..]; + let (consumed_bytes_total_count, produced_bytes_total_count) = + self.decoder.decompress(input, &mut self.output_buffer)?; + + self.input_buffer_pos += consumed_bytes_total_count; + + if produced_bytes_total_count == 0 { + is_refill_required = true; + } + } + + if is_refill_required { + if self.remaining_bytes == 0 { + break; + } + let available_bytes = min(self.remaining_bytes, self.decoder.get_chunk_size()); + + self.input_buffer = buffered_socket.read(available_bytes).await?; + self.input_buffer_pos = 0; + self.remaining_bytes = + self.remaining_bytes.saturating_sub(self.input_buffer.len()); + + if self.input_buffer.is_empty() { + return Err(err_protocol!("Compressed input ended unexpectedly")); + } + } + } + + let available_bytes = min(self.output_buffer.len(), output_bytes_count); + Ok(self.output_buffer.split_to(available_bytes).freeze()) + } + } + + enum Decoder { + #[cfg(feature = "zlib-compression")] + Zlib(ZlibDecompressor), + #[cfg(feature = "zstd-compression")] + Zstd(ZstdDecoder<'static>), + } + impl Decoder { + // Chunk size is chosen based on lzbench benchmarks: + // https://github.com/inikep/lzbench?tab=readme-ov-file#benchmarks + // The target is to keep runtime under 50 ms. + fn get_chunk_size(&self) -> usize { + match self { + #[cfg(feature = "zlib-compression")] + Decoder::Zlib(_) => 16 * 1024, + #[cfg(feature = "zstd-compression")] + Decoder::Zstd(_) => 32 * 1024, + } + } + + fn new(compression_config: &CompressionConfig) -> Result { + let decoder = match compression_config.0 { + #[cfg(feature = "zlib-compression")] + Compression::Zlib => Decoder::Zlib(ZlibDecompressor::new(true)), + #[cfg(feature = "zstd-compression")] + Compression::Zstd => Decoder::Zstd(ZstdDecoder::new()?), + }; + Ok(decoder) + } + + fn decompress( + &mut self, + input: &[u8], + output: &mut BytesMut, + ) -> Result<(usize, usize), Error> { + let mut produced_bytes_total_count = 0; + let mut consumed_bytes_total_count = 0; + + match self { + #[cfg(feature = "zlib-compression")] + Decoder::Zlib(decoder) => { + let mut output_buffer = [0u8; 16 * 1024]; + while consumed_bytes_total_count < input.len() { + let consumed_bytes_count_before = decoder.total_in(); + let produced_bytes_count_before = decoder.total_out(); + + let status = decoder + .decompress( + &input[consumed_bytes_total_count..], + &mut output_buffer, + FlushDecompress::None, + ) + .map_err(|e| err_protocol!("Decompression error: {}", e))?; + + #[allow(clippy::cast_possible_truncation)] + let consumed_bytes_count = + (decoder.total_in() - consumed_bytes_count_before) as usize; + #[allow(clippy::cast_possible_truncation)] + let produced_bytes_count = + (decoder.total_out() - produced_bytes_count_before) as usize; + + if produced_bytes_count > 0 { + output.extend_from_slice(&output_buffer[..produced_bytes_count]); + } + + consumed_bytes_total_count += consumed_bytes_count; + produced_bytes_total_count += produced_bytes_count; + + match status { + // Not enough input data to continue decompression + Status::BufError => break, + Status::StreamEnd => { + if consumed_bytes_total_count < input.len() { + return Err(err_protocol!("Unexpected stream end")); + } else { + break; + } + } + Status::Ok => {} + } + } + } + #[cfg(feature = "zstd-compression")] + Decoder::Zstd(decoder) => { + let mut input_chunk = input; + let mut output_buffer = [0u8; 16 * 1024]; + + while !input_chunk.is_empty() { + let mut in_buf = InBuffer::around(input_chunk); + let mut out_buf = OutBuffer::around(&mut output_buffer[..]); + + let result = decoder.run(&mut in_buf, &mut out_buf)?; + + let consumed_bytes_count = in_buf.pos(); + let produced_bytes_count = out_buf.pos(); + + input_chunk = &input_chunk[consumed_bytes_count..]; + + if produced_bytes_count > 0 { + output.extend_from_slice(&output_buffer[..produced_bytes_count]); + } + + consumed_bytes_total_count += consumed_bytes_count; + produced_bytes_total_count += produced_bytes_count; + + // No progress made; waiting for the next input chunk + if consumed_bytes_count == 0 && produced_bytes_count == 0 { + break; + } + + if result == 0 && !input_chunk.is_empty() { + return Err(err_protocol!("Unexpected stream end")); + } + } + } + }; + + Ok((consumed_bytes_total_count, produced_bytes_total_count)) } } } diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 1ca62c4571..3fd6643a6b 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -104,15 +104,17 @@ impl<'a> DoHandshake<'a> { None }; - stream.write_packet(HandshakeResponse { - charset: super::INITIAL_CHARSET, - max_packet_size: MAX_PACKET_SIZE, - username: &options.username, - database: options.database.as_deref(), - auth_plugin: plugin, - auth_response: auth_response.as_deref(), - compression: options.compression, - })?; + stream + .write_packet(HandshakeResponse { + charset: super::INITIAL_CHARSET, + max_packet_size: MAX_PACKET_SIZE, + username: &options.username, + database: options.database.as_deref(), + auth_plugin: plugin, + auth_response: auth_response.as_deref(), + compression_configs: options.get_compression(), + }) + .await?; stream.flush().await?; @@ -141,7 +143,7 @@ impl<'a> DoHandshake<'a> { ) .await?; - stream.write_packet(AuthSwitchResponse(response))?; + stream.write_packet(AuthSwitchResponse(response)).await?; stream.flush().await?; } diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index 7f72a85cd7..d7cd0074f1 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -6,7 +6,7 @@ use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::{ProtocolDecode, ProtocolEncode}; use crate::net::{BufferedSocket, Socket}; -#[cfg(feature = "compression")] +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] use crate::options::Compression; use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::protocol::{Capabilities, Packet}; @@ -51,13 +51,13 @@ impl MySqlStream { capabilities |= Capabilities::CONNECT_WITH_DB; } - #[cfg(feature = "compression")] - if let Some(compression) = options.compression { - match compression.0 { - Compression::Zlib => capabilities |= Capabilities::COMPRESS, - Compression::Zstd => capabilities |= Capabilities::ZSTD_COMPRESSION_ALGORITHM, - } - } + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + options.compression_configs.iter().for_each(|c| match c.0 { + #[cfg(feature = "zlib-compression")] + Compression::Zlib => capabilities |= Capabilities::COMPRESS, + #[cfg(feature = "zstd-compression")] + Compression::Zstd => capabilities |= Capabilities::ZSTD_COMPRESSION_ALGORITHM, + }); Self { waiting: VecDeque::new(), @@ -113,17 +113,26 @@ impl MySqlStream { T: ProtocolEncode<'en, Capabilities>, { self.sequence_id = 0; - self.write_packet(payload)?; + self.write_packet(payload).await?; self.flush().await?; Ok(()) } - pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> + pub(crate) async fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> where T: ProtocolEncode<'en, Capabilities>, { self.compression_stream .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)) + .await + } + + pub(crate) fn write_uncompressed_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> + where + T: ProtocolEncode<'en, Capabilities>, + { + self.compression_stream + .uncompressed_write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)) } async fn recv_packet_part(&mut self) -> Result { @@ -229,7 +238,7 @@ impl MySqlStream { compression_stream: CompressionMySqlStream::create( self.compression_stream.socket, &self.capabilities, - options.compression, + options.get_compression(), ), server_version: self.server_version, capabilities: self.capabilities, diff --git a/sqlx-mysql/src/connection/tls.rs b/sqlx-mysql/src/connection/tls.rs index b363b19c32..582d7676bf 100644 --- a/sqlx-mysql/src/connection/tls.rs +++ b/sqlx-mysql/src/connection/tls.rs @@ -67,10 +67,12 @@ pub(super) async fn maybe_upgrade( }; // Request TLS upgrade - stream.write_packet(SslRequest { - max_packet_size: super::MAX_PACKET_SIZE, - charset: super::INITIAL_CHARSET, - })?; + stream + .write_packet(SslRequest { + max_packet_size: super::MAX_PACKET_SIZE, + charset: super::INITIAL_CHARSET, + }) + .await?; stream.flush().await?; diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index 6f1cd61ef1..83f655c942 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "compression")] +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] use sqlx_core::Error; use std::path::{Path, PathBuf}; @@ -82,30 +82,36 @@ pub struct MySqlConnectOptions { pub(crate) no_engine_substitution: bool, pub(crate) timezone: Option, pub(crate) set_names: bool, - pub(crate) compression: Option, + pub(crate) compression_configs: Vec, } #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct CompressionConfig( pub(crate) Compression, - #[cfg_attr(not(feature = "compression"), allow(dead_code))] pub(crate) u8, + #[cfg_attr( + not(all(feature = "zlib-compression", feature = "zstd-compression")), + allow(dead_code) + )] + pub(crate) u8, ); #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum Compression { - #[cfg(feature = "compression")] + #[cfg(feature = "zlib-compression")] Zlib, - #[cfg(feature = "compression")] + #[cfg(feature = "zstd-compression")] Zstd, } -#[cfg(feature = "compression")] +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] impl Compression { /// Selects a default compression level optimized for both encoding speed and output size. pub fn default(self) -> CompressionConfig { match self { - Compression::Zlib => CompressionConfig(self, 5), - Compression::Zstd => CompressionConfig(self, 11), + #[cfg(feature = "zlib-compression")] + Compression::Zlib => CompressionConfig(self, 6), + #[cfg(feature = "zstd-compression")] + Compression::Zstd => CompressionConfig(self, 3), } } @@ -114,10 +120,18 @@ impl Compression { CompressionConfig(self, 1) } - /// Optimize for the size of data being encoded. + /// Optimize for maximum compression ratio. + /// + /// This mode favors smaller output size at the cost of significantly slower + /// compression speed. At high levels, compression itself may become the main + /// bottleneck rather than I/O or network transfer. + /// + /// Recommended only for offline or non-latency-sensitive workloads. pub fn best(self) -> CompressionConfig { match self { + #[cfg(feature = "zlib-compression")] Compression::Zlib => CompressionConfig(self, 9), + #[cfg(feature = "zstd-compression")] Compression::Zstd => CompressionConfig(self, 22), } } @@ -129,6 +143,12 @@ impl Compression { /// - **Zstd:** `1` to `22` /// - **Zlib:** `1` to `9` /// + /// For **Zstd**, the configured level is applied on the server side. + /// + /// For **Zlib**, this setting affects only outgoing packets. Incoming data is + /// always decompressed using the server-defined compression level, which is + /// fixed at `6` and cannot be changed. + /// /// If the provided level is valid for the selected algorithm, a new /// [`CompressionConfig`] is returned. /// If the level is out of range, an [`Error::Configuration`] is returned. @@ -142,16 +162,22 @@ impl Compression { /// /// ```rust /// # use sqlx_mysql::Compression; - /// - /// let ok = Compression::Zstd.level(5); - /// assert!(ok.is_ok()); - /// + /// # #[cfg(feature = "zstd-compression")] + /// # { + /// let good = Compression::Zstd.level(5); + /// assert!(good.is_ok()); + /// # } + /// # #[cfg(feature = "zlib-compression")] + /// # { /// let bad = Compression::Zlib.level(42); /// assert!(bad.is_err()); + /// # } /// ``` pub fn level(self, value: u8) -> Result { let range = match self { + #[cfg(feature = "zstd-compression")] Compression::Zstd => 1..=22, + #[cfg(feature = "zlib-compression")] Compression::Zlib => 1..=9, }; @@ -200,7 +226,7 @@ impl MySqlConnectOptions { no_engine_substitution: true, timezone: Some(String::from("+00:00")), set_names: true, - compression: None, + compression_configs: vec![], } } @@ -505,21 +531,29 @@ impl MySqlConnectOptions { self } - /// Sets the compression mode for the connection. + /// Sets the compression configuration for the connection. + /// + /// Compression is disabled by default. /// - /// Data is uncompressed by default. - /// Ensure that the server supports the selected compression algorithm; - /// if it does not, the client will fall back to uncompressed mode. + /// The client will negotiate compression with the server using the provided + /// configurations, in the given order. The first compression algorithm + /// supported by the server will be selected. If none of the specified + /// algorithms are supported, the connection falls back to uncompressed mode. /// /// # Example /// /// ```rust /// # use sqlx_mysql::{MySqlConnectOptions, Compression}; /// let options = MySqlConnectOptions::new() - /// .compression(Compression::Zlib.fast()); + /// .compression(vec![ + ///# #[cfg(feature = "zlib-compression")] + /// Compression::Zlib.fast(), + ///# #[cfg(feature = "zstd-compression")] + /// Compression::Zstd.default(), + /// ]); /// ``` - pub fn compression(mut self, compression: CompressionConfig) -> Self { - self.compression = Some(compression); + pub fn compression(mut self, compression: Vec) -> Self { + self.compression_configs = compression; self } } @@ -640,14 +674,16 @@ impl MySqlConnectOptions { /// # Example /// /// ```rust - /// #![cfg(feature = "compression")] + /// # #[cfg(feature = "zlib-compression")] + /// # { /// # use sqlx_mysql::{Compression, CompressionConfig, MySqlConnectOptions}; /// let options = MySqlConnectOptions::new() - /// .compression(Compression::Zlib.fast()); + /// .compression(vec![Compression::Zlib.fast()]); /// - /// assert!(options.get_compression().is_some()); + /// assert_eq!(options.get_compression(), &[Compression::Zlib.fast()]); + /// # } /// ``` - pub fn get_compression(&self) -> Option { - self.compression + pub fn get_compression(&self) -> &[CompressionConfig] { + &self.compression_configs } } diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index 56cf7aa6be..44f8e27f22 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -1,7 +1,7 @@ use super::MySqlConnectOptions; use crate::error::Error; use crate::MySqlSslMode; -#[cfg(feature = "compression")] +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] use crate::{Compression, CompressionConfig}; use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; @@ -80,27 +80,34 @@ impl MySqlConnectOptions { options = options.timezone(Some(value.to_string())); } - #[cfg(feature = "compression")] + #[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] "compression" => { - let (algorithm, level) = value.split_once(":").ok_or_else(|| { - Error::Configuration( - format!( - "Invalid compression parameter. Expected algorithm:level, but got '{}'", - value + let mut configs: Vec = vec![]; + for c in value.split(",") { + let (algorithm, level) = c + .split_once(":") + .ok_or_else(|| { + Error::Configuration( + format!( + "Invalid compression parameter. Expected algorithm:level, but got '{}'", + value + ).into(), ) - .into(), - ) - })?; - let compression = match algorithm { - "zlib" => Ok(Compression::Zlib), - "zstd" => Ok(Compression::Zstd), - _ => Err(Error::Configuration( - format!("Unknown compression algorithm: {}", algorithm).into(), - )), - }?; - let compression_config = - compression.level(level.parse().map_err(Error::config)?)?; - options = options.compression(compression_config); + })?; + let compression = match algorithm { + #[cfg(feature = "zlib-compression")] + "zlib" => Ok(Compression::Zlib), + #[cfg(feature = "zstd-compression")] + "zstd" => Ok(Compression::Zstd), + _ => Err(Error::Configuration( + format!("Unknown compression algorithm: {}", algorithm).into(), + )), + }?; + let compression_config = + compression.level(level.parse().map_err(Error::config)?)?; + configs.push(compression_config); + } + options = options.compression(configs); } _ => {} @@ -166,13 +173,21 @@ impl MySqlConnectOptions { .append_pair("socket", &socket.to_string_lossy()); } - #[cfg(feature = "compression")] - if let Some(compression_config) = &self.compression { - let value = match compression_config { - CompressionConfig(Compression::Zstd, level) => format!("zstd:{}", level), - CompressionConfig(Compression::Zlib, level) => format!("zlib:{}", level), - }; - url.query_pairs_mut().append_pair("compression", &value); + #[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] + if !&self.compression_configs.is_empty() { + let values = self + .compression_configs + .iter() + .map(|c| match c { + #[cfg(feature = "zstd-compression")] + CompressionConfig(Compression::Zstd, level) => format!("zstd:{}", level), + #[cfg(feature = "zlib-compression")] + CompressionConfig(Compression::Zlib, level) => format!("zlib:{}", level), + }) + .collect::>() + .join(","); + + url.query_pairs_mut().append_pair("compression", &values); } url @@ -218,12 +233,12 @@ fn it_returns_the_parsed_url() { } #[test] -#[cfg(feature = "compression")] -fn it_returns_the_build_url_with_compression_param() { +#[cfg(feature = "zstd-compression")] +fn it_returns_the_build_url_with_zstd_compression_param() { let url = "mysql://username:p@ssw0rd@hostname:3306/database"; let opts = MySqlConnectOptions::from_str(url) .unwrap() - .compression(Compression::Zstd.fast()); + .compression(vec![Compression::Zstd.fast()]); let mut expected_url = Url::parse(url).unwrap(); let mut query_string = String::new(); @@ -236,6 +251,25 @@ fn it_returns_the_build_url_with_compression_param() { assert_eq!(expected_url, opts.build_url()); } +#[test] +#[cfg(feature = "zlib-compression")] +fn it_returns_the_build_url_with_compression_params() { + let url = "mysql://username:p@ssw0rd@hostname:3306/database"; + let opts = MySqlConnectOptions::from_str(url) + .unwrap() + .compression(vec![Compression::Zlib.best()]); + + let mut expected_url = Url::parse(url).unwrap(); + let mut query_string = String::new(); + // MySqlConnectOptions defaults + query_string += "ssl-mode=PREFERRED&charset=utf8mb4&statement-cache-capacity=100"; + query_string += "&compression=zlib%3A9"; + + expected_url.set_query(Some(&query_string)); + + assert_eq!(expected_url, opts.build_url()); +} + #[test] fn it_parses_timezone() { let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?timezone=%2B08:00" @@ -250,17 +284,44 @@ fn it_parses_timezone() { } #[test] -#[cfg(feature = "compression")] +#[cfg(feature = "zstd-compression")] fn it_parses_compression() { let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?compression=zstd:10" .parse() .unwrap(); - assert_eq!(opts.compression, Compression::Zstd.level(10).ok()); + assert_eq!( + opts.get_compression(), + &[Compression::Zstd.level(10).unwrap()] + ); +} +#[test] +#[cfg(feature = "zlib-compression")] +fn it_parses_zlib_compression() { let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?compression=zlib:2" .parse() .unwrap(); - assert_eq!(opts.compression, Compression::Zlib.level(2).ok()); + assert_eq!( + opts.get_compression(), + &[Compression::Zlib.level(2).unwrap()] + ); +} + +#[test] +#[cfg(all(feature = "zlib-compression", feature = "zstd-compression"))] +fn it_parses_list_of_compression_algorithms() { + let opts: MySqlConnectOptions = + "mysql://user:password@hostname/database?compression=zlib:1,zstd:2" + .parse() + .unwrap(); + + assert_eq!( + opts.get_compression(), + &[ + Compression::Zlib.level(1).unwrap(), + Compression::Zstd.level(2).unwrap() + ] + ); } diff --git a/sqlx-mysql/src/protocol/compressed_packet.rs b/sqlx-mysql/src/protocol/compressed_packet.rs deleted file mode 100644 index 0dbc3d36cd..0000000000 --- a/sqlx-mysql/src/protocol/compressed_packet.rs +++ /dev/null @@ -1,108 +0,0 @@ -use crate::error::Error; -use crate::io::ProtocolEncode; -use crate::options::Compression; -use crate::CompressionConfig; -use bytes::{BufMut, Bytes}; -use flate2::read::ZlibDecoder; -use flate2::{write::ZlibEncoder, Compression as ZlibCompression}; -use sqlx_core::io::ProtocolDecode; -use std::io::{Cursor, Read, Write}; - -#[derive(Debug)] -pub(crate) struct CompressedPacket(pub(crate) T); - -pub(crate) struct CompressedPacketContext<'cs, C> { - pub(crate) nested_context: C, - pub(crate) sequence_id: &'cs mut u8, - pub(crate) compression: CompressionConfig, -} - -impl<'en, 'compressed_stream, T, C> - ProtocolEncode<'en, CompressedPacketContext<'compressed_stream, C>> for CompressedPacket -where - T: ProtocolEncode<'en, C>, -{ - fn encode_with( - &self, - buf: &mut Vec, - context: CompressedPacketContext<'compressed_stream, C>, - ) -> Result<(), Error> { - let mut uncompressed_payload = Vec::with_capacity(0xFF_FF_FF); - self.0 - .encode_with(&mut uncompressed_payload, context.nested_context)?; - - let mut chunks = uncompressed_payload.chunks(0xFF_FF_FF); - for chunk in chunks.by_ref() { - add_packet(buf, *context.sequence_id, &context.compression, chunk)?; - *context.sequence_id = context.sequence_id.wrapping_add(1); - } - - Ok(()) - } -} - -fn add_packet( - buf: &mut Vec, - sequence_id: u8, - compression: &CompressionConfig, - uncompressed_chunk: &[u8], -) -> Result<(), Error> { - let offset = buf.len(); - buf.extend_from_slice(&[0; 7]); - - let compressed_payload_length = compress(compression, uncompressed_chunk, buf)?; - - let mut header = Vec::with_capacity(7); - header.put_uint_le(compressed_payload_length as u64, 3); - header.put_u8(sequence_id); - header.put_uint_le(uncompressed_chunk.len() as u64, 3); - buf[offset..offset + 7].copy_from_slice(&header); - - Ok(()) -} - -impl<'compressed_stream, C> ProtocolDecode<'_, CompressedPacketContext<'compressed_stream, C>> - for CompressedPacket -{ - fn decode_with( - buf: Bytes, - context: CompressedPacketContext<'compressed_stream, C>, - ) -> Result { - decompress(&context.compression, buf.as_ref()).map(|d| CompressedPacket(Bytes::from(d))) - } -} - -fn compress( - compression: &CompressionConfig, - input: &[u8], - output: &mut Vec, -) -> Result { - let offset = output.len(); - let mut cursor = Cursor::new(output); - cursor.set_position(offset as u64); - - let cursor = match compression { - CompressionConfig(Compression::Zlib, level) => { - let mut encoder = ZlibEncoder::new(cursor, ZlibCompression::new(*level as u32)); - let _ = encoder.write(input)?; - encoder.finish()? - } - CompressionConfig(Compression::Zstd, level) => { - zstd::stream::copy_encode(input, &mut cursor, *level as i32)?; - cursor - } - }; - - Ok(cursor.get_ref().len().saturating_sub(offset)) -} - -fn decompress(compression: &CompressionConfig, bytes: &[u8]) -> Result, Error> { - match compression.0 { - Compression::Zlib => { - let mut out = Vec::with_capacity(bytes.len() * 2); - ZlibDecoder::new(bytes).read_to_end(&mut out)?; - Ok(out) - } - Compression::Zstd => Ok(zstd::stream::decode_all(bytes)?), - } -} diff --git a/sqlx-mysql/src/protocol/connect/handshake_response.rs b/sqlx-mysql/src/protocol/connect/handshake_response.rs index c5d1bcc3d9..a1999ab852 100644 --- a/sqlx-mysql/src/protocol/connect/handshake_response.rs +++ b/sqlx-mysql/src/protocol/connect/handshake_response.rs @@ -1,6 +1,6 @@ use crate::io::MySqlBufMutExt; use crate::io::{BufMutExt, ProtocolEncode}; -#[cfg(feature = "compression")] +#[cfg(feature = "zstd-compression")] use crate::options::Compression; use crate::protocol::auth::AuthPlugin; use crate::protocol::connect::ssl_request::SslRequest; @@ -28,9 +28,9 @@ pub struct HandshakeResponse<'a> { /// Opaque authentication response pub auth_response: Option<&'a [u8]>, - /// compression algorithm - #[cfg_attr(not(feature = "compression"), allow(dead_code))] - pub compression: Option, + /// compression configurations + #[cfg_attr(not(feature = "zstd-compression"), allow(dead_code))] + pub compression_configs: &'a [CompressionConfig], } impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { @@ -83,10 +83,15 @@ impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { } } - #[cfg(feature = "compression")] + #[cfg(feature = "zstd-compression")] if context.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM) { - if let Some(CompressionConfig(Compression::Zstd, level)) = self.compression { - buf.push(level) + let compression_config = self + .compression_configs + .iter() + .find(|c| c.0 == Compression::Zstd); + + if let Some(CompressionConfig(Compression::Zstd, level)) = compression_config { + buf.push(*level) } } diff --git a/sqlx-mysql/src/protocol/mod.rs b/sqlx-mysql/src/protocol/mod.rs index 325ce456f4..d1860f5c65 100644 --- a/sqlx-mysql/src/protocol/mod.rs +++ b/sqlx-mysql/src/protocol/mod.rs @@ -1,7 +1,5 @@ pub(crate) mod auth; mod capabilities; -#[cfg(feature = "compression")] -mod compressed_packet; pub(crate) mod connect; mod packet; pub(crate) mod response; @@ -10,7 +8,5 @@ pub(crate) mod statement; pub(crate) mod text; pub(crate) use capabilities::Capabilities; -#[cfg(feature = "compression")] -pub(crate) use compressed_packet::{CompressedPacket, CompressedPacketContext}; pub(crate) use packet::Packet; pub(crate) use row::Row; diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 18db30b183..37c1e7ef42 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -63,7 +63,7 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.stream.sequence_id = 0; conn.inner .stream - .write_packet(Query(rollback_ansi_transaction_sql(depth).as_str())) + .write_uncompressed_packet(Query(rollback_ansi_transaction_sql(depth).as_str())) .expect("BUG: unexpected error queueing ROLLBACK"); conn.inner.transaction_depth = depth - 1; diff --git a/tests/mysql/compression.rs b/tests/mysql/compression.rs new file mode 100644 index 0000000000..f8a218aaeb --- /dev/null +++ b/tests/mysql/compression.rs @@ -0,0 +1,25 @@ +#[cfg(any(feature = "mysql-zstd-compression", feature = "mysql-zlib-compression"))] +mod compression_tests { + use sqlx::Row; + use sqlx_mysql::MySql; + use sqlx_test::new; + + #[sqlx_macros::test] + async fn it_connects_with_compression() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let rows = sqlx::raw_sql(r#"SHOW SESSION STATUS LIKE 'Compression'"#) + .fetch_all(&mut conn) + .await?; + + let result = rows + .first() + .map(|r| r.try_get::(1).unwrap_or_default()) + .unwrap_or_default(); + + assert!(!rows.is_empty()); + assert_eq!(result, "ON"); + + Ok(()) + } +} diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 0c05195ae5..5d6a5ef233 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -3,7 +3,7 @@ use futures_util::TryStreamExt; use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow}; use sqlx::{Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo}; use sqlx_core::connection::ConnectOptions; -use sqlx_mysql::{Compression, MySqlConnectOptions}; +use sqlx_mysql::MySqlConnectOptions; use sqlx_test::{new, setup_if_needed}; use std::env; use url::Url; @@ -39,64 +39,6 @@ async fn it_connects_without_password() -> anyhow::Result<()> { Ok(()) } -#[sqlx_macros::test] -async fn it_connects_with_zlib_compression() -> anyhow::Result<()> { - let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) - .context("error parsing DATABASE_URL")?; - let mut conn = MySqlConnectOptions::from_url(&url)? - .compression(Compression::Zlib.default()) - .connect() - .await?; - - let rows = sqlx::raw_sql(r#"SHOW SESSION STATUS LIKE 'Compression'"#) - .fetch_all(&mut conn) - .await?; - - let result = rows - .first() - .map(|r| r.try_get::(1).unwrap_or_default()) - .unwrap_or_default(); - - assert!(!rows.is_empty()); - assert_eq!(result, "ON"); - - Ok(()) -} - -#[sqlx_macros::test] -#[cfg(all( - not(any( - mariadb = "verylatest", - mariadb = "10_6", - mariadb = "10_11", - mariadb = "11_4", - mariadb = "11_8", - )), - feature = "mysql" -))] -async fn it_connects_with_zstd_compression() -> anyhow::Result<()> { - let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) - .context("error parsing DATABASE_URL")?; - let mut conn = MySqlConnectOptions::from_url(&url)? - .compression(Compression::Zstd.default()) - .connect() - .await?; - - let rows = sqlx::raw_sql(r#"SHOW SESSION STATUS LIKE 'Compression'"#) - .fetch_all(&mut conn) - .await?; - - let result = rows - .first() - .map(|r| r.try_get::(1).unwrap_or_default()) - .unwrap_or_default(); - - assert!(!rows.is_empty()); - assert_eq!(result, "ON"); - - Ok(()) -} - #[sqlx_macros::test] async fn it_maths() -> anyhow::Result<()> { let mut conn = new::().await?; @@ -618,86 +560,6 @@ CREATE TEMPORARY TABLE large_table (data LONGBLOB); Ok(()) } -#[sqlx_macros::test] -#[cfg(all( - not(any( - mariadb = "verylatest", - mariadb = "10_6", - mariadb = "10_11", - mariadb = "11_4", - mariadb = "11_8", - )), - feature = "mysql" -))] -async fn it_can_handle_split_packets_with_zstd_compression() -> anyhow::Result<()> { - let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) - .context("error parsing DATABASE_URL")?; - - let options = MySqlConnectOptions::from_url(&url)?.compression(Compression::Zstd.best()); - - // This will only take effect on new connections - options - .connect() - .await? - .execute("SET GLOBAL max_allowed_packet = 4294967297") - .await?; - - let mut conn = options.await?; - - conn.execute(r#" CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) - .await?; - - let data = vec![0x41; 0xFF_FF_FF * 2]; - - sqlx::query("INSERT INTO large_table (data) VALUES (?)") - .bind(&data) - .execute(&mut conn) - .await?; - - let ret: Vec = sqlx::query_scalar("SELECT * FROM large_table") - .fetch_one(&mut conn) - .await?; - - assert_eq!(ret, data); - - Ok(()) -} - -#[sqlx_macros::test] -async fn it_can_handle_split_packets_with_zlib_compression() -> anyhow::Result<()> { - let url = Url::parse(&env::var("DATABASE_URL").context("expected DATABASE_URL")?) - .context("error parsing DATABASE_URL")?; - - let options = MySqlConnectOptions::from_url(&url)?.compression(Compression::Zlib.best()); - - // This will only take effect on new connections - options - .connect() - .await? - .execute("SET GLOBAL max_allowed_packet = 4294967297") - .await?; - - let mut conn = options.await?; - - conn.execute(r#"CREATE TEMPORARY TABLE large_table (data LONGBLOB);"#) - .await?; - - let data = vec![0x41; 0xFF_FF_FF * 2]; - - sqlx::query("INSERT INTO large_table (data) VALUES (?)") - .bind(&data) - .execute(&mut conn) - .await?; - - let ret: Vec = sqlx::query_scalar("SELECT * FROM large_table") - .fetch_one(&mut conn) - .await?; - - assert_eq!(ret, data); - - Ok(()) -} - #[sqlx_macros::test] async fn test_shrink_buffers() -> anyhow::Result<()> { // We don't really have a good way to test that `.shrink_buffers()` functions as expected diff --git a/tests/mysql/rustsec.rs b/tests/mysql/rustsec.rs index 8d8db0c250..41ad56753c 100644 --- a/tests/mysql/rustsec.rs +++ b/tests/mysql/rustsec.rs @@ -1,4 +1,5 @@ use sqlx::{Error, MySql}; +use sqlx_mysql::MySqlDatabaseError; use std::io; use sqlx_test::new; @@ -29,8 +30,8 @@ async fn rustsec_2024_0363() -> anyhow::Result<()> { "CREATE TEMPORARY TABLE injection_target(id INTEGER PRIMARY KEY AUTO_INCREMENT, message TEXT);\n\ INSERT INTO injection_target(message) VALUES ('existing message');", ) - .execute(&mut conn) - .await?; + .execute(&mut conn) + .await?; // We can't concatenate a query string together like the other tests // because it would just demonstrate a regular old SQL injection. @@ -42,16 +43,22 @@ async fn rustsec_2024_0363() -> anyhow::Result<()> { if let Err(e) = res { // Connection rejected the query; we're happy. // - // Current observed behavior is that `mysqld` closes the connection before we're even done - // sending the message, giving us a "Broken pipe" error. + // If a packet exceeds `max_allowed_packet`, MySQL returns ER_NET_PACKET_TOO_LARGE + // and closes the connection. Depending on timing, the client may instead observe + // "Lost connection to MySQL server during query" or a local "Broken pipe" error. // - // As it turns out, MySQL has a tight limit on packet sizes (even after splitting) - // by default: https://dev.mysql.com/doc/refman/8.4/en/packet-too-large.html - if matches!(e, Error::Io(ref ioe) if ioe.kind() == io::ErrorKind::BrokenPipe) { - return Ok(()); + // See: https://dev.mysql.com/doc/refman/8.4/en/packet-too-large.html + match e { + Error::Database(ref dbe) => { + let err_net_packet_too_large = 1153; + return match dbe.try_downcast_ref::() { + Some(error) if error.number() == err_net_packet_too_large => Ok(()), + _ => panic!("unexpected error: {e:?}"), + }; + } + Error::Io(ref ioe) if ioe.kind() == io::ErrorKind::BrokenPipe => return Ok(()), + _ => panic!("unexpected error: {e:?}"), } - - panic!("unexpected error: {e:?}"); } let messages: Vec = diff --git a/tests/x.py b/tests/x.py index 7b01ce0f54..34e06c474d 100755 --- a/tests/x.py +++ b/tests/x.py @@ -81,13 +81,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data environ["RUSTFLAGS"] = "--cfg sqlite_ipaddr" if platform.system() == "Linux": if os.environ.get("LD_LIBRARY_PATH"): - environ["LD_LIBRARY_PATH"]= os.environ.get("LD_LIBRARY_PATH") + ":"+ os.getcwd() + environ["LD_LIBRARY_PATH"] = os.environ.get("LD_LIBRARY_PATH") + ":" + os.getcwd() else: - environ["LD_LIBRARY_PATH"]=os.getcwd() - + environ["LD_LIBRARY_PATH"] = os.getcwd() if service is not None: - database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", cwd=dir_tests) + database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", + cwd=dir_tests) if database_url_args: database_url += "?" + database_url_args @@ -209,23 +209,51 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data for version in ["8", "5_7"]: # Since docker mysql 5.7 using yaSSL(It only supports TLSv1.1), avoid running when using rustls. # https://github.com/docker-library/mysql/issues/567 - if not(version == "5_7" and tls == "rustls"): + if not (version == "5_7" and tls == "rustls"): run( - f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mysql {version}", service=f"mysql_{version}", tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", ) + run( + f"cargo test --no-default-features --features any,mysql,mysql-zlib-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mysql {version} zlib-compression", + database_url_args="compression=zlib:1", + service=f"mysql_{version}", + tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", + ) + run( + f"cargo test --no-default-features --features any,mysql,mysql-zstd-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mysql {version} zstd-compression", + database_url_args="compression=zstd:1", + service=f"mysql_{version}", + tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", + ) ## +client-ssl - if tls != "none" and not(version == "5_7" and tls == "rustls"): + if tls != "none" and not (version == "5_7" and tls == "rustls"): run( - f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mysql {version}_client_ssl no-password", database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt", service=f"mysql_{version}_client_ssl", tag=f"mysql_{version}_client_ssl_no_password" if runtime == "async-std" else f"mysql_{version}_client_ssl_no_password_{runtime}", ) + run( + f"cargo test --no-default-features --features any,mysql,mysql,mysql-zlib-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mysql {version}_client_ssl no-password zlib-compression", + database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zlib:1", + service=f"mysql_{version}_client_ssl", + tag=f"mysql_{version}_client_ssl_no_password" if runtime == "async-std" else f"mysql_{version}_client_ssl_no_password_{runtime}", + ) + run( + f"cargo test --no-default-features --features any,mysql,mysql,mysql-zstd-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mysql {version}_client_ssl no-password zstd-compression", + database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zstd:1", + service=f"mysql_{version}_client_ssl", + tag=f"mysql_{version}_client_ssl_no_password" if runtime == "async-std" else f"mysql_{version}_client_ssl_no_password_{runtime}", + ) # # mariadb @@ -233,21 +261,35 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data for version in ["verylatest", "10_11", "10_6", "10_5", "10_4"]: run( - f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mariadb {version}", service=f"mariadb_{version}", tag=f"mariadb_{version}" if runtime == "async-std" else f"mariadb_{version}_{runtime}", ) + run( + f"cargo test --no-default-features --features any,mysql,mysql-zlib-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mariadb {version} zlib-compression", + database_url_args="compression=zlib:1", + service=f"mariadb_{version}", + tag=f"mariadb_{version}" if runtime == "async-std" else f"mariadb_{version}_{runtime}", + ) ## +client-ssl if tls != "none": run( - f"cargo test --no-default-features --features any,mysql,macros,mysql-compression,_unstable-all-types,runtime-{runtime},tls-{tls}", + f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mariadb {version}_client_ssl no-password", database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt", service=f"mariadb_{version}_client_ssl", tag=f"mariadb_{version}_client_ssl_no_password" if runtime == "async-std" else f"mariadb_{version}_client_ssl_no_password_{runtime}", ) + run( + f"cargo test --no-default-features --features any,mysql,mysql-zlib-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mariadb {version}_client_ssl no-password zlib-compression", + database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zlib:1", + service=f"mariadb_{version}_client_ssl", + tag=f"mariadb_{version}_client_ssl_no_password" if runtime == "async-std" else f"mariadb_{version}_client_ssl_no_password_{runtime}", + ) # TODO: Use [grcov] if available # ~/.cargo/bin/grcov tests/.cache/target/debug -s sqlx-core/ -t html --llvm --branch -o ./target/debug/coverage