Skip to content

Commit

Permalink
mypy annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
gene-db committed Apr 8, 2024
1 parent d15cdad commit 807d574
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
7 changes: 3 additions & 4 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,9 +1349,8 @@ def needConversion(self) -> bool:
return True

def fromInternal(self, obj: Dict) -> "VariantVal":
if obj is None or not all(key in obj for key in ["value", "metadata"]):
return
return VariantVal(obj["value"], obj["metadata"])
if obj is not None and all(key in obj for key in ["value", "metadata"]):
return VariantVal(obj["value"], obj["metadata"])


class UserDefinedType(DataType):
Expand Down Expand Up @@ -1510,7 +1509,7 @@ def __str__(self) -> str:
return VariantUtils.to_json(self.value, self.metadata)

def __repr__(self) -> str:
return "VariantVal(%s, %s)" % (self.value, self.metadata)
return "VariantVal(%r, %r)" % (self.value, self.metadata)

def toPython(self) -> Any:
"""
Expand Down
22 changes: 12 additions & 10 deletions python/pyspark/sql/variant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import struct
from array import array
from typing import Any
from typing import Any, Callable, Dict, List, Tuple
from pyspark.errors import PySparkValueError


Expand Down Expand Up @@ -120,7 +120,7 @@ def _check_index(cls, pos: int, length: int) -> None:
raise PySparkValueError(error_class="MALFORMED_VARIANT")

@classmethod
def _get_type_info(cls, value: bytes, pos: int):
def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]:
"""
Returns the (basic_type, type_info) pair from the given position in the value.
"""
Expand Down Expand Up @@ -221,7 +221,7 @@ def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))

@classmethod
def _get_type(cls, value: bytes, pos: int):
def _get_type(cls, value: bytes, pos: int) -> Any:
"""
Returns the Python type of the Variant at the given position.
"""
Expand Down Expand Up @@ -261,7 +261,7 @@ def _to_json(cls, value: bytes, metadata: bytes, pos: int) -> Any:
variant_type = cls._get_type(value, pos)
if variant_type == dict:

def handle_object(key_value_pos_list):
def handle_object(key_value_pos_list: list[Tuple[str, int]]) -> str:
key_value_list = [
json.dumps(key) + ":" + cls._to_json(value, metadata, value_pos)
for (key, value_pos) in key_value_pos_list
Expand All @@ -271,7 +271,7 @@ def handle_object(key_value_pos_list):
return cls._handle_object(value, metadata, pos, handle_object)
elif variant_type == array:

def handle_array(value_pos_list):
def handle_array(value_pos_list: list[int]) -> str:
value_list = [
cls._to_json(value, metadata, value_pos) for value_pos in value_pos_list
]
Expand All @@ -293,7 +293,7 @@ def _to_python(cls, value: bytes, metadata: bytes, pos: int) -> Any:
variant_type = cls._get_type(value, pos)
if variant_type == dict:

def handle_object(key_value_pos_list):
def handle_object(key_value_pos_list: list[Tuple[str, int]]) -> Dict[str, Any]:
key_value_list = [
(key, cls._to_python(value, metadata, value_pos))
for (key, value_pos) in key_value_pos_list
Expand All @@ -303,7 +303,7 @@ def handle_object(key_value_pos_list):
return cls._handle_object(value, metadata, pos, handle_object)
elif variant_type == array:

def handle_array(value_pos_list):
def handle_array(value_pos_list: list[int]) -> List[Any]:
value_list = [
cls._to_python(value, metadata, value_pos) for value_pos in value_pos_list
]
Expand All @@ -314,7 +314,7 @@ def handle_array(value_pos_list):
return cls._get_scalar(variant_type, value, metadata, pos)

@classmethod
def _get_scalar(cls, variant_type, value: bytes, metadata: bytes, pos: int) -> Any:
def _get_scalar(cls, variant_type: Any, value: bytes, metadata: bytes, pos: int) -> Any:
if isinstance(None, variant_type):
return None
elif variant_type == bool:
Expand All @@ -331,7 +331,9 @@ def _get_scalar(cls, variant_type, value: bytes, metadata: bytes, pos: int) -> A
raise PySparkValueError(error_class="MALFORMED_VARIANT")

@classmethod
def _handle_object(cls, value: bytes, metadata: bytes, pos: int, func):
def _handle_object(
cls, value: bytes, metadata: bytes, pos: int, func: Callable[[list[Tuple[str, int]]], Any]
) -> Any:
"""
Parses the variant object at position `pos`.
Calls `func` with a list of (key, value position) pairs of the object.
Expand Down Expand Up @@ -360,7 +362,7 @@ def _handle_object(cls, value: bytes, metadata: bytes, pos: int, func):
return func(key_value_pos_list)

@classmethod
def _handle_array(cls, value: bytes, pos: int, func):
def _handle_array(cls, value: bytes, pos: int, func: Callable[[list[int]], Any]) -> Any:
"""
Parses the variant array at position `pos`.
Calls `func` with a list of element positions of the array.
Expand Down

0 comments on commit 807d574

Please sign in to comment.