Skip to content

Commit

Permalink
[Pallas:TPU] Fix some stale/wrong skip conditions.
Browse files Browse the repository at this point in the history
Surprised that we didn't test f32 dot_general on TPU (?) Even tpu_ops_test doesn't exercise it.

PiperOrigin-RevId: 693777426
  • Loading branch information
WindQAQ authored and Google-ML-Automation committed Nov 6, 2024
1 parent 3df204a commit b6f5c95
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,6 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data):
tol = 1e-6
elif name == "exp2":
tol = 1e-6
elif jtu.test_device_matches(["tpu"]):
if not jtu.is_device_tpu_at_least(version=5) and False:
self.skipTest("TODO: not implemented on TPU v{3,4}")

def kernel(x_ref, y_ref):
y_ref[...] = func(x_ref[...])
Expand Down Expand Up @@ -1413,7 +1410,7 @@ def test_dot(self, size, dtype, trans_x, trans_y):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

if jtu.test_device_matches(["tpu"]):
if jtu.test_device_matches(["tpu"]) and trans_x:
self.skipTest("Not implemented: Transposed LHS")

@functools.partial(
Expand Down

0 comments on commit b6f5c95

Please sign in to comment.