diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 84469249db..c9bc8d061b 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -91,9 +91,7 @@ pub use scylla_macros::SerializeCql; /// Derive macro for the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait /// which serializes given Rust structure into bind markers for a CQL statement. /// -/// At the moment, only structs with named fields are supported. The generated -/// implementation of the trait will match the struct fields to bind markers/columns -/// by name automatically. +/// At the moment, only structs with named fields are supported. /// /// Serialization will fail if there are some bind markers/columns in the statement /// that don't match to any of the Rust struct fields, _or vice versa_. @@ -125,6 +123,21 @@ pub use scylla_macros::SerializeCql; /// /// # Attributes /// +/// `#[scylla(flavor = "flavor_name")]` +/// +/// Allows to choose one of the possible "flavors", i.e. the way how the +/// generated code will approach serialization. Possible flavors are: +/// +/// - `"match_by_name"` (default) - the generated implementation _does not +/// require_ the fields in the Rust struct to be in the same order as the +/// columns/bind markers. During serialization, the implementation will take +/// care to serialize the fields in the order which the database expects. +/// - `"enforce_order"` - the generated implementation _requires_ the fields +/// in the Rust struct to be in the same order as the columns/bind markers. +/// If the order is incorrect, type checking/serialization will fail. +/// This is a less robust flavor than `"match_by_name"`, but should be +/// slightly more performant as it doesn't need to perform lookups by name. +/// /// `#[scylla(crate = crate_name)]` /// /// By default, the code generated by the derive macro will refer to the items diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 7e5c14d0f4..fd3b307567 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -463,6 +463,12 @@ pub enum BuiltinTypeCheckErrorKind { /// A value required by the statement is not provided by the Rust type. ColumnMissingForValue { name: String }, + /// A different column name was expected at given position. + ColumnNameMismatch { + rust_column_name: String, + db_column_name: String, + }, + /// One of the columns failed to type check. ColumnTypeCheckFailed { name: String, @@ -488,6 +494,10 @@ impl Display for BuiltinTypeCheckErrorKind { "value for column {name} was provided, but there is no bind marker for this column in the query" ) } + BuiltinTypeCheckErrorKind::ColumnNameMismatch { rust_column_name, db_column_name } => write!( + f, + "expected column with name {db_column_name} at given position, but the Rust field name is {rust_column_name}" + ), BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { name, err } => { write!(f, "failed to check column {name}: {err}") } @@ -741,4 +751,100 @@ mod tests { BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { .. } )); } + + #[derive(SerializeRow, Debug, PartialEq, Eq)] + #[scylla(crate = crate, flavor = "enforce_order")] + struct TestRowWithEnforcedOrder { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_row_serialization_with_enforced_order_correct_order() { + let spec = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + ]; + + let reference = do_serialize(("Ala ma kota", 42i32, vec![1i64, 2i64, 3i64]), &spec); + let row = do_serialize( + TestRowWithEnforcedOrder { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &spec, + ); + + assert_eq!(reference, row); + } + + #[test] + fn test_row_serialization_with_enforced_order_failing_type_check() { + // The order of two last columns is swapped + let spec = [ + col("a", ColumnType::Text), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + col("b", ColumnType::Int), + ]; + let ctx = RowSerializationContext { columns: &spec }; + let err = TestRowWithEnforcedOrder::preliminary_type_check(&ctx).unwrap_err(); + let err = err.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::ColumnNameMismatch { .. } + )); + + let spec_without_c = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + // Missing column c + ]; + + let ctx = RowSerializationContext { + columns: &spec_without_c, + }; + let err = TestRowWithEnforcedOrder::preliminary_type_check(&ctx).unwrap_err(); + let err = err.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::ColumnMissingForValue { .. } + )); + + let spec_duplicate_column = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + // Unexpected last column + col("d", ColumnType::Counter), + ]; + + let ctx = RowSerializationContext { + columns: &spec_duplicate_column, + }; + let err = TestRowWithEnforcedOrder::preliminary_type_check(&ctx).unwrap_err(); + let err = err.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::MissingValueForColumn { .. } + )); + + let spec_wrong_type = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::TinyInt), // Wrong type + ]; + + let ctx = RowSerializationContext { + columns: &spec_wrong_type, + }; + let err = TestRowWithEnforcedOrder::preliminary_type_check(&ctx).unwrap_err(); + let err = err.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { .. } + )); + } } diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs index 7816397377..ecb901d172 100644 --- a/scylla-macros/src/serialize/row.rs +++ b/scylla-macros/src/serialize/row.rs @@ -3,11 +3,15 @@ use proc_macro::TokenStream; use proc_macro2::Span; use syn::parse_quote; +use super::Flavor; + #[derive(FromAttributes)] #[darling(attributes(scylla))] struct Attributes { #[darling(rename = "crate")] crate_path: Option, + + flavor: Option, } impl Attributes { @@ -36,7 +40,11 @@ pub fn derive_serialize_row(tokens_input: TokenStream) -> Result = match ctx.attributes.flavor { + Some(Flavor::MatchByName) | None => Box::new(ColumnSortingGenerator { ctx: &ctx }), + Some(Flavor::EnforceOrder) => Box::new(ColumnOrderedGenerator { ctx: &ctx }), + }; let preliminary_type_check_item = gen.generate_preliminary_type_check(); let serialize_item = gen.generate_serialize(); @@ -80,13 +88,18 @@ impl Context { } } +trait Generator { + fn generate_preliminary_type_check(&self) -> syn::TraitItemFn; + fn generate_serialize(&self) -> syn::TraitItemFn; +} + // Generates an implementation of the trait which sorts the columns according // to how they are defined in prepared statement metadata. struct ColumnSortingGenerator<'a> { ctx: &'a Context, } -impl<'a> ColumnSortingGenerator<'a> { +impl<'a> Generator for ColumnSortingGenerator<'a> { fn generate_preliminary_type_check(&self) -> syn::TraitItemFn { // Need to: // - Check that all required columns are there and no more @@ -245,3 +258,131 @@ impl<'a> ColumnSortingGenerator<'a> { } } } + +// Generates an implementation of the trait which requires the columns +// to be placed in the same order as they are defined in the struct. +struct ColumnOrderedGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> Generator for ColumnOrderedGenerator<'a> { + fn generate_preliminary_type_check(&self) -> syn::TraitItemFn { + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + statements.push(self.ctx.generate_mk_typck_err()); + + // Create an iterator over fields + statements.push(parse_quote! { + let mut column_iter = ctx.columns().iter(); + }); + + // Go over all fields, check their names and then type check + for field in self.ctx.fields.iter() { + let name = field.ident.as_ref().unwrap().to_string(); + let typ = &field.ty; + statements.push(parse_quote! { + match column_iter.next() { + Some(spec) => { + if spec.name == #name { + match <#typ as #crate_path::SerializeCql>::preliminary_type_check(&spec.typ) { + Ok(()) => {} + Err(err) => { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::ColumnTypeCheckFailed { + name: <_ as ::std::clone::Clone>::clone(&spec.name), + err, + } + )); + } + } + } else { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::ColumnNameMismatch { + rust_column_name: <_ as ::std::string::ToString>::to_string(#name), + db_column_name: <_ as ::std::clone::Clone>::clone(&spec.name), + } + )); + } + } + None => { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::ColumnMissingForValue { + name: <_ as ::std::string::ToString>::to_string(#name), + } + )); + } + } + }); + } + + // Check whether there are some columns remaining + statements.push(parse_quote! { + if let Some(spec) = column_iter.next() { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::MissingValueForColumn { + name: <_ as ::std::clone::Clone>::clone(&spec.name), + } + )); + } + }); + + // Concatenate generated code and return + parse_quote! { + fn preliminary_type_check( + ctx: &#crate_path::RowSerializationContext, + ) -> ::std::result::Result<(), #crate_path::SerializationError> { + #(#statements)* + ::std::result::Result::Ok(()) + } + } + } + + fn generate_serialize(&self) -> syn::TraitItemFn { + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + // Declare a helper lambda for creating errors + statements.push(self.ctx.generate_mk_ser_err()); + + // Create an iterator over fields + statements.push(parse_quote! { + let mut column_iter = ctx.columns().iter(); + }); + + // Serialize each field + for field in self.ctx.fields.iter() { + let name = &field.ident; + let typ = &field.ty; + statements.push(parse_quote! { + if let Some(spec) = column_iter.next() { + let cell_writer = <_ as #crate_path::RowWriter>::make_cell_writer(writer); + match <#typ as #crate_path::SerializeCql>::serialize(&self.#name, &spec.typ, cell_writer) { + Ok(_proof) => {}, + Err(err) => { + return ::std::result::Result::Err(mk_ser_err( + #crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed { + name: <_ as ::std::clone::Clone>::clone(&spec.name), + err, + } + )); + } + } + } + }); + } + + parse_quote! { + fn serialize( + &self, + ctx: &#crate_path::RowSerializationContext, + writer: &mut W, + ) -> ::std::result::Result<(), #crate_path::SerializationError> { + #(#statements)* + ::std::result::Result::Ok(()) + } + } + } +}