Skip to content

Commit

Permalink
update test and evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
taigw committed Jan 13, 2018
1 parent cf5231a commit 12b90b5
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 74 deletions.
44 changes: 18 additions & 26 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,8 @@ def test(config_file):
margin = config_test.get('roi_patch_margin', 5)

for i in range(image_num):
[imgs, weight, temp_name] = dataloader.get_image_data_with_name(i)
[temp_imgs, temp_weight, temp_name, temp_bbox, temp_size] = dataloader.get_image_data_with_name(i)
t0 = time.time()
groi = get_roi(weight > 0, margin)
temp_imgs = [x[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))] \
for x in imgs]
temp_weight = weight[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))]

# 5.1, test of 1st network
if(config_net1):
data_shapes = [ data_shape1[:-1], data_shape1[:-1], data_shape1[:-1]]
Expand All @@ -328,20 +323,20 @@ def test(config_file):
wt_threshold = 2000
if(config_test.get('whole_tumor_only', False) is True):
pred1_lc = ndimage.morphology.binary_closing(pred1, structure = struct)
pred1_lc = get_largest_two_component(pred1_lc, True, wt_threshold)
pred1_lc = get_largest_two_component(pred1_lc, False, wt_threshold)
out_label = pred1_lc
else:
# 5.2, test of 2nd network
if(pred1.sum() == 0):
print('net1 output is null', temp_name)
roi2 = get_roi(temp_imgs[0] > 0, margin)
bbox1 = get_ND_bounding_box(temp_imgs[0] > 0, margin)
else:
pred1_lc = ndimage.morphology.binary_closing(pred1, structure = struct)
pred1_lc = get_largest_two_component(pred1_lc, True, wt_threshold)
roi2 = get_roi(pred1_lc, margin)
sub_imgs = [x[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] \
for x in temp_imgs]
sub_weight = temp_weight[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))]
pred1_lc = get_largest_two_component(pred1_lc, False, wt_threshold)
bbox1 = get_ND_bounding_box(pred1_lc, margin)
sub_imgs = [crop_ND_volume_with_bounding_box(one_img, bbox1[0], bbox1[1]) for one_img in temp_imgs]
sub_weight = crop_ND_volume_with_bounding_box(temp_weight, bbox1[0], bbox1[1])

if(config_net2):
data_shapes = [ data_shape2[:-1], data_shape2[:-1], data_shape2[:-1]]
label_shapes = [label_shape2[:-1], label_shape2[:-1], label_shape2[:-1]]
Expand All @@ -364,17 +359,16 @@ def test(config_file):
# 5.3, test of 3rd network
if(pred2.sum() == 0):
[roid, roih, roiw] = sub_imgs[0].shape
roi3 = [0, roid, 0, roih, 0, roiw]
bbox2 = [[0,0,0], [roid-1, roih-1, roiw-1]]
subsub_imgs = sub_imgs
subsub_weight = sub_weight
else:
pred2_lc = ndimage.morphology.binary_closing(pred2, structure = struct)
pred2_lc = get_largest_two_component(pred2_lc)
roi3 = get_roi(pred2_lc, margin)
subsub_imgs = [x[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))] \
for x in sub_imgs]
subsub_weight = sub_weight[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))]

bbox2 = get_ND_bounding_box(pred2_lc, margin)
subsub_imgs = [crop_ND_volume_with_bounding_box(one_img, bbox2[0], bbox2[1]) for one_img in sub_imgs]
subsub_weight = crop_ND_volume_with_bounding_box(sub_weight, bbox2[0], bbox2[1])

if(config_net3):
data_shapes = [ data_shape3[:-1], data_shape3[:-1], data_shape3[:-1]]
label_shapes = [label_shape3[:-1], label_shape3[:-1], label_shape3[:-1]]
Expand All @@ -399,12 +393,12 @@ def test(config_file):
# 5.4, fuse results at 3 levels
# convert subsub_label to full size (non-enhanced)
label3_roi = np.zeros_like(pred2)
label3_roi[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))] = pred3
label3_roi = set_ND_volume_roi_with_bounding_box_range(label3_roi, bbox2[0], bbox2[1], pred3)
label3 = np.zeros_like(pred1)
label3[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] = label3_roi
label3 = set_ND_volume_roi_with_bounding_box_range(label3, bbox1[0], bbox1[1], label3_roi)

