Skip to content

Commit

Permalink
Add pagination to plot_recordings (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilkilic authored Feb 2, 2024
1 parent 87de1b6 commit eb83ba3
Showing 1 changed file with 37 additions and 34 deletions.
71 changes: 37 additions & 34 deletions bluepyefe/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,55 +224,58 @@ def plot_recordings(self, protocol_name, output_dir=None, show=False):
recordings = self.get_recordings_by_protocol_name(protocol_name)

if not len(recordings):
return None, None
return

recordings_amp = [rec.amp for rec in recordings]
recordings = [recordings[k] for k in numpy.argsort(recordings_amp)]
recordings_sorted = [recordings[k] for k in numpy.argsort(recordings_amp)]

n_cols = 6
n_rows = int(2 * numpy.ceil(len(recordings) / n_cols))
max_plots_per_page = 24
total_pages = int(numpy.ceil(len(recordings_sorted) / max_plots_per_page))

fig, axs = plt.subplots(
n_rows, n_cols,
figsize=[3.0 + 3.0 * int(n_cols), 2.5 * n_rows],
squeeze=False
)

for i, rec in enumerate(recordings):

col = i % int(n_cols)
row = 2 * int(i / n_cols)
for page in range(total_pages):
start_idx = page * max_plots_per_page
end_idx = start_idx + max_plots_per_page
page_recordings = recordings_sorted[start_idx:end_idx]

display_ylabel = col == 0
display_xlabel = row + 1 == axs.shape[0]
n_rows = int(numpy.ceil(len(page_recordings) / n_cols)) * 2

_, _ = rec.plot(
axis_current=axs[row][col],
axis_voltage=axs[row + 1][col],
display_xlabel=display_xlabel,
display_ylabel=display_ylabel
fig, axs = plt.subplots(
n_rows, n_cols,
figsize=[3.0 * n_cols, 2.5 * n_rows],
squeeze=False
)

fig.suptitle("Cell: {}, Experiment: {}".format(self.name, protocol_name))
for i, rec in enumerate(page_recordings):
col = i % n_cols
row = (i // n_cols) * 2

plt.subplots_adjust(wspace=0.53, hspace=0.7)
display_ylabel = col == 0
display_xlabel = (row // 2) + 1 == n_rows // 2

rec.plot(
axis_current=axs[row][col],
axis_voltage=axs[row + 1][col],
display_xlabel=display_xlabel,
display_ylabel=display_ylabel
)

for ax in axs.flatten():
if not ax.lines:
ax.set_visible(False)
fig.suptitle(f"Cell: {self.name}, Experiment: {protocol_name}, Page: {page + 1}")
plt.subplots_adjust(wspace=0.53, hspace=0.7)

# Do not use tight-layout, it significantly increases the runtime
plt.margins(0, 0)
for ax in axs.flatten():
if not ax.lines:
ax.set_visible(False)

if show:
fig.show()
plt.margins(0, 0)

if output_dir is not None:
filename = "{}_{}_recordings.pdf".format(self.name, protocol_name)
dirname = pathlib.Path(output_dir) / self.name
_save_fig(dirname, filename)
if show:
fig.show()

return fig, axs
if output_dir is not None:
filename = f"{self.name}_{protocol_name}_recordings_page_{page + 1}.pdf"
dirname = pathlib.Path(output_dir) / self.name
_save_fig(dirname, filename)

def plot_all_recordings(self, output_dir=None, show=False):
"""Plot all the recordings of the cell.
Expand Down

0 comments on commit eb83ba3

Please sign in to comment.