diff --git a/bluepyefe/cell.py b/bluepyefe/cell.py index 90d594c..102cf00 100644 --- a/bluepyefe/cell.py +++ b/bluepyefe/cell.py @@ -26,6 +26,7 @@ from bluepyefe.ecode import eCodes from bluepyefe.reader import * from bluepyefe.plotting import _save_fig +from matplotlib.backends.backend_pdf import PdfPages logger = logging.getLogger(__name__) @@ -230,52 +231,58 @@ def plot_recordings(self, protocol_name, output_dir=None, show=False): recordings_sorted = [recordings[k] for k in numpy.argsort(recordings_amp)] n_cols = 6 - max_plots_per_page = 24 + max_plots_per_page = 12 total_pages = int(numpy.ceil(len(recordings_sorted) / max_plots_per_page)) - for page in range(total_pages): - start_idx = page * max_plots_per_page - end_idx = start_idx + max_plots_per_page - page_recordings = recordings_sorted[start_idx:end_idx] + if output_dir is not None: + filename = f"{self.name}_{protocol_name}_recordings.pdf" + dirname = pathlib.Path(output_dir) / self.name + dirname.mkdir(parents=True, exist_ok=True) + filepath = dirname / filename - n_rows = int(numpy.ceil(len(page_recordings) / n_cols)) * 2 + with PdfPages(filepath) as pdf: + for page in range(total_pages): + start_idx = page * max_plots_per_page + end_idx = start_idx + max_plots_per_page + page_recordings = recordings_sorted[start_idx:end_idx] - fig, axs = plt.subplots( - n_rows, n_cols, - figsize=[3.0 * n_cols, 2.5 * n_rows], - squeeze=False - ) + n_rows = int(numpy.ceil(len(page_recordings) / n_cols)) * 2 - for i, rec in enumerate(page_recordings): - col = i % n_cols - row = (i // n_cols) * 2 + fig, axs = plt.subplots( + n_rows, n_cols, + figsize=[3.0 * n_cols, 2.5 * n_rows], + squeeze=False + ) - display_ylabel = col == 0 - display_xlabel = (row // 2) + 1 == n_rows // 2 + for i, rec in enumerate(page_recordings): + col = i % n_cols + row = (i // n_cols) * 2 - rec.plot( - axis_current=axs[row][col], - axis_voltage=axs[row + 1][col], - display_xlabel=display_xlabel, - display_ylabel=display_ylabel - ) + display_ylabel = col == 0 + display_xlabel = (row // 2) + 1 == n_rows // 2 + + rec.plot( + axis_current=axs[row][col], + axis_voltage=axs[row + 1][col], + display_xlabel=display_xlabel, + display_ylabel=display_ylabel + ) - fig.suptitle(f"Cell: {self.name}, Experiment: {protocol_name}, Page: {page + 1}") - plt.subplots_adjust(wspace=0.53, hspace=0.7) + fig.suptitle(f"Cell: {self.name}, Experiment: {protocol_name}, Page: {page + 1}") + plt.subplots_adjust(wspace=0.53, hspace=0.7) - for ax in axs.flatten(): - if not ax.lines: - ax.set_visible(False) + for ax in axs.flatten(): + if not ax.lines: + ax.set_visible(False) - plt.margins(0, 0) + plt.margins(0, 0) - if show: - fig.show() + if show: + plt.show() - if output_dir is not None: - filename = f"{self.name}_{protocol_name}_recordings_page_{page + 1}.pdf" - dirname = pathlib.Path(output_dir) / self.name - _save_fig(dirname, filename) + pdf.savefig(fig, dpi=80) + plt.close("all") + plt.clf() def plot_all_recordings(self, output_dir=None, show=False): """Plot all the recordings of the cell.