-
-
Notifications
You must be signed in to change notification settings - Fork 622
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Polyaxon logger code with tests (#482)
* Added Polyaxon logger code with tests * Improved code and updated docs * Added no package tests * Added PolyaxonLogger in handlers module * Fixed docs warning
- Loading branch information
Showing
9 changed files
with
384 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import numbers | ||
|
||
import warnings | ||
import torch | ||
|
||
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler | ||
|
||
|
||
__all__ = ['PolyaxonLogger', 'OutputHandler'] | ||
|
||
|
||
class OutputHandler(BaseOutputHandler): | ||
"""Helper handler to log engine's output and/or metrics. | ||
Examples: | ||
.. code-block:: python | ||
from ignite.contrib.handlers.polyaxon_logger import * | ||
# Create a logger | ||
plx_logger = PolyaxonLogger() | ||
# Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after | ||
# each epoch. We setup `another_engine=trainer` to take the epoch of the `trainer` | ||
plx_logger.attach(evaluator, | ||
log_handler=OutputHandler(tag="validation", | ||
metric_names=["nll", "accuracy"], | ||
another_engine=trainer), | ||
event_name=Events.EPOCH_COMPLETED) | ||
Args: | ||
tag (str): common title for all produced plots. For example, 'training' | ||
metric_names (list of str, optional): list of metric names to plot. | ||
output_transform (callable, optional): output transform function to prepare `engine.state.output` as a number. | ||
For example, `output_transform = lambda output: output` | ||
This function can also return a dictionary, e.g `{'loss': loss1, `another_loss`: loss2}` to label the plot | ||
with corresponding keys. | ||
another_engine (Engine): another engine to use to provide the value of event. Typically, user can provide | ||
the trainer if this handler is attached to an evaluator and thus it logs proper trainer's | ||
epoch/iteration value. | ||
""" | ||
def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None): | ||
super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine) | ||
|
||
def __call__(self, engine, logger, event_name): | ||
|
||
if not isinstance(logger, PolyaxonLogger): | ||
raise RuntimeError("Handler 'OutputHandler' works only with PolyaxonLogger") | ||
|
||
metrics = self._setup_output_metrics(engine) | ||
|
||
state = engine.state if self.another_engine is None else self.another_engine.state | ||
global_step = state.get_event_attrib_value(event_name) | ||
|
||
rendered_metrics = {"step": global_step} | ||
for key, value in metrics.items(): | ||
if isinstance(value, numbers.Number): | ||
rendered_metrics["{}/{}".format(self.tag, key)] = value | ||
elif isinstance(value, torch.Tensor) and value.ndimension() == 0: | ||
rendered_metrics["{}/{}".format(self.tag, key)] = value.item() | ||
elif isinstance(value, torch.Tensor) and value.ndimension() == 1: | ||
for i, v in enumerate(value): | ||
rendered_metrics["{}/{}/{}".format(self.tag, key, i)] = v.item() | ||
else: | ||
warnings.warn("PolyaxonLogger output_handler can not log " | ||
"metrics value type {}".format(type(value))) | ||
logger.log_metrics(**rendered_metrics) | ||
|
||
|
||
class PolyaxonLogger(BaseLogger): | ||
""" | ||
`Polyaxon <https://polyaxon.com/>`_ tracking client handler to log parameters and metrics during the training | ||
and validation. | ||
This class requires `polyaxon-client <https://github.com/polyaxon/polyaxon-client/>`_ package to be installed: | ||
.. code-block:: bash | ||
pip install polyaxon-client | ||
Examples: | ||
.. code-block:: python | ||
from ignite.contrib.handlers.polyaxon_logger import * | ||
# Create a logger | ||
plx_logger = PolyaxonLogger() | ||
# Log experiment parameters: | ||
plx_logger.log_params(**{ | ||
"seed": seed, | ||
"batch_size": batch_size, | ||
"model": model.__class__.__name__, | ||
"pytorch version": torch.__version__, | ||
"ignite version": ignite.__version__, | ||
"cuda version": torch.version.cuda, | ||
"device name": torch.cuda.get_device_name(0) | ||
}) | ||
# Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch | ||
# We setup `another_engine=trainer` to take the epoch of the `trainer` instead of `train_evaluator`. | ||
plx_logger.attach(train_evaluator, | ||
log_handler=OutputHandler(tag="training", | ||
metric_names=["nll", "accuracy"], | ||
another_engine=trainer), | ||
event_name=Events.EPOCH_COMPLETED) | ||
# Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after | ||
# each epoch. We setup `another_engine=trainer` to take the epoch of the `trainer` | ||
plx_logger.attach(evaluator, | ||
log_handler=OutputHandler(tag="validation", | ||
metric_names=["nll", "accuracy"], | ||
another_engine=trainer), | ||
event_name=Events.EPOCH_COMPLETED) | ||
""" | ||
|
||
def __init__(self): | ||
try: | ||
from polyaxon_client.tracking import Experiment | ||
except ImportError: | ||
raise RuntimeError("This contrib module requires polyaxon-client to be installed. " | ||
"Please install it with command: \n pip install polyaxon-client") | ||
|
||
self.experiment = Experiment() | ||
|
||
def __getattr__(self, attr): | ||
def wrapper(*args, **kwargs): | ||
return getattr(self.experiment, attr)(*args, **kwargs) | ||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import os | ||
import tempfile | ||
import shutil | ||
|
||
import pytest | ||
|
||
from mock import MagicMock, call | ||
|
||
import torch | ||
|
||
from ignite.engine import Engine, Events, State | ||
from ignite.contrib.handlers.polyaxon_logger import * | ||
|
||
os.environ['POLYAXON_NO_OP'] = "1" | ||
|
||
|
||
def test_output_handler_with_wrong_logger_type(): | ||
|
||
wrapper = OutputHandler("tag", output_transform=lambda x: x) | ||
|
||
mock_logger = MagicMock() | ||
mock_engine = MagicMock() | ||
with pytest.raises(RuntimeError, match="Handler 'OutputHandler' works only with PolyaxonLogger"): | ||
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) | ||
|
||
|
||
def test_output_handler_output_transform(): | ||
|
||
wrapper = OutputHandler("tag", output_transform=lambda x: x) | ||
mock_logger = MagicMock(spec=PolyaxonLogger) | ||
mock_logger.log_metrics = MagicMock() | ||
|
||
mock_engine = MagicMock() | ||
mock_engine.state = State() | ||
mock_engine.state.output = 12345 | ||
mock_engine.state.iteration = 123 | ||
|
||
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) | ||
|
||
mock_logger.log_metrics.assert_called_once_with(step=123, **{"tag/output": 12345}) | ||
|
||
wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) | ||
mock_logger = MagicMock(spec=PolyaxonLogger) | ||
mock_logger.log_metrics = MagicMock() | ||
|
||
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) | ||
mock_logger.log_metrics.assert_called_once_with(step=123, **{"another_tag/loss": 12345}) | ||
|
||
|
||
def test_output_handler_metric_names(): | ||
|
||
wrapper = OutputHandler("tag", metric_names=["a", "b", "c"]) | ||
mock_logger = MagicMock(spec=PolyaxonLogger) | ||
mock_logger.log_metrics = MagicMock() | ||
|
||
mock_engine = MagicMock() | ||
mock_engine.state = State(metrics={"a": 12.23, "b": 23.45, "c": torch.tensor(10.0)}) | ||
mock_engine.state.iteration = 5 | ||
|
||
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) | ||
|
||
assert mock_logger.log_metrics.call_count == 1 | ||
mock_logger.log_metrics.assert_called_once_with( | ||
step=5, | ||
**{"tag/a": 12.23, | ||
"tag/b": 23.45, | ||
"tag/c": 10.0} | ||
) | ||
|
||
wrapper = OutputHandler("tag", metric_names=["a", ]) | ||
|
||
mock_engine = MagicMock() | ||
mock_engine.state = State(metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])}) | ||
mock_engine.state.iteration = 5 | ||
|
||
mock_logger = MagicMock(spec=PolyaxonLogger) | ||
mock_logger.log_metrics = MagicMock() | ||
|
||
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) | ||
|
||
assert mock_logger.log_metrics.call_count == 1 | ||
mock_logger.log_metrics.assert_has_calls([ | ||
call(step=5, | ||
**{"tag/a/0": 0.0, | ||
"tag/a/1": 1.0, | ||
"tag/a/2": 2.0, | ||
"tag/a/3": 3.0}), | ||
], any_order=True) | ||
|
||
wrapper = OutputHandler("tag", metric_names=["a", "c"]) | ||
|
||
mock_engine = MagicMock() | ||
mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"}) | ||
mock_engine.state.iteration = 7 | ||
|
||
mock_logger = MagicMock(spec=PolyaxonLogger) | ||
mock_logger.log_metrics = MagicMock() | ||
|
||
with pytest.warns(UserWarning): | ||
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) | ||
|
||
assert mock_logger.log_metrics.call_count == 1 | ||
mock_logger.log_metrics.assert_has_calls([ | ||
call(step=7, **{"tag/a": 55.56}) | ||
], any_order=True) | ||
|
||
|
||
def test_output_handler_both(): | ||
|
||
wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) | ||
mock_logger = MagicMock(spec=PolyaxonLogger) | ||
mock_logger.log_metrics = MagicMock() | ||
|
||
mock_engine = MagicMock() | ||
mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) | ||
mock_engine.state.epoch = 5 | ||
mock_engine.state.output = 12345 | ||
|
||
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) | ||
|
||
assert mock_logger.log_metrics.call_count == 1 | ||
mock_logger.log_metrics.assert_called_once_with( | ||
step=5, | ||
**{"tag/a": 12.23, | ||
"tag/b": 23.45, | ||
"tag/loss": 12345} | ||
) | ||
|
||
|
||
def test_integration(): | ||
|
||
n_epochs = 5 | ||
data = list(range(50)) | ||
|
||
losses = torch.rand(n_epochs * len(data)) | ||
losses_iter = iter(losses) | ||
|
||
def update_fn(engine, batch): | ||
return next(losses_iter) | ||
|
||
trainer = Engine(update_fn) | ||
|
||
plx_logger = PolyaxonLogger() | ||
|
||
def dummy_handler(engine, logger, event_name): | ||
global_step = engine.state.get_event_attrib_value(event_name) | ||
logger.log_metrics(step=global_step, **{"{}".format("test_value"): global_step}) | ||
|
||
plx_logger.attach(trainer, | ||
log_handler=dummy_handler, | ||
event_name=Events.EPOCH_COMPLETED) | ||
|
||
trainer.run(data, max_epochs=n_epochs) | ||
|
||
|
||
def test_integration_as_context_manager(): | ||
|
||
n_epochs = 5 | ||
data = list(range(50)) | ||
|
||
losses = torch.rand(n_epochs * len(data)) | ||
losses_iter = iter(losses) | ||
|
||
def update_fn(engine, batch): | ||
return next(losses_iter) | ||
|
||
with PolyaxonLogger() as plx_logger: | ||
|
||
trainer = Engine(update_fn) | ||
|
||
def dummy_handler(engine, logger, event_name): | ||
global_step = engine.state.get_event_attrib_value(event_name) | ||
logger.log_metrics(step=global_step, **{"{}".format("test_value"): global_step}) | ||
|
||
plx_logger.attach(trainer, | ||
log_handler=dummy_handler, | ||
event_name=Events.EPOCH_COMPLETED) | ||
|
||
trainer.run(data, max_epochs=n_epochs) | ||
|
||
|
||
@pytest.fixture | ||
def no_site_packages(): | ||
import sys | ||
|
||
polyaxon_client_modules = {} | ||
for k in sys.modules: | ||
if "polyaxon" in k: | ||
polyaxon_client_modules[k] = sys.modules[k] | ||
for k in polyaxon_client_modules: | ||
del sys.modules[k] | ||
|
||
prev_path = list(sys.path) | ||
sys.path = [p for p in sys.path if "site-packages" not in p] | ||
yield "no_site_packages" | ||
sys.path = prev_path | ||
for k in polyaxon_client_modules: | ||
sys.modules[k] = polyaxon_client_modules[k] | ||
|
||
|
||
def test_no_polyaxon_client(no_site_packages): | ||
|
||
with pytest.raises(RuntimeError, match=r"This contrib module requires polyaxon-client to be installed"): | ||
PolyaxonLogger() |
Oops, something went wrong.