Spaces:
Build error
Build error
File size: 5,117 Bytes
f35cc94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from fastai.basics import *
from .transform import *
from ..music_transformer.dataloader import MusicDataBunch, MusicItemList
# Sequence 2 Sequence Translate
class S2SFileProcessor(PreProcessor):
"`PreProcessor` that opens the filenames and read the texts."
def process_one(self,item):
out = np.load(item, allow_pickle=True)
if out.shape != (2,): return None
if not 16 < len(out[0]) < 2048: return None
if not 16 < len(out[1]) < 2048: return None
return out
def process(self, ds:Collection):
ds.items = [self.process_one(item) for item in ds.items]
ds.items = [i for i in ds.items if i is not None] # filter out None
class S2SPartsProcessor(PreProcessor):
"Encodes midi file into 2 separate parts - melody and chords."
def process_one(self, item):
m, c = item
mtrack = MultitrackItem.from_npenc_parts(m, c, vocab=self.vocab)
return mtrack.to_idx()
def process(self, ds):
self.vocab = ds.vocab
ds.items = [self.process_one(item) for item in ds.items]
class Midi2MultitrackProcessor(PreProcessor):
"Converts midi files to multitrack items"
def process_one(self, midi_file):
try:
item = MultitrackItem.from_file(midi_file, vocab=self.vocab)
except Exception as e:
print(e)
return None
return item.to_idx()
def process(self, ds):
self.vocab = ds.vocab
ds.items = [self.process_one(item) for item in ds.items]
ds.items = [i for i in ds.items if i is not None]
class S2SPreloader(Callback):
def __init__(self, dataset:LabelList, bptt:int=512,
transpose_range=None, **kwargs):
self.dataset,self.bptt = dataset,bptt
self.vocab = self.dataset.vocab
self.transpose_range = transpose_range
self.rand_transpose = partial(rand_transpose_value, rand_range=transpose_range) if transpose_range is not None else None
def __getitem__(self, k:int):
item,empty_label = self.dataset[k]
if self.rand_transpose is not None:
val = self.rand_transpose()
item = item.transpose(val)
item = item.pad_to(self.bptt+1)
((m_x, m_pos), (c_x, c_pos)) = item.to_idx()
return m_x, m_pos, c_x, c_pos
def __len__(self):
return len(self.dataset)
def rand_transpose_value(rand_range=(0,24), p=0.5):
if np.random.rand() < p: return np.random.randint(*rand_range)-rand_range[1]//2
return 0
class S2SItemList(MusicItemList):
_bunch = MusicDataBunch
def get(self, i):
return MultitrackItem.from_idx(self.items[i], self.vocab)
# DATALOADING AND TRANSFORMATIONS
# These transforms happen on batch
def mask_tfm(b, mask_range, mask_idx, pad_idx, p=0.3):
# mask range (min, max)
# replacement vals - [x_replace, y_replace]. Usually [mask_idx, pad_idx]
# p = replacement probability
x,y = b
x,y = x.clone(),y.clone()
rand = torch.rand(x.shape, device=x.device)
rand[x < mask_range[0]] = 1.0
rand[x >= mask_range[1]] = 1.0
# p(15%) of words are replaced. Of those p(15%) - 80% are masked. 10% wrong word. 10% unchanged
y[rand > p] = pad_idx # pad unchanged 80%. Remove these from loss/acc metrics
x[rand <= (p*.8)] = mask_idx # 80% = mask
wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
x[wrong_word] = torch.randint(*mask_range, [wrong_word.sum().item()], device=x.device)
return x, y
def mask_lm_tfm_default(b, vocab, mask_p=0.3):
return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)
def mask_lm_tfm_pitchdur(b, vocab, mask_p=0.9):
mask_range = vocab.dur_range if np.random.rand() < 0.5 else vocab.note_range
return mask_lm_tfm(b, mask_range=mask_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)
def mask_lm_tfm(b, mask_range, mask_idx, pad_idx, mask_p):
x,y = b
x_lm,x_pos = x[...,0], x[...,1]
y_lm,y_pos = y[...,0], y[...,1]
# Note: masking y_lm instead of x_lm. Just in case we ever do sequential s2s training
x_msk, y_msk = mask_tfm((y_lm, y_lm), mask_range=mask_range, mask_idx=mask_idx, pad_idx=pad_idx, p=mask_p)
msk_pos = y_pos
x_dict = {
'msk': { 'x': x_msk, 'pos': msk_pos },
'lm': { 'x': x_lm, 'pos': msk_pos }
}
y_dict = { 'msk': y_msk, 'lm': y_lm }
return x_dict, y_dict
def melody_chord_tfm(b):
m,m_pos,c,c_pos = b
# offset x and y for next word prediction
y_m = m[:,1:]
x_m, m_pos = m[:,:-1], m_pos[:,:-1]
y_c = c[:,1:]
x_c, c_pos = c[:,:-1], c_pos[:,:-1]
x_dict = {
'c2m': {
'enc': x_c,
'enc_pos': c_pos,
'dec': x_m,
'dec_pos': m_pos
},
'm2c': {
'enc': x_m,
'enc_pos': m_pos,
'dec': x_c,
'dec_pos': c_pos
}
}
y_dict = {
'c2m': y_m, 'm2c': y_c
}
return x_dict, y_dict
|