diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5538ce38baa3..5ace4b5ecf18 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1450,14 +1450,6 @@ def testLuBatching(self, shape, dtype): self.assertAllClose(ls, actual_ls, rtol=5e-6) self.assertAllClose(us, actual_us) - @jtu.skip_on_devices("cpu", "tpu") - @jtu.skip_on_flag("jax_skip_slow_tests", True) - def testBatchedLuOverflow(self): - # see https://github.com/jax-ml/jax/issues/24843 - x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32) - lu, _, _ = lax.linalg.lu(x) - self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9)) - @jtu.skip_on_devices("cpu", "tpu") @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument")