From 1f8dfe0a52e9b2c57da1f2d5a1212a244f5e4519 Mon Sep 17 00:00:00 2001 From: Samuel Janas Date: Wed, 8 Nov 2023 19:23:48 +0100 Subject: [PATCH] Add new slicing options to trim method --- fortepyan/midi/structures.py | 55 +++++++++++++++++++++++------------ tests/midi/test_structures.py | 22 ++++++++++++++ 2 files changed, 58 insertions(+), 19 deletions(-) diff --git a/fortepyan/midi/structures.py b/fortepyan/midi/structures.py index f1e4dbf..d13feec 100644 --- a/fortepyan/midi/structures.py +++ b/fortepyan/midi/structures.py @@ -57,30 +57,47 @@ def time_shift(self, shift_s: float): self.df.start += shift_s self.df.end += shift_s - def trim(self, start: float, finish: float, shift_time: bool = True) -> "MidiPiece": - """Trim the MidiPiece object between the specified start and finish time. - - This function takes two parameters, `start` and `finish`, which represent the start and end time in seconds, - and returns a new MidiPiece object that contains only the notes that start within the specified time range. + def trim(self, start: float, finish: float, shift_time: bool = True, slice_type: str = "standard") -> "MidiPiece": + """ + Trim the MidiPiece object based on a specified slicing type. Args: - - start (float): start time in seconds - - finish (float): end time in seconds + - start (float): Depending on `slice_type`, this is either the start time in seconds or the start index. + - finish (float): Depending on `slice_type`, this is either the end time in seconds or the end index. + - shift_time (bool, optional): If True, the trimmed piece's start time will be shifted to 0. Defaults to True. + - slice_type (str, optional): Determines the slicing method ('standard', 'by_end', 'index'). Defaults to "standard". + - "standard": Slices the MidiPiece to include notes that start within the [start, finish] time range. + - "by_end": Slices the MidiPiece to include notes where the end time is within the [start, finish] time range. + - "index": Slices the MidiPiece by note indices, where start and finish must be integer indices. Returns: - - MidiPiece: the trimmed MidiPiece object + - MidiPiece: The trimmed MidiPiece object. """ - # Filter the rows in the data frame that are within the specified start and end time - ids = (self.df.start >= start) & (self.df.start <= finish) - # Get the indices of the rows that meet the criteria - idxs = np.where(ids)[0] - - # Get the start and end indices for the new MidiPiece object - start = idxs[0] - finish = idxs[-1] + 1 - # Create a slice object to pass to __getitem__ - slice_obj = slice(start, finish) - # Slice the original MidiPiece object to create the trimmed MidiPiece object + if slice_type == "index": + if not isinstance(start, int) or not isinstance(finish, int): + raise ValueError("Using 'index' slice_type requires 'start' and 'finish' to be integers.") + if start < 0 or finish >= self.size: + raise IndexError("Index out of bounds.") + if start > finish: + raise ValueError("'start' must be smaller than 'finish'.") + start_idx = start + finish_idx = finish + 1 + else: + if slice_type == "by_end": + ids = (self.df.start >= start) & (self.df.end <= finish) + elif slice_type == "standard": # Standard slice type + ids = (self.df.start >= start) & (self.df.start <= finish) + else: + # not implemented + raise NotImplementedError(f"Slice type '{slice_type}' is not implemented.") + idx = np.where(ids)[0] + if len(idx) == 0: + raise IndexError("No notes found in the specified range.") + start_idx = idx[0] + finish_idx = idx[-1] + 1 + + slice_obj = slice(start_idx, finish_idx) + out = self.__getitem__(slice_obj, shift_time) return out diff --git a/tests/midi/test_structures.py b/tests/midi/test_structures.py index 1ff9fdd..34ee5fc 100644 --- a/tests/midi/test_structures.py +++ b/tests/midi/test_structures.py @@ -84,6 +84,28 @@ def test_trim_within_bounds_with_shift(sample_midi_piece): assert trimmed_piece.df["end"].iloc[-1] == 2, "New last note should end at 2 seconds." +def test_trim_index_slice_type(sample_midi_piece): + trimmed_piece = sample_midi_piece.trim(1, 3, slice_type="index") + assert len(trimmed_piece) == 3, "Trimmed MidiPiece should contain 3 notes." + assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds." + assert trimmed_piece.df["pitch"].iloc[0] == 62, "New first note should have pitch 62." + assert trimmed_piece.df["end"].iloc[-1] == 3, "New last note should end at 3 seconds." + + +def test_trim_by_end_slice_type(sample_midi_piece): + trimmed_piece = sample_midi_piece.trim(1, 5, slice_type="by_end") + assert len(trimmed_piece.df) == 3, "Trimmed MidiPiece should contain 3 notes." + assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds." + assert trimmed_piece.df["pitch"].iloc[0] == 62, "New first note should have pitch 62." + assert trimmed_piece.df["end"].iloc[-1] == 3, "New last note should end at 2 seconds." + assert trimmed_piece.df["pitch"].iloc[-1] == 65, "New last note should have pitch 65." + + +def test_trim_with_invalid_slice_type(sample_midi_piece): + with pytest.raises(NotImplementedError): + _ = sample_midi_piece.trim(1, 3, slice_type="invalid") # Invalid slice type, should raise an error + + def test_trim_within_bounds_no_shift(sample_midi_piece): # This test should not shift the start times trimmed_piece = sample_midi_piece.trim(2, 3, shift_time=False)