diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index 0f89125859..7c71965262 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -17,6 +17,7 @@ from __future__ import print_function import os +import gc import sys import datetime import six @@ -179,6 +180,7 @@ def __init__(self, model): self.best_ap = -1000. self.save_dir = self.model.cfg.save_dir self.uniform_output_enabled = self.model.cfg.get("uniform_output_enabled", False) + self.export_during_train = self.model.cfg.get("export_during_train", False) if hasattr(self.model.model, 'student_model'): self.weight = self.model.model.student_model else: @@ -261,8 +263,9 @@ def on_epoch_end(self, status): save_name, epoch_id + 1, ema_model=weight) - if self.uniform_output_enabled: + if self.export_during_train: self.model.export(output_dir=os.path.join(self.save_dir, save_name, "inference"), for_fd=True) + gc.collect() else: # save model(student model) and ema_model(teacher model) # in DenseTeacher SSOD, the teacher model will be higher, @@ -281,8 +284,9 @@ def on_epoch_end(self, status): else: save_model(weight, self.model.optimizer, os.path.join(self.save_dir, save_name) if self.uniform_output_enabled else self.save_dir, save_name, epoch_id + 1) - if self.uniform_output_enabled: + if self.export_during_train: self.model.export(output_dir=os.path.join(self.save_dir, save_name, "inference"), for_fd=True) + gc.collect() class WiferFaceEval(Callback): diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 90addeb7d5..4b07c8dc28 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -77,7 +77,8 @@ def __init__(self, cfg, mode='train'): self.custom_black_list = self.cfg.get('custom_black_list', None) self.use_master_grad = self.cfg.get('master_grad', False) self.uniform_output_enabled = self.cfg.get('uniform_output_enabled', False) - if ('slim' in cfg and cfg['slim_type'] == 'PTQ') or self.uniform_output_enabled: + self.export_during_train = self.cfg.get('export_during_train', False) + if ('slim' in cfg and cfg['slim_type'] == 'PTQ') or self.export_during_train: self.cfg['TestDataset'] = create('TestDataset')() log_ranks = cfg.get('log_ranks', '0') if isinstance(log_ranks, str):