Skip to content

Commit

Permalink
fix: arrow vtab panic (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mause authored Apr 17, 2024
1 parent 2e638f1 commit d110924
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 37 deletions.
2 changes: 1 addition & 1 deletion libduckdb-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ serde_json = { version = "1.0" }
tar = "0.4.38"

[dev-dependencies]
arrow = { version = "49", default-features = false, features = ["ffi"] }
arrow = { version = "51", default-features = false, features = ["ffi"] }
112 changes: 76 additions & 36 deletions src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{
vector::{FlatVector, ListVector, Vector},
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab,
};
use std::ptr::null_mut;

use crate::vtab::vector::Inserter;
use arrow::array::{
Expand Down Expand Up @@ -74,8 +75,11 @@ impl VTab for ArrowVTab {
type InitData = ArrowInitData;

unsafe fn bind(bind: &BindInfo, data: *mut ArrowBindData) -> Result<(), Box<dyn std::error::Error>> {
(*data).rb = null_mut();
let param_count = bind.get_parameter_count();
assert!(param_count == 2);
if param_count != 2 {
return Err(format!("Bad param count: {param_count}, expected 2").into());
}
let array = bind.get_parameter(0).to_int64();
let schema = bind.get_parameter(1).to_int64();
unsafe {
Expand Down Expand Up @@ -106,6 +110,7 @@ impl VTab for ArrowVTab {
output.set_len(0);
} else {
let rb = Box::from_raw((*bind_info).rb);
(*bind_info).rb = null_mut(); // erase ref in case of failure in record_batch_to_duckdb_data_chunk
record_batch_to_duckdb_data_chunk(&rb, output)?;
(*bind_info).rb = Box::into_raw(rb);
(*init_info).done = true;
Expand Down Expand Up @@ -156,7 +161,7 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => List,
DataType::Struct(_) => Struct,
DataType::Union(_, _) => Union,
DataType::Dictionary(_, _) => todo!(),
// DataType::Dictionary(_, _) => todo!(),
// duckdb/src/main/capi/helper-c.cpp does not support decimal
// DataType::Decimal128(_, _) => Decimal,
// DataType::Decimal256(_, _) => Decimal,
Expand Down Expand Up @@ -194,8 +199,9 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<d
} else if let DataType::FixedSizeList(child, _) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else {
unimplemented!(
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
Err(
format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
.into(),
)
}
}
Expand All @@ -216,30 +222,31 @@ pub fn record_batch_to_duckdb_data_chunk(
let col = batch.column(i);
match col.data_type() {
dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
primitive_array_to_vector(col, &mut chunk.flat_vector(i));
primitive_array_to_vector(col, &mut chunk.flat_vector(i))?;
}
DataType::Utf8 => {
string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i));
}
DataType::List(_) => {
list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i));
list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
}
DataType::LargeList(_) => {
list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i));
list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
}
DataType::FixedSizeList(_, _) => {
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i));
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
}
DataType::Struct(_) => {
let struct_array = as_struct_array(col.as_ref());
let mut struct_vector = chunk.struct_vector(i);
struct_array_to_vector(struct_array, &mut struct_vector);
struct_array_to_vector(struct_array, &mut struct_vector)?;
}
_ => {
unimplemented!(
return Err(format!(
"column {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs",
batch.schema().field(i)
);
)
.into());
}
}
}
Expand All @@ -262,7 +269,7 @@ fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
}

fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<(), Box<dyn std::error::Error>> {
match array.data_type() {
DataType::Boolean => {
boolean_array_to_vector(as_boolean_array(array), out.as_mut_any().downcast_mut().unwrap());
Expand Down Expand Up @@ -315,7 +322,6 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Float16 => todo!("Float16 is not supported yet"),
DataType::Float32 => {
primitive_array_to_flat_vector::<Float32Type>(
as_primitive_array(array),
Expand All @@ -337,7 +343,6 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"),

// DuckDB Only supports timetamp_tz in microsecond precision
DataType::Timestamp(_, Some(tz)) => primitive_array_to_flat_vector_cast::<TimestampMicrosecondType>(
Expand Down Expand Up @@ -376,11 +381,9 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
DataType::Time64(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
_ => todo!(
"Converting '{dtype:#?}' to primitive flat vector is not supported",
dtype = array.data_type()
),
datatype => return Err(format!("Data type \"{datatype}\" not yet supported by ArrowVTab").into()),
}
Ok(())
}

/// Convert Arrow [Decimal128Array] to a duckdb vector.
Expand Down Expand Up @@ -410,31 +413,38 @@ fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) {
}
}

fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(array: &GenericListArray<O>, out: &mut ListVector) {
fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
array: &GenericListArray<O>,
out: &mut ListVector,
) -> Result<(), Box<dyn std::error::Error>> {
let value_array = array.values();
let mut child = out.child(value_array.len());
match value_array.data_type() {
dt if dt.is_primitive() => {
primitive_array_to_vector(value_array.as_ref(), &mut child);
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
for i in 0..array.len() {
let offset = array.value_offsets()[i];
let length = array.value_length(i);
out.set_entry(i, offset.as_(), length.as_());
}
}
_ => {
println!("Nested list is not supported yet.");
todo!()
return Err("Nested list is not supported yet.".into());
}
}

Ok(())
}

fn fixed_size_list_array_to_vector(array: &FixedSizeListArray, out: &mut ListVector) {
fn fixed_size_list_array_to_vector(
array: &FixedSizeListArray,
out: &mut ListVector,
) -> Result<(), Box<dyn std::error::Error>> {
let value_array = array.values();
let mut child = out.child(value_array.len());
match value_array.data_type() {
dt if dt.is_primitive() => {
primitive_array_to_vector(value_array.as_ref(), &mut child);
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
for i in 0..array.len() {
let offset = array.value_offset(i);
let length = array.value_length();
Expand All @@ -443,10 +453,11 @@ fn fixed_size_list_array_to_vector(array: &FixedSizeListArray, out: &mut ListVec
out.set_len(value_array.len());
}
_ => {
println!("Nested list is not supported yet.");
todo!()
return Err("Nested list is not supported yet.".into());
}
}

Ok(())
}

/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
Expand All @@ -455,32 +466,32 @@ fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
}

fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) {
fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result<(), Box<dyn std::error::Error>> {
for i in 0..array.num_columns() {
let column = array.column(i);
match column.data_type() {
dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
primitive_array_to_vector(column, &mut out.child(i));
primitive_array_to_vector(column, &mut out.child(i))?;
}
DataType::Utf8 => {
string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i));
}
DataType::List(_) => {
list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i));
list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i))?;
}
DataType::LargeList(_) => {
list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i));
list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i))?;
}
DataType::FixedSizeList(_, _) => {
fixed_size_list_array_to_vector(
as_fixed_size_list_array(column.as_ref()),
&mut out.list_vector_child(i),
);
)?;
}
DataType::Struct(_) => {
let struct_array = as_struct_array(column.as_ref());
let mut struct_vector = out.struct_vector_child(i);
struct_array_to_vector(struct_array, &mut struct_vector);
struct_array_to_vector(struct_array, &mut struct_vector)?;
}
_ => {
unimplemented!(
Expand All @@ -490,6 +501,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) {
}
}
}
Ok(())
}

/// Pass RecordBatch to duckdb.
Expand Down Expand Up @@ -531,11 +543,11 @@ mod test {
use crate::{Connection, Result};
use arrow::{
array::{
Array, ArrayRef, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, Int32Array,
PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
},
datatypes::{ArrowPrimitiveType, DataType, Field, Fields, Schema},
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
record_batch::RecordBatch,
};
use std::{error::Error, sync::Arc};
Expand Down Expand Up @@ -749,4 +761,32 @@ mod test {
assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE");
Ok(())
}

#[test]
fn test_arrow_error() {
let arc: ArrayRef = Arc::new(Decimal256Array::from(vec![i256::from(1), i256::from(2), i256::from(3)]));
let batch = RecordBatch::try_from_iter(vec![("x", arc)]).unwrap();

let db = Connection::open_in_memory().unwrap();
db.register_table_function::<ArrowVTab>("arrow").unwrap();

let mut stmt = db.prepare("SELECT * FROM arrow(?, ?)").unwrap();

let res = match stmt.execute(arrow_recordbatch_to_query_params(batch)) {
Ok(..) => None,
Err(e) => Some(e),
}
.unwrap();

assert_eq!(
res,
crate::error::Error::DuckDBFailure(
crate::ffi::Error {
code: crate::ffi::ErrorCode::Unknown,
extended_code: 1
},
Some("Invalid Input Error: Data type \"Decimal256(76, 10)\" not yet supported by ArrowVTab".to_owned())
)
);
}
}

0 comments on commit d110924

Please sign in to comment.