Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache initialization fails when a JAX Array is created before enabling local cache #25768

Open
flferretti opened this issue Jan 8, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@flferretti
Copy link
Contributor

flferretti commented Jan 8, 2025

Description

The persistent compilation cache in JAX fails to initialize if a JAX array is created prior to activating the local cache using jax.config.update. Removing the array creation line allows the cache to initialize correctly.

MRE with array allocation:

import jax
import jax.numpy as jnp

# This line causes the persistent cache to remain uninitialized.
a = jnp.zeros(4)

# Configuration for JAX persistent cache
jax.config.update("jax_logging_level", "DEBUG")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")

@jax.jit
def f(x):
    return x + 1

# Function invocation
x = jnp.zeros((2, 2))
f(x)
Full Log

DEBUG:2025-01-08 11:37:25,064:jax._src.dispatch:182: Finished tracing + transforming broadcast_in_dim for pjit in 0.000183344 sec
DEBUG:2025-01-08 11:37:25,065:jax._src.interpreters.pxla:1906: Compiling broadcast_in_dim with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2025-01-08 11:37:25,067:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.001679897 sec
DEBUG:2025-01-08 11:37:25,067:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CudaDevice(id=0)]]
DEBUG:2025-01-08 11:37:25,067:jax._src.compiler:239: get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:2025-01-08 11:37:25,076:jax._src.compiler:260: Enabling XLA kernel cache at '/tmp/jax_cache/xla_gpu_kernel_cache_file'
DEBUG:2025-01-08 11:37:25,076:jax._src.compiler:265: Enabling XLA autotuning cache at '/tmp/jax_cache/xla_gpu_per_fusion_autotune_cache_dir'
DEBUG:2025-01-08 11:37:25,076:jax._src.cache_key:152: get_cache_key hash of serialized computation: 7c595daa617132810980e2a78d5722364377c78aa62385b323a42352c06c0986
DEBUG:2025-01-08 11:37:25,076:jax._src.cache_key:158: get_cache_key hash after serializing computation: 7c595daa617132810980e2a78d5722364377c78aa62385b323a42352c06c0986
DEBUG:2025-01-08 11:37:25,076:jax._src.cache_key:152: get_cache_key hash of serialized jax_lib version: c8601d1831072872293c1f9c58282e40273dd0289eaea98e369c2037dc4231ae
DEBUG:2025-01-08 11:37:25,076:jax._src.cache_key:158: get_cache_key hash after serializing jax_lib version: 71bccc6c9a13fd5342b8a6453530356df28fc9bef6b8de1f013f37112c95aa3f
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:152: get_cache_key hash of serialized XLA flags: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:158: get_cache_key hash after serializing XLA flags: 71bccc6c9a13fd5342b8a6453530356df28fc9bef6b8de1f013f37112c95aa3f
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:152: get_cache_key hash of serialized compile_options: 9a2bca9520f15d649eb148fe1c967023e6c1abcbe23b5c18508ffd37fc2caa42
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:158: get_cache_key hash after serializing compile_options: dd3b57e335deee8784863afece324e434917d6cc3af26116758dda1b5d012223
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:152: get_cache_key hash of serialized accelerator_config: b58a62c4527e3728c60e269461bd03852cb6f48a6708b25e0307cd74663f17e9
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:158: get_cache_key hash after serializing accelerator_config: aefcf17c78285d8750a4153653d9e8622c90110f6194cfb1588f53bcd1ccb53e
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:152: get_cache_key hash of serialized compression: 0ea55c28f8014d8886b6248fe3da5d588f55c0823847a6b4579f1131b051b5e2
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:158: get_cache_key hash after serializing compression: d8606c1f0763d704c10857737957225e80346b1ea8e35098666c2fe6d93ff8cd
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:152: get_cache_key hash of serialized custom_hook: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
DEBUG:2025-01-08 11:37:25,077:jax._src.cache_key:158: get_cache_key hash after serializing custom_hook: d8606c1f0763d704c10857737957225e80346b1ea8e35098666c2fe6d93ff8cd
DEBUG:2025-01-08 11:37:25,077:jax._src.compilation_cache:215: get_executable_and_time: cache is disabled/not initialized
DEBUG:2025-01-08 11:37:25,077:jax._src.compiler:108: PERSISTENT COMPILATION CACHE MISS for 'jit_broadcast_in_dim' with key 'jit_broadcast_in_dim-d8606c1f0763d704c10857737957225e80346b1ea8e35098666c2fe6d93ff8cd'
DEBUG:2025-01-08 11:37:25,085:jax._src.compiler:730: 'jit_broadcast_in_dim' took at least 0.00 seconds to compile (0.01s)
DEBUG:2025-01-08 11:37:25,085:jax._src.compilation_cache:245: Not writing persistent cache entry with key 'jit_broadcast_in_dim-d8606c1f0763d704c10857737957225e80346b1ea8e35098666c2fe6d93ff8cd' since cache is disabled/not initialized
DEBUG:2025-01-08 11:37:25,085:jax._src.dispatch:182: Finished XLA compilation of jit(broadcast_in_dim) in 0.008747339 sec
DEBUG:2025-01-08 11:37:25,087:jax._src.dispatch:182: Finished tracing + transforming add for pjit in 0.000348568 sec
DEBUG:2025-01-08 11:37:25,087:jax._src.dispatch:182: Finished tracing + transforming f for pjit in 0.000899553 sec
DEBUG:2025-01-08 11:37:25,087:jax._src.interpreters.pxla:1906: Compiling f with global shapes and types [ShapedArray(float32[2,2])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2025-01-08 11:37:25,089:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(f) in 0.001835823 sec
DEBUG:2025-01-08 11:37:25,089:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CudaDevice(id=0)]]
DEBUG:2025-01-08 11:37:25,089:jax._src.compiler:239: get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:2025-01-08 11:37:25,090:jax._src.compiler:260: Enabling XLA kernel cache at '/tmp/jax_cache/xla_gpu_kernel_cache_file'
DEBUG:2025-01-08 11:37:25,090:jax._src.compiler:265: Enabling XLA autotuning cache at '/tmp/jax_cache/xla_gpu_per_fusion_autotune_cache_dir'
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:152: get_cache_key hash of serialized computation: 380ddb18f25ed5a1aca7f087de5f4a4d07f46c3e8f4d42ad21937e931b05da57
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:158: get_cache_key hash after serializing computation: 380ddb18f25ed5a1aca7f087de5f4a4d07f46c3e8f4d42ad21937e931b05da57
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:152: get_cache_key hash of serialized jax_lib version: c8601d1831072872293c1f9c58282e40273dd0289eaea98e369c2037dc4231ae
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:158: get_cache_key hash after serializing jax_lib version: 9c38da7ed7dfd35462c73d36e5b770ea7bf5ad679c67ca4625afea9e628528cf
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:152: get_cache_key hash of serialized XLA flags: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:158: get_cache_key hash after serializing XLA flags: 9c38da7ed7dfd35462c73d36e5b770ea7bf5ad679c67ca4625afea9e628528cf
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:152: get_cache_key hash of serialized compile_options: 9a2bca9520f15d649eb148fe1c967023e6c1abcbe23b5c18508ffd37fc2caa42
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:158: get_cache_key hash after serializing compile_options: 04f7b5724db48e6c4554f1b80916cfd93d671f5944eeaf988b077f7002d2f4b4
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:152: get_cache_key hash of serialized accelerator_config: b58a62c4527e3728c60e269461bd03852cb6f48a6708b25e0307cd74663f17e9
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:158: get_cache_key hash after serializing accelerator_config: 7fadb1901a3564de3579b886a670a31028011157831263b0bd6a13aa7a64d9c9
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:152: get_cache_key hash of serialized compression: 0ea55c28f8014d8886b6248fe3da5d588f55c0823847a6b4579f1131b051b5e2
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:158: get_cache_key hash after serializing compression: 233097fb49ebef28937711d6252b53ea2f4caccf79727f493dab2ec62fb37bc9
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:152: get_cache_key hash of serialized custom_hook: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
DEBUG:2025-01-08 11:37:25,090:jax._src.cache_key:158: get_cache_key hash after serializing custom_hook: 233097fb49ebef28937711d6252b53ea2f4caccf79727f493dab2ec62fb37bc9
DEBUG:2025-01-08 11:37:25,090:jax._src.compilation_cache:215: get_executable_and_time: cache is disabled/not initialized
DEBUG:2025-01-08 11:37:25,090:jax._src.compiler:108: PERSISTENT COMPILATION CACHE MISS for 'jit_f' with key 'jit_f-233097fb49ebef28937711d6252b53ea2f4caccf79727f493dab2ec62fb37bc9'
DEBUG:2025-01-08 11:37:25,096:jax._src.compiler:730: 'jit_f' took at least 0.00 seconds to compile (0.01s)
DEBUG:2025-01-08 11:37:25,096:jax._src.compilation_cache:245: Not writing persistent cache entry with key 'jit_f-233097fb49ebef28937711d6252b53ea2f4caccf79727f493dab2ec62fb37bc9' since cache is disabled/not initialized
DEBUG:2025-01-08 11:37:25,096:jax._src.dispatch:182: Finished XLA compilation of jit(f) in 0.006633282 sec

