Skip to content

Commit

Permalink
Improves #694 (#712)
Browse files Browse the repository at this point in the history
* Made all loggers public attribute
- method to setup logger

* Fixed flake8
  • Loading branch information
vfdev-5 authored Jan 21, 2020
1 parent a32edb9 commit ebd1876
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 25 deletions.
29 changes: 14 additions & 15 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,7 @@ def compute_mean_std(engine, batch):

def __init__(self, process_function):
self._event_handlers = defaultdict(list)
self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._logger.addHandler(logging.NullHandler())
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._process_function = process_function
self.last_event_name = None
self.should_terminate = False
Expand Down Expand Up @@ -489,14 +488,14 @@ def print_epoch(engine):
handler = Engine._handler_wrapper(handler, event_name, event_filter)

if event_name not in self._allowed_events:
self._logger.error("attempt to add event handler to an invalid event %s.", event_name)
self.logger.error("attempt to add event handler to an invalid event %s.", event_name)
raise ValueError("Event {} is not a valid event for this Engine.".format(event_name))

event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else ()
Engine._check_signature(self, handler, 'handler', *(event_args + args), **kwargs)

self._event_handlers[event_name].append((handler, args, kwargs))
self._logger.debug("added handler for event %s.", event_name)
self.logger.debug("added handler for event %s.", event_name)

return RemovableEventHandle(event_name, handler, self)

Expand Down Expand Up @@ -601,7 +600,7 @@ def _fire_event(self, event_name, *event_args, **event_kwargs):
"""
if event_name in self._allowed_events:
self._logger.debug("firing handlers for event %s ", event_name)
self.logger.debug("firing handlers for event %s ", event_name)
self.last_event_name = event_name
for func, args, kwargs in self._event_handlers[event_name]:
kwargs.update(event_kwargs)
Expand Down Expand Up @@ -633,14 +632,14 @@ def fire_event(self, event_name):
def terminate(self):
"""Sends terminate signal to the engine, so that it terminates completely the run after the current iteration.
"""
self._logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True

def terminate_epoch(self):
"""Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration.
"""
self._logger.info("Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished.")
self.logger.info("Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished.")
self.should_terminate_single_epoch = True

def _run_once_on_dataset(self):
Expand Down Expand Up @@ -702,7 +701,7 @@ def _run_once_on_dataset(self):
break

except BaseException as e:
self._logger.error("Current run is terminating due to exception: %s.", str(e))
self.logger.error("Current run is terminating due to exception: %s.", str(e))
self._handle_exception(e)

time_taken = time.time() - start_time
Expand Down Expand Up @@ -835,7 +834,7 @@ def switch_batch(engine):
else:
raise ValueError("Argument `epoch_length` should be defined if `data` is an iterator")
self.state = State(seed=seed, iteration=0, epoch=0, max_epochs=max_epochs, epoch_length=epoch_length)
self._logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
else:
# Keep actual state and override it if input args provided
if max_epochs is not None:
Expand All @@ -844,8 +843,8 @@ def switch_batch(engine):
self.state.seed = seed
if epoch_length is not None:
self.state.epoch_length = epoch_length
self._logger.info("Engine run resuming from iteration {}, epoch {} until {} epochs"
.format(self.state.iteration, self.state.epoch, self.state.max_epochs))
self.logger.info("Engine run resuming from iteration {}, epoch {} until {} epochs"
.format(self.state.iteration, self.state.epoch, self.state.max_epochs))

self.state.dataloader = data
return self._internal_run()
Expand Down Expand Up @@ -937,19 +936,19 @@ def _internal_run(self):

hours, mins, secs = self._run_once_on_dataset()

self._logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs)
self.logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs)
if self.should_terminate:
break
self._fire_event(Events.EPOCH_COMPLETED)

self._fire_event(Events.COMPLETED)
time_taken = time.time() - start_time
hours, mins, secs = _to_hours_mins_secs(time_taken)
self._logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs))
self.logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs))

except BaseException as e:
self._dataloader_iter = self._dataloader_len = None
self._logger.error("Engine run is terminating due to exception: %s.", str(e))
self.logger.error("Engine run is terminating due to exception: %s.", str(e))
self._handle_exception(e)

self._dataloader_iter = self._dataloader_len = None
Expand Down
7 changes: 3 additions & 4 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def __init__(self, patience, score_function, trainer, min_delta=0., cumulative_d
self.trainer = trainer
self.counter = 0
self.best_score = None
self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._logger.addHandler(logging.NullHandler())
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)

def __call__(self, engine):
score = self.score_function(engine)
Expand All @@ -71,9 +70,9 @@ def __call__(self, engine):
if not self.cumulative_delta and score > self.best_score:
self.best_score = score
self.counter += 1
self._logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
if self.counter >= self.patience:
self._logger.info("EarlyStopping: Stop training")
self.logger.info("EarlyStopping: Stop training")
self.trainer.terminate()
else:
self.best_score = score
Expand Down
8 changes: 4 additions & 4 deletions ignite/handlers/terminate_on_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class TerminateOnNan:
"""

