From 12b90b5044c98e8e84325ffb4de10528fa9f84f4 Mon Sep 17 00:00:00 2001 From: Guotai Wang Date: Sat, 13 Jan 2018 17:46:21 +0000 Subject: [PATCH] update test and evaluation --- test.py | 44 ++++++++----------- util/data_loader.py | 27 ++++++------ util/data_process.py | 92 +++++++++++++++++++++++++++++++--------- util/evaluation.py | 64 ++++++++++++++++++++++------ util/rename_variables.py | 5 ++- 5 files changed, 158 insertions(+), 74 deletions(-) diff --git a/test.py b/test.py index 2719be0..2adceeb 100644 --- a/test.py +++ b/test.py @@ -298,13 +298,8 @@ def test(config_file): margin = config_test.get('roi_patch_margin', 5) for i in range(image_num): - [imgs, weight, temp_name] = dataloader.get_image_data_with_name(i) + [temp_imgs, temp_weight, temp_name, temp_bbox, temp_size] = dataloader.get_image_data_with_name(i) t0 = time.time() - groi = get_roi(weight > 0, margin) - temp_imgs = [x[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))] \ - for x in imgs] - temp_weight = weight[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))] - # 5.1, test of 1st network if(config_net1): data_shapes = [ data_shape1[:-1], data_shape1[:-1], data_shape1[:-1]] @@ -328,20 +323,20 @@ def test(config_file): wt_threshold = 2000 if(config_test.get('whole_tumor_only', False) is True): pred1_lc = ndimage.morphology.binary_closing(pred1, structure = struct) - pred1_lc = get_largest_two_component(pred1_lc, True, wt_threshold) + pred1_lc = get_largest_two_component(pred1_lc, False, wt_threshold) out_label = pred1_lc else: # 5.2, test of 2nd network if(pred1.sum() == 0): print('net1 output is null', temp_name) - roi2 = get_roi(temp_imgs[0] > 0, margin) + bbox1 = get_ND_bounding_box(temp_imgs[0] > 0, margin) else: pred1_lc = ndimage.morphology.binary_closing(pred1, structure = struct) - pred1_lc = get_largest_two_component(pred1_lc, True, wt_threshold) - roi2 = get_roi(pred1_lc, margin) - sub_imgs = [x[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] \ - for x in temp_imgs] - sub_weight = temp_weight[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] + pred1_lc = get_largest_two_component(pred1_lc, False, wt_threshold) + bbox1 = get_ND_bounding_box(pred1_lc, margin) + sub_imgs = [crop_ND_volume_with_bounding_box(one_img, bbox1[0], bbox1[1]) for one_img in temp_imgs] + sub_weight = crop_ND_volume_with_bounding_box(temp_weight, bbox1[0], bbox1[1]) + if(config_net2): data_shapes = [ data_shape2[:-1], data_shape2[:-1], data_shape2[:-1]] label_shapes = [label_shape2[:-1], label_shape2[:-1], label_shape2[:-1]] @@ -364,17 +359,16 @@ def test(config_file): # 5.3, test of 3rd network if(pred2.sum() == 0): [roid, roih, roiw] = sub_imgs[0].shape - roi3 = [0, roid, 0, roih, 0, roiw] + bbox2 = [[0,0,0], [roid-1, roih-1, roiw-1]] subsub_imgs = sub_imgs subsub_weight = sub_weight else: pred2_lc = ndimage.morphology.binary_closing(pred2, structure = struct) pred2_lc = get_largest_two_component(pred2_lc) - roi3 = get_roi(pred2_lc, margin) - subsub_imgs = [x[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))] \ - for x in sub_imgs] - subsub_weight = sub_weight[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))] - + bbox2 = get_ND_bounding_box(pred2_lc, margin) + subsub_imgs = [crop_ND_volume_with_bounding_box(one_img, bbox2[0], bbox2[1]) for one_img in sub_imgs] + subsub_weight = crop_ND_volume_with_bounding_box(sub_weight, bbox2[0], bbox2[1]) + if(config_net3): data_shapes = [ data_shape3[:-1], data_shape3[:-1], data_shape3[:-1]] label_shapes = [label_shape3[:-1], label_shape3[:-1], label_shape3[:-1]] @@ -399,12 +393,12 @@ def test(config_file): # 5.4, fuse results at 3 levels # convert subsub_label to full size (non-enhanced) label3_roi = np.zeros_like(pred2) - label3_roi[np.ix_(range(roi3[0], roi3[1]), range(roi3[2], roi3[3]), range(roi3[4], roi3[5]))] = pred3 + label3_roi = set_ND_volume_roi_with_bounding_box_range(label3_roi, bbox2[0], bbox2[1], pred3) label3 = np.zeros_like(pred1) - label3[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] = label3_roi + label3 = set_ND_volume_roi_with_bounding_box_range(label3, bbox1[0], bbox1[1], label3_roi) label2 = np.zeros_like(pred1) - label2[np.ix_(range(roi2[0], roi2[1]), range(roi2[2], roi2[3]), range(roi2[4], roi2[5]))] = pred2 + lalbe2 = set_ND_volume_roi_with_bounding_box_range(label2, bbox1[0], bbox1[1], pred2) label1_mask = (pred1 + label2 + label3) > 0 label1_mask = ndimage.morphology.binary_closing(label1_mask, structure = struct) @@ -422,7 +416,6 @@ def test(config_file): label3 = label2 * label3 vox_3 = np.asarray(label3 > 0, np.float32).sum() if(0 < vox_3 and vox_3 < 30): - print('ignored voxel number ', vox_3) label3 = np.zeros_like(label2) # 5.5, convert label and save output @@ -438,9 +431,8 @@ def test(config_file): test_time.append(time.time() - t0) - final_label = np.zeros_like(weight, np.int16) - final_label[np.ix_(range(groi[0], groi[1]), range(groi[2], groi[3]), range(groi[4], groi[5]))] = out_label - + final_label = np.zeros(temp_size, np.int16) + final_label = set_ND_volume_roi_with_bounding_box_range(final_label, temp_bbox[0], temp_bbox[1], out_label) save_array_as_nifty_volume(final_label, save_folder+"/{0:}.nii.gz".format(temp_name)) print(temp_name) test_time = np.asarray(test_time) diff --git a/util/data_loader.py b/util/data_loader.py index 3d8a7b8..b0e4a5f 100644 --- a/util/data_loader.py +++ b/util/data_loader.py @@ -72,7 +72,7 @@ def __load_one_volume(self, patient_name, mod): volume = load_3d_volume_as_array(volume_name) return volume - def load_data(self, stage='train'): + def load_data(self): """ load all the training/testing data """ @@ -81,22 +81,20 @@ def load_data(self, stage='train'): X = [] W = [] Y = [] + bbox = [] + in_size = [] data_num = self.data_num if (self.data_num is not None) else len(self.patient_names) for i in range(data_num): volume_list = [] for mod_idx in range(len(self.modality_postfix)): volume = self.__load_one_volume(self.patient_names[i], self.modality_postfix[mod_idx]) - if(self.data_resize): - volume = resize_3D_volume_to_given_shape(volume, self.data_resize, 1) if(mod_idx == 0): margin = 5 - [d_idxes, h_idxes, w_idxes] = np.nonzero(volume) - mind = d_idxes.min() - margin; maxd = d_idxes.max() + margin - minh = h_idxes.min() - margin; maxh = h_idxes.max() + margin - minw = w_idxes.min() - margin; maxw = w_idxes.max() + margin - - if(stage == 'train'): - volume = volume[np.ix_(range(mind, maxd), range(minh, maxh), range(minw, maxw))] + bbmin, bbmax = get_ND_bounding_box(volume, margin) + volume_size = volume.shape + volume = crop_ND_volume_with_bounding_box(volume, bbmin, bbmax) + if(self.data_resize): + volume = resize_3D_volume_to_given_shape(volume, self.data_resize, 1) if(mod_idx ==0): weight = np.asarray(volume > 0, np.float32) if(self.intensity_normalize[mod_idx]): @@ -104,18 +102,21 @@ def load_data(self, stage='train'): volume_list.append(volume) X.append(volume_list) W.append(weight) + bbox.append([bbmin, bbmax]) + in_size.append(volume_size) if(self.with_ground_truth): label = self.__load_one_volume(self.patient_names[i], self.label_postfix) + label = crop_ND_volume_with_bounding_box(label, bbmin, bbmax) if(self.data_resize): label = resize_3D_volume_to_given_shape(label, self.data_resize, 0) - if(stage == 'train'): - label = label[np.ix_(range(mind, maxd), range(minh, maxh), range(minw, maxw))] Y.append(label) if((i+1)%50 == 0 or (i+1) == data_num): print('Data load, {0:}% finished'.format((i+1)*100.0/data_num)) self.data = X self.weight = W self.label = Y + self.bbox = bbox + self.in_size= in_size def get_subimage_batch(self): """ @@ -256,4 +257,4 @@ def get_image_data_with_name(self, i): """ Used for testing, get one image data and patient name """ - return [self.data[i], self.weight[i], self.patient_names[i]] + return [self.data[i], self.weight[i], self.patient_names[i], self.bbox[i], self.in_size[i]] diff --git a/util/data_process.py b/util/data_process.py index 295c68e..dc61463 100644 --- a/util/data_process.py +++ b/util/data_process.py @@ -79,6 +79,76 @@ def itensity_normalize_one_volume(volume): out[volume == 0] = out_random[volume == 0] return out +def get_ND_bounding_box(label, margin): + """ + get the bounding box of the non-zero region of an ND volume + """ + input_shape = label.shape + if(type(margin) is int ): + margin = [margin]*len(input_shape) + assert(len(input_shape) == len(margin)) + indxes = np.nonzero(label) + idx_min = [] + idx_max = [] + for i in range(len(input_shape)): + idx_min.append(indxes[i].min()) + idx_max.append(indxes[i].max()) + + for i in range(len(input_shape)): + idx_min[i] = max(idx_min[i] - margin[i], 0) + idx_max[i] = min(idx_max[i] + margin[i], input_shape[i] - 1) + return idx_min, idx_max + +def crop_ND_volume_with_bounding_box(volume, min_idx, max_idx): + """ + crop/extract a subregion form an nd image. + """ + dim = len(volume.shape) + assert(dim >= 2 and dim <= 5) + if(dim == 2): + output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1), + range(min_idx[1], max_idx[1] + 1))] + elif(dim == 3): + output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1), + range(min_idx[1], max_idx[1] + 1), + range(min_idx[2], max_idx[2] + 1))] + elif(dim == 4): + output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1), + range(min_idx[1], max_idx[1] + 1), + range(min_idx[2], max_idx[2] + 1), + range(min_idx[3], max_idx[3] + 1))] + elif(dim == 5): + output = volume[np.ix_(range(min_idx[0], max_idx[0] + 1), + range(min_idx[1], max_idx[1] + 1), + range(min_idx[2], max_idx[2] + 1), + range(min_idx[3], max_idx[3] + 1), + range(min_idx[4], max_idx[4] + 1))] + else: + raise ValueError("the dimension number shoud be 2 to 5") + return output + +def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume): + """ + set a subregion to an nd image. + """ + dim = len(bb_min) + out = volume + if(dim == 2): + out[np.ix_(range(bb_min[0], bb_max[0] + 1), + range(bb_min[1], bb_max[1] + 1))] = sub_volume + elif(dim == 3): + out[np.ix_(range(bb_min[0], bb_max[0] + 1), + range(bb_min[1], bb_max[1] + 1), + range(bb_min[2], bb_max[2] + 1))] = sub_volume + elif(dim == 4): + out[np.ix_(range(bb_min[0], bb_max[0] + 1), + range(bb_min[1], bb_max[1] + 1), + range(bb_min[2], bb_max[2] + 1), + range(bb_min[3], bb_max[3] + 1))] = sub_volume + else: + raise ValueError("array dimension should be 2, 3 or 4") + return out + def convert_label(in_volume, label_convert_source, label_convert_target): """ convert the label value in a volume @@ -245,26 +315,6 @@ def set_roi_to_volume(volume, center, sub_volume): raise ValueError("array dimension should be 3 or 4") return output_volume - -def get_roi(volume, margin): - """ - get the roi bounding box of a 3D volume - inputs: - volume: the input 3D volume - margin: an integer margin along each axis - output: - [mind, maxd, minh, maxh, minw, maxw]: a list of lower and upper bound along each dimension - """ - [d_idxes, h_idxes, w_idxes] = np.nonzero(volume) - [D, H, W] = volume.shape - mind = max(d_idxes.min() - margin, 0) - maxd = min(d_idxes.max() + margin, D) - minh = max(h_idxes.min() - margin, 0) - maxh = min(h_idxes.max() + margin, H) - minw = max(w_idxes.min() - margin, 0) - maxw = min(w_idxes.max() + margin, W) - return [mind, maxd, minh, maxh, minw, maxw] - def get_largest_two_component(img, print_info = False, threshold = None): """ Get the largest two components of a binary volume @@ -357,5 +407,5 @@ def binary_dice3d(s,g): s0 = prod.sum() s1 = s.sum() s2 = g.sum() - dice = 2.0*s0/(s1 + s2 + 1e-10) + dice = (2.0*s0 + 1e-10)/(s1 + s2 + 1e-10) return dice diff --git a/util/evaluation.py b/util/evaluation.py index ca3e6b3..548f66d 100644 --- a/util/evaluation.py +++ b/util/evaluation.py @@ -2,19 +2,49 @@ from __future__ import absolute_import, print_function import os import sys +sys.path.append('./') import numpy as np -from util.data_process import load_nifty_volume_as_array, binary_dice3d +from util.data_process import load_3d_volume_as_array, binary_dice3d -def dice_of_brats_data_set(s_folder, g_folder, patient_names_file, type_idx): +def get_ground_truth_names(g_folder, patient_names_file, year = 15): + assert(year==15 or year == 17) with open(patient_names_file) as f: content = f.readlines() - patient_names = [x.strip() for x in content] + patient_names = [x.strip() for x in content] + full_gt_names = [] + for patient_name in patient_names: + patient_dir = os.path.join(g_folder, patient_name) + img_names = os.listdir(patient_dir) + gt_name = None + for img_name in img_names: + if(year == 15): + if 'OT.' in img_name: + gt_name = img_name + '/' + img_name + '.mha' + break + else: + if 'seg.' in img_name: + gt_name = img_name + break + gt_name = os.path.join(patient_dir, gt_name) + full_gt_names.append(gt_name) + return full_gt_names + +def get_segmentation_names(seg_folder, patient_names_file): + with open(patient_names_file) as f: + content = f.readlines() + patient_names = [x.strip() for x in content] + full_seg_names = [] + for patient_name in patient_names: + seg_name = os.path.join(seg_folder, patient_name + '.nii.gz') + full_seg_names.append(seg_name) + return full_seg_names + +def dice_of_brats_data_set(gt_names, seg_names, type_idx): + assert(len(gt_names) == len(seg_names)) dice_all_data = [] - for i in range(len(patient_names)): - s_name = os.path.join(s_folder, patient_names[i] + '.nii.gz') - g_name = os.path.join(g_folder, patient_names[i] + '.nii.gz') - s_volume = load_nifty_volume_as_array(s_name) - g_volume = load_nifty_volume_as_array(g_name) + for i in range(len(gt_names)): + g_volume = load_3d_volume_as_array(gt_names[i]) + s_volume = load_3d_volume_as_array(seg_names[i]) dice_one_volume = [] if(type_idx ==0): # whole tumor temp_dice = binary_dice3d(s_volume > 0, g_volume > 0) @@ -32,12 +62,22 @@ def dice_of_brats_data_set(s_folder, g_folder, patient_names_file, type_idx): return dice_all_data if __name__ == '__main__': - s_folder = 'result' - g_folder = '/home/guotai/data/brats_docker_data/pre_process' - patient_names_file = 'config_part/test_names.txt' + year = 15 # or 17 + + if(year == 15): + s_folder = 'result15' + g_folder = '/home/guotwang/data/BRATS2015_Training' + patient_names_file = 'config15/test_names.txt' + else: + s_folder = 'result17' + g_folder = '/home/guotwang/data/Brats17TrainingData' + patient_names_file = 'config15/test_names.txt' + test_types = ['whole','core', 'all'] + gt_names = get_ground_truth_names(g_folder, patient_names_file, year) + seg_names = get_segmentation_names(s_folder, patient_names_file) for type_idx in range(3): - dice = dice_of_brats_data_set(s_folder, g_folder, patient_names_file, type_idx) + dice = dice_of_brats_data_set(gt_names, seg_names, type_idx) dice = np.asarray(dice) dice_mean = dice.mean(axis = 0) dice_std = dice.std(axis = 0) diff --git a/util/rename_variables.py b/util/rename_variables.py index bf59970..f6eae19 100644 --- a/util/rename_variables.py +++ b/util/rename_variables.py @@ -20,13 +20,14 @@ def rename(checkpoint_from, checkpoint_to, replace_from, replace_to): saver.save(sess, checkpoint_to) if __name__ == '__main__': + year = 15 net_name = ['wt', 'tc', 'en'] net_name_c = ['WT', 'TC', 'EN'] num_pretrain = [10000, 20000, 20000] for i in range(3): for view in ['sg', 'cr']: - checkpoint_from = "model15/msnet_{0:}32_{1:}.ckpt".format(net_name[i], num_pretrain[i]) - checkpoint_to = "model15/msnet_{0:}32{1:}_init".format(net_name[i], view) + checkpoint_from = "model{0:}/msnet_{1:}32_{2:}.ckpt".format(year, net_name[i], num_pretrain[i]) + checkpoint_to = "model{0:}/msnet_{1:}32{2:}_init".format(year, net_name[i], view) replace_from = "MSNet_{0:}32".format(net_name_c[i]) replace_to = "MSNet_{0:}32{1:}".format(net_name_c[i], view) rename(checkpoint_from, checkpoint_to, replace_from, replace_to)