From bb8cac9ce1d10f360cb7c617710dec108c56dbd2 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 14 Nov 2024 20:59:09 -0500 Subject: [PATCH] Create thinkstats.py Restoring thinkstats.py --- code/thinkstats.py | 890 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 890 insertions(+) create mode 100644 code/thinkstats.py diff --git a/code/thinkstats.py b/code/thinkstats.py new file mode 100644 index 00000000..e68f2467 --- /dev/null +++ b/code/thinkstats.py @@ -0,0 +1,890 @@ +"""This file contains code for use with "Think Stats", +by Allen B. Downey, available from greenteapress.com + +Copyright 2014 Allen B. Downey +License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html +""" + +from __future__ import print_function + +import math +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +import warnings + +# customize some matplotlib attributes +#matplotlib.rc('figure', figsize=(4, 3)) + +#matplotlib.rc('font', size=14.0) +#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0) +#matplotlib.rc('legend', fontsize=20.0) + +#matplotlib.rc('xtick.major', size=6.0) +#matplotlib.rc('xtick.minor', size=3.0) + +#matplotlib.rc('ytick.major', size=6.0) +#matplotlib.rc('ytick.minor', size=3.0) + + +class _Brewer(object): + """Encapsulates a nice sequence of colors. + + Shades of blue that look good in color and can be distinguished + in grayscale (up to a point). + + Borrowed from http://colorbrewer2.org/ + """ + color_iter = None + + colors = ['#f7fbff', '#deebf7', '#c6dbef', + '#9ecae1', '#6baed6', '#4292c6', + '#2171b5','#08519c','#08306b'][::-1] + + # lists that indicate which colors to use depending on how many are used + which_colors = [[], + [1], + [1, 3], + [0, 2, 4], + [0, 2, 4, 6], + [0, 2, 3, 5, 6], + [0, 2, 3, 4, 5, 6], + [0, 1, 2, 3, 4, 5, 6], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7, 8], + ] + + current_figure = None + + @classmethod + def Colors(cls): + """Returns the list of colors. + """ + return cls.colors + + @classmethod + def ColorGenerator(cls, num): + """Returns an iterator of color strings. + + n: how many colors will be used + """ + for i in cls.which_colors[num]: + yield cls.colors[i] + raise StopIteration('Ran out of colors in _Brewer.') + + @classmethod + def InitIter(cls, num): + """Initializes the color iterator with the given number of colors.""" + cls.color_iter = cls.ColorGenerator(num) + fig = plt.gcf() + cls.current_figure = fig + + @classmethod + def ClearIter(cls): + """Sets the color iterator to None.""" + cls.color_iter = None + cls.current_figure = None + + @classmethod + def GetIter(cls, num): + """Gets the color iterator.""" + fig = plt.gcf() + if fig != cls.current_figure: + cls.InitIter(num) + cls.current_figure = fig + + if cls.color_iter is None: + cls.InitIter(num) + + return cls.color_iter + + +def _UnderrideColor(options): + """If color is not in the options, chooses a color. + """ + if 'color' in options: + return options + + # get the current color iterator; if there is none, init one + color_iter = _Brewer.GetIter(5) + + try: + options['color'] = next(color_iter) + except StopIteration: + # if you run out of colors, initialize the color iterator + # and try again + warnings.warn('Ran out of colors. Starting over.') + _Brewer.ClearIter() + _UnderrideColor(options) + + return options + + +def PrePlot(num=None, rows=None, cols=None): + """Takes hints about what's coming. + + num: number of lines that will be plotted + rows: number of rows of subplots + cols: number of columns of subplots + """ + if num: + _Brewer.InitIter(num) + + if rows is None and cols is None: + return + + if rows is not None and cols is None: + cols = 1 + + if cols is not None and rows is None: + rows = 1 + + # resize the image, depending on the number of rows and cols + size_map = {(1, 1): (8, 6), + (1, 2): (12, 6), + (1, 3): (12, 6), + (1, 4): (12, 5), + (1, 5): (12, 4), + (2, 2): (10, 10), + (2, 3): (16, 10), + (3, 1): (8, 10), + (4, 1): (8, 12), + } + + if (rows, cols) in size_map: + fig = plt.gcf() + fig.set_size_inches(*size_map[rows, cols]) + + # create the first subplot + if rows > 1 or cols > 1: + ax = plt.subplot(rows, cols, 1) + global SUBPLOT_ROWS, SUBPLOT_COLS + SUBPLOT_ROWS = rows + SUBPLOT_COLS = cols + else: + ax = plt.gca() + + return ax + + +def SubPlot(plot_number, rows=None, cols=None, **options): + """Configures the number of subplots and changes the current plot. + + rows: int + cols: int + plot_number: int + options: passed to subplot + """ + rows = rows or SUBPLOT_ROWS + cols = cols or SUBPLOT_COLS + return plt.subplot(rows, cols, plot_number, **options) + + +def _Underride(d, **options): + """Add key-value pairs to d only if key is not in d. + + If d is None, create a new dictionary. + + d: dictionary + options: keyword args to add to d + """ + if d is None: + d = {} + + for key, val in options.items(): + d.setdefault(key, val) + + return d + + +def Clf(): + """Clears the figure and any hints that have been set.""" + global LOC + LOC = None + _Brewer.ClearIter() + plt.clf() + fig = plt.gcf() + fig.set_size_inches(8, 6) + + +def Figure(**options): + """Sets options for the current figure.""" + _Underride(options, figsize=(6, 8)) + plt.figure(**options) + + +def Plot(obj, ys=None, style='', **options): + """Plots a line. + + Args: + obj: sequence of x values, or Series, or anything with Render() + ys: sequence of y values + style: style string passed along to plt.plot + options: keyword args passed to plt.plot + """ + options = _UnderrideColor(options) + label = getattr(obj, 'label', '_nolegend_') + options = _Underride(options, linewidth=3, alpha=0.7, label=label) + + xs = obj + if ys is None: + if hasattr(obj, 'Render'): + xs, ys = obj.Render() + if isinstance(obj, pd.Series): + ys = obj.values + xs = obj.index + + if ys is None: + plt.plot(xs, style, **options) + else: + plt.plot(xs, ys, style, **options) + + +def Vlines(xs, y1, y2, **options): + """Plots a set of vertical lines. + + Args: + xs: sequence of x values + y1: sequence of y values + y2: sequence of y values + options: keyword args passed to plt.vlines + """ + options = _UnderrideColor(options) + options = _Underride(options, linewidth=1, alpha=0.5) + plt.vlines(xs, y1, y2, **options) + + +def Hlines(ys, x1, x2, **options): + """Plots a set of horizontal lines. + + Args: + ys: sequence of y values + x1: sequence of x values + x2: sequence of x values + options: keyword args passed to plt.vlines + """ + options = _UnderrideColor(options) + options = _Underride(options, linewidth=1, alpha=0.5) + plt.hlines(ys, x1, x2, **options) + + +def axvline(x, **options): + """Plots a vertical line. + + Args: + x: x location + options: keyword args passed to plt.axvline + """ + options = _UnderrideColor(options) + options = _Underride(options, linewidth=1, alpha=0.5) + plt.axvline(x, **options) + + +def axhline(y, **options): + """Plots a horizontal line. + + Args: + y: y location + options: keyword args passed to plt.axhline + """ + options = _UnderrideColor(options) + options = _Underride(options, linewidth=1, alpha=0.5) + plt.axhline(y, **options) + + +def tight_layout(**options): + """Adjust subplots to minimize padding and margins. + """ + options = _Underride(options, + wspace=0.1, hspace=0.1, + left=0, right=1, + bottom=0, top=1) + plt.tight_layout() + plt.subplots_adjust(**options) + + +def FillBetween(xs, y1, y2=None, where=None, **options): + """Fills the space between two lines. + + Args: + xs: sequence of x values + y1: sequence of y values + y2: sequence of y values + where: sequence of boolean + options: keyword args passed to plt.fill_between + """ + options = _UnderrideColor(options) + options = _Underride(options, linewidth=0, alpha=0.5) + plt.fill_between(xs, y1, y2, where, **options) + + +def Bar(xs, ys, **options): + """Plots a line. + + Args: + xs: sequence of x values + ys: sequence of y values + options: keyword args passed to plt.bar + """ + options = _UnderrideColor(options) + options = _Underride(options, linewidth=0, alpha=0.6) + plt.bar(xs, ys, **options) + + +def Scatter(xs, ys=None, **options): + """Makes a scatter plot. + + xs: x values + ys: y values + options: options passed to plt.scatter + """ + options = _Underride(options, color='blue', alpha=0.2, + s=30, edgecolors='none') + + if ys is None and isinstance(xs, pd.Series): + ys = xs.values + xs = xs.index + + plt.scatter(xs, ys, **options) + + +def HexBin(xs, ys, **options): + """Makes a scatter plot. + + xs: x values + ys: y values + options: options passed to plt.scatter + """ + options = _Underride(options, cmap=matplotlib.cm.Blues) + plt.hexbin(xs, ys, **options) + + +def Pdf(pdf, **options): + """Plots a Pdf, Pmf, or Hist as a line. + + Args: + pdf: Pdf, Pmf, or Hist object + options: keyword args passed to plt.plot + """ + low, high = options.pop('low', None), options.pop('high', None) + n = options.pop('n', 101) + xs, ps = pdf.Render(low=low, high=high, n=n) + options = _Underride(options, label=pdf.label) + Plot(xs, ps, **options) + + +def Pdfs(pdfs, **options): + """Plots a sequence of PDFs. + + Options are passed along for all PDFs. If you want different + options for each pdf, make multiple calls to Pdf. + + Args: + pdfs: sequence of PDF objects + options: keyword args passed to plt.plot + """ + for pdf in pdfs: + Pdf(pdf, **options) + + +def Hist(hist, **options): + """Plots a Pmf or Hist with a bar plot. + + The default width of the bars is based on the minimum difference + between values in the Hist. If that's too small, you can override + it by providing a width keyword argument, in the same units + as the values. + + Args: + hist: Hist or Pmf object + options: keyword args passed to plt.bar + """ + # find the minimum distance between adjacent values + xs, ys = hist.Render() + + # see if the values support arithmetic + try: + xs[0] - xs[0] + except TypeError: + # if not, replace values with numbers + labels = [str(x) for x in xs] + xs = np.arange(len(xs)) + plt.xticks(xs+0.5, labels) + + if 'width' not in options: + try: + options['width'] = 0.9 * np.diff(xs).min() + except TypeError: + warnings.warn("Hist: Can't compute bar width automatically." + "Check for non-numeric types in Hist." + "Or try providing width option." + ) + + options = _Underride(options, label=hist.label) + options = _Underride(options, align='center') + if options['align'] == 'left': + options['align'] = 'edge' + elif options['align'] == 'right': + options['align'] = 'edge' + options['width'] *= -1 + + Bar(xs, ys, **options) + + +def Hists(hists, **options): + """Plots two histograms as interleaved bar plots. + + Options are passed along for all PMFs. If you want different + options for each pmf, make multiple calls to Pmf. + + Args: + hists: list of two Hist or Pmf objects + options: keyword args passed to plt.plot + """ + for hist in hists: + Hist(hist, **options) + + +def Pmf(pmf, **options): + """Plots a Pmf or Hist as a line. + + Args: + pmf: Hist or Pmf object + options: keyword args passed to plt.plot + """ + xs, ys = pmf.Render() + low, high = min(xs), max(xs) + + width = options.pop('width', None) + if width is None: + try: + width = np.diff(xs).min() + except TypeError: + warnings.warn("Pmf: Can't compute bar width automatically." + "Check for non-numeric types in Pmf." + "Or try providing width option.") + points = [] + + lastx = np.nan + lasty = 0 + for x, y in zip(xs, ys): + if (x - lastx) > 1e-5: + points.append((lastx, 0)) + points.append((x, 0)) + + points.append((x, lasty)) + points.append((x, y)) + points.append((x+width, y)) + + lastx = x + width + lasty = y + points.append((lastx, 0)) + pxs, pys = zip(*points) + + align = options.pop('align', 'center') + if align == 'center': + pxs = np.array(pxs) - width/2.0 + if align == 'right': + pxs = np.array(pxs) - width + + options = _Underride(options, label=pmf.label) + Plot(pxs, pys, **options) + + +def Pmfs(pmfs, **options): + """Plots a sequence of PMFs. + + Options are passed along for all PMFs. If you want different + options for each pmf, make multiple calls to Pmf. + + Args: + pmfs: sequence of PMF objects + options: keyword args passed to plt.plot + """ + for pmf in pmfs: + Pmf(pmf, **options) + + +def Diff(t): + """Compute the differences between adjacent elements in a sequence. + + Args: + t: sequence of number + + Returns: + sequence of differences (length one less than t) + """ + diffs = [t[i+1] - t[i] for i in range(len(t)-1)] + return diffs + + +def Cdf(cdf, complement=False, transform=None, **options): + """Plots a CDF as a line. + + Args: + cdf: Cdf object + complement: boolean, whether to plot the complementary CDF + transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel' + options: keyword args passed to plt.plot + + Returns: + dictionary with the scale options that should be passed to + Config, Show or Save. + """ + xs, ps = cdf.Render() + xs = np.asarray(xs) + ps = np.asarray(ps) + + scale = dict(xscale='linear', yscale='linear') + + for s in ['xscale', 'yscale']: + if s in options: + scale[s] = options.pop(s) + + if transform == 'exponential': + complement = True + scale['yscale'] = 'log' + + if transform == 'pareto': + complement = True + scale['yscale'] = 'log' + scale['xscale'] = 'log' + + if complement: + ps = [1.0-p for p in ps] + + if transform == 'weibull': + xs = np.delete(xs, -1) + ps = np.delete(ps, -1) + ps = [-math.log(1.0-p) for p in ps] + scale['xscale'] = 'log' + scale['yscale'] = 'log' + + if transform == 'gumbel': + xs = np.delete(xs, 0) + ps = np.delete(ps, 0) + ps = [-math.log(p) for p in ps] + scale['yscale'] = 'log' + + options = _Underride(options, label=cdf.label) + Plot(xs, ps, **options) + return scale + + +def Cdfs(cdfs, complement=False, transform=None, **options): + """Plots a sequence of CDFs. + + cdfs: sequence of CDF objects + complement: boolean, whether to plot the complementary CDF + transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel' + options: keyword args passed to plt.plot + """ + for cdf in cdfs: + Cdf(cdf, complement, transform, **options) + + +def Contour(obj, pcolor=False, contour=True, imshow=False, **options): + """Makes a contour plot. + + d: map from (x, y) to z, or object that provides GetDict + pcolor: boolean, whether to make a pseudocolor plot + contour: boolean, whether to make a contour plot + imshow: boolean, whether to use plt.imshow + options: keyword args passed to plt.pcolor and/or plt.contour + """ + try: + d = obj.GetDict() + except AttributeError: + d = obj + + _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues) + + xs, ys = zip(*d.keys()) + xs = sorted(set(xs)) + ys = sorted(set(ys)) + + X, Y = np.meshgrid(xs, ys) + func = lambda x, y: d.get((x, y), 0) + func = np.vectorize(func) + Z = func(X, Y) + + x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False) + axes = plt.gca() + axes.xaxis.set_major_formatter(x_formatter) + + if pcolor: + plt.pcolormesh(X, Y, Z, **options) + if contour: + cs = plt.contour(X, Y, Z, **options) + plt.clabel(cs, inline=1, fontsize=10) + if imshow: + extent = xs[0], xs[-1], ys[0], ys[-1] + plt.imshow(Z, extent=extent, **options) + + +def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options): + """Makes a pseudocolor plot. + + xs: + ys: + zs: + pcolor: boolean, whether to make a pseudocolor plot + contour: boolean, whether to make a contour plot + options: keyword args passed to plt.pcolor and/or plt.contour + """ + _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues) + + X, Y = np.meshgrid(xs, ys) + Z = zs + + x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False) + axes = plt.gca() + axes.xaxis.set_major_formatter(x_formatter) + + if pcolor: + plt.pcolormesh(X, Y, Z, **options) + + if contour: + cs = plt.contour(X, Y, Z, **options) + plt.clabel(cs, inline=1, fontsize=10) + + +def Text(x, y, s, **options): + """Puts text in a figure. + + x: number + y: number + s: string + options: keyword args passed to plt.text + """ + options = _Underride(options, + fontsize=16, + verticalalignment='top', + horizontalalignment='left') + plt.text(x, y, s, **options) + + +LEGEND = True +LOC = None + +def Config(**options): + """Configures the plot. + + Pulls options out of the option dictionary and passes them to + the corresponding plt functions. + """ + names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale', + 'xticks', 'yticks', 'axis', 'xlim', 'ylim'] + + for name in names: + if name in options: + getattr(plt, name)(options[name]) + + global LEGEND + LEGEND = options.get('legend', LEGEND) + + # see if there are any elements with labels; + # if not, don't draw a legend + ax = plt.gca() + handles, labels = ax.get_legend_handles_labels() + + if LEGEND and len(labels) > 0: + global LOC + LOC = options.get('loc', LOC) + frameon = options.get('frameon', True) + + try: + plt.legend(loc=LOC, frameon=frameon) + except UserWarning: + pass + + # x and y ticklabels can be made invisible + val = options.get('xticklabels', None) + if val is not None: + if val == 'invisible': + ax = plt.gca() + labels = ax.get_xticklabels() + plt.setp(labels, visible=False) + + val = options.get('yticklabels', None) + if val is not None: + if val == 'invisible': + ax = plt.gca() + labels = ax.get_yticklabels() + plt.setp(labels, visible=False) + +def set_font_size(title_size=16, label_size=16, ticklabel_size=14, legend_size=14): + """Set font sizes for the title, labels, ticklabels, and legend. + """ + def set_text_size(texts, size): + for text in texts: + text.set_size(size) + + ax = plt.gca() + + # TODO: Make this function more robust if any of these elements + # is missing. + + # title + ax.title.set_size(title_size) + + # x axis + ax.xaxis.label.set_size(label_size) + set_text_size(ax.xaxis.get_ticklabels(), ticklabel_size) + + # y axis + ax.yaxis.label.set_size(label_size) + set_text_size(ax.yaxis.get_ticklabels(), ticklabel_size) + + # legend + legend = ax.get_legend() + if legend is not None: + set_text_size(legend.texts, legend_size) + + +def bigger_text(): + sizes = dict(title_size=16, label_size=16, ticklabel_size=14, legend_size=14) + set_font_size(**sizes) + + +def Show(**options): + """Shows the plot. + + For options, see Config. + + options: keyword args used to invoke various plt functions + """ + clf = options.pop('clf', True) + Config(**options) + plt.show() + if clf: + Clf() + + +def Plotly(**options): + """Shows the plot. + + For options, see Config. + + options: keyword args used to invoke various plt functions + """ + clf = options.pop('clf', True) + Config(**options) + import plotly.plotly as plotly + url = plotly.plot_mpl(plt.gcf()) + if clf: + Clf() + return url + + +def Save(root=None, formats=None, **options): + """Saves the plot in the given formats and clears the figure. + + For options, see Config. + + Note: With a capital S, this is the original save, maintained for + compatibility. New code should use save(), which works better + with my newer code, especially in Jupyter notebooks. + + Args: + root: string filename root + formats: list of string formats + options: keyword args used to invoke various plt functions + """ + clf = options.pop('clf', True) + + save_options = {} + for option in ['bbox_inches', 'pad_inches']: + if option in options: + save_options[option] = options.pop(option) + + # TODO: falling Config inside Save was probably a mistake, but removing + # it will require some work + Config(**options) + + if formats is None: + formats = ['pdf', 'png'] + + try: + formats.remove('plotly') + Plotly(clf=False) + except ValueError: + pass + + if root: + for fmt in formats: + SaveFormat(root, fmt, **save_options) + if clf: + Clf() + + +def save(root, formats=None, **options): + """Saves the plot in the given formats and clears the figure. + + For options, see plt.savefig. + + Args: + root: string filename root + formats: list of string formats + options: keyword args passed to plt.savefig + """ + if formats is None: + formats = ['pdf', 'png'] + + try: + formats.remove('plotly') + Plotly(clf=False) + except ValueError: + pass + + for fmt in formats: + SaveFormat(root, fmt, **options) + + +def SaveFormat(root, fmt='eps', **options): + """Writes the current figure to a file in the given format. + + Args: + root: string filename root + fmt: string format + """ + _Underride(options, dpi=300) + filename = '%s.%s' % (root, fmt) + print('Writing', filename) + plt.savefig(filename, format=fmt, **options) + + +# provide aliases for calling functions with lower-case names +preplot = PrePlot +subplot = SubPlot +clf = Clf +figure = Figure +plot = Plot +vlines = Vlines +hlines = Hlines +fill_between = FillBetween +text = Text +scatter = Scatter +pmf = Pmf +pmfs = Pmfs +hist = Hist +hists = Hists +diff = Diff +cdf = Cdf +cdfs = Cdfs +contour = Contour +pcolor = Pcolor +config = Config +show = Show + + +def main(): + color_iter = _Brewer.ColorGenerator(7) + for color in color_iter: + print(color) + + +if __name__ == '__main__': + main()