def __init__(self, output_transform=lambda x: x):
self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._logger.addHandler(logging.StreamHandler())
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self.logger.addHandler(logging.StreamHandler())
self._output_transform = output_transform

def __call__(self, engine):
Expand All @@ -47,6 +47,6 @@ def raise_error(x):
try:
apply_to_type(output, (numbers.Number, torch.Tensor), raise_error)
except RuntimeError:
self._logger.warning("{}: Output '{}' contains NaN or Inf. Stop training"
.format(self.__class__.__name__, output))
self.logger.warning("{}: Output '{}' contains NaN or Inf. Stop training"
.format(self.__class__.__name__, output))
engine.terminate()
69 changes: 68 additions & 1 deletion ignite/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
import os
import collections.abc as collections
import logging

import torch


def convert_tensor(input_, device=None, non_blocking=False):
Expand Down Expand Up @@ -41,3 +44,67 @@ def to_onehot(indices, num_classes):
dtype=torch.uint8,
device=indices.device)
return onehot.scatter_(1, indices.unsqueeze(1), 1)


def setup_logger(name, level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath=None, distributed_rank=0):
"""Setups logger: name, level, format etc.
Args:
name (str): new name for the logger.
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
Returns:
logging.Logger
For example, to improve logs readability when training with a trainer and evaluator:
.. code-block:: python
from ignite.utils import setup_logger
trainer = ...
evaluator = ...
trainer.logger = setup_logger("trainer")
evaluator.logger = setup_logger("evaluator")
trainer.run(data, max_epochs=10)
# Logs will look like
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
# ...
"""
logger = logging.getLogger(name)

if distributed_rank > 0:
return logger

logger.setLevel(level)

# Remove previous handlers
if logger.hasHandlers():
for h in list(logger.handlers):
logger.removeHandler(h)

formatter = logging.Formatter(format)

ch = logging.StreamHandler()
ch.setLevel(level)
ch.setFormatter(formatter)
logger.addHandler(ch)

if filepath is not None:
fh = logging.FileHandler(filepath)
fh.setLevel(level)
fh.setFormatter(formatter)
logger.addHandler(fh)

return logger
47 changes: 46 additions & 1 deletion tests/ignite/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import logging
import pytest
import torch
from ignite.utils import convert_tensor, to_onehot

from ignite.utils import convert_tensor, to_onehot, setup_logger


def test_convert_tensor():
Expand Down Expand Up @@ -54,3 +57,45 @@ def test_to_onehot():
y_ohe = to_onehot(y, num_classes=21)
y2 = torch.argmax(y_ohe, dim=1)
assert y.equal(y2)


def test_dist_setup_logger():

logger = setup_logger("trainer", level=logging.CRITICAL, distributed_rank=1)
assert logger.level != logging.CRITICAL


def test_setup_logger(capsys, dirname):

from ignite.engine import Engine, Events

trainer = Engine(lambda e, b: None)
evaluator = Engine(lambda e, b: None)

fp = os.path.join(dirname, "log")
assert len(trainer.logger.handlers) == 0
trainer.logger.addHandler(logging.NullHandler())
trainer.logger.addHandler(logging.NullHandler())
trainer.logger.addHandler(logging.NullHandler())

trainer.logger = setup_logger("trainer", filepath=fp)
evaluator.logger = setup_logger("evaluator", filepath=fp)

assert len(trainer.logger.handlers) == 2
assert len(evaluator.logger.handlers) == 2

@trainer.on(Events.EPOCH_COMPLETED)
def _(_):
evaluator.run([0, 1, 2])

trainer.run([0, 1, 2, 3, 4, 5], max_epochs=5)

captured = capsys.readouterr()
err = captured.err.split('\n')

with open(fp, "r") as h:
data = h.readlines()

for source in [err, data]:
assert "trainer INFO: Engine run starting with max_epochs=5." in source[0]
assert "evaluator INFO: Engine run starting with max_epochs=1." in source[2]

0 comments on commit ebd1876

Please sign in to comment.