Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions sqlx-core/src/any/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Default for AnyArguments<'_> {

impl<'q> AnyArguments<'q> {
#[doc(hidden)]
pub fn convert_to<'a, A: Arguments<'a>>(&'a self) -> Result<A, BoxDynError>
pub fn convert_into<'a, A: Arguments<'a>>(self) -> Result<A, BoxDynError>
where
'q: 'a,
Option<i32>: Type<A::Database> + Encode<'a, A::Database>,
Expand All @@ -60,12 +60,12 @@ impl<'q> AnyArguments<'q> {
i64: Type<A::Database> + Encode<'a, A::Database>,
f32: Type<A::Database> + Encode<'a, A::Database>,
f64: Type<A::Database> + Encode<'a, A::Database>,
&'a str: Type<A::Database> + Encode<'a, A::Database>,
&'a [u8]: Type<A::Database> + Encode<'a, A::Database>,
String: Type<A::Database> + Encode<'a, A::Database>,
Vec<u8>: Type<A::Database> + Encode<'a, A::Database>,
{
let mut out = A::default();

for arg in &self.values.0 {
for arg in self.values.0 {
match arg {
AnyValueKind::Null(AnyTypeInfoKind::Null) => out.add(Option::<i32>::None),
AnyValueKind::Null(AnyTypeInfoKind::Bool) => out.add(Option::<bool>::None),
Expand All @@ -82,8 +82,8 @@ impl<'q> AnyArguments<'q> {
AnyValueKind::BigInt(i) => out.add(i),
AnyValueKind::Real(r) => out.add(r),
AnyValueKind::Double(d) => out.add(d),
AnyValueKind::Text(t) => out.add(&**t),
AnyValueKind::Blob(b) => out.add(&**b),
AnyValueKind::Text(t) => out.add(String::from(t)),
AnyValueKind::Blob(b) => out.add(Vec::from(b)),
}?
}
Ok(out)
Expand Down
5 changes: 2 additions & 3 deletions sqlx-mysql/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl AnyConnectionBackend for MySqlConnection {
arguments: Option<AnyArguments<'q>>,
) -> BoxStream<'q, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
let persistent = persistent && arguments.is_some();
let arguments = match arguments.as_ref().map(AnyArguments::convert_to).transpose() {
let arguments = match arguments.map(AnyArguments::convert_into).transpose() {
Ok(arguments) => arguments,
Err(error) => {
return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed()
Expand All @@ -111,8 +111,7 @@ impl AnyConnectionBackend for MySqlConnection {
) -> BoxFuture<'q, sqlx_core::Result<Option<AnyRow>>> {
let persistent = persistent && arguments.is_some();
let arguments = arguments
.as_ref()
.map(AnyArguments::convert_to)
.map(AnyArguments::convert_into)
.transpose()
.map_err(sqlx_core::Error::Encode);

Expand Down
5 changes: 2 additions & 3 deletions sqlx-postgres/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl AnyConnectionBackend for PgConnection {
arguments: Option<AnyArguments<'q>>,
) -> BoxStream<'q, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
let persistent = persistent && arguments.is_some();
let arguments = match arguments.as_ref().map(AnyArguments::convert_to).transpose() {
let arguments = match arguments.map(AnyArguments::convert_into).transpose() {
Ok(arguments) => arguments,
Err(error) => {
return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed()
Expand All @@ -113,8 +113,7 @@ impl AnyConnectionBackend for PgConnection {
) -> BoxFuture<'q, sqlx_core::Result<Option<AnyRow>>> {
let persistent = persistent && arguments.is_some();
let arguments = arguments
.as_ref()
.map(AnyArguments::convert_to)
.map(AnyArguments::convert_into)
.transpose()
.map_err(sqlx_core::Error::Encode);

Expand Down
44 changes: 24 additions & 20 deletions sqlx-sqlite/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ use sqlx_core::any::{
};
use sqlx_core::sql_str::SqlStr;

use crate::arguments::SqliteArgumentsBuffer;
use crate::type_info::DataType;
use sqlx_core::connection::{ConnectOptions, Connection};
use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::Executor;
use sqlx_core::transaction::TransactionManager;
use std::pin::pin;
use std::sync::Arc;

sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite);

Expand Down Expand Up @@ -203,27 +205,29 @@ impl<'a> TryFrom<&'a AnyConnectOptions> for SqliteConnectOptions {
}
}

/// Instead of `AnyArguments::convert_into()`, we can do a direct mapping and preserve the lifetime.
fn map_arguments(args: AnyArguments<'_>) -> SqliteArguments<'_> {
// Infallible alternative to AnyArguments::convert_into()
fn map_arguments(args: AnyArguments<'_>) -> SqliteArguments {
let values = args
.values
.0
.into_iter()
.map(|val| match val {
AnyValueKind::Null(_) => SqliteArgumentValue::Null,
AnyValueKind::Bool(b) => SqliteArgumentValue::Int(b as i32),
AnyValueKind::SmallInt(i) => SqliteArgumentValue::Int(i as i32),
AnyValueKind::Integer(i) => SqliteArgumentValue::Int(i),
AnyValueKind::BigInt(i) => SqliteArgumentValue::Int64(i),
AnyValueKind::Real(r) => SqliteArgumentValue::Double(r as f64),
AnyValueKind::Double(d) => SqliteArgumentValue::Double(d),
AnyValueKind::Text(t) => SqliteArgumentValue::Text(Arc::new(t.to_string())),
AnyValueKind::Blob(b) => SqliteArgumentValue::Blob(Arc::new(b.to_vec())),
// AnyValueKind is `#[non_exhaustive]` but we should have covered everything
_ => unreachable!("BUG: missing mapping for {val:?}"),
})
.collect();

SqliteArguments {
values: args
.values
.0
.into_iter()
.map(|val| match val {
AnyValueKind::Null(_) => SqliteArgumentValue::Null,
AnyValueKind::Bool(b) => SqliteArgumentValue::Int(b as i32),
AnyValueKind::SmallInt(i) => SqliteArgumentValue::Int(i as i32),
AnyValueKind::Integer(i) => SqliteArgumentValue::Int(i),
AnyValueKind::BigInt(i) => SqliteArgumentValue::Int64(i),
AnyValueKind::Real(r) => SqliteArgumentValue::Double(r as f64),
AnyValueKind::Double(d) => SqliteArgumentValue::Double(d),
AnyValueKind::Text(t) => SqliteArgumentValue::Text(t),
AnyValueKind::Blob(b) => SqliteArgumentValue::Blob(b),
// AnyValueKind is `#[non_exhaustive]` but we should have covered everything
_ => unreachable!("BUG: missing mapping for {val:?}"),
})
.collect(),
values: SqliteArgumentsBuffer::new(values),
}
}

Expand Down
67 changes: 30 additions & 37 deletions sqlx-sqlite/src/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,56 @@ use crate::statement::StatementHandle;
use crate::Sqlite;
use atoi::atoi;
use libsqlite3_sys::SQLITE_OK;
use std::borrow::Cow;
use std::sync::Arc;

pub(crate) use sqlx_core::arguments::*;
use sqlx_core::error::BoxDynError;

#[derive(Debug, Clone)]
pub enum SqliteArgumentValue<'q> {
pub enum SqliteArgumentValue {
Null,
Text(Cow<'q, str>),
Blob(Cow<'q, [u8]>),
Text(Arc<String>),
TextSlice(Arc<str>),
Blob(Arc<Vec<u8>>),
Comment on lines +15 to +17
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're copying at this level, we could actually avoid a copy later if we tell SQLite the pointer is going to remain valid: https://sqlite.org/c3ref/bind_blob.html

The fifth argument to the BLOB and string binding interfaces controls or indicates the lifetime of the object referenced by the third parameter. These three options exist: (1) A destructor to dispose of the BLOB or string after SQLite has finished with it may be passed. It is called to dispose of the BLOB or string even if the call to the bind API fails, except the destructor is not called if the third parameter is a NULL pointer or the fourth parameter is negative. (2) The special constant, SQLITE_STATIC, may be passed to indicate that the application remains responsible for disposing of the object. In this case, the object and the provided pointer to it must remain valid until either the prepared statement is finalized or the same SQL parameter is bound to something else, whichever occurs sooner. (3) The constant, SQLITE_TRANSIENT, may be passed to indicate that the object is to be copied prior to the return from sqlite3_bind_*(). The object and pointer to it must remain valid until then. SQLite will then manage the lifetime of its private copy.

We currently pass SQLITE_TRANSIENT. Theoretically SQLITE_STATIC should be perfectly fine, though I'm leaning toward passing a destructor and letting SQLite manage the lifetime as that should be less to worry about.

I'm not saying we have to add that here, I just wanted to note this somewhere for posterity.

Double(f64),
Int(i32),
Int64(i64),
}

#[derive(Default, Debug, Clone)]
pub struct SqliteArguments<'q> {
pub(crate) values: Vec<SqliteArgumentValue<'q>>,
pub struct SqliteArguments {
pub(crate) values: SqliteArgumentsBuffer,
}

impl<'q> SqliteArguments<'q> {
#[derive(Default, Debug, Clone)]
pub struct SqliteArgumentsBuffer(Vec<SqliteArgumentValue>);

impl<'q> SqliteArguments {
pub(crate) fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
where
T: Encode<'q, Sqlite>,
{
let value_length_before_encoding = self.values.len();
let value_length_before_encoding = self.values.0.len();

match value.encode(&mut self.values) {
Ok(IsNull::Yes) => self.values.push(SqliteArgumentValue::Null),
Ok(IsNull::Yes) => self.values.0.push(SqliteArgumentValue::Null),
Ok(IsNull::No) => {}
Err(error) => {
// reset the value buffer to its previous value if encoding failed so we don't leave a half-encoded value behind
self.values.truncate(value_length_before_encoding);
self.values.0.truncate(value_length_before_encoding);
return Err(error);
}
};

Ok(())
}

pub(crate) fn into_static(self) -> SqliteArguments<'static> {
SqliteArguments {
values: self
.values
.into_iter()
.map(SqliteArgumentValue::into_static)
.collect(),
}
}
}

impl<'q> Arguments<'q> for SqliteArguments<'q> {
impl<'q> Arguments<'q> for SqliteArguments {
type Database = Sqlite;

fn reserve(&mut self, len: usize, _size_hint: usize) {
self.values.reserve(len);
self.values.0.reserve(len);
}

fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
Expand All @@ -70,11 +64,11 @@ impl<'q> Arguments<'q> for SqliteArguments<'q> {
}

fn len(&self) -> usize {
self.values.len()
self.values.0.len()
}
}

impl SqliteArguments<'_> {
impl SqliteArguments {
pub(super) fn bind(&self, handle: &mut StatementHandle, offset: usize) -> Result<usize, Error> {
let mut arg_i = offset;
// for handle in &statement.handles {
Expand Down Expand Up @@ -103,7 +97,7 @@ impl SqliteArguments<'_> {
arg_i
};

if n > self.values.len() {
if n > self.values.0.len() {
// SQLite treats unbound variables as NULL
// we reproduce this here
// If you are reading this and think this should be an error, open an issue and we can
Expand All @@ -113,32 +107,31 @@ impl SqliteArguments<'_> {
break;
}

self.values[n - 1].bind(handle, param_i)?;
self.values.0[n - 1].bind(handle, param_i)?;
}

Ok(arg_i - offset)
}
}

impl SqliteArgumentValue<'_> {
fn into_static(self) -> SqliteArgumentValue<'static> {
use SqliteArgumentValue::*;
impl SqliteArgumentsBuffer {
#[allow(dead_code)] // clippy incorrectly reports this as unused
pub(crate) fn new(values: Vec<SqliteArgumentValue>) -> SqliteArgumentsBuffer {
Self(values)
}

match self {
Null => Null,
Text(text) => Text(text.into_owned().into()),
Blob(blob) => Blob(blob.into_owned().into()),
Int(v) => Int(v),
Int64(v) => Int64(v),
Double(v) => Double(v),
}
pub(crate) fn push(&mut self, value: SqliteArgumentValue) {
self.0.push(value);
}
}

impl SqliteArgumentValue {
fn bind(&self, handle: &mut StatementHandle, i: usize) -> Result<(), Error> {
use SqliteArgumentValue::*;

let status = match self {
Text(v) => handle.bind_text(i, v),
TextSlice(v) => handle.bind_text(i, v),
Blob(v) => handle.bind_blob(i, v),
Int(v) => handle.bind_int(i, *v),
Int64(v) => handle.bind_int64(i, *v),
Expand Down
14 changes: 7 additions & 7 deletions sqlx-sqlite/src/connection/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct ExecuteIter<'a> {
handle: &'a mut ConnectionHandle,
statement: &'a mut VirtualStatement,
logger: QueryLogger,
args: Option<SqliteArguments<'a>>,
args: Option<SqliteArguments>,

/// since a `VirtualStatement` can encompass multiple actual statements,
/// this keeps track of the number of arguments so far
Expand All @@ -19,12 +19,12 @@ pub struct ExecuteIter<'a> {
goto_next: bool,
}

pub(crate) fn iter<'a>(
conn: &'a mut ConnectionState,
pub(crate) fn iter(
conn: &mut ConnectionState,
query: impl SqlSafeStr,
args: Option<SqliteArguments<'a>>,
args: Option<SqliteArguments>,
persistent: bool,
) -> Result<ExecuteIter<'a>, Error> {
) -> Result<ExecuteIter<'_>, Error> {
let query = query.into_sql_str();
// fetch the cached statement or allocate a new one
let statement = conn.statements.get(query.as_str(), persistent)?;
Expand All @@ -43,7 +43,7 @@ pub(crate) fn iter<'a>(

fn bind(
statement: &mut StatementHandle,
arguments: &Option<SqliteArguments<'_>>,
arguments: &Option<SqliteArguments>,
offset: usize,
) -> Result<usize, Error> {
let mut n = 0;
Expand All @@ -56,7 +56,7 @@ fn bind(
}

impl ExecuteIter<'_> {
pub fn finish(&mut self) -> Result<(), Error> {
pub fn finish(self) -> Result<(), Error> {
for res in self {
let _ = res?;
}
Expand Down
6 changes: 3 additions & 3 deletions sqlx-sqlite/src/connection/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ enum Command {
},
Execute {
query: SqlStr,
arguments: Option<SqliteArguments<'static>>,
arguments: Option<SqliteArguments>,
persistent: bool,
tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
limit: Option<usize>,
Expand Down Expand Up @@ -360,7 +360,7 @@ impl ConnectionWorker {
pub(crate) async fn execute(
&mut self,
query: SqlStr,
args: Option<SqliteArguments<'_>>,
args: Option<SqliteArguments>,
chan_size: usize,
persistent: bool,
limit: Option<usize>,
Expand All @@ -371,7 +371,7 @@ impl ConnectionWorker {
.send_async((
Command::Execute {
query,
arguments: args.map(SqliteArguments::into_static),
arguments: args,
persistent,
tx,
limit,
Expand Down
10 changes: 5 additions & 5 deletions sqlx-sqlite/src/database.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub(crate) use sqlx_core::database::{Database, HasStatementCache};

use crate::arguments::SqliteArgumentsBuffer;
use crate::{
SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnection, SqliteQueryResult,
SqliteRow, SqliteStatement, SqliteTransactionManager, SqliteTypeInfo, SqliteValue,
SqliteValueRef,
SqliteArguments, SqliteColumn, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement,
SqliteTransactionManager, SqliteTypeInfo, SqliteValue, SqliteValueRef,
};

/// Sqlite database driver.
Expand All @@ -26,8 +26,8 @@ impl Database for Sqlite {
type Value = SqliteValue;
type ValueRef<'r> = SqliteValueRef<'r>;

type Arguments<'q> = SqliteArguments<'q>;
type ArgumentBuffer<'q> = Vec<SqliteArgumentValue<'q>>;
type Arguments<'q> = SqliteArguments;
type ArgumentBuffer<'q> = SqliteArgumentsBuffer;

type Statement = SqliteStatement;

Expand Down
4 changes: 2 additions & 2 deletions sqlx-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ extern crate sqlx_core;

use std::sync::atomic::AtomicBool;

pub use arguments::{SqliteArgumentValue, SqliteArguments};
pub use arguments::{SqliteArgumentValue, SqliteArguments, SqliteArgumentsBuffer};
pub use column::SqliteColumn;
#[cfg(feature = "deserialize")]
#[cfg_attr(docsrs, doc(cfg(feature = "deserialize")))]
Expand Down Expand Up @@ -147,7 +147,7 @@ impl<'c, T: Executor<'c, Database = Sqlite>> SqliteExecutor<'c> for T {}
pub type SqliteTransaction<'c> = sqlx_core::transaction::Transaction<'c, Sqlite>;

// NOTE: required due to the lack of lazy normalization
impl_into_arguments_for_arguments!(SqliteArguments<'q>);
impl_into_arguments_for_arguments!(SqliteArguments);
impl_column_index_for_row!(SqliteRow);
impl_column_index_for_statement!(SqliteStatement);
impl_acquire!(Sqlite, SqliteConnection);
Expand Down
Loading
Loading