diff --git a/camtools/io.py b/camtools/io.py index 46a52263..722f0154 100644 --- a/camtools/io.py +++ b/camtools/io.py @@ -5,7 +5,7 @@ import cv2 import numpy as np from pathlib import Path -from typing import Union, Optional +from typing import Union, Optional, Literal from jaxtyping import UInt8, Float from . import sanity @@ -185,7 +185,7 @@ def imwrite_depth( def imread( im_path: Union[str, Path], - alpha_mode: Optional[str] = None, + alpha_mode: Optional[Literal["keep", "ignore", "white", "black"]] = None, ) -> Union[ Float[np.ndarray, "h w"], Float[np.ndarray, "h w 3"], diff --git a/camtools/metric.py b/camtools/metric.py index 86febfea..0bbf1784 100644 --- a/camtools/metric.py +++ b/camtools/metric.py @@ -6,7 +6,7 @@ from skimage.metrics import peak_signal_noise_ratio from skimage.metrics import structural_similarity from pathlib import Path -from typing import Tuple, Optional, Union +from typing import Tuple, Optional, Union, Literal from jaxtyping import Float from . import image @@ -200,7 +200,7 @@ def load_im_pd_im_gt_im_mask_for_eval( im_pd_path: Union[str, Path], im_gt_path: Union[str, Path], im_mask_path: Optional[Union[str, Path]] = None, - alpha_mode: str = "white", + alpha_mode: Literal["white", "keep"] = "white", ) -> Tuple[ Float[np.ndarray, "h w 3"], Float[np.ndarray, "h w 3"], Float[np.ndarray, "h w"] ]: