-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
110 lines (99 loc) · 2.81 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
from tqdm import tqdm
from functools import partial
from torch.utils.data import DataLoader
from modelscope.msdatasets import MsDataset
from torchvision.transforms import Compose, Resize, RandomAffine, ToTensor, Normalize
def transform(example_batch, data_column: str, label_column: str, img_size: int):
compose = Compose(
[
Resize([img_size, img_size]),
# RandomAffine(5),
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
inputs = [compose(x.convert("RGB")) for x in example_batch[data_column]]
example_batch[data_column] = inputs
keys = list(example_batch.keys())
for key in keys:
if not (key == data_column or key == label_column):
del example_batch[key]
return example_batch
def prepare_data(dataset: str, subset: str, label_col: str, use_wce: bool):
print("Preparing & loading data...")
ds = MsDataset.load(
dataset,
subset_name=subset,
cache_dir="./__pycache__",
)
classes = ds["test"].features[label_col].names
num_samples = []
if use_wce:
each_nums = {k: 0 for k in classes}
for item in tqdm(ds["train"], desc="Statistics by category for WCE loss"):
each_nums[classes[item[label_col]]] += 1
num_samples = list(each_nums.values())
return ds, classes, num_samples
def load_data(
ds: MsDataset,
data_col: str,
label_col: str,
input_size: int,
has_bn: bool,
shuffle=True,
batch_size=4,
):
bs = batch_size
if has_bn:
print("The model has bn layer")
if bs < 2:
print("Switch batch_size >= 2")
bs = 2
trainset = ds["train"].with_transform(
partial(
transform,
data_column=data_col,
label_column=label_col,
img_size=input_size,
)
)
validset = ds["validation"].with_transform(
partial(
transform,
data_column=data_col,
label_column=label_col,
img_size=input_size,
)
)
testset = ds["test"].with_transform(
partial(
transform,
data_column=data_col,
label_column=label_col,
img_size=input_size,
)
)
num_workers = os.cpu_count() // 2
traLoader = DataLoader(
trainset,
batch_size=bs,
shuffle=shuffle,
num_workers=num_workers,
drop_last=has_bn,
)
valLoader = DataLoader(
validset,
batch_size=bs,
shuffle=shuffle,
num_workers=num_workers,
drop_last=has_bn,
)
tesLoader = DataLoader(
testset,
batch_size=bs,
shuffle=shuffle,
num_workers=num_workers,
drop_last=has_bn,
)
return traLoader, valLoader, tesLoader