Skip to content

Commit

Permalink
libsql: port row deserializer from rust client
Browse files Browse the repository at this point in the history
  • Loading branch information
pjhades committed Nov 21, 2023
1 parent 8045a45 commit a38e6ef
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 0 deletions.
99 changes: 99 additions & 0 deletions libsql/src/deserialize_row.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use crate::{Row, Value};
use serde::de::{
value::{Error as DeError, SeqDeserializer},
Error, IntoDeserializer, MapAccess, Visitor,
};
use serde::{Deserialize, Deserializer};

struct RowDeserializer<'de> {
row: &'de Row,
}

impl<'de> Deserializer<'de> for RowDeserializer<'de> {
type Error = DeError;

fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(DeError::custom("Expects a struct"))
}

fn deserialize_struct<V>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
struct RowMapAccess<'a> {
row: &'a Row,
idx: std::ops::Range<usize>,
value: Option<Value>,
}

impl<'de> MapAccess<'de> for RowMapAccess<'de> {
type Error = DeError;

fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: serde::de::DeserializeSeed<'de>,
{
match self.idx.next() {
None => Ok(None),
Some(i) => {
let value = self
.row
.get_value(i as i32)
.map_err(|e| DeError::custom(e))?;
self.value = Some(value);
self.row
.column_name(i as i32)
.map(|name| seed.deserialize(name.into_deserializer()))
.transpose()
}
}
}

fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: serde::de::DeserializeSeed<'de>,
{
let value = self
.value
.take()
.ok_or(DeError::custom("Expects a value but row is exhausted"))?;

match value {
Value::Text(value) => seed.deserialize(value.into_deserializer()),
Value::Null => seed.deserialize(().into_deserializer()),
Value::Integer(value) => seed.deserialize(value.into_deserializer()),
Value::Real(value) => seed.deserialize(value.into_deserializer()),
Value::Blob(value) => {
let seq = SeqDeserializer::new(value.iter().cloned());
seed.deserialize(seq)
}
}
}
}

visitor.visit_map(RowMapAccess {
row: self.row,
idx: 0..self.row.inner.column_count(),
value: None,
})
}

serde::forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map enum identifier ignored_any
}
}

pub fn from_row<'de, T: Deserialize<'de>>(row: &'de Row) -> Result<T, DeError> {
let de = RowDeserializer { row };
T::deserialize(de)
}
4 changes: 4 additions & 0 deletions libsql/src/hrana/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ impl RowInner for Row {
Err(crate::Error::ColumnNotFound(idx))
}
}

fn column_count(&self) -> usize {
self.cols.len()
}
}

fn bind_params(params: Params, stmt: &mut Stmt) {
Expand Down
3 changes: 3 additions & 0 deletions libsql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ mod statement;
mod transaction;
mod value;

#[cfg(feature = "serde")]
pub mod deserialize_row;

pub use value::{Value, ValueRef, ValueType};

cfg_hrana! {
Expand Down
4 changes: 4 additions & 0 deletions libsql/src/local/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ impl RowInner for LibsqlRow {
fn column_type(&self, idx: i32) -> Result<ValueType> {
self.0.column_type(idx).map(ValueType::from)
}

fn column_count(&self) -> usize {
self.0.stmt.column_count()
}
}

impl fmt::Debug for LibsqlRow {
Expand Down
4 changes: 4 additions & 0 deletions libsql/src/replication/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,10 @@ impl RowInner for RemoteRow {
.map(ValueType::from)
.ok_or(Error::InvalidColumnType)
}

fn column_count(&self) -> usize {
self.1.len()
}
}

pub(super) struct RemoteTx(pub(super) Option<RemoteConnection>);
Expand Down
1 change: 1 addition & 0 deletions libsql/src/rows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,5 @@ pub(crate) trait RowInner: fmt::Debug {
fn column_str(&self, idx: i32) -> Result<&str>;
fn column_name(&self, idx: i32) -> Option<&str>;
fn column_type(&self, idx: i32) -> Result<ValueType>;
fn column_count(&self) -> usize;
}
41 changes: 41 additions & 0 deletions libsql/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,44 @@ async fn debug_print_row() {
"{Some(\"id\"): (Integer, 123), Some(\"name\"): (Text, \"potato\"), Some(\"score\"): (Real, 3.14), Some(\"data\"): (Blob, 4), Some(\"age\"): (Null, ())}"
);
}

#[cfg(feature = "serde")]
#[tokio::test]
async fn deserialize_row() {
let db = Database::open(":memory:").unwrap();
let conn = db.connect().unwrap();
let _ = conn
.execute(
"CREATE TABLE users (id INTEGER, name TEXT, score REAL, data BLOB, age INTEGER)",
(),
)
.await;
conn.execute("INSERT INTO users (id, name, score, data, age) VALUES (123, \"potato\", 3.14, X'deadbeef', NULL)", ())
.await
.unwrap();

use serde::Deserialize;

#[derive(Deserialize, Debug)]
struct Data {
id: i64,
name: String,
score: f64,
data: Vec<u8>,
age: (),
}

let row = conn
.query("SELECT * FROM users", ())
.await
.unwrap()
.next()
.unwrap()
.unwrap();
let data: Data = libsql::deserialize_row::from_row(&row).unwrap();
assert_eq!(data.id, 123);
assert_eq!(data.name, "potato".to_string());
assert_eq!(data.score, 3.14);
assert_eq!(data.data, vec![0xde, 0xad, 0xbe, 0xef]);
assert_eq!(data.age, ());
}

0 comments on commit a38e6ef

Please sign in to comment.