diff --git a/scikitplot/__init__.py b/scikitplot/__init__.py index 82ef1e5..e223419 100644 --- a/scikitplot/__init__.py +++ b/scikitplot/__init__.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, division, print_function, unicode_literals -__version__ = '0.2.6' +__version__ = '0.2.7' from scikitplot.classifiers import classifier_factory diff --git a/scikitplot/classifiers.py b/scikitplot/classifiers.py index d5fb36b..eba688e 100644 --- a/scikitplot/classifiers.py +++ b/scikitplot/classifiers.py @@ -53,8 +53,8 @@ def classifier_factory(clf): return clf -def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, do_cv=True, cv=None, - shuffle=True, random_state=None, ax=None, figsize=None, +def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, x_tick_rotation=0, + do_cv=True, cv=None, shuffle=True, random_state=None, ax=None, figsize=None, title_fontsize="large", text_fontsize="medium"): """Generates the confusion matrix for a given classifier and dataset. @@ -79,6 +79,9 @@ def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, d normalize (bool, optional): If True, normalizes the confusion matrix before plotting. Defaults to False. + x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is + useful in cases where there are numerous categories and the labels overlap each other. + do_cv (bool, optional): If True, the classifier is cross-validated on the dataset using the cross-validation strategy in `cv` to generate the confusion matrix. If False, the confusion matrix is generated without training or cross-validating the classifier. @@ -158,7 +161,8 @@ def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, d y_true = np.concatenate(trues_list) ax = plotters.plot_confusion_matrix(y_true=y_true, y_pred=y_pred, labels=labels, - title=title, normalize=normalize, ax=ax, figsize=figsize, + title=title, normalize=normalize, + x_tick_rotation=x_tick_rotation, ax=ax, figsize=figsize, title_fontsize=title_fontsize, text_fontsize=text_fontsize) return ax diff --git a/scikitplot/helpers.py b/scikitplot/helpers.py index 1f213ad..e04d260 100644 --- a/scikitplot/helpers.py +++ b/scikitplot/helpers.py @@ -42,7 +42,7 @@ def binary_ks_curve(y_true, y_probas): '{} category/ies'.format(len(lb.classes_))) idx = encoded_labels == 0 data1 = np.sort(y_probas[idx]) - data2 = np.sort(y_probas[-idx]) + data2 = np.sort(y_probas[np.logical_not(idx)]) ctr1, ctr2 = 0, 0 thresholds, pct1, pct2 = [], [], [] diff --git a/scikitplot/plotters.py b/scikitplot/plotters.py index 8ba096b..699bed6 100644 --- a/scikitplot/plotters.py +++ b/scikitplot/plotters.py @@ -22,8 +22,8 @@ from sklearn.metrics import silhouette_samples -def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=False, ax=None, - figsize=None, title_fontsize="large", text_fontsize="medium"): +def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=False, x_tick_rotation=0, + ax=None, figsize=None, title_fontsize="large", text_fontsize="medium"): """Generates confusion matrix plot for a given set of ground truth labels and classifier predictions. Args: @@ -44,6 +44,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=Fal normalize (bool, optional): If True, normalizes the confusion matrix before plotting. Defaults to False. + x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is + useful in cases where there are numerous categories and the labels overlap each other. + ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to plot the learning curve. If None, the plot is drawn on a new set of axes. @@ -96,7 +99,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=Fal plt.colorbar(mappable=image) tick_marks = np.arange(len(classes)) ax.set_xticks(tick_marks) - ax.set_xticklabels(classes, fontsize=text_fontsize) + ax.set_xticklabels(classes, fontsize=text_fontsize, rotation=x_tick_rotation) ax.set_yticks(tick_marks) ax.set_yticklabels(classes, fontsize=text_fontsize) diff --git a/setup.py b/setup.py index bdf07cd..b18a76a 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,6 @@ import os import sys -import scikitplot - here = os.path.abspath(os.path.dirname(__file__)) @@ -37,7 +35,7 @@ def run_tests(self): setup( name='scikit-plot', - version=scikitplot.__version__, + version='0.2.7', url='https://github.com/reiinakano/scikit-plot', license='MIT License', author='Reiichiro Nakano',