diff --git a/crates/duckdb-loadable-macros/src/lib.rs b/crates/duckdb-loadable-macros/src/lib.rs index d82cde71..136f27f5 100644 --- a/crates/duckdb-loadable-macros/src/lib.rs +++ b/crates/duckdb-loadable-macros/src/lib.rs @@ -133,8 +133,10 @@ pub fn duckdb_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream { /// Will be called by duckdb #[no_mangle] pub unsafe extern "C" fn #c_entrypoint(db: *mut c_void) { - let connection = Connection::open_from_raw(db.cast()).expect("can't open db connection"); - #prefixed_original_function(connection).expect("init failed"); + unsafe { + let connection = Connection::open_from_raw(db.cast()).expect("can't open db connection"); + #prefixed_original_function(connection).expect("init failed"); + } } /// # Safety @@ -142,7 +144,9 @@ pub fn duckdb_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream { /// Predefined function, don't need to change unless you are sure #[no_mangle] pub unsafe extern "C" fn #c_entrypoint_version() -> *const c_char { - ffi::duckdb_library_version() + unsafe { + ffi::duckdb_library_version() + } } diff --git a/crates/duckdb/examples/hello-ext/main.rs b/crates/duckdb/examples/hello-ext/main.rs index 6f159e9a..8a6edad3 100644 --- a/crates/duckdb/examples/hello-ext/main.rs +++ b/crates/duckdb/examples/hello-ext/main.rs @@ -1,10 +1,13 @@ +#![warn(unsafe_op_in_unsafe_fn)] +#![warn(unsafe)] // extensions can be safe + extern crate duckdb; extern crate duckdb_loadable_macros; extern crate libduckdb_sys; use duckdb::{ core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId}, - vtab::{BindInfo, Free, FunctionInfo, InitInfo, VTab}, + vtab::{BindInfo, FunctionInfo, InitInfo, VTab}, Connection, Result, }; use duckdb_loadable_macros::duckdb_entrypoint; @@ -12,70 +15,46 @@ use libduckdb_sys as ffi; use std::{ error::Error, ffi::{c_char, c_void, CString}, + sync::atomic::{AtomicBool, Ordering}, }; -#[repr(C)] struct HelloBindData { - name: *mut c_char, -} - -impl Free for HelloBindData { - fn free(&mut self) { - unsafe { - if self.name.is_null() { - return; - } - drop(CString::from_raw(self.name)); - } - } + name: String, } -#[repr(C)] struct HelloInitData { - done: bool, + done: AtomicBool, } struct HelloVTab; -impl Free for HelloInitData {} - impl VTab for HelloVTab { type InitData = HelloInitData; type BindData = HelloBindData; - unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { + fn bind(bind: &BindInfo) -> Result> { bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar)); - let param = bind.get_parameter(0).to_string(); - unsafe { - (*data).name = CString::new(param).unwrap().into_raw(); - } - Ok(()) + let name = bind.get_parameter(0).to_string(); + Ok(HelloBindData { name }) } - unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { - unsafe { - (*data).done = false; - } - Ok(()) + fn init(_: &InitInfo) -> Result> { + Ok(HelloInitData { + done: AtomicBool::new(false), + }) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { - let init_info = func.get_init_data::(); - let bind_info = func.get_bind_data::(); + fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + let init_data = func.get_init_data(); + let bind_data = func.get_bind_data(); - unsafe { - if (*init_info).done { - output.set_len(0); - } else { - (*init_info).done = true; - let vector = output.flat_vector(0); - let name = CString::from_raw((*bind_info).name); - let result = CString::new(format!("Hello {}", name.to_str()?))?; - // Can't consume the CString - (*bind_info).name = CString::into_raw(name); - vector.insert(0, result); - output.set_len(1); - } + if init_data.done.swap(true, Ordering::Relaxed) { + output.set_len(0); + } else { + let vector = output.flat_vector(0); + let result = CString::new(format!("Hello {}", bind_data.name))?; + vector.insert(0, result); + output.set_len(1); } Ok(()) } diff --git a/crates/duckdb/src/vtab/function.rs b/crates/duckdb/src/vtab/function.rs index 9d14b510..cb32ee77 100644 --- a/crates/duckdb/src/vtab/function.rs +++ b/crates/duckdb/src/vtab/function.rs @@ -9,10 +9,11 @@ use super::{ duckdb_table_function_set_init, duckdb_table_function_set_local_init, duckdb_table_function_set_name, duckdb_table_function_supports_projection_pushdown, idx_t, }, - LogicalTypeHandle, Value, + LogicalTypeHandle, VTab, Value, }; use std::{ ffi::{c_void, CString}, + marker::PhantomData, os::raw::c_char, }; @@ -138,7 +139,9 @@ impl From for InitInfo { impl InitInfo { /// # Safety pub unsafe fn set_init_data(&self, data: *mut c_void, freeer: Option) { - duckdb_init_set_init_data(self.0, data, freeer); + unsafe { + duckdb_init_set_init_data(self.0, data, freeer); + } } /// Returns the column indices of the projected columns at the specified positions. @@ -188,7 +191,7 @@ impl InitInfo { /// * `error`: The error message pub fn set_error(&self, error: &str) { let c_str = CString::new(error).unwrap(); - unsafe { duckdb_init_set_error(self.0, c_str.as_ptr() as *const c_char) } + unsafe { duckdb_init_set_error(self.0, c_str.as_ptr()) } } } @@ -309,7 +312,9 @@ impl TableFunction { /// /// # Safety pub unsafe fn set_extra_info(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) { - duckdb_table_function_set_extra_info(self.ptr, extra_info, destroy); + unsafe { + duckdb_table_function_set_extra_info(self.ptr, extra_info, destroy); + } } /// Sets the thread-local init function of the table function @@ -334,9 +339,12 @@ use super::ffi::{ /// An interface to store and retrieve data during the function execution stage #[derive(Debug)] -pub struct FunctionInfo(duckdb_function_info); +pub struct FunctionInfo { + ptr: duckdb_function_info, + _vtab: PhantomData, +} -impl FunctionInfo { +impl FunctionInfo { /// Report that an error has occurred while executing the function. /// /// # Arguments @@ -344,44 +352,57 @@ impl FunctionInfo { pub fn set_error(&self, error: &str) { let c_str = CString::new(error).unwrap(); unsafe { - duckdb_function_set_error(self.0, c_str.as_ptr() as *const c_char); + duckdb_function_set_error(self.ptr, c_str.as_ptr()); } } + /// Gets the bind data set by [`BindInfo::set_bind_data`] during the bind. /// /// Note that the bind data should be considered as read-only. /// For tracking state, use the init data instead. - /// - /// # Arguments - /// * `returns`: The bind data object - pub fn get_bind_data(&self) -> *mut T { - unsafe { duckdb_function_get_bind_data(self.0).cast() } + pub fn get_bind_data(&self) -> &V::BindData { + unsafe { + let bind_data: *const V::BindData = duckdb_function_get_bind_data(self.ptr).cast(); + bind_data.as_ref().unwrap() + } } - /// Gets the init data set by [`InitInfo::set_init_data`] during the init. + + /// Get a reference to the init data set by [`InitInfo::set_init_data`] during the init. + /// + /// This returns a shared reference because the init data is shared between multiple threads. + /// It may internally be mutable. /// /// # Arguments /// * `returns`: The init data object - pub fn get_init_data(&self) -> *mut T { - unsafe { duckdb_function_get_init_data(self.0).cast() } + pub fn get_init_data(&self) -> &V::InitData { + // Safety: A pointer to a box of the init data is stored during vtab init. + unsafe { + let init_data: *const V::InitData = duckdb_function_get_init_data(self.ptr).cast(); + init_data.as_ref().unwrap() + } } + /// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`] /// /// # Arguments /// * `returns`: The extra info pub fn get_extra_info(&self) -> *mut T { - unsafe { duckdb_function_get_extra_info(self.0).cast() } + unsafe { duckdb_function_get_extra_info(self.ptr).cast() } } /// Gets the thread-local init data set by [`InitInfo::set_init_data`] during the local_init. /// /// # Arguments /// * `returns`: The init data object pub fn get_local_init_data(&self) -> *mut T { - unsafe { duckdb_function_get_local_init_data(self.0).cast() } + unsafe { duckdb_function_get_local_init_data(self.ptr).cast() } } } -impl From for FunctionInfo { +impl From for FunctionInfo { fn from(ptr: duckdb_function_info) -> Self { - Self(ptr) + Self { + ptr, + _vtab: PhantomData, + } } } diff --git a/crates/duckdb/src/vtab/mod.rs b/crates/duckdb/src/vtab/mod.rs index 9249fb1e..32bd479d 100644 --- a/crates/duckdb/src/vtab/mod.rs +++ b/crates/duckdb/src/vtab/mod.rs @@ -1,8 +1,11 @@ -use crate::{error::Error, inner_connection::InnerConnection, Connection, Result}; +// #![warn(unsafe_op_in_unsafe_fn)] -use super::{ffi, ffi::duckdb_free}; use std::ffi::c_void; +use crate::{error::Error, inner_connection::InnerConnection, Connection, Result}; + +use super::ffi; + mod function; mod value; @@ -20,35 +23,15 @@ mod excel; pub use function::{BindInfo, FunctionInfo, InitInfo, TableFunction}; pub use value::Value; -use crate::core::{DataChunkHandle, LogicalTypeHandle, LogicalTypeId}; +use crate::core::{DataChunkHandle, LogicalTypeHandle}; use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info}; -use ffi::duckdb_malloc; -use std::mem::size_of; - -/// duckdb_malloc a struct of type T -/// used for the bind_info and init_info -/// # Safety -/// This function is obviously unsafe -unsafe fn malloc_data_c() -> *mut T { - duckdb_malloc(size_of::()).cast() -} - -/// free bind or info data +/// Given a raw pointer to a box, free the box and the data contained within it. /// /// # Safety -/// This function is obviously unsafe -/// TODO: maybe we should use a Free trait here -unsafe extern "C" fn drop_data_c(v: *mut c_void) { - let actual = v.cast::(); - (*actual).free(); - duckdb_free(v); -} - -/// Free trait for the bind and init data -pub trait Free { - /// Free the data - fn free(&mut self) {} +/// The pointer must be a valid pointer to a `Box` created by `Box::into_raw`. +unsafe extern "C" fn drop_boxed(v: *mut c_void) { + drop(unsafe { Box::from_raw(v.cast::()) }); } /// Duckdb table function trait @@ -56,51 +39,33 @@ pub trait Free { /// See to the HelloVTab example for more details /// pub trait VTab: Sized { - /// The data type of the bind data - type InitData: Sized + Free; - /// The data type of the init data - type BindData: Sized + Free; - - /// Bind data to the table function + /// The data type of the init data. /// - /// # Safety + /// The init data tracks the state of the table function and is global across threads. /// - /// This function is unsafe because it dereferences raw pointers (`data`) and manipulates the memory directly. - /// The caller must ensure that: - /// - /// - The `data` pointer is valid and points to a properly initialized `BindData` instance. - /// - The lifetime of `data` must outlive the execution of `bind` to avoid dangling pointers, especially since - /// `bind` does not take ownership of `data`. - /// - Concurrent access to `data` (if applicable) must be properly synchronized. - /// - The `bind` object must be valid and correctly initialized. - unsafe fn bind(bind: &BindInfo, data: *mut Self::BindData) -> Result<(), Box>; - /// Initialize the table function - /// - /// # Safety - /// - /// This function is unsafe because it performs raw pointer dereferencing on the `data` argument. - /// The caller is responsible for ensuring that: - /// - /// - The `data` pointer is non-null and points to a valid `InitData` instance. - /// - There is no data race when accessing `data`, meaning if `data` is accessed from multiple threads, - /// proper synchronization is required. - /// - The lifetime of `data` extends beyond the scope of this call to avoid use-after-free errors. - unsafe fn init(init: &InitInfo, data: *mut Self::InitData) -> Result<(), Box>; - /// The actual function - /// - /// # Safety + /// The init data is shared across threads so must be `Send + Sync`. + type InitData: Sized + Send + Sync; + + /// The data type of the bind data. /// - /// This function is unsafe because it: + /// The bind data is shared across threads so must be `Send + Sync`. + type BindData: Sized + Send + Sync; + + /// Bind data to the table function /// - /// - Dereferences multiple raw pointers (`func` to access `init_info` and `bind_info`). + /// This function is used for determining the return type of a table producing function and returning bind data + fn bind(bind: &BindInfo) -> Result>; + + /// Initialize the table function + fn init(init: &InitInfo) -> Result>; + + /// Generate rows from the table function. /// - /// The caller must ensure that: + /// The implementation should populate the `output` parameter with the rows to be returned. /// - /// - All pointers (`func`, `output`, internal `init_info`, and `bind_info`) are valid and point to the expected types of data structures. - /// - The `init_info` and `bind_info` data pointed to remains valid and is not freed until after this function completes. - /// - No other threads are concurrently mutating the data pointed to by `init_info` and `bind_info` without proper synchronization. - /// - The `output` parameter is correctly initialized and can safely be written to. - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box>; + /// When the table function is done, the implementation should set the length of the output to 0. + fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box>; + /// Does the table function support pushdown /// default is false fn supports_pushdown() -> bool { @@ -122,7 +87,7 @@ unsafe extern "C" fn func(info: duckdb_function_info, output: duckdb_data_chu where T: VTab, { - let info = FunctionInfo::from(info); + let info = FunctionInfo::::from(info); let mut data_chunk_handle = DataChunkHandle::new_unowned(output); let result = T::func(&info, &mut data_chunk_handle); if result.is_err() { @@ -135,11 +100,16 @@ where T: VTab, { let info = InitInfo::from(info); - let data = malloc_data_c::(); - let result = T::init(&info, data); - info.set_init_data(data.cast(), Some(drop_data_c::)); - if result.is_err() { - info.set_error(&result.err().unwrap().to_string()); + match T::init(&info) { + Ok(init_data) => { + info.set_init_data( + Box::into_raw(Box::new(init_data)) as *mut c_void, + Some(drop_boxed::), + ); + } + Err(e) => { + info.set_error(&e.to_string()); + } } } @@ -148,11 +118,16 @@ where T: VTab, { let info = BindInfo::from(info); - let data = malloc_data_c::(); - let result = T::bind(&info, data); - info.set_bind_data(data.cast(), Some(drop_data_c::)); - if result.is_err() { - info.set_error(&result.err().unwrap().to_string()); + match T::bind(&info) { + Ok(bind_data) => { + info.set_bind_data( + Box::into_raw(Box::new(bind_data)) as *mut c_void, + Some(drop_boxed::), + ); + } + Err(e) => { + info.set_error(&e.to_string()); + } } } @@ -194,73 +169,51 @@ impl InnerConnection { mod test { use super::*; use crate::core::Inserter; + use crate::core::LogicalTypeId; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::Ordering; use std::{ error::Error, ffi::{c_char, CString}, }; - #[repr(C)] struct HelloBindData { - name: *mut c_char, + name: String, } - impl Free for HelloBindData { - fn free(&mut self) { - unsafe { - if self.name.is_null() { - return; - } - drop(CString::from_raw(self.name)); - } - } - } - - #[repr(C)] struct HelloInitData { - done: bool, + done: AtomicBool, } struct HelloVTab; - impl Free for HelloInitData {} - impl VTab for HelloVTab { type InitData = HelloInitData; type BindData = HelloBindData; - unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { + fn bind(bind: &BindInfo) -> Result> { bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar)); - let param = bind.get_parameter(0).to_string(); - unsafe { - (*data).name = CString::new(param).unwrap().into_raw(); - } - Ok(()) + let name = bind.get_parameter(0).to_string(); + Ok(HelloBindData { name }) } - unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { - unsafe { - (*data).done = false; - } - Ok(()) + fn init(_: &InitInfo) -> Result> { + Ok(HelloInitData { + done: AtomicBool::new(false), + }) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { - let init_info = func.get_init_data::(); - let bind_info = func.get_bind_data::(); - - unsafe { - if (*init_info).done { - output.set_len(0); - } else { - (*init_info).done = true; - let vector = output.flat_vector(0); - let name = CString::from_raw((*bind_info).name); - let result = CString::new(format!("Hello {}", name.to_str()?))?; - // Can't consume the CString - (*bind_info).name = CString::into_raw(name); - vector.insert(0, result); - output.set_len(1); - } + fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + let init_data = func.get_init_data(); + let bind_data = func.get_bind_data(); + + if init_data.done.swap(true, Ordering::Relaxed) { + output.set_len(0); + } else { + let vector = output.flat_vector(0); + let result = CString::new(format!("Hello {}", bind_data.name))?; + vector.insert(0, result); + output.set_len(1); } Ok(()) } @@ -275,22 +228,30 @@ mod test { type InitData = HelloInitData; type BindData = HelloBindData; - unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { + fn bind(bind: &BindInfo) -> Result> { bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar)); - let param = bind.get_named_parameter("name").unwrap().to_string(); + let name = bind.get_named_parameter("name").unwrap().to_string(); assert!(bind.get_named_parameter("unknown_name").is_none()); - unsafe { - (*data).name = CString::new(param).unwrap().into_raw(); - } - Ok(()) + Ok(HelloBindData { name }) } - unsafe fn init(init_info: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { - HelloVTab::init(init_info, data) + fn init(init_info: &InitInfo) -> Result> { + HelloVTab::init(init_info) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { - HelloVTab::func(func, output) + fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + let init_data = func.get_init_data(); + let bind_data = func.get_bind_data(); + + if init_data.done.swap(true, Ordering::Relaxed) { + output.set_len(0); + } else { + let vector = output.flat_vector(0); + let result = CString::new(format!("Hello {}", bind_data.name))?; + vector.insert(0, result); + output.set_len(1); + } + Ok(()) } fn named_parameters() -> Option> {