diff --git a/blackjax/util.py b/blackjax/util.py index 608183cc9..2b52ca5ae 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -9,7 +9,7 @@ from jax.random import normal, split from jax.tree_util import tree_leaves -from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm +from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import progress_bar_scan from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey