# Copyright 2024 The YourMT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Please see the details in the LICENSE file. """ note2event.py Note tools: • mix_notes(notes_to_mix, sort, trim_overlap, fix_offset) -> List[Note] • validate_notes(notes, fix) -> List[Note] • trim_overlapping_notes(notes, sort) -> List[Note] • sort_notes(notes) -> List[Note] • notes2pc_notes(notes, note_offs) -> List[Note] • extract_program_from_notes(notes) -> Set[int] • extract_notes_selected_by_programs(notes, programs, sort) -> List[Note] Note to NoteEvent • note2note_event(notes, sort, return_activity) -> List[NoteEvent] NoteEvent tools: • slice_note_events_and_ties(note_events, start_time, end_time, tidyup) -> Tuple[List[NoteEvent], List[NoteEvent], int]) • slice_multiple_note_events_and_ties_to_bundle(note_events, start_times, duration_sec, tidyup) -> List[List[NoteEvent], List[NoteEvent], int]] # Note implmented yet.. • mix_note_event_lists_bundle(note_events_to_mix, sort, start_time_to_zero) -> NoteEventListsBundle • pitch_shift_note_events(note_events, semitone, use_deepcopy) -> List[NoteEvent] • separate_by_subunit_programs_from_note_event_lists_bundle( source_note_event_lists_bundle, subunit_programs) -> NoteEventListsBundle: • separate_channel_by_program_group_from_note_event_lists_bundle( source_note_event_lists_bundle, num_program_groups, program2channel_vocab) -> List[NoteEventListsBundle]: NoteEvent to Event: • note_event2event(note_events, tie_note_events, start_time, tps, sort) -> List[Event] Event tools: • check_event_len_from_bundle(note_events_dic_a, note_events_dic_b, max_len, fast_check) -> bool """ import warnings from copy import deepcopy from itertools import chain from typing import Optional, Tuple, Union, List, Set, Dict, Any import numpy as np from utils.note_event_dataclasses import Note, NoteEvent, NoteEventListsBundle from utils.note_event_dataclasses import Event DRUM_OFFSET_TIME = 0.01 # in seconds MINIMUM_OFFSET_TIME = 0.01 # this is used to avoid zero-length notes DRUM_PROGRAM = 128 def mix_notes(notes_to_mix: Tuple[List[Note]], sort: bool = True, trim_overlap: bool = True, fix_offset: bool = True) -> List[Note]: """ mix_notes: Mixes a tuple of many lists of Note instances into a single list of Note instances. This processes 'notes1 + notes2 + ... + notesN' faster. Because Note instances use absolute timing, the Note instances in the same timiming will be sorted by increasing order of program and pitch. Args: - notes_to_mix (tuple[list[Note]]): A tuple of lists of Note instances. - sort (bool): If True, sort the Note instances by increasing order of onsets, and at the same timing, by increasing order of program and pitch. Default is True. Returns: - notes (list[Note]): A list of Note instances. """ mixed_notes = list(chain(*notes_to_mix)) if sort and len(mixed_notes) > 0: mixed_notes.sort( key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch, note.offset)) # Trim overlapping notes if trim_overlap: mixed_notes = trim_overlapping_notes(mixed_notes, sort=sort) # fix offset >= onset the Note instances if fix_offset: mixed_notes = validate_notes(mixed_notes, fix=True) return mixed_notes def validate_notes(notes: Tuple[List[Note]], minimum_offset: Optional[bool] = 0.01, fix: bool = True) -> List[Note]: """ validate and fix unrealistic notes """ if len(notes) > 0: for note in list(notes): if note.onset == None: if fix: notes.remove(note) continue elif note.offset == None: if fix: note.offset = note.onset + MINIMUM_OFFSET_TIME elif note.onset > note.offset: warnings.warn(f'📙 Note at {note} has onset > offset.') if fix: note.offset = max(note.offset, note.onset + MINIMUM_OFFSET_TIME) print(f'✅\033[92m Fixed! Setting offset to onset + {MINIMUM_OFFSET_TIME}.\033[0m') elif note.is_drum is False and note.offset - note.onset < 0.01: # fix 13 Oct: too short notes issue for the dataset with non-MIDI annotations # warnings.warn(f'📙 Note at {note} has offset - onset < 0.01.') if fix: note.offset = note.onset + MINIMUM_OFFSET_TIME # print(f'✅\033[92m Fixed! Setting offset to onset + {MINIMUM_OFFSET_TIME}.\033[0m') return notes def trim_overlapping_notes(notes: List[Note], sort: bool = True) -> List[Note]: """ Trim overlapping notes and dropping zero-length notes. https://github.com/magenta/mt3/blob/3deffa260ba7de3cf03cda1ea513a4d7ba7144ca/mt3/note_sequences.py#L52 Trimming was only applied to train set, not test set in MT3. """ if len(notes) <= 1: return notes trimmed_notes = [] channels = set((note.pitch, note.program, note.is_drum) for note in notes) for pitch, program, is_drum in channels: channel_notes = [ note for note in notes if note.pitch == pitch and note.program == program and note.is_drum == is_drum ] sorted_notes = sorted(channel_notes, key=lambda note: note.onset) for i in range(1, len(sorted_notes)): if sorted_notes[i - 1].offset > sorted_notes[i].onset: sorted_notes[i - 1].offset = sorted_notes[i].onset # Filter out zero-length notes valid_notes = [note for note in sorted_notes if note.onset < note.offset] trimmed_notes.extend(valid_notes) if sort: trimmed_notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch)) return trimmed_notes def sort_notes(notes: List[Note]) -> List[Note]: """ Sort notes by increasing order of onsets, and at the same timing, by increasing order of program and pitch. """ if len(notes) > 0: notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch, note.offset)) return notes def notes2pc_notes(notes: List[Note], note_offset: int = 64) -> List[Note]: """ Convert a list of Note instances to a list of Pitch Class Set (PCS) instances. This method is implemented for octave-ignore evaluation cases. """ pc_notes = deepcopy(notes) for note in pc_notes: note.pitch = note.pitch % 12 + note_offset return pc_notes def extract_program_from_notes(notes: List[Note]) -> Set[int]: """ Extract program numbers from a list of Note instances.""" prg = set() for note in notes: if note.program not in prg: prg.add(note.program) return prg def extract_notes_selected_by_programs(notes: List[Note], programs: Set[int], sort: bool = True) -> List[Note]: """ Extract notes selected by program numbers from a list of Note instances.""" selected_notes = [] for note in notes: if note.program in programs: selected_notes.append(note) if sort: selected_notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch)) return selected_notes """ NoteEvent data class: Combines NoteEvent and NoteActivity for onset and offset events during Note to Event conversion. Features: Trackable: follow note activity by index Sliceable: extract time ranges; time is absolute Mergeable: combine two NoteEvent instances (re-index needed) Mutable: mute events by program number, pitch Transferable: easily convert to Note or Event tokens """ def note2note_event(notes: List[Note], sort: bool = True, return_activity: bool = True) -> List[NoteEvent]: """ note2note_event: Converts a list of Note instances to a list of NoteEvent instances. Args: - notes (List[Note]): A list of Note instances. - sort (bool): Sort the NoteEvent instances by increasing order of onsets, and at the same timing, by increasing order of program and pitch. Default is True. If return_activity is set to True, NoteEvent instances are sorted regardless of this argument. - return_activity (bool): If True, return a list of NoteActivity instances Returns: - note_events (List[NoteEvent]): A list of NoteEvent instances. """ note_events = [] for note in notes: # for each note, add onset and offset events note_events.append(NoteEvent(note.is_drum, note.program, note.onset, note.velocity, note.pitch)) if note.is_drum == 0: # (drum has no offset!) note_events.append(NoteEvent(note.is_drum, note.program, note.offset, 0, note.pitch)) if sort or return_activity: note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) if return_activity: # activity stores the indices of previous notes that are still active activity = set() # mutable class for i, ne in enumerate(note_events): # set a copy of the activity set ti the current note event ne.activity = activity.copy() if ne.is_drum: continue # drum's offset and activity are not tracked elif ne.velocity == 1: activity.add(i) elif ne.velocity == 0: # search for the index of matching onset event matched_onset_event_index = None for j in activity: if note_events[j].equals_only(ne, 'is_drum', 'program', 'pitch'): matched_onset_event_index = j break if matched_onset_event_index is not None: activity.remove(matched_onset_event_index) else: raise ValueError(f'📕 note2note_event: no matching onset event for {ne}') else: raise ValueError(f'📕 Invalid velocity: {ne.velocity} expected 0 or 1') if len(activity) > 0: # if there are still active notes at the end of the sequence warnings.warn(f'📙 note2note_event: {len(activity)} notes are still \ active at the end of the sequence. Please validate \ the input Note instances. ') return note_events def slice_note_events_and_ties(note_events: List[NoteEvent], start_time: float, end_time: float, tidyup: bool = False) -> Tuple[List[NoteEvent], List[NoteEvent], int]: """ Extracts a specific subsequence of note events and tie note events for the first note event in the subsequence. Args: - note_events (List[NoteEvent]): List of NoteEvent instances. - start_time (float): The start time of the subsequence in seconds. - end_time (float): The end time of the subsequence in seconds. - tidyup (Optional[bool]): If True, sort the resulting lists of NoteEvents, and remove the activity attribute of sliced_note_event, and remove the time and activity attributes of tie_note_events. Default is False. Avoid using tidyup=True without deepcopying the original note_events. Note: - The activity attribute of returned sliced_note_events, and the time and activity attributes of tie_note_events are not valid after slicing. Thus, they should be ignored in the downstream processing. Returns: - sliced_note_events (List[NoteEvent]): List of NoteEvent instances in the specified range. - tie_note_events (List[NoteEvent]): List of NoteEvent instances that are active (tie) at start_time. - start_time (float): Just bypass the start time from the input argument. """ if start_time > end_time: raise ValueError(f'📕 slice_note_events: start_time {start_time} \ is greater than end_time {end_time}') elif len(note_events) == 0: warnings.warn('📙 slice_note_events: empty note_events as input') return [], [], start_time # Get start_index and end_index start_index, end_index = None, None found_start = False for i, ne in enumerate(note_events): if not found_start and ne.time >= start_time and ne.time < end_time: start_index = i found_start = True if ne.time >= end_time: end_index = i break # Get tie_note_events if start_index == None: if end_index == 0: tie_note_events = [] elif end_index == None: tie_note_events = [] else: tie_note_events = [note_events[i] for i in note_events[end_index].activity] else: tie_note_events = [note_events[i] for i in note_events[start_index].activity] """ modifying note events here is dangerous, due to mutability of original note_events!! """ if tidyup: for tne in tie_note_events: tne.time = None tne.activity = None tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) # Get sliced note_events if start_index is None: sliced_note_events = [] else: sliced_note_events = note_events[start_index:end_index] if tidyup: for sne in sliced_note_events: sne.activity = None sliced_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) return sliced_note_events, tie_note_events, start_time """ class NoteEventListsBundle(TypedDict): note_events: List[List[NoteEvent]] tie_note_events: List[List[NoteEvent]] start_time: List[int] """ def slice_multiple_note_events_and_ties_to_bundle(note_events: List[NoteEvent], start_times: List[float], duration_sec: float, tidyup: bool = False) -> NoteEventListsBundle: """ Extracts N subsequence of note events and tie-note events by taking a list of N start_time and a list of N end_time. """ sliced_note_events_list = [] sliced_tie_note_events_list = [] for start_time in start_times: end_time = start_time + duration_sec sliced_note_events, tie_note_events, _ = slice_note_events_and_ties(note_events, start_time, end_time, tidyup) sliced_note_events_list.append(sliced_note_events) sliced_tie_note_events_list.append(tie_note_events) return NoteEventListsBundle({ 'note_events': sliced_note_events_list, 'tie_note_events': sliced_tie_note_events_list, 'start_times': start_times }) def mix_note_event_lists_bundle( note_event_lists_bundle_to_mix: NoteEventListsBundle, sort: bool = True, start_time_to_zero: bool = True, use_deepcopy: bool = False, ) -> NoteEventListsBundle: """ Mixes a tuple of many lists of NoteEvent instances into a single list of NoteEvent instances. This processes 'note_events1 + note_events2 + ... + note_eventsN'. Because each NoteEvent list instance may have different start time, it is recommended to set start_time_to_zero to True. Known issue: - Solution for overlapping note_events is not implemented yet. - Currently, it is assumed that programs have no overlap among note_events_to_mix. - For faster processing, use_deepcopy is set to False by default. Args: - note_events_bundle_to_mix (NoteEventListsBundle): A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). See NoteEventListsBundle in utils/note_event_dataclasses.py for more details. - sort (bool): If True, sort the NoteEvent instances by increasing order of onsets, and at the same timing, by increasing order of program and pitch. Default is True. - start_time_to_zero (bool): If True, set the start time of each list of NoteEvents to 0. Default is True. - use_deepcopy (bool): If True, use deepcopy() to avoid modifying the original NoteEvent Returns: - mixed_note_events_dic (NoteEventListsBundle): A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). """ if use_deepcopy is True: note_events_to_mix = deepcopy(note_event_lists_bundle_to_mix["note_events"]) tie_note_events_to_mix = deepcopy(note_event_lists_bundle_to_mix["tie_note_events"]) else: note_events_to_mix = note_event_lists_bundle_to_mix["note_events"] tie_note_events_to_mix = note_event_lists_bundle_to_mix["tie_note_events"] start_times = note_event_lists_bundle_to_mix["start_times"] # Reset start time to zero if start_time_to_zero is True: for note_events, tie_note_events, start_time in zip(note_events_to_mix, tie_note_events_to_mix, start_times): for ne in note_events: ne.time -= start_time assert ne.time >= 0, f'📕 mix_note_events: negative time {ne.time}' """modifying tie note events here is dangerous, due to mutability of linked note_events""" # for tne in tie_note_events: # tne.time = None # tne.activity = None # Mix mixed_note_events = list(chain(*note_events_to_mix)) mixed_tie_note_events = list(chain(*tie_note_events_to_mix)) # Sort if sort is True: mixed_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) mixed_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) mixed_note_events_dic = NoteEventListsBundle({ 'note_events': [mixed_note_events], 'tie_note_events': [mixed_tie_note_events], 'start_times': [0.] }) return mixed_note_events_dic def pitch_shift_note_events(note_events: List[NoteEvent], semitone: int, use_deepcopy: bool = False) -> List[NoteEvent]: """ Apply pitch shift to NoteEvent instances: Args: - note_events (List[NoteEvent]): A list of NoteEvent instances. Typically 'note_events' or 'tie_note_events' can be an input. - semitone (int): The number of semitones to shift. Positive value shifts up, negative value - use_deepcopy (bool): If True, use deepcopy() to avoid modifying the original NoteEvent Returns: - note_events (List[NoteEvent]): A list of NoteEvent instances with pitch shifted. Drums are excluded from pitch shift processing. """ if semitone == 0: return note_events if use_deepcopy is True: note_events = deepcopy(note_events) for ne in note_events: if ne.is_drum is False: new_pitch = ne.pitch + semitone if new_pitch >= 0 and new_pitch < 128: ne.pitch = new_pitch return note_events def separate_by_subunit_programs_from_note_event_lists_bundle(source_note_event_lists_bundle: NoteEventListsBundle, subunit_programs: List[List[int]], start_time_to_zero: bool = True, sort: bool = True) -> NoteEventListsBundle: src_note_events = source_note_event_lists_bundle['note_events'] src_tie_note_events = source_note_event_lists_bundle['tie_note_events'] src_start_times = source_note_event_lists_bundle['start_times'] # Reset start time to zero if start_time_to_zero is True and not all(t == 0. for t in src_start_times): for nes, tnes, start_time in zip(src_note_events, src_tie_note_events, src_start_times): for ne in nes: ne.time -= start_time assert ne.time >= 0, f'📕 mix_note_events: negative time {ne.time}' for tne in tnes: tne.time = None tne.activity = None src_start_times = [0. for i in range(len(src_start_times))] num_subunits = len(subunit_programs) result_note_events = [[] for _ in range(num_subunits)] result_tie_note_events = [[] for _ in range(num_subunits)] result_start_times = [0. for _ in range(num_subunits)] # Convert subunit_programs to list of sets for faster lookups subunit_program_sets = [set(sp) for sp in subunit_programs] for nes, tnes in zip(src_note_events, src_tie_note_events): for ne in nes: if ne.is_drum: target_indices = [i for i, sp_set in enumerate(subunit_program_sets) if DRUM_PROGRAM in sp_set] else: target_indices = [i for i, sp_set in enumerate(subunit_program_sets) if ne.program in sp_set] for i in target_indices: result_note_events[i].append(ne) for tne in tnes: target_indices = [i for i, sp_set in enumerate(subunit_program_sets) if tne.program in sp_set] for i in target_indices: result_tie_note_events[i].append(tne) # Sort if sort is True: for nes, tnes in zip(result_note_events, result_tie_note_events): nes.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) tnes.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) return { 'note_events': result_note_events, # List[List[NoteEvent]] 'tie_note_events': result_tie_note_events, # List[List[NoteEvent]] 'start_times': result_start_times, # List[float] } def separate_channel_by_program_group_from_note_event_lists_bundle(source_note_event_lists_bundle: NoteEventListsBundle, num_program_groups: int, program2channel_vocab: Dict[int, Dict[str, Any]], start_time_to_zero: bool = False, sort: bool = True) -> List[NoteEventListsBundle]: """ Args: - source_note_event_lists_bundle (NoteEventListsBundle): A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). See NoteEventListsBundle in utils/note_event_dataclasses.py for more details. - num_program_groups (int): The number of program groups to separate. Typically this is the length of program_vocab + 1 (for drums). - program2channel_vocab (Dict[int, Dict[str, Union[List[int], np.ndarray]]]): A dictionary with keys (program, channel, instrument_group, primary_program). See program2channel_vocab in utils/utils.py, create_program2channel_vocab() for more details. example: program2channel_vocab[program_int] = { "channel": (int), "instrument_group": (str), "primary_program": (int), } - start_time_to_zero (bool): If True, set the start time of each list of NoteEvents to 0. Default is False. - sort (bool): If True, sort the NoteEvent instances by increasing order of onsets, and at the same timing, by increasing order of program and pitch. Default is True. Returns: - result_list_bundle List[NoteEventListsBundle]: A list of NoteEventListsBundle instances with length of batch_sz. NoteEventListsBundle is a dictionary with keys ('note_events', 'tie_note_events', 'start_time'). See NoteEventListsBundle in utils/note_event_dataclasses.py for more details. """ src_note_events = source_note_event_lists_bundle['note_events'] src_tie_note_events = source_note_event_lists_bundle['tie_note_events'] src_start_times = source_note_event_lists_bundle['start_times'] # Reset start time to zero if start_time_to_zero is True and not all(t == 0. for t in src_start_times): for nes, tnes, start_time in zip(src_note_events, src_tie_note_events, src_start_times): """modifying time of note events is only for mixing events within training. test set should keep the original time""" for ne in nes: ne.time -= start_time assert ne.time >= 0, f'📕 mix_note_events: negative time {ne.time}' """modifying tie note events here is dangerous, due to mutability of linked note_events""" # for tne in tnes: # tne.time = None # tne.activity = None src_start_times = [0. for i in range(len(src_start_times))] batch_sz = len(src_note_events) result_list_bundle = [{ "note_events": [[] for _ in range(num_program_groups)], "tie_note_events": [[] for _ in range(num_program_groups)], "start_times": [src_start_times[b] for _ in range(num_program_groups)], } for b in range(batch_sz)] """ Example of program2channel_vocab { 0: {'channel': 0, 'instrument_group': 'Piano', 'primary_program': 0}, 1: {'channel': 1, 'instrument_group': 'Chromatic Percussion', 'primary_program': 8}, ... 100: {'channel': 11, 'instrument_group': 'Singing Voice', 'primary_program': 100}, 128: {'channel': 12, 'instrument_group': 'Drums', 'primary_program': 128} } """ # Separate by program_vocab for b, (nes, tnes) in enumerate(zip(src_note_events, src_tie_note_events)): for ne in nes: program = DRUM_PROGRAM if ne.is_drum else ne.program mapping_info = program2channel_vocab.get(program, None) if mapping_info is not None: ch = mapping_info["channel"] result_list_bundle[b]["note_events"][ch].append(ne) else: # Temporary fix for program > 95, such as gunshot and FX. TODO: FX class pass for tne in tnes: mapping_info = program2channel_vocab.get(tne.program) if mapping_info is not None: ch = mapping_info["channel"] result_list_bundle[b]["tie_note_events"][ch].append(tne) else: # Temporary fix for program > 95, such as gunshot and FX. TODO: FX class pass # Sort if sort: for ch in range(num_program_groups): result_list_bundle[b]["note_events"][ch].sort( key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) result_list_bundle[b]["tie_note_events"][ch].sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) return result_list_bundle # List[NoteEventListsBundle] with length of batch_sz def note_event2event(note_events: List[NoteEvent], tie_note_events: Optional[List[NoteEvent]] = None, start_time: float = 0., tps: int = 100, sort: bool = True) -> List[Event]: """ note_event2event: Converts a list of NoteEvent instances to a list of Event instances. - NoteEvent instances have absolute time within a file, while Event instances have 'shift' events of absolute time within a segment. - Tie NoteEvent instances are prepended to output list of Event instances, and closed by a 'tie' event. - If start_time is not provided, start_time=0 in seconds by default. - If there is non-tie note_event instances before the start_time, raises an error. Args: - note_events (list[NoteEvent]): A list of NoteEvent instances. - tie_note_events (Optional[list[NoteEvent]]): A list of tie NoteEvent instances. See slice_note_events_and_ties() for more details. Default is None. - start_time (float): Start time in seconds. Default is 0. Any non-tie NoteEvent instances should have time >= start_time. - tps (Optional[int]): Ticks per second. Default is 100. - sort (bool): If True, sort the Event instances by increasing order of onsets, and at the same timing, by increasing order of program and pitch. Default is False. Returns: - events (list[Event]): A list of Event instances. """ if sort: if tie_note_events != None: tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) note_events.sort( key=lambda n_ev: (round(n_ev.time * tps), n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) # Initialize event list and state variables events = [] start_tick = round(start_time * tps) tick_state = start_tick program_state = None # Prepend tie events if tie_note_events: for tne in tie_note_events: if tne.program != program_state: events.append(Event(type='program', value=tne.program)) program_state = tne.program events.append(Event(type='pitch', value=tne.pitch)) # Any tie events (can be empty) are closed by a 'tie' event events.append(Event(type='tie', value=0)) # Translate NoteEvent to Event in the list velocity_state = None # reset state variables for ne in note_events: if ne.is_drum and ne.velocity == 0: # <-- bug fix continue # drum's offset should be ignored, and should not cause shift # Process time shift and update tick_state ne_tick = round(ne.time * tps) if ne_tick > tick_state: # shift_ticks = ne_tick - tick_state shift_ticks = ne_tick - start_tick events.append(Event(type='shift', value=shift_ticks)) tick_state = ne_tick elif ne_tick == tick_state: pass else: raise ValueError( f'NoteEvent tick_state {ne_tick} of time {ne.time} is smaller than tick_state {tick_state}.') # Process program change and update program_state if ne.is_drum and ne.velocity == 1: # drum events have no program and offset but velocity 1 if velocity_state != 1 or velocity_state == None: events.append(Event(type='velocity', value=1)) velocity_state = 1 events.append(Event(type='drum', value=ne.pitch)) else: if ne.program != program_state or program_state == None: events.append(Event(type='program', value=ne.program)) program_state = ne.program if ne.velocity != velocity_state or velocity_state == None: events.append(Event(type='velocity', value=ne.velocity)) velocity_state = ne.velocity events.append(Event(type='pitch', value=ne.pitch)) return events def check_event_len_from_bundle(note_events_dic_a: Dict, note_events_dic_b: Dict, max_len: int, fast_check: bool = True) -> bool: """ Check if the total length of events converted from note_events_dic exceeds the max length. This is used in cross augmentation. See augment.py for more the usage. Args: - note_events_dic_a (Dict): A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). - note_events_dic_b (Dict): A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). - max_len (int): Maximum length of events. - fast_check (bool): If True, check the total length of note_events only. Default is True. Returns: - bool: True (passed) or False (failed) """ if fast_check is True: ne_len_a = sum([len(ne) for ne in note_events_dic_a['note_events']]) ne_len_b = sum([len(ne) for ne in note_events_dic_b['note_events']]) total_note_events_len = ne_len_a + ne_len_b if fast_check is False or total_note_events_len >= max_len // 3: event_len_a = 0 for ne, tne, start_time in zip(note_events_dic_a['note_events'], note_events_dic_a['tie_note_events'], note_events_dic_a['start_times']): event_len_a += len(note_event2event(ne, tne, start_time)) event_len_b = 0 for ne, tne, start_time in zip(note_events_dic_b['note_events'], note_events_dic_b['tie_note_events'], note_events_dic_b['start_times']): event_len_b += len(note_event2event(ne, tne, start_time)) total_events_len = event_len_a + event_len_b if total_events_len >= max_len: return False # failed else: return True # passed