Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MCLMC sampler #586

Merged
merged 91 commits into from
Dec 5, 2023
Merged

Add MCLMC sampler #586

merged 91 commits into from
Dec 5, 2023

Conversation

reubenharry
Copy link
Contributor

@reubenharry reubenharry commented Nov 11, 2023

Todos:

  • add tuning: without the tuning algorithm, MCLMC is next to useless
  • If I add a new sampler, there is an issue discussing it already;
  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date; (add references to bibliography, check what you need to do for API docs)
  • If I add a new sampler* I added/updated related examples
  • Remove the explore.py script
  • Test that the dE and logdensity reported are correct

As per #530, this PR adds MCLMC (Microcanonical Langevin Monte Carlo) as a sampling method to BlackJax. It is an adaptation of the original repo: https://github.com/JakobRobnik/MicroCanonicalHMC/tree/master.

In particular, this is an implementation of the single chain partially momentum updating algorithm (as opposed to full updates every n steps). It also will implement (once this PR is ready) the MCLMC tuning algorithm.

blackjax/mcmc/mclmc.py Outdated Show resolved Hide resolved
@reubenharry reubenharry marked this pull request as ready for review December 2, 2023 01:31
@reubenharry
Copy link
Contributor Author

After a bit of a mammoth effort, the PR is ready for review.

One possible point of objection is that there are still some poorly named variables, now confined to certain functions in the adaptation code. These I have not renamed because I don't fully understand what all of them are doing, so some reading/talking to Jakob is in order.

But there's now a test of both the tuning and the kernel, that uses pytrees, showing that the code is pytree generic. Moreover, the implementation reproduces the original MCLMC implementation perfectly (with appropriate choices of random seeds), so other than a few details (I removed the preconditioning for now, for simplicity), it's hopefully looking good.

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! I indeed have some comments re naming.

@@ -1,3 +1,4 @@
# @title `integrators.py` from https://github.com/blackjax-devs/blackjax/pull/589
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to remove this line.

blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/mcmc/mclmc.py Outdated Show resolved Hide resolved
blackjax/util.py Outdated Show resolved Hide resolved
blackjax/mcmc/mclmc.py Outdated Show resolved Hide resolved
blackjax/__init__.py Show resolved Hide resolved
blackjax/__init__.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved
blackjax/adaptation/mclmc_adaptation.py Outdated Show resolved Hide resolved

gamma_forget = (num_effective_samples - 1.0) / (num_effective_samples + 1.0)

def predictor(state_old, params, adaptive_state, rng_key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to change the naming.

blackjax/adaptation/mclmc_adaptation.py Show resolved Hide resolved
blackjax/mcmc/mclmc.py Outdated Show resolved Hide resolved
Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!! Thank you

@junpenglao junpenglao merged commit 039b277 into blackjax-devs:main Dec 5, 2023
7 checks passed
junpenglao added a commit that referenced this pull request Mar 12, 2024
* initial draft of mclmc

* refactor

* wip

* wip

* wip

* wip

* wip

* fix pre-commit

* remove dim from class

* add docstrings

* add mclmc to init

* move minimal_norm to integrators

* move update pos and momentum

* remove params

* Infer the shape from inverse_mass_matrix outside the function step

* use tree_map

* integration now aligned with mclmc repo

* dE and logdensity align too (fixed sign error)

* make L and step size arguments to kernel

* rough draft of tuning: works

* remove inv mass matrix

* almost correct

* almost correct

* move tuning to adaptation

* tuning works in this commit

* clean up 1

* remove sigma from tuning

* wip

* fix linting

* rename T and V

* uniformity wip

* make uniform implementation of integrators

* make uniform implementation of integrators

* fix minimal norm integrator

* add warning to tune3

* Refactor integrators.py to make it more general.
Also add momentum update based on Esh dynamics

Co-authored-by: Reuben Cohn-Gordon <[email protected]>

* temp: explore

* Refactor to use integrator generation functions

* Additional refactoring

Also add test for esh momentum update.

Co-authored-by: Reuben Cohn-Gordon <[email protected]>

* Minor clean up.

* Use standard JAX ops

* new integrator

* add references

* flake

* temporarily add 'explore'

* temporarily add 'explore'

* Adding a test for energy preservation.

Co-authored-by: Reuben Cohn-Gordon <[email protected]>

* fix formatting

* wip: tests

* use pytrees for partially_refresh_momentum, and add test

* update docstring

* remove 'explore'

* fix pre-commit

* adding randomized MCHMC

* wip checkpoint on tuning

* align blackjax and mclmc repos, for tuning

* use effective_sample_size

* patial rename

* rename

* clean up tuning

* clean up tuning

* RANDOMIZE KEYS

* ADD TEST

* ADD TEST

* MERGE MAIN

* INCREASE CODE COVERAGE

* REMOVE REDUNDANT LINE

* ADD NAME 'mclmc'

* SPLIT KEYS AND FIX DOCSTRING

* FIX MINOR ERRORS

* FIX MINOR ERRORS

* RANDOMIZE KEYS (reversion)

* PRECOMMIT CLEAN UP

* ADD KWARGS FOR DEFAULT HYPERPARAMS

* UPDATE ESS

* NAME CHANGES

* NAME CHANGES

* MINOR FIXES

---------

Co-authored-by: Junpeng Lao <[email protected]>
Co-authored-by: jakob.robnik <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants