From 108c18f834b9130ca91c8bf886a286ebb7bd56e2 Mon Sep 17 00:00:00 2001 From: Samuel Janas Date: Wed, 8 Nov 2023 18:47:43 +0100 Subject: [PATCH] Add optional parameter to trim method for shifting start times --- fortepyan/midi/structures.py | 27 ++++++++++++++++++--------- tests/midi/test_structures.py | 12 +++++++++++- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/fortepyan/midi/structures.py b/fortepyan/midi/structures.py index d671151..f1e4dbf 100644 --- a/fortepyan/midi/structures.py +++ b/fortepyan/midi/structures.py @@ -57,7 +57,7 @@ def time_shift(self, shift_s: float): self.df.start += shift_s self.df.end += shift_s - def trim(self, start: float, finish: float) -> "MidiPiece": + def trim(self, start: float, finish: float, shift_time: bool = True) -> "MidiPiece": """Trim the MidiPiece object between the specified start and finish time. This function takes two parameters, `start` and `finish`, which represent the start and end time in seconds, @@ -78,9 +78,11 @@ def trim(self, start: float, finish: float) -> "MidiPiece": # Get the start and end indices for the new MidiPiece object start = idxs[0] finish = idxs[-1] + 1 + # Create a slice object to pass to __getitem__ + slice_obj = slice(start, finish) # Slice the original MidiPiece object to create the trimmed MidiPiece object - out = self[start:finish] - # Return the trimmed MidiPiece object + out = self.__getitem__(slice_obj, shift_time) + return out def __sanitize_get_index(self, index: slice) -> slice: @@ -97,19 +99,26 @@ def __sanitize_get_index(self, index: slice) -> slice: return index - def __getitem__(self, index: slice) -> "MidiPiece": + def __getitem__(self, index: slice, shift_time: bool = True) -> "MidiPiece": index = self.__sanitize_get_index(index) part = self.df[index].reset_index(drop=True) - first_sound = part.start.min() - part.start -= first_sound - part.end -= first_sound + if shift_time: + # Shift the start and end times so that the first note starts at 0 + first_sound = part.start.min() + part.start -= first_sound + part.end -= first_sound + # Adjust the source to reflect the new start time + start_time_adjustment = first_sound + else: + # No adjustment to the start time + start_time_adjustment = 0 - # Make sure the piece can always be track back to the original file exactly + # Make sure the piece can always be tracked back to the original file exactly out_source = dict(self.source) out_source["start"] = self.source.get("start", 0) + index.start out_source["finish"] = self.source.get("start", 0) + index.stop - out_source["start_time"] = self.source.get("start_time", 0) + first_sound + out_source["start_time"] = self.source.get("start_time", 0) + start_time_adjustment out = MidiPiece(df=part, source=out_source) return out diff --git a/tests/midi/test_structures.py b/tests/midi/test_structures.py index 801f80f..1ff9fdd 100644 --- a/tests/midi/test_structures.py +++ b/tests/midi/test_structures.py @@ -73,7 +73,7 @@ def test_midi_piece_duration_calculation(sample_df): assert piece.duration == 5.5 -def test_trim_within_bounds(sample_midi_piece): +def test_trim_within_bounds_with_shift(sample_midi_piece): # Test currently works as in the original code. # We might want to change this behavior so that # we do not treat the trimed piece as a new piece @@ -84,6 +84,16 @@ def test_trim_within_bounds(sample_midi_piece): assert trimmed_piece.df["end"].iloc[-1] == 2, "New last note should end at 2 seconds." +def test_trim_within_bounds_no_shift(sample_midi_piece): + # This test should not shift the start times + trimmed_piece = sample_midi_piece.trim(2, 3, shift_time=False) + assert len(trimmed_piece.df) == 2, "Trimmed MidiPiece should contain 2 notes." + # Since we're not shifting, the start should not be 0 but the actual start time + assert trimmed_piece.df["start"].iloc[0] == 2, "First note should retain its original start time." + assert trimmed_piece.df["pitch"].iloc[0] == 64, "First note should have pitch 64." + assert trimmed_piece.df["end"].iloc[-1] == 4, "Last note should end at 4 seconds." + + def test_trim_at_boundaries(sample_midi_piece): trimmed_piece = sample_midi_piece.trim(0, 5.5) assert trimmed_piece.size == sample_midi_piece.size, "Trimming at boundaries should not change the size."