Spaces:
Build error
Build error
from fastai.basics import * | |
from .transform import * | |
from ..music_transformer.dataloader import MusicDataBunch, MusicItemList | |
# Sequence 2 Sequence Translate | |
class S2SFileProcessor(PreProcessor): | |
"`PreProcessor` that opens the filenames and read the texts." | |
def process_one(self,item): | |
out = np.load(item, allow_pickle=True) | |
if out.shape != (2,): return None | |
if not 16 < len(out[0]) < 2048: return None | |
if not 16 < len(out[1]) < 2048: return None | |
return out | |
def process(self, ds:Collection): | |
ds.items = [self.process_one(item) for item in ds.items] | |
ds.items = [i for i in ds.items if i is not None] # filter out None | |
class S2SPartsProcessor(PreProcessor): | |
"Encodes midi file into 2 separate parts - melody and chords." | |
def process_one(self, item): | |
m, c = item | |
mtrack = MultitrackItem.from_npenc_parts(m, c, vocab=self.vocab) | |
return mtrack.to_idx() | |
def process(self, ds): | |
self.vocab = ds.vocab | |
ds.items = [self.process_one(item) for item in ds.items] | |
class Midi2MultitrackProcessor(PreProcessor): | |
"Converts midi files to multitrack items" | |
def process_one(self, midi_file): | |
try: | |
item = MultitrackItem.from_file(midi_file, vocab=self.vocab) | |
except Exception as e: | |
print(e) | |
return None | |
return item.to_idx() | |
def process(self, ds): | |
self.vocab = ds.vocab | |
ds.items = [self.process_one(item) for item in ds.items] | |
ds.items = [i for i in ds.items if i is not None] | |
class S2SPreloader(Callback): | |
def __init__(self, dataset:LabelList, bptt:int=512, | |
transpose_range=None, **kwargs): | |
self.dataset,self.bptt = dataset,bptt | |
self.vocab = self.dataset.vocab | |
self.transpose_range = transpose_range | |
self.rand_transpose = partial(rand_transpose_value, rand_range=transpose_range) if transpose_range is not None else None | |
def __getitem__(self, k:int): | |
item,empty_label = self.dataset[k] | |
if self.rand_transpose is not None: | |
val = self.rand_transpose() | |
item = item.transpose(val) | |
item = item.pad_to(self.bptt+1) | |
((m_x, m_pos), (c_x, c_pos)) = item.to_idx() | |
return m_x, m_pos, c_x, c_pos | |
def __len__(self): | |
return len(self.dataset) | |
def rand_transpose_value(rand_range=(0,24), p=0.5): | |
if np.random.rand() < p: return np.random.randint(*rand_range)-rand_range[1]//2 | |
return 0 | |
class S2SItemList(MusicItemList): | |
_bunch = MusicDataBunch | |
def get(self, i): | |
return MultitrackItem.from_idx(self.items[i], self.vocab) | |
# DATALOADING AND TRANSFORMATIONS | |
# These transforms happen on batch | |
def mask_tfm(b, mask_range, mask_idx, pad_idx, p=0.3): | |
# mask range (min, max) | |
# replacement vals - [x_replace, y_replace]. Usually [mask_idx, pad_idx] | |
# p = replacement probability | |
x,y = b | |
x,y = x.clone(),y.clone() | |
rand = torch.rand(x.shape, device=x.device) | |
rand[x < mask_range[0]] = 1.0 | |
rand[x >= mask_range[1]] = 1.0 | |
# p(15%) of words are replaced. Of those p(15%) - 80% are masked. 10% wrong word. 10% unchanged | |
y[rand > p] = pad_idx # pad unchanged 80%. Remove these from loss/acc metrics | |
x[rand <= (p*.8)] = mask_idx # 80% = mask | |
wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word | |
x[wrong_word] = torch.randint(*mask_range, [wrong_word.sum().item()], device=x.device) | |
return x, y | |
def mask_lm_tfm_default(b, vocab, mask_p=0.3): | |
return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p) | |
def mask_lm_tfm_pitchdur(b, vocab, mask_p=0.9): | |
mask_range = vocab.dur_range if np.random.rand() < 0.5 else vocab.note_range | |
return mask_lm_tfm(b, mask_range=mask_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p) | |
def mask_lm_tfm(b, mask_range, mask_idx, pad_idx, mask_p): | |
x,y = b | |
x_lm,x_pos = x[...,0], x[...,1] | |
y_lm,y_pos = y[...,0], y[...,1] | |
# Note: masking y_lm instead of x_lm. Just in case we ever do sequential s2s training | |
x_msk, y_msk = mask_tfm((y_lm, y_lm), mask_range=mask_range, mask_idx=mask_idx, pad_idx=pad_idx, p=mask_p) | |
msk_pos = y_pos | |
x_dict = { | |
'msk': { 'x': x_msk, 'pos': msk_pos }, | |
'lm': { 'x': x_lm, 'pos': msk_pos } | |
} | |
y_dict = { 'msk': y_msk, 'lm': y_lm } | |
return x_dict, y_dict | |
def melody_chord_tfm(b): | |
m,m_pos,c,c_pos = b | |
# offset x and y for next word prediction | |
y_m = m[:,1:] | |
x_m, m_pos = m[:,:-1], m_pos[:,:-1] | |
y_c = c[:,1:] | |
x_c, c_pos = c[:,:-1], c_pos[:,:-1] | |
x_dict = { | |
'c2m': { | |
'enc': x_c, | |
'enc_pos': c_pos, | |
'dec': x_m, | |
'dec_pos': m_pos | |
}, | |
'm2c': { | |
'enc': x_m, | |
'enc_pos': m_pos, | |
'dec': x_c, | |
'dec_pos': c_pos | |
} | |
} | |
y_dict = { | |
'c2m': y_m, 'm2c': y_c | |
} | |
return x_dict, y_dict | |