diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 062f369e7a..d337022623 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1490,7 +1490,7 @@ mod tests { use std::collections::BTreeMap; use crate::frame::response::result::{ColumnType, CqlValue}; - use crate::frame::value::{MaybeUnset, Unset, Value, ValueTooBig}; + use crate::frame::value::{Counter, MaybeUnset, Unset, Value, ValueTooBig}; use crate::types::serialize::value::{ BuiltinSerializationError, BuiltinSerializationErrorKind, BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, MapSerializationErrorKind, MapTypeCheckErrorKind, @@ -2109,6 +2109,95 @@ mod tests { assert_eq!(reference, udt); } + #[test] + fn test_udt_serialization_with_missing_rust_fields_at_end() { + let udt = TestUdtWithFieldSorting::default(); + + let typ_normal = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected fields + ("d".to_string(), ColumnType::Counter), + ("e".to_string(), ColumnType::Counter), + ], + }; + + let result_normal = do_serialize(&udt, &typ_normal); + let result_additional_field = do_serialize(&udt, &typ_unexpected_field); + + assert_eq!(result_normal, result_additional_field); + } + + #[derive(SerializeCql, Debug, PartialEq, Default)] + #[scylla(crate = crate)] + struct TestUdtWithFieldSorting2 { + a: String, + b: i32, + d: Option, + c: Vec, + } + + #[derive(SerializeCql, Debug, PartialEq, Default)] + #[scylla(crate = crate)] + struct TestUdtWithFieldSorting3 { + a: String, + b: i32, + d: Option, + e: Option, + c: Vec, + } + + #[test] + fn test_udt_serialization_with_missing_rust_field_in_middle() { + let udt = TestUdtWithFieldSorting::default(); + let udt2 = TestUdtWithFieldSorting2::default(); + let udt3 = TestUdtWithFieldSorting3::default(); + + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + // Unexpected fields + ("d".to_string(), ColumnType::Counter), + ("e".to_string(), ColumnType::Float), + // Remaining normal field + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let result_1 = do_serialize(udt, &typ); + let result_2 = do_serialize(udt2, &typ); + let result_3 = do_serialize(udt3, &typ); + + assert_eq!(result_1, result_2); + assert_eq!(result_2, result_3); + } + #[test] fn test_udt_serialization_failing_type_check() { let typ_not_udt = ColumnType::Ascii; @@ -2268,6 +2357,44 @@ mod tests { assert_eq!(reference, udt); } + #[test] + fn test_udt_serialization_with_enforced_order_additional_field() { + let udt = TestUdtWithEnforcedOrder::default(); + + let typ_normal = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let result_normal = do_serialize(&udt, &typ_normal); + let result_additional_field = do_serialize(&udt, &typ_unexpected_field); + + assert_eq!(result_normal, result_additional_field); + } + #[test] fn test_udt_serialization_with_enforced_order_failing_type_check() { let typ_not_udt = ColumnType::Ascii; @@ -2465,4 +2592,104 @@ mod tests { assert_eq!(reference, udt); } + + #[derive(SerializeCql, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate, force_exact_match)] + struct TestStrictUdtWithFieldSorting { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_strict_udt_with_field_sorting_rejects_additional_field() { + let udt = TestStrictUdtWithFieldSorting::default(); + let mut data = Vec::new(); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let err = udt + .serialize(&typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NoSuchFieldInUdt { .. }) + )); + + let typ_unexpected_field_middle = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + // Unexpected field + ("b_c".to_string(), ColumnType::Counter), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let err = udt + .serialize(&typ_unexpected_field_middle, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NoSuchFieldInUdt { .. }) + )); + } + + #[derive(SerializeCql, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate, flavor = "enforce_order", force_exact_match)] + struct TestStrictUdtWithEnforcedOrder { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_strict_udt_with_enforced_order_rejects_additional_field() { + let udt = TestStrictUdtWithEnforcedOrder::default(); + let mut data = Vec::new(); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NoSuchFieldInUdt { .. }) + )); + } } diff --git a/scylla/src/transport/cql_types_test.rs b/scylla/src/transport/cql_types_test.rs index 6c05fc90f2..1ab0997728 100644 --- a/scylla/src/transport/cql_types_test.rs +++ b/scylla/src/transport/cql_types_test.rs @@ -1491,3 +1491,207 @@ async fn test_empty() { assert_eq!(empty, CqlValue::Empty); } + +#[tokio::test] +async fn test_udt_with_missing_field() { + let table_name = "udt_tests"; + let type_name = "usertype1"; + + let session: Session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + + session + .query( + format!( + "CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = \ + {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", + ks + ), + &[], + ) + .await + .unwrap(); + session.use_keyspace(ks, false).await.unwrap(); + + session + .query(format!("DROP TABLE IF EXISTS {}", table_name), &[]) + .await + .unwrap(); + + session + .query(format!("DROP TYPE IF EXISTS {}", type_name), &[]) + .await + .unwrap(); + + session + .query( + format!( + "CREATE TYPE IF NOT EXISTS {} (first int, second boolean, third float, fourth blob)", + type_name + ), + &[], + ) + .await + .unwrap(); + + session + .query( + format!( + "CREATE TABLE IF NOT EXISTS {} (id int PRIMARY KEY, val {})", + table_name, type_name + ), + &[], + ) + .await + .unwrap(); + + let mut id = 0; + + async fn verify_insert_select_identity( + session: &Session, + table_name: &str, + id: i32, + element: TQ, + expected: TR, + ) where + TQ: SerializeCql, + TR: FromCqlVal + PartialEq + Debug, + { + session + .query( + format!("INSERT INTO {}(id,val) VALUES (?,?)", table_name), + &(id, &element), + ) + .await + .unwrap(); + let result = session + .query( + format!("SELECT val from {} WHERE id = ?", table_name), + &(id,), + ) + .await + .unwrap() + .rows + .unwrap() + .into_typed::<(TR,)>() + .next() + .unwrap() + .unwrap() + .0; + assert_eq!(expected, result); + } + + #[derive(FromUserType, Debug, PartialEq)] + struct UdtFull { + pub first: i32, + pub second: bool, + pub third: Option, + pub fourth: Option>, + } + + #[derive(SerializeCql)] + #[scylla(crate = crate)] + struct UdtV1 { + pub first: i32, + pub second: bool, + } + + verify_insert_select_identity( + &session, + table_name, + id, + UdtV1 { + first: 3, + second: true, + }, + UdtFull { + first: 3, + second: true, + third: None, + fourth: None, + }, + ) + .await; + + id += 1; + + #[derive(SerializeCql)] + #[scylla(crate = crate)] + struct UdtV2 { + pub first: i32, + pub second: bool, + pub third: Option, + } + + verify_insert_select_identity( + &session, + table_name, + id, + UdtV2 { + first: 3, + second: true, + third: Some(123.45), + }, + UdtFull { + first: 3, + second: true, + third: Some(123.45), + fourth: None, + }, + ) + .await; + + id += 1; + + #[derive(SerializeCql)] + #[scylla(crate = crate)] + struct UdtV3 { + pub first: i32, + pub second: bool, + pub fourth: Option>, + } + + verify_insert_select_identity( + &session, + table_name, + id, + UdtV3 { + first: 3, + second: true, + fourth: Some(vec![3, 6, 9]), + }, + UdtFull { + first: 3, + second: true, + third: None, + fourth: Some(vec![3, 6, 9]), + }, + ) + .await; + + id += 1; + + #[derive(SerializeCql)] + #[scylla(crate = crate, flavor="enforce_order")] + struct UdtV4 { + pub first: i32, + pub second: bool, + } + + verify_insert_select_identity( + &session, + table_name, + id, + UdtV4 { + first: 3, + second: true, + }, + UdtFull { + first: 3, + second: true, + third: None, + fourth: None, + }, + ) + .await; +}