Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: 01/08/25 upstream sync #196

Merged
merged 34 commits into from
Jan 8, 2025
Merged

CI: 01/08/25 upstream sync #196

merged 34 commits into from
Jan 8, 2025

Conversation

github-actions[bot]
Copy link

@github-actions github-actions bot commented Jan 8, 2025

Daily sync with upstream

zhenying-liu and others added 30 commits December 19, 2024 14:47
Previously with 4 heads the reference function `ref` would allocate 32 GiB since it materializes large intermediate tensors. That causes CI on an 80GB H100 to run out of memory when 4 tests run in parallel. `num_q_heads=2` allows us to test multiple heads while cutting memory in half.
It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.

However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).

PiperOrigin-RevId: 712926140
PiperOrigin-RevId: 712930537
…ton lowering.

Addresses: jax-ml#25714
PiperOrigin-RevId: 712930760
…tion`, which ensures the determinism in the generated `sdy.manual_computation`.

PiperOrigin-RevId: 712973327
(1.11.0 was yanked from PyPi because of licensing problems, so 1.11.1 is the oldest 1.11 release.)

PiperOrigin-RevId: 713073731
PiperOrigin-RevId: 713075852
… None and not UNCONSTRAINED because axis_types + pspec give the full picture.

PiperOrigin-RevId: 713105375
Log such events for log_elapsed_time.

The rationale for not replacing durations with it is that it appears that
record_event_duration_secs() is widely used outside of the code of JAX itself.

PiperOrigin-RevId: 713167192
It was cmp + iota before.

PiperOrigin-RevId: 713240888
…outs for 16-bit types on v6

This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.

PiperOrigin-RevId: 713270421
…ify the number of CPU directly.

This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
hawkinsp and others added 3 commits January 8, 2025 09:11
…threads.

This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes.

This approach is broadly inspired by https://github.com/testing-cabal/testtools/blob/a6d205dd4cac51f3cf9267978d39fc877103aacb/testtools/testsuite.py#L113 and by unittest-ft.

We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool.

We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib.

PiperOrigin-RevId: 713312937
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example
```
File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module
    pipeline.run(ctx.module.operation)
MLIRError: Failure while executing pass pipeline:
error:
...
'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2
...
see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64>
```
PiperOrigin-RevId: 713314555
@github-actions github-actions bot enabled auto-merge January 8, 2025 17:46
@github-actions github-actions bot merged commit 90eab82 into rocm-main Jan 8, 2025
8 checks passed
@charleshofer charleshofer deleted the ci-upstream-sync-82_1 branch January 8, 2025 21:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.