diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 672fe2f97e..5de8124111 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -104,6 +104,23 @@ impl From for ParseError { } } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum RawValue<'a> { + Null, + Unset, + Value(&'a [u8]), +} + +impl<'a> RawValue<'a> { + #[inline] + pub fn as_value(&self) -> Option<&'a [u8]> { + match self { + RawValue::Value(v) => Some(v), + RawValue::Null | RawValue::Unset => None, + } + } +} + fn read_raw_bytes<'a>(count: usize, buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { if buf.len() < count { return Err(ParseError::BadIncomingData(format!( @@ -218,6 +235,22 @@ pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { Ok(v) } +pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { + let len = read_int(buf)?; + match len { + -2 => Ok(RawValue::Unset), + -1 => Ok(RawValue::Null), + len if len >= 0 => { + let v = read_raw_bytes(len as usize, buf)?; + Ok(RawValue::Value(v)) + } + len => Err(ParseError::BadIncomingData(format!( + "invalid value length: {}", + len, + ))), + } +} + pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { let len = read_short_length(buf)?; let v = read_raw_bytes(len, buf)?; diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index db67b4fab8..4faa8df501 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -16,6 +16,7 @@ use chrono::{DateTime, NaiveDate, NaiveTime, TimeZone, Utc}; use super::response::result::CqlValue; use super::types::vint_encode; +use super::types::RawValue; #[cfg(feature = "secret")] use secrecy::{ExposeSecret, Secret, Zeroize}; @@ -366,7 +367,7 @@ impl SerializedValues { Ok(()) } - pub fn iter(&self) -> impl Iterator> { + pub fn iter(&self) -> impl Iterator { SerializedValuesIterator { serialized_values: &self.serialized_values, contains_names: self.contains_names, @@ -410,7 +411,7 @@ impl SerializedValues { }) } - pub fn iter_name_value_pairs(&self) -> impl Iterator, &[u8])> { + pub fn iter_name_value_pairs(&self) -> impl Iterator, RawValue)> { let mut buf = &self.serialized_values[..]; (0..self.values_num).map(move |_| { // `unwrap()`s here are safe, as we assume type-safety: if `SerializedValues` exits, @@ -418,7 +419,7 @@ impl SerializedValues { let name = self .contains_names .then(|| types::read_string(&mut buf).unwrap()); - let serialized = types::read_bytes(&mut buf).unwrap(); + let serialized = types::read_value(&mut buf).unwrap(); (name, serialized) }) } @@ -431,7 +432,7 @@ pub struct SerializedValuesIterator<'a> { } impl<'a> Iterator for SerializedValuesIterator<'a> { - type Item = Option<&'a [u8]>; + type Item = RawValue<'a>; fn next(&mut self) -> Option { if self.serialized_values.is_empty() { @@ -443,7 +444,7 @@ impl<'a> Iterator for SerializedValuesIterator<'a> { types::read_short_bytes(&mut self.serialized_values).expect("badly encoded value name"); } - Some(types::read_bytes_opt(&mut self.serialized_values).expect("badly encoded value")) + Some(types::read_value(&mut self.serialized_values).expect("badly encoded value")) } } diff --git a/scylla-cql/src/frame/value_tests.rs b/scylla-cql/src/frame/value_tests.rs index 633e3ce7b8..003ff0116a 100644 --- a/scylla-cql/src/frame/value_tests.rs +++ b/scylla-cql/src/frame/value_tests.rs @@ -1,4 +1,4 @@ -use crate::frame::value::BatchValuesIterator; +use crate::frame::{types::RawValue, value::BatchValuesIterator}; use super::value::{ BatchValues, CqlDate, CqlTime, CqlTimestamp, MaybeUnset, SerializeValuesError, @@ -421,7 +421,10 @@ fn serialized_values() { values.write_to_request(&mut request); assert_eq!(request, vec![0, 1, 0, 0, 0, 1, 8]); - assert_eq!(values.iter().collect::>(), vec![Some([8].as_ref())]); + assert_eq!( + values.iter().collect::>(), + vec![RawValue::Value([8].as_ref())] + ); } // Add second value @@ -436,7 +439,10 @@ fn serialized_values() { assert_eq!( values.iter().collect::>(), - vec![Some([8].as_ref()), Some([0, 16].as_ref())] + vec![ + RawValue::Value([8].as_ref()), + RawValue::Value([0, 16].as_ref()) + ] ); } @@ -468,7 +474,10 @@ fn serialized_values() { assert_eq!( values.iter().collect::>(), - vec![Some([8].as_ref()), Some([0, 16].as_ref())] + vec![ + RawValue::Value([8].as_ref()), + RawValue::Value([0, 16].as_ref()) + ] ); } } @@ -498,9 +507,9 @@ fn slice_value_list() { assert_eq!( serialized.iter().collect::>(), vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) + RawValue::Value([0, 0, 0, 1].as_ref()), + RawValue::Value([0, 0, 0, 2].as_ref()), + RawValue::Value([0, 0, 0, 3].as_ref()) ] ); } @@ -515,9 +524,9 @@ fn vec_value_list() { assert_eq!( serialized.iter().collect::>(), vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) + RawValue::Value([0, 0, 0, 1].as_ref()), + RawValue::Value([0, 0, 0, 2].as_ref()), + RawValue::Value([0, 0, 0, 3].as_ref()) ] ); } @@ -530,7 +539,7 @@ fn tuple_value_list() { let serialized_vals: Vec = serialized .iter() - .map(|o: Option<&[u8]>| o.unwrap()[0]) + .map(|o: RawValue| o.as_value().unwrap()[0]) .collect(); let expected: Vec = expected.collect(); @@ -604,9 +613,9 @@ fn ref_value_list() { assert_eq!( serialized.iter().collect::>(), vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) + RawValue::Value([0, 0, 0, 1].as_ref()), + RawValue::Value([0, 0, 0, 2].as_ref()), + RawValue::Value([0, 0, 0, 3].as_ref()) ] ); } diff --git a/scylla-cql/src/types/serialize/mod.rs b/scylla-cql/src/types/serialize/mod.rs index 0cda84e252..511a7104b1 100644 --- a/scylla-cql/src/types/serialize/mod.rs +++ b/scylla-cql/src/types/serialize/mod.rs @@ -2,5 +2,11 @@ use std::{any::Any, sync::Arc}; pub mod row; pub mod value; +pub mod writers; + +pub use writers::{ + BufBackedCellValueBuilder, BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder, + CellWriter, CountingWriter, RowWriter, +}; type SerializationError = Arc; diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 2e9832412d..fe91585c8c 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -1,20 +1,25 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; + +use thiserror::Error; -use crate::frame::response::result::ColumnSpec; use crate::frame::value::ValueList; +use crate::frame::{response::result::ColumnSpec, types::RawValue}; -use super::SerializationError; +use super::{CellWriter, RowWriter, SerializationError}; +/// Contains information needed to serialize a row. pub struct RowSerializationContext<'a> { columns: &'a [ColumnSpec], } impl<'a> RowSerializationContext<'a> { + /// Returns column/bind marker specifications for given query. #[inline] pub fn columns(&self) -> &'a [ColumnSpec] { self.columns } + /// Looks up and returns a column/bind marker by name. // TODO: change RowSerializationContext to make this faster #[inline] pub fn column_by_name(&self, target: &str) -> Option<&ColumnSpec> { @@ -23,11 +28,25 @@ impl<'a> RowSerializationContext<'a> { } pub trait SerializeRow { + /// Checks if it _might_ be possible to serialize the row according to the + /// information in the context. + /// + /// This function is intended to serve as an optimization in the future, + /// if we were ever to introduce prepared statements parametrized by types. + /// + /// Sometimes, a row cannot be fully type checked right away without knowing + /// the exact values of the columns (e.g. when deserializing to `CqlValue`), + /// but it's fine to do full type checking later in `serialize`. fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError>; - fn serialize( + + /// Serializes the row according to the information in the given context. + /// + /// The function may assume that `preliminary_type_check` was called, + /// though it must not do anything unsafe if this assumption does not hold. + fn serialize( &self, ctx: &RowSerializationContext<'_>, - out: &mut Vec, + writer: &mut W, ) -> Result<(), SerializationError>; } @@ -38,12 +57,134 @@ impl SerializeRow for T { Ok(()) } - fn serialize( + fn serialize( &self, - _ctx: &RowSerializationContext<'_>, - out: &mut Vec, + ctx: &RowSerializationContext<'_>, + writer: &mut W, ) -> Result<(), SerializationError> { - self.write_to_request(out) - .map_err(|err| Arc::new(err) as SerializationError) + serialize_legacy_row(self, ctx, writer) + } +} + +pub fn serialize_legacy_row( + r: &T, + ctx: &RowSerializationContext<'_>, + writer: &mut impl RowWriter, +) -> Result<(), SerializationError> { + let serialized = + ::serialized(r).map_err(|err| Arc::new(err) as SerializationError)?; + + let mut append_value = |value: RawValue| { + let cell_writer = writer.make_cell_writer(); + let _proof = match value { + RawValue::Null => cell_writer.set_null(), + RawValue::Unset => cell_writer.set_unset(), + RawValue::Value(v) => cell_writer.set_value(v), + }; + }; + + if !serialized.has_names() { + serialized.iter().for_each(append_value); + } else { + let values_by_name = serialized + .iter_name_value_pairs() + .map(|(k, v)| (k.unwrap(), v)) + .collect::>(); + + for col in ctx.columns() { + let val = values_by_name.get(col.name.as_str()).ok_or_else(|| { + Arc::new(ValueListToSerializeRowAdapterError::NoBindMarkerWithName { + name: col.name.clone(), + }) as SerializationError + })?; + append_value(*val); + } + } + + Ok(()) +} + +#[derive(Error, Debug)] +pub enum ValueListToSerializeRowAdapterError { + #[error("There is no bind marker with name {name}, but a value for it was provided")] + NoBindMarkerWithName { name: String }, +} + +#[cfg(test)] +mod tests { + use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec}; + use crate::frame::value::{MaybeUnset, SerializedValues, ValueList}; + use crate::types::serialize::BufBackedRowWriter; + + use super::{RowSerializationContext, SerializeRow}; + + fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + table_spec: TableSpec { + ks_name: "ks".to_string(), + table_name: "tbl".to_string(), + }, + name: name.to_string(), + typ, + } + } + + #[test] + fn test_legacy_fallback() { + let row = ( + 1i32, + "Ala ma kota", + None::, + MaybeUnset::Unset::, + ); + + let mut legacy_data = Vec::new(); + <_ as ValueList>::write_to_request(&row, &mut legacy_data).unwrap(); + + let mut new_data = Vec::new(); + let mut new_data_writer = BufBackedRowWriter::new(&mut new_data); + let ctx = RowSerializationContext { columns: &[] }; + <_ as SerializeRow>::serialize(&row, &ctx, &mut new_data_writer).unwrap(); + assert_eq!(new_data_writer.value_count(), 4); + + // Skip the value count + assert_eq!(&legacy_data[2..], new_data); + } + + #[test] + fn test_legacy_fallback_with_names() { + let sorted_row = ( + 1i32, + "Ala ma kota", + None::, + MaybeUnset::Unset::, + ); + + let mut sorted_row_data = Vec::new(); + <_ as ValueList>::write_to_request(&sorted_row, &mut sorted_row_data).unwrap(); + + let mut unsorted_row = SerializedValues::new(); + unsorted_row.add_named_value("a", &1i32).unwrap(); + unsorted_row.add_named_value("b", &"Ala ma kota").unwrap(); + unsorted_row + .add_named_value("d", &MaybeUnset::Unset::) + .unwrap(); + unsorted_row.add_named_value("c", &None::).unwrap(); + + let mut unsorted_row_data = Vec::new(); + let mut unsorted_row_data_writer = BufBackedRowWriter::new(&mut unsorted_row_data); + let ctx = RowSerializationContext { + columns: &[ + col_spec("a", ColumnType::Int), + col_spec("b", ColumnType::Text), + col_spec("c", ColumnType::BigInt), + col_spec("d", ColumnType::Ascii), + ], + }; + <_ as SerializeRow>::serialize(&unsorted_row, &ctx, &mut unsorted_row_data_writer).unwrap(); + assert_eq!(unsorted_row_data_writer.value_count(), 4); + + // Skip the value count + assert_eq!(&sorted_row_data[2..], unsorted_row_data); } } diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 43eb9ef738..25d605d13d 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1,13 +1,32 @@ use std::sync::Arc; +use thiserror::Error; + use crate::frame::response::result::ColumnType; use crate::frame::value::Value; -use super::SerializationError; +use super::{CellWriter, SerializationError}; pub trait SerializeCql { + /// Given a CQL type, checks if it _might_ be possible to serialize to that type. + /// + /// This function is intended to serve as an optimization in the future, + /// if we were ever to introduce prepared statements parametrized by types. + /// + /// Some types cannot be type checked without knowing the exact value, + /// this is the case e.g. for `CqlValue`. It's also fine to do it later in + /// `serialize`. fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError>; - fn serialize(&self, typ: &ColumnType, buf: &mut Vec) -> Result<(), SerializationError>; + + /// Serializes the value to given CQL type. + /// + /// The function may assume that `preliminary_type_check` was called, + /// though it must not do anything unsafe if this assumption does not hold. + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result; } impl SerializeCql for T { @@ -15,8 +34,89 @@ impl SerializeCql for T { Ok(()) } - fn serialize(&self, _typ: &ColumnType, buf: &mut Vec) -> Result<(), SerializationError> { - self.serialize(buf) - .map_err(|err| Arc::new(err) as SerializationError) + fn serialize( + &self, + _typ: &ColumnType, + writer: W, + ) -> Result { + serialize_legacy_value(self, writer) + } +} + +pub fn serialize_legacy_value( + v: &T, + writer: W, +) -> Result { + // It's an inefficient and slightly tricky but correct implementation. + let mut buf = Vec::new(); + ::serialize(v, &mut buf).map_err(|err| Arc::new(err) as SerializationError)?; + + // Analyze the output. + // All this dance shows how unsafe our previous interface was... + if buf.len() < 4 { + return Err(Arc::new(ValueToSerializeCqlAdapterError::TooShort { + size: buf.len(), + })); + } + + let (len_bytes, contents) = buf.split_at(4); + let len = i32::from_be_bytes(len_bytes.try_into().unwrap()); + match len { + -2 => Ok(writer.set_unset()), + -1 => Ok(writer.set_null()), + len if len >= 0 => { + if contents.len() != len as usize { + Err(Arc::new( + ValueToSerializeCqlAdapterError::DeclaredVsActualSizeMismatch { + declared: len as usize, + actual: contents.len(), + }, + )) + } else { + Ok(writer.set_value(contents)) + } + } + _ => Err(Arc::new( + ValueToSerializeCqlAdapterError::InvalidDeclaredSize { size: len }, + )), + } +} + +#[derive(Error, Debug)] +pub enum ValueToSerializeCqlAdapterError { + #[error("Output produced by the Value trait is too short to be considered a value: {size} < 4 minimum bytes")] + TooShort { size: usize }, + + #[error("Mismatch between the declared value size vs. actual size: {declared} != {actual}")] + DeclaredVsActualSizeMismatch { declared: usize, actual: usize }, + + #[error("Invalid declared value size: {size}")] + InvalidDeclaredSize { size: i32 }, +} + +#[cfg(test)] +mod tests { + use crate::frame::response::result::ColumnType; + use crate::frame::value::{MaybeUnset, Value}; + use crate::types::serialize::BufBackedCellWriter; + + use super::SerializeCql; + + fn check_compat(v: V) { + let mut legacy_data = Vec::new(); + ::serialize(&v, &mut legacy_data).unwrap(); + + let mut new_data = Vec::new(); + let new_data_writer = BufBackedCellWriter::new(&mut new_data); + ::serialize(&v, &ColumnType::Int, new_data_writer).unwrap(); + + assert_eq!(legacy_data, new_data); + } + + #[test] + fn test_legacy_fallback() { + check_compat(123i32); + check_compat(None::); + check_compat(MaybeUnset::Unset::); } } diff --git a/scylla-cql/src/types/serialize/writers.rs b/scylla-cql/src/types/serialize/writers.rs new file mode 100644 index 0000000000..cafd5442fc --- /dev/null +++ b/scylla-cql/src/types/serialize/writers.rs @@ -0,0 +1,426 @@ +//! Contains types and traits used for safe serialization of values for a CQL statement. + +/// An interface that facilitates writing values for a CQL query. +pub trait RowWriter { + type CellWriter<'a>: CellWriter + where + Self: 'a; + + /// Appends a new value to the sequence and returns an object that allows + /// to fill it in. + fn make_cell_writer(&mut self) -> Self::CellWriter<'_>; +} + +/// Represents a handle to a CQL value that needs to be written into. +/// +/// The writer can either be transformed into a ready value right away +/// (via [`set_null`](CellWriter::set_null), +/// [`set_unset`](CellWriter::set_unset) +/// or [`set_value`](CellWriter::set_value) or transformed into +/// the [`CellWriter::ValueBuilder`] in order to gradually initialize +/// the value when the contents are not available straight away. +/// +/// After the value is fully initialized, the handle is consumed and +/// a [`WrittenCellProof`](CellWriter::WrittenCellProof) object is returned +/// in its stead. This is a type-level proof that the value was fully initialized +/// and is used in [`SerializeCql::serialize`](`super::value::SerializeCql::serialize`) +/// in order to enforce the implementor to fully initialize the provided handle +/// to CQL value. +/// +/// Dropping this type without calling any of its methods will result +/// in nothing being written. +pub trait CellWriter { + /// The type of the value builder, returned by the [`CellWriter::set_value`] + /// method. + type ValueBuilder: CellValueBuilder; + + /// An object that serves as a proof that the cell was fully initialized. + /// + /// This type is returned by [`set_null`](CellWriter::set_null), + /// [`set_unset`](CellWriter::set_unset), + /// [`set_value`](CellWriter::set_value) + /// and also [`CellValueBuilder::finish`] - generally speaking, after + /// the value is fully initialized and the `CellWriter` is destroyed. + /// + /// The purpose of this type is to enforce the contract of + /// [`SerializeCql::serialize`](super::value::SerializeCql::serialize): either + /// the method succeeds and returns a proof that it serialized itself + /// into the given value, or it fails and returns an error or panics. + /// The exact type of [`WrittenCellProof`](CellWriter::WrittenCellProof) + /// is not important as the value is not used at all - it's only + /// a compile-time check. + type WrittenCellProof; + + /// Sets this value to be null, consuming this object. + fn set_null(self) -> Self::WrittenCellProof; + + /// Sets this value to represent an unset value, consuming this object. + fn set_unset(self) -> Self::WrittenCellProof; + + /// Sets this value to a non-zero, non-unset value with given contents. + /// + /// Prefer this to [`into_value_builder`](CellWriter::into_value_builder) + /// if you have all of the contents of the value ready up front (e.g. for + /// fixed size types). + fn set_value(self, contents: &[u8]) -> Self::WrittenCellProof; + + /// Turns this writter into a [`CellValueBuilder`] which can be used + /// to gradually initialize the CQL value. + /// + /// This method should be used if you don't have all of the data + /// up front, e.g. when serializing compound types such as collections + /// or UDTs. + fn into_value_builder(self) -> Self::ValueBuilder; +} + +/// Allows appending bytes to a non-null, non-unset cell. +/// +/// This object needs to be dropped in order for the value to be correctly +/// serialized. Failing to drop this value will result in a payload that will +/// not be parsed by the database correctly, but otherwise should not cause +/// data to be misinterpreted. +pub trait CellValueBuilder { + type SubCellWriter<'a>: CellWriter + where + Self: 'a; + + type WrittenCellProof; + + /// Appends raw bytes to this cell. + fn append_bytes(&mut self, bytes: &[u8]); + + /// Appends a sub-value to the end of the current contents of the cell + /// and returns an object that allows to fill it in. + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_>; + + /// Finishes serializing the value. + fn finish(self) -> Self::WrittenCellProof; +} + +/// A row writer backed by a buffer (vec). +pub struct BufBackedRowWriter<'buf> { + // Buffer that this value should be serialized to. + buf: &'buf mut Vec, + + // Number of values written so far. + value_count: u16, +} + +impl<'buf> BufBackedRowWriter<'buf> { + /// Creates a new row writer based on an existing Vec. + /// + /// The newly created row writer will append data to the end of the vec. + #[inline] + pub fn new(buf: &'buf mut Vec) -> Self { + Self { + buf, + value_count: 0, + } + } + + /// Returns the number of values that were written so far. + #[inline] + pub fn value_count(&self) -> u16 { + self.value_count + } +} + +impl<'buf> RowWriter for BufBackedRowWriter<'buf> { + type CellWriter<'a> = BufBackedCellWriter<'a> where Self: 'a; + + #[inline] + fn make_cell_writer(&mut self) -> Self::CellWriter<'_> { + self.value_count = self + .value_count + .checked_add(1) + .expect("tried to serialize too many values for a query (more than u16::MAX)"); + BufBackedCellWriter::new(self.buf) + } +} + +/// A cell writer backed by a buffer (vec). +pub struct BufBackedCellWriter<'buf> { + buf: &'buf mut Vec, +} + +impl<'buf> BufBackedCellWriter<'buf> { + /// Creates a new cell writer based on an existing Vec. + /// + /// The newly created row writer will append data to the end of the vec. + #[inline] + pub fn new(buf: &'buf mut Vec) -> Self { + BufBackedCellWriter { buf } + } +} + +impl<'buf> CellWriter for BufBackedCellWriter<'buf> { + type ValueBuilder = BufBackedCellValueBuilder<'buf>; + + type WrittenCellProof = (); + + #[inline] + fn set_null(self) { + self.buf.extend_from_slice(&(-1i32).to_be_bytes()); + } + + #[inline] + fn set_unset(self) { + self.buf.extend_from_slice(&(-2i32).to_be_bytes()); + } + + #[inline] + fn set_value(self, bytes: &[u8]) { + let value_len: i32 = bytes + .len() + .try_into() + .expect("value is too big to fit into a CQL [bytes] object (larger than i32::MAX)"); + self.buf.extend_from_slice(&value_len.to_be_bytes()); + self.buf.extend_from_slice(bytes); + } + + #[inline] + fn into_value_builder(self) -> Self::ValueBuilder { + BufBackedCellValueBuilder::new(self.buf) + } +} + +/// A cell value builder backed by a buffer (vec). +pub struct BufBackedCellValueBuilder<'buf> { + // Buffer that this value should be serialized to. + buf: &'buf mut Vec, + + // Starting position of the value in the buffer. + starting_pos: usize, +} + +impl<'buf> BufBackedCellValueBuilder<'buf> { + #[inline] + fn new(buf: &'buf mut Vec) -> Self { + // "Length" of a [bytes] frame can either be a non-negative i32, + // -1 (null) or -1 (not set). Push an invalid value here. It will be + // overwritten eventually either by set_null, set_unset or Drop. + // If the CellSerializer is not dropped as it should, this will trigger + // an error on the DB side and the serialized data + // won't be misinterpreted. + let starting_pos = buf.len(); + buf.extend_from_slice(&(-3i32).to_be_bytes()); + BufBackedCellValueBuilder { buf, starting_pos } + } +} + +impl<'buf> CellValueBuilder for BufBackedCellValueBuilder<'buf> { + type SubCellWriter<'a> = BufBackedCellWriter<'a> + where + Self: 'a; + + type WrittenCellProof = (); + + #[inline] + fn append_bytes(&mut self, bytes: &[u8]) { + self.buf.extend_from_slice(bytes); + } + + #[inline] + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_> { + BufBackedCellWriter::new(self.buf) + } + + #[inline] + fn finish(self) { + // TODO: Should this panic, or should we catch this error earlier? + // Vec will panic anyway if we overflow isize, so at least this + // behavior is consistent with what the stdlib does. + let value_len: i32 = (self.buf.len() - self.starting_pos - 4) + .try_into() + .expect("value is too big to fit into a CQL [bytes] object (larger than i32::MAX)"); + self.buf[self.starting_pos..self.starting_pos + 4] + .copy_from_slice(&value_len.to_be_bytes()); + } +} + +/// A writer that does not actually write anything, just counts the bytes. +/// +/// It can serve as a: +/// +/// - [`RowWriter`] +/// - [`CellWriter`] +/// - [`CellValueBuilder`] +pub struct CountingWriter<'buf> { + buf: &'buf mut usize, +} + +impl<'buf> CountingWriter<'buf> { + /// Creates a new writer which increments the counter under given reference + /// when bytes are appended. + #[inline] + fn new(buf: &'buf mut usize) -> Self { + CountingWriter { buf } + } +} + +impl<'buf> RowWriter for CountingWriter<'buf> { + type CellWriter<'a> = CountingWriter<'a> where Self: 'a; + + #[inline] + fn make_cell_writer(&mut self) -> Self::CellWriter<'_> { + CountingWriter::new(self.buf) + } +} + +impl<'buf> CellWriter for CountingWriter<'buf> { + type ValueBuilder = CountingWriter<'buf>; + + type WrittenCellProof = (); + + #[inline] + fn set_null(self) { + *self.buf += 4; + } + + #[inline] + fn set_unset(self) { + *self.buf += 4; + } + + #[inline] + fn set_value(self, contents: &[u8]) { + *self.buf += 4 + contents.len(); + } + + #[inline] + fn into_value_builder(self) -> Self::ValueBuilder { + *self.buf += 4; + CountingWriter::new(self.buf) + } +} + +impl<'buf> CellValueBuilder for CountingWriter<'buf> { + type SubCellWriter<'a> = CountingWriter<'a> + where + Self: 'a; + + type WrittenCellProof = (); + + #[inline] + fn append_bytes(&mut self, bytes: &[u8]) { + *self.buf += bytes.len(); + } + + #[inline] + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_> { + CountingWriter::new(self.buf) + } + + #[inline] + fn finish(self) -> Self::WrittenCellProof {} +} + +#[cfg(test)] +mod tests { + use super::{ + BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder, CellWriter, CountingWriter, + RowWriter, + }; + + // We want to perform the same computation for both buf backed writer + // and counting writer, but Rust does not support generic closures. + // This trait comes to the rescue. + trait CellSerializeCheck { + fn check(&self, writer: W); + } + + fn check_cell_serialize(c: C) -> Vec { + let mut data = Vec::new(); + let writer = BufBackedCellWriter::new(&mut data); + c.check(writer); + + let mut byte_count = 0usize; + let counting_writer = CountingWriter::new(&mut byte_count); + c.check(counting_writer); + + assert_eq!(data.len(), byte_count); + data + } + + #[test] + fn test_cell_writer() { + struct Check; + impl CellSerializeCheck for Check { + fn check(&self, writer: W) { + let mut sub_writer = writer.into_value_builder(); + sub_writer.make_sub_writer().set_null(); + sub_writer.make_sub_writer().set_value(&[1, 2, 3, 4]); + sub_writer.make_sub_writer().set_unset(); + sub_writer.finish(); + } + } + + let data = check_cell_serialize(Check); + assert_eq!( + data, + [ + 0, 0, 0, 16, // Length of inner data is 16 + 255, 255, 255, 255, // Null (encoded as -1) + 0, 0, 0, 4, 1, 2, 3, 4, // Four byte value + 255, 255, 255, 254, // Unset (encoded as -2) + ] + ); + } + + #[test] + fn test_poisoned_appender() { + struct Check; + impl CellSerializeCheck for Check { + fn check(&self, writer: W) { + let _ = writer.into_value_builder(); + } + } + + let data = check_cell_serialize(Check); + assert_eq!( + data, + [ + 255, 255, 255, 253, // Invalid value + ] + ); + } + + trait RowSerializeCheck { + fn check(&self, writer: &mut W); + } + + fn check_row_serialize(c: C) -> Vec { + let mut data = Vec::new(); + let mut writer = BufBackedRowWriter::new(&mut data); + c.check(&mut writer); + + let mut byte_count = 0usize; + let mut counting_writer = CountingWriter::new(&mut byte_count); + c.check(&mut counting_writer); + + assert_eq!(data.len(), byte_count); + data + } + + #[test] + fn test_row_writer() { + struct Check; + impl RowSerializeCheck for Check { + fn check(&self, writer: &mut W) { + writer.make_cell_writer().set_null(); + writer.make_cell_writer().set_value(&[1, 2, 3, 4]); + writer.make_cell_writer().set_unset(); + } + } + + let data = check_row_serialize(Check); + assert_eq!( + data, + [ + 255, 255, 255, 255, // Null (encoded as -1) + 0, 0, 0, 4, 1, 2, 3, 4, // Four byte value + 255, 255, 255, 254, // Unset (encoded as -2) + ] + ) + } +} diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index 22f34e60a2..58d8b9ea3d 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -1,5 +1,6 @@ use bytes::{Bytes, BytesMut}; use scylla_cql::errors::{BadQuery, QueryError}; +use scylla_cql::frame::types::RawValue; use smallvec::{smallvec, SmallVec}; use std::convert::TryInto; use std::sync::Arc; @@ -399,7 +400,7 @@ impl<'ps> PartitionKey<'ps> { PartitionKeyExtractionError::NoPkIndexValue(pk_index.index, bound_values.len()) })?; // Add it in sequence order to pk_values - if let Some(v) = next_val { + if let RawValue::Value(v) = next_val { let spec = &prepared_metadata.col_specs[pk_index.index as usize]; pk_values[pk_index.sequence as usize] = Some((v, spec)); } diff --git a/scylla/src/transport/partitioner.rs b/scylla/src/transport/partitioner.rs index 9c8a542325..4526715ab2 100644 --- a/scylla/src/transport/partitioner.rs +++ b/scylla/src/transport/partitioner.rs @@ -1,4 +1,5 @@ use bytes::Buf; +use scylla_cql::frame::types::RawValue; use std::num::Wrapping; use crate::{ @@ -343,11 +344,14 @@ pub fn calculate_token_for_partition_key( if serialized_partition_key_values.len() == 1 { let val = serialized_partition_key_values.iter().next().unwrap(); - if let Some(val) = val { + if let RawValue::Value(val) = val { partitioner_hasher.write(val); } } else { - for val in serialized_partition_key_values.iter().flatten() { + for val in serialized_partition_key_values + .iter() + .filter_map(|rv| rv.as_value()) + { let val_len_u16: u16 = val .len() .try_into()