Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen committed Apr 17, 2024
2 parents a033f66 + 3f1ea0a commit 1bd167f
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 15 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ extensions-full = ["httpfs", "json", "parquet", "vtab-full"]
buildtime_bindgen = ["libduckdb-sys/buildtime_bindgen"]
modern-full = ["chrono", "serde_json", "url", "r2d2", "uuid", "polars"]
polars = ["dep:polars"]
chrono = ["dep:chrono", "num-integer"]

[dependencies]
# time = { version = "0.3.2", features = ["formatting", "parsing"], optional = true }
Expand All @@ -52,14 +53,15 @@ memchr = "2.3"
uuid = { version = "1.0", optional = true }
smallvec = "1.6.1"
cast = { version = "0.3", features = ["std"] }
arrow = { version = "50", default-features = false, features = ["prettyprint", "ffi"] }
arrow = { version = "51", default-features = false, features = ["prettyprint", "ffi"] }
rust_decimal = "1.14"
strum = { version = "0.25", features = ["derive"] }
r2d2 = { version = "0.8.9", optional = true }
calamine = { version = "0.22.0", optional = true }
num = { version = "0.4", optional = true, default-features = false, features = ["std"] }
duckdb-loadable-macros = { version = "0.1.1", path="./duckdb-loadable-macros", optional = true }
polars = { version = "0.35.4", features = ["dtype-full"], optional = true}
num-integer = {version = "0.1.46", optional = true}

[dev-dependencies]
doc-comment = "0.3"
Expand Down
22 changes: 22 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ doc_comment::doctest!("../README.md");

