Skip to content

Commit

Permalink
Support Variant in arrow writer
Browse files Browse the repository at this point in the history
  • Loading branch information
gene-db committed Apr 2, 2024
1 parent 734af23 commit 56ca807
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
10 changes: 10 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ private[sql] object ArrowUtils {
largeVarTypes)).asJava)
case udt: UserDefinedType[_] =>
toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
case _: VariantType =>
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null,
Map("variant" -> "true").asJava)
new Field(name, fieldType,
Seq(toArrowField("value", BinaryType, false, timeZoneId, largeVarTypes),
toArrowField("metadata", BinaryType, false, timeZoneId, largeVarTypes)).asJava)
case dataType =>
val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId,
largeVarTypes), null)
Expand All @@ -143,6 +149,10 @@ private[sql] object ArrowUtils {
val elementField = field.getChildren().get(0)
val elementType = fromArrowField(elementField)
ArrayType(elementType, containsNull = elementField.isNullable)
case ArrowType.Struct.INSTANCE if field.getMetadata.getOrDefault("variant", "") == "true"
&& field.getChildren.asScala.map(_.getName).asJava
.containsAll(Seq("value", "metadata").asJava) =>
VariantType
case ArrowType.Struct.INSTANCE =>
val fields = field.getChildren().asScala.map { child =>
val dt = fromArrowField(child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ object ArrowWriter {
case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector)
case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) =>
new IntervalMonthDayNanoWriter(vector)
case (VariantType, vector: StructVector) =>
val children = (0 until vector.size()).map { ordinal =>
createFieldWriter(vector.getChildByOrdinal(ordinal))
}
new StructWriter(vector, children.toArray)
case (dt, _) =>
throw ExecutionErrors.unsupportedDataTypeError(dt)
}
Expand Down Expand Up @@ -368,6 +373,8 @@ private[arrow] class StructWriter(
val valueVector: StructVector,
children: Array[ArrowFieldWriter]) extends ArrowFieldWriter {

lazy val isVariant = valueVector.getField.getMetadata.get("variant") == "true"

override def setNull(): Unit = {
var i = 0
while (i < children.length) {
Expand All @@ -379,12 +386,20 @@ private[arrow] class StructWriter(
}

override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
val struct = input.getStruct(ordinal, children.length)
var i = 0
valueVector.setIndexDefined(count)
while (i < struct.numFields) {
children(i).write(struct, i)
i += 1
if (isVariant) {
valueVector.setIndexDefined(count)
val v = input.getVariant(ordinal)
val row = InternalRow(v.getValue, v.getMetadata)
children(0).write(row, 0)
children(1).write(row, 1)
} else {
val struct = input.getStruct(ordinal, children.length)
var i = 0
valueVector.setIndexDefined(count)
while (i < struct.numFields) {
children(i).write(struct, i)
i += 1
}
}
}

Expand Down

0 comments on commit 56ca807

Please sign in to comment.