forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CI: 01/08/25 upstream sync #196
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Previously with 4 heads the reference function `ref` would allocate 32 GiB since it materializes large intermediate tensors. That causes CI on an 80GB H100 to run out of memory when 4 tests run in parallel. `num_q_heads=2` allows us to test multiple heads while cutting memory in half.
It turns out that the backend is rarely needed when lowering, e.g., for lowering callbacks. Whenever we need the backend for lowering, we must be in single-platform lowering mode (`len(platforms) == 1`) and we can look up the backend from `platforms[0]`. However, in some rare cases we can have a custom `XlaBackend` whose platform matches `platforms[0]`. We rename `backend_or_name` to just `backend` and we restrict its type to be an optional `XlaBackend` (not a platform string). PiperOrigin-RevId: 712926140
PiperOrigin-RevId: 712929769
PiperOrigin-RevId: 712930537
…ton lowering. Addresses: jax-ml#25714 PiperOrigin-RevId: 712930760
http://github.com/openxla/xla/commit/9b8f679bd219079ede47398e01c0d9863ef0d6e3. PiperOrigin-RevId: 712940327
…tion`, which ensures the determinism in the generated `sdy.manual_computation`. PiperOrigin-RevId: 712973327
PiperOrigin-RevId: 712979906
(1.11.0 was yanked from PyPi because of licensing problems, so 1.11.1 is the oldest 1.11 release.) PiperOrigin-RevId: 713073731
PiperOrigin-RevId: 713075852
… None and not UNCONSTRAINED because axis_types + pspec give the full picture. PiperOrigin-RevId: 713105375
PiperOrigin-RevId: 713110147
Log such events for log_elapsed_time. The rationale for not replacing durations with it is that it appears that record_event_duration_secs() is widely used outside of the code of JAX itself. PiperOrigin-RevId: 713167192
…ng-doc PiperOrigin-RevId: 713170813
PiperOrigin-RevId: 713202387
PiperOrigin-RevId: 713222111
It was cmp + iota before. PiperOrigin-RevId: 713240888
PiperOrigin-RevId: 713247480
PiperOrigin-RevId: 713259707
…outs for 16-bit types on v6 This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling, by biasing the layout inference to prefer the latter. PiperOrigin-RevId: 713270421
…ify the number of CPU directly. This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS. In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads. PiperOrigin-RevId: 713272197
PiperOrigin-RevId: 713283207
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism. If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally. PiperOrigin-RevId: 713296722
…threads. This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes. This approach is broadly inspired by https://github.com/testing-cabal/testtools/blob/a6d205dd4cac51f3cf9267978d39fc877103aacb/testtools/testsuite.py#L113 and by unittest-ft. We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool. We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib. PiperOrigin-RevId: 713312937
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example ``` File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module pipeline.run(ctx.module.operation) MLIRError: Failure while executing pass pipeline: error: ... 'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2 ... see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64> ``` PiperOrigin-RevId: 713314555
http://github.com/openxla/xla/commit/1b9969348628c9a435fc1d6973acab2e5d55849a. PiperOrigin-RevId: 713320368
charleshofer
approved these changes
Jan 8, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream