Skip to content

Commit

Permalink
Merge pull request #8 from kuanweih/dev
Browse files Browse the repository at this point in the history
bug fixed: use calc_pred for different model types
  • Loading branch information
kuanweih authored May 10, 2022
2 parents b567ec3 + e81f746 commit f41c676
Showing 1 changed file with 38 additions and 8 deletions.
46 changes: 38 additions & 8 deletions model_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

Expand Down Expand Up @@ -136,6 +137,34 @@ def save_model(CONFIG, model, epoch, test_loss):
print(f"\nSave model to {model_save_path}\n")


def calc_pred(model, data):
""" Calculate prediction of input data using model.
pred = model(data)
Different model objects have different pred shapes by default such as
ViT and ResNet.
Args:
model (model object): ViT or ResNet
data (torch.Tensor): batch data parsed in the model
Raises:
TypeError: type(model) has to be checked
Returns:
[torch.Tensor]: prediction
"""
if isinstance(model, ViTForImageClassification):
pred = model(data)[0]
elif isinstance(model, torchvision.models.resnet.ResNet):
pred = model(data)
else:
raise TypeError(f"{type(model)} not implemented for correct pred shape.")

return pred


def train_model(CONFIG):
""" Train models based on parameters in CONFIG.
Expand All @@ -155,7 +184,7 @@ def train_model(CONFIG):

# load model and cast to 'device'
model = load_model(CONFIG)
model.to(device)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=CONFIG['init_learning_rate'])

Expand All @@ -173,14 +202,15 @@ def train_model(CONFIG):
for data, target_dict in tqdm(train_loader, total=len(train_loader)):
data, target = prepare_data_and_target(data, target_dict, device)
optimizer.zero_grad()
output = model(data)[0]
loss = calc_loss(output, target, CONFIG, device)

cache_train.update_cache(output, target, loss)
pred = calc_pred(model, data)
loss = calc_loss(pred, target, CONFIG, device)
cache_train.update_cache(pred, target, loss)

loss.backward()
optimizer.step()


cache_train.calc_avg_across_batches()
cache_train.print_cache(epoch)

Expand All @@ -191,7 +221,7 @@ def train_model(CONFIG):

for data, target_dict in test_loader:
data, target = prepare_data_and_target(data, target_dict, device)
pred = model(data)[0]
pred = calc_pred(model, data)
loss = calc_loss(pred, target, CONFIG, device)

cache_test.update_cache(pred, target, loss)
Expand Down Expand Up @@ -220,8 +250,8 @@ def train_model(CONFIG):
'epoch': 4,
'batch_size': 30,
'load_new_model': True,
'new_model_name': "google/vit-base-patch16-224", # for 'load_new_model' = True
# 'new_model_name': "resnet18", # for 'load_new_model' = True
# 'new_model_name': "google/vit-base-patch16-224", # for 'load_new_model' = True
'new_model_name': "resnet18", # for 'load_new_model' = True
'resumed_model_path': Path(""), # for 'load_new_model' = False
'output_folder': Path("C:/Users/abcd2/Downloads/tmp_dev_outputs"), # needs to be non-existing
'dataset_folder': Path("C:/Users/abcd2/Datasets/2022_icml_lens_sim/dev_256"),
Expand Down

0 comments on commit f41c676

Please sign in to comment.