diff --git a/examples/advection-esweno-conservation.py b/examples/advection-esweno-conservation.py index 4e7b366..3589c40 100644 --- a/examples/advection-esweno-conservation.py +++ b/examples/advection-esweno-conservation.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: MIT import pathlib -from typing import cast import jax import jax.numpy as jnp @@ -32,7 +31,7 @@ def cosine_pulse( r = (0.5 + 0.5 * jnp.cos(w * (x - xc))) ** p mask = (jnp.abs(x - xc) < sigma).astype(x.dtype) - return cast(Array, r * mask) + return r * mask def main( diff --git a/pyshocks/burgers/schemes.py b/pyshocks/burgers/schemes.py index 9716480..6682b4c 100644 --- a/pyshocks/burgers/schemes.py +++ b/pyshocks/burgers/schemes.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: MIT from dataclasses import dataclass, field, replace -from typing import ClassVar, cast +from typing import ClassVar import jax.numpy as jnp @@ -324,7 +324,7 @@ def hesthaven_limiter(u: Array, *, variant: int = 1) -> Array: else: raise ValueError(f"Unknown variant: {variant!r}") - return cast(Array, phi) + return phi @numerical_flux.register(SSMUSCL) diff --git a/pyshocks/convolve.py b/pyshocks/convolve.py index f5ca06d..6687678 100644 --- a/pyshocks/convolve.py +++ b/pyshocks/convolve.py @@ -18,7 +18,6 @@ from __future__ import annotations import enum -from typing import cast import jax.numpy as jnp @@ -123,4 +122,4 @@ def convolve1d( result = u[n:-n] assert result.shape == ary.shape - return cast(Array, result) + return result diff --git a/pyshocks/sbp.py b/pyshocks/sbp.py index e9ffd89..add5c56 100644 --- a/pyshocks/sbp.py +++ b/pyshocks/sbp.py @@ -105,7 +105,7 @@ import enum from dataclasses import dataclass, replace from functools import singledispatch -from typing import Any, cast +from typing import Any import jax import jax.numpy as jnp @@ -363,7 +363,7 @@ def make_sbp_mattsson2012_second_derivative( assert jnp.linalg.norm(M - M.T) < 1.0e-8 assert jnp.linalg.norm(jnp.sum(M, axis=1)) < 1.0e-8 - return cast(Array, invP @ (-M + BS)) + return invP @ (-M + BS) @singledispatch @@ -562,7 +562,7 @@ def make_boundary(b1: Array, b2: Array, b3: Array, b4: Array) -> Array: n, m = mb_r.shape M = M.at[-n:, -m:].set(mb_r) - return cast(Array, M / dx) + return M / dx def make_sbp_21_norm_stencil(dtype: Any = None) -> Stencil: @@ -780,10 +780,8 @@ def make_sbp_42_second_derivative_r_matrix( C34 = jnp.diag(make_sbp_matrix_from_stencil(bc, n, c34)) C44 = jnp.diag(make_sbp_matrix_from_stencil(bc, n, c44)) - return cast( - Array, - dx**5 / 18 * D34.T @ C34 @ B34 @ D34 - + dx**7 / 144 * D44.T @ C44 @ B44 @ D44, + return ( + dx**5 / 18 * D34.T @ C34 @ B34 @ D34 + dx**7 / 144 * D44.T @ C44 @ B44 @ D44 ) @@ -1042,7 +1040,7 @@ def make_boundary( n, m = mb_r.shape M = M.at[-n:, -m:].set(mb_r) - return cast(Array, -M / dx) + return -M / dx def make_sbp_42_norm_stencil(dtype: Any = None) -> Stencil: diff --git a/pyshocks/timestepping.py b/pyshocks/timestepping.py index 35717fb..235fe40 100644 --- a/pyshocks/timestepping.py +++ b/pyshocks/timestepping.py @@ -107,7 +107,7 @@ def step( """ if tfinal is None: # NOTE: can't set this to jnp.inf because we have a debug check for it - tfinal = jnp.finfo(u0.dtype).max + tfinal = jnp.finfo(u0.dtype).max # type: ignore[no-untyped-call] m = 0 t = jnp.array(tstart, dtype=u0.dtype) diff --git a/pyshocks/tools.py b/pyshocks/tools.py index bad1344..3d48157 100644 --- a/pyshocks/tools.py +++ b/pyshocks/tools.py @@ -173,7 +173,7 @@ def estimate_order_of_convergence(x: Array, y: Array) -> tuple[Scalar, Scalar]: if x.size <= 1: raise RuntimeError("Need at least two values to estimate order.") - eps = jnp.finfo(x.dtype).eps + eps = jnp.finfo(x.dtype).eps # type: ignore[no-untyped-call] c = jnp.polyfit(jnp.log10(x + eps), jnp.log10(y + eps), 1) return 10 ** c[-1], c[-2] @@ -272,7 +272,7 @@ def satisfied( _, error = self._history if atol is None: - atol = 1.0e2 * jnp.finfo(error.dtype).eps + atol = 1.0e2 * jnp.finfo(error.dtype).eps # type: ignore[no-untyped-call] return bool(self.estimated_order >= (order - slack) or jnp.max(error) < atol) @@ -877,7 +877,7 @@ def _anim_func(n: int) -> tuple[Any, ...]: if legends is not None: ax.set_ylabel(ylabel) ax.set_xlim((float(x[0]), float(x[-1]))) - ax.set_ylim((ymin - 0.1 * jnp.abs(ymin), ymax + 0.1 * jnp.abs(ymax))) + ax.set_ylim((float(ymin - 0.1 * jnp.abs(ymin)), float(ymax + 0.1 * jnp.abs(ymax)))) ax.margins(0.05) if legends: diff --git a/tests/test_finite_difference.py b/tests/test_finite_difference.py index 4d69e9f..16ba857 100644 --- a/tests/test_finite_difference.py +++ b/tests/test_finite_difference.py @@ -333,7 +333,7 @@ def test_finite_difference_taylor_stencil(*, visualize: bool = False) -> None: ax.set_xlabel("$k h$") ax.set_ylabel(r"$\tilde{k} h$") ax.set_xlim((0.0, jnp.pi)) - ax.set_ylim((0.0, sign * jnp.pi**s.derivative)) + ax.set_ylim((0.0, float(sign * jnp.pi**s.derivative))) fig.savefig(f"finite_difference_wavenumber_{s.derivative}_{s.order}") fig.clf() diff --git a/tests/test_sbp.py b/tests/test_sbp.py index 68c32a9..1f3d656 100644 --- a/tests/test_sbp.py +++ b/tests/test_sbp.py @@ -85,7 +85,7 @@ def test_sbp_matrices(name: str, bc: BoundaryType, *, visualize: bool = False) - fig.clf() # NOTE: allow negative values larger than eps because floating point.. - mask = jnp.real(s) > -jnp.finfo(dtype).eps + mask = jnp.real(s) > -jnp.finfo(dtype).eps # type: ignore[no-untyped-call] assert jnp.all(mask), jnp.real(s[~mask]) # }}}