From 807d57424dd2bdbe2000b0634b54cfd684dfd29d Mon Sep 17 00:00:00 2001 From: Gene Pang Date: Sun, 7 Apr 2024 21:58:05 -0700 Subject: [PATCH] mypy annotations --- python/pyspark/sql/types.py | 7 +++---- python/pyspark/sql/variant_utils.py | 22 ++++++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 50243255bf145..a84111725035b 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1349,9 +1349,8 @@ def needConversion(self) -> bool: return True def fromInternal(self, obj: Dict) -> "VariantVal": - if obj is None or not all(key in obj for key in ["value", "metadata"]): - return - return VariantVal(obj["value"], obj["metadata"]) + if obj is not None and all(key in obj for key in ["value", "metadata"]): + return VariantVal(obj["value"], obj["metadata"]) class UserDefinedType(DataType): @@ -1510,7 +1509,7 @@ def __str__(self) -> str: return VariantUtils.to_json(self.value, self.metadata) def __repr__(self) -> str: - return "VariantVal(%s, %s)" % (self.value, self.metadata) + return "VariantVal(%r, %r)" % (self.value, self.metadata) def toPython(self) -> Any: """ diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 9f6eba098d125..9ca70365316d1 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -19,7 +19,7 @@ import json import struct from array import array -from typing import Any +from typing import Any, Callable, Dict, List, Tuple from pyspark.errors import PySparkValueError @@ -120,7 +120,7 @@ def _check_index(cls, pos: int, length: int) -> None: raise PySparkValueError(error_class="MALFORMED_VARIANT") @classmethod - def _get_type_info(cls, value: bytes, pos: int): + def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]: """ Returns the (basic_type, type_info) pair from the given position in the value. """ @@ -221,7 +221,7 @@ def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal: return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale)) @classmethod - def _get_type(cls, value: bytes, pos: int): + def _get_type(cls, value: bytes, pos: int) -> Any: """ Returns the Python type of the Variant at the given position. """ @@ -261,7 +261,7 @@ def _to_json(cls, value: bytes, metadata: bytes, pos: int) -> Any: variant_type = cls._get_type(value, pos) if variant_type == dict: - def handle_object(key_value_pos_list): + def handle_object(key_value_pos_list: list[Tuple[str, int]]) -> str: key_value_list = [ json.dumps(key) + ":" + cls._to_json(value, metadata, value_pos) for (key, value_pos) in key_value_pos_list @@ -271,7 +271,7 @@ def handle_object(key_value_pos_list): return cls._handle_object(value, metadata, pos, handle_object) elif variant_type == array: - def handle_array(value_pos_list): + def handle_array(value_pos_list: list[int]) -> str: value_list = [ cls._to_json(value, metadata, value_pos) for value_pos in value_pos_list ] @@ -293,7 +293,7 @@ def _to_python(cls, value: bytes, metadata: bytes, pos: int) -> Any: variant_type = cls._get_type(value, pos) if variant_type == dict: - def handle_object(key_value_pos_list): + def handle_object(key_value_pos_list: list[Tuple[str, int]]) -> Dict[str, Any]: key_value_list = [ (key, cls._to_python(value, metadata, value_pos)) for (key, value_pos) in key_value_pos_list @@ -303,7 +303,7 @@ def handle_object(key_value_pos_list): return cls._handle_object(value, metadata, pos, handle_object) elif variant_type == array: - def handle_array(value_pos_list): + def handle_array(value_pos_list: list[int]) -> List[Any]: value_list = [ cls._to_python(value, metadata, value_pos) for value_pos in value_pos_list ] @@ -314,7 +314,7 @@ def handle_array(value_pos_list): return cls._get_scalar(variant_type, value, metadata, pos) @classmethod - def _get_scalar(cls, variant_type, value: bytes, metadata: bytes, pos: int) -> Any: + def _get_scalar(cls, variant_type: Any, value: bytes, metadata: bytes, pos: int) -> Any: if isinstance(None, variant_type): return None elif variant_type == bool: @@ -331,7 +331,9 @@ def _get_scalar(cls, variant_type, value: bytes, metadata: bytes, pos: int) -> A raise PySparkValueError(error_class="MALFORMED_VARIANT") @classmethod - def _handle_object(cls, value: bytes, metadata: bytes, pos: int, func): + def _handle_object( + cls, value: bytes, metadata: bytes, pos: int, func: Callable[[list[Tuple[str, int]]], Any] + ) -> Any: """ Parses the variant object at position `pos`. Calls `func` with a list of (key, value position) pairs of the object. @@ -360,7 +362,7 @@ def _handle_object(cls, value: bytes, metadata: bytes, pos: int, func): return func(key_value_pos_list) @classmethod - def _handle_array(cls, value: bytes, pos: int, func): + def _handle_array(cls, value: bytes, pos: int, func: Callable[[list[int]], Any]) -> Any: """ Parses the variant array at position `pos`. Calls `func` with a list of element positions of the array.