-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MCLMC sampler #586
Add MCLMC sampler #586
Conversation
After a bit of a mammoth effort, the PR is ready for review. One possible point of objection is that there are still some poorly named variables, now confined to certain functions in the adaptation code. These I have not renamed because I don't fully understand what all of them are doing, so some reading/talking to Jakob is in order. But there's now a test of both the tuning and the kernel, that uses pytrees, showing that the code is pytree generic. Moreover, the implementation reproduces the original MCLMC implementation perfectly (with appropriate choices of random seeds), so other than a few details (I removed the preconditioning for now, for simplicity), it's hopefully looking good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job! I indeed have some comments re naming.
blackjax/mcmc/integrators.py
Outdated
@@ -1,3 +1,4 @@ | |||
# @title `integrators.py` from https://github.com/blackjax-devs/blackjax/pull/589 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder to remove this line.
|
||
gamma_forget = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) | ||
|
||
def predictor(state_old, params, adaptive_state, rng_key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder to change the naming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!! Thank you
* initial draft of mclmc * refactor * wip * wip * wip * wip * wip * fix pre-commit * remove dim from class * add docstrings * add mclmc to init * move minimal_norm to integrators * move update pos and momentum * remove params * Infer the shape from inverse_mass_matrix outside the function step * use tree_map * integration now aligned with mclmc repo * dE and logdensity align too (fixed sign error) * make L and step size arguments to kernel * rough draft of tuning: works * remove inv mass matrix * almost correct * almost correct * move tuning to adaptation * tuning works in this commit * clean up 1 * remove sigma from tuning * wip * fix linting * rename T and V * uniformity wip * make uniform implementation of integrators * make uniform implementation of integrators * fix minimal norm integrator * add warning to tune3 * Refactor integrators.py to make it more general. Also add momentum update based on Esh dynamics Co-authored-by: Reuben Cohn-Gordon <[email protected]> * temp: explore * Refactor to use integrator generation functions * Additional refactoring Also add test for esh momentum update. Co-authored-by: Reuben Cohn-Gordon <[email protected]> * Minor clean up. * Use standard JAX ops * new integrator * add references * flake * temporarily add 'explore' * temporarily add 'explore' * Adding a test for energy preservation. Co-authored-by: Reuben Cohn-Gordon <[email protected]> * fix formatting * wip: tests * use pytrees for partially_refresh_momentum, and add test * update docstring * remove 'explore' * fix pre-commit * adding randomized MCHMC * wip checkpoint on tuning * align blackjax and mclmc repos, for tuning * use effective_sample_size * patial rename * rename * clean up tuning * clean up tuning * RANDOMIZE KEYS * ADD TEST * ADD TEST * MERGE MAIN * INCREASE CODE COVERAGE * REMOVE REDUNDANT LINE * ADD NAME 'mclmc' * SPLIT KEYS AND FIX DOCSTRING * FIX MINOR ERRORS * FIX MINOR ERRORS * RANDOMIZE KEYS (reversion) * PRECOMMIT CLEAN UP * ADD KWARGS FOR DEFAULT HYPERPARAMS * UPDATE ESS * NAME CHANGES * NAME CHANGES * MINOR FIXES --------- Co-authored-by: Junpeng Lao <[email protected]> Co-authored-by: jakob.robnik <[email protected]>
Todos:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;explore.py
scriptAs per #530, this PR adds MCLMC (Microcanonical Langevin Monte Carlo) as a sampling method to BlackJax. It is an adaptation of the original repo: https://github.com/JakobRobnik/MicroCanonicalHMC/tree/master.
In particular, this is an implementation of the single chain partially momentum updating algorithm (as opposed to full updates every n steps). It also will implement (once this PR is ready) the MCLMC tuning algorithm.