Skip to content

Commit

Permalink
update eval when train
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing committed May 11, 2022
1 parent 1cbbe32 commit ac181a6
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 40 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Medical_Datasets/
lfw/
logs/
model_data/
.temp_map_out/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
49 changes: 37 additions & 12 deletions get_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from PIL import Image
from tqdm import tqdm

from yolo import YOLO
from utils.utils import get_classes
from utils.utils_map import get_coco_map, get_map
from yolo import YOLO

if __name__ == "__main__":
'''
Recall和Precision不像AP是一个面积的概念,在门限值不同时,网络的Recall和Precision值是不同的。
map计算结果中的Recall和Precision代表的是当预测时,门限置信度为0.5时,所对应的Recall和Precision值。
Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
此处获得的./map_out/detection-results/里面的txt的框的数量会比直接predict多一些,这是因为这里的门限低,
目的是为了计算不同门限条件下的Recall和Precision值,从而实现map的计算。
受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
'''
#------------------------------------------------------------------------------------------------------------------#
# map_mode用于指定该文件运行时计算的内容
Expand All @@ -25,16 +25,41 @@
# map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
#-------------------------------------------------------------------------------------------------------------------#
map_mode = 0
#-------------------------------------------------------#
#--------------------------------------------------------------------------------------#
# 此处的classes_path用于指定需要测量VOC_map的类别
# 一般情况下与训练和预测所用的classes_path一致即可
#-------------------------------------------------------#
#--------------------------------------------------------------------------------------#
classes_path = 'model_data/voc_classes.txt'
#-------------------------------------------------------#
# MINOVERLAP用于指定想要获得的mAP0.x
#--------------------------------------------------------------------------------------#
# MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
# 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
#-------------------------------------------------------#
#
# 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
# 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
#--------------------------------------------------------------------------------------#
MINOVERLAP = 0.5
#--------------------------------------------------------------------------------------#
# 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
# 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
#
# 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
# 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
#--------------------------------------------------------------------------------------#
confidence = 0.001
#--------------------------------------------------------------------------------------#
# 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
#
# 该值一般不调整。
#--------------------------------------------------------------------------------------#
nms_iou = 0.5
#---------------------------------------------------------------------------------------------------------------#
# Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
#
# 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
# 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
# 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
#---------------------------------------------------------------------------------------------------------------#
score_threhold = 0.5
#-------------------------------------------------------#
# map_vis用于指定是否开启VOC_map计算的可视化
#-------------------------------------------------------#
Expand Down Expand Up @@ -64,7 +89,7 @@

if map_mode == 0 or map_mode == 1:
print("Load model.")
yolo = YOLO(confidence = 0.001, nms_iou = 0.5)
yolo = YOLO(confidence = confidence, nms_iou = nms_iou)
print("Load model done.")

print("Get predict result.")
Expand Down Expand Up @@ -104,7 +129,7 @@

if map_mode == 0 or map_mode == 3:
print("Get map.")
get_map(MINOVERLAP, True, path = map_out_path)
get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
print("Get map done.")

if map_mode == 4:
Expand Down
38 changes: 32 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#-------------------------------------#
# 对数据集进行训练
#-------------------------------------#
import datetime
import os

import numpy as np
Expand All @@ -15,7 +16,7 @@
from nets.yolo import YoloBody
from nets.yolo_training import (YOLOLoss, get_lr_scheduler, set_optimizer_lr,
weights_init)
from utils.callbacks import LossHistory
from utils.callbacks import LossHistory, EvalCallback
from utils.dataloader import YoloDataset, yolo_dataset_collate
from utils.utils import get_anchors, get_classes, show_config
from utils.utils_fit import fit_one_epoch
Expand Down Expand Up @@ -198,6 +199,17 @@
#------------------------------------------------------------------#
save_dir = 'logs'
#------------------------------------------------------------------#
# eval_flag 是否在训练时进行评估,评估对象为验证集
# 安装pycocotools库后,评估体验更佳。
# eval_period 代表多少个epoch评估一次,不建议频繁的评估
# 评估需要消耗较多的时间,频繁评估会导致训练非常慢
# 此处获得的mAP会与get_map.py获得的会有所不同,原因有二:
# (一)此处获得的mAP为验证集的mAP。
# (二)此处设置评估参数较为保守,目的是加快评估速度。
#------------------------------------------------------------------#
eval_flag = True
eval_period = 10
#------------------------------------------------------------------#
# num_workers 用于设置是否使用多线程读取数据
# 开启后会加快数据读取速度,但是会占用更多内存
# 内存较小的电脑可以设置为2或者0
Expand Down Expand Up @@ -275,9 +287,11 @@
# 记录Loss
#----------------------#
if local_rank == 0:
loss_history = LossHistory(save_dir, model, input_shape=input_shape)
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
loss_history = LossHistory(log_dir, model, input_shape=input_shape)
else:
loss_history = None
loss_history = None

#------------------------------------------------------------------#
# torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16
Expand Down Expand Up @@ -425,6 +439,15 @@
gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler)

#----------------------#
# 记录eval的map曲线
#----------------------#
if local_rank == 0:
eval_callback = EvalCallback(model, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, Cuda, \
eval_flag=eval_flag, period=eval_period)
else:
eval_callback = None

#---------------------------------------#
# 开始模型训练
#---------------------------------------#
Expand Down Expand Up @@ -472,7 +495,10 @@
train_sampler.set_epoch(epoch)
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)

fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)

fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)

if distributed:
dist.barrier()

if local_rank == 0:
loss_history.writer.close()
loss_history.writer.close()
166 changes: 164 additions & 2 deletions utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter

import shutil
import numpy as np

from PIL import Image
from tqdm import tqdm
from .utils import cvtColor, preprocess_input, resize_image
from .utils_bbox import DecodeBox
from .utils_map import get_coco_map, get_map


class LossHistory():
def __init__(self, log_dir, model, input_shape):
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
self.log_dir = log_dir
self.losses = []
self.val_loss = []

Expand Down Expand Up @@ -68,3 +76,157 @@ def loss_plot(self):

plt.cla()
plt.close("all")

class EvalCallback():
def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
super(EvalCallback, self).__init__()

self.net = net
self.input_shape = input_shape
self.anchors = anchors
self.anchors_mask = anchors_mask
self.class_names = class_names
self.num_classes = num_classes
self.val_lines = val_lines
self.log_dir = log_dir
self.cuda = cuda
self.map_out_path = map_out_path
self.max_boxes = max_boxes
self.confidence = confidence
self.nms_iou = nms_iou
self.letterbox_image = letterbox_image
self.MINOVERLAP = MINOVERLAP
self.eval_flag = eval_flag
self.period = period

self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)

self.maps = [0]
self.epoches = [0]
if self.eval_flag:
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
f.write(str(0))
f.write("\n")

def get_map_txt(self, image_id, image, class_names, map_out_path):
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8')
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = self.bbox_util.decode_box(outputs)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)

if results[0] is None:
return

top_label = np.array(results[0][:, 6], dtype = 'int32')
top_conf = results[0][:, 4] * results[0][:, 5]
top_boxes = results[0][:, :4]

top_100 = np.argsort(top_label)[::-1][:self.max_boxes]
top_boxes = top_boxes[top_100]
top_conf = top_conf[top_100]
top_label = top_label[top_100]

for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)]
box = top_boxes[i]
score = str(top_conf[i])

top, left, bottom, right = box
if predicted_class not in class_names:
continue

f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))

f.close()
return

def on_epoch_end(self, epoch, model_eval):
if epoch % self.period == 0 and self.eval_flag:
self.net = model_eval
if not os.path.exists(self.map_out_path):
os.makedirs(self.map_out_path)
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
print("Get map.")
for annotation_line in tqdm(self.val_lines):
line = annotation_line.split()
image_id = os.path.basename(line[0]).split('.')[0]
#------------------------------#
# 读取图像并转换成RGB图像
#------------------------------#
image = Image.open(line[0])
#------------------------------#
# 获得预测框
#------------------------------#
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
#------------------------------#
# 获得预测txt
#------------------------------#
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)

#------------------------------#
# 获得真实框txt
#------------------------------#
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
for box in gt_boxes:
left, top, right, bottom, obj = box
obj_name = self.class_names[obj]
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))

print("Calculate Map.")
try:
temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
except:
temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
self.maps.append(temp_map)
self.epoches.append(epoch)

with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
f.write(str(temp_map))
f.write("\n")

plt.figure()
plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')

plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Map %s'%str(self.MINOVERLAP))
plt.title('A Map Curve')
plt.legend(loc="upper right")

plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
plt.cla()
plt.close("all")

print("Get map done.")
shutil.rmtree(self.map_out_path)
3 changes: 2 additions & 1 deletion utils/utils_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from utils.utils import get_lr


def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
def fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
loss = 0
val_loss = 0

Expand Down Expand Up @@ -120,6 +120,7 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
pbar.close()
print('Finish Validation')
loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
eval_callback.on_epoch_end(epoch + 1, model_train)
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))

Expand Down
Loading

0 comments on commit ac181a6

Please sign in to comment.