Skip to content

Commit

Permalink
Add xtick rotation (#38)
Browse files Browse the repository at this point in the history
* add xtick rotation

* fix bug in np.logical_not
  • Loading branch information
reiinakano authored Jul 9, 2017
1 parent 477871d commit 1d1dc99
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion scikitplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals
__version__ = '0.2.6'
__version__ = '0.2.7'


from scikitplot.classifiers import classifier_factory
Expand Down
10 changes: 7 additions & 3 deletions scikitplot/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def classifier_factory(clf):
return clf


def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, do_cv=True, cv=None,
shuffle=True, random_state=None, ax=None, figsize=None,
def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, x_tick_rotation=0,
do_cv=True, cv=None, shuffle=True, random_state=None, ax=None, figsize=None,
title_fontsize="large", text_fontsize="medium"):
"""Generates the confusion matrix for a given classifier and dataset.
Expand All @@ -79,6 +79,9 @@ def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, d
normalize (bool, optional): If True, normalizes the confusion matrix before plotting.
Defaults to False.
x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is
useful in cases where there are numerous categories and the labels overlap each other.
do_cv (bool, optional): If True, the classifier is cross-validated on the dataset using the
cross-validation strategy in `cv` to generate the confusion matrix. If False, the
confusion matrix is generated without training or cross-validating the classifier.
Expand Down Expand Up @@ -158,7 +161,8 @@ def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, d
y_true = np.concatenate(trues_list)

ax = plotters.plot_confusion_matrix(y_true=y_true, y_pred=y_pred, labels=labels,
title=title, normalize=normalize, ax=ax, figsize=figsize,
title=title, normalize=normalize,
x_tick_rotation=x_tick_rotation, ax=ax, figsize=figsize,
title_fontsize=title_fontsize, text_fontsize=text_fontsize)

return ax
Expand Down
2 changes: 1 addition & 1 deletion scikitplot/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def binary_ks_curve(y_true, y_probas):
'{} category/ies'.format(len(lb.classes_)))
idx = encoded_labels == 0
data1 = np.sort(y_probas[idx])
data2 = np.sort(y_probas[-idx])
data2 = np.sort(y_probas[np.logical_not(idx)])

ctr1, ctr2 = 0, 0
thresholds, pct1, pct2 = [], [], []
Expand Down
9 changes: 6 additions & 3 deletions scikitplot/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from sklearn.metrics import silhouette_samples


def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=False, ax=None,
figsize=None, title_fontsize="large", text_fontsize="medium"):
def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=False, x_tick_rotation=0,
ax=None, figsize=None, title_fontsize="large", text_fontsize="medium"):
"""Generates confusion matrix plot for a given set of ground truth labels and classifier predictions.
Args:
Expand All @@ -44,6 +44,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=Fal
normalize (bool, optional): If True, normalizes the confusion matrix before plotting.
Defaults to False.
x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is
useful in cases where there are numerous categories and the labels overlap each other.
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to plot
the learning curve. If None, the plot is drawn on a new set of axes.
Expand Down Expand Up @@ -96,7 +99,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=Fal
plt.colorbar(mappable=image)
tick_marks = np.arange(len(classes))
ax.set_xticks(tick_marks)
ax.set_xticklabels(classes, fontsize=text_fontsize)
ax.set_xticklabels(classes, fontsize=text_fontsize, rotation=x_tick_rotation)
ax.set_yticks(tick_marks)
ax.set_yticklabels(classes, fontsize=text_fontsize)

Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import os
import sys

import scikitplot

here = os.path.abspath(os.path.dirname(__file__))


Expand Down Expand Up @@ -37,7 +35,7 @@ def run_tests(self):

setup(
name='scikit-plot',
version=scikitplot.__version__,
version='0.2.7',
url='https://github.com/reiinakano/scikit-plot',
license='MIT License',
author='Reiichiro Nakano',
Expand Down

0 comments on commit 1d1dc99

Please sign in to comment.