Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RnnTest.test_struct_encoding_determinism failing #25825

Open
belitskiy opened this issue Jan 10, 2025 · 2 comments
Open

RnnTest.test_struct_encoding_determinism failing #25825

belitskiy opened this issue Jan 10, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@belitskiy
Copy link
Contributor

Description

FAIL: //tests:experimental_rnn_test_gpu (shard 6 of 15) (see /root/.cache/bazel/_bazel_root/0fe1f5b2bc0694265667ca527ea660d2/execroot/__main__/bazel-out/k8-opt/testlogs/tests/experimental_rnn_test_gpu/shard_6_of_15/test.log)
INFO: From Testing //tests:experimental_rnn_test_gpu (shard 6 of 15):
==================== Test output for //tests:experimental_rnn_test_gpu (shard 6 of 15):
Running test /root/.cache/bazel/_bazel_root/0fe1f5b2bc0694265667ca527ea660d2/execroot/__main__/bazel-out/k8-opt/bin/tests/experimental_rnn_test_gpu.runfiles/__main__/tests/experimental_rnn_test_gpu --jax_test_dut=gpu --jax_platform_name=gpu on accelerator 2
Running tests under Python 3.10.15: /root/.cache/bazel/_bazel_root/0fe1f5b2bc0694265667ca527ea660d2/execroot/__main__/bazel-out/k8-opt/bin/tests/experimental_rnn_test_gpu.runfiles/python_x86_64-unknown-linux-gnu/bin/python3
[ RUN      ] RnnTest.test_lstm5 (batch_size=1, seq_len=1, input_size=2, hidden_size=1, num_layers=1, bidirectional=False)
INFO:2025-01-10 12:34:51,786:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0110 12:34:51.786977 140250971793216 xla_bridge.py:945] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-10 12:34:51,802:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
I0110 12:34:51.802578 140250971793216 xla_bridge.py:945] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
[       OK ] RnnTest.test_lstm5 (batch_size=1, seq_len=1, input_size=2, hidden_size=1, num_layers=1, bidirectional=False)
[ RUN      ] RnnTest.test_struct_encoding_determinism
[  FAILED  ] RnnTest.test_struct_encoding_determinism
======================================================================
FAIL: test_struct_encoding_determinism (__main__.RnnTest)
RnnTest.test_struct_encoding_determinism
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/root/.cache/bazel/_bazel_root/0fe1f5b2bc0694265667ca527ea660d2/execroot/__main__/bazel-out/k8-opt/bin/tests/experimental_rnn_test_gpu.runfiles/__main__/jax/_src/test_util.py", line 489, in test_method_wrapper
    return test_method(self, *args, **kwargs)
  File "/root/.cache/bazel/_bazel_root/0fe1f5b2bc0694265667ca527ea660d2/execroot/__main__/bazel-out/k8-opt/bin/tests/experimental_rnn_test_gpu.runfiles/__main__/tests/experimental_rnn_test.py", line 215, in test_struct_encoding_determinism
    self.assertIn('stablehlo.custom_call @cudnn_rnn(%0, %1, %2, %6, %5) '
AssertionError: 'stablehlo.custom_call @cudnn_rnn(%0, %1, %2, %6, %5) {api_version = 2 : i32, backend_config = "\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"}' not found in 'module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n  func.func public @main(%arg0: tensor<2xui32>, %arg1: tensor<2xui32>, %arg2: tensor<2xui32>, %arg3: tensor<2xui32>) -> (tensor<1x1x1xf32> {jax.result_info = "[0]"}, tensor<1x1x1xf32> {jax.result_info = "[1]"}, tensor<1x1x1xf32> {jax.result_info = "[2]"}) {\n    %0 = call @_normal(%arg0) : (tensor<2xui32>) -> tensor<1x1x1xf32>\n    %1 = call @_normal(%arg1) : (tensor<2xui32>) -> tensor<1x1x1xf32>\n    %2 = call @_normal(%arg2) : (tensor<2xui32>) -> tensor<1x1x1xf32>\n    %c = stablehlo.constant dense<1> : tensor<i32>\n    %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<1xi32>\n    %4 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<1xi32>\n    %5 = stablehlo.multiply %3, %4 : tensor<1xi32>\n    %cst = stablehlo.constant dense<-1.000000e+00> : tensor<f32>\n    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>\n    %6 = call @_uniform_0(%arg3, %cst, %cst_0) : (tensor<2xui32>, tensor<f32>, tensor<f32>) -> tensor<16xf32>\n    %7:5 = stablehlo.custom_call @cudnn_rnn(%0, %1, %2, %6, %5) {api_version = 2 : i32, backend_config = "\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\01\\DD\\BA@\\03\\80\\00@\\01\\00\\00"} : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<16xf32>, tensor<1xi32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<8389440xf32>, tensor<320xf32>)\n    return %7#0, %7#1, %7#2 : tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>\n  }\n  func.func private @_normal(%arg0: tensor<2xui32>) -> tensor<1x1x1xf32> {\n    %0 = call @_normal_real(%arg0) : (tensor<2xui32>) -> tensor<1x1x1xf32>\n    return %0 : tensor<1x1x1xf32>\n  }\n  func.func private @_normal_real(%arg0: tensor<2xui32>) -> tensor<1x1x1xf32> {\n    %cst = stablehlo.constant dense<-0.99999994> : tensor<f32>\n    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>\n    %0 = call @_uniform(%arg0, %cst, %cst_0) : (tensor<2xui32>, tensor<f32>, tensor<f32>) -> tensor<1x1x1xf32>\n    %1 = chlo.erf_inv %0 : tensor<1x1x1xf32> -> tensor<1x1x1xf32>\n    %cst_1 = stablehlo.constant dense<1.41421354> : tensor<f32>\n    %2 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<1x1x1xf32>\n    %3 = stablehlo.multiply %2, %1 : tensor<1x1x1xf32>\n    return %3 : tensor<1x1x1xf32>\n  }\n  func.func private @_uniform(%arg0: tensor<2xui32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x1x1xf32> {\n    %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<1x1x1xf32>\n    %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor<f32>) -> tensor<1x1x1xf32>\n    %c = stablehlo.constant dense<0> : tensor<1xui32>\n    %2 = stablehlo.iota dim = 0 : tensor<1xui32>\n    %3 = stablehlo.slice %arg0 [0:1] : (tensor<2xui32>) -> tensor<1xui32>\n    %4 = stablehlo.reshape %3 : (tensor<1xui32>) -> tensor<ui32>\n    %5 = stablehlo.slice %arg0 [1:2] : (tensor<2xui32>) -> tensor<1xui32>\n    %6 = stablehlo.reshape %5 : (tensor<1xui32>) -> tensor<ui32>\n    %7 = stablehlo.concatenate %2, %c, dim = 0 : (tensor<1xui32>, tensor<1xui32>) -> tensor<2xui32>\n    %8 = stablehlo.slice %7 [0:1] : (tensor<2xui32>) -> tensor<1xui32>\n    %9 = stablehlo.slice %7 [1:2] : (tensor<2xui32>) -> tensor<1xui32>\n    %10:2 = call @threefry2x32(%4, %6, %8, %9) : (tensor<ui32>, tensor<ui32>, tensor<1xui32>, tensor<1xui32>) -> (tensor<1xui32>, tensor<1xui32>)\n    %11 = stablehlo.concatenate %10#0, %10#1, dim = 0 : (tensor<1xui32>, tensor<1xui32>) -> tensor<2xui32>\n    %12 = stablehlo.slice %11 [0:1] : (tensor<2xui32>) -> tensor<1xui32>\n    %13 = stablehlo.reshape %12 : (tensor<1xui32>) -> tensor<1x1x1xui32>\n    %c_0 = stablehlo.constant dense<9> : tensor<ui32>\n    %14 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<ui32>) -> tensor<1x1x1xui32>\n    %15 = stablehlo.shift_right_logical %13, %14 : tensor<1x1x1xui32>\n    %c_1 = stablehlo.constant dense<1065353216> : tensor<ui32>\n    %16 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<ui32>) -> tensor<1x1x1xui32>\n    %17 = stablehlo.or %15, %16 : tensor<1x1x1xui32>\n    %18 = stablehlo.bitcast_convert %17 : (tensor<1x1x1xui32>) -> tensor<1x1x1xf32>\n    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>\n    %19 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<1x1x1xf32>\n    %20 = stablehlo.subtract %18, %19 : tensor<1x1x1xf32>\n    %21 = stablehlo.subtract %1, %0 : tensor<1x1x1xf32>\n    %22 = stablehlo.multiply %20, %21 : tensor<1x1x1xf32>\n    %23 = stablehlo.add %22, %0 : tensor<1x1x1xf32>\n    %24 = stablehlo.maximum %0, %23 : tensor<1x1x1xf32>\n    return %24 : tensor<1x1x1xf32>\n  }\n  func.func private @threefry2x32(%arg0: tensor<ui32>, %arg1: tensor<ui32>, %arg2: tensor<1xui32>, %arg3: tensor<1xui32>) -> (tensor<1xui32>, tensor<1xui32>) {\n    %0 = stablehlo.xor %arg0, %arg1 : tensor<ui32>\n    %c = stablehlo.constant dense<466688986> : tensor<ui32>\n    %1 = stablehlo.xor %0, %c : tensor<ui32>\n    %2 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %3 = stablehlo.add %arg2, %2 : tensor<1xui32>\n    %4 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %5 = stablehlo.add %arg3, %4 : tensor<1xui32>\n    %6 = stablehlo.add %3, %5 : tensor<1xui32>\n    %c_0 = stablehlo.constant dense<13> : tensor<ui32>\n    %7 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %8 = stablehlo.shift_left %5, %7 : tensor<1xui32>\n    %c_1 = stablehlo.constant dense<19> : tensor<ui32>\n    %9 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %10 = stablehlo.shift_right_logical %5, %9 : tensor<1xui32>\n    %11 = stablehlo.or %8, %10 : tensor<1xui32>\n    %12 = stablehlo.xor %6, %11 : tensor<1xui32>\n    %13 = stablehlo.add %6, %12 : tensor<1xui32>\n    %c_2 = stablehlo.constant dense<15> : tensor<ui32>\n    %14 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %15 = stablehlo.shift_left %12, %14 : tensor<1xui32>\n    %c_3 = stablehlo.constant dense<17> : tensor<ui32>\n    %16 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %17 = stablehlo.shift_right_logical %12, %16 : tensor<1xui32>\n    %18 = stablehlo.or %15, %17 : tensor<1xui32>\n    %19 = stablehlo.xor %13, %18 : tensor<1xui32>\n    %20 = stablehlo.add %13, %19 : tensor<1xui32>\n    %c_4 = stablehlo.constant dense<26> : tensor<ui32>\n    %21 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %22 = stablehlo.shift_left %19, %21 : tensor<1xui32>\n    %c_5 = stablehlo.constant dense<6> : tensor<ui32>\n    %23 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %24 = stablehlo.shift_right_logical %19, %23 : tensor<1xui32>\n    %25 = stablehlo.or %22, %24 : tensor<1xui32>\n    %26 = stablehlo.xor %20, %25 : tensor<1xui32>\n    %27 = stablehlo.add %20, %26 : tensor<1xui32>\n    %28 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %29 = stablehlo.shift_left %26, %28 : tensor<1xui32>\n    %30 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %31 = stablehlo.shift_right_logical %26, %30 : tensor<1xui32>\n    %32 = stablehlo.or %29, %31 : tensor<1xui32>\n    %33 = stablehlo.xor %27, %32 : tensor<1xui32>\n    %34 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %35 = stablehlo.add %27, %34 : tensor<1xui32>\n    %36 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %37 = stablehlo.add %33, %36 : tensor<1xui32>\n    %c_6 = stablehlo.constant dense<1> : tensor<ui32>\n    %38 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %39 = stablehlo.add %37, %38 : tensor<1xui32>\n    %40 = stablehlo.add %35, %39 : tensor<1xui32>\n    %41 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %42 = stablehlo.shift_left %39, %41 : tensor<1xui32>\n    %43 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %44 = stablehlo.shift_right_logical %39, %43 : tensor<1xui32>\n    %45 = stablehlo.or %42, %44 : tensor<1xui32>\n    %46 = stablehlo.xor %40, %45 : tensor<1xui32>\n    %47 = stablehlo.add %40, %46 : tensor<1xui32>\n    %c_7 = stablehlo.constant dense<29> : tensor<ui32>\n    %48 = stablehlo.broadcast_in_dim %c_7, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %49 = stablehlo.shift_left %46, %48 : tensor<1xui32>\n    %c_8 = stablehlo.constant dense<3> : tensor<ui32>\n    %50 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %51 = stablehlo.shift_right_logical %46, %50 : tensor<1xui32>\n    %52 = stablehlo.or %49, %51 : tensor<1xui32>\n    %53 = stablehlo.xor %47, %52 : tensor<1xui32>\n    %54 = stablehlo.add %47, %53 : tensor<1xui32>\n    %c_9 = stablehlo.constant dense<16> : tensor<ui32>\n    %55 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %56 = stablehlo.shift_left %53, %55 : tensor<1xui32>\n    %57 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %58 = stablehlo.shift_right_logical %53, %57 : tensor<1xui32>\n    %59 = stablehlo.or %56, %58 : tensor<1xui32>\n    %60 = stablehlo.xor %54, %59 : tensor<1xui32>\n    %61 = stablehlo.add %54, %60 : tensor<1xui32>\n    %c_10 = stablehlo.constant dense<24> : tensor<ui32>\n    %62 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %63 = stablehlo.shift_left %60, %62 : tensor<1xui32>\n    %c_11 = stablehlo.constant dense<8> : tensor<ui32>\n    %64 = stablehlo.broadcast_in_dim %c_11, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %65 = stablehlo.shift_right_logical %60, %64 : tensor<1xui32>\n    %66 = stablehlo.or %63, %65 : tensor<1xui32>\n    %67 = stablehlo.xor %61, %66 : tensor<1xui32>\n    %68 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %69 = stablehlo.add %61, %68 : tensor<1xui32>\n    %70 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %71 = stablehlo.add %67, %70 : tensor<1xui32>\n    %c_12 = stablehlo.constant dense<2> : tensor<ui32>\n    %72 = stablehlo.broadcast_in_dim %c_12, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %73 = stablehlo.add %71, %72 : tensor<1xui32>\n    %74 = stablehlo.add %69, %73 : tensor<1xui32>\n    %75 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %76 = stablehlo.shift_left %73, %75 : tensor<1xui32>\n    %77 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %78 = stablehlo.shift_right_logical %73, %77 : tensor<1xui32>\n    %79 = stablehlo.or %76, %78 : tensor<1xui32>\n    %80 = stablehlo.xor %74, %79 : tensor<1xui32>\n    %81 = stablehlo.add %74, %80 : tensor<1xui32>\n    %82 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %83 = stablehlo.shift_left %80, %82 : tensor<1xui32>\n    %84 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %85 = stablehlo.shift_right_logical %80, %84 : tensor<1xui32>\n    %86 = stablehlo.or %83, %85 : tensor<1xui32>\n    %87 = stablehlo.xor %81, %86 : tensor<1xui32>\n    %88 = stablehlo.add %81, %87 : tensor<1xui32>\n    %89 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %90 = stablehlo.shift_left %87, %89 : tensor<1xui32>\n    %91 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %92 = stablehlo.shift_right_logical %87, %91 : tensor<1xui32>\n    %93 = stablehlo.or %90, %92 : tensor<1xui32>\n    %94 = stablehlo.xor %88, %93 : tensor<1xui32>\n    %95 = stablehlo.add %88, %94 : tensor<1xui32>\n    %96 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %97 = stablehlo.shift_left %94, %96 : tensor<1xui32>\n    %98 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %99 = stablehlo.shift_right_logical %94, %98 : tensor<1xui32>\n    %100 = stablehlo.or %97, %99 : tensor<1xui32>\n    %101 = stablehlo.xor %95, %100 : tensor<1xui32>\n    %102 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %103 = stablehlo.add %95, %102 : tensor<1xui32>\n    %104 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %105 = stablehlo.add %101, %104 : tensor<1xui32>\n    %106 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %107 = stablehlo.add %105, %106 : tensor<1xui32>\n    %108 = stablehlo.add %103, %107 : tensor<1xui32>\n    %109 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %110 = stablehlo.shift_left %107, %109 : tensor<1xui32>\n    %111 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %112 = stablehlo.shift_right_logical %107, %111 : tensor<1xui32>\n    %113 = stablehlo.or %110, %112 : tensor<1xui32>\n    %114 = stablehlo.xor %108, %113 : tensor<1xui32>\n    %115 = stablehlo.add %108, %114 : tensor<1xui32>\n    %116 = stablehlo.broadcast_in_dim %c_7, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %117 = stablehlo.shift_left %114, %116 : tensor<1xui32>\n    %118 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %119 = stablehlo.shift_right_logical %114, %118 : tensor<1xui32>\n    %120 = stablehlo.or %117, %119 : tensor<1xui32>\n    %121 = stablehlo.xor %115, %120 : tensor<1xui32>\n    %122 = stablehlo.add %115, %121 : tensor<1xui32>\n    %123 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %124 = stablehlo.shift_left %121, %123 : tensor<1xui32>\n    %125 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %126 = stablehlo.shift_right_logical %121, %125 : tensor<1xui32>\n    %127 = stablehlo.or %124, %126 : tensor<1xui32>\n    %128 = stablehlo.xor %122, %127 : tensor<1xui32>\n    %129 = stablehlo.add %122, %128 : tensor<1xui32>\n    %130 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %131 = stablehlo.shift_left %128, %130 : tensor<1xui32>\n    %132 = stablehlo.broadcast_in_dim %c_11, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %133 = stablehlo.shift_right_logical %128, %132 : tensor<1xui32>\n    %134 = stablehlo.or %131, %133 : tensor<1xui32>\n    %135 = stablehlo.xor %129, %134 : tensor<1xui32>\n    %136 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %137 = stablehlo.add %129, %136 : tensor<1xui32>\n    %138 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %139 = stablehlo.add %135, %138 : tensor<1xui32>\n    %c_13 = stablehlo.constant dense<4> : tensor<ui32>\n    %140 = stablehlo.broadcast_in_dim %c_13, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %141 = stablehlo.add %139, %140 : tensor<1xui32>\n    %142 = stablehlo.add %137, %141 : tensor<1xui32>\n    %143 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %144 = stablehlo.shift_left %141, %143 : tensor<1xui32>\n    %145 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %146 = stablehlo.shift_right_logical %141, %145 : tensor<1xui32>\n    %147 = stablehlo.or %144, %146 : tensor<1xui32>\n    %148 = stablehlo.xor %142, %147 : tensor<1xui32>\n    %149 = stablehlo.add %142, %148 : tensor<1xui32>\n    %150 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %151 = stablehlo.shift_left %148, %150 : tensor<1xui32>\n    %152 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %153 = stablehlo.shift_right_logical %148, %152 : tensor<1xui32>\n    %154 = stablehlo.or %151, %153 : tensor<1xui32>\n    %155 = stablehlo.xor %149, %154 : tensor<1xui32>\n    %156 = stablehlo.add %149, %155 : tensor<1xui32>\n    %157 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %158 = stablehlo.shift_left %155, %157 : tensor<1xui32>\n    %159 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %160 = stablehlo.shift_right_logical %155, %159 : tensor<1xui32>\n    %161 = stablehlo.or %158, %160 : tensor<1xui32>\n    %162 = stablehlo.xor %156, %161 : tensor<1xui32>\n    %163 = stablehlo.add %156, %162 : tensor<1xui32>\n    %164 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %165 = stablehlo.shift_left %162, %164 : tensor<1xui32>\n    %166 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %167 = stablehlo.shift_right_logical %162, %166 : tensor<1xui32>\n    %168 = stablehlo.or %165, %167 : tensor<1xui32>\n    %169 = stablehlo.xor %163, %168 : tensor<1xui32>\n    %170 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %171 = stablehlo.add %163, %170 : tensor<1xui32>\n    %172 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %173 = stablehlo.add %169, %172 : tensor<1xui32>\n    %c_14 = stablehlo.constant dense<5> : tensor<ui32>\n    %174 = stablehlo.broadcast_in_dim %c_14, dims = [] : (tensor<ui32>) -> tensor<1xui32>\n    %175 = stablehlo.add %173, %174 : tensor<1xui32>\n    return %171, %175 : tensor<1xui32>, tensor<1xui32>\n  }\n  func.func private @_uniform_0(%arg0: tensor<2xui32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<16xf32> {\n    %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<1xf32>\n    %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor<f32>) -> tensor<1xf32>\n    %2 = stablehlo.iota dim = 0 : tensor<16xui32>\n    %3 = stablehlo.slice %arg0 [0:1] : (tensor<2xui32>) -> tensor<1xui32>\n    %4 = stablehlo.reshape %3 : (tensor<1xui32>) -> tensor<ui32>\n    %5 = stablehlo.slice %arg0 [1:2] : (tensor<2xui32>) -> tensor<1xui32>\n    %6 = stablehlo.reshape %5 : (tensor<1xui32>) -> tensor<ui32>\n    %7 = stablehlo.slice %2 [0:8] : (tensor<16xui32>) -> tensor<8xui32>\n    %8 = stablehlo.slice %2 [8:16] : (tensor<16xui32>) -> tensor<8xui32>\n    %9:2 = call @threefry2x32_1(%4, %6, %7, %8) : (tensor<ui32>, tensor<ui32>, tensor<8xui32>, tensor<8xui32>) -> (tensor<8xui32>, tensor<8xui32>)\n    %10 = stablehlo.concatenate %9#0, %9#1, dim = 0 : (tensor<8xui32>, tensor<8xui32>) -> tensor<16xui32>\n    %c = stablehlo.constant dense<9> : tensor<ui32>\n    %11 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<ui32>) -> tensor<16xui32>\n    %12 = stablehlo.shift_right_logical %10, %11 : tensor<16xui32>\n    %c_0 = stablehlo.constant dense<1065353216> : tensor<ui32>\n    %13 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<ui32>) -> tensor<16xui32>\n    %14 = stablehlo.or %12, %13 : tensor<16xui32>\n    %15 = stablehlo.bitcast_convert %14 : (tensor<16xui32>) -> tensor<16xf32>\n    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>\n    %16 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>\n    %17 = stablehlo.subtract %15, %16 : tensor<16xf32>\n    %18 = stablehlo.subtract %1, %0 : tensor<1xf32>\n    %19 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xf32>) -> tensor<16xf32>\n    %20 = stablehlo.multiply %17, %19 : tensor<16xf32>\n    %21 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<1xf32>) -> tensor<16xf32>\n    %22 = stablehlo.add %20, %21 : tensor<16xf32>\n    %23 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<1xf32>) -> tensor<16xf32>\n    %24 = stablehlo.maximum %23, %22 : tensor<16xf32>\n    return %24 : tensor<16xf32>\n  }\n  func.func private @threefry2x32_1(%arg0: tensor<ui32>, %arg1: tensor<ui32>, %arg2: tensor<8xui32>, %arg3: tensor<8xui32>) -> (tensor<8xui32>, tensor<8xui32>) {\n    %0 = stablehlo.xor %arg0, %arg1 : tensor<ui32>\n    %c = stablehlo.constant dense<466688986> : tensor<ui32>\n    %1 = stablehlo.xor %0, %c : tensor<ui32>\n    %2 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %3 = stablehlo.add %arg2, %2 : tensor<8xui32>\n    %4 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %5 = stablehlo.add %arg3, %4 : tensor<8xui32>\n    %6 = stablehlo.add %3, %5 : tensor<8xui32>\n    %c_0 = stablehlo.constant dense<13> : tensor<ui32>\n    %7 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %8 = stablehlo.shift_left %5, %7 : tensor<8xui32>\n    %c_1 = stablehlo.constant dense<19> : tensor<ui32>\n    %9 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %10 = stablehlo.shift_right_logical %5, %9 : tensor<8xui32>\n    %11 = stablehlo.or %8, %10 : tensor<8xui32>\n    %12 = stablehlo.xor %6, %11 : tensor<8xui32>\n    %13 = stablehlo.add %6, %12 : tensor<8xui32>\n    %c_2 = stablehlo.constant dense<15> : tensor<ui32>\n    %14 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %15 = stablehlo.shift_left %12, %14 : tensor<8xui32>\n    %c_3 = stablehlo.constant dense<17> : tensor<ui32>\n    %16 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %17 = stablehlo.shift_right_logical %12, %16 : tensor<8xui32>\n    %18 = stablehlo.or %15, %17 : tensor<8xui32>\n    %19 = stablehlo.xor %13, %18 : tensor<8xui32>\n    %20 = stablehlo.add %13, %19 : tensor<8xui32>\n    %c_4 = stablehlo.constant dense<26> : tensor<ui32>\n    %21 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %22 = stablehlo.shift_left %19, %21 : tensor<8xui32>\n    %c_5 = stablehlo.constant dense<6> : tensor<ui32>\n    %23 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %24 = stablehlo.shift_right_logical %19, %23 : tensor<8xui32>\n    %25 = stablehlo.or %22, %24 : tensor<8xui32>\n    %26 = stablehlo.xor %20, %25 : tensor<8xui32>\n    %27 = stablehlo.add %20, %26 : tensor<8xui32>\n    %28 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %29 = stablehlo.shift_left %26, %28 : tensor<8xui32>\n    %30 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %31 = stablehlo.shift_right_logical %26, %30 : tensor<8xui32>\n    %32 = stablehlo.or %29, %31 : tensor<8xui32>\n    %33 = stablehlo.xor %27, %32 : tensor<8xui32>\n    %34 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %35 = stablehlo.add %27, %34 : tensor<8xui32>\n    %36 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %37 = stablehlo.add %33, %36 : tensor<8xui32>\n    %c_6 = stablehlo.constant dense<1> : tensor<ui32>\n    %38 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %39 = stablehlo.add %37, %38 : tensor<8xui32>\n    %40 = stablehlo.add %35, %39 : tensor<8xui32>\n    %41 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %42 = stablehlo.shift_left %39, %41 : tensor<8xui32>\n    %43 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %44 = stablehlo.shift_right_logical %39, %43 : tensor<8xui32>\n    %45 = stablehlo.or %42, %44 : tensor<8xui32>\n    %46 = stablehlo.xor %40, %45 : tensor<8xui32>\n    %47 = stablehlo.add %40, %46 : tensor<8xui32>\n    %c_7 = stablehlo.constant dense<29> : tensor<ui32>\n    %48 = stablehlo.broadcast_in_dim %c_7, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %49 = stablehlo.shift_left %46, %48 : tensor<8xui32>\n    %c_8 = stablehlo.constant dense<3> : tensor<ui32>\n    %50 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %51 = stablehlo.shift_right_logical %46, %50 : tensor<8xui32>\n    %52 = stablehlo.or %49, %51 : tensor<8xui32>\n    %53 = stablehlo.xor %47, %52 : tensor<8xui32>\n    %54 = stablehlo.add %47, %53 : tensor<8xui32>\n    %c_9 = stablehlo.constant dense<16> : tensor<ui32>\n    %55 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %56 = stablehlo.shift_left %53, %55 : tensor<8xui32>\n    %57 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %58 = stablehlo.shift_right_logical %53, %57 : tensor<8xui32>\n    %59 = stablehlo.or %56, %58 : tensor<8xui32>\n    %60 = stablehlo.xor %54, %59 : tensor<8xui32>\n    %61 = stablehlo.add %54, %60 : tensor<8xui32>\n    %c_10 = stablehlo.constant dense<24> : tensor<ui32>\n    %62 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %63 = stablehlo.shift_left %60, %62 : tensor<8xui32>\n    %c_11 = stablehlo.constant dense<8> : tensor<ui32>\n    %64 = stablehlo.broadcast_in_dim %c_11, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %65 = stablehlo.shift_right_logical %60, %64 : tensor<8xui32>\n    %66 = stablehlo.or %63, %65 : tensor<8xui32>\n    %67 = stablehlo.xor %61, %66 : tensor<8xui32>\n    %68 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %69 = stablehlo.add %61, %68 : tensor<8xui32>\n    %70 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %71 = stablehlo.add %67, %70 : tensor<8xui32>\n    %c_12 = stablehlo.constant dense<2> : tensor<ui32>\n    %72 = stablehlo.broadcast_in_dim %c_12, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %73 = stablehlo.add %71, %72 : tensor<8xui32>\n    %74 = stablehlo.add %69, %73 : tensor<8xui32>\n    %75 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %76 = stablehlo.shift_left %73, %75 : tensor<8xui32>\n    %77 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %78 = stablehlo.shift_right_logical %73, %77 : tensor<8xui32>\n    %79 = stablehlo.or %76, %78 : tensor<8xui32>\n    %80 = stablehlo.xor %74, %79 : tensor<8xui32>\n    %81 = stablehlo.add %74, %80 : tensor<8xui32>\n    %82 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %83 = stablehlo.shift_left %80, %82 : tensor<8xui32>\n    %84 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %85 = stablehlo.shift_right_logical %80, %84 : tensor<8xui32>\n    %86 = stablehlo.or %83, %85 : tensor<8xui32>\n    %87 = stablehlo.xor %81, %86 : tensor<8xui32>\n    %88 = stablehlo.add %81, %87 : tensor<8xui32>\n    %89 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %90 = stablehlo.shift_left %87, %89 : tensor<8xui32>\n    %91 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %92 = stablehlo.shift_right_logical %87, %91 : tensor<8xui32>\n    %93 = stablehlo.or %90, %92 : tensor<8xui32>\n    %94 = stablehlo.xor %88, %93 : tensor<8xui32>\n    %95 = stablehlo.add %88, %94 : tensor<8xui32>\n    %96 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %97 = stablehlo.shift_left %94, %96 : tensor<8xui32>\n    %98 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %99 = stablehlo.shift_right_logical %94, %98 : tensor<8xui32>\n    %100 = stablehlo.or %97, %99 : tensor<8xui32>\n    %101 = stablehlo.xor %95, %100 : tensor<8xui32>\n    %102 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %103 = stablehlo.add %95, %102 : tensor<8xui32>\n    %104 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %105 = stablehlo.add %101, %104 : tensor<8xui32>\n    %106 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %107 = stablehlo.add %105, %106 : tensor<8xui32>\n    %108 = stablehlo.add %103, %107 : tensor<8xui32>\n    %109 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %110 = stablehlo.shift_left %107, %109 : tensor<8xui32>\n    %111 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %112 = stablehlo.shift_right_logical %107, %111 : tensor<8xui32>\n    %113 = stablehlo.or %110, %112 : tensor<8xui32>\n    %114 = stablehlo.xor %108, %113 : tensor<8xui32>\n    %115 = stablehlo.add %108, %114 : tensor<8xui32>\n    %116 = stablehlo.broadcast_in_dim %c_7, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %117 = stablehlo.shift_left %114, %116 : tensor<8xui32>\n    %118 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %119 = stablehlo.shift_right_logical %114, %118 : tensor<8xui32>\n    %120 = stablehlo.or %117, %119 : tensor<8xui32>\n    %121 = stablehlo.xor %115, %120 : tensor<8xui32>\n    %122 = stablehlo.add %115, %121 : tensor<8xui32>\n    %123 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %124 = stablehlo.shift_left %121, %123 : tensor<8xui32>\n    %125 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %126 = stablehlo.shift_right_logical %121, %125 : tensor<8xui32>\n    %127 = stablehlo.or %124, %126 : tensor<8xui32>\n    %128 = stablehlo.xor %122, %127 : tensor<8xui32>\n    %129 = stablehlo.add %122, %128 : tensor<8xui32>\n    %130 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %131 = stablehlo.shift_left %128, %130 : tensor<8xui32>\n    %132 = stablehlo.broadcast_in_dim %c_11, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %133 = stablehlo.shift_right_logical %128, %132 : tensor<8xui32>\n    %134 = stablehlo.or %131, %133 : tensor<8xui32>\n    %135 = stablehlo.xor %129, %134 : tensor<8xui32>\n    %136 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %137 = stablehlo.add %129, %136 : tensor<8xui32>\n    %138 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %139 = stablehlo.add %135, %138 : tensor<8xui32>\n    %c_13 = stablehlo.constant dense<4> : tensor<ui32>\n    %140 = stablehlo.broadcast_in_dim %c_13, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %141 = stablehlo.add %139, %140 : tensor<8xui32>\n    %142 = stablehlo.add %137, %141 : tensor<8xui32>\n    %143 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %144 = stablehlo.shift_left %141, %143 : tensor<8xui32>\n    %145 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %146 = stablehlo.shift_right_logical %141, %145 : tensor<8xui32>\n    %147 = stablehlo.or %144, %146 : tensor<8xui32>\n    %148 = stablehlo.xor %142, %147 : tensor<8xui32>\n    %149 = stablehlo.add %142, %148 : tensor<8xui32>\n    %150 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %151 = stablehlo.shift_left %148, %150 : tensor<8xui32>\n    %152 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %153 = stablehlo.shift_right_logical %148, %152 : tensor<8xui32>\n    %154 = stablehlo.or %151, %153 : tensor<8xui32>\n    %155 = stablehlo.xor %149, %154 : tensor<8xui32>\n    %156 = stablehlo.add %149, %155 : tensor<8xui32>\n    %157 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %158 = stablehlo.shift_left %155, %157 : tensor<8xui32>\n    %159 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %160 = stablehlo.shift_right_logical %155, %159 : tensor<8xui32>\n    %161 = stablehlo.or %158, %160 : tensor<8xui32>\n    %162 = stablehlo.xor %156, %161 : tensor<8xui32>\n    %163 = stablehlo.add %156, %162 : tensor<8xui32>\n    %164 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %165 = stablehlo.shift_left %162, %164 : tensor<8xui32>\n    %166 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %167 = stablehlo.shift_right_logical %162, %166 : tensor<8xui32>\n    %168 = stablehlo.or %165, %167 : tensor<8xui32>\n    %169 = stablehlo.xor %163, %168 : tensor<8xui32>\n    %170 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %171 = stablehlo.add %163, %170 : tensor<8xui32>\n    %172 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %173 = stablehlo.add %169, %172 : tensor<8xui32>\n    %c_14 = stablehlo.constant dense<5> : tensor<ui32>\n    %174 = stablehlo.broadcast_in_dim %c_14, dims = [] : (tensor<ui32>) -> tensor<8xui32>\n    %175 = stablehlo.add %173, %174 : tensor<8xui32>\n    return %171, %175 : tensor<8xui32>, tensor<8xui32>\n  }\n}\n'

----------------------------------------------------------------------
Ran 2 tests in 16.874s

FAILED (failures=1)

System info (python version, jaxlib version, accelerator, etc.)

jax: from head
jaxlib: pypi (latest)

command:

bazel test --config=ci_linux_x86_64_cuda --repo_env=HERMETIC_PYTHON_VERSION=3.10 --//jax:build_jaxlib=false //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform --test_output=errors --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=JAX_ACCELERATOR_COUNT=4 --test_env=JAX_TESTS_PER_ACCELERATOR=12 --local_test_jobs=48 --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow --test_tag_filters=-multiaccelerator
```
@belitskiy belitskiy added the bug Something isn't working label Jan 10, 2025
@belitskiy
Copy link
Contributor Author

belitskiy commented Jan 10, 2025

@sergachev please take a look.

The test was skipped in
86643a1

@sergachev
Copy link
Contributor

sergachev commented Jan 10, 2025

jax: from head
jaxlib: pypi (latest)

For this test to pass jaxlib has to include the change to rnn_kernels.h from the PR: https://github.com/jax-ml/jax/pull/25803/files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants