You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The persistent compilation cache in JAX fails to initialize if a JAX array is created prior to activating the local cache using jax.config.update. Removing the array creation line allows the cache to initialize correctly.
MRE with array allocation:
importjaximportjax.numpyasjnp# This line causes the persistent cache to remain uninitialized.a=jnp.zeros(4)
# Configuration for JAX persistent cachejax.config.update("jax_logging_level", "DEBUG")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
@jax.jitdeff(x):
returnx+1# Function invocationx=jnp.zeros((2, 2))
f(x)
This issue is present also with ClassVar default values if they are JAX NumPy arrays and with default arguments of functions. (See also ami-iit/jaxsim#322 and ami-iit/jaxsim#329)
MRE with ClassVar:
importjaximportjax.numpyasjnpclassMyClass:
# This attribute causes the persistent cache to remain uninitialized.default_array: ClassVar=jnp.zeros((3,))
# Configuration for JAX persistent cachejax.config.update("jax_logging_level", "DEBUG")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
@jax.jitdeff(x):
returnx+1# Function invocationx=jnp.zeros((2, 2))
f(x)
MRE with default arguments:
importjaximportjax.numpyasjnp# The `array` default value causes the persistent cache to remain uninitialized.deftest_fn(array: jax.Array=jnp.zeros(3)):
returnarray# Configuration for JAX persistent cachejax.config.update("jax_logging_level", "DEBUG")
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
@jax.jitdeff(x):
returnx+1# Function invocationx=jnp.zeros((2, 2))
f(x)
This was quite hard to spot for me, so I would expect a more clear error message if for some reason the cache cannot be initialized.
I think this is working as expected: the compilation cache state must be enabled before backend initialization. Many of the JAX configuration options have similar requirements.
I'm assigning @skye who knows more about this code path and may be able to confirm.
Description
The persistent compilation cache in JAX fails to initialize if a JAX array is created prior to activating the local cache using jax.config.update. Removing the array creation line allows the cache to initialize correctly.
MRE with array allocation:
Full Log
This issue is present also with
ClassVar
default values if they are JAX NumPy arrays and with default arguments of functions. (See also ami-iit/jaxsim#322 and ami-iit/jaxsim#329)MRE with
ClassVar
:MRE with default arguments:
This was quite hard to spot for me, so I would expect a more clear error message if for some reason the cache cannot be initialized.
Thank you for your help!
FYI @traversaro @xela-95 @CarlottaSartore @paLeziart
System info (python version, jaxlib version, accelerator, etc.)
pip list
The text was updated successfully, but these errors were encountered: