Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: a2973be
  • Loading branch information
Default email authored and jaxbara committed Oct 11, 2024
1 parent 88bb402 commit 673c14d
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ jobs:
build:
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
runs-on: ROCM-Ubuntu
container:
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
timeout-minutes: 60
strategy:
matrix:
Expand All @@ -56,6 +58,10 @@ jobs:
num_generated_cases: 1
steps:
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
- name: Image Setup
run: |
apt update
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
with:
Expand Down
4 changes: 2 additions & 2 deletions docs/aot.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ way. An example:
>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
%c = stablehlo.constant dense<2> : tensor<i32>
%0 = stablehlo.multiply %c, %arg0 : tensor<i32>
%1 = stablehlo.add %0, %arg1 : tensor<i32>
Expand Down Expand Up @@ -137,7 +137,7 @@ to invoke the resulting compiled function. Continuing with our example above:
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
%c = stablehlo.constant dense<14> : tensor<i32>
%0 = stablehlo.add %c, %arg0 : tensor<i32>
return %0 : tensor<i32>
Expand Down
4 changes: 2 additions & 2 deletions docs/export/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Here is an example:
(ShapedArray(float32[]),)

>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0))
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
func.func public @main(%arg0: tensor<f32> loc("x")) -> (tensor<f32> {jax.result_info = ""}) {

>>> # And you can serialize the Exported to a bytearray.
>>> serialized: bytearray = exported.serialize()
Expand Down Expand Up @@ -206,7 +206,7 @@ as in the following example:
>>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results)
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,11 +1014,11 @@ def _to_physical_op_sharding(
def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout,
aval: core.AbstractValue) -> str | None:
if layout is None:
return "default"
return None
if isinstance(layout, AutoLayout):
return "auto"
if aval is core.abstract_token:
return "default"
return None
return str(layout._to_xla_layout(aval.dtype)) # type: ignore


Expand Down
1 change: 0 additions & 1 deletion tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def f(x):
return x.T

lowered = jax.jit(f, in_shardings=None, out_shardings=None).lower(sds)
self.assertIn("default", lowered.as_text())
compiled = lowered.compile()
out = compiled(arr)

Expand Down

0 comments on commit 673c14d

Please sign in to comment.