Spaces:
Build error
Build error
from fastai.basics import * | |
from .numpy_encode import * | |
from .music_transformer import transform | |
BOS = 'xxbos' | |
PAD = 'xxpad' | |
EOS = 'xxeos' | |
MASK = 'xxmask' # Used for BERT masked language modeling. | |
CSEQ = 'xxcseq' # Used for Seq2Seq translation - denotes start of chord sequence | |
MSEQ = 'xxmseq' # Used for Seq2Seq translation - denotes start of melody sequence | |
# Deprecated tokens. Kept for compatibility | |
S2SCLS = 'xxs2scls' # deprecated | |
NSCLS = 'xxnscls' # deprecated | |
SEP = 'xxsep' # Used to denote end of timestep (required for polyphony). separator idx = -1 (part of notes) | |
SPECIAL_TOKS = [BOS, PAD, EOS, S2SCLS, MASK, CSEQ, MSEQ, NSCLS, SEP] # Important: SEP token must be last | |
NOTE_TOKS = [f'n{i}' for i in range(NOTE_SIZE)] | |
DUR_TOKS = [f'd{i}' for i in range(DUR_SIZE)] | |
NOTE_START, NOTE_END = NOTE_TOKS[0], NOTE_TOKS[-1] | |
DUR_START, DUR_END = DUR_TOKS[0], DUR_TOKS[-1] | |
MTEMPO_SIZE = 10 | |
MTEMPO_OFF = 'mt0' | |
MTEMPO_TOKS = [f'mt{i}' for i in range(MTEMPO_SIZE)] | |
# Vocab - token to index mapping | |
class MusicVocab(): | |
"Contain the correspondence between numbers and tokens and numericalize." | |
def __init__(self, itos:Collection[str]): | |
self.itos = itos | |
self.stoi = {v:k for k,v in enumerate(self.itos)} | |
def numericalize(self, t:Collection[str]) -> List[int]: | |
"Convert a list of tokens `t` to their ids." | |
return [self.stoi[w] for w in t] | |
def textify(self, nums:Collection[int], sep=' ') -> List[str]: | |
"Convert a list of `nums` to their tokens." | |
items = [self.itos[i] for i in nums] | |
return sep.join(items) if sep is not None else items | |
def to_music_item(self, idxenc): | |
return transform.MusicItem(idxenc, self) | |
def mask_idx(self): return self.stoi[MASK] | |
def pad_idx(self): return self.stoi[PAD] | |
def bos_idx(self): return self.stoi[BOS] | |
def sep_idx(self): return self.stoi[SEP] | |
def npenc_range(self): return (self.stoi[SEP], self.stoi[DUR_END]+1) | |
def note_range(self): return self.stoi[NOTE_START], self.stoi[NOTE_END]+1 | |
def dur_range(self): return self.stoi[DUR_START], self.stoi[DUR_END]+1 | |
def is_duration(self, idx): | |
return idx >= self.dur_range[0] and idx < self.dur_range[1] | |
def is_duration_or_pad(self, idx): | |
return idx == self.pad_idx or self.is_duration(idx) | |
def __getstate__(self): | |
return {'itos':self.itos} | |
def __setstate__(self, state:dict): | |
self.itos = state['itos'] | |
self.stoi = {v:k for k,v in enumerate(self.itos)} | |
def __len__(self): return len(self.itos) | |
def save(self, path): | |
"Save `self.itos` in `path`" | |
pickle.dump(self.itos, open(path, 'wb')) | |
def create(cls) -> 'Vocab': | |
"Create a vocabulary from a set of `tokens`." | |
itos = SPECIAL_TOKS + NOTE_TOKS + DUR_TOKS + MTEMPO_TOKS | |
if len(itos)%8 != 0: | |
itos = itos + [f'dummy{i}' for i in range(len(itos)%8)] | |
return cls(itos) | |
def load(cls, path): | |
"Load the `Vocab` contained in `path`" | |
itos = pickle.load(open(path, 'rb')) | |
return cls(itos) | |