Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend function signature for InitFn #627

Merged
merged 3 commits into from
Dec 12, 2023

Conversation

junpenglao
Copy link
Member

Close #619.

This PR introduce rng_key as optional input to initFn protocal.

For sampler like dynamic_hmc and ghmc, the init_fn of the top level API does not follow the old patter, as it needs an rng_key to generate part of the state. While it is possible to set a default rng_key in the class __new__, we actually wants to keep the rng_key as input to init_fn as we usually want to vmap init so that it takes a vector of PRNG_key to initialized parallel chains (see eg: in meads_adaptation

batch_init = jax.vmap(lambda p, r: mcmc.ghmc.init(p, r, logdensity_fn))
,
rng_keys = jax.random.split(key_init, num_chains)
init_states = batch_init(positions, rng_keys)
).

Thus, in this PR we introduce rng_key to the InitFn, with some minor refactoring to other top level API follows the same contract and easier to plug into utility sampling function run_inference_algorithm

@junpenglao
Copy link
Member Author

@AdrienCorenflos for changes in blackjax/mcmc/marginal_latent_gaussian.py
@reubenharry for changes in blackjax/mcmc/mclmc.py

Copy link

codecov bot commented Dec 11, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (029b981) 99.00% compared to head (559fc6b) 99.05%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #627      +/-   ##
==========================================
+ Coverage   99.00%   99.05%   +0.04%     
==========================================
  Files          58       58              
  Lines        2615     2632      +17     
==========================================
+ Hits         2589     2607      +18     
+ Misses         26       25       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@AdrienCorenflos
Copy link
Contributor

Why did the delta parameter get out of the step function? This seems like an unrelated change.

@junpenglao
Copy link
Member Author

Why did the delta parameter get out of the step function? This seems like an unrelated change.

I am not familiar with marginal_latent_gaussian, from the test it doesnt seems it is needed to be in step, but if the algorithm is intend to have different scale even for top level user API, I will revert those changes.

@AdrienCorenflos
Copy link
Contributor

AdrienCorenflos commented Dec 11, 2023 via email

@junpenglao
Copy link
Member Author

junpenglao commented Dec 11, 2023

Could user does this instead (which are in general how Blackjax sees top level API):

from blackjax.mcmc import marginal_latent_gaussian as mgrad

init, kernel = mgrad.init_and_kernel(...)

# does tuning for delta
...

# Actual sampling, delta is now fixed.
algorithm = mgrad.mgrad_gaussian(logdensity_fn, cov, delta)
state = algorithm.init(...)
for ...
    state, info = algorithm.step(rnd_key, state)

@AdrienCorenflos
Copy link
Contributor

Could user does this instead (which are in general how Blackjax sees top level API):

from blackjax.mcmc import marginal_latent_gaussian as mgrad

init, kernel = mgrad.init_and_kernel(...)

# does tuning for delta
...

# Actual sampling, delta is now fixed.
algorithm = mgrad.mgrad_gaussian(logdensity_fn, cov, delta)
state = algorithm.init(...)
for ...
    state, info = algorithm.step(rnd_key, state)

I suppose, but this would be a waste of computational resources. The init_and_kernel does some handling of the covariance matrix (SVD) which is easily the most expensive part. Why would you want to do it again post calibration?

@junpenglao
Copy link
Member Author

To maintain the same API for what we consider "top-level". This is done to be easier to compare different sampler. I suppose your concern here is that the 2nd call to init_and_kernel is wasteful? Sounds like it is better to refactor out the class altogether so that the result of svd is the input to the top level API mgrad_gaussian.

@junpenglao junpenglao merged commit 4058971 into blackjax-devs:main Dec 12, 2023
7 checks passed
@junpenglao junpenglao deleted the base_api branch December 12, 2023 10:06
junpenglao added a commit that referenced this pull request Mar 12, 2024
* Extend function signature for InitFn

* Fix formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve SamplingAlgorithm design for init_fn and step_fn
2 participants