diff --git a/ikpls/fast_cross_validation/numpy_ikpls.py b/ikpls/fast_cross_validation/numpy_ikpls.py index 8c0d241..d5b651c 100644 --- a/ikpls/fast_cross_validation/numpy_ikpls.py +++ b/ikpls/fast_cross_validation/numpy_ikpls.py @@ -84,7 +84,7 @@ def __init__( scale_X: bool = True, scale_Y: bool = True, algorithm: int = 1, - dtype: Union[np.float16, np.float32, np.float64, np.float128] = np.float64, + dtype: np.floating = np.float64, ) -> None: self.center_X = center_X self.center_Y = center_Y @@ -139,27 +139,27 @@ def _stateless_fit( validation_indices: npt.NDArray[np.int_], ) -> Union[ tuple[ - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], ], tuple[ - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], ], ]: """ @@ -433,13 +433,13 @@ def _stateless_fit( def _stateless_predict( self, indices: npt.NDArray[np.int_], - B: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - training_X_mean: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - training_Y_mean: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - training_X_std: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], - training_Y_std: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], + B: npt.NDArray[np.floating], + training_X_mean: npt.NDArray[np.floating], + training_Y_mean: npt.NDArray[np.floating], + training_X_std: npt.NDArray[np.floating], + training_Y_std: npt.NDArray[np.floating], n_components: Union[None, int] = None, - ) -> npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]]: + ) -> npt.NDArray[np.floating]: """ Predicts with Improved Kernel PLS Algorithm #1 on `X` with `B` using `n_components` components. If `n_components` is None, then predictions are @@ -503,7 +503,7 @@ def _stateless_fit_predict_eval( self, validation_indices: npt.NDArray[np.int_], metric_function: Callable[ - [npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]]], Any + [npt.NDArray[np.floating], npt.NDArray[np.floating]], Any ], ) -> Any: """ diff --git a/ikpls/jax_ikpls_base.py b/ikpls/jax_ikpls_base.py index 12cefef..39ded28 100644 --- a/ikpls/jax_ikpls_base.py +++ b/ikpls/jax_ikpls_base.py @@ -115,7 +115,7 @@ def __init__( self.X_std = None self.Y_std = None - def _weight_warning(self, arg: Tuple[npt.NDArray[np.int_], npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]]]): + def _weight_warning(self, arg: Tuple[npt.NDArray[np.int_], npt.NDArray[np.floating]]): """ Display a warning message if the weight is close to zero. diff --git a/ikpls/numpy_ikpls.py b/ikpls/numpy_ikpls.py index 0d895c2..679577a 100644 --- a/ikpls/numpy_ikpls.py +++ b/ikpls/numpy_ikpls.py @@ -79,7 +79,7 @@ def __init__( scale_X: bool = True, scale_Y: bool = True, copy: bool = True, - dtype: Union[np.float16, np.float32, np.float64, np.float128] = np.float64, + dtype: np.floating = np.float64, ) -> None: self.algorithm = algorithm self.center_X = center_X @@ -300,7 +300,7 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None: def predict( self, X: npt.ArrayLike, n_components: Union[None, int] = None - ) -> npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]]: + ) -> npt.NDArray[np.floating]: """ Predicts with Improved Kernel PLS Algorithm #1 on `X` with `B` using `n_components` components. If `n_components` is None, then predictions are