diff --git a/src/cellmap_models/pytorch/cosem/load_model.py b/src/cellmap_models/pytorch/cosem/load_model.py index 6aff03b..1294f95 100755 --- a/src/cellmap_models/pytorch/cosem/load_model.py +++ b/src/cellmap_models/pytorch/cosem/load_model.py @@ -78,12 +78,27 @@ def load_model(checkpoint_name: str) -> torch.nn.Module: new_checkpoint = deepcopy(checkpoint) for key in checkpoint["model"].keys(): if "chain" in key: + # Verify that the chain tensors are redundant and can be removed + if key.replace("chain.0", "architecture") in checkpoint["model"].keys(): + assert ( + checkpoint["model"][key] + == checkpoint["model"][key.replace("chain.0", "architecture")] + ).all(), f"Chain {key} does not match architecture {key.replace('chain.0', 'architecture')}" + elif ( + key.replace("chain.1", "prediction_head") in checkpoint["model"].keys() + ): + assert ( + checkpoint["model"][key] + == checkpoint["model"][key.replace("chain.1", "prediction_head")] + ).all(), f"Chain {key} does not match prediction_head {key.replace('chain.1', 'prediction_head')}" + else: + raise ValueError(f"No match found for key {key} in checkpoint") new_checkpoint["model"].pop(key) continue new_key = key.replace("architecture.", "") new_key = new_key.replace("unet.", "backbone.") new_checkpoint["model"][new_key] = new_checkpoint["model"].pop(key) - model.load_state_dict(new_checkpoint["model"]) + model.load_state_dict(new_checkpoint["model"], strict=True) model.eval() return model