Spaces:
Sleeping
Sleeping
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
""" tokenizer.py: Encodes and decodes events to/from tokens. """ | |
import numpy as np | |
import warnings | |
from abc import ABC, abstractmethod | |
from utils.note_event_dataclasses import Event, EventRange, Note #, Codec | |
from utils.event_codec import FastCodec as Codec | |
from utils.note_event_dataclasses import NoteEvent | |
from utils.note2event import note_event2event | |
from utils.event2note import event2note_event, note_event2note | |
from typing import List, Optional, Union, Tuple, Dict, Counter | |
#TODO: Too complex to be an abstract class. | |
class EventTokenizerBase(ABC): | |
""" | |
A base class for encoding and decoding events to and from tokens. | |
""" | |
def __init__( | |
self, | |
base_codec: Union[Codec, str] = 'mt3', | |
special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], | |
extra_tokens: List[str] = [], | |
max_shift_steps: int = 206, # 1001 in Gardner et al. | |
program_vocabulary: Optional[Dict] = None, | |
drum_vocabulary: Optional[Dict] = None, | |
) -> None: | |
""" | |
Initializes the EventTokenizerBase object. | |
:param base_codec: The codec to use for encoding and decoding. | |
:param special_tokens: None or list of special tokens to include in the vocabulary. | |
:param extra_tokens: None or list of tokens to be treated as additional special tokens. | |
:param program_vocabulary: None or a dictionary mapping program names to program indices. | |
:param drum_vocabulary: None or a dictionary mapping drum names to drum indices. | |
:param max_shift_steps: The maximum number of shift steps to use for the codec. | |
""" | |
# Initialize the codec attribute based on the input codec parameter. | |
if isinstance(base_codec, str): | |
# If codec is a string, initialize codec with the appropriate Codec object. | |
if base_codec.lower() == 'mt3': | |
event_ranges = [ | |
EventRange('pitch', min_value=0, max_value=127), | |
EventRange('velocity', min_value=0, max_value=1), | |
EventRange('tie', min_value=0, max_value=0), | |
EventRange('program', min_value=0, max_value=127), | |
EventRange('drum', min_value=0, max_value=127), | |
] | |
else: | |
raise ValueError(f'Unknown codec name: {base_codec}') | |
# Initialize codec | |
self.codec = Codec(special_tokens=special_tokens + extra_tokens, | |
max_shift_steps=max_shift_steps, | |
event_ranges=event_ranges, | |
program_vocabulary=program_vocabulary, | |
drum_vocabulary=drum_vocabulary, | |
name='mt3') | |
elif isinstance(base_codec, Codec): | |
# If codec is a Codec object, store it directly. | |
self.codec = base_codec | |
if program_vocabulary is not None or drum_vocabulary is not None: | |
print('') | |
warnings.warn("Vocabulary cannot be applied when using a custom codec.") | |
else: | |
# If codec is neither a string nor a Codec object, raise a NotImplementedError. | |
raise TypeError(f'Unknown codec type: {type(base_codec)}') | |
self.num_tokens = self.codec._num_classes | |
def _encode(self, events: List[Event]) -> List[int]: | |
return [self.codec.encode_event(e) for e in events] | |
def _decode(self, tokens: List[int]) -> List[Event]: | |
return [self.codec.decode_event_index(idx) for idx in tokens] | |
def encode(self): | |
""" Encode your custom events to tokens. """ | |
pass | |
def decode(self): | |
""" Decode your custom tokens to events.""" | |
pass | |
class EventTokenizer(EventTokenizerBase): | |
""" | |
Eencoding and decoding events to and from tokens. | |
""" | |
def __init__(self, | |
base_codec: Union[Codec, str] = 'mt3', | |
special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], | |
extra_tokens: List[str] = [], | |
max_shift_steps: int = 206, | |
program_vocabulary: Optional[Dict] = None, | |
drum_vocabulary: Optional[Dict] = None) -> None: | |
""" | |
Initializes the EventTokenizerBase object. | |
:param codec: The codec to use for encoding and decoding. | |
:param special_tokens: None or list of special tokens to include in the vocabulary. | |
:param extra_tokens: None or list of tokens to be treated as additional special tokens. | |
:param program_vocabulary: None or a dictionary mapping program names to program indices. | |
:param drum_vocabulary: None or a dictionary mapping drum names to drum indices. | |
:param max_shift_steps: The maximum number of shift steps to use for the codec. | |
""" | |
# Initialize the codec attribute based on the input codec parameter. | |
super().__init__( | |
base_codec=base_codec, | |
special_tokens=special_tokens, | |
extra_tokens=extra_tokens, | |
max_shift_steps=max_shift_steps, | |
program_vocabulary=program_vocabulary, | |
drum_vocabulary=drum_vocabulary, | |
) | |
def encode(self, events): | |
""" Encode your custom events to tokens. """ | |
return super()._encode(events) | |
def decode(self, tokens): | |
""" Decode your custom tokens to events.""" | |
return super()._decode(tokens) | |
class NoteEventTokenizer(EventTokenizerBase): | |
""" Encodes and decodes note events to/from tokens. """ | |
def __init__( | |
self, | |
base_codec: Union[Codec, str] = 'mt3', | |
max_length: int = 1024, # max length of tokens | |
tps: int = 100, | |
sort_note_event: bool = True, | |
special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], | |
extra_tokens: List[str] = [], | |
max_shift_steps: int = 206, | |
program_vocabulary: Optional[Dict] = None, | |
drum_vocabulary: Optional[Dict] = None, | |
ignore_decoding_tokens: List[str] = [], | |
ignore_decoding_tokens_from_and_to: Optional[List[str]] = None, | |
debug_mode: bool = False) -> None: | |
""" | |
Initializes the TaskEventNoteTokenizer object. | |
List[NoteEvent] -> encdoe_note_events -> np.ndarray[int] | |
np.ndarray[int] -> decode_note_events -> Tuple[List[NoteEvent], List[NoteEvent]] | |
:param codec: The codec to use for encoding and decoding. | |
:param special_tokens: None or list of special tokens to include in the vocabulary. | |
:param extra_tokens: None or list of tokens to be treated as additional special tokens. | |
:param program_vocabulary: None or a dictionary mapping program names to program indices. | |
:param drum_vocabulary: None or a dictionary mapping drum names to drum indices. | |
:param max_shift_steps: The maximum number of shift steps to use for the codec. | |
:param ignore_decoding_tokens: List of tokens to ignore during decoding. | |
:param ignore_decoding_tokens_from_and_to: List of tokens to ignore during decoding. [from, to] | |
""" | |
super().__init__(base_codec=base_codec, | |
special_tokens=special_tokens, | |
extra_tokens=extra_tokens, | |
max_shift_steps=max_shift_steps, | |
program_vocabulary=program_vocabulary, | |
drum_vocabulary=drum_vocabulary) | |
self.max_length = max_length | |
self.tps = tps | |
self.sort = sort_note_event | |
# Prepare prefix, suffix and pad tokens. | |
self._prefix = [] | |
self._suffix = [] | |
for stk in self.codec.special_tokens: | |
if stk == 'EOS': | |
self._suffix.append(self.codec.special_tokens.index('EOS')) | |
elif stk == 'PAD': | |
self._zero_pad = [0] * 1024 | |
elif stk == 'UNK': | |
pass | |
else: | |
pass | |
# raise NotImplementedError(f'Unknown special token: {stk}') | |
self.eos_id = self.codec.special_tokens.index('EOS') | |
self.pad_id = self.codec.special_tokens.index('PAD') | |
self.ids_to_ignore_decoding = [self.codec.special_tokens.index(t) for t in ignore_decoding_tokens] | |
self.ignore_tokens_from_and_to = ignore_decoding_tokens_from_and_to | |
self.debug_mode = debug_mode | |
def _decode(self, tokens): | |
# This is event detokenizer, not note_event. It is required for displaying events in validation dashboard | |
return super()._decode(tokens) | |
def encode( | |
self, | |
note_events: List[NoteEvent], | |
tie_note_events: Optional[List[NoteEvent]] = None, | |
start_time: float = 0., | |
) -> List[int]: | |
""" Encodes note events and tie note events to tokens. """ | |
events = note_event2event( | |
note_events=note_events, | |
tie_note_events=tie_note_events, | |
start_time=start_time, # required for calcuating relative time | |
tps=self.tps, | |
sort=self.sort) | |
return super()._encode(events) | |
def encode_plus( | |
self, | |
note_events: List[NoteEvent], | |
tie_note_events: Optional[List[NoteEvent]] = None, | |
start_times: float = 0., # Fixing bug: start_time --> start_times | |
add_special_tokens: Optional[bool] = True, | |
max_length: Optional[int] = None, # if None, use self.max_length | |
pad_to_max_length: Optional[bool] = True, | |
return_attention_mask: bool = False) -> Union[List[int], Tuple[List[int], List[int]]]: | |
""" Encodes note events and tie note info to padded tokens. """ | |
encoded = self.encode(note_events, tie_note_events, start_times) | |
# if task_events: | |
# encoded = super()._encode(task_events) + encoded | |
if add_special_tokens: | |
if self._prefix: | |
encoded = self._prefix + encoded | |
if self._suffix: | |
encoded = encoded + self._suffix | |
if max_length is None: | |
max_length = self.max_length | |
length = len(encoded) | |
if length >= max_length: | |
encoded = encoded[:max_length] | |
length = max_length | |
if return_attention_mask: | |
attention_mask = [1] * length | |
# <PAD> | |
if pad_to_max_length is True: | |
if len(self._zero_pad) != max_length: | |
self._zero_pad = [self.pad_id] * max_length | |
if return_attention_mask: | |
attention_mask += self._zero_pad[length:] | |
encoded = encoded + self._zero_pad[length:] | |
if return_attention_mask: | |
return encoded, attention_mask | |
return encoded | |
def encode_task(self, task_events: List[Event], max_length: Optional[int] = None) -> List[int]: | |
# NOTE: This is an event tokenizer that generates task ids, not the list of note_event objects. | |
encoded = super()._encode(task_events) | |
# <PAD> | |
if max_length is not None: | |
if len(self._zero_pad_task) != max_length: | |
self._zero_pad_task = [self.pad_id] * max_length | |
length = len(encoded) | |
encoded = encoded + self._zero_pad[length:] | |
return encoded | |
def decode( | |
self, | |
tokens: List[int], | |
start_time: float = 0., | |
return_events: bool = False, | |
) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], | |
List[Event], int]]: | |
"""Decodes a sequence of tokens into note events. | |
Args: | |
tokens (List[int]): The list of tokens to be decoded. | |
start_time (float, optional): The starting time for the note events. Defaults to 0. | |
return_events (bool, optional): Indicates whether to include the raw events in the return value. | |
Defaults to False. | |
Returns: | |
Union[Tuple[List[NoteEvent], List[NoteEvent]], | |
Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events. | |
If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`, | |
`last_activity`, and `err_cnt`. | |
If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`, | |
`last_activity`, `events`, and `err_cnt`. | |
""" | |
if self.debug_mode: | |
ignored_tokens_from_input = [t for t in tokens if t in self.ids_to_ignore_decoding] | |
print(ignored_tokens_from_input) | |
if self.ids_to_ignore_decoding: | |
tokens = [t for t in tokens if t not in self.ids_to_ignore_decoding] | |
events = super()._decode(tokens) | |
note_events, tie_note_events, last_activity, err_cnt = event2note_event(events, start_time, True, self.tps) | |
if return_events: | |
return note_events, tie_note_events, last_activity, events, err_cnt | |
else: | |
return note_events, tie_note_events, last_activity, err_cnt | |
def decode_batch( | |
self, | |
batch_tokens: Union[List[List[int]], np.ndarray], | |
start_times: List[float], | |
return_events: bool = False | |
) -> Union[Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], int], | |
Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], List[List[Event]], | |
Counter[str]]]: | |
""" | |
Decodes a batch of tokens to note_events and tie_note_events. | |
Args: | |
batch_tokens (List[List[int]] or np.ndarray): Tokens to be decoded. | |
start_times (List[float]): List of start times for each token set. | |
return_events (bool, optional): Flag to determine if events should be returned. Defaults to False. | |
""" | |
if isinstance(batch_tokens, np.ndarray): | |
batch_tokens = batch_tokens.tolist() | |
if len(batch_tokens) != len(start_times): | |
raise ValueError('The length of batch_tokens and start_times must be same.') | |
zipped_note_events_and_tie = [] | |
list_events = [] | |
total_err_cnt = 0 | |
for tokens, start_time in zip(batch_tokens, start_times): | |
if return_events: | |
note_events, tie_note_events, last_activity, events, err_cnt = self.decode( | |
tokens, start_time, return_events) | |
list_events.append(events) | |
else: | |
note_events, tie_note_events, last_activity, err_cnt = self.decode(tokens, start_time, return_events) | |
zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time)) | |
total_err_cnt += err_cnt | |
if return_events: | |
return zipped_note_events_and_tie, list_events, total_err_cnt | |
else: | |
return zipped_note_events_and_tie, total_err_cnt | |
def decode_list_batches( | |
self, | |
list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]], | |
list_start_times: Union[List[List[float]], List[float]], | |
return_events: bool = False | |
) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], Counter[str]], | |
Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], | |
List[List[Event]], Counter[str]]]: | |
""" | |
Decodes a list of variable-size batches of token array to a list of | |
zipped note_events and tie_note_events. | |
Args: | |
list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length) | |
list_start_times: List[float], where the length is sum of all batch_sizes. | |
return_events: bool, Defaults to False. | |
Returns: | |
list_list_zipped_note_events_and_tie: | |
List[ | |
Tuple[ | |
List[NoteEvent]: A list of note events. | |
List[NoteEvent]: A list of tie note events. | |
List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful | |
for validating notes within a batch of segments extracted from a file. | |
List[float]: A list of segment start times. | |
] | |
] | |
(Optional) list_events: | |
List[List[Event]] | |
total_err_cnt: | |
Counter[str]: error counter. | |
""" | |
list_tokens = [] | |
for arr in list_batch_tokens: | |
for tokens in arr: | |
list_tokens.append(tokens) | |
assert (len(list_tokens) == len(list_start_times)) | |
zipped_note_events_and_tie = [] | |
list_events = [] | |
total_err_cnt = Counter() | |
for tokens, start_time in zip(list_tokens, list_start_times): | |
note_events, tie_note_events, last_activity, events, err_cnt = self.decode( | |
tokens, start_time, return_events) | |
zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time)) | |
if return_events: | |
list_events.append(events) | |
total_err_cnt += err_cnt | |
if return_events: | |
return zipped_note_events_and_tie, list_events, total_err_cnt | |
else: | |
return zipped_note_events_and_tie, total_err_cnt | |