Skip to content

Commit

Permalink
fix: Validate chain tensor redundancy in model loading and enforce st…
Browse files Browse the repository at this point in the history
…rict state dict loading
  • Loading branch information
rhoadesScholar committed Jan 15, 2025
1 parent b00a5ea commit 8c6e831
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/cellmap_models/pytorch/cosem/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,27 @@ def load_model(checkpoint_name: str) -> torch.nn.Module:
new_checkpoint = deepcopy(checkpoint)
for key in checkpoint["model"].keys():
if "chain" in key:
# Verify that the chain tensors are redundant and can be removed
if key.replace("chain.0", "architecture") in checkpoint["model"].keys():
assert (
checkpoint["model"][key]
== checkpoint["model"][key.replace("chain.0", "architecture")]
).all(), f"Chain {key} does not match architecture {key.replace('chain.0', 'architecture')}"
elif (
key.replace("chain.1", "prediction_head") in checkpoint["model"].keys()
):
assert (
checkpoint["model"][key]
== checkpoint["model"][key.replace("chain.1", "prediction_head")]
).all(), f"Chain {key} does not match prediction_head {key.replace('chain.1', 'prediction_head')}"
else:
raise ValueError(f"No match found for key {key} in checkpoint")
new_checkpoint["model"].pop(key)
continue
new_key = key.replace("architecture.", "")
new_key = new_key.replace("unet.", "backbone.")
new_checkpoint["model"][new_key] = new_checkpoint["model"].pop(key)
model.load_state_dict(new_checkpoint["model"])
model.load_state_dict(new_checkpoint["model"], strict=True)
model.eval()

return model
Expand Down

0 comments on commit 8c6e831

Please sign in to comment.