Skip to content

Commit

Permalink
Fixed type annotations to also work on MacOS
Browse files Browse the repository at this point in the history
  • Loading branch information
Sm00thix committed Aug 4, 2024
1 parent a613c5f commit c3bf004
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
54 changes: 27 additions & 27 deletions ikpls/fast_cross_validation/numpy_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
scale_X: bool = True,
scale_Y: bool = True,
algorithm: int = 1,
dtype: Union[np.float16, np.float32, np.float64, np.float128] = np.float64,
dtype: np.floating = np.float64,
) -> None:
self.center_X = center_X
self.center_Y = center_Y
Expand Down Expand Up @@ -139,27 +139,27 @@ def _stateless_fit(
validation_indices: npt.NDArray[np.int_],
) -> Union[
tuple[
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
],
tuple[
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
],
]:
"""
Expand Down Expand Up @@ -433,13 +433,13 @@ def _stateless_fit(
def _stateless_predict(
self,
indices: npt.NDArray[np.int_],
B: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
training_X_mean: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
training_Y_mean: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
training_X_std: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
training_Y_std: npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]],
B: npt.NDArray[np.floating],
training_X_mean: npt.NDArray[np.floating],
training_Y_mean: npt.NDArray[np.floating],
training_X_std: npt.NDArray[np.floating],
training_Y_std: npt.NDArray[np.floating],
n_components: Union[None, int] = None,
) -> npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]]:
) -> npt.NDArray[np.floating]:
"""
Predicts with Improved Kernel PLS Algorithm #1 on `X` with `B` using
`n_components` components. If `n_components` is None, then predictions are
Expand Down Expand Up @@ -503,7 +503,7 @@ def _stateless_fit_predict_eval(
self,
validation_indices: npt.NDArray[np.int_],
metric_function: Callable[
[npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]], npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]]], Any
[npt.NDArray[np.floating], npt.NDArray[np.floating]], Any
],
) -> Any:
"""
Expand Down
2 changes: 1 addition & 1 deletion ikpls/jax_ikpls_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self.X_std = None
self.Y_std = None

def _weight_warning(self, arg: Tuple[npt.NDArray[np.int_], npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]]]):
def _weight_warning(self, arg: Tuple[npt.NDArray[np.int_], npt.NDArray[np.floating]]):
"""
Display a warning message if the weight is close to zero.
Expand Down
4 changes: 2 additions & 2 deletions ikpls/numpy_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
scale_X: bool = True,
scale_Y: bool = True,
copy: bool = True,
dtype: Union[np.float16, np.float32, np.float64, np.float128] = np.float64,
dtype: np.floating = np.float64,
) -> None:
self.algorithm = algorithm
self.center_X = center_X
Expand Down Expand Up @@ -300,7 +300,7 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None:

def predict(
self, X: npt.ArrayLike, n_components: Union[None, int] = None
) -> npt.NDArray[Union[np.float16, np.float32, np.float64, np.float128]]:
) -> npt.NDArray[np.floating]:
"""
Predicts with Improved Kernel PLS Algorithm #1 on `X` with `B` using
`n_components` components. If `n_components` is None, then predictions are
Expand Down

0 comments on commit c3bf004

Please sign in to comment.