Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss is negative when finetune pretrained model on megadepth #90

Open
skill-diver opened this issue Nov 30, 2024 · 1 comment
Open

Loss is negative when finetune pretrained model on megadepth #90

skill-diver opened this issue Nov 30, 2024 · 1 comment

Comments

@skill-diver
Copy link

skill-diver commented Nov 30, 2024

Hi Author. Thank you for the great work at first.
I get a issue when training your pretrained model on megadepth, I find the loss is negative.

image

Here's the training scripts I used:

CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.run --nproc_per_node=4 train.py
--train_dataset "68_400 @ MegaDepth(ROOT='megadepth', split='train', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], aug_crop='auto', aug_monocular=0.005, transform=ColorJitter, n_corres=8192, nneg=0.5)"
--test_dataset "1_000 @ MegaDepth(ROOT='megadepth', split='val', resolution=(512,336), seed=777, n_corres=1024)"
--model "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True, desc_conf_mode=('exp', 0, inf))"
--train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)"
--test_criterion "Regr3D(L21, norm_mode='?avg_dis', gt_scale=True, sky_loss_value=0) + -1.*MatchingLoss(APLoss(nq='torch', fp=torch.float16), negatives_padding=12288)"
--pretrained "/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
--lr 0.0001 --min_lr 1e-06 --warmup_epochs 8 --epochs 50 --batch_size 2 --accum_iter 2
--save_freq 1 --keep_freq 5 --eval_freq 1 --print_freq=10 --disable_cudnn_benchmark
--output_dir "checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"

Is the negative loss caused by my training set up or it's normal?

@yocabon
Copy link
Contributor

yocabon commented Dec 2, 2024

1: yes negative confidence loss is normal
2: note that you are not training with the matching loss, but it's there in the validation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants