Skip to content

Commit

Permalink
Add VariantVal for PySpark
Browse files Browse the repository at this point in the history
  • Loading branch information
gene-db committed Apr 2, 2024
1 parent db0975c commit 734af23
Show file tree
Hide file tree
Showing 3 changed files with 477 additions and 1 deletion.
63 changes: 63 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
BooleanType,
NullType,
VariantType,
VariantVal,
)
from pyspark.sql.types import (
_array_signed_int_typecode_ctype_mappings,
Expand Down Expand Up @@ -1406,6 +1407,68 @@ def test_calendar_interval_type_with_sf(self):
schema1 = self.spark.range(1).select(F.make_interval(F.lit(1))).schema
self.assertEqual(schema1.fields[0].dataType, CalendarIntervalType())

def test_variant_type(self):
from decimal import Decimal
self.assertEqual(VariantType().simpleString(), "variant")

# Holds a tuple of (key, json string value, python value)
expected_values = [
("str", '"%s"' % ("0123456789" * 10), "0123456789" * 10),
("short_str", '"abc"', "abc"),
("null", "null", None),
("true", "true", True),
("false", "false", False),
("int1", "1", 1),
("-int1", "-5", -5),
("int2", "257", 257),
("-int2", "-124", -124),
("int4", "65793", 65793),
("-int4", "-69633", -69633),
("int8", "4295033089", 4295033089),
("-int8", "-4294967297", -4294967297),
("float4", "1.23456789e-30", 1.23456789e-30),
("-float4", "-4.56789e+29", -4.56789e+29),
("dec4", "123.456", Decimal("123.456")),
("-dec4", "-321.654", Decimal("-321.654")),
("dec8", "429.4967297", Decimal("429.4967297")),
("-dec8", "-5.678373902", Decimal("-5.678373902")),
("dec16", "467440737095.51617", Decimal("467440737095.51617")),
("-dec16", "-67.849438003827263", Decimal("-67.849438003827263")),
("arr", '[1.1,"2",[3],{"4":5}]', [Decimal("1.1"), "2", [3], {"4": 5}]),
("obj", '{"a":["123",{"b":2}],"c":3}', {"a": ["123", {"b": 2}], "c": 3}),
]
json_str = "{%s}" % ",".join(['"%s": %s' % (t[0], t[1]) for t in expected_values])

df = self.spark.createDataFrame([({"json": json_str})])
row = df.select(F.parse_json(df.json).alias("v"),
F.array([F.parse_json(F.lit('{"a": 1}'))]).alias("a"),
F.struct([F.parse_json(F.lit('{"b": "2"}'))]).alias("s"),
F.create_map(
[F.lit("k"), F.parse_json(F.lit('{"c": true}'))]).alias("m")).collect()[0]
variants = [row["v"], row["a"][0], row["s"]["col1"], row["m"]["k"]]
for v in variants:
self.assertEqual(type(v), VariantVal)

# check str
as_string = str(variants[0])
for (key, expected, _) in expected_values:
self.assertTrue('"%s":%s' % (key, expected) in as_string)
self.assertEqual(str(variants[1]), '{"a":1}')
self.assertEqual(str(variants[2]), '{"b":"2"}')
self.assertEqual(str(variants[3]), '{"c":true}')

# check toPython
as_python = variants[0].toPython()
for (key, _, obj) in expected_values:
self.assertEqual(as_python[key], obj)
self.assertEqual(variants[1].toPython(), {"a": 1})
self.assertEqual(variants[2].toPython(), {"b": "2"})
self.assertEqual(variants[3].toPython(), {"c": True})

# check repr
self.assertEqual(str(variants[0]), str(eval(repr(variants[0]))))


def test_from_ddl(self):
self.assertEqual(DataType.fromDDL("long"), LongType())
self.assertEqual(
Expand Down
39 changes: 38 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.utils import has_numpy, get_active_spark_context
from pyspark.sql.variant_utils import VariantUtils
from pyspark.errors import (
PySparkNotImplementedError,
PySparkTypeError,
Expand Down Expand Up @@ -1344,7 +1345,13 @@ class VariantType(AtomicType):
.. versionadded:: 4.0.0
"""

pass
def needConversion(self) -> bool:
return True

def fromInternal(self, obj: Dict) -> "VariantVal":
if obj is None or not all(key in obj for key in ["value", "metadata"]):
return
return VariantVal(obj["value"], obj["metadata"])


class UserDefinedType(DataType):
Expand Down Expand Up @@ -1468,6 +1475,36 @@ def __eq__(self, other: Any) -> bool:
return type(self) == type(other)


class VariantVal:
"""
A class to represent a Variant value in Python.
"""

def __init__(self, value: bytes, metadata: bytes):
self.value = value
self.metadata = metadata

def __str__(self) -> str:
return self.toString()

def __repr__(self) -> str:
return "VariantVal(%s, %s)" % (self.value, self.metadata)

def toString(self) -> str:
"""
Convert the VariantVal to a string.
:return: a string representation of the Variant
"""
return VariantUtils.to_json(self.value, self.metadata)

def toPython(self) -> str:
"""
Convert the VariantVal to a Python data structure.
:return: a Python object
"""
return VariantUtils.to_python(self.value, self.metadata)


_atomic_types: List[Type[DataType]] = [
StringType,
CharType,
Expand Down
Loading

0 comments on commit 734af23

Please sign in to comment.