#[cfg(test)]
mod test {
use crate::types::Value;

use super::*;
use std::{error::Error as StdError, fmt};

Expand Down Expand Up @@ -1297,6 +1299,26 @@ mod test {
Ok(())
}

#[test]
fn round_trip_interval() -> Result<()> {
let db = checked_memory_handle();
db.execute_batch("CREATE TABLE foo (t INTERVAL);")?;

let d = Value::Interval {
months: 1,
days: 2,
nanos: 3,
};
db.execute("INSERT INTO foo VALUES (?)", [d])?;

let mut stmt = db.prepare("SELECT t FROM foo")?;
let mut rows = stmt.query([])?;
let row = rows.next()?.unwrap();
let d: Value = row.get_unwrap(0);
assert_eq!(d, d);
Ok(())
}

#[test]
fn test_database_name_to_string() -> Result<()> {
assert_eq!(DatabaseName::Main.to_string(), "main");
Expand Down
33 changes: 23 additions & 10 deletions src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use std::{convert, sync::Arc};
use super::{Error, Result, Statement};
use crate::types::{self, FromSql, FromSqlError, ValueRef};

use arrow::array::{ArrayRef, ListArray};
use arrow::{
array::{self, Array, StructArray},
array::{self, Array, ArrayRef, ListArray, StructArray},
datatypes::*,
};
use fallible_iterator::FallibleIterator;
Expand Down Expand Up @@ -547,15 +546,29 @@ impl<'stmt> Row<'stmt> {
}
ValueRef::Time64(types::TimeUnit::Microsecond, array.value(row))
}
DataType::Interval(unit) => match unit {
IntervalUnit::MonthDayNano => {
let array = column
.as_any()
.downcast_ref::<array::IntervalMonthDayNanoArray>()
.unwrap();

if array.is_null(row) {
return ValueRef::Null;
}

let value = array.value(row);

// TODO: remove this manual conversion once arrow-rs bug is fixed
let months = (value) as i32;
let days = (value >> 32) as i32;
let nanos = (value >> 64) as i64;

ValueRef::Interval { months, days, nanos }
}
_ => unimplemented!("{:?}", unit),
},
// TODO: support more data types
// DataType::Interval(unit) => match unit {
// IntervalUnit::DayTime => {
// make_string_interval_day_time!(column, row)
// }
// IntervalUnit::YearMonth => {
// make_string_interval_year_month!(column, row)
// }
// },
// DataType::List(_) => make_string_from_list!(column, row),
// DataType::Dictionary(index_type, _value_type) => match **index_type {
// DataType::Int8 => dict_array_value_to_string::<Int8Type>(column, row),
Expand Down
4 changes: 4 additions & 0 deletions src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,10 @@ impl Statement<'_> {
};
ffi::duckdb_bind_timestamp(ptr, col as u64, ffi::duckdb_timestamp { micros })
},
ValueRef::Interval { months, days, nanos } => unsafe {
let micros = nanos / 1_000;
ffi::duckdb_bind_interval(ptr, col as u64, ffi::duckdb_interval { months, days, micros })
},
_ => unreachable!("not supported: {}", value.data_type()),
};
result_from_duckdb_prepare(rc, ptr)
Expand Down
20 changes: 19 additions & 1 deletion src/test_all_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ fn test_all_types() -> crate::Result<()> {
// union is currently blocked by https://github.com/duckdb/duckdb/pull/11326
"union",
// these remaining types are not yet supported by duckdb-rs
"interval",
"small_enum",
"medium_enum",
"large_enum",
Expand Down Expand Up @@ -331,6 +330,25 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) {
1 => assert_eq!(value, ValueRef::Blob(&[3, 245])),
_ => assert_eq!(value, ValueRef::Null),
},
"interval" => match idx {
0 => assert_eq!(
value,
ValueRef::Interval {
months: 0,
days: 0,
nanos: 0
}
),
1 => assert_eq!(
value,
ValueRef::Interval {
months: 999,
days: 999,
nanos: 999999999000
}
),
_ => assert_eq!(value, ValueRef::Null),
},
_ => todo!("{column:?}"),
}
}
Expand Down
80 changes: 77 additions & 3 deletions src/types/chrono.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
//! Convert most of the [Time Strings](http://sqlite.org/lang_datefunc.html) to chrono types.
use chrono::{DateTime, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use num_integer::Integer;

use crate::{
types::{FromSql, FromSqlError, FromSqlResult, TimeUnit, ToSql, ToSqlOutput, ValueRef},
Result,
};

use super::Value;

/// ISO 8601 calendar date without timezone => "YYYY-MM-DD"
impl ToSql for NaiveDate {
#[inline]
Expand Down Expand Up @@ -126,13 +129,55 @@ impl FromSql for DateTime<Local> {
}
}

impl FromSql for Duration {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
match value {
ValueRef::Interval { months, days, nanos } => {
let days = days + (months * 30);
let (additional_seconds, nanos) = nanos.div_mod_floor(&NANOS_PER_SECOND);
let seconds = additional_seconds + (i64::from(days) * 24 * 3600);

match nanos.try_into() {
Ok(nanos) => {
if let Some(duration) = Duration::new(seconds, nanos) {
Ok(duration)
} else {
Err(FromSqlError::Other("Invalid duration".into()))
}
}
Err(err) => Err(FromSqlError::Other(format!("Invalid duration: {err}").into())),
}
}
_ => Err(FromSqlError::InvalidType),
}
}
}

const DAYS_PER_MONTH: i64 = 30;
const SECONDS_PER_DAY: i64 = 24 * 3600;
const NANOS_PER_SECOND: i64 = 1_000_000_000;
const NANOS_PER_DAY: i64 = SECONDS_PER_DAY * NANOS_PER_SECOND;

impl ToSql for Duration {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
let nanos = self.num_nanoseconds().unwrap();
let (days, nanos) = nanos.div_mod_floor(&NANOS_PER_DAY);
let (months, days) = days.div_mod_floor(&DAYS_PER_MONTH);
Ok(ToSqlOutput::Owned(Value::Interval {
months: months.try_into().unwrap(),
days: days.try_into().unwrap(),
nanos,
}))
}
}

#[cfg(test)]
mod test {
use crate::{
types::{FromSql, ValueRef},
types::{FromSql, ToSql, ToSqlOutput, ValueRef},
Connection, Result,
};
use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, TimeZone, Utc};

