diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 70ced6eb2be6..318df0b0bfcf 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -525,9 +525,6 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): tol = 1e-6 elif name == "exp2": tol = 1e-6 - elif jtu.test_device_matches(["tpu"]): - if not jtu.is_device_tpu_at_least(version=5) and False: - self.skipTest("TODO: not implemented on TPU v{3,4}") def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...]) @@ -1413,7 +1410,7 @@ def test_dot(self, size, dtype, trans_x, trans_y): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - if jtu.test_device_matches(["tpu"]): + if jtu.test_device_matches(["tpu"]) and trans_x: self.skipTest("Not implemented: Transposed LHS") @functools.partial(