label2 = np.zeros_like(pred1)
label2[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] = pred2
lalbe2 = set_ND_volume_roi_with_bounding_box_range(label2, bbox1[0], bbox1[1], pred2)

label1_mask = (pred1 + label2 + label3) > 0
label1_mask = ndimage.morphology.binary_closing(label1_mask, structure = struct)
Expand All @@ -422,7 +416,6 @@ def test(config_file):
label3 = label2 * label3
vox_3 = np.asarray(label3 > 0, np.float32).sum()
if(0 < vox_3 and vox_3 < 30):
print('ignored voxel number ', vox_3)
label3 = np.zeros_like(label2)

# 5.5, convert label and save output
Expand All @@ -438,9 +431,8 @@ def test(config_file):


test_time.append(time.time() - t0)
final_label = np.zeros_like(weight, np.int16)
final_label[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))] = out_label

final_label = np.zeros(temp_size, np.int16)
final_label = set_ND_volume_roi_with_bounding_box_range(final_label, temp_bbox[0], temp_bbox[1], out_label)
save_array_as_nifty_volume(final_label, save_folder+"/{0:}.nii.gz".format(temp_name))
print(temp_name)
test_time = np.asarray(test_time)
Expand Down
27 changes: 14 additions & 13 deletions util/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __load_one_volume(self, patient_name, mod):
volume = load_3d_volume_as_array(volume_name)
return volume

def load_data(self, stage='train'):
def load_data(self):
"""
load all the training/testing data
"""
Expand All @@ -81,41 +81,42 @@ def load_data(self, stage='train'):
X = []
W = []
Y = []
bbox = []
in_size = []
data_num = self.data_num if (self.data_num is not None) else len(self.patient_names)
for i in range(data_num):
volume_list = []
for mod_idx in range(len(self.modality_postfix)):
volume = self.__load_one_volume(self.patient_names[i], self.modality_postfix[mod_idx])
if(self.data_resize):
volume = resize_3D_volume_to_given_shape(volume, self.data_resize, 1)
if(mod_idx == 0):
margin = 5
[d_idxes, h_idxes, w_idxes] = np.nonzero(volume)
mind = d_idxes.min() - margin; maxd = d_idxes.max() + margin
minh = h_idxes.min() - margin; maxh = h_idxes.max() + margin
minw = w_idxes.min() - margin; maxw = w_idxes.max() + margin

if(stage == 'train'):
volume = volume[np.ix_(range(mind, maxd), range(minh, maxh), range(minw, maxw))]
bbmin, bbmax = get_ND_bounding_box(volume, margin)
volume_size = volume.shape
volume = crop_ND_volume_with_bounding_box(volume, bbmin, bbmax)
if(self.data_resize):
volume = resize_3D_volume_to_given_shape(volume, self.data_resize, 1)
if(mod_idx ==0):
weight = np.asarray(volume > 0, np.float32)
if(self.intensity_normalize[mod_idx]):
volume = itensity_normalize_one_volume(volume)
volume_list.append(volume)
X.append(volume_list)
W.append(weight)
bbox.append([bbmin, bbmax])
in_size.append(volume_size)
if(self.with_ground_truth):
label = self.__load_one_volume(self.patient_names[i], self.label_postfix)
label = crop_ND_volume_with_bounding_box(label, bbmin, bbmax)
if(self.data_resize):
label = resize_3D_volume_to_given_shape(label, self.data_resize, 0)
if(stage == 'train'):
label = label[np.ix_(range(mind, maxd), range(minh, maxh), range(minw, maxw))]
Y.append(label)
if((i+1)%50 == 0 or (i+1) == data_num):
print('Data load, {0:}% finished'.format((i+1)*100.0/data_num))
self.data = X
self.weight = W
self.label = Y
self.bbox = bbox
self.in_size= in_size

