Results are as reported by this notebook. To re-run these experiments, just head over to Google Colab, upload the notebook, and run the cells one by one.
Hardware at the time of writing (19 Dec 2019):
- Intel(R) Xeon(R) CPU @ 2.30GHz (1 core, 2 threads)
- 12.6GB of RAM
- NVidia Tesla K80 GPU with 12GB memory
Caveat: Jax does not support 64bit floating point precision on TPU architectures (yet). Therefore, the Jax + TPU results are not bit-identical to all other backends and devices, so it's not really an apples-to-apples comparison.
Caveat²: I didn't manage to get Bohrium to work on Colab, so it's missing from these results.
An equation consisting of >100 terms with no data dependencies and only elementary math. This benchmark should represent a best-case scenario for vector instructions and GPU performance.
$ taskset -c 0 python run.py benchmarks/equation_of_state/
Estimating repetitions...
Running 100116 benchmarks... [####################################] 100%
benchmarks.equation_of_state
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numba 10,000 0.001 0.000 0.000 0.000 0.000 0.001 0.003 3.402
4,096 theano 10,000 0.001 0.000 0.000 0.001 0.001 0.001 0.004 2.668
4,096 tensorflow 10,000 0.001 0.000 0.000 0.001 0.001 0.001 0.004 2.645
4,096 jax 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.004 2.500
4,096 numpy 10,000 0.002 0.000 0.001 0.002 0.002 0.002 0.006 1.000
4,096 pytorch 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.008 0.742
16,384 tensorflow 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.006 4.418
16,384 jax 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.006 4.244
16,384 numba 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.007 3.994
16,384 theano 10,000 0.003 0.000 0.002 0.002 0.002 0.003 0.007 3.443
16,384 pytorch 1,000 0.008 0.001 0.007 0.008 0.008 0.009 0.015 1.037
16,384 numpy 1,000 0.009 0.001 0.007 0.008 0.009 0.009 0.014 1.000
65,536 tensorflow 1,000 0.006 0.001 0.006 0.006 0.006 0.006 0.010 7.932
65,536 jax 1,000 0.007 0.001 0.006 0.007 0.007 0.007 0.012 6.998
65,536 numba 1,000 0.009 0.001 0.008 0.008 0.008 0.009 0.014 5.852
65,536 theano 1,000 0.010 0.001 0.009 0.009 0.010 0.010 0.016 5.123
65,536 pytorch 100 0.048 0.002 0.045 0.046 0.047 0.048 0.058 1.045
65,536 numpy 100 0.050 0.002 0.047 0.049 0.049 0.050 0.061 1.000
262,144 tensorflow 1,000 0.020 0.002 0.018 0.019 0.020 0.020 0.034 12.017
262,144 jax 1,000 0.023 0.002 0.021 0.022 0.023 0.023 0.034 10.489
262,144 numba 1,000 0.031 0.002 0.029 0.031 0.031 0.031 0.044 7.725
262,144 theano 100 0.035 0.002 0.034 0.034 0.035 0.035 0.043 6.887
262,144 pytorch 100 0.201 0.006 0.180 0.197 0.201 0.205 0.231 1.206
262,144 numpy 100 0.242 0.007 0.210 0.239 0.242 0.246 0.265 1.000
1,048,576 tensorflow 100 0.090 0.004 0.084 0.087 0.088 0.092 0.102 8.535
1,048,576 jax 100 0.099 0.005 0.093 0.095 0.097 0.101 0.115 7.756
1,048,576 numba 100 0.133 0.005 0.126 0.129 0.131 0.136 0.146 5.778
1,048,576 theano 100 0.148 0.004 0.143 0.145 0.147 0.151 0.160 5.171
1,048,576 numpy 10 0.766 0.005 0.761 0.763 0.764 0.768 0.780 1.000
1,048,576 pytorch 10 0.807 0.007 0.795 0.803 0.806 0.809 0.822 0.950
4,194,304 tensorflow 10 0.355 0.007 0.347 0.348 0.354 0.358 0.370 9.285
4,194,304 jax 10 0.396 0.005 0.387 0.392 0.396 0.400 0.406 8.316
4,194,304 numba 10 0.514 0.006 0.504 0.511 0.515 0.518 0.522 6.412
4,194,304 theano 10 0.590 0.007 0.577 0.587 0.589 0.595 0.604 5.581
4,194,304 numpy 10 3.295 0.017 3.262 3.284 3.293 3.302 3.327 1.000
4,194,304 pytorch 10 3.837 0.016 3.819 3.827 3.833 3.841 3.876 0.859
(time in wall seconds, less is better)
$ for backend in jax tensorflow pytorch cupy; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/equation_of_state/ --device gpu -b $backend -b numpy; done
Estimating repetitions...
Running 71232 benchmarks... [####################################] 100%
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 5.083
4,096 numpy 10,000 0.002 0.000 0.001 0.002 0.002 0.002 0.005 1.000
16,384 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 25.571
16,384 numpy 1,000 0.009 0.001 0.007 0.008 0.008 0.009 0.013 1.000
65,536 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.003 352.617
65,536 numpy 100 0.140 0.006 0.122 0.137 0.140 0.144 0.154 1.000
262,144 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.003 533.222
262,144 numpy 100 0.251 0.035 0.228 0.243 0.247 0.251 0.595 1.000
1,048,576 jax 10,000 0.002 0.000 0.000 0.002 0.002 0.003 0.006 290.491
1,048,576 numpy 10 0.707 0.007 0.699 0.703 0.707 0.709 0.723 1.000
4,194,304 jax 10,000 0.003 0.000 0.001 0.003 0.003 0.003 0.005 1140.959
4,194,304 numpy 10 3.060 0.012 3.041 3.049 3.063 3.071 3.073 1.000
(time in wall seconds, less is better)
Estimating repetitions...
Running 71232 benchmarks... [####################################] 100%
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 tensorflow 10,000 0.001 0.000 0.000 0.001 0.001 0.001 0.007 2.552
4,096 numpy 10,000 0.002 0.000 0.001 0.002 0.002 0.002 0.005 1.000
16,384 tensorflow 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.006 11.811
16,384 numpy 1,000 0.009 0.001 0.007 0.008 0.008 0.009 0.013 1.000
65,536 tensorflow 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.010 48.165
65,536 numpy 100 0.047 0.002 0.043 0.045 0.046 0.047 0.058 1.000
262,144 tensorflow 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.012 188.762
262,144 numpy 100 0.251 0.008 0.220 0.247 0.252 0.255 0.272 1.000
1,048,576 tensorflow 10,000 0.003 0.000 0.002 0.003 0.003 0.003 0.009 261.178
1,048,576 numpy 10 0.720 0.017 0.680 0.715 0.721 0.729 0.741 1.000
4,194,304 tensorflow 10,000 0.011 0.001 0.010 0.011 0.011 0.011 0.025 279.396
4,194,304 numpy 10 3.069 0.014 3.053 3.059 3.067 3.074 3.104 1.000
(time in wall seconds, less is better)
Estimating repetitions...
Running 16332 benchmarks... [####################################] 100%
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 10,000 0.002 0.000 0.001 0.001 0.002 0.002 0.009 1.000
4,096 pytorch 1,000 0.004 0.001 0.003 0.004 0.004 0.004 0.008 0.389
16,384 pytorch 1,000 0.004 0.001 0.003 0.004 0.004 0.004 0.009 1.987
16,384 numpy 1,000 0.008 0.001 0.007 0.007 0.008 0.008 0.013 1.000
65,536 pytorch 1,000 0.004 0.001 0.003 0.004 0.004 0.004 0.008 17.421
65,536 numpy 100 0.073 0.005 0.048 0.071 0.073 0.075 0.084 1.000
262,144 pytorch 1,000 0.004 0.001 0.004 0.004 0.004 0.004 0.008 40.591
262,144 numpy 100 0.174 0.005 0.163 0.171 0.174 0.177 0.191 1.000
1,048,576 pytorch 1,000 0.017 0.000 0.016 0.017 0.017 0.017 0.017 50.141
1,048,576 numpy 10 0.832 0.008 0.821 0.829 0.830 0.836 0.849 1.000
4,194,304 pytorch 100 0.062 0.000 0.062 0.062 0.062 0.062 0.063 50.867
4,194,304 numpy 10 3.156 0.021 3.130 3.139 3.151 3.161 3.197 1.000
(time in wall seconds, less is better)
Estimating repetitions...
Running 16332 benchmarks... [####################################] 100%
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 10,000 0.002 0.000 0.001 0.001 0.002 0.002 0.006 1.000
4,096 cupy 1,000 0.011 0.001 0.010 0.011 0.011 0.011 0.019 0.146
16,384 numpy 1,000 0.008 0.001 0.007 0.007 0.008 0.009 0.014 1.000
16,384 cupy 1,000 0.011 0.001 0.010 0.011 0.011 0.011 0.021 0.719
65,536 cupy 1,000 0.011 0.002 0.010 0.011 0.011 0.011 0.023 13.935
65,536 numpy 100 0.159 0.018 0.063 0.158 0.162 0.166 0.182 1.000
262,144 cupy 1,000 0.011 0.002 0.010 0.011 0.011 0.011 0.023 23.526
262,144 numpy 100 0.268 0.008 0.235 0.264 0.269 0.274 0.284 1.000
1,048,576 cupy 1,000 0.017 0.001 0.016 0.016 0.016 0.016 0.022 44.047
1,048,576 numpy 10 0.730 0.010 0.707 0.723 0.735 0.738 0.741 1.000
4,194,304 cupy 100 0.061 0.000 0.060 0.061 0.061 0.061 0.061 50.960
4,194,304 numpy 10 3.089 0.017 3.052 3.083 3.089 3.100 3.114 1.000
(time in wall seconds, less is better)
$ JAX_BACKEND_TARGET="grpc://$COLAB_TPU_ADDR" python run.py benchmarks/equation_of_state -b jax -b numpy --device tpu
benchmarks.equation_of_state
============================
Running on TPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.002 0.001 0.001 0.002 0.002 0.002 0.024 1.694
4,096 numpy 10,000 0.004 0.001 0.002 0.003 0.004 0.004 0.016 1.000
16,384 jax 10,000 0.002 0.001 0.001 0.002 0.002 0.002 0.012 6.065
16,384 numpy 1,000 0.013 0.002 0.008 0.012 0.013 0.014 0.032 1.000
65,536 jax 10,000 0.002 0.001 0.001 0.002 0.002 0.002 0.017 42.891
65,536 numpy 100 0.095 0.014 0.063 0.091 0.095 0.103 0.150 1.000
262,144 jax 10,000 0.002 0.001 0.001 0.002 0.002 0.003 0.018 170.553
262,144 numpy 100 0.419 0.050 0.302 0.382 0.425 0.453 0.532 1.000
1,048,576 jax 10,000 0.009 0.001 0.003 0.008 0.009 0.010 0.019 124.731
1,048,576 numpy 10 1.129 0.085 0.922 1.106 1.157 1.180 1.211 1.000
4,194,304 jax 1,000 0.047 0.008 0.029 0.038 0.047 0.054 0.065 79.958
4,194,304 numpy 10 3.724 0.256 3.255 3.675 3.786 3.903 4.068 1.000
(time in wall seconds, less is better)
A more balanced routine with many data dependencies (stencil operations), and tensor shapes of up to 5 dimensions. This is the most expensive part of Veros, so in a way this is the benchmark that interests me the most.
$ taskset -c 0 python run.py benchmarks/isoneutral_mixing/
Estimating repetitions...
Running 39930 benchmarks... [####################################] 100%
benchmarks.isoneutral_mixing
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numba 10,000 0.001 0.002 0.001 0.001 0.001 0.001 0.060 3.366
4,096 jax 10,000 0.002 0.001 0.002 0.002 0.002 0.002 0.037 1.977
4,096 theano 10,000 0.003 0.001 0.002 0.003 0.003 0.003 0.039 1.427
4,096 numpy 1,000 0.004 0.002 0.004 0.004 0.004 0.004 0.040 1.000
4,096 pytorch 1,000 0.007 0.001 0.006 0.006 0.006 0.006 0.025 0.677
16,384 numba 1,000 0.006 0.001 0.005 0.006 0.006 0.006 0.031 2.699
16,384 jax 1,000 0.007 0.001 0.006 0.007 0.007 0.007 0.043 2.275
16,384 theano 1,000 0.012 0.003 0.011 0.011 0.012 0.012 0.054 1.347
16,384 numpy 1,000 0.016 0.001 0.015 0.016 0.016 0.017 0.035 1.000
16,384 pytorch 1,000 0.016 0.002 0.015 0.016 0.016 0.016 0.052 0.998
65,536 numba 1,000 0.027 0.002 0.025 0.026 0.026 0.027 0.046 2.727
65,536 jax 1,000 0.029 0.002 0.027 0.028 0.029 0.029 0.064 2.531
65,536 theano 100 0.052 0.002 0.048 0.051 0.052 0.053 0.059 1.421
65,536 pytorch 100 0.071 0.004 0.064 0.067 0.070 0.074 0.082 1.040
65,536 numpy 100 0.074 0.002 0.068 0.073 0.074 0.075 0.082 1.000
262,144 numba 100 0.106 0.005 0.097 0.103 0.104 0.109 0.136 2.553
262,144 jax 100 0.116 0.005 0.108 0.113 0.115 0.119 0.133 2.329
262,144 theano 100 0.193 0.006 0.180 0.191 0.194 0.197 0.216 1.398
262,144 pytorch 100 0.258 0.009 0.227 0.252 0.258 0.264 0.285 1.048
262,144 numpy 100 0.270 0.031 0.250 0.264 0.268 0.272 0.569 1.000
1,048,576 numba 10 0.468 0.005 0.457 0.465 0.471 0.472 0.474 3.303
1,048,576 jax 10 0.570 0.015 0.535 0.563 0.575 0.578 0.595 2.713
1,048,576 theano 10 0.896 0.023 0.842 0.891 0.901 0.907 0.924 1.727
1,048,576 numpy 10 1.547 0.016 1.505 1.544 1.552 1.558 1.566 1.000
1,048,576 pytorch 10 2.033 0.036 1.993 2.011 2.025 2.037 2.129 0.761
4,194,304 numba 10 1.824 0.025 1.772 1.815 1.825 1.832 1.872 3.099
4,194,304 jax 10 2.330 0.046 2.226 2.340 2.350 2.357 2.369 2.425
4,194,304 theano 10 3.680 0.081 3.486 3.668 3.715 3.733 3.747 1.536
4,194,304 numpy 10 5.652 0.104 5.407 5.667 5.696 5.714 5.729 1.000
4,194,304 pytorch 10 6.151 0.054 6.065 6.102 6.160 6.185 6.248 0.919
(time in wall seconds, less is better)
$ for backend in jax pytorch cupy; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/isoneutral_mixing/ --device gpu -b $backend -b numpy; done
Estimating repetitions...
Running 34332 benchmarks... [####################################] 100%
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.002 0.001 0.001 0.001 0.002 0.002 0.019 2.424
4,096 numpy 1,000 0.004 0.001 0.003 0.004 0.004 0.004 0.018 1.000
16,384 jax 10,000 0.003 0.001 0.001 0.002 0.002 0.004 0.023 5.881
16,384 numpy 1,000 0.016 0.001 0.014 0.015 0.015 0.016 0.028 1.000
65,536 jax 10,000 0.004 0.001 0.002 0.004 0.004 0.004 0.024 37.624
65,536 numpy 100 0.161 0.014 0.068 0.156 0.162 0.168 0.183 1.000
262,144 jax 1,000 0.005 0.001 0.005 0.005 0.005 0.005 0.018 63.256
262,144 numpy 100 0.344 0.014 0.269 0.339 0.346 0.352 0.373 1.000
1,048,576 jax 1,000 0.018 0.001 0.017 0.017 0.017 0.018 0.031 70.403
1,048,576 numpy 10 1.239 0.012 1.217 1.231 1.240 1.247 1.258 1.000
4,194,304 jax 100 0.063 0.000 0.063 0.063 0.063 0.063 0.067 77.945
4,194,304 numpy 10 4.920 0.024 4.868 4.911 4.923 4.936 4.952 1.000
(time in wall seconds, less is better)
Estimating repetitions...
Running 7332 benchmarks... [####################################] 100%
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 1,000 0.004 0.001 0.003 0.004 0.004 0.004 0.013 1.000
4,096 pytorch 1,000 0.008 0.001 0.007 0.008 0.008 0.008 0.018 0.502
16,384 pytorch 1,000 0.009 0.001 0.007 0.008 0.008 0.008 0.018 1.854
16,384 numpy 1,000 0.016 0.001 0.014 0.015 0.016 0.016 0.023 1.000
65,536 pytorch 1,000 0.009 0.001 0.007 0.008 0.008 0.009 0.018 10.181
65,536 numpy 100 0.089 0.009 0.071 0.085 0.087 0.090 0.132 1.000
262,144 pytorch 1,000 0.009 0.002 0.008 0.008 0.008 0.009 0.021 29.526
262,144 numpy 100 0.262 0.014 0.244 0.253 0.259 0.267 0.331 1.000
1,048,576 pytorch 1,000 0.022 0.001 0.022 0.022 0.022 0.022 0.032 55.302
1,048,576 numpy 10 1.218 0.018 1.190 1.206 1.221 1.232 1.244 1.000
4,194,304 pytorch 100 0.082 0.000 0.082 0.082 0.082 0.082 0.084 61.422
4,194,304 numpy 10 5.023 0.044 4.979 4.998 5.003 5.023 5.111 1.000
(time in wall seconds, less is better)
Estimating repetitions...
Running 7242 benchmarks... [####################################] 100%
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 1,000 0.005 0.003 0.004 0.004 0.004 0.004 0.024 1.000
4,096 cupy 1,000 0.018 0.004 0.014 0.015 0.016 0.018 0.034 0.313
16,384 numpy 1,000 0.016 0.002 0.014 0.015 0.016 0.016 0.029 1.000
16,384 cupy 1,000 0.017 0.003 0.014 0.015 0.016 0.017 0.035 0.928
65,536 cupy 1,000 0.018 0.003 0.015 0.016 0.016 0.017 0.039 9.358
65,536 numpy 100 0.165 0.012 0.073 0.161 0.165 0.170 0.187 1.000
262,144 cupy 1,000 0.018 0.004 0.015 0.016 0.016 0.018 0.043 19.340
262,144 numpy 10 0.348 0.011 0.330 0.342 0.345 0.350 0.367 1.000
1,048,576 cupy 1,000 0.024 0.003 0.022 0.023 0.023 0.024 0.041 51.576
1,048,576 numpy 10 1.254 0.021 1.227 1.242 1.253 1.259 1.305 1.000
4,194,304 cupy 100 0.087 0.003 0.085 0.085 0.085 0.087 0.098 56.541
4,194,304 numpy 10 4.903 0.030 4.858 4.884 4.892 4.927 4.951 1.000
(time in wall seconds, less is better)
$ JAX_BACKEND_TARGET="grpc://$COLAB_TPU_ADDR" python run.py benchmarks/isoneutral_mixing -b jax -b numpy --device tpu
benchmarks.isoneutral_mixing
============================
Running on TPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.005 0.003 0.002 0.004 0.004 0.005 0.041 1.933
4,096 numpy 1,000 0.009 0.001 0.005 0.008 0.009 0.010 0.020 1.000
16,384 jax 10,000 0.005 0.003 0.003 0.004 0.004 0.005 0.040 6.148
16,384 numpy 1,000 0.030 0.004 0.021 0.028 0.030 0.032 0.055 1.000
65,536 jax 1,000 0.017 0.006 0.006 0.012 0.016 0.021 0.051 7.817
65,536 numpy 100 0.131 0.016 0.096 0.125 0.132 0.139 0.166 1.000
262,144 jax 1,000 0.014 0.004 0.007 0.011 0.015 0.017 0.040 29.070
262,144 numpy 10 0.419 0.053 0.354 0.357 0.441 0.461 0.493 1.000
1,048,576 jax 1,000 0.063 0.007 0.047 0.058 0.061 0.065 0.106 27.709
1,048,576 numpy 10 1.739 0.221 1.493 1.529 1.715 1.928 2.037 1.000
4,194,304 jax 100 0.248 0.017 0.223 0.235 0.243 0.261 0.288 26.421
4,194,304 numpy 10 6.541 0.493 5.874 6.170 6.413 6.752 7.428 1.000
(time in wall seconds, less is better)
This routine consists of some stencil operations and some linear algebra (a tridiagonal matrix solver), which cannot be vectorized.
$ taskset -c 0 python run.py benchmarks/turbulent_kinetic_energy/
Estimating repetitions...
Running 44658 benchmarks... [####################################] 100%
benchmarks.turbulent_kinetic_energy
===================================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numba 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.008 2.323
4,096 jax 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.011 2.012
4,096 numpy 10,000 0.003 0.000 0.002 0.002 0.003 0.003 0.009 1.000
16,384 jax 10,000 0.003 0.000 0.003 0.003 0.003 0.003 0.016 2.715
16,384 numba 1,000 0.004 0.000 0.004 0.004 0.004 0.004 0.007 2.220
16,384 numpy 1,000 0.009 0.001 0.008 0.008 0.009 0.009 0.016 1.000
65,536 jax 1,000 0.012 0.001 0.011 0.011 0.012 0.012 0.018 3.212
65,536 numba 1,000 0.014 0.001 0.012 0.013 0.014 0.014 0.024 2.754
65,536 numpy 100 0.038 0.002 0.035 0.037 0.038 0.039 0.046 1.000
262,144 jax 100 0.046 0.002 0.042 0.045 0.045 0.046 0.055 2.900
262,144 numba 100 0.047 0.004 0.043 0.045 0.046 0.047 0.062 2.794
262,144 numpy 100 0.132 0.004 0.122 0.130 0.132 0.134 0.151 1.000
1,048,576 numba 100 0.189 0.006 0.173 0.186 0.188 0.192 0.202 3.095
1,048,576 jax 100 0.270 0.009 0.249 0.264 0.270 0.276 0.297 2.160
1,048,576 numpy 10 0.584 0.014 0.558 0.578 0.588 0.594 0.602 1.000
4,194,304 numba 10 0.750 0.016 0.725 0.737 0.754 0.763 0.772 3.395
4,194,304 jax 10 1.398 0.022 1.350 1.389 1.398 1.416 1.427 1.820
4,194,304 numpy 10 2.545 0.011 2.525 2.539 2.545 2.554 2.559 1.000
(time in wall seconds, less is better)
$ CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/turbulent_kinetic_energy/ --device gpu -b jax -b numpy
Estimating repetitions...
Running 43332 benchmarks... [####################################] 100%
benchmarks.turbulent_kinetic_energy
===================================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.002 0.001 0.002 0.002 0.002 0.002 0.013 1.266
4,096 numpy 10,000 0.003 0.000 0.002 0.002 0.003 0.003 0.007 1.000
16,384 jax 10,000 0.002 0.001 0.002 0.002 0.002 0.002 0.014 3.783
16,384 numpy 1,000 0.009 0.001 0.008 0.008 0.009 0.009 0.014 1.000
65,536 jax 10,000 0.003 0.001 0.002 0.003 0.003 0.003 0.012 17.423
65,536 numpy 100 0.048 0.003 0.044 0.046 0.047 0.049 0.061 1.000
262,144 jax 1,000 0.005 0.001 0.004 0.005 0.005 0.005 0.012 29.930
262,144 numpy 100 0.149 0.004 0.139 0.147 0.148 0.151 0.160 1.000
1,048,576 jax 1,000 0.013 0.001 0.013 0.013 0.013 0.013 0.020 44.627
1,048,576 numpy 10 0.592 0.008 0.578 0.586 0.593 0.598 0.603 1.000
4,194,304 jax 100 0.049 0.001 0.048 0.048 0.049 0.049 0.056 50.415
4,194,304 numpy 10 2.451 0.012 2.432 2.441 2.452 2.462 2.468 1.000
(time in wall seconds, less is better)
$ JAX_BACKEND_TARGET="grpc://$COLAB_TPU_ADDR" python run.py benchmarks/turbulent_kinetic_energy -b jax -b numpy --device tpu
benchmarks.turbulent_kinetic_energy
===================================
Running on TPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.004 0.002 0.002 0.003 0.003 0.004 0.020 1.425
4,096 numpy 10,000 0.005 0.001 0.003 0.004 0.005 0.006 0.018 1.000
16,384 jax 10,000 0.004 0.002 0.002 0.003 0.003 0.004 0.028 3.843
16,384 numpy 1,000 0.015 0.003 0.010 0.013 0.015 0.016 0.032 1.000
65,536 jax 10,000 0.005 0.002 0.003 0.004 0.004 0.005 0.027 13.077
65,536 numpy 100 0.060 0.009 0.044 0.057 0.060 0.066 0.082 1.000
262,144 jax 1,000 0.010 0.003 0.004 0.008 0.011 0.012 0.026 19.048
262,144 numpy 100 0.198 0.028 0.157 0.170 0.203 0.214 0.280 1.000
1,048,576 jax 1,000 0.076 0.007 0.056 0.070 0.075 0.080 0.110 10.219
1,048,576 numpy 10 0.772 0.083 0.673 0.684 0.788 0.850 0.871 1.000
4,194,304 jax 1,000 0.339 0.027 0.288 0.317 0.340 0.361 0.403 7.577
4,194,304 numpy 10 2.569 0.244 2.373 2.408 2.439 2.621 3.035 1.000
(time in wall seconds, less is better)