Skip to content

Commit

Permalink
Metric sugar (#484)
Browse files Browse the repository at this point in the history
* Added indexing to Metric, with tests

* Added confusion matrix to docs

* Update requirements.txt

* Update metrics.rst
  • Loading branch information
anmolsjoshi authored and vfdev-5 committed Apr 9, 2019
1 parent 6b8b16b commit f098795
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
sphinx
sphinx==1.8.5
-e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme
18 changes: 17 additions & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Metrics could be combined together to form new metrics. This could be done throu
as ``metric1 + metric2``, use PyTorch operators, such as ``(metric1 + metric2).pow(2).mean()``,
or use a lambda function, such as ``MetricsLambda(lambda a, b: torch.mean(a + b), metric1, metric2)``.

for example:
For example:

.. code-block:: python
Expand All @@ -54,6 +54,16 @@ for example:
that `average=False`, i.e. to use the unaveraged precision and recall,
otherwise we will not be computing F-beta metrics.

Metrics also support indexing operation (if metric's result is a vector/matrix/tensor). For example, this can be useful to compute mean metric (e.g. precision, recall or IoU) ignoring the background:

.. code-block:: python
cm = ConfusionMatrix(num_classes=10)
iou_metric = IoU(cm)
iou_no_bg_metric = iou_metric[:9] # We assume that the background index is 9
mean_iou_no_bg_metric = iou_no_bg_metric.mean()
# mean_iou_no_bg_metric.compute() -> tensor(0.12345)
.. currentmodule:: ignite.metrics

Expand Down Expand Up @@ -83,3 +93,9 @@ for example:
.. autoclass:: RunningAverage

.. autoclass:: MetricsLambda

.. autoclass:: ConfusionMatrix

.. autofunction:: IoU

.. autofunction:: mIoU
4 changes: 4 additions & 0 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,7 @@ def fn(x, *args, **kwargs):
def wrapper(*args, **kwargs):
return MetricsLambda(fn, self, *args, **kwargs)
return wrapper

def __getitem__(self, index):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x: x[index], self)
61 changes: 59 additions & 2 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import sys
from ignite.metrics import Metric, Precision, Recall
from ignite.metrics import Metric, Precision, Recall, ConfusionMatrix
from ignite.engine import Engine, State
import torch
from mock import MagicMock

from pytest import approx, raises
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix


def test_no_transform():
Expand Down Expand Up @@ -388,3 +388,60 @@ def compute_f1(y_pred, y):
return f1

_test(f1, "f1", compute_true_value_fn=compute_f1)


def test_indexing_metric():
def _test(ignite_metric, sklearn_metic, sklearn_args, index, num_classes=5):
y_pred = torch.rand(15, 10, num_classes).float()
y = torch.randint(0, num_classes, size=(15, 10)).long()

def update_fn(engine, batch):
y_pred, y = batch
return y_pred, y

metrics = {'metric': ignite_metric[index],
'metric_wo_index': ignite_metric}

validator = Engine(update_fn)

for name, metric in metrics.items():
metric.attach(validator, name)

def data(y_pred, y):
for i in range(y_pred.shape[0]):
yield (y_pred[i], y[i])

d = data(y_pred, y)
state = validator.run(d, max_epochs=1)

sklearn_output = sklearn_metic(y.view(-1).numpy(),
y_pred.view(-1, num_classes).argmax(dim=1).numpy(),
**sklearn_args)

assert (state.metrics['metric_wo_index'][index] == state.metrics['metric']).all()
assert (np.allclose(state.metrics['metric'].numpy(), sklearn_output))

num_classes = 5

labels = list(range(0, num_classes, 2))
_test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels)
labels = list(range(num_classes - 1, 0, -2))
_test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels)
labels = [1]
_test(Precision(), precision_score, {'labels': labels, 'average': None}, index=labels)

labels = list(range(0, num_classes, 2))
_test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels)
labels = list(range(num_classes - 1, 0, -2))
_test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels)
labels = [1]
_test(Recall(), recall_score, {'labels': labels, 'average': None}, index=labels)

# np.ix_ is used to allow for a 2D slice of a matrix. This is required to get accurate result from
# ConfusionMatrix. ConfusionMatrix must be sliced the same row-wise and column-wise.
labels = list(range(0, num_classes, 2))
_test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels))
labels = list(range(num_classes - 1, 0, -2))
_test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels))
labels = [1]
_test(ConfusionMatrix(num_classes), confusion_matrix, {'labels': labels}, index=np.ix_(labels, labels))

0 comments on commit f098795

Please sign in to comment.