Skip to content

Commit

Permalink
Merge pull request #9 from kuanweih/dev
Browse files Browse the repository at this point in the history
remove Variable
  • Loading branch information
kuanweih authored May 10, 2022
2 parents f41c676 + 0f01a1b commit 9a93293
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions model_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
Modified by Kuan-Wei Huang
"""


from tqdm import tqdm
from pathlib import Path

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

from torch.autograd import Variable
from tqdm import tqdm
from pathlib import Path
from transformers import ViTForImageClassification

from src.data_utils import (
Expand All @@ -38,17 +35,17 @@ def prepare_data_and_target(data, target_dict, device):
""" Prepare data (X) and target (Y) for a given batch.
Args:
data (np.array): image (X)
data (torch.Tensor): image (X)
target_dict (dict): targets (Y)
device (torch.device): cpu or gpu
Returns:
data (torch.Tensor): image (X)
target (torch.Tensor): targets(Y)
"""
data = Variable(data.float()).to(device)
data = data.float().to(device)
for key, val in target_dict.items():
target_dict[key] = Variable(val.float()).to(device)
target_dict[key] = val.float().to(device)
target = torch.cat([val for _, val in target_dict.items()], dim=1)
return data, target

Expand Down

0 comments on commit 9a93293

Please sign in to comment.