diff --git a/model_vit.py b/model_vit.py index b08d336..68a02e0 100644 --- a/model_vit.py +++ b/model_vit.py @@ -3,16 +3,13 @@ Modified by Kuan-Wei Huang """ - -from tqdm import tqdm -from pathlib import Path - import torch import torchvision import torch.nn as nn import torch.optim as optim -from torch.autograd import Variable +from tqdm import tqdm +from pathlib import Path from transformers import ViTForImageClassification from src.data_utils import ( @@ -38,7 +35,7 @@ def prepare_data_and_target(data, target_dict, device): """ Prepare data (X) and target (Y) for a given batch. Args: - data (np.array): image (X) + data (torch.Tensor): image (X) target_dict (dict): targets (Y) device (torch.device): cpu or gpu @@ -46,9 +43,9 @@ def prepare_data_and_target(data, target_dict, device): data (torch.Tensor): image (X) target (torch.Tensor): targets(Y) """ - data = Variable(data.float()).to(device) + data = data.float().to(device) for key, val in target_dict.items(): - target_dict[key] = Variable(val.float()).to(device) + target_dict[key] = val.float().to(device) target = torch.cat([val for _, val in target_dict.items()], dim=1) return data, target