Skip to content

Commit

Permalink
Add Gains chart and Lift chart (#71)
Browse files Browse the repository at this point in the history
* add unfinished plot_cumulative_gain and some stylefixes

* write out interface for important helper function cumulative_gain_curve

* finish plot_cumulative_gain and add example

* add tests to plot_cumulative_gain

* add plot_lift_curve, tests, and example

* add example images and new metrics in docs
  • Loading branch information
reiinakano authored Oct 25, 2017
1 parent 7d39a7b commit 3378c28
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 21 deletions.
Binary file added docs/_static/examples/plot_cumulative_gain.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/examples/plot_lift_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Metrics Module (API Reference)
==============================

.. automodule:: scikitplot.metrics
:members: plot_confusion_matrix, plot_roc_curve, plot_ks_statistic, plot_precision_recall_curve, plot_silhouette, plot_calibration_curve
:members: plot_confusion_matrix, plot_roc_curve, plot_ks_statistic, plot_precision_recall_curve, plot_silhouette, plot_calibration_curve, plot_cumulative_gain, plot_lift_curve
17 changes: 17 additions & 0 deletions examples/plot_cumulative_gain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
An example showing the plot_cumulative_gain method used
by a scikit-learn classifier
"""
from __future__ import absolute_import
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_breast_cancer as load_data
import scikitplot as skplt


X, y = load_data(return_X_y=True)
lr = LogisticRegression()
lr.fit(X, y)
probas = lr.predict_proba(X)
skplt.metrics.plot_cumulative_gain(y_true=y, y_probas=probas)
plt.show()
17 changes: 17 additions & 0 deletions examples/plot_lift_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
An example showing the plot_lift_curve method used
by a scikit-learn classifier
"""
from __future__ import absolute_import
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_breast_cancer as load_data
import scikitplot as skplt


X, y = load_data(return_X_y=True)
lr = LogisticRegression()
lr.fit(X, y)
probas = lr.predict_proba(X)
skplt.metrics.plot_lift_curve(y_true=y, y_probas=probas)
plt.show()
105 changes: 85 additions & 20 deletions scikitplot/helpers.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import absolute_import, division, print_function, \
unicode_literals
import numpy as np
from sklearn.preprocessing import LabelEncoder


def binary_ks_curve(y_true, y_probas):
"""This function generates the points necessary to calculate the KS Statistic curve.
"""This function generates the points necessary to calculate the KS
Statistic curve.
Args:
y_true (array-like, shape (n_samples)): True labels of the data.
y_probas (array-like, shape (n_samples)): Probability predictions of the positive class.
y_probas (array-like, shape (n_samples)): Probability predictions of
the positive class.
Returns:
thresholds (numpy.ndarray): An array containing the X-axis values for plotting the
KS Statistic plot.
thresholds (numpy.ndarray): An array containing the X-axis values for
plotting the KS Statistic plot.
pct1 (numpy.ndarray): An array containing the Y-axis values for one curve of the
KS Statistic plot.
pct1 (numpy.ndarray): An array containing the Y-axis values for one
curve of the KS Statistic plot.
pct2 (numpy.ndarray): An array containing the Y-axis values for one curve of the
KS Statistic plot.
pct2 (numpy.ndarray): An array containing the Y-axis values for one
curve of the KS Statistic plot.
ks_statistic (float): The KS Statistic, or the maximum vertical distance between the
two curves.
ks_statistic (float): The KS Statistic, or the maximum vertical
distance between the two curves.
max_distance_at (float): The X-axis value at which the maximum vertical distance between
the two curves is seen.
max_distance_at (float): The X-axis value at which the maximum vertical
distance between the two curves is seen.
classes (np.ndarray, shape (2)): An array containing the labels of the two classes making
up `y_true`.
classes (np.ndarray, shape (2)): An array containing the labels of the
two classes making up `y_true`.
Raises:
ValueError: If `y_true` is not composed of 2 classes. The KS Statistic is only relevant in
binary classification.
ValueError: If `y_true` is not composed of 2 classes. The KS Statistic
is only relevant in binary classification.
"""
y_true, y_probas = np.asarray(y_true), np.asarray(y_probas)
lb = LabelEncoder()
Expand Down Expand Up @@ -96,7 +99,8 @@ def binary_ks_curve(y_true, y_probas):
pct2 = np.append(pct2, [1.0])

differences = pct1 - pct2
ks_statistic, max_distance_at = np.max(differences), thresholds[np.argmax(differences)]
ks_statistic, max_distance_at = (np.max(differences),
thresholds[np.argmax(differences)])

return thresholds, pct1, pct2, ks_statistic, max_distance_at, lb.classes_

Expand Down Expand Up @@ -141,8 +145,69 @@ def validate_labels(known_classes, passed_labels, argument_name):
if np.any(passed_labels_absent):
absent_labels = [str(x) for x in passed_labels[passed_labels_absent]]

msg = "The following labels were passed into {0}, but were not found in labels: {1}" \
.format(argument_name, ", ".join(absent_labels))
msg = ("The following labels "
"were passed into {0}, "
"but were not found in "
"labels: {1}").format(argument_name, ", ".join(absent_labels))
raise ValueError(msg)

return


def cumulative_gain_curve(y_true, y_score, pos_label=None):
"""This function generates the points necessary to plot the Cumulative Gain
Note: This implementation is restricted to the binary classification task.
Args:
y_true (array-like, shape (n_samples)): True labels of the data.
y_score (array-like, shape (n_samples)): Target scores, can either be
probability estimates of the positive class, confidence values, or
non-thresholded measure of decisions (as returned by
decision_function on some classifiers).
pos_label (int or str, default=None): Label considered as positive and
others are considered negative
Returns:
percentages (numpy.ndarray): An array containing the X-axis values for
plotting the Cumulative Gains chart.
gains (numpy.ndarray): An array containing the Y-axis values for one
curve of the Cumulative Gains chart.
Raises:
ValueError: If `y_true` is not composed of 2 classes. The Cumulative
Gain Chart is only relevant in binary classification.
"""
y_true, y_score = np.asarray(y_true), np.asarray(y_score)

# ensure binary classification if pos_label is not specified
classes = np.unique(y_true)
if (pos_label is None and
not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1]))):
raise ValueError("Data is not binary and pos_label is not specified")
elif pos_label is None:
pos_label = 1.

# make y_true a boolean vector
y_true = (y_true == pos_label)

sorted_indices = np.argsort(y_score)[::-1]
y_true = y_true[sorted_indices]
gains = np.cumsum(y_true)

percentages = np.arange(start=1, stop=len(y_true) + 1)

gains = gains / float(np.sum(y_true))
percentages = percentages / float(len(y_true))

gains = np.insert(gains, 0, [0])
percentages = np.insert(percentages, 0, [0])

return percentages, gains
181 changes: 181 additions & 0 deletions scikitplot/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from scipy import interp

from scikitplot.helpers import binary_ks_curve, validate_labels
from scikitplot.helpers import cumulative_gain_curve


def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
Expand Down Expand Up @@ -780,3 +781,183 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
ax.legend(loc='lower right')

return ax


def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve',
ax=None, figsize=None, title_fontsize="large",
text_fontsize="medium"):
"""Generates the Cumulative Gains Plot from labels and scores/probabilities
The cumulative gains chart is used to determine the effectiveness of a
binary classifier. A detailed explanation can be found at
http://mlwiki.org/index.php/Cumulative_Gain_Chart. The implementation
here works only for binary classification.
Args:
y_true (array-like, shape (n_samples)):
Ground truth (correct) target values.
y_probas (array-like, shape (n_samples, n_classes)):
Prediction probabilities for each class returned by a classifier.
title (string, optional): Title of the generated plot. Defaults to
"Cumulative Gains Curve".
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
plot the learning curve. If None, the plot is drawn on a new set of
axes.
figsize (2-tuple, optional): Tuple denoting figure size of the plot
e.g. (6, 6). Defaults to ``None``.
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
Use e.g. "small", "medium", "large" or integer-values. Defaults to
"large".
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
Use e.g. "small", "medium", "large" or integer-values. Defaults to
"medium".
Returns:
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
drawn.
Example:
>>> import scikitplot as skplt
>>> lr = LogisticRegression()
>>> lr = lr.fit(X_train, y_train)
>>> y_probas = lr.predict_proba(X_test)
>>> skplt.metrics.plot_cumulative_gain(y_test, y_probas)
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
>>> plt.show()
.. image:: _static/examples/plot_cumulative_gain.png
:align: center
:alt: Cumulative Gains Plot
"""
y_true = np.array(y_true)
y_probas = np.array(y_probas)

classes = np.unique(y_true)
if len(classes) != 2:
raise ValueError('Cannot calculate Cumulative Gains for data with '
'{} category/ies'.format(len(classes)))

# Compute Cumulative Gain Curves
percentages, gains1 = cumulative_gain_curve(y_true, y_probas[:, 0],
classes[0])
percentages, gains2 = cumulative_gain_curve(y_true, y_probas[:, 1],
classes[1])

if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)

ax.set_title(title, fontsize=title_fontsize)

ax.plot(percentages, gains1, lw=3, label='Class {}'.format(classes[0]))
ax.plot(percentages, gains2, lw=3, label='Class {}'.format(classes[1]))

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.0])

ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Baseline')

ax.set_xlabel('Percentage of sample', fontsize=text_fontsize)
ax.set_ylabel('Gain', fontsize=text_fontsize)
ax.tick_params(labelsize=text_fontsize)
ax.grid('on')
ax.legend(loc='lower right', fontsize=text_fontsize)

return ax


def plot_lift_curve(y_true, y_probas, title='Lift Curve',
ax=None, figsize=None, title_fontsize="large",
text_fontsize="medium"):
"""Generates the Lift Curve from labels and scores/probabilities
The lift curve is used to determine the effectiveness of a
binary classifier. A detailed explanation can be found at
http://www2.cs.uregina.ca/~dbd/cs831/notes/lift_chart/lift_chart.html.
The implementation here works only for binary classification.
Args:
y_true (array-like, shape (n_samples)):
Ground truth (correct) target values.
y_probas (array-like, shape (n_samples, n_classes)):
Prediction probabilities for each class returned by a classifier.
title (string, optional): Title of the generated plot. Defaults to
"Lift Curve".
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
plot the learning curve. If None, the plot is drawn on a new set of
axes.
figsize (2-tuple, optional): Tuple denoting figure size of the plot
e.g. (6, 6). Defaults to ``None``.
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
Use e.g. "small", "medium", "large" or integer-values. Defaults to
"large".
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
Use e.g. "small", "medium", "large" or integer-values. Defaults to
"medium".
Returns:
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
drawn.
Example:
>>> import scikitplot as skplt
>>> lr = LogisticRegression()
>>> lr = lr.fit(X_train, y_train)
>>> y_probas = lr.predict_proba(X_test)
>>> skplt.metrics.plot_lift_curve(y_test, y_probas)
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
>>> plt.show()
.. image:: _static/examples/plot_lift_curve.png
:align: center
:alt: Lift Curve
"""
y_true = np.array(y_true)
y_probas = np.array(y_probas)

classes = np.unique(y_true)
if len(classes) != 2:
raise ValueError('Cannot calculate Lift Curve for data with '
'{} category/ies'.format(len(classes)))

# Compute Cumulative Gain Curves
percentages, gains1 = cumulative_gain_curve(y_true, y_probas[:, 0],
classes[0])
percentages, gains2 = cumulative_gain_curve(y_true, y_probas[:, 1],
classes[1])

percentages = percentages[1:]
gains1 = gains1[1:]
gains2 = gains2[1:]

gains1 = gains1 / percentages
gains2 = gains2 / percentages

if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)

ax.set_title(title, fontsize=title_fontsize)

ax.plot(percentages, gains1, lw=3, label='Class {}'.format(classes[0]))
ax.plot(percentages, gains2, lw=3, label='Class {}'.format(classes[1]))

ax.plot([0, 1], [1, 1], 'k--', lw=2, label='Baseline')

ax.set_xlabel('Percentage of sample', fontsize=text_fontsize)
ax.set_ylabel('Lift', fontsize=text_fontsize)
ax.tick_params(labelsize=text_fontsize)
ax.grid('on')
ax.legend(loc='lower right', fontsize=text_fontsize)

return ax
Loading

0 comments on commit 3378c28

Please sign in to comment.