Skip to content

Commit

Permalink
Support Arrow type LargeUtf8. (#341)
Browse files Browse the repository at this point in the history
* support LargeUtf8

* lint

* fix tests

* Fix tests check_generic_byte_roundtrip

* fix test

* fix clippy
  • Loading branch information
Jeadie authored Jul 10, 2024
1 parent 1c5e7cd commit 88dd455
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 13 deletions.
2 changes: 1 addition & 1 deletion crates/duckdb/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl From<::std::ffi::NulError> for Error {
}
}

const UNKNOWN_COLUMN: usize = std::usize::MAX;
const UNKNOWN_COLUMN: usize = usize::MAX;

/// The conversion isn't precise, but it's convenient to have it
/// to allow use of `get_raw(…).as_…()?` in callbacks that take `Error`.
Expand Down
9 changes: 3 additions & 6 deletions crates/duckdb/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,7 @@ impl fmt::Display for Type {
mod test {
use super::Value;
use crate::{params, Connection, Error, Result, Statement};
use std::{
f64::EPSILON,
os::raw::{c_double, c_int},
};
use std::os::raw::{c_double, c_int};

fn checked_memory_handle() -> Result<Connection> {
let db = Connection::open_in_memory()?;
Expand Down Expand Up @@ -385,7 +382,7 @@ mod test {
assert_eq!(vec![1, 2], row.get::<_, Vec<u8>>(0)?);
assert_eq!("text", row.get::<_, String>(1)?);
assert_eq!(1, row.get::<_, c_int>(2)?);
assert!((1.5 - row.get::<_, c_double>(3)?).abs() < EPSILON);
assert!((1.5 - row.get::<_, c_double>(3)?).abs() < f64::EPSILON);
assert_eq!(row.get::<_, Option<c_int>>(4)?, None);
assert_eq!(row.get::<_, Option<c_double>>(4)?, None);
assert_eq!(row.get::<_, Option<String>>(4)?, None);
Expand Down Expand Up @@ -453,7 +450,7 @@ mod test {
assert_eq!(Value::Text(String::from("text")), row.get::<_, Value>(1)?);
assert_eq!(Value::Int(1), row.get::<_, Value>(2)?);
match row.get::<_, Value>(3)? {
Value::Float(val) => assert!((1.5 - val).abs() < EPSILON as f32),
Value::Float(val) => assert!((1.5 - val).abs() < f32::EPSILON),
x => panic!("Invalid Value {x:?}"),
}
assert_eq!(Value::Null, row.get::<_, Value>(4)?);
Expand Down
78 changes: 72 additions & 6 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::vtab::vector::Inserter;
use arrow::array::{
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array,
as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, FixedSizeListArray,
GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray,
GenericListArray, GenericStringArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray,
};

use arrow::{
Expand Down Expand Up @@ -229,6 +229,15 @@ pub fn record_batch_to_duckdb_data_chunk(
DataType::Utf8 => {
string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i));
}
DataType::LargeUtf8 => {
string_array_to_vector(
col.as_ref()
.as_any()
.downcast_ref::<LargeStringArray>()
.ok_or_else(|| Box::<dyn std::error::Error>::from("Unable to downcast to LargeStringArray"))?,
&mut chunk.flat_vector(i),
);
}
DataType::Binary => {
binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i));
}
Expand Down Expand Up @@ -453,7 +462,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) {
}
}

fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) {
fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out: &mut FlatVector) {
assert!(array.len() <= out.capacity());

// TODO: zero copy assignment
Expand Down Expand Up @@ -612,12 +621,12 @@ mod test {
use arrow::{
array::{
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
FixedSizeListArray, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, LargeStringArray, ListArray,
OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
},
buffer::{OffsetBuffer, ScalarBuffer},
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
datatypes::{i256, ArrowPrimitiveType, ByteArrayType, DataType, Field, Fields, Schema},
record_batch::RecordBatch,
};
use std::{error::Error, sync::Arc};
Expand Down Expand Up @@ -784,6 +793,48 @@ mod test {
Ok(())
}

fn check_generic_byte_roundtrip<T1, T2>(
arry_in: GenericByteArray<T1>,
arry_out: GenericByteArray<T2>,
) -> Result<(), Box<dyn Error>>
where
T1: ByteArrayType,
T2: ByteArrayType,
{
let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", arry_in.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry_in.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);

assert!(
output_any_array.data_type().equals_datatype(arry_out.data_type()),
"{} != {}",
output_any_array.data_type(),
arry_out.data_type()
);

match output_any_array.as_bytes_opt::<T2>() {
Some(output_array) => {
assert_eq!(output_array.len(), arry_out.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), arry_out.is_valid(i));
assert_eq!(output_array.value_data(), arry_out.value_data())
}
}
None => panic!("Expected GenericByteArray"),
}

Ok(())
}

#[test]
fn test_array_roundtrip() -> Result<(), Box<dyn Error>> {
check_generic_array_roundtrip(ListArray::new(
Expand Down Expand Up @@ -862,6 +913,21 @@ mod test {
Ok(())
}

#[test]
fn test_utf8_roundtrip() -> Result<(), Box<dyn Error>> {
check_generic_byte_roundtrip(
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
)?;

// [`LargeStringArray`] will be downcasted to [`StringArray`].
check_generic_byte_roundtrip(
LargeStringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
)?;
Ok(())
}

#[test]
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;
Expand Down

0 comments on commit 88dd455

Please sign in to comment.