forked from lee-zq/3DUNet-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocess_LiTS.py
122 lines (104 loc) · 5.8 KB
/
preprocess_LiTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import os
import SimpleITK as sitk
import random
from scipy import ndimage
from os.path import join
import config
class LITS_preprocess:
def __init__(self, raw_dataset_path,fixed_dataset_path, args):
self.raw_root_path = raw_dataset_path
self.fixed_path = fixed_dataset_path
self.classes = args.n_labels # 分割类别数(只分割肝脏为2,或者分割肝脏和肿瘤为3)
self.upper = args.upper
self.lower = args.lower
self.expand_slice = args.expand_slice # 轴向外侧扩张的slice数量
self.size = args.min_slices # 取样的slice数量
self.xy_down_scale = args.xy_down_scale
self.slice_down_scale = args.slice_down_scale
self.valid_rate = args.valid_rate
def fix_data(self):
if not os.path.exists(self.fixed_path): # 创建保存目录
os.makedirs(join(self.fixed_path,'ct'))
os.makedirs(join(self.fixed_path, 'label'))
file_list = os.listdir(join(self.raw_root_path,'ct'))
Numbers = len(file_list)
print('Total numbers of samples is :',Numbers)
for ct_file,i in zip(file_list,range(Numbers)):
print("==== {} | {}/{} ====".format(ct_file, i+1,Numbers))
ct_path = os.path.join(self.raw_root_path, 'ct', ct_file)
seg_path = os.path.join(self.raw_root_path, 'label', ct_file.replace('volume', 'segmentation'))
new_ct, new_seg = self.process(ct_path, seg_path, classes = self.classes)
if new_ct != None and new_seg != None:
sitk.WriteImage(new_ct, os.path.join(self.fixed_path, 'ct', ct_file))
sitk.WriteImage(new_seg, os.path.join(self.fixed_path, 'label', ct_file.replace('volume', 'segmentation').replace('.nii', '.nii.gz')))
def process(self, ct_path, seg_path, classes=None):
ct = sitk.ReadImage(ct_path, sitk.sitkInt16)
ct_array = sitk.GetArrayFromImage(ct)
seg = sitk.ReadImage(seg_path, sitk.sitkInt8)
seg_array = sitk.GetArrayFromImage(seg)
print("Ori shape:",ct_array.shape, seg_array.shape)
if classes==2:
# 将金标准中肝脏和肝肿瘤的标签融合为一个
seg_array[seg_array > 0] = 1
# 将灰度值在阈值之外的截断掉
ct_array[ct_array > self.upper] = self.upper
ct_array[ct_array < self.lower] = self.lower
# 降采样,(对x和y轴进行降采样,slice轴的spacing归一化到slice_down_scale)
ct_array = ndimage.zoom(ct_array, (ct.GetSpacing()[-1] / self.slice_down_scale, self.xy_down_scale, self.xy_down_scale), order=3)
seg_array = ndimage.zoom(seg_array, (ct.GetSpacing()[-1] / self.slice_down_scale, self.xy_down_scale, self.xy_down_scale), order=0)
# 找到肝脏区域开始和结束的slice,并各向外扩张
z = np.any(seg_array, axis=(1, 2))
start_slice, end_slice = np.where(z)[0][[0, -1]]
# 两个方向上各扩张个slice
if start_slice - self.expand_slice < 0:
start_slice = 0
else:
start_slice -= self.expand_slice
if end_slice + self.expand_slice >= seg_array.shape[0]:
end_slice = seg_array.shape[0] - 1
else:
end_slice += self.expand_slice
print("Cut out range:",str(start_slice) + '--' + str(end_slice))
# 如果这时候剩下的slice数量不足size,直接放弃,这样的数据很少
if end_slice - start_slice + 1 < self.size:
print('Too little slice,give up the sample:', ct_file)
return None,None
# 截取保留区域
ct_array = ct_array[start_slice:end_slice + 1, :, :]
seg_array = seg_array[start_slice:end_slice + 1, :, :]
print("Preprocessed shape:",ct_array.shape,seg_array.shape)
# 保存为对应的格式
new_ct = sitk.GetImageFromArray(ct_array)
new_ct.SetDirection(ct.GetDirection())
new_ct.SetOrigin(ct.GetOrigin())
new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / self.xy_down_scale), ct.GetSpacing()[1] * int(1 / self.xy_down_scale), self.slice_down_scale))
new_seg = sitk.GetImageFromArray(seg_array)
new_seg.SetDirection(ct.GetDirection())
new_seg.SetOrigin(ct.GetOrigin())
new_seg.SetSpacing((ct.GetSpacing()[0] * int(1 / self.xy_down_scale), ct.GetSpacing()[1] * int(1 / self.xy_down_scale), self.slice_down_scale))
return new_ct, new_seg
def write_train_val_name_list(self):
data_name_list = os.listdir(join(self.fixed_path, "ct"))
data_num = len(data_name_list)
print('the fixed dataset total numbers of samples is :', data_num)
random.shuffle(data_name_list)
assert self.valid_rate < 1.0
train_name_list = data_name_list[0:int(data_num*(1-self.valid_rate))]
val_name_list = data_name_list[int(data_num*(1-self.valid_rate)):int(data_num*((1-self.valid_rate) + self.valid_rate))]
self.write_name_list(train_name_list, "train_path_list.txt")
self.write_name_list(val_name_list, "val_path_list.txt")
def write_name_list(self, name_list, file_name):
f = open(join(self.fixed_path, file_name), 'w')
for name in name_list:
ct_path = os.path.join(self.fixed_path, 'ct', name)
seg_path = os.path.join(self.fixed_path, 'label', name.replace('volume', 'segmentation'))
f.write(ct_path + ' ' + seg_path + "\n")
f.close()
if __name__ == '__main__':
raw_dataset_path = '/ssd/lzq/dataset/LiTS/train'
fixed_dataset_path = '/ssd/lzq/dataset/fixed_lits'
args = config.args
tool = LITS_preprocess(raw_dataset_path,fixed_dataset_path, args)
tool.fix_data() # 对原始图像进行修剪并保存
tool.write_train_val_name_list() # 创建索引txt文件