fn checked_memory_handle() -> Result<Connection> {
let db = Connection::open_in_memory()?;
Expand Down Expand Up @@ -216,6 +261,35 @@ mod test {
Ok(())
}

#[test]
fn test_time_delta_roundtrip() {
roundtrip_type(TimeDelta::new(3600, 0).unwrap());
roundtrip_type(TimeDelta::new(3600, 1000).unwrap());
}

#[test]
fn test_time_delta() -> Result<()> {
let db = checked_memory_handle()?;
let td = TimeDelta::new(3600, 0).unwrap();

let row: Result<TimeDelta> = db.query_row("SELECT ?", [td], |row| Ok(row.get(0)))?;

assert_eq!(row.unwrap(), td);

Ok(())
}

fn roundtrip_type<T: FromSql + ToSql + Eq + std::fmt::Debug>(td: T) {
let sqled = td.to_sql().unwrap();
let value = match sqled {
ToSqlOutput::Borrowed(v) => v,
ToSqlOutput::Owned(ref v) => ValueRef::from(v),
};
let reversed = FromSql::column_result(value).unwrap();

assert_eq!(td, reversed);
}

#[test]
fn test_date_time_local() -> Result<()> {
let db = checked_memory_handle()?;
Expand Down
3 changes: 3 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pub enum Type {
Date32,
/// TIME64
Time64,
/// INTERVAL
Interval,
/// LIST
List(Box<Type>),
/// Any
Expand Down Expand Up @@ -215,6 +217,7 @@ impl fmt::Display for Type {
Type::Blob => f.pad("Blob"),
Type::Date32 => f.pad("Date32"),
Type::Time64 => f.pad("Time64"),
Type::Interval => f.pad("Interval"),
Type::List(..) => f.pad("List"),
Type::Any => f.pad("Any"),
}
Expand Down
10 changes: 10 additions & 0 deletions src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ pub enum Value {
Date32(i32),
/// The value is a time64
Time64(TimeUnit, i64),
/// The value is an interval (month, day, nano)
Interval {
/// months
months: i32,
/// days
days: i32,
/// nanos
nanos: i64,
},
/// The value is a list
List(Vec<Value>),
}
Expand Down Expand Up @@ -214,6 +223,7 @@ impl Value {
Value::Blob(_) => Type::Blob,
Value::Date32(_) => Type::Date32,
Value::Time64(..) => Type::Time64,
Value::Interval { .. } => Type::Interval,
Value::List(_) => todo!(),
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/types/value_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ pub enum ValueRef<'a> {
Date32(i32),
/// The value is a time64
Time64(TimeUnit, i64),
/// The value is an interval (month, day, nano)
Interval {
/// months
months: i32,
/// days
days: i32,
/// nanos
nanos: i64,
},
/// The value is a list
List(&'a ListArray, usize),
}
Expand Down Expand Up @@ -92,6 +101,7 @@ impl ValueRef<'_> {
ValueRef::Blob(_) => Type::Blob,
ValueRef::Date32(_) => Type::Date32,
ValueRef::Time64(..) => Type::Time64,
ValueRef::Interval { .. } => Type::Interval,
ValueRef::List(arr, _) => arr.data_type().into(),
}
}
Expand Down Expand Up @@ -151,6 +161,7 @@ impl From<ValueRef<'_>> for Value {
ValueRef::Blob(b) => Value::Blob(b.to_vec()),
ValueRef::Date32(d) => Value::Date32(d),
ValueRef::Time64(t, d) => Value::Time64(t, d),
ValueRef::Interval { months, days, nanos } => Value::Interval { months, days, nanos },
ValueRef::List(items, idx) => {
let offsets = items.offsets();
let range = offsets[idx]..offsets[idx + 1];
Expand Down Expand Up @@ -200,6 +211,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> {
Value::Blob(ref b) => ValueRef::Blob(b),
Value::Date32(d) => ValueRef::Date32(d),
Value::Time64(t, d) => ValueRef::Time64(t, d),
Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos },
Value::List(..) => unimplemented!(),
}
}
Expand Down

0 comments on commit 1bd167f

Please sign in to comment.