This issue is present also with ClassVar default values if they are JAX NumPy arrays and with default arguments of functions. (See also ami-iit/jaxsim#322 and ami-iit/jaxsim#329)

MRE with ClassVar:

import jax
import jax.numpy as jnp

class MyClass:
    # This attribute causes the persistent cache to remain uninitialized.
    default_array: ClassVar = jnp.zeros((3,))

# Configuration for JAX persistent cache
jax.config.update("jax_logging_level", "DEBUG")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")

@jax.jit
def f(x):
    return x + 1

# Function invocation
x = jnp.zeros((2, 2))
f(x)

MRE with default arguments:

import jax
import jax.numpy as jnp

# The `array` default value causes the persistent cache to remain uninitialized.
def test_fn(array: jax.Array = jnp.zeros(3)):
    return array

# Configuration for JAX persistent cache
jax.config.update("jax_logging_level", "DEBUG")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")

@jax.jit
def f(x):
    return x + 1

# Function invocation
x = jnp.zeros((2, 2))
f(x)

This was quite hard to spot for me, so I would expect a more clear error message if for some reason the cache cannot be initialized.

Thank you for your help!

FYI @traversaro @xela-95 @CarlottaSartore @paLeziart

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.0
python: 3.13.1 | packaged by conda-forge | (main, Dec  5 2024, 21:23:54) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4060 Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='iitbmp014lw015u', release='6.11.0-13-generic', version='#14-Ubuntu SMP PREEMPT_DYNAMIC Sat Nov 30 23:51:51 UTC 2024', machine='x86_64')


