diff --git a/bluepyefe/cell.py b/bluepyefe/cell.py index b91b1a1..7c57dbf 100644 --- a/bluepyefe/cell.py +++ b/bluepyefe/cell.py @@ -18,6 +18,7 @@ along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. """ +from collections import defaultdict import logging from multiprocessing import Pool import numpy @@ -29,10 +30,19 @@ from bluepyefe.plotting import _save_fig from matplotlib.backends.backend_pdf import PdfPages +from bluepyefe.recording import Recording + logger = logging.getLogger(__name__) -class Cell(object): +def extract_efeatures_helper(recording, efeatures, efeature_names, efel_settings): + """Helper function to compute efeatures for a single recording.""" + recording.compute_efeatures( + efeatures, efeature_names, efel_settings) + return recording + + +class Cell: """Contains the metadata related to a cell as well as the electrophysiological recordings once they are read""" @@ -47,7 +57,7 @@ def __init__(self, name): self.name = name - self.recordings = [] + self.recordings: dict[str, list[Recording]] = defaultdict(list) self.rheobase = None def reader(self, config_data, recording_reader=None): @@ -91,9 +101,8 @@ def reader(self, config_data, recording_reader=None): ) def get_protocol_names(self): - """List of all the protocols available for the present cell.""" - - return list(set([rec.protocol_name for rec in self.recordings])) + """List of all the protocol names available for the present cell.""" + return list(self.recordings.keys()) def get_recordings_by_protocol_name(self, protocol_name): """List of all the recordings available for the present cell for a @@ -103,27 +112,7 @@ def get_recordings_by_protocol_name(self, protocol_name): protocol_name (str): name of the protocol for which to get the recordings. """ - - return [ - rec - for rec in self.recordings - if rec.protocol_name == protocol_name - ] - - def get_recordings_id_by_protocol_name(self, protocol_name): - """List of the indexes of the recordings available for the present - cell for a given protocol. - - Args: - protocol_name (str): name of the protocol for which to get - the recordings. - """ - - return [ - i - for i, trace in enumerate(self.recordings) - if trace.protocol_name == protocol_name - ] + return self.recordings.get(protocol_name) def read_recordings( self, @@ -164,7 +153,7 @@ def read_recordings( protocol_name, efel_settings ) - self.recordings.append(rec) + self.recordings[protocol_name].append(rec) break else: raise KeyError( @@ -173,12 +162,6 @@ def read_recordings( f"the available stimuli names" ) - def extract_efeatures_helper(self, recording_id, efeatures, efeature_names, efel_settings): - """Helper function to compute efeatures for a single recording.""" - self.recordings[recording_id].compute_efeatures( - efeatures, efeature_names, efel_settings) - return self.recordings[recording_id] - def extract_efeatures( self, protocol_name, @@ -199,26 +182,26 @@ def extract_efeatures( is to be extracted several time on different sections of the same recording. """ - recording_ids = self.get_recordings_id_by_protocol_name(protocol_name) + recordings_of_protocol: list[Recording] = self.recordings.get(protocol_name) # Run in parallel via multiprocessing with Pool(maxtasksperchild=1) as pool: tasks = [ - (rec_id, efeatures, efeature_names, efel_settings) - for rec_id in recording_ids + (recording, efeatures, efeature_names, efel_settings) + for recording in recordings_of_protocol ] - results = pool.starmap(self.extract_efeatures_helper, tasks) + results = pool.starmap(extract_efeatures_helper, tasks) - self.recordings = results + self.recordings[protocol_name] = results def compute_relative_amp(self): """Compute the relative current amplitude for all the recordings as a percentage of the rheobase.""" if self.rheobase not in (0.0, None, False, numpy.nan): - - for i in range(len(self.recordings)): - self.recordings[i].compute_relative_amp(self.rheobase) + for _, recordings_list in self.recordings.items(): + for recording in recordings_list: + recording.compute_relative_amp(self.rheobase) else: diff --git a/bluepyefe/rheobase.py b/bluepyefe/rheobase.py index 5490e47..a299809 100644 --- a/bluepyefe/rheobase.py +++ b/bluepyefe/rheobase.py @@ -25,34 +25,39 @@ def _get_list_spiking_amplitude(cell, protocols_rheobase): - """Return the list of sorted list of amplitude that triggered at least - one spike""" + """Return the list of sorted amplitudes that triggered at least + one spike, along with their corresponding spike counts.""" amps = [] spike_counts = [] - for i, rec in enumerate(cell.recordings): - if rec.protocol_name in protocols_rheobase: - if rec.spikecount is not None: + for protocol_name, recordings_list in cell.recordings.items(): + if protocol_name in protocols_rheobase: + for rec in recordings_list: + if rec.spikecount is not None: - amps.append(rec.amp) - spike_counts.append(rec.spikecount) + amps.append(rec.amp) + spike_counts.append(rec.spikecount) - if rec.amp < 0.01 and rec.spikecount >= 1: - logger.warning( - f"A recording of cell {cell.name} protocol " - f"{rec.protocol_name} shows spikes at a " - "suspiciously low current in a trace from file" - f" {rec.files}. Check that the ton and toff are" - "correct or for the presence of unwanted spikes." - ) + if rec.amp < 0.01 and rec.spikecount >= 1: + logger.warning( + f"A recording of cell {cell.name} protocol " + f"{protocol_name} shows spikes at a " + "suspiciously low current in a trace from file " + f"{rec.files}. Check that the ton and toff are " + "correct or for the presence of unwanted spikes." + ) + # Sort amplitudes and their corresponding spike counts if amps: amps, spike_counts = zip(*sorted(zip(amps, spike_counts))) + else: + amps, spike_counts = (), () return amps, spike_counts + def compute_rheobase_absolute(cell, protocols_rheobase, spike_threshold=1): """ Compute the rheobase by finding the smallest current amplitude triggering at least one spike. diff --git a/tests/test_cell.py b/tests/test_cell.py index a111dc6..181b1bf 100644 --- a/tests/test_cell.py +++ b/tests/test_cell.py @@ -11,6 +11,7 @@ class CellTest(unittest.TestCase): def setUp(self): self.cell = bluepyefe.cell.Cell(name="MouseNeuron") + self.protocol_name = "IDRest" file_metadata = { "i_file": "./tests/exp_data/B95_Ch0_IDRest_107.ibw", @@ -25,18 +26,18 @@ def setUp(self): self.cell.read_recordings(protocol_data=[file_metadata], protocol_name="IDRest") self.cell.extract_efeatures( - protocol_name="IDRest", efeatures=["Spikecount", "AP1_amp"] + protocol_name=self.protocol_name, efeatures=["Spikecount", "AP1_amp"] ) def test_efeature_extraction(self): - recording = self.cell.recordings[0] + recording = self.cell.recordings[self.protocol_name][0] self.assertEqual(2, len(recording.efeatures)) self.assertEqual(recording.efeatures["Spikecount"], 9.0) self.assertLess(abs(recording.efeatures["AP1_amp"] - 66.4), 2.0) def test_amp_threshold(self): - recording = self.cell.recordings[0] - compute_rheobase_absolute(self.cell, ["IDRest"]) + recording = self.cell.recordings[self.protocol_name][0] + compute_rheobase_absolute(self.cell, [self.protocol_name]) self.cell.compute_relative_amp() self.assertEqual(recording.amp, self.cell.rheobase) self.assertEqual(recording.amp_rel, 100.0) diff --git a/tests/test_efel_settings.py b/tests/test_efel_settings.py index 045407f..59959aa 100644 --- a/tests/test_efel_settings.py +++ b/tests/test_efel_settings.py @@ -28,7 +28,7 @@ def setUp(self): def test_efel_threshold(self): - self.cell.recordings[0].efeatures = {} + self.cell.recordings["IDRest"][0].efeatures = {} self.cell.extract_efeatures( protocol_name="IDRest", @@ -36,13 +36,13 @@ def test_efel_threshold(self): efel_settings={'Threshold': 40.} ) - recording = self.cell.recordings[0] + recording = self.cell.recordings["IDRest"][0] self.assertEqual(recording.efeatures["Spikecount"], 0.) self.assertLess(abs(recording.efeatures["AP1_amp"] - 66.68), 0.01) def test_efel_strictstim(self): - self.cell.recordings[0].efeatures = {} + self.cell.recordings["IDRest"][0].efeatures = {} self.cell.extract_efeatures( protocol_name="IDRest", @@ -54,11 +54,11 @@ def test_efel_strictstim(self): } ) - self.assertEqual(self.cell.recordings[0].efeatures["Spikecount"], 0.) + self.assertEqual(self.cell.recordings["IDRest"][0].efeatures["Spikecount"], 0.) def test_efel_threshold(self): - self.cell.recordings[0].efeatures = {} + self.cell.recordings["IDRest"][0].efeatures = {} self.cell.extract_efeatures( protocol_name="IDRest", @@ -66,7 +66,7 @@ def test_efel_threshold(self): efel_settings={'Threshold': 40.} ) - recording = self.cell.recordings[0] + recording = self.cell.recordings["IDRest"][0] self.assertEqual(recording.efeatures["Spikecount"], 0.)