Skip to content

Commit

Permalink
add basic enum support
Browse files Browse the repository at this point in the history
  • Loading branch information
Mause committed Apr 18, 2024
1 parent d110924 commit 528cf3a
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 3 deletions.
21 changes: 20 additions & 1 deletion src/row.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::{convert, sync::Arc};

use super::{Error, Result, Statement};
use crate::types::{self, FromSql, FromSqlError, ValueRef};
use crate::types::{self, EnumType, FromSql, FromSqlError, ValueRef};

use arrow::array::DictionaryArray;
use arrow::{
array::{self, Array, ArrayRef, ListArray, StructArray},
datatypes::*,
Expand Down Expand Up @@ -601,6 +602,24 @@ impl<'stmt> Row<'stmt> {

ValueRef::List(arr, row)
}
DataType::Dictionary(key_type, ..) => {
let column = column.as_any();
ValueRef::Enum(
match key_type.as_ref() {
DataType::UInt8 => {
EnumType::UInt8(column.downcast_ref::<DictionaryArray<UInt8Type>>().unwrap())
}
DataType::UInt16 => {
EnumType::UInt16(column.downcast_ref::<DictionaryArray<UInt16Type>>().unwrap())
}
DataType::UInt32 => {
EnumType::UInt32(column.downcast_ref::<DictionaryArray<UInt32Type>>().unwrap())
}
typ => panic!("Unsupported key type: {typ:?}"),
},
row,
)
}
_ => unreachable!("invalid value: {} {}", col, column.data_type()),
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub use self::{
from_sql::{FromSql, FromSqlError, FromSqlResult},
to_sql::{ToSql, ToSqlOutput},
value::Value,
value_ref::{TimeUnit, ValueRef},
value_ref::{EnumType, TimeUnit, ValueRef},
};

use arrow::datatypes::DataType;
Expand Down Expand Up @@ -149,6 +149,8 @@ pub enum Type {
Interval,
/// LIST
List(Box<Type>),
/// ENUM
Enum,
/// Any
Any,
}
Expand Down Expand Up @@ -219,6 +221,7 @@ impl fmt::Display for Type {
Type::Time64 => f.pad("Time64"),
Type::Interval => f.pad("Interval"),
Type::List(..) => f.pad("List"),
Type::Enum => f.pad("Enum"),
Type::Any => f.pad("Any"),
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub enum Value {
},
/// The value is a list
List(Vec<Value>),
/// The value is an enum
Enum(String),
}

impl From<Null> for Value {
Expand Down Expand Up @@ -225,6 +227,7 @@ impl Value {
Value::Time64(..) => Type::Time64,
Value::Interval { .. } => Type::Interval,
Value::List(_) => todo!(),
Value::Enum(..) => Type::Enum,
}
}
}
35 changes: 34 additions & 1 deletion src/types/value_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use crate::types::{FromSqlError, FromSqlResult};
use crate::Row;
use rust_decimal::prelude::*;

use arrow::array::{Array, ListArray};
use arrow::array::{Array, DictionaryArray, ListArray};
use arrow::datatypes::{UInt16Type, UInt32Type, UInt8Type};

/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds.
/// Copy from arrow::datatypes::TimeUnit
Expand Down Expand Up @@ -75,6 +76,19 @@ pub enum ValueRef<'a> {
},
/// The value is a list
List(&'a ListArray, usize),
/// The value is an enum
Enum(EnumType<'a>, usize),
}

/// Wrapper type for different enum sizes
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum EnumType<'a> {
/// The underlying enum type is u8
UInt8(&'a DictionaryArray<UInt8Type>),
/// The underlying enum type is u16
UInt16(&'a DictionaryArray<UInt16Type>),
/// The underlying enum type is u32
UInt32(&'a DictionaryArray<UInt32Type>),
}

impl ValueRef<'_> {
Expand Down Expand Up @@ -103,6 +117,7 @@ impl ValueRef<'_> {
ValueRef::Time64(..) => Type::Time64,
ValueRef::Interval { .. } => Type::Interval,
ValueRef::List(arr, _) => arr.data_type().into(),
ValueRef::Enum(..) => Type::Enum,
}
}

Expand Down Expand Up @@ -170,6 +185,23 @@ impl From<ValueRef<'_>> for Value {
.collect();
Value::List(map)
}
ValueRef::Enum(items, idx) => {
let value = Row::value_ref_internal(
idx,
0,
match items {
EnumType::UInt8(res) => res.values(),
EnumType::UInt16(res) => res.values(),
EnumType::UInt32(res) => res.values(),
},
).to_owned();

if let Value::Text(s) = value {
Value::Enum(s)
} else {
panic!("Enum value is not a string")
}
}
}
}
}
Expand Down Expand Up @@ -213,6 +245,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> {
Value::Time64(t, d) => ValueRef::Time64(t, d),
Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos },
Value::List(..) => unimplemented!(),
Value::Enum(..) => todo!(),
}
}
}
Expand Down

0 comments on commit 528cf3a

Please sign in to comment.