Skip to content

Commit

Permalink
Add arrow converter
Browse files Browse the repository at this point in the history
  • Loading branch information
gene-db committed Apr 3, 2024
1 parent 56ca807 commit eed91a4
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
NullType,
DataType,
UserDefinedType,
VariantType,
VariantVal,
_create_row,
)
from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError
Expand Down Expand Up @@ -108,6 +110,12 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
arrow_type = pa.null()
elif isinstance(dt, UserDefinedType):
arrow_type = to_arrow_type(dt.sqlType())
elif type(dt) == VariantType:
fields = [
pa.field("value", pa.binary(), nullable=False),
pa.field("metadata", pa.binary(), nullable=False),
]
arrow_type = pa.struct(fields)
else:
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
Expand Down Expand Up @@ -763,6 +771,18 @@ def convert_udt(value: Any) -> Any:

return convert_udt

elif isinstance(dt, VariantType):

def convert_variant(value: Any) -> Any:
if (isinstance(value, dict) and
all(key in value for key in ["value", "metadata"]) and
all(isinstance(value[key], bytes) for key in ["value", "metadata"])):
return VariantVal(value["value"], value["metadata"])
else:
raise PySparkValueError(error_class="MALFORMED_VARIANT")

return convert_variant

else:
return None

Expand Down

0 comments on commit eed91a4

Please sign in to comment.