From 8c5bb34a2b7583c88d942ef0041462354132cbf8 Mon Sep 17 00:00:00 2001 From: Jannis Bolik Date: Wed, 24 Jul 2024 21:43:05 +0200 Subject: [PATCH 1/2] Add *args to UpdateFn.__call__ --- blackjax/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/base.py b/blackjax/base.py index 8ea24cd70..e94ea2697 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -64,7 +64,7 @@ class UpdateFn(Protocol): """ - def __call__(self, rng_key: PRNGKey, state: State) -> tuple[State, Info]: + def __call__(self, rng_key: PRNGKey, state: State, *args) -> tuple[State, Info]: """Update the current state using the sampling algorithm. Parameters From 3149cf0fbf15e7c6f2f5a54e3da6560aaa5078a5 Mon Sep 17 00:00:00 2001 From: Jannis Bolik Date: Wed, 24 Jul 2024 22:15:33 +0000 Subject: [PATCH 2/2] Try adding Any --- blackjax/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/base.py b/blackjax/base.py index e94ea2697..1ddfff9cc 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Optional +from typing import Callable, NamedTuple, Optional, Any from typing_extensions import Protocol @@ -64,7 +64,7 @@ class UpdateFn(Protocol): """ - def __call__(self, rng_key: PRNGKey, state: State, *args) -> tuple[State, Info]: + def __call__(self, rng_key: PRNGKey, state: State, *args: Any) -> tuple[State, Info]: """Update the current state using the sampling algorithm. Parameters