From 86d3439be9af7e13c16f6b5065f89e9f1448963a Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Fri, 10 Jan 2025 20:27:54 +0800 Subject: [PATCH] refactor(typing): add Literal type hints for alpha_mode parameters in io.py and metric.py to enforce valid input values and improve type safety --- camtools/io.py | 4 ++-- camtools/metric.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/camtools/io.py b/camtools/io.py index 46a52263..722f0154 100644 --- a/camtools/io.py +++ b/camtools/io.py @@ -5,7 +5,7 @@ import cv2 import numpy as np from pathlib import Path -from typing import Union, Optional +from typing import Union, Optional, Literal from jaxtyping import UInt8, Float from . import sanity @@ -185,7 +185,7 @@ def imwrite_depth( def imread( im_path: Union[str, Path], - alpha_mode: Optional[str] = None, + alpha_mode: Optional[Literal["keep", "ignore", "white", "black"]] = None, ) -> Union[ Float[np.ndarray, "h w"], Float[np.ndarray, "h w 3"], diff --git a/camtools/metric.py b/camtools/metric.py index 86febfea..0bbf1784 100644 --- a/camtools/metric.py +++ b/camtools/metric.py @@ -6,7 +6,7 @@ from skimage.metrics import peak_signal_noise_ratio from skimage.metrics import structural_similarity from pathlib import Path -from typing import Tuple, Optional, Union +from typing import Tuple, Optional, Union, Literal from jaxtyping import Float from . import image @@ -200,7 +200,7 @@ def load_im_pd_im_gt_im_mask_for_eval( im_pd_path: Union[str, Path], im_gt_path: Union[str, Path], im_mask_path: Optional[Union[str, Path]] = None, - alpha_mode: str = "white", + alpha_mode: Literal["white", "keep"] = "white", ) -> Tuple[ Float[np.ndarray, "h w 3"], Float[np.ndarray, "h w 3"], Float[np.ndarray, "h w"] ]: