Skip to content

Commit

Permalink
Add new slicing options to trim method
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelJanas committed Nov 8, 2023
1 parent 108c18f commit 1f8dfe0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 19 deletions.
55 changes: 36 additions & 19 deletions fortepyan/midi/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,47 @@ def time_shift(self, shift_s: float):
self.df.start += shift_s
self.df.end += shift_s

def trim(self, start: float, finish: float, shift_time: bool = True) -> "MidiPiece":
"""Trim the MidiPiece object between the specified start and finish time.
This function takes two parameters, `start` and `finish`, which represent the start and end time in seconds,
and returns a new MidiPiece object that contains only the notes that start within the specified time range.
def trim(self, start: float, finish: float, shift_time: bool = True, slice_type: str = "standard") -> "MidiPiece":
"""
Trim the MidiPiece object based on a specified slicing type.
Args:
- start (float): start time in seconds
- finish (float): end time in seconds
- start (float): Depending on `slice_type`, this is either the start time in seconds or the start index.
- finish (float): Depending on `slice_type`, this is either the end time in seconds or the end index.
- shift_time (bool, optional): If True, the trimmed piece's start time will be shifted to 0. Defaults to True.
- slice_type (str, optional): Determines the slicing method ('standard', 'by_end', 'index'). Defaults to "standard".
- "standard": Slices the MidiPiece to include notes that start within the [start, finish] time range.
- "by_end": Slices the MidiPiece to include notes where the end time is within the [start, finish] time range.
- "index": Slices the MidiPiece by note indices, where start and finish must be integer indices.
Returns:
- MidiPiece: the trimmed MidiPiece object
- MidiPiece: The trimmed MidiPiece object.
"""
# Filter the rows in the data frame that are within the specified start and end time
ids = (self.df.start >= start) & (self.df.start <= finish)
# Get the indices of the rows that meet the criteria
idxs = np.where(ids)[0]

# Get the start and end indices for the new MidiPiece object
start = idxs[0]
finish = idxs[-1] + 1
# Create a slice object to pass to __getitem__
slice_obj = slice(start, finish)
# Slice the original MidiPiece object to create the trimmed MidiPiece object
if slice_type == "index":
if not isinstance(start, int) or not isinstance(finish, int):
raise ValueError("Using 'index' slice_type requires 'start' and 'finish' to be integers.")
if start < 0 or finish >= self.size:
raise IndexError("Index out of bounds.")
if start > finish:
raise ValueError("'start' must be smaller than 'finish'.")
start_idx = start
finish_idx = finish + 1
else:
if slice_type == "by_end":
ids = (self.df.start >= start) & (self.df.end <= finish)
elif slice_type == "standard": # Standard slice type
ids = (self.df.start >= start) & (self.df.start <= finish)
else:
# not implemented
raise NotImplementedError(f"Slice type '{slice_type}' is not implemented.")
idx = np.where(ids)[0]
if len(idx) == 0:
raise IndexError("No notes found in the specified range.")
start_idx = idx[0]
finish_idx = idx[-1] + 1

slice_obj = slice(start_idx, finish_idx)

out = self.__getitem__(slice_obj, shift_time)

return out
Expand Down
22 changes: 22 additions & 0 deletions tests/midi/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ def test_trim_within_bounds_with_shift(sample_midi_piece):
assert trimmed_piece.df["end"].iloc[-1] == 2, "New last note should end at 2 seconds."


def test_trim_index_slice_type(sample_midi_piece):
trimmed_piece = sample_midi_piece.trim(1, 3, slice_type="index")
assert len(trimmed_piece) == 3, "Trimmed MidiPiece should contain 3 notes."
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
assert trimmed_piece.df["pitch"].iloc[0] == 62, "New first note should have pitch 62."
assert trimmed_piece.df["end"].iloc[-1] == 3, "New last note should end at 3 seconds."


def test_trim_by_end_slice_type(sample_midi_piece):
trimmed_piece = sample_midi_piece.trim(1, 5, slice_type="by_end")
assert len(trimmed_piece.df) == 3, "Trimmed MidiPiece should contain 3 notes."
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
assert trimmed_piece.df["pitch"].iloc[0] == 62, "New first note should have pitch 62."
assert trimmed_piece.df["end"].iloc[-1] == 3, "New last note should end at 2 seconds."
assert trimmed_piece.df["pitch"].iloc[-1] == 65, "New last note should have pitch 65."


def test_trim_with_invalid_slice_type(sample_midi_piece):
with pytest.raises(NotImplementedError):
_ = sample_midi_piece.trim(1, 3, slice_type="invalid") # Invalid slice type, should raise an error


def test_trim_within_bounds_no_shift(sample_midi_piece):
# This test should not shift the start times
trimmed_piece = sample_midi_piece.trim(2, 3, shift_time=False)
Expand Down

0 comments on commit 1f8dfe0

Please sign in to comment.