def get_subimage_batch(self):
"""
Expand Down Expand Up @@ -256,4 +257,4 @@ def get_image_data_with_name(self, i):
"""
Used for testing, get one image data and patient name
"""
return [self.data[i], self.weight[i], self.patient_names[i]]
return [self.data[i], self.weight[i], self.patient_names[i], self.bbox[i], self.in_size[i]]
92 changes: 71 additions & 21 deletions util/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,76 @@ def itensity_normalize_one_volume(volume):
out[volume == 0] = out_random[volume == 0]
return out

def get_ND_bounding_box(label, margin):
"""
get the bounding box of the non-zero region of an ND volume
"""
input_shape = label.shape
if(type(margin) is int ):
margin = [margin]*len(input_shape)
assert(len(input_shape) == len(margin))
indxes = np.nonzero(label)
idx_min = []
idx_max = []
for i in range(len(input_shape)):
idx_min.append(indxes[i].min())
idx_max.append(indxes[i].max())

for i in range(len(input_shape)):
idx_min[i] = max(idx_min[i] - margin[i], 0)
idx_max[i] = min(idx_max[i] + margin[i], input_shape[i] - 1)
return idx_min, idx_max

def crop_ND_volume_with_bounding_box(volume, min_idx, max_idx):
"""
crop/extract a subregion form an nd image.
"""
dim = len(volume.shape)
assert(dim >= 2 and dim <= 5)
if(dim == 2):
output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1),
range(min_idx[1], max_idx[1] + 1))]
elif(dim == 3):
output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1),
range(min_idx[1], max_idx[1] + 1),
range(min_idx[2], max_idx[2] + 1))]
elif(dim == 4):
output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1),
range(min_idx[1], max_idx[1] + 1),
range(min_idx[2], max_idx[2] + 1),
range(min_idx[3], max_idx[3] + 1))]
elif(dim == 5):
output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1),
range(min_idx[1], max_idx[1] + 1),
range(min_idx[2], max_idx[2] + 1),
range(min_idx[3], max_idx[3] + 1),
range(min_idx[4], max_idx[4] + 1))]
else:
raise ValueError("the dimension number shoud be 2 to 5")
return output

def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume):
"""
set a subregion to an nd image.
"""
dim = len(bb_min)
out = volume
if(dim == 2):
out[np.ix_(range(bb_min[0], bb_max[0] + 1),
range(bb_min[1], bb_max[1] + 1))] = sub_volume
elif(dim == 3):
out[np.ix_(range(bb_min[0], bb_max[0] + 1),
range(bb_min[1], bb_max[1] + 1),
range(bb_min[2], bb_max[2] + 1))] = sub_volume
elif(dim == 4):
out[np.ix_(range(bb_min[0], bb_max[0] + 1),
range(bb_min[1], bb_max[1] + 1),
range(bb_min[2], bb_max[2] + 1),
range(bb_min[3], bb_max[3] + 1))] = sub_volume
else:
raise ValueError("array dimension should be 2, 3 or 4")
return out

def convert_label(in_volume, label_convert_source, label_convert_target):
"""
convert the label value in a volume
Expand Down Expand Up @@ -245,26 +315,6 @@ def set_roi_to_volume(volume, center, sub_volume):
raise ValueError("array dimension should be 3 or 4")
return output_volume


def get_roi(volume, margin):
"""
get the roi bounding box of a 3D volume
inputs:
volume: the input 3D volume
margin: an integer margin along each axis
output:
[mind, maxd, minh, maxh, minw, maxw]: a list of lower and upper bound along each dimension
"""
[d_idxes, h_idxes, w_idxes] = np.nonzero(volume)
[D, H, W] = volume.shape
mind = max(d_idxes.min() - margin, 0)
maxd = min(d_idxes.max() + margin, D)
minh = max(h_idxes.min() - margin, 0)
maxh = min(h_idxes.max() + margin, H)
minw = max(w_idxes.min() - margin, 0)
maxw = min(w_idxes.max() + margin, W)
return [mind, maxd, minh, maxh, minw, maxw]

