diff --git a/persim/visuals.py b/persim/visuals.py index 9b6aeba..5ab8289 100644 --- a/persim/visuals.py +++ b/persim/visuals.py @@ -1,7 +1,8 @@ import numpy as np import matplotlib.pyplot as plt +from typing import Optional, List -__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"] +__all__ = ["plot_diagrams", "plot_barcode_diagrams", "bottleneck_matching", "wasserstein_matching"] def plot_diagrams( @@ -165,6 +166,147 @@ def plot_diagrams( if show is True: plt.show() +def plot_barcode_diagrams( + diagrams: np.ndarray, + plot_only: Optional[List[int]] = None, + title: Optional[str] = None, + x_range: Optional[List[int]] = None, + labels: Optional[List[str]] = None, + colormap: str="default", + x_axis_label: Optional[str] = None, + legend: bool = False, + show: bool = False, + ax: plt.Axes = None +): + """A helper function to plot persistence barcode diagrams. + + Parameters + ---------- + diagrams: ndarray (n_pairs, 2) or list of diagrams + A diagram or list of diagrams. If diagram is a list of diagrams, + then plot all on the same plot using different colors. + plot_only: list of numeric + If specified, an array of only the diagrams that should be plotted. + title: string, default is None + If title is defined, add it as title of the plot. + x_range: list of numeric [xmin, xmax] + User provided range of x-axis. This is useful for comparing + multiple persistence diagrams. + labels: string or list of strings + Legend labels for each diagram. + If none are specified, we use H_0, H_1, H_2,... by default. + colormap: string, default is 'default' + Any of matplotlib color palettes. + Some options are 'default', 'seaborn', 'sequential'. + See all available styles with + + .. code:: python + + import matplotlib as mpl + print(mpl.styles.available) + x_axis_label: str, default is None + x_axis label. + legend: bool, default is False + If true, show the legend. + show: bool, default is False + Call plt.show() after plotting. If you are using self.plot() as part + of a subplot, set show=False and call plt.show() only once at the end. + ax: matplotlib axes object, default None + An axes of the current figure. + """ + if not ax: + fig, ax = plt.subplots() + + plt.style.use(colormap) + + if not isinstance(diagrams, list): + # Must have diagrams as a list for processing downstream + diagrams = [diagrams] + + if labels is None: + # Provide default labels for diagrams if using self.dgm_ + labels = [f"$H_{{{i}}}$" for i , _ in enumerate(diagrams)] + + if plot_only: + diagrams = [diagrams[i] for i in plot_only] + labels = [labels[i] for i in plot_only] + + if not isinstance(labels, list): + labels = [labels] * len(diagrams) + + # Construct copy with proper type of each diagram + # so we can freely edit them. + diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams] + + # Plot each diagram + #y-cords for each hbar + total = sum([d.shape[0] + 1 for d in diagrams]) + start = total + + for i, d in enumerate(diagrams): + y= [i for i in range(start,start-d.shape[0],-1)] + bottom = np.zeros(d.shape[0]).reshape(1,-1) + #Ensure bars are always ordered by earliest birth + dc = d if d[:,0].sum() == 0 else d[d[:,0].argsort()] + for col in range(dc.shape[1]): + #plot 0 -> birth as white bar (not displayed) + if col == 0: + ax.barh(y, + dc[:,col], + color='white', + height=0.6, + left=bottom.ravel()) + bottom = bottom + dc[:,col] + #plot birth -> death as colour bar (displayed) + else: + ax.barh(y, + (dc[:,col] - bottom[:,]).ravel(), + label=labels[i], + left=bottom.ravel(), + height=0.6 + + ) + start -= d.shape[0] + #Add split lines between Groups + ax.axhline(start, color='k', linestyle='--',linewidth=0.75) + start -= 1 + + #Set yaxis ticklabels at centerpoint of each group + yticks = [] + tmp = total + for d in diagrams: + if d.shape[0] == 1: + yticks.append(tmp) + else: + yticks.append(tmp - d.shape[0]/2) + tmp -= d.shape[0] + 1 + + ax.set_yticks(yticks) + ax.set_yticklabels(labels) + + #Remove grid from yaxis + ax.grid(which='major',axis='x') + ax.set_facecolor('white') + + if x_axis_label is not None: + ax.set_xlabel(x_axis_label) + + if x_range is not None: + ax.set_xlim(x_range) + + #Ensure bars aren't thick on plot + if len(ax.patches) < 15: + ax.set_ylim(ax.get_ylim()[0], 15) + + if title is not None: + ax.set_title(title) + + if legend is True: + ax.legend(loc="lower right") + + if show is True: + plt.show() + def plot_a_bar(p, q, c='b', linestyle='-'): plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)