Skip to content

Latest commit

 

History

History
542 lines (391 loc) · 33 KB

colab.md

File metadata and controls

542 lines (391 loc) · 33 KB

Example results on Google Colab

Results are as reported by this notebook. To re-run these experiments, just head over to Google Colab, upload the notebook, and run the cells one by one.

Hardware at the time of writing (19 Dec 2019):

  • Intel(R) Xeon(R) CPU @ 2.30GHz (1 core, 2 threads)
  • 12.6GB of RAM
  • NVidia Tesla K80 GPU with 12GB memory

Caveat: Jax does not support 64bit floating point precision on TPU architectures (yet). Therefore, the Jax + TPU results are not bit-identical to all other backends and devices, so it's not really an apples-to-apples comparison.

Caveat²: I didn't manage to get Bohrium to work on Colab, so it's missing from these results.

Contents

Equation of state

An equation consisting of >100 terms with no data dependencies and only elementary math. This benchmark should represent a best-case scenario for vector instructions and GPU performance.

CPU

$ taskset -c 0 python run.py benchmarks/equation_of_state/

Estimating repetitions...
Running 100116 benchmarks...  [####################################]  100%

benchmarks.equation_of_state
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numba         10,000     0.001     0.000     0.000     0.000     0.000     0.001     0.003     3.402
       4,096  theano        10,000     0.001     0.000     0.000     0.001     0.001     0.001     0.004     2.668
       4,096  tensorflow    10,000     0.001     0.000     0.000     0.001     0.001     0.001     0.004     2.645
       4,096  jax           10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.004     2.500
       4,096  numpy         10,000     0.002     0.000     0.001     0.002     0.002     0.002     0.006     1.000
       4,096  pytorch       10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.008     0.742

      16,384  tensorflow    10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.006     4.418
      16,384  jax           10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.006     4.244
      16,384  numba         10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.007     3.994
      16,384  theano        10,000     0.003     0.000     0.002     0.002     0.002     0.003     0.007     3.443
      16,384  pytorch        1,000     0.008     0.001     0.007     0.008     0.008     0.009     0.015     1.037
      16,384  numpy          1,000     0.009     0.001     0.007     0.008     0.009     0.009     0.014     1.000

      65,536  tensorflow     1,000     0.006     0.001     0.006     0.006     0.006     0.006     0.010     7.932
      65,536  jax            1,000     0.007     0.001     0.006     0.007     0.007     0.007     0.012     6.998
      65,536  numba          1,000     0.009     0.001     0.008     0.008     0.008     0.009     0.014     5.852
      65,536  theano         1,000     0.010     0.001     0.009     0.009     0.010     0.010     0.016     5.123
      65,536  pytorch          100     0.048     0.002     0.045     0.046     0.047     0.048     0.058     1.045
      65,536  numpy            100     0.050     0.002     0.047     0.049     0.049     0.050     0.061     1.000

     262,144  tensorflow     1,000     0.020     0.002     0.018     0.019     0.020     0.020     0.034    12.017
     262,144  jax            1,000     0.023     0.002     0.021     0.022     0.023     0.023     0.034    10.489
     262,144  numba          1,000     0.031     0.002     0.029     0.031     0.031     0.031     0.044     7.725
     262,144  theano           100     0.035     0.002     0.034     0.034     0.035     0.035     0.043     6.887
     262,144  pytorch          100     0.201     0.006     0.180     0.197     0.201     0.205     0.231     1.206
     262,144  numpy            100     0.242     0.007     0.210     0.239     0.242     0.246     0.265     1.000

   1,048,576  tensorflow       100     0.090     0.004     0.084     0.087     0.088     0.092     0.102     8.535
   1,048,576  jax              100     0.099     0.005     0.093     0.095     0.097     0.101     0.115     7.756
   1,048,576  numba            100     0.133     0.005     0.126     0.129     0.131     0.136     0.146     5.778
   1,048,576  theano           100     0.148     0.004     0.143     0.145     0.147     0.151     0.160     5.171
   1,048,576  numpy             10     0.766     0.005     0.761     0.763     0.764     0.768     0.780     1.000
   1,048,576  pytorch           10     0.807     0.007     0.795     0.803     0.806     0.809     0.822     0.950

   4,194,304  tensorflow        10     0.355     0.007     0.347     0.348     0.354     0.358     0.370     9.285
   4,194,304  jax               10     0.396     0.005     0.387     0.392     0.396     0.400     0.406     8.316
   4,194,304  numba             10     0.514     0.006     0.504     0.511     0.515     0.518     0.522     6.412
   4,194,304  theano            10     0.590     0.007     0.577     0.587     0.589     0.595     0.604     5.581
   4,194,304  numpy             10     3.295     0.017     3.262     3.284     3.293     3.302     3.327     1.000
   4,194,304  pytorch           10     3.837     0.016     3.819     3.827     3.833     3.841     3.876     0.859

(time in wall seconds, less is better)

GPU

$ for backend in jax tensorflow pytorch cupy; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/equation_of_state/ --device gpu -b $backend -b numpy; done

Estimating repetitions...
Running 71232 benchmarks...  [####################################]  100%

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004     5.083
       4,096  numpy         10,000     0.002     0.000     0.001     0.002     0.002     0.002     0.005     1.000

      16,384  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004    25.571
      16,384  numpy          1,000     0.009     0.001     0.007     0.008     0.008     0.009     0.013     1.000

      65,536  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.003   352.617
      65,536  numpy            100     0.140     0.006     0.122     0.137     0.140     0.144     0.154     1.000

     262,144  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.003   533.222
     262,144  numpy            100     0.251     0.035     0.228     0.243     0.247     0.251     0.595     1.000

   1,048,576  jax           10,000     0.002     0.000     0.000     0.002     0.002     0.003     0.006   290.491
   1,048,576  numpy             10     0.707     0.007     0.699     0.703     0.707     0.709     0.723     1.000

   4,194,304  jax           10,000     0.003     0.000     0.001     0.003     0.003     0.003     0.005  1140.959
   4,194,304  numpy             10     3.060     0.012     3.041     3.049     3.063     3.071     3.073     1.000

(time in wall seconds, less is better)

Estimating repetitions...
Running 71232 benchmarks...  [####################################]  100%

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  tensorflow    10,000     0.001     0.000     0.000     0.001     0.001     0.001     0.007     2.552
       4,096  numpy         10,000     0.002     0.000     0.001     0.002     0.002     0.002     0.005     1.000

      16,384  tensorflow    10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.006    11.811
      16,384  numpy          1,000     0.009     0.001     0.007     0.008     0.008     0.009     0.013     1.000

      65,536  tensorflow    10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.010    48.165
      65,536  numpy            100     0.047     0.002     0.043     0.045     0.046     0.047     0.058     1.000

     262,144  tensorflow    10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.012   188.762
     262,144  numpy            100     0.251     0.008     0.220     0.247     0.252     0.255     0.272     1.000

   1,048,576  tensorflow    10,000     0.003     0.000     0.002     0.003     0.003     0.003     0.009   261.178
   1,048,576  numpy             10     0.720     0.017     0.680     0.715     0.721     0.729     0.741     1.000

   4,194,304  tensorflow    10,000     0.011     0.001     0.010     0.011     0.011     0.011     0.025   279.396
   4,194,304  numpy             10     3.069     0.014     3.053     3.059     3.067     3.074     3.104     1.000

(time in wall seconds, less is better)

Estimating repetitions...
Running 16332 benchmarks...  [####################################]  100%

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy         10,000     0.002     0.000     0.001     0.001     0.002     0.002     0.009     1.000
       4,096  pytorch        1,000     0.004     0.001     0.003     0.004     0.004     0.004     0.008     0.389

      16,384  pytorch        1,000     0.004     0.001     0.003     0.004     0.004     0.004     0.009     1.987
      16,384  numpy          1,000     0.008     0.001     0.007     0.007     0.008     0.008     0.013     1.000

      65,536  pytorch        1,000     0.004     0.001     0.003     0.004     0.004     0.004     0.008    17.421
      65,536  numpy            100     0.073     0.005     0.048     0.071     0.073     0.075     0.084     1.000

     262,144  pytorch        1,000     0.004     0.001     0.004     0.004     0.004     0.004     0.008    40.591
     262,144  numpy            100     0.174     0.005     0.163     0.171     0.174     0.177     0.191     1.000

   1,048,576  pytorch        1,000     0.017     0.000     0.016     0.017     0.017     0.017     0.017    50.141
   1,048,576  numpy             10     0.832     0.008     0.821     0.829     0.830     0.836     0.849     1.000

   4,194,304  pytorch          100     0.062     0.000     0.062     0.062     0.062     0.062     0.063    50.867
   4,194,304  numpy             10     3.156     0.021     3.130     3.139     3.151     3.161     3.197     1.000

(time in wall seconds, less is better)

Estimating repetitions...
Running 16332 benchmarks...  [####################################]  100%

benchmarks.equation_of_state
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy         10,000     0.002     0.000     0.001     0.001     0.002     0.002     0.006     1.000
       4,096  cupy           1,000     0.011     0.001     0.010     0.011     0.011     0.011     0.019     0.146

      16,384  numpy          1,000     0.008     0.001     0.007     0.007     0.008     0.009     0.014     1.000
      16,384  cupy           1,000     0.011     0.001     0.010     0.011     0.011     0.011     0.021     0.719

      65,536  cupy           1,000     0.011     0.002     0.010     0.011     0.011     0.011     0.023    13.935
      65,536  numpy            100     0.159     0.018     0.063     0.158     0.162     0.166     0.182     1.000

     262,144  cupy           1,000     0.011     0.002     0.010     0.011     0.011     0.011     0.023    23.526
     262,144  numpy            100     0.268     0.008     0.235     0.264     0.269     0.274     0.284     1.000

   1,048,576  cupy           1,000     0.017     0.001     0.016     0.016     0.016     0.016     0.022    44.047
   1,048,576  numpy             10     0.730     0.010     0.707     0.723     0.735     0.738     0.741     1.000

   4,194,304  cupy             100     0.061     0.000     0.060     0.061     0.061     0.061     0.061    50.960
   4,194,304  numpy             10     3.089     0.017     3.052     3.083     3.089     3.100     3.114     1.000

(time in wall seconds, less is better)

TPU

$ JAX_BACKEND_TARGET="grpc://$COLAB_TPU_ADDR" python run.py benchmarks/equation_of_state -b jax -b numpy --device tpu

benchmarks.equation_of_state
============================
Running on TPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.002     0.001     0.001     0.002     0.002     0.002     0.024     1.694
       4,096  numpy         10,000     0.004     0.001     0.002     0.003     0.004     0.004     0.016     1.000

      16,384  jax           10,000     0.002     0.001     0.001     0.002     0.002     0.002     0.012     6.065
      16,384  numpy          1,000     0.013     0.002     0.008     0.012     0.013     0.014     0.032     1.000

      65,536  jax           10,000     0.002     0.001     0.001     0.002     0.002     0.002     0.017    42.891
      65,536  numpy            100     0.095     0.014     0.063     0.091     0.095     0.103     0.150     1.000

     262,144  jax           10,000     0.002     0.001     0.001     0.002     0.002     0.003     0.018   170.553
     262,144  numpy            100     0.419     0.050     0.302     0.382     0.425     0.453     0.532     1.000

   1,048,576  jax           10,000     0.009     0.001     0.003     0.008     0.009     0.010     0.019   124.731
   1,048,576  numpy             10     1.129     0.085     0.922     1.106     1.157     1.180     1.211     1.000

   4,194,304  jax            1,000     0.047     0.008     0.029     0.038     0.047     0.054     0.065    79.958
   4,194,304  numpy             10     3.724     0.256     3.255     3.675     3.786     3.903     4.068     1.000

(time in wall seconds, less is better)

Isoneutral mixing

A more balanced routine with many data dependencies (stencil operations), and tensor shapes of up to 5 dimensions. This is the most expensive part of Veros, so in a way this is the benchmark that interests me the most.

CPU

$ taskset -c 0 python run.py benchmarks/isoneutral_mixing/

Estimating repetitions...
Running 39930 benchmarks...  [####################################]  100%

benchmarks.isoneutral_mixing
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numba         10,000     0.001     0.002     0.001     0.001     0.001     0.001     0.060     3.366
       4,096  jax           10,000     0.002     0.001     0.002     0.002     0.002     0.002     0.037     1.977
       4,096  theano        10,000     0.003     0.001     0.002     0.003     0.003     0.003     0.039     1.427
       4,096  numpy          1,000     0.004     0.002     0.004     0.004     0.004     0.004     0.040     1.000
       4,096  pytorch        1,000     0.007     0.001     0.006     0.006     0.006     0.006     0.025     0.677

      16,384  numba          1,000     0.006     0.001     0.005     0.006     0.006     0.006     0.031     2.699
      16,384  jax            1,000     0.007     0.001     0.006     0.007     0.007     0.007     0.043     2.275
      16,384  theano         1,000     0.012     0.003     0.011     0.011     0.012     0.012     0.054     1.347
      16,384  numpy          1,000     0.016     0.001     0.015     0.016     0.016     0.017     0.035     1.000
      16,384  pytorch        1,000     0.016     0.002     0.015     0.016     0.016     0.016     0.052     0.998

      65,536  numba          1,000     0.027     0.002     0.025     0.026     0.026     0.027     0.046     2.727
      65,536  jax            1,000     0.029     0.002     0.027     0.028     0.029     0.029     0.064     2.531
      65,536  theano           100     0.052     0.002     0.048     0.051     0.052     0.053     0.059     1.421
      65,536  pytorch          100     0.071     0.004     0.064     0.067     0.070     0.074     0.082     1.040
      65,536  numpy            100     0.074     0.002     0.068     0.073     0.074     0.075     0.082     1.000

     262,144  numba            100     0.106     0.005     0.097     0.103     0.104     0.109     0.136     2.553
     262,144  jax              100     0.116     0.005     0.108     0.113     0.115     0.119     0.133     2.329
     262,144  theano           100     0.193     0.006     0.180     0.191     0.194     0.197     0.216     1.398
     262,144  pytorch          100     0.258     0.009     0.227     0.252     0.258     0.264     0.285     1.048
     262,144  numpy            100     0.270     0.031     0.250     0.264     0.268     0.272     0.569     1.000

   1,048,576  numba             10     0.468     0.005     0.457     0.465     0.471     0.472     0.474     3.303
   1,048,576  jax               10     0.570     0.015     0.535     0.563     0.575     0.578     0.595     2.713
   1,048,576  theano            10     0.896     0.023     0.842     0.891     0.901     0.907     0.924     1.727
   1,048,576  numpy             10     1.547     0.016     1.505     1.544     1.552     1.558     1.566     1.000
   1,048,576  pytorch           10     2.033     0.036     1.993     2.011     2.025     2.037     2.129     0.761

   4,194,304  numba             10     1.824     0.025     1.772     1.815     1.825     1.832     1.872     3.099
   4,194,304  jax               10     2.330     0.046     2.226     2.340     2.350     2.357     2.369     2.425
   4,194,304  theano            10     3.680     0.081     3.486     3.668     3.715     3.733     3.747     1.536
   4,194,304  numpy             10     5.652     0.104     5.407     5.667     5.696     5.714     5.729     1.000
   4,194,304  pytorch           10     6.151     0.054     6.065     6.102     6.160     6.185     6.248     0.919

(time in wall seconds, less is better)

GPU

$ for backend in jax pytorch cupy; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/isoneutral_mixing/ --device gpu -b $backend -b numpy; done

Estimating repetitions...
Running 34332 benchmarks...  [####################################]  100%

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.002     0.001     0.001     0.001     0.002     0.002     0.019     2.424
       4,096  numpy          1,000     0.004     0.001     0.003     0.004     0.004     0.004     0.018     1.000

      16,384  jax           10,000     0.003     0.001     0.001     0.002     0.002     0.004     0.023     5.881
      16,384  numpy          1,000     0.016     0.001     0.014     0.015     0.015     0.016     0.028     1.000

      65,536  jax           10,000     0.004     0.001     0.002     0.004     0.004     0.004     0.024    37.624
      65,536  numpy            100     0.161     0.014     0.068     0.156     0.162     0.168     0.183     1.000

     262,144  jax            1,000     0.005     0.001     0.005     0.005     0.005     0.005     0.018    63.256
     262,144  numpy            100     0.344     0.014     0.269     0.339     0.346     0.352     0.373     1.000

   1,048,576  jax            1,000     0.018     0.001     0.017     0.017     0.017     0.018     0.031    70.403
   1,048,576  numpy             10     1.239     0.012     1.217     1.231     1.240     1.247     1.258     1.000

   4,194,304  jax              100     0.063     0.000     0.063     0.063     0.063     0.063     0.067    77.945
   4,194,304  numpy             10     4.920     0.024     4.868     4.911     4.923     4.936     4.952     1.000

(time in wall seconds, less is better)
Estimating repetitions...
Running 7332 benchmarks...  [####################################]  100%

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy          1,000     0.004     0.001     0.003     0.004     0.004     0.004     0.013     1.000
       4,096  pytorch        1,000     0.008     0.001     0.007     0.008     0.008     0.008     0.018     0.502

      16,384  pytorch        1,000     0.009     0.001     0.007     0.008     0.008     0.008     0.018     1.854
      16,384  numpy          1,000     0.016     0.001     0.014     0.015     0.016     0.016     0.023     1.000

      65,536  pytorch        1,000     0.009     0.001     0.007     0.008     0.008     0.009     0.018    10.181
      65,536  numpy            100     0.089     0.009     0.071     0.085     0.087     0.090     0.132     1.000

     262,144  pytorch        1,000     0.009     0.002     0.008     0.008     0.008     0.009     0.021    29.526
     262,144  numpy            100     0.262     0.014     0.244     0.253     0.259     0.267     0.331     1.000

   1,048,576  pytorch        1,000     0.022     0.001     0.022     0.022     0.022     0.022     0.032    55.302
   1,048,576  numpy             10     1.218     0.018     1.190     1.206     1.221     1.232     1.244     1.000

   4,194,304  pytorch          100     0.082     0.000     0.082     0.082     0.082     0.082     0.084    61.422
   4,194,304  numpy             10     5.023     0.044     4.979     4.998     5.003     5.023     5.111     1.000

(time in wall seconds, less is better)
Estimating repetitions...
Running 7242 benchmarks...  [####################################]  100%

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy          1,000     0.005     0.003     0.004     0.004     0.004     0.004     0.024     1.000
       4,096  cupy           1,000     0.018     0.004     0.014     0.015     0.016     0.018     0.034     0.313

      16,384  numpy          1,000     0.016     0.002     0.014     0.015     0.016     0.016     0.029     1.000
      16,384  cupy           1,000     0.017     0.003     0.014     0.015     0.016     0.017     0.035     0.928

      65,536  cupy           1,000     0.018     0.003     0.015     0.016     0.016     0.017     0.039     9.358
      65,536  numpy            100     0.165     0.012     0.073     0.161     0.165     0.170     0.187     1.000

     262,144  cupy           1,000     0.018     0.004     0.015     0.016     0.016     0.018     0.043    19.340
     262,144  numpy             10     0.348     0.011     0.330     0.342     0.345     0.350     0.367     1.000

   1,048,576  cupy           1,000     0.024     0.003     0.022     0.023     0.023     0.024     0.041    51.576
   1,048,576  numpy             10     1.254     0.021     1.227     1.242     1.253     1.259     1.305     1.000

   4,194,304  cupy             100     0.087     0.003     0.085     0.085     0.085     0.087     0.098    56.541
   4,194,304  numpy             10     4.903     0.030     4.858     4.884     4.892     4.927     4.951     1.000

(time in wall seconds, less is better)

TPU

$ JAX_BACKEND_TARGET="grpc://$COLAB_TPU_ADDR" python run.py benchmarks/isoneutral_mixing -b jax -b numpy --device tpu

benchmarks.isoneutral_mixing
============================
Running on TPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.005     0.003     0.002     0.004     0.004     0.005     0.041     1.933
       4,096  numpy          1,000     0.009     0.001     0.005     0.008     0.009     0.010     0.020     1.000

      16,384  jax           10,000     0.005     0.003     0.003     0.004     0.004     0.005     0.040     6.148
      16,384  numpy          1,000     0.030     0.004     0.021     0.028     0.030     0.032     0.055     1.000

      65,536  jax            1,000     0.017     0.006     0.006     0.012     0.016     0.021     0.051     7.817
      65,536  numpy            100     0.131     0.016     0.096     0.125     0.132     0.139     0.166     1.000

     262,144  jax            1,000     0.014     0.004     0.007     0.011     0.015     0.017     0.040    29.070
     262,144  numpy             10     0.419     0.053     0.354     0.357     0.441     0.461     0.493     1.000

   1,048,576  jax            1,000     0.063     0.007     0.047     0.058     0.061     0.065     0.106    27.709
   1,048,576  numpy             10     1.739     0.221     1.493     1.529     1.715     1.928     2.037     1.000

   4,194,304  jax              100     0.248     0.017     0.223     0.235     0.243     0.261     0.288    26.421
   4,194,304  numpy             10     6.541     0.493     5.874     6.170     6.413     6.752     7.428     1.000

(time in wall seconds, less is better)

Turbulent kinetic energy

This routine consists of some stencil operations and some linear algebra (a tridiagonal matrix solver), which cannot be vectorized.

CPU

$ taskset -c 0 python run.py benchmarks/turbulent_kinetic_energy/

Estimating repetitions...
Running 44658 benchmarks...  [####################################]  100%

benchmarks.turbulent_kinetic_energy
===================================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numba         10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.008     2.323
       4,096  jax           10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.011     2.012
       4,096  numpy         10,000     0.003     0.000     0.002     0.002     0.003     0.003     0.009     1.000

      16,384  jax           10,000     0.003     0.000     0.003     0.003     0.003     0.003     0.016     2.715
      16,384  numba          1,000     0.004     0.000     0.004     0.004     0.004     0.004     0.007     2.220
      16,384  numpy          1,000     0.009     0.001     0.008     0.008     0.009     0.009     0.016     1.000

      65,536  jax            1,000     0.012     0.001     0.011     0.011     0.012     0.012     0.018     3.212
      65,536  numba          1,000     0.014     0.001     0.012     0.013     0.014     0.014     0.024     2.754
      65,536  numpy            100     0.038     0.002     0.035     0.037     0.038     0.039     0.046     1.000

     262,144  jax              100     0.046     0.002     0.042     0.045     0.045     0.046     0.055     2.900
     262,144  numba            100     0.047     0.004     0.043     0.045     0.046     0.047     0.062     2.794
     262,144  numpy            100     0.132     0.004     0.122     0.130     0.132     0.134     0.151     1.000

   1,048,576  numba            100     0.189     0.006     0.173     0.186     0.188     0.192     0.202     3.095
   1,048,576  jax              100     0.270     0.009     0.249     0.264     0.270     0.276     0.297     2.160
   1,048,576  numpy             10     0.584     0.014     0.558     0.578     0.588     0.594     0.602     1.000

   4,194,304  numba             10     0.750     0.016     0.725     0.737     0.754     0.763     0.772     3.395
   4,194,304  jax               10     1.398     0.022     1.350     1.389     1.398     1.416     1.427     1.820
   4,194,304  numpy             10     2.545     0.011     2.525     2.539     2.545     2.554     2.559     1.000

(time in wall seconds, less is better)

GPU

$ CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/turbulent_kinetic_energy/ --device gpu -b jax -b numpy

Estimating repetitions...
Running 43332 benchmarks...  [####################################]  100%

benchmarks.turbulent_kinetic_energy
===================================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.002     0.001     0.002     0.002     0.002     0.002     0.013     1.266
       4,096  numpy         10,000     0.003     0.000     0.002     0.002     0.003     0.003     0.007     1.000

      16,384  jax           10,000     0.002     0.001     0.002     0.002     0.002     0.002     0.014     3.783
      16,384  numpy          1,000     0.009     0.001     0.008     0.008     0.009     0.009     0.014     1.000

      65,536  jax           10,000     0.003     0.001     0.002     0.003     0.003     0.003     0.012    17.423
      65,536  numpy            100     0.048     0.003     0.044     0.046     0.047     0.049     0.061     1.000

     262,144  jax            1,000     0.005     0.001     0.004     0.005     0.005     0.005     0.012    29.930
     262,144  numpy            100     0.149     0.004     0.139     0.147     0.148     0.151     0.160     1.000

   1,048,576  jax            1,000     0.013     0.001     0.013     0.013     0.013     0.013     0.020    44.627
   1,048,576  numpy             10     0.592     0.008     0.578     0.586     0.593     0.598     0.603     1.000

   4,194,304  jax              100     0.049     0.001     0.048     0.048     0.049     0.049     0.056    50.415
   4,194,304  numpy             10     2.451     0.012     2.432     2.441     2.452     2.462     2.468     1.000

(time in wall seconds, less is better)

TPU

$ JAX_BACKEND_TARGET="grpc://$COLAB_TPU_ADDR" python run.py benchmarks/turbulent_kinetic_energy -b jax -b numpy --device tpu

benchmarks.turbulent_kinetic_energy
===================================
Running on TPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.004     0.002     0.002     0.003     0.003     0.004     0.020     1.425
       4,096  numpy         10,000     0.005     0.001     0.003     0.004     0.005     0.006     0.018     1.000

      16,384  jax           10,000     0.004     0.002     0.002     0.003     0.003     0.004     0.028     3.843
      16,384  numpy          1,000     0.015     0.003     0.010     0.013     0.015     0.016     0.032     1.000

      65,536  jax           10,000     0.005     0.002     0.003     0.004     0.004     0.005     0.027    13.077
      65,536  numpy            100     0.060     0.009     0.044     0.057     0.060     0.066     0.082     1.000

     262,144  jax            1,000     0.010     0.003     0.004     0.008     0.011     0.012     0.026    19.048
     262,144  numpy            100     0.198     0.028     0.157     0.170     0.203     0.214     0.280     1.000

   1,048,576  jax            1,000     0.076     0.007     0.056     0.070     0.075     0.080     0.110    10.219
   1,048,576  numpy             10     0.772     0.083     0.673     0.684     0.788     0.850     0.871     1.000

   4,194,304  jax            1,000     0.339     0.027     0.288     0.317     0.340     0.361     0.403     7.577
   4,194,304  numpy             10     2.569     0.244     2.373     2.408     2.439     2.621     3.035     1.000

(time in wall seconds, less is better)