-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprepare_customed_dataset.py
76 lines (43 loc) · 1.53 KB
/
prepare_customed_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#-*-coding:utf-8-*-
import os
import random
ratio=0.8
data_set_dir='./IMAGENET'
def prepare_data():
labels=os.listdir(data_set_dir)
##filter
labels=[x for x in labels if os.path.isdir(os.path.join(data_set_dir,x))]
labels.sort()
syntext_f=open('label.txt','w')
for label in labels:
message = label + ' ' +str(labels.index(label)) +'\n'
syntext_f.write(message)
train_f=open('train.txt','w')
val_f=open('val.txt','w')
for label in labels:
cur_dir=os.path.join(data_set_dir,label)
pic_list=os.listdir(cur_dir)
random.shuffle(pic_list)
num_data=len(pic_list)
train_list=pic_list[:int(ratio*num_data)]
val_list = pic_list[int(ratio*num_data):]
for pic in train_list:
cur_path=os.path.join(cur_dir,pic)
message=cur_path+'|'+str(labels.index(label))+'\n'
train_f.write(message)
for pic in val_list:
cur_path = os.path.join(cur_dir, pic)
message = cur_path + '|' + str(labels.index(label)) + '\n'
val_f.write(message)
train_f.close()
val_f.close()
syntext_f.close()
if __name__=='__main__':
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--ratio", required=False, default=0.8,type=float, help="train val split ratio")
ap.add_argument("--data_dir", required=False, default="IMAGENET", help="train val split ratio")
args = ap.parse_args()
ratio = args.ratio
data_set_dir=args.datadir
prepare_data()