File size: 5,117 Bytes
f35cc94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from fastai.basics import *
from .transform import *
from ..music_transformer.dataloader import MusicDataBunch, MusicItemList
# Sequence 2 Sequence Translate

class S2SFileProcessor(PreProcessor):
    "`PreProcessor` that opens the filenames and read the texts."
    def process_one(self,item):
        out = np.load(item, allow_pickle=True)
        if out.shape != (2,): return None
        if not 16 < len(out[0]) < 2048: return None
        if not 16 < len(out[1]) < 2048: return None
        return out
    
    def process(self, ds:Collection):
        ds.items = [self.process_one(item) for item in ds.items]
        ds.items = [i for i in ds.items if i is not None] # filter out None

class S2SPartsProcessor(PreProcessor):
    "Encodes midi file into 2 separate parts - melody and chords."
    
    def process_one(self, item):
        m, c = item
        mtrack = MultitrackItem.from_npenc_parts(m, c, vocab=self.vocab)
        return mtrack.to_idx()
    
    def process(self, ds):
        self.vocab = ds.vocab
        ds.items = [self.process_one(item) for item in ds.items]
        
class Midi2MultitrackProcessor(PreProcessor):
    "Converts midi files to multitrack items"
    def process_one(self, midi_file):
        try:
            item = MultitrackItem.from_file(midi_file, vocab=self.vocab)
        except Exception as e:
            print(e)
            return None
        return item.to_idx()
        
    def process(self, ds):
        self.vocab = ds.vocab
        ds.items = [self.process_one(item) for item in ds.items]
        ds.items = [i for i in ds.items if i is not None]
    
class S2SPreloader(Callback):
    def __init__(self, dataset:LabelList, bptt:int=512, 
                 transpose_range=None, **kwargs):
        self.dataset,self.bptt = dataset,bptt
        self.vocab = self.dataset.vocab
        self.transpose_range = transpose_range
        self.rand_transpose = partial(rand_transpose_value, rand_range=transpose_range) if transpose_range is not None else None
        
    def __getitem__(self, k:int):
        item,empty_label = self.dataset[k]
        
        if self.rand_transpose is not None:
            val = self.rand_transpose()
            item = item.transpose(val)
        item = item.pad_to(self.bptt+1)
        ((m_x, m_pos), (c_x, c_pos)) = item.to_idx()
        return m_x, m_pos, c_x, c_pos
    
    def __len__(self):
        return len(self.dataset)

def rand_transpose_value(rand_range=(0,24), p=0.5):
    if np.random.rand() < p: return np.random.randint(*rand_range)-rand_range[1]//2
    return 0

class S2SItemList(MusicItemList):
    _bunch = MusicDataBunch
    def get(self, i):
        return MultitrackItem.from_idx(self.items[i], self.vocab)

# DATALOADING AND TRANSFORMATIONS
# These transforms happen on batch

def mask_tfm(b, mask_range, mask_idx, pad_idx, p=0.3):
    # mask range (min, max)
    # replacement vals - [x_replace, y_replace]. Usually [mask_idx, pad_idx]
    # p = replacement probability
    x,y = b
    x,y = x.clone(),y.clone()
    rand = torch.rand(x.shape, device=x.device)
    rand[x < mask_range[0]] = 1.0
    rand[x >= mask_range[1]] = 1.0
    
    # p(15%) of words are replaced. Of those p(15%) - 80% are masked. 10% wrong word. 10% unchanged
    y[rand > p] = pad_idx # pad unchanged 80%. Remove these from loss/acc metrics
    x[rand <= (p*.8)] = mask_idx # 80% = mask
    wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
    x[wrong_word] = torch.randint(*mask_range, [wrong_word.sum().item()], device=x.device)
    return x, y

def mask_lm_tfm_default(b, vocab, mask_p=0.3):
    return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)

def mask_lm_tfm_pitchdur(b, vocab, mask_p=0.9):
    mask_range = vocab.dur_range if np.random.rand() < 0.5 else vocab.note_range
    return mask_lm_tfm(b, mask_range=mask_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)

def mask_lm_tfm(b, mask_range, mask_idx, pad_idx, mask_p):
    x,y = b
    x_lm,x_pos = x[...,0], x[...,1]
    y_lm,y_pos = y[...,0], y[...,1]
    
    # Note: masking y_lm instead of x_lm. Just in case we ever do sequential s2s training
    x_msk, y_msk = mask_tfm((y_lm, y_lm), mask_range=mask_range, mask_idx=mask_idx, pad_idx=pad_idx, p=mask_p)
    msk_pos = y_pos
    
    x_dict = { 
        'msk': { 'x': x_msk, 'pos': msk_pos },
        'lm': { 'x': x_lm, 'pos': msk_pos }
    }
    y_dict = { 'msk': y_msk, 'lm': y_lm }
    return x_dict, y_dict

def melody_chord_tfm(b):
    m,m_pos,c,c_pos = b
    
    # offset x and y for next word prediction
    y_m = m[:,1:]
    x_m, m_pos = m[:,:-1], m_pos[:,:-1]
    
    y_c = c[:,1:]
    x_c, c_pos = c[:,:-1], c_pos[:,:-1]
    
    x_dict = { 
        'c2m': {
            'enc': x_c,
            'enc_pos': c_pos,
            'dec': x_m,
            'dec_pos': m_pos
        },
        'm2c': {
            'enc': x_m,
            'enc_pos': m_pos,
            'dec': x_c,
            'dec_pos': c_pos
        }
    }
    y_dict = {
        'c2m': y_m, 'm2c': y_c
    }
    return x_dict, y_dict