diff --git a/blackjax/base.py b/blackjax/base.py index 8ea24cd70..1ddfff9cc 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Optional +from typing import Callable, NamedTuple, Optional, Any from typing_extensions import Protocol @@ -64,7 +64,7 @@ class UpdateFn(Protocol): """ - def __call__(self, rng_key: PRNGKey, state: State) -> tuple[State, Info]: + def __call__(self, rng_key: PRNGKey, state: State, *args: Any) -> tuple[State, Info]: """Update the current state using the sampling algorithm. Parameters