Skip to content

Commit

Permalink
Disable flaky python callback test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575893965
  • Loading branch information
gnecula authored and jax authors committed Oct 23, 2023
1 parent 9b1a656 commit 9bc0439
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/jaxpr_effects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,8 @@ def f(x):
jax.effects_barrier()
self.assertListEqual(log, [2., 3.])

# TODO(b/307211483): Investigate failure
@jtu.skip_on_devices("tpu")
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
Expand Down Expand Up @@ -632,8 +634,8 @@ def g(x):
f(jnp.ones((500, 500)))
g(3.)
jax.effects_barrier()
x_, y_ = float(jnp.log(1.25e8)), 3.
expected_log = [x_, y_, x_, y_, x_, y_]
f_, g_ = float(jnp.log(1.25e8)), 3.
expected_log = [f_, g_, f_, g_, f_, g_]
self.assertListEqual(log, expected_log)

def test_different_threads_get_different_tokens(self):
Expand Down

0 comments on commit 9bc0439

Please sign in to comment.