def get_largest_two_component(img, print_info = False, threshold = None):
"""
Get the largest two components of a binary volume
Expand Down Expand Up @@ -357,5 +407,5 @@ def binary_dice3d(s,g):
s0 = prod.sum()
s1 = s.sum()
s2 = g.sum()
dice = 2.0*s0/(s1 + s2 + 1e-10)
dice = (2.0*s0 + 1e-10)/(s1 + s2 + 1e-10)
return dice
64 changes: 52 additions & 12 deletions util/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,49 @@
from __future__ import absolute_import, print_function
import os
import sys
sys.path.append('./')
import numpy as np
from util.data_process import load_nifty_volume_as_array, binary_dice3d
from util.data_process import load_3d_volume_as_array, binary_dice3d

def dice_of_brats_data_set(s_folder, g_folder, patient_names_file, type_idx):
def get_ground_truth_names(g_folder, patient_names_file, year = 15):
assert(year==15 or year == 17)
with open(patient_names_file) as f:
content = f.readlines()
patient_names = [x.strip() for x in content]
patient_names = [x.strip() for x in content]
full_gt_names = []
for patient_name in patient_names:
patient_dir = os.path.join(g_folder, patient_name)
img_names = os.listdir(patient_dir)
gt_name = None
for img_name in img_names:
if(year == 15):
if 'OT.' in img_name:
gt_name = img_name + '/' + img_name + '.mha'
break
else:
if 'seg.' in img_name:
gt_name = img_name
break
gt_name = os.path.join(patient_dir, gt_name)
full_gt_names.append(gt_name)
return full_gt_names

def get_segmentation_names(seg_folder, patient_names_file):
with open(patient_names_file) as f:
content = f.readlines()
patient_names = [x.strip() for x in content]
full_seg_names = []
for patient_name in patient_names:
seg_name = os.path.join(seg_folder, patient_name + '.nii.gz')
full_seg_names.append(seg_name)
return full_seg_names

def dice_of_brats_data_set(gt_names, seg_names, type_idx):
assert(len(gt_names) == len(seg_names))
dice_all_data = []
for i in range(len(patient_names)):
s_name = os.path.join(s_folder, patient_names[i] + '.nii.gz')
g_name = os.path.join(g_folder, patient_names[i] + '.nii.gz')
s_volume = load_nifty_volume_as_array(s_name)
g_volume = load_nifty_volume_as_array(g_name)
for i in range(len(gt_names)):
g_volume = load_3d_volume_as_array(gt_names[i])
s_volume = load_3d_volume_as_array(seg_names[i])
dice_one_volume = []
if(type_idx ==0): # whole tumor
temp_dice = binary_dice3d(s_volume > 0, g_volume > 0)
Expand All @@ -32,12 +62,22 @@ def dice_of_brats_data_set(s_folder, g_folder, patient_names_file, type_idx):
return dice_all_data

if __name__ == '__main__':
s_folder = 'result'
g_folder = '/home/guotai/data/brats_docker_data/pre_process'
patient_names_file = 'config_part/test_names.txt'
year = 15 # or 17

if(year == 15):
s_folder = 'result15'
g_folder = '/home/guotwang/data/BRATS2015_Training'
patient_names_file = 'config15/test_names.txt'
else:
s_folder = 'result17'
g_folder = '/home/guotwang/data/Brats17TrainingData'
patient_names_file = 'config15/test_names.txt'

test_types = ['whole','core', 'all']
gt_names = get_ground_truth_names(g_folder, patient_names_file, year)
seg_names = get_segmentation_names(s_folder, patient_names_file)
for type_idx in range(3):
dice = dice_of_brats_data_set(s_folder, g_folder, patient_names_file, type_idx)
dice = dice_of_brats_data_set(gt_names, seg_names, type_idx)
dice = np.asarray(dice)
dice_mean = dice.mean(axis = 0)
dice_std = dice.std(axis = 0)
Expand Down
5 changes: 3 additions & 2 deletions util/rename_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ def rename(checkpoint_from, checkpoint_to, replace_from, replace_to):
saver.save(sess, checkpoint_to)

if __name__ == '__main__':
year = 15
net_name = ['wt', 'tc', 'en']
net_name_c = ['WT', 'TC', 'EN']
num_pretrain = [10000, 20000, 20000]
for i in range(3):
for view in ['sg', 'cr']:
checkpoint_from = "model15/msnet_{0:}32_{1:}.ckpt".format(net_name[i], num_pretrain[i])
checkpoint_to = "model15/msnet_{0:}32{1:}_init".format(net_name[i], view)
checkpoint_from = "model{0:}/msnet_{1:}32_{2:}.ckpt".format(year, net_name[i], num_pretrain[i])
checkpoint_to = "model{0:}/msnet_{1:}32{2:}_init".format(year, net_name[i], view)
replace_from = "MSNet_{0:}32".format(net_name_c[i])
replace_to = "MSNet_{0:}32{1:}".format(net_name_c[i], view)
rename(checkpoint_from, checkpoint_to, replace_from, replace_to)
Expand Down

0 comments on commit 12b90b5

Please sign in to comment.