$ nvidia-smi
Wed Jan  8 11:54:03 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4060 ...    Off |   00000000:01:00.0 Off |                  N/A |
| N/A   44C    P3             10W /   40W |    6094MiB /   8188MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      3416      G   /usr/bin/gnome-shell                            2MiB |
|    0   N/A  N/A   1535899      C   ...iforge3/envs/jaxpypi/bin/python3.13       5980MiB |
|    0   N/A  N/A   1552631      C   ...iforge3/envs/jaxpypi/bin/python3.13         94MiB |
+-----------------------------------------------------------------------------------------+
pip list

Package                  Version     Editable project location
------------------------ ----------- --------------------------------
absl-py                  2.1.0
asttokens                3.0.0
chex                     0.1.87
colorama                 0.4.6
coloredlogs              15.0.1
decorator                5.1.1
docstring_parser         0.16
etils                    1.11.0
exceptiongroup           1.2.2
executing                2.1.0
gitdb                    4.0.11
GitPython                3.1.43
humanfriendly            10.0
importlib_resources      6.5.2
iniconfig                2.0.0
ipython                  8.30.0
jax                      0.4.38
jax-cuda12-pjrt          0.4.34
jax-cuda12-plugin        0.4.34
jax_dataclasses          1.6.1
jaxlib                   0.4.38
jaxlie                   1.4.2
jaxsim                   0.5.1.dev69 /home/fferretti-iit.local/jaxsim
jedi                     0.19.2
markdown-it-py           3.0.0
mashumaro                3.15
matplotlib-inline        0.1.7
mdurl                    0.1.2
ml_dtypes                0.5.0
numpy                    2.2.0
nvidia-cublas-cu12       12.6.4.1
nvidia-cuda-cupti-cu12   12.6.80
nvidia-cuda-nvcc-cu12    12.6.85
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12        9.6.0.74
nvidia-cufft-cu12        11.3.0.4
nvidia-cusolver-cu12     11.7.1.2
nvidia-cusparse-cu12     12.5.4.2
nvidia-nccl-cu12         2.23.4
nvidia-nvjitlink-cu12    12.6.85
opt_einsum               3.4.0
optax                    0.2.4
packaging                24.2
parso                    0.8.4
pexpect                  4.9.0
pip                      24.3.1
pluggy                   1.5.0
pptree                   3.1
prompt_toolkit           3.0.48
ptyprocess               0.7.0
pure_eval                0.2.3
py-cpuinfo               9.0.0
Pygments                 2.18.0
pytest                   8.3.4
pytest-benchmark         5.1.0
qpax                     0.0.9
resolve-robotics-uri-py  0.3.0
rich                     13.9.4
robot_descriptions       1.13.0
rod                      0.3.4
scipy                    1.14.1
setuptools               75.6.0
shtab                    1.7.1
smmap                    5.0.1
stack-data               0.6.3
tomli                    2.2.1
toolz                    1.0.0
tqdm                     4.67.1
traitlets                5.14.3
trimesh                  4.5.3
typeguard                4.4.1
typing_extensions        4.12.2
tyro                     0.9.2
wcwidth                  0.2.13
xmltodict                0.14.2

@flferretti flferretti added the bug Something isn't working label Jan 8, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 8, 2025

I think this is working as expected: the compilation cache state must be enabled before backend initialization. Many of the JAX configuration options have similar requirements.

I'm assigning @skye who knows more about this code path and may be able to confirm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants