Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to load the model in two_tower_retrieval.py script #2679

Open
wei-m-teh opened this issue Jan 11, 2025 · 0 comments
Open

Failed to load the model in two_tower_retrieval.py script #2679

wei-m-teh opened this issue Jan 11, 2025 · 0 comments

Comments

@wei-m-teh
Copy link

wei-m-teh commented Jan 11, 2025

I was able to train the model using the following command:

torchx run -s local_cwd dist.ddp -j 1x1 --gpu 1 --script two_tower_train.py -- --save_dir model

However, when running the retrieval script:

CUDA_VISIBLE_DEVICES=0 python two_tower_retrieval.py --load_dir model

I ran into the following error:

File "/home/sagemaker-user/.conda/envs/torchrec/lib/python3.11/site-packages/torch/serialization.py", line 1359, in load
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
(1) Re-running torch.load with weights_only set to False will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with weights_only=True please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL torch.distributed._shard.sharded_tensor.api.ShardedTensor was not an allowed global by default. Please use torch.serialization.add_safe_globals([ShardedTensor]) to allowlist this global if you trust this class/function.

Based on the suggestions from the tracktrace, I made the changes to make the Classes allowed global:

with torch.serialization.safe_globals([ShardedTensor, Shard, ShardMetadata, _remote_device]):
        two_tower_sd = torch.load(f"{load_dir}/model.pt", weights_only=True)

This time, I ran into this error:

two_tower_retrieval/0 [0]: File "/home/sagemaker-user/.conda/envs/torchrec/lib/python3.11/site-packages/torch/_weights_only_unpickler.py", line 292, in load
two_tower_retrieval/0 [0]: inst.dict.update(state)
two_tower_retrieval/0 [0]: ^^^^^^^^^^^^^
two_tower_retrieval/0 [0]:AttributeError: 'ShardMetadata' object has no attribute 'dict'. Did you mean: 'dir'?

Any idea why this is happening?
I ran both scripts in a g5.2xlarge instance with single GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant