diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 6c6f2b7243..0593112668 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -18,8 +18,18 @@ pub use scylla_macros::ValueList; /// /// At the moment, only structs with named fields are supported. /// -/// Serialization will fail if there are some fields in the UDT that don't match -/// to any of the Rust struct fields, _or vice versa_. +/// Serialization will fail if there are some fields in the Rust struct that don't match +/// to any of the UDT fields. +/// +/// If there are fields in UDT that are not present in Rust definition: +/// - serialization will succeed in "match_by_name" flavor (default). Missing +/// fields in the middle of UDT will be sent as NULLs, missing fields at the end will not be sent +/// at all. +/// - serialization will succed if suffix of UDT fields is missing. If there are missing fields in the +/// middle it will fail. Note that if "skip_name_checks" is enabled, and the types happen to match, +/// it is possible for serialization to succeed with unexpected result. +/// This behavior is the default to support ALTERing UDTs by adding new fields. +/// You can require exact match of fields using `force_exact_match` attribute. /// /// In case of failure, either [`BuiltinTypeCheckError`](crate::types::serialize::value::BuiltinTypeCheckError) /// or [`BuiltinSerializationError`](crate::types::serialize::value::BuiltinSerializationError) @@ -42,7 +52,7 @@ pub use scylla_macros::ValueList; /// struct MyUdt { /// a: i32, /// b: Option, -/// c: Vec, +/// // No "c" field - it is not mandatory by default for all fields to be present /// } /// ``` /// @@ -87,7 +97,7 @@ pub use scylla_macros::ValueList; /// macro itself, so in those cases the user must provide an alternative path /// to either the `scylla` or `scylla-cql` crate. /// -/// `#[scylla(skip_name_checks)] +/// `#[scylla(skip_name_checks)]` /// /// _Specific only to the `enforce_order` flavor._ /// @@ -96,6 +106,11 @@ pub use scylla_macros::ValueList; /// struct field names and UDT field names, i.e. it's OK if i-th field has a /// different name in Rust and in the UDT. Fields are still being type-checked. /// +/// `#[scylla(force_exact_match)]` +/// +/// Forces Rust struct to have all the fields present in UDT, otherwise +/// serialization fails. +/// /// # Field attributes /// /// `#[scylla(rename = "name_in_the_udt")]` diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 4b5f9aae27..d337022623 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1490,7 +1490,7 @@ mod tests { use std::collections::BTreeMap; use crate::frame::response::result::{ColumnType, CqlValue}; - use crate::frame::value::{MaybeUnset, Unset, Value, ValueTooBig}; + use crate::frame::value::{Counter, MaybeUnset, Unset, Value, ValueTooBig}; use crate::types::serialize::value::{ BuiltinSerializationError, BuiltinSerializationErrorKind, BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, MapSerializationErrorKind, MapTypeCheckErrorKind, @@ -2109,6 +2109,95 @@ mod tests { assert_eq!(reference, udt); } + #[test] + fn test_udt_serialization_with_missing_rust_fields_at_end() { + let udt = TestUdtWithFieldSorting::default(); + + let typ_normal = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected fields + ("d".to_string(), ColumnType::Counter), + ("e".to_string(), ColumnType::Counter), + ], + }; + + let result_normal = do_serialize(&udt, &typ_normal); + let result_additional_field = do_serialize(&udt, &typ_unexpected_field); + + assert_eq!(result_normal, result_additional_field); + } + + #[derive(SerializeCql, Debug, PartialEq, Default)] + #[scylla(crate = crate)] + struct TestUdtWithFieldSorting2 { + a: String, + b: i32, + d: Option, + c: Vec, + } + + #[derive(SerializeCql, Debug, PartialEq, Default)] + #[scylla(crate = crate)] + struct TestUdtWithFieldSorting3 { + a: String, + b: i32, + d: Option, + e: Option, + c: Vec, + } + + #[test] + fn test_udt_serialization_with_missing_rust_field_in_middle() { + let udt = TestUdtWithFieldSorting::default(); + let udt2 = TestUdtWithFieldSorting2::default(); + let udt3 = TestUdtWithFieldSorting3::default(); + + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + // Unexpected fields + ("d".to_string(), ColumnType::Counter), + ("e".to_string(), ColumnType::Float), + // Remaining normal field + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let result_1 = do_serialize(udt, &typ); + let result_2 = do_serialize(udt2, &typ); + let result_3 = do_serialize(udt3, &typ); + + assert_eq!(result_1, result_2); + assert_eq!(result_2, result_3); + } + #[test] fn test_udt_serialization_failing_type_check() { let typ_not_udt = ColumnType::Ascii; @@ -2145,30 +2234,6 @@ mod tests { ) )); - let typ_unexpected_field = ColumnType::UserDefinedType { - type_name: "typ".to_string(), - keyspace: "ks".to_string(), - field_types: vec![ - ("a".to_string(), ColumnType::Text), - ("b".to_string(), ColumnType::Int), - ( - "c".to_string(), - ColumnType::List(Box::new(ColumnType::BigInt)), - ), - // Unexpected field - ("d".to_string(), ColumnType::Counter), - ], - }; - - let err = udt - .serialize(&typ_unexpected_field, CellWriter::new(&mut data)) - .unwrap_err(); - let err = err.0.downcast_ref::().unwrap(); - assert!(matches!( - err.kind, - BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NoSuchFieldInUdt { .. }) - )); - let typ_wrong_type = ColumnType::UserDefinedType { type_name: "typ".to_string(), keyspace: "ks".to_string(), @@ -2292,6 +2357,44 @@ mod tests { assert_eq!(reference, udt); } + #[test] + fn test_udt_serialization_with_enforced_order_additional_field() { + let udt = TestUdtWithEnforcedOrder::default(); + + let typ_normal = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let result_normal = do_serialize(&udt, &typ_normal); + let result_additional_field = do_serialize(&udt, &typ_unexpected_field); + + assert_eq!(result_normal, result_additional_field); + } + #[test] fn test_udt_serialization_with_enforced_order_failing_type_check() { let typ_not_udt = ColumnType::Ascii; @@ -2349,30 +2452,6 @@ mod tests { ) )); - let typ_unexpected_field = ColumnType::UserDefinedType { - type_name: "typ".to_string(), - keyspace: "ks".to_string(), - field_types: vec![ - ("a".to_string(), ColumnType::Text), - ("b".to_string(), ColumnType::Int), - ( - "c".to_string(), - ColumnType::List(Box::new(ColumnType::BigInt)), - ), - // Unexpected field - ("d".to_string(), ColumnType::Counter), - ], - }; - - let err = - <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) - .unwrap_err(); - let err = err.0.downcast_ref::().unwrap(); - assert!(matches!( - err.kind, - BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NoSuchFieldInUdt { .. }) - )); - let typ_unexpected_field = ColumnType::UserDefinedType { type_name: "typ".to_string(), keyspace: "ks".to_string(), @@ -2513,4 +2592,104 @@ mod tests { assert_eq!(reference, udt); } + + #[derive(SerializeCql, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate, force_exact_match)] + struct TestStrictUdtWithFieldSorting { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_strict_udt_with_field_sorting_rejects_additional_field() { + let udt = TestStrictUdtWithFieldSorting::default(); + let mut data = Vec::new(); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let err = udt + .serialize(&typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NoSuchFieldInUdt { .. }) + )); + + let typ_unexpected_field_middle = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + // Unexpected field + ("b_c".to_string(), ColumnType::Counter), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let err = udt + .serialize(&typ_unexpected_field_middle, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NoSuchFieldInUdt { .. }) + )); + } + + #[derive(SerializeCql, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate, flavor = "enforce_order", force_exact_match)] + struct TestStrictUdtWithEnforcedOrder { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_strict_udt_with_enforced_order_rejects_additional_field() { + let udt = TestStrictUdtWithEnforcedOrder::default(); + let mut data = Vec::new(); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NoSuchFieldInUdt { .. }) + )); + } } diff --git a/scylla-macros/src/serialize/cql.rs b/scylla-macros/src/serialize/cql.rs index 1aa9d05835..3ba74e671e 100644 --- a/scylla-macros/src/serialize/cql.rs +++ b/scylla-macros/src/serialize/cql.rs @@ -18,6 +18,9 @@ struct Attributes { #[darling(default)] skip_name_checks: bool, + + #[darling(default)] + force_exact_match: bool, } impl Attributes { @@ -216,6 +219,37 @@ impl<'a> Generator for FieldSortingGenerator<'a> { let udt_field_names = rust_field_names.clone(); // For now, it's the same let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); + let missing_rust_field_expression: syn::Expr = if self.ctx.attributes.force_exact_match { + parse_quote! { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::NoSuchFieldInUdt { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )) + } + } else { + parse_quote! { + skipped_fields += 1 + } + }; + + let serialize_missing_nulls_statement: syn::Stmt = if self.ctx.attributes.force_exact_match + { + // Not sure if there is better way to create no-op statement + // parse_quote!{} / parse_quote!{ ; } doesn't work + parse_quote! { + (); + } + } else { + parse_quote! { + while skipped_fields > 0 { + let sub_builder = #crate_path::CellValueBuilder::make_sub_writer(&mut builder); + sub_builder.set_null(); + skipped_fields -= 1; + } + } + }; + // Declare helper lambdas for creating errors statements.push(self.ctx.generate_mk_typck_err()); statements.push(self.ctx.generate_mk_ser_err()); @@ -241,6 +275,16 @@ impl<'a> Generator for FieldSortingGenerator<'a> { let mut remaining_count = #field_count; }); + // We want to send nulls for missing rust fields in the middle, but send + // nothing for those fields at the end of UDT. While executing the loop + // we don't know if there will be any more present fields. The solution is + // to count how many fields we missed and send them when we find any present field. + if !self.ctx.attributes.force_exact_match { + statements.push(parse_quote! { + let mut skipped_fields = 0; + }); + } + // Turn the cell writer into a value builder statements.push(parse_quote! { let mut builder = #crate_path::CellWriter::into_value_builder(writer); @@ -253,6 +297,7 @@ impl<'a> Generator for FieldSortingGenerator<'a> { match ::std::string::String::as_str(field_name) { #( #udt_field_names => { + #serialize_missing_nulls_statement let sub_builder = #crate_path::CellValueBuilder::make_sub_writer(&mut builder); match <#field_types as #crate_path::SerializeCql>::serialize(&self.#rust_field_idents, field_type, sub_builder) { ::std::result::Result::Ok(_proof) => {} @@ -271,11 +316,7 @@ impl<'a> Generator for FieldSortingGenerator<'a> { } } )* - _ => return ::std::result::Result::Err(mk_typck_err( - #crate_path::UdtTypeCheckErrorKind::NoSuchFieldInUdt { - field_name: <_ as ::std::clone::Clone>::clone(field_name), - } - )), + _ => #missing_rust_field_expression, } } }); @@ -396,16 +437,18 @@ impl<'a> Generator for FieldOrderedGenerator<'a> { }); } - // Check whether there are some fields remaining - statements.push(parse_quote! { - if let Some((field_name, typ)) = field_iter.next() { - return ::std::result::Result::Err(mk_typck_err( - #crate_path::UdtTypeCheckErrorKind::NoSuchFieldInUdt { - field_name: <_ as ::std::clone::Clone>::clone(field_name), - } - )); - } - }); + if self.ctx.attributes.force_exact_match { + // Check whether there are some fields remaining + statements.push(parse_quote! { + if let Some((field_name, typ)) = field_iter.next() { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::NoSuchFieldInUdt { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )); + } + }); + } parse_quote! { fn serialize<'b>( diff --git a/scylla/src/transport/cql_types_test.rs b/scylla/src/transport/cql_types_test.rs index 6c05fc90f2..1ab0997728 100644 --- a/scylla/src/transport/cql_types_test.rs +++ b/scylla/src/transport/cql_types_test.rs @@ -1491,3 +1491,207 @@ async fn test_empty() { assert_eq!(empty, CqlValue::Empty); } + +#[tokio::test] +async fn test_udt_with_missing_field() { + let table_name = "udt_tests"; + let type_name = "usertype1"; + + let session: Session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + + session + .query( + format!( + "CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = \ + {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", + ks + ), + &[], + ) + .await + .unwrap(); + session.use_keyspace(ks, false).await.unwrap(); + + session + .query(format!("DROP TABLE IF EXISTS {}", table_name), &[]) + .await + .unwrap(); + + session + .query(format!("DROP TYPE IF EXISTS {}", type_name), &[]) + .await + .unwrap(); + + session + .query( + format!( + "CREATE TYPE IF NOT EXISTS {} (first int, second boolean, third float, fourth blob)", + type_name + ), + &[], + ) + .await + .unwrap(); + + session + .query( + format!( + "CREATE TABLE IF NOT EXISTS {} (id int PRIMARY KEY, val {})", + table_name, type_name + ), + &[], + ) + .await + .unwrap(); + + let mut id = 0; + + async fn verify_insert_select_identity( + session: &Session, + table_name: &str, + id: i32, + element: TQ, + expected: TR, + ) where + TQ: SerializeCql, + TR: FromCqlVal + PartialEq + Debug, + { + session + .query( + format!("INSERT INTO {}(id,val) VALUES (?,?)", table_name), + &(id, &element), + ) + .await + .unwrap(); + let result = session + .query( + format!("SELECT val from {} WHERE id = ?", table_name), + &(id,), + ) + .await + .unwrap() + .rows + .unwrap() + .into_typed::<(TR,)>() + .next() + .unwrap() + .unwrap() + .0; + assert_eq!(expected, result); + } + + #[derive(FromUserType, Debug, PartialEq)] + struct UdtFull { + pub first: i32, + pub second: bool, + pub third: Option, + pub fourth: Option>, + } + + #[derive(SerializeCql)] + #[scylla(crate = crate)] + struct UdtV1 { + pub first: i32, + pub second: bool, + } + + verify_insert_select_identity( + &session, + table_name, + id, + UdtV1 { + first: 3, + second: true, + }, + UdtFull { + first: 3, + second: true, + third: None, + fourth: None, + }, + ) + .await; + + id += 1; + + #[derive(SerializeCql)] + #[scylla(crate = crate)] + struct UdtV2 { + pub first: i32, + pub second: bool, + pub third: Option, + } + + verify_insert_select_identity( + &session, + table_name, + id, + UdtV2 { + first: 3, + second: true, + third: Some(123.45), + }, + UdtFull { + first: 3, + second: true, + third: Some(123.45), + fourth: None, + }, + ) + .await; + + id += 1; + + #[derive(SerializeCql)] + #[scylla(crate = crate)] + struct UdtV3 { + pub first: i32, + pub second: bool, + pub fourth: Option>, + } + + verify_insert_select_identity( + &session, + table_name, + id, + UdtV3 { + first: 3, + second: true, + fourth: Some(vec![3, 6, 9]), + }, + UdtFull { + first: 3, + second: true, + third: None, + fourth: Some(vec![3, 6, 9]), + }, + ) + .await; + + id += 1; + + #[derive(SerializeCql)] + #[scylla(crate = crate, flavor="enforce_order")] + struct UdtV4 { + pub first: i32, + pub second: bool, + } + + verify_insert_select_identity( + &session, + table_name, + id, + UdtV4 { + first: 3, + second: true, + }, + UdtFull { + first: 3, + second: true, + third: None, + fourth: None, + }, + ) + .await; +}