Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Barcode Plot #68

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 143 additions & 1 deletion persim/visuals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, List

__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"]
__all__ = ["plot_diagrams", "plot_barcode_diagrams", "bottleneck_matching", "wasserstein_matching"]


def plot_diagrams(
Expand Down Expand Up @@ -165,6 +166,147 @@ def plot_diagrams(
if show is True:
plt.show()

def plot_barcode_diagrams(
diagrams: np.ndarray,
plot_only: Optional[List[int]] = None,
title: Optional[str] = None,
x_range: Optional[List[int]] = None,
labels: Optional[List[str]] = None,
colormap: str="default",
x_axis_label: Optional[str] = None,
legend: bool = False,
show: bool = False,
ax: plt.Axes = None
):
"""A helper function to plot persistence barcode diagrams.

Parameters
----------
diagrams: ndarray (n_pairs, 2) or list of diagrams
A diagram or list of diagrams. If diagram is a list of diagrams,
then plot all on the same plot using different colors.
plot_only: list of numeric
If specified, an array of only the diagrams that should be plotted.
title: string, default is None
If title is defined, add it as title of the plot.
x_range: list of numeric [xmin, xmax]
User provided range of x-axis. This is useful for comparing
multiple persistence diagrams.
labels: string or list of strings
Legend labels for each diagram.
If none are specified, we use H_0, H_1, H_2,... by default.
colormap: string, default is 'default'
Any of matplotlib color palettes.
Some options are 'default', 'seaborn', 'sequential'.
See all available styles with

.. code:: python

import matplotlib as mpl
print(mpl.styles.available)
x_axis_label: str, default is None
x_axis label.
legend: bool, default is False
If true, show the legend.
show: bool, default is False
Call plt.show() after plotting. If you are using self.plot() as part
of a subplot, set show=False and call plt.show() only once at the end.
ax: matplotlib axes object, default None
An axes of the current figure.
"""
if not ax:
fig, ax = plt.subplots()

plt.style.use(colormap)

if not isinstance(diagrams, list):
# Must have diagrams as a list for processing downstream
diagrams = [diagrams]

if labels is None:
# Provide default labels for diagrams if using self.dgm_
labels = [f"$H_{{{i}}}$" for i , _ in enumerate(diagrams)]

if plot_only:
diagrams = [diagrams[i] for i in plot_only]
labels = [labels[i] for i in plot_only]

if not isinstance(labels, list):
labels = [labels] * len(diagrams)

# Construct copy with proper type of each diagram
# so we can freely edit them.
diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams]

# Plot each diagram
#y-cords for each hbar
total = sum([d.shape[0] + 1 for d in diagrams])
start = total

for i, d in enumerate(diagrams):
y= [i for i in range(start,start-d.shape[0],-1)]
bottom = np.zeros(d.shape[0]).reshape(1,-1)
#Ensure bars are always ordered by earliest birth
dc = d if d[:,0].sum() == 0 else d[d[:,0].argsort()]
for col in range(dc.shape[1]):
#plot 0 -> birth as white bar (not displayed)
if col == 0:
ax.barh(y,
dc[:,col],
color='white',
height=0.6,
left=bottom.ravel())
bottom = bottom + dc[:,col]
#plot birth -> death as colour bar (displayed)
else:
ax.barh(y,
(dc[:,col] - bottom[:,]).ravel(),
label=labels[i],
left=bottom.ravel(),
height=0.6

)
start -= d.shape[0]
#Add split lines between Groups
ax.axhline(start, color='k', linestyle='--',linewidth=0.75)
start -= 1

#Set yaxis ticklabels at centerpoint of each group
yticks = []
tmp = total
for d in diagrams:
if d.shape[0] == 1:
yticks.append(tmp)
else:
yticks.append(tmp - d.shape[0]/2)
tmp -= d.shape[0] + 1

ax.set_yticks(yticks)
ax.set_yticklabels(labels)

#Remove grid from yaxis
ax.grid(which='major',axis='x')
ax.set_facecolor('white')

if x_axis_label is not None:
ax.set_xlabel(x_axis_label)

if x_range is not None:
ax.set_xlim(x_range)

#Ensure bars aren't thick on plot
if len(ax.patches) < 15:
ax.set_ylim(ax.get_ylim()[0], 15)

if title is not None:
ax.set_title(title)

if legend is True:
ax.legend(loc="lower right")

if show is True:
plt.show()

Comment on lines +284 to +309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bf108 What do you think about removing all this? My philosophy with plotting functions has changed recently, and now I think they should be pretty minimal. In particular, I think this method should return the (fig,ax) pair it creates and then the user is free to modify it (e.g., with ax.set_title, ax.legend etc.) after it is constructed. They can also call plt.show() when they'd like, or save it directly from the ax object.

This would allow us to remove quite a bit from the function signature as well. We could then add documentation for how a user could modify the plot after its been created with this function, e.g. showing how to add and move a legend or a title.

def plot_a_bar(p, q, c='b', linestyle='-'):
plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)

Expand Down