Skip to content

Commit

Permalink
Merge pull request #855 from piodul/adjust-interface
Browse files Browse the repository at this point in the history
types: serialize: constrain the new serialization traits to make them easier and safer to use
  • Loading branch information
piodul authored Nov 24, 2023
2 parents 156ee60 + 29a37b4 commit 46e33c9
Show file tree
Hide file tree
Showing 9 changed files with 758 additions and 37 deletions.
33 changes: 33 additions & 0 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ impl From<std::array::TryFromSliceError> for ParseError {
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RawValue<'a> {
Null,
Unset,
Value(&'a [u8]),
}

impl<'a> RawValue<'a> {
#[inline]
pub fn as_value(&self) -> Option<&'a [u8]> {
match self {
RawValue::Value(v) => Some(v),
RawValue::Null | RawValue::Unset => None,
}
}
}

fn read_raw_bytes<'a>(count: usize, buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
if buf.len() < count {
return Err(ParseError::BadIncomingData(format!(
Expand Down Expand Up @@ -218,6 +235,22 @@ pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
Ok(v)
}

pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result<RawValue<'a>, ParseError> {
let len = read_int(buf)?;
match len {
-2 => Ok(RawValue::Unset),
-1 => Ok(RawValue::Null),
len if len >= 0 => {
let v = read_raw_bytes(len as usize, buf)?;
Ok(RawValue::Value(v))
}
len => Err(ParseError::BadIncomingData(format!(
"invalid value length: {}",
len,
))),
}
}

pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
let len = read_short_length(buf)?;
let v = read_raw_bytes(len, buf)?;
Expand Down
11 changes: 6 additions & 5 deletions scylla-cql/src/frame/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use chrono::{DateTime, NaiveDate, NaiveTime, TimeZone, Utc};

use super::response::result::CqlValue;
use super::types::vint_encode;
use super::types::RawValue;

#[cfg(feature = "secret")]
use secrecy::{ExposeSecret, Secret, Zeroize};
Expand Down Expand Up @@ -366,7 +367,7 @@ impl SerializedValues {
Ok(())
}

pub fn iter(&self) -> impl Iterator<Item = Option<&[u8]>> {
pub fn iter(&self) -> impl Iterator<Item = RawValue> {
SerializedValuesIterator {
serialized_values: &self.serialized_values,
contains_names: self.contains_names,
Expand Down Expand Up @@ -410,15 +411,15 @@ impl SerializedValues {
})
}

pub fn iter_name_value_pairs(&self) -> impl Iterator<Item = (Option<&str>, &[u8])> {
pub fn iter_name_value_pairs(&self) -> impl Iterator<Item = (Option<&str>, RawValue)> {
let mut buf = &self.serialized_values[..];
(0..self.values_num).map(move |_| {
// `unwrap()`s here are safe, as we assume type-safety: if `SerializedValues` exits,
// we have a guarantee that the layout of the serialized values is valid.
let name = self
.contains_names
.then(|| types::read_string(&mut buf).unwrap());
let serialized = types::read_bytes(&mut buf).unwrap();
let serialized = types::read_value(&mut buf).unwrap();
(name, serialized)
})
}
Expand All @@ -431,7 +432,7 @@ pub struct SerializedValuesIterator<'a> {
}

impl<'a> Iterator for SerializedValuesIterator<'a> {
type Item = Option<&'a [u8]>;
type Item = RawValue<'a>;

fn next(&mut self) -> Option<Self::Item> {
if self.serialized_values.is_empty() {
Expand All @@ -443,7 +444,7 @@ impl<'a> Iterator for SerializedValuesIterator<'a> {
types::read_short_bytes(&mut self.serialized_values).expect("badly encoded value name");
}

Some(types::read_bytes_opt(&mut self.serialized_values).expect("badly encoded value"))
Some(types::read_value(&mut self.serialized_values).expect("badly encoded value"))
}
}

Expand Down
37 changes: 23 additions & 14 deletions scylla-cql/src/frame/value_tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::frame::value::BatchValuesIterator;
use crate::frame::{types::RawValue, value::BatchValuesIterator};

use super::value::{
BatchValues, CqlDate, CqlTime, CqlTimestamp, MaybeUnset, SerializeValuesError,
Expand Down Expand Up @@ -421,7 +421,10 @@ fn serialized_values() {
values.write_to_request(&mut request);
assert_eq!(request, vec![0, 1, 0, 0, 0, 1, 8]);

assert_eq!(values.iter().collect::<Vec<_>>(), vec![Some([8].as_ref())]);
assert_eq!(
values.iter().collect::<Vec<_>>(),
vec![RawValue::Value([8].as_ref())]
);
}

// Add second value
Expand All @@ -436,7 +439,10 @@ fn serialized_values() {

assert_eq!(
values.iter().collect::<Vec<_>>(),
vec![Some([8].as_ref()), Some([0, 16].as_ref())]
vec![
RawValue::Value([8].as_ref()),
RawValue::Value([0, 16].as_ref())
]
);
}

Expand Down Expand Up @@ -468,7 +474,10 @@ fn serialized_values() {

assert_eq!(
values.iter().collect::<Vec<_>>(),
vec![Some([8].as_ref()), Some([0, 16].as_ref())]
vec![
RawValue::Value([8].as_ref()),
RawValue::Value([0, 16].as_ref())
]
);
}
}
Expand Down Expand Up @@ -498,9 +507,9 @@ fn slice_value_list() {
assert_eq!(
serialized.iter().collect::<Vec<_>>(),
vec![
Some([0, 0, 0, 1].as_ref()),
Some([0, 0, 0, 2].as_ref()),
Some([0, 0, 0, 3].as_ref())
RawValue::Value([0, 0, 0, 1].as_ref()),
RawValue::Value([0, 0, 0, 2].as_ref()),
RawValue::Value([0, 0, 0, 3].as_ref())
]
);
}
Expand All @@ -515,9 +524,9 @@ fn vec_value_list() {
assert_eq!(
serialized.iter().collect::<Vec<_>>(),
vec![
Some([0, 0, 0, 1].as_ref()),
Some([0, 0, 0, 2].as_ref()),
Some([0, 0, 0, 3].as_ref())
RawValue::Value([0, 0, 0, 1].as_ref()),
RawValue::Value([0, 0, 0, 2].as_ref()),
RawValue::Value([0, 0, 0, 3].as_ref())
]
);
}
Expand All @@ -530,7 +539,7 @@ fn tuple_value_list() {

let serialized_vals: Vec<u8> = serialized
.iter()
.map(|o: Option<&[u8]>| o.unwrap()[0])
.map(|o: RawValue| o.as_value().unwrap()[0])
.collect();

let expected: Vec<u8> = expected.collect();
Expand Down Expand Up @@ -604,9 +613,9 @@ fn ref_value_list() {
assert_eq!(
serialized.iter().collect::<Vec<_>>(),
vec![
Some([0, 0, 0, 1].as_ref()),
Some([0, 0, 0, 2].as_ref()),
Some([0, 0, 0, 3].as_ref())
RawValue::Value([0, 0, 0, 1].as_ref()),
RawValue::Value([0, 0, 0, 2].as_ref()),
RawValue::Value([0, 0, 0, 3].as_ref())
]
);
}
Expand Down
6 changes: 6 additions & 0 deletions scylla-cql/src/types/serialize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,11 @@ use std::{any::Any, sync::Arc};

pub mod row;
pub mod value;
pub mod writers;

pub use writers::{
BufBackedCellValueBuilder, BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder,
CellWriter, CountingWriter, RowWriter,
};

type SerializationError = Arc<dyn Any + Send + Sync>;
161 changes: 151 additions & 10 deletions scylla-cql/src/types/serialize/row.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use thiserror::Error;

use crate::frame::response::result::ColumnSpec;
use crate::frame::value::ValueList;
use crate::frame::{response::result::ColumnSpec, types::RawValue};

use super::SerializationError;
use super::{CellWriter, RowWriter, SerializationError};

/// Contains information needed to serialize a row.
pub struct RowSerializationContext<'a> {
columns: &'a [ColumnSpec],
}

impl<'a> RowSerializationContext<'a> {
/// Returns column/bind marker specifications for given query.
#[inline]
pub fn columns(&self) -> &'a [ColumnSpec] {
self.columns
}

/// Looks up and returns a column/bind marker by name.
// TODO: change RowSerializationContext to make this faster
#[inline]
pub fn column_by_name(&self, target: &str) -> Option<&ColumnSpec> {
Expand All @@ -23,11 +28,25 @@ impl<'a> RowSerializationContext<'a> {
}

pub trait SerializeRow {
/// Checks if it _might_ be possible to serialize the row according to the
/// information in the context.
///
/// This function is intended to serve as an optimization in the future,
/// if we were ever to introduce prepared statements parametrized by types.
///
/// Sometimes, a row cannot be fully type checked right away without knowing
/// the exact values of the columns (e.g. when deserializing to `CqlValue`),
/// but it's fine to do full type checking later in `serialize`.
fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError>;
fn serialize(

/// Serializes the row according to the information in the given context.
///
/// The function may assume that `preliminary_type_check` was called,
/// though it must not do anything unsafe if this assumption does not hold.
fn serialize<W: RowWriter>(
&self,
ctx: &RowSerializationContext<'_>,
out: &mut Vec<u8>,
writer: &mut W,
) -> Result<(), SerializationError>;
}

Expand All @@ -38,12 +57,134 @@ impl<T: ValueList> SerializeRow for T {
Ok(())
}

fn serialize(
fn serialize<W: RowWriter>(
&self,
_ctx: &RowSerializationContext<'_>,
out: &mut Vec<u8>,
ctx: &RowSerializationContext<'_>,
writer: &mut W,
) -> Result<(), SerializationError> {
self.write_to_request(out)
.map_err(|err| Arc::new(err) as SerializationError)
serialize_legacy_row(self, ctx, writer)
}
}

pub fn serialize_legacy_row<T: ValueList>(
r: &T,
ctx: &RowSerializationContext<'_>,
writer: &mut impl RowWriter,
) -> Result<(), SerializationError> {
let serialized =
<T as ValueList>::serialized(r).map_err(|err| Arc::new(err) as SerializationError)?;

let mut append_value = |value: RawValue| {
let cell_writer = writer.make_cell_writer();
let _proof = match value {
RawValue::Null => cell_writer.set_null(),
RawValue::Unset => cell_writer.set_unset(),
RawValue::Value(v) => cell_writer.set_value(v),
};
};

if !serialized.has_names() {
serialized.iter().for_each(append_value);
} else {
let values_by_name = serialized
.iter_name_value_pairs()
.map(|(k, v)| (k.unwrap(), v))
.collect::<HashMap<_, _>>();

for col in ctx.columns() {
let val = values_by_name.get(col.name.as_str()).ok_or_else(|| {
Arc::new(ValueListToSerializeRowAdapterError::NoBindMarkerWithName {
name: col.name.clone(),
}) as SerializationError
})?;
append_value(*val);
}
}

Ok(())
}

#[derive(Error, Debug)]
pub enum ValueListToSerializeRowAdapterError {
#[error("There is no bind marker with name {name}, but a value for it was provided")]
NoBindMarkerWithName { name: String },
}

#[cfg(test)]
mod tests {
use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec};
use crate::frame::value::{MaybeUnset, SerializedValues, ValueList};
use crate::types::serialize::BufBackedRowWriter;

use super::{RowSerializationContext, SerializeRow};

fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec {
ColumnSpec {
table_spec: TableSpec {
ks_name: "ks".to_string(),
table_name: "tbl".to_string(),
},
name: name.to_string(),
typ,
}
}

#[test]
fn test_legacy_fallback() {
let row = (
1i32,
"Ala ma kota",
None::<i64>,
MaybeUnset::Unset::<String>,
);

let mut legacy_data = Vec::new();
<_ as ValueList>::write_to_request(&row, &mut legacy_data).unwrap();

let mut new_data = Vec::new();
let mut new_data_writer = BufBackedRowWriter::new(&mut new_data);
let ctx = RowSerializationContext { columns: &[] };
<_ as SerializeRow>::serialize(&row, &ctx, &mut new_data_writer).unwrap();
assert_eq!(new_data_writer.value_count(), 4);

// Skip the value count
assert_eq!(&legacy_data[2..], new_data);
}

#[test]
fn test_legacy_fallback_with_names() {
let sorted_row = (
1i32,
"Ala ma kota",
None::<i64>,
MaybeUnset::Unset::<String>,
);

let mut sorted_row_data = Vec::new();
<_ as ValueList>::write_to_request(&sorted_row, &mut sorted_row_data).unwrap();

let mut unsorted_row = SerializedValues::new();
unsorted_row.add_named_value("a", &1i32).unwrap();
unsorted_row.add_named_value("b", &"Ala ma kota").unwrap();
unsorted_row
.add_named_value("d", &MaybeUnset::Unset::<String>)
.unwrap();
unsorted_row.add_named_value("c", &None::<i64>).unwrap();

let mut unsorted_row_data = Vec::new();
let mut unsorted_row_data_writer = BufBackedRowWriter::new(&mut unsorted_row_data);
let ctx = RowSerializationContext {
columns: &[
col_spec("a", ColumnType::Int),
col_spec("b", ColumnType::Text),
col_spec("c", ColumnType::BigInt),
col_spec("d", ColumnType::Ascii),
],
};
<_ as SerializeRow>::serialize(&unsorted_row, &ctx, &mut unsorted_row_data_writer).unwrap();
assert_eq!(unsorted_row_data_writer.value_count(), 4);

// Skip the value count
assert_eq!(&sorted_row_data[2..], unsorted_row_data);
}
}
Loading

0 comments on commit 46e33c9

Please sign in to comment.