jerald commited on
Commit
ac4cbcf
·
1 Parent(s): ad497e7

added musicautobot library

Browse files
Files changed (46) hide show
  1. app.py +6 -4
  2. requirements.txt +1 -0
  3. utils/.DS_Store +0 -0
  4. utils/musicautobot/.DS_Store +0 -0
  5. utils/musicautobot/__init__.py +0 -3
  6. utils/musicautobot/__pycache__/__init__.cpython-310.pyc +0 -0
  7. utils/musicautobot/__pycache__/config.cpython-310.pyc +0 -0
  8. utils/musicautobot/__pycache__/numpy_encode.cpython-310.pyc +0 -0
  9. utils/musicautobot/__pycache__/vocab.cpython-310.pyc +0 -0
  10. utils/musicautobot/config.py +0 -47
  11. utils/musicautobot/multitask_transformer/__init__.py +0 -3
  12. utils/musicautobot/multitask_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  13. utils/musicautobot/multitask_transformer/__pycache__/dataloader.cpython-310.pyc +0 -0
  14. utils/musicautobot/multitask_transformer/__pycache__/learner.cpython-310.pyc +0 -0
  15. utils/musicautobot/multitask_transformer/__pycache__/model.cpython-310.pyc +0 -0
  16. utils/musicautobot/multitask_transformer/__pycache__/transform.cpython-310.pyc +0 -0
  17. utils/musicautobot/multitask_transformer/dataloader.py +0 -146
  18. utils/musicautobot/multitask_transformer/learner.py +0 -340
  19. utils/musicautobot/multitask_transformer/model.py +0 -258
  20. utils/musicautobot/multitask_transformer/transform.py +0 -68
  21. utils/musicautobot/music_transformer/__init__.py +0 -3
  22. utils/musicautobot/music_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  23. utils/musicautobot/music_transformer/__pycache__/dataloader.cpython-310.pyc +0 -0
  24. utils/musicautobot/music_transformer/__pycache__/learner.cpython-310.pyc +0 -0
  25. utils/musicautobot/music_transformer/__pycache__/model.cpython-310.pyc +0 -0
  26. utils/musicautobot/music_transformer/__pycache__/transform.cpython-310.pyc +0 -0
  27. utils/musicautobot/music_transformer/dataloader.py +0 -229
  28. utils/musicautobot/music_transformer/learner.py +0 -171
  29. utils/musicautobot/music_transformer/model.py +0 -66
  30. utils/musicautobot/music_transformer/transform.py +0 -235
  31. utils/musicautobot/numpy_encode.py +0 -302
  32. utils/musicautobot/utils/__init__.py +0 -0
  33. utils/musicautobot/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  34. utils/musicautobot/utils/__pycache__/attention_mask.cpython-310.pyc +0 -0
  35. utils/musicautobot/utils/__pycache__/file_processing.cpython-310.pyc +0 -0
  36. utils/musicautobot/utils/__pycache__/midifile.cpython-310.pyc +0 -0
  37. utils/musicautobot/utils/__pycache__/setup_musescore.cpython-310.pyc +0 -0
  38. utils/musicautobot/utils/__pycache__/top_k_top_p.cpython-310.pyc +0 -0
  39. utils/musicautobot/utils/attention_mask.py +0 -21
  40. utils/musicautobot/utils/file_processing.py +0 -52
  41. utils/musicautobot/utils/lamb.py +0 -106
  42. utils/musicautobot/utils/midifile.py +0 -107
  43. utils/musicautobot/utils/setup_musescore.py +0 -46
  44. utils/musicautobot/utils/stacked_dataloader.py +0 -70
  45. utils/musicautobot/utils/top_k_top_p.py +0 -35
  46. utils/musicautobot/vocab.py +0 -93
app.py CHANGED
@@ -1,7 +1,9 @@
1
- from utils.musicautobot.numpy_encode import *
2
- from utils.musicautobot.utils.file_processing import process_all, process_file
3
- from utils.musicautobot.config import *
4
- from utils.musicautobot.music_transformer import *
 
 
5
 
6
  import gradio as gr
7
  from midi2audio import FluidSynth
 
1
+ from musicautobot.numpy_encode import *
2
+ from musicautobot.utils.file_processing import process_all, process_file
3
+ from musicautobot.config import *
4
+ from musicautobot.music_transformer import *
5
+ from musicautobot.utils.setup_musescore import setup_musescore
6
+ setup_musescore()
7
 
8
  import gradio as gr
9
  from midi2audio import FluidSynth
requirements.txt CHANGED
@@ -2,5 +2,6 @@ gradio
2
  midi2audio
3
  music21
4
  git+https://github.com/fastai/fastai1.git@master
 
5
  pebble
6
  spacy
 
2
  midi2audio
3
  music21
4
  git+https://github.com/fastai/fastai1.git@master
5
+ git+https://github.com/bearpelican/musicautobot.git
6
  pebble
7
  spacy
utils/.DS_Store DELETED
Binary file (6.15 kB)
 
utils/musicautobot/.DS_Store DELETED
Binary file (6.15 kB)
 
utils/musicautobot/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .utils.setup_musescore import setup_musescore
2
-
3
- setup_musescore()
 
 
 
 
utils/musicautobot/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (239 Bytes)
 
utils/musicautobot/__pycache__/config.cpython-310.pyc DELETED
Binary file (1.25 kB)
 
utils/musicautobot/__pycache__/numpy_encode.cpython-310.pyc DELETED
Binary file (9.77 kB)
 
utils/musicautobot/__pycache__/vocab.cpython-310.pyc DELETED
Binary file (5.24 kB)
 
utils/musicautobot/config.py DELETED
@@ -1,47 +0,0 @@
1
- from fastai.text.models.transformer import tfmerXL_lm_config, Activation
2
- # from .vocab import MusicVocab
3
-
4
- def default_config():
5
- config = tfmerXL_lm_config.copy()
6
- config['act'] = Activation.GeLU
7
-
8
- config['mem_len'] = 512
9
- config['d_model'] = 512
10
- config['d_inner'] = 2048
11
- config['n_layers'] = 16
12
-
13
- config['n_heads'] = 8
14
- config['d_head'] = 64
15
-
16
- return config
17
-
18
- def music_config():
19
- config = default_config()
20
- config['encode_position'] = True
21
- return config
22
-
23
- def musicm_config():
24
- config = music_config()
25
- config['d_model'] = 768
26
- config['d_inner'] = 3072
27
- config['n_heads'] = 12
28
- config['d_head'] = 64
29
- config['n_layers'] = 12
30
- return config
31
-
32
- def multitask_config():
33
- config = default_config()
34
- config['bias'] = True
35
- config['enc_layers'] = 8
36
- config['dec_layers'] = 8
37
- del config['n_layers']
38
- return config
39
-
40
- def multitaskm_config():
41
- config = musicm_config()
42
- config['bias'] = True
43
- config['enc_layers'] = 12
44
- config['dec_layers'] = 12
45
- del config['n_layers']
46
- return config
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/multitask_transformer/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .dataloader import *
2
- from .model import *
3
- from .learner import *
 
 
 
 
utils/musicautobot/multitask_transformer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (257 Bytes)
 
utils/musicautobot/multitask_transformer/__pycache__/dataloader.cpython-310.pyc DELETED
Binary file (6.17 kB)
 
utils/musicautobot/multitask_transformer/__pycache__/learner.cpython-310.pyc DELETED
Binary file (11.5 kB)
 
utils/musicautobot/multitask_transformer/__pycache__/model.cpython-310.pyc DELETED
Binary file (11.4 kB)
 
utils/musicautobot/multitask_transformer/__pycache__/transform.cpython-310.pyc DELETED
Binary file (3.72 kB)
 
utils/musicautobot/multitask_transformer/dataloader.py DELETED
@@ -1,146 +0,0 @@
1
- from fastai.basics import *
2
- from .transform import *
3
- from ..music_transformer.dataloader import MusicDataBunch, MusicItemList
4
- # Sequence 2 Sequence Translate
5
-
6
- class S2SFileProcessor(PreProcessor):
7
- "`PreProcessor` that opens the filenames and read the texts."
8
- def process_one(self,item):
9
- out = np.load(item, allow_pickle=True)
10
- if out.shape != (2,): return None
11
- if not 16 < len(out[0]) < 2048: return None
12
- if not 16 < len(out[1]) < 2048: return None
13
- return out
14
-
15
- def process(self, ds:Collection):
16
- ds.items = [self.process_one(item) for item in ds.items]
17
- ds.items = [i for i in ds.items if i is not None] # filter out None
18
-
19
- class S2SPartsProcessor(PreProcessor):
20
- "Encodes midi file into 2 separate parts - melody and chords."
21
-
22
- def process_one(self, item):
23
- m, c = item
24
- mtrack = MultitrackItem.from_npenc_parts(m, c, vocab=self.vocab)
25
- return mtrack.to_idx()
26
-
27
- def process(self, ds):
28
- self.vocab = ds.vocab
29
- ds.items = [self.process_one(item) for item in ds.items]
30
-
31
- class Midi2MultitrackProcessor(PreProcessor):
32
- "Converts midi files to multitrack items"
33
- def process_one(self, midi_file):
34
- try:
35
- item = MultitrackItem.from_file(midi_file, vocab=self.vocab)
36
- except Exception as e:
37
- print(e)
38
- return None
39
- return item.to_idx()
40
-
41
- def process(self, ds):
42
- self.vocab = ds.vocab
43
- ds.items = [self.process_one(item) for item in ds.items]
44
- ds.items = [i for i in ds.items if i is not None]
45
-
46
- class S2SPreloader(Callback):
47
- def __init__(self, dataset:LabelList, bptt:int=512,
48
- transpose_range=None, **kwargs):
49
- self.dataset,self.bptt = dataset,bptt
50
- self.vocab = self.dataset.vocab
51
- self.transpose_range = transpose_range
52
- self.rand_transpose = partial(rand_transpose_value, rand_range=transpose_range) if transpose_range is not None else None
53
-
54
- def __getitem__(self, k:int):
55
- item,empty_label = self.dataset[k]
56
-
57
- if self.rand_transpose is not None:
58
- val = self.rand_transpose()
59
- item = item.transpose(val)
60
- item = item.pad_to(self.bptt+1)
61
- ((m_x, m_pos), (c_x, c_pos)) = item.to_idx()
62
- return m_x, m_pos, c_x, c_pos
63
-
64
- def __len__(self):
65
- return len(self.dataset)
66
-
67
- def rand_transpose_value(rand_range=(0,24), p=0.5):
68
- if np.random.rand() < p: return np.random.randint(*rand_range)-rand_range[1]//2
69
- return 0
70
-
71
- class S2SItemList(MusicItemList):
72
- _bunch = MusicDataBunch
73
- def get(self, i):
74
- return MultitrackItem.from_idx(self.items[i], self.vocab)
75
-
76
- # DATALOADING AND TRANSFORMATIONS
77
- # These transforms happen on batch
78
-
79
- def mask_tfm(b, mask_range, mask_idx, pad_idx, p=0.3):
80
- # mask range (min, max)
81
- # replacement vals - [x_replace, y_replace]. Usually [mask_idx, pad_idx]
82
- # p = replacement probability
83
- x,y = b
84
- x,y = x.clone(),y.clone()
85
- rand = torch.rand(x.shape, device=x.device)
86
- rand[x < mask_range[0]] = 1.0
87
- rand[x >= mask_range[1]] = 1.0
88
-
89
- # p(15%) of words are replaced. Of those p(15%) - 80% are masked. 10% wrong word. 10% unchanged
90
- y[rand > p] = pad_idx # pad unchanged 80%. Remove these from loss/acc metrics
91
- x[rand <= (p*.8)] = mask_idx # 80% = mask
92
- wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
93
- x[wrong_word] = torch.randint(*mask_range, [wrong_word.sum().item()], device=x.device)
94
- return x, y
95
-
96
- def mask_lm_tfm_default(b, vocab, mask_p=0.3):
97
- return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)
98
-
99
- def mask_lm_tfm_pitchdur(b, vocab, mask_p=0.9):
100
- mask_range = vocab.dur_range if np.random.rand() < 0.5 else vocab.note_range
101
- return mask_lm_tfm(b, mask_range=mask_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)
102
-
103
- def mask_lm_tfm(b, mask_range, mask_idx, pad_idx, mask_p):
104
- x,y = b
105
- x_lm,x_pos = x[...,0], x[...,1]
106
- y_lm,y_pos = y[...,0], y[...,1]
107
-
108
- # Note: masking y_lm instead of x_lm. Just in case we ever do sequential s2s training
109
- x_msk, y_msk = mask_tfm((y_lm, y_lm), mask_range=mask_range, mask_idx=mask_idx, pad_idx=pad_idx, p=mask_p)
110
- msk_pos = y_pos
111
-
112
- x_dict = {
113
- 'msk': { 'x': x_msk, 'pos': msk_pos },
114
- 'lm': { 'x': x_lm, 'pos': msk_pos }
115
- }
116
- y_dict = { 'msk': y_msk, 'lm': y_lm }
117
- return x_dict, y_dict
118
-
119
- def melody_chord_tfm(b):
120
- m,m_pos,c,c_pos = b
121
-
122
- # offset x and y for next word prediction
123
- y_m = m[:,1:]
124
- x_m, m_pos = m[:,:-1], m_pos[:,:-1]
125
-
126
- y_c = c[:,1:]
127
- x_c, c_pos = c[:,:-1], c_pos[:,:-1]
128
-
129
- x_dict = {
130
- 'c2m': {
131
- 'enc': x_c,
132
- 'enc_pos': c_pos,
133
- 'dec': x_m,
134
- 'dec_pos': m_pos
135
- },
136
- 'm2c': {
137
- 'enc': x_m,
138
- 'enc_pos': m_pos,
139
- 'dec': x_c,
140
- 'dec_pos': c_pos
141
- }
142
- }
143
- y_dict = {
144
- 'c2m': y_m, 'm2c': y_c
145
- }
146
- return x_dict, y_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/multitask_transformer/learner.py DELETED
@@ -1,340 +0,0 @@
1
- from fastai.basics import *
2
- from ..vocab import *
3
- from ..utils.top_k_top_p import top_k_top_p
4
- from ..utils.midifile import is_empty_midi
5
- from ..music_transformer.transform import *
6
- from ..music_transformer.learner import filter_invalid_indexes
7
- from .model import get_multitask_model
8
- from .dataloader import *
9
-
10
- def multitask_model_learner(data:DataBunch, config:dict=None, drop_mult:float=1.,
11
- pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
12
- "Create a `Learner` with a language model from `data` and `arch`."
13
- vocab = data.vocab
14
- vocab_size = len(vocab)
15
-
16
- if pretrained_path:
17
- state = torch.load(pretrained_path, map_location='cpu')
18
- if config is None: config = state['config']
19
-
20
- model = get_multitask_model(vocab_size, config=config, drop_mult=drop_mult, pad_idx=vocab.pad_idx)
21
- metrics = [AverageMultiMetric(partial(m, pad_idx=vocab.pad_idx)) for m in [mask_acc, lm_acc, c2m_acc, m2c_acc]]
22
- loss_func = MultiLoss(ignore_index=data.vocab.pad_idx)
23
- learn = MultitaskLearner(data, model, loss_func=loss_func, metrics=metrics, **learn_kwargs)
24
-
25
- if pretrained_path:
26
- get_model(model).load_state_dict(state['model'], strict=False)
27
- if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
28
- try: learn.opt.load_state_dict(state['opt'])
29
- except: pass
30
- del state
31
- gc.collect()
32
-
33
- return learn
34
-
35
- class MultitaskLearner(Learner):
36
- def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
37
- "Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
38
- out_path = super().save(file, return_path=True, with_opt=with_opt)
39
- if config and out_path:
40
- state = torch.load(out_path)
41
- state['config'] = config
42
- torch.save(state, out_path)
43
- del state
44
- gc.collect()
45
- return out_path
46
-
47
- def predict_nw(self, item:MusicItem, n_words:int=128,
48
- temperatures:float=(1.0,1.0), min_bars=4,
49
- top_k=30, top_p=0.6):
50
- "Return the `n_words` that come after `text`."
51
- self.model.reset()
52
- new_idx = []
53
- vocab = self.data.vocab
54
- x, pos = item.to_tensor(), item.get_pos_tensor()
55
- last_pos = pos[-1] if len(pos) else 0
56
- y = torch.tensor([0])
57
-
58
- start_pos = last_pos
59
-
60
- sep_count = 0
61
- bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
62
- vocab = self.data.vocab
63
-
64
- repeat_count = 0
65
-
66
- for i in progress_bar(range(n_words), leave=True):
67
- batch = { 'lm': { 'x': x[None], 'pos': pos[None] } }, y
68
- logits = self.pred_batch(batch=batch)['lm'][-1][-1]
69
-
70
- prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
71
-
72
- # Temperature
73
- # Use first temperatures value if last prediction was duration
74
- temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
75
- repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
76
- temperature += repeat_penalty
77
- if temperature != 1.: logits = logits / temperature
78
-
79
-
80
- # Filter
81
- # bar = 16 beats
82
- filter_value = -float('Inf')
83
- if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
84
-
85
- logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
86
- logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
87
-
88
- # Sample
89
- probs = F.softmax(logits, dim=-1)
90
- idx = torch.multinomial(probs, 1).item()
91
-
92
- # Update repeat count
93
- num_choices = len(probs.nonzero().view(-1))
94
- if num_choices <= 2: repeat_count += 1
95
- else: repeat_count = repeat_count // 2
96
-
97
- if prev_idx==vocab.sep_idx:
98
- duration = idx - vocab.dur_range[0]
99
- last_pos = last_pos + duration
100
-
101
- bars_pred = (last_pos - start_pos) // 16
102
- abs_bar = last_pos // 16
103
- # if (bars % 8 == 0) and (bars_pred > min_bars): break
104
- if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
105
-
106
-
107
- if idx==vocab.bos_idx:
108
- print('Predicted BOS token. Returning prediction...')
109
- break
110
-
111
- new_idx.append(idx)
112
- x = x.new_tensor([idx])
113
- pos = pos.new_tensor([last_pos])
114
-
115
- pred = vocab.to_music_item(np.array(new_idx))
116
- full = item.append(pred)
117
- return pred, full
118
-
119
- def predict_mask(self, masked_item:MusicItem,
120
- temperatures:float=(1.0,1.0),
121
- top_k=20, top_p=0.8):
122
- x = masked_item.to_tensor()
123
- pos = masked_item.get_pos_tensor()
124
- y = torch.tensor([0])
125
- vocab = self.data.vocab
126
- self.model.reset()
127
- mask_idxs = (x == vocab.mask_idx).nonzero().view(-1)
128
-
129
- repeat_count = 0
130
-
131
- for midx in progress_bar(mask_idxs, leave=True):
132
- prev_idx = x[midx-1]
133
-
134
- # Using original positions, otherwise model gets too off track
135
- # pos = torch.tensor(-position_enc(xb[0].cpu().numpy()), device=xb.device)[None]
136
-
137
- # Next Word
138
- logits = self.pred_batch(batch=({ 'msk': { 'x': x[None], 'pos': pos[None] } }, y) )['msk'][0][midx]
139
-
140
- # Temperature
141
- # Use first temperatures value if last prediction was duration
142
- temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
143
- repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
144
- temperature += repeat_penalty
145
- if temperature != 1.: logits = logits / temperature
146
-
147
- # Filter
148
- filter_value = -float('Inf')
149
- special_idxs = [vocab.bos_idx, vocab.sep_idx, vocab.stoi[EOS]]
150
- logits[special_idxs] = filter_value # Don't allow any special tokens (as we are only removing notes and durations)
151
- logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
152
- logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
153
-
154
- # Sampling
155
- probs = F.softmax(logits, dim=-1)
156
- idx = torch.multinomial(probs, 1).item()
157
-
158
- # Update repeat count
159
- num_choices = len(probs.nonzero().view(-1))
160
- if num_choices <= 2: repeat_count += 1
161
- else: repeat_count = repeat_count // 2
162
-
163
- x[midx] = idx
164
-
165
- return vocab.to_music_item(x.cpu().numpy())
166
-
167
- def predict_s2s(self, input_item:MusicItem, target_item:MusicItem, n_words:int=256,
168
- temperatures:float=(1.0,1.0), top_k=30, top_p=0.8,
169
- use_memory=True):
170
- vocab = self.data.vocab
171
-
172
- # Input doesn't change. We can reuse the encoder output on each prediction
173
- with torch.no_grad():
174
- inp, inp_pos = input_item.to_tensor(), input_item.get_pos_tensor()
175
- x_enc = self.model.encoder(inp[None], inp_pos[None])
176
-
177
- # target
178
- targ = target_item.data.tolist()
179
- targ_pos = target_item.position.tolist()
180
- last_pos = targ_pos[-1]
181
- self.model.reset()
182
-
183
- repeat_count = 0
184
-
185
- max_pos = input_item.position[-1] + SAMPLE_FREQ * 4 # Only predict until both tracks/parts have the same length
186
- x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
187
-
188
- for i in progress_bar(range(n_words), leave=True):
189
- # Predict
190
- with torch.no_grad():
191
- dec = self.model.decoder(x[None], pos[None], x_enc)
192
- logits = self.model.head(dec)[-1, -1]
193
-
194
- # Temperature
195
- # Use first temperatures value if last prediction was duration
196
- prev_idx = targ[-1] if len(targ) else vocab.pad_idx
197
- temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
198
- repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
199
- temperature += repeat_penalty
200
- if temperature != 1.: logits = logits / temperature
201
-
202
- # Filter
203
- filter_value = -float('Inf')
204
- logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
205
- logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
206
-
207
- # Sample
208
- probs = F.softmax(logits, dim=-1)
209
- idx = torch.multinomial(probs, 1).item()
210
-
211
- # Update repeat count
212
- num_choices = len(probs.nonzero().view(-1))
213
- if num_choices <= 2: repeat_count += 1
214
- else: repeat_count = repeat_count // 2
215
-
216
- if idx == vocab.bos_idx | idx == vocab.stoi[EOS]:
217
- print('Predicting BOS/EOS')
218
- break
219
-
220
- if prev_idx == vocab.sep_idx:
221
- duration = idx - vocab.dur_range[0]
222
- last_pos = last_pos + duration
223
- if last_pos > max_pos:
224
- print('Predicted past counter-part length. Returning early')
225
- break
226
-
227
- targ_pos.append(last_pos)
228
- targ.append(idx)
229
-
230
- if use_memory:
231
- # Relying on memory for kv. Only need last prediction index
232
- x, pos = inp.new_tensor([targ[-1]]), inp_pos.new_tensor([targ_pos[-1]])
233
- else:
234
- # Reset memory after each prediction, since we feeding the whole sequence every time
235
- self.model.reset()
236
- x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
237
-
238
- return vocab.to_music_item(np.array(targ))
239
-
240
- # High level prediction functions from midi file
241
- def nw_predict_from_midi(learn, midi=None, n_words=400,
242
- temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
243
- vocab = learn.data.vocab
244
- seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
245
- if seed_len is not None: seed = seed.trim_to_beat(seed_len)
246
-
247
- pred, full = learn.predict_nw(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
248
- return full
249
-
250
- def s2s_predict_from_midi(learn, midi=None, n_words=200,
251
- temperatures=(1.0,1.0), top_k=24, top_p=0.7, seed_len=None, pred_melody=True, **kwargs):
252
- multitrack_item = MultitrackItem.from_file(midi, learn.data.vocab)
253
- melody, chords = multitrack_item.melody, multitrack_item.chords
254
- inp, targ = (chords, melody) if pred_melody else (melody, chords)
255
-
256
- # if seed_len is passed, cutoff sequence so we can predict the rest
257
- if seed_len is not None: targ = targ.trim_to_beat(seed_len)
258
- targ = targ.remove_eos()
259
-
260
- pred = learn.predict_s2s(inp, targ, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
261
-
262
- part_order = (pred, inp) if pred_melody else (inp, pred)
263
- return MultitrackItem(*part_order)
264
-
265
- def mask_predict_from_midi(learn, midi=None, predict_notes=True,
266
- temperatures=(1.0,1.0), top_k=30, top_p=0.7, section=None, **kwargs):
267
- item = MusicItem.from_file(midi, learn.data.vocab)
268
- masked_item = item.mask_pitch(section) if predict_notes else item.mask_duration(section)
269
- pred = learn.predict_mask(masked_item, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
270
- return pred
271
-
272
- # LOSS AND METRICS
273
-
274
- class MultiLoss():
275
- def __init__(self, ignore_index=None):
276
- "Loss mult - Mask, NextWord, Seq2Seq"
277
- self.loss = CrossEntropyFlat(ignore_index=ignore_index)
278
-
279
- def __call__(self, inputs:Dict[str,Tensor], targets:Dict[str,Tensor])->Rank0Tensor:
280
- losses = [self.loss(inputs[key], target) for key,target in targets.items()]
281
- return sum(losses)
282
-
283
- def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx)->Rank0Tensor:
284
- if input is None or targ is None: return None
285
- n = targ.shape[0]
286
- input = input.argmax(dim=-1).view(n,-1)
287
- targ = targ.view(n,-1)
288
- mask = targ != pad_idx
289
- return (input[mask]==targ[mask]).float().mean()
290
-
291
- def acc_index(inputs, targets, key, pad_idx):
292
- return acc_ignore_pad(inputs.get(key), targets.get(key), pad_idx)
293
-
294
- def mask_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'msk', pad_idx)
295
- def lm_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'lm', pad_idx)
296
- def c2m_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'c2m', pad_idx)
297
- def m2c_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'm2c', pad_idx)
298
-
299
-
300
- class AverageMultiMetric(AverageMetric):
301
- "Updated fastai.AverageMetric to support multi task metrics."
302
- def on_batch_end(self, last_output, last_target, **kwargs):
303
- "Update metric computation with `last_output` and `last_target`."
304
- if not is_listy(last_target): last_target=[last_target]
305
- val = self.func(last_output, *last_target)
306
- if val is None: return
307
- self.count += first_el(last_target).size(0)
308
- if self.world:
309
- val = val.clone()
310
- dist.all_reduce(val, op=dist.ReduceOp.SUM)
311
- val /= self.world
312
- self.val += first_el(last_target).size(0) * val.detach().cpu()
313
-
314
- def on_epoch_end(self, last_metrics, **kwargs):
315
- "Set the final result in `last_metrics`."
316
- if self.count == 0: return add_metrics(last_metrics, 0)
317
- return add_metrics(last_metrics, self.val/self.count)
318
-
319
-
320
- # MODEL LOADING
321
- class MTTrainer(LearnerCallback):
322
- "`Callback` that regroups lr adjustment to seq_len, AR and TAR."
323
- def __init__(self, learn:Learner, dataloaders=None, starting_mask_window=1):
324
- super().__init__(learn)
325
- self.count = 1
326
- self.mw_start = starting_mask_window
327
- self.dataloaders = dataloaders
328
-
329
- def on_epoch_begin(self, **kwargs):
330
- "Reset the hidden state of the model."
331
- model = get_model(self.learn.model)
332
- model.reset()
333
- model.encoder.mask_steps = max(self.count+self.mw_start, 100)
334
-
335
- def on_epoch_end(self, last_metrics, **kwargs):
336
- "Finish the computation and sends the result to the Recorder."
337
- if self.dataloaders is not None:
338
- self.learn.data = self.dataloaders[self.count % len(self.dataloaders)]
339
- self.count += 1
340
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/multitask_transformer/model.py DELETED
@@ -1,258 +0,0 @@
1
- from fastai.basics import *
2
- from fastai.text.models.transformer import Activation, PositionalEncoding, feed_forward, init_transformer, _line_shift
3
- from fastai.text.models.awd_lstm import RNNDropout
4
- from ..utils.attention_mask import *
5
-
6
- def get_multitask_model(vocab_size:int, config:dict=None, drop_mult:float=1., pad_idx=None):
7
- "Create a language model from `arch` and its `config`, maybe `pretrained`."
8
- for k in config.keys():
9
- if k.endswith('_p'): config[k] *= drop_mult
10
- n_hid = config['d_model']
11
- mem_len = config.pop('mem_len')
12
- embed = TransformerEmbedding(vocab_size, n_hid, embed_p=config['embed_p'], mem_len=mem_len, pad_idx=pad_idx)
13
- encoder = MTEncoder(embed, n_hid, n_layers=config['enc_layers'], mem_len=0, **config) # encoder doesn't need memory
14
- decoder = MTEncoder(embed, n_hid, is_decoder=True, n_layers=config['dec_layers'], mem_len=mem_len, **config)
15
- head = MTLinearDecoder(n_hid, vocab_size, tie_encoder=embed.embed, **config)
16
- model = MultiTransformer(encoder, decoder, head, mem_len=mem_len)
17
- return model.apply(init_transformer)
18
-
19
- class MultiTransformer(nn.Module):
20
- "Multitask Transformer for training mask, next word, and sequence 2 sequence"
21
- def __init__(self, encoder, decoder, head, mem_len):
22
- super().__init__()
23
- self.encoder = encoder
24
- self.decoder = decoder
25
- self.head = head
26
- self.default_mem_len = mem_len
27
- self.current_mem_len = None
28
-
29
- def forward(self, inp):
30
- # data order: mask, next word, melody, chord
31
- outputs = {}
32
- msk, lm, c2m, m2c = [inp.get(key) for key in ['msk', 'lm', 'c2m', 'm2c']]
33
-
34
- if msk is not None:
35
- outputs['msk'] = self.head(self.encoder(msk['x'], msk['pos']))
36
- if lm is not None:
37
- outputs['lm'] = self.head(self.decoder(lm['x'], lm['pos']))
38
-
39
- if c2m is not None:
40
- self.reset()
41
- c2m_enc = self.encoder(c2m['enc'], c2m['enc_pos'])
42
- c2m_dec = self.decoder(c2m['dec'], c2m['dec_pos'], c2m_enc)
43
- outputs['c2m'] = self.head(c2m_dec)
44
-
45
- if m2c is not None:
46
- self.reset()
47
- m2c_enc = self.encoder(m2c['enc'], m2c['enc_pos'])
48
- m2c_dec = self.decoder(m2c['dec'], m2c['dec_pos'], m2c_enc)
49
- outputs['m2c'] = self.head(m2c_dec)
50
-
51
- return outputs
52
-
53
- "A sequential module that passes the reset call to its children."
54
- def reset(self):
55
- for module in self.children():
56
- reset_children(module)
57
-
58
- def reset_children(mod):
59
- if hasattr(mod, 'reset'): mod.reset()
60
- for module in mod.children():
61
- reset_children(module)
62
-
63
- # COMPONENTS
64
- class TransformerEmbedding(nn.Module):
65
- "Embedding + positional encoding + dropout"
66
- def __init__(self, vocab_size:int, emb_sz:int, embed_p:float=0., mem_len=512, beat_len=32, max_bar_len=1024, pad_idx=None):
67
- super().__init__()
68
- self.emb_sz = emb_sz
69
- self.pad_idx = pad_idx
70
-
71
- self.embed = nn.Embedding(vocab_size, emb_sz, padding_idx=pad_idx)
72
- self.pos_enc = PositionalEncoding(emb_sz)
73
- self.beat_len, self.max_bar_len = beat_len, max_bar_len
74
- self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0)
75
- self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0)
76
-
77
- self.drop = nn.Dropout(embed_p)
78
- self.mem_len = mem_len
79
-
80
- def forward(self, inp, pos):
81
- beat_enc = self.beat_enc(pos % self.beat_len)
82
- bar_pos = pos // self.beat_len % self.max_bar_len
83
- bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1
84
- bar_enc = self.bar_enc((bar_pos))
85
- emb = self.drop(self.embed(inp) + beat_enc + bar_enc)
86
- return emb
87
-
88
- def relative_pos_enc(self, emb):
89
- # return torch.arange(640-1, -1, -1).float().cuda()
90
- seq_len = emb.shape[1] + self.mem_len
91
- pos = torch.arange(seq_len-1, -1, -1, device=emb.device, dtype=emb.dtype) # backwards (txl pos encoding)
92
- return self.pos_enc(pos)
93
-
94
- class MTLinearDecoder(nn.Module):
95
- "To go on top of a RNNCore module and create a Language Model."
96
- initrange=0.1
97
-
98
- def __init__(self, n_hid:int, n_out:int, output_p:float, tie_encoder:nn.Module=None, out_bias:bool=True, **kwargs):
99
- super().__init__()
100
- self.decoder = nn.Linear(n_hid, n_out, bias=out_bias)
101
- self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
102
- self.output_dp = RNNDropout(output_p)
103
- if out_bias: self.decoder.bias.data.zero_()
104
- if tie_encoder: self.decoder.weight = tie_encoder.weight
105
-
106
- def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
107
- output = self.output_dp(input)
108
- decoded = self.decoder(output)
109
- return decoded
110
-
111
-
112
- # DECODER TRANSLATE BLOCK
113
- class MTEncoder(nn.Module):
114
- def __init__(self, embed:nn.Module, n_hid:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,
115
- resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=True, scale:bool=True,
116
- act:Activation=Activation.ReLU, double_drop:bool=True, mem_len:int=512, is_decoder=False,
117
- mask_steps=1, mask_p=0.3, **kwargs):
118
- super().__init__()
119
- self.embed = embed
120
- self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
121
- self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
122
- self.n_layers,self.d_model = n_layers,d_model
123
- self.layers = nn.ModuleList([MTEncoderBlock(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
124
- ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, mem_len=mem_len,
125
- ) for k in range(n_layers)])
126
-
127
- self.mask_steps, self.mask_p = mask_steps, mask_p
128
- self.is_decoder = is_decoder
129
-
130
- nn.init.normal_(self.u, 0., 0.02)
131
- nn.init.normal_(self.v, 0., 0.02)
132
-
133
- def forward(self, x_lm, lm_pos, msk_emb=None):
134
- bs,lm_len = x_lm.size()
135
-
136
- lm_emb = self.embed(x_lm, lm_pos)
137
- if msk_emb is not None and msk_emb.shape[1] > lm_emb.shape[1]:
138
- pos_enc = self.embed.relative_pos_enc(msk_emb)
139
- else:
140
- pos_enc = self.embed.relative_pos_enc(lm_emb)
141
-
142
- # Masks
143
- if self.is_decoder:
144
- lm_mask = rand_window_mask(lm_len, self.embed.mem_len, x_lm.device,
145
- max_size=self.mask_steps, p=self.mask_p, is_eval=not self.training)
146
- else:
147
- lm_mask = None
148
-
149
- for i, layer in enumerate(self.layers):
150
- lm_emb = layer(lm_emb, msk_emb, lm_mask=lm_mask,
151
- r=pos_enc, g_u=self.u, g_v=self.v)
152
- return lm_emb
153
-
154
- class MTEncoderBlock(nn.Module):
155
- "Decoder block of a Transformer model."
156
- #Can't use Sequential directly cause more than one input...
157
- def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,
158
- bias:bool=True, scale:bool=True, double_drop:bool=True, mem_len:int=512, mha2_mem_len=0, **kwargs):
159
- super().__init__()
160
- attn_cls = MemMultiHeadRelativeAttentionKV
161
- self.mha1 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mem_len, r_mask=False)
162
- self.mha2 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mha2_mem_len, r_mask=True)
163
- self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, double_drop=double_drop)
164
-
165
- def forward(self, enc_lm:Tensor, enc_msk:Tensor,
166
- r=None, g_u=None, g_v=None,
167
- msk_mask:Tensor=None, lm_mask:Tensor=None):
168
-
169
- y_lm = self.mha1(enc_lm, enc_lm, enc_lm, r, g_u, g_v, mask=lm_mask)
170
- if enc_msk is None: return y_lm
171
- return self.ff(self.mha2(y_lm, enc_msk, enc_msk, r, g_u, g_v, mask=msk_mask))
172
-
173
-
174
- # Attention Layer
175
-
176
-
177
- # Attn
178
-
179
- class MemMultiHeadRelativeAttentionKV(nn.Module):
180
- "Attention Layer monster - relative positioning, keeps track of own memory, separate kv weights to support sequence2sequence decoding."
181
- def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True,
182
- scale:bool=True, mem_len:int=512, r_mask=True):
183
- super().__init__()
184
- d_head = ifnone(d_head, d_model//n_heads)
185
- self.n_heads,self.d_head,self.scale = n_heads,d_head,scale
186
-
187
- assert(d_model == d_head * n_heads)
188
- self.q_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
189
- self.k_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
190
- self.v_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
191
-
192
- self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)
193
- self.ln = nn.LayerNorm(d_model)
194
- self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)
195
- self.r_mask = r_mask
196
-
197
- self.mem_len = mem_len
198
- self.prev_k = None
199
- self.prev_v = None
200
-
201
- def forward(self, q:Tensor, k:Tensor=None, v:Tensor=None,
202
- r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None,
203
- mask:Tensor=None, **kwargs):
204
- if k is None: k = q
205
- if v is None: v = q
206
- return self.ln(q + self.drop_res(self._apply_attention(q, k, v, r, g_u, g_v, mask=mask, **kwargs)))
207
-
208
- def mem_k(self, k):
209
- if self.mem_len == 0: return k
210
- if self.prev_k is None or (self.prev_k.shape[0] != k.shape[0]): # reset if wrong batch size
211
- self.prev_k = k[:, -self.mem_len:]
212
- return k
213
- with torch.no_grad():
214
- k_ext = torch.cat([self.prev_k, k], dim=1)
215
- self.prev_k = k_ext[:, -self.mem_len:]
216
- return k_ext.detach()
217
-
218
- def mem_v(self, v):
219
- if self.mem_len == 0: return v
220
- if self.prev_v is None or (self.prev_v.shape[0] != v.shape[0]): # reset if wrong batch size
221
- self.prev_v = v[:, -self.mem_len:]
222
- return v
223
- with torch.no_grad():
224
- v_ext = torch.cat([self.prev_v, v], dim=1)
225
- self.prev_v = v_ext[:, -self.mem_len:]
226
- return v_ext.detach()
227
-
228
- def reset(self):
229
- self.prev_v = None
230
- self.prev_k = None
231
-
232
- def _apply_attention(self, q:Tensor, k:Tensor, v:Tensor,
233
- r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None,
234
- mask:Tensor=None, **kwargs):
235
- #Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable
236
- #parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states.
237
- # bs,x_len,seq_len = q.size(0),q.size(1),r.size(0)
238
- k = self.mem_k(k)
239
- v = self.mem_v(v)
240
- bs,x_len,seq_len = q.size(0),q.size(1),k.size(1)
241
- wq,wk,wv = self.q_wgt(q),self.k_wgt(k),self.v_wgt(v)
242
- wq = wq[:,-x_len:]
243
- wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
244
- wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
245
- wkr = self.r_attn(r[-seq_len:])
246
- wkr = wkr.view(seq_len, self.n_heads, self.d_head)
247
- wkr = wkr.permute(1,2,0)
248
- #### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)
249
- AC = torch.matmul(wq+g_u,wk)
250
- BD = _line_shift(torch.matmul(wq+g_v, wkr), mask=self.r_mask)
251
- if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))
252
- if mask is not None:
253
- mask = mask[...,-seq_len:]
254
- if hasattr(mask, 'bool'): mask = mask.bool()
255
- attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
256
- attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
257
- attn_vec = torch.matmul(attn_prob, wv)
258
- return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/multitask_transformer/transform.py DELETED
@@ -1,68 +0,0 @@
1
- from ..music_transformer.transform import *
2
-
3
- class MultitrackItem():
4
- def __init__(self, melody:MusicItem, chords:MusicItem, stream=None):
5
- self.melody,self.chords = melody, chords
6
- self.vocab = melody.vocab
7
- self._stream = stream
8
-
9
- @classmethod
10
- def from_file(cls, midi_file, vocab):
11
- return cls.from_stream(file2stream(midi_file), vocab)
12
-
13
- @classmethod
14
- def from_stream(cls, stream, vocab):
15
- if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts()
16
- num_parts = len(stream.parts)
17
- sort_pitch = False
18
- if num_parts > 2:
19
- raise ValueError('Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks')
20
- elif num_parts == 1:
21
- print('Warning: only 1 track found. Inferring melody/chords')
22
- stream = separate_melody_chord(stream)
23
- sort_pitch = False
24
-
25
- mpart, cpart = stream2npenc_parts(stream, sort_pitch=sort_pitch)
26
- return cls.from_npenc_parts(mpart, cpart, vocab, stream)
27
-
28
- @classmethod
29
- def from_npenc_parts(cls, mpart, cpart, vocab, stream=None):
30
- mpart = npenc2idxenc(mpart, seq_type=SEQType.Melody, vocab=vocab, add_eos=False)
31
- cpart = npenc2idxenc(cpart, seq_type=SEQType.Chords, vocab=vocab, add_eos=False)
32
- return MultitrackItem(MusicItem(mpart, vocab), MusicItem(cpart, vocab), stream)
33
-
34
- @classmethod
35
- def from_idx(cls, item, vocab):
36
- m, c = item
37
- return MultitrackItem(MusicItem.from_idx(m, vocab), MusicItem.from_idx(c, vocab))
38
- def to_idx(self): return np.array((self.melody.to_idx(), self.chords.to_idx()))
39
-
40
- @property
41
- def stream(self):
42
- self._stream = self.to_stream() if self._stream is None else self._stream
43
- return self._stream
44
-
45
- def to_stream(self, bpm=120):
46
- ps = self.melody.to_npenc(), self.chords.to_npenc()
47
- ps = [npenc2chordarr(p) for p in ps]
48
- chordarr = chordarr_combine_parts(ps)
49
- return chordarr2stream(chordarr, bpm=bpm)
50
-
51
-
52
- def show(self, format:str=None):
53
- return self.stream.show(format)
54
- def play(self): self.stream.show('midi')
55
-
56
- def transpose(self, val):
57
- return MultitrackItem(self.melody.transpose(val), self.chords.transpose(val))
58
- def pad_to(self, val):
59
- return MultitrackItem(self.melody.pad_to(val), self.chords.pad_to(val))
60
- def trim_to_beat(self, beat):
61
- return MultitrackItem(self.melody.trim_to_beat(beat), self.chords.trim_to_beat(beat))
62
-
63
- def combine2chordarr(np1, np2, vocab):
64
- if len(np1.shape) == 1: np1 = idxenc2npenc(np1, vocab)
65
- if len(np2.shape) == 1: np2 = idxenc2npenc(np2, vocab)
66
- p1 = npenc2chordarr(np1)
67
- p2 = npenc2chordarr(np2)
68
- return chordarr_combine_parts((p1, p2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/music_transformer/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .dataloader import *
2
- from .model import *
3
- from .learner import *
 
 
 
 
utils/musicautobot/music_transformer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (251 Bytes)
 
utils/musicautobot/music_transformer/__pycache__/dataloader.cpython-310.pyc DELETED
Binary file (11.2 kB)
 
utils/musicautobot/music_transformer/__pycache__/learner.cpython-310.pyc DELETED
Binary file (5.94 kB)
 
utils/musicautobot/music_transformer/__pycache__/model.cpython-310.pyc DELETED
Binary file (3 kB)
 
utils/musicautobot/music_transformer/__pycache__/transform.cpython-310.pyc DELETED
Binary file (10.7 kB)
 
utils/musicautobot/music_transformer/dataloader.py DELETED
@@ -1,229 +0,0 @@
1
- "Fastai Language Model Databunch modified to work with music"
2
- from fastai.basics import *
3
- # from fastai.basic_data import DataBunch
4
- from fastai.text.data import LMLabelList
5
- from .transform import *
6
- from ..vocab import MusicVocab
7
-
8
-
9
- class MusicDataBunch(DataBunch):
10
- "Create a `TextDataBunch` suitable for training a language model."
11
- @classmethod
12
- def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', no_check:bool=False, bs=64, val_bs:int=None,
13
- num_workers:int=0, device:torch.device=None, collate_fn:Callable=data_collate,
14
- dl_tfms:Optional[Collection[Callable]]=None, bptt:int=70,
15
- preloader_cls=None, shuffle_dl=False, transpose_range=(0,12), **kwargs) -> DataBunch:
16
- "Create a `TextDataBunch` in `path` from the `datasets` for language modelling."
17
- datasets = cls._init_ds(train_ds, valid_ds, test_ds)
18
- preloader_cls = MusicPreloader if preloader_cls is None else preloader_cls
19
- val_bs = ifnone(val_bs, bs)
20
- datasets = [preloader_cls(ds, shuffle=(i==0), bs=(bs if i==0 else val_bs), bptt=bptt, transpose_range=transpose_range, **kwargs)
21
- for i,ds in enumerate(datasets)]
22
- val_bs = bs
23
- dl_tfms = [partially_apply_vocab(tfm, train_ds.vocab) for tfm in listify(dl_tfms)]
24
- dls = [DataLoader(d, b, shuffle=shuffle_dl) for d,b in zip(datasets, (bs,val_bs,val_bs,val_bs)) if d is not None]
25
- return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
26
-
27
- @classmethod
28
- def from_folder(cls, path:PathOrStr, extensions='.npy', **kwargs):
29
- files = get_files(path, extensions=extensions, recurse=True);
30
- return cls.from_files(files, path, **kwargs)
31
-
32
- @classmethod
33
- def from_files(cls, files, path, processors=None, split_pct=0.1,
34
- vocab=None, list_cls=None, **kwargs):
35
- if vocab is None: vocab = MusicVocab.create()
36
- if list_cls is None: list_cls = MusicItemList
37
- src = (list_cls(items=files, path=path, processor=processors, vocab=vocab)
38
- .split_by_rand_pct(split_pct, seed=6)
39
- .label_const(label_cls=LMLabelList))
40
- return src.databunch(**kwargs)
41
-
42
- @classmethod
43
- def empty(cls, path, **kwargs):
44
- vocab = MusicVocab.create()
45
- src = MusicItemList([], path=path, vocab=vocab, ignore_empty=True).split_none()
46
- return src.label_const(label_cls=LMLabelList).databunch()
47
-
48
- def partially_apply_vocab(tfm, vocab):
49
- if 'vocab' in inspect.getfullargspec(tfm).args:
50
- return partial(tfm, vocab=vocab)
51
- return tfm
52
-
53
- class MusicItemList(ItemList):
54
- _bunch = MusicDataBunch
55
-
56
- def __init__(self, items:Iterator, vocab:MusicVocab=None, **kwargs):
57
- super().__init__(items, **kwargs)
58
- self.vocab = vocab
59
- self.copy_new += ['vocab']
60
-
61
- def get(self, i):
62
- o = super().get(i)
63
- if is_pos_enc(o):
64
- return MusicItem.from_idx(o, self.vocab)
65
- return MusicItem(o, self.vocab)
66
-
67
- def is_pos_enc(idxenc):
68
- if len(idxenc.shape) == 2 and idxenc.shape[0] == 2: return True
69
- return idxenc.dtype == np.object and idxenc.shape == (2,)
70
-
71
- class MusicItemProcessor(PreProcessor):
72
- "`PreProcessor` that transforms numpy files to indexes for training"
73
- def process_one(self,item):
74
- item = MusicItem.from_npenc(item, vocab=self.vocab)
75
- return item.to_idx()
76
-
77
- def process(self, ds):
78
- self.vocab = ds.vocab
79
- super().process(ds)
80
-
81
- class OpenNPFileProcessor(PreProcessor):
82
- "`PreProcessor` that opens the filenames and read the texts."
83
- def process_one(self,item):
84
- return np.load(item, allow_pickle=True) if isinstance(item, Path) else item
85
-
86
- class Midi2ItemProcessor(PreProcessor):
87
- "Skips midi preprocessing step. And encodes midi files to MusicItems"
88
- def process_one(self,item):
89
- item = MusicItem.from_file(item, vocab=self.vocab)
90
- return item.to_idx()
91
-
92
- def process(self, ds):
93
- self.vocab = ds.vocab
94
- super().process(ds)
95
-
96
- ## For npenc dataset
97
- class MusicPreloader(Callback):
98
- "Transforms the tokens in `dataset` to a stream of contiguous batches for language modelling."
99
-
100
- class CircularIndex():
101
- "Handles shuffle, direction of indexing, wraps around to head tail in the ragged array as needed"
102
- def __init__(self, length:int, forward:bool): self.idx, self.forward = np.arange(length), forward
103
- def __getitem__(self, i):
104
- return self.idx[ i%len(self.idx) if self.forward else len(self.idx)-1-i%len(self.idx)]
105
- def __len__(self) -> int: return len(self.idx)
106
- def shuffle(self): np.random.shuffle(self.idx)
107
-
108
- def __init__(self, dataset:LabelList, lengths:Collection[int]=None, bs:int=32, bptt:int=70, backwards:bool=False,
109
- shuffle:bool=False, y_offset:int=1,
110
- transpose_range=None, transpose_p=0.5,
111
- encode_position=True,
112
- **kwargs):
113
- self.dataset,self.bs,self.bptt,self.shuffle,self.backwards,self.lengths = dataset,bs,bptt,shuffle,backwards,lengths
114
- self.vocab = self.dataset.vocab
115
- self.bs *= num_distrib() or 1
116
- self.totalToks,self.ite_len,self.idx = int(0),None,None
117
- self.y_offset = y_offset
118
-
119
- self.transpose_range,self.transpose_p = transpose_range,transpose_p
120
- self.encode_position = encode_position
121
- self.bptt_len = self.bptt
122
-
123
- self.allocate_buffers() # needed for valid_dl on distributed training - otherwise doesn't get initialized on first epoch
124
-
125
- def __len__(self):
126
- if self.ite_len is None:
127
- if self.lengths is None: self.lengths = np.array([len(item) for item in self.dataset.x])
128
- self.totalToks = self.lengths.sum()
129
- self.ite_len = self.bs*int( math.ceil( self.totalToks/(self.bptt*self.bs) )) if self.item is None else 1
130
- return self.ite_len
131
-
132
- def __getattr__(self,k:str)->Any: return getattr(self.dataset, k)
133
-
134
- def allocate_buffers(self):
135
- "Create the ragged array that will be filled when we ask for items."
136
- if self.ite_len is None: len(self)
137
- self.idx = MusicPreloader.CircularIndex(len(self.dataset.x), not self.backwards)
138
-
139
- # batch shape = (bs, bptt, 2 - [index, pos]) if encode_position. Else - (bs, bptt)
140
- buffer_len = (2,) if self.encode_position else ()
141
- self.batch = np.zeros((self.bs, self.bptt+self.y_offset) + buffer_len, dtype=np.int64)
142
- self.batch_x, self.batch_y = self.batch[:,0:self.bptt], self.batch[:,self.y_offset:self.bptt+self.y_offset]
143
- #ro: index of the text we're at inside our datasets for the various batches
144
- self.ro = np.zeros(self.bs, dtype=np.int64)
145
- #ri: index of the token we're at inside our current text for the various batches
146
- self.ri = np.zeros(self.bs, dtype=np.int)
147
-
148
- # allocate random transpose values. Need to allocate this before hand.
149
- self.transpose_values = self.get_random_transpose_values()
150
-
151
- def get_random_transpose_values(self):
152
- if self.transpose_range is None: return None
153
- n = len(self.dataset)
154
- rt_arr = torch.randint(*self.transpose_range, (n,))-self.transpose_range[1]//2
155
- mask = torch.rand(rt_arr.shape) > self.transpose_p
156
- rt_arr[mask] = 0
157
- return rt_arr
158
-
159
- def on_epoch_begin(self, **kwargs):
160
- if self.idx is None: self.allocate_buffers()
161
- elif self.shuffle:
162
- self.ite_len = None
163
- self.idx.shuffle()
164
- self.transpose_values = self.get_random_transpose_values()
165
- self.bptt_len = self.bptt
166
- self.idx.forward = not self.backwards
167
-
168
- step = self.totalToks / self.bs
169
- ln_rag, countTokens, i_rag = 0, 0, -1
170
- for i in range(0,self.bs):
171
- #Compute the initial values for ro and ri
172
- while ln_rag + countTokens <= int(step * i):
173
- countTokens += ln_rag
174
- i_rag += 1
175
- ln_rag = self.lengths[self.idx[i_rag]]
176
- self.ro[i] = i_rag
177
- self.ri[i] = ( ln_rag - int(step * i - countTokens) ) if self.backwards else int(step * i - countTokens)
178
-
179
- #Training dl gets on_epoch_begin called, val_dl, on_epoch_end
180
- def on_epoch_end(self, **kwargs): self.on_epoch_begin()
181
-
182
- def __getitem__(self, k:int):
183
- j = k % self.bs
184
- if j==0:
185
- if self.item is not None: return self.dataset[0]
186
- if self.idx is None: self.on_epoch_begin()
187
-
188
- self.ro[j],self.ri[j] = self.fill_row(not self.backwards, self.dataset.x, self.idx, self.batch[j][:self.bptt_len+self.y_offset],
189
- self.ro[j], self.ri[j], overlap=1, lengths=self.lengths)
190
- return self.batch_x[j][:self.bptt_len], self.batch_y[j][:self.bptt_len]
191
-
192
- def fill_row(self, forward, items, idx, row, ro, ri, overlap, lengths):
193
- "Fill the row with tokens from the ragged array. --OBS-- overlap != 1 has not been implemented"
194
- ibuf = n = 0
195
- ro -= 1
196
- while ibuf < row.shape[0]:
197
- ro += 1
198
- ix = idx[ro]
199
-
200
- item = items[ix]
201
- if self.transpose_values is not None:
202
- item = item.transpose(self.transpose_values[ix].item())
203
-
204
- if self.encode_position:
205
- # Positions are colomn stacked with indexes. This makes it easier to keep in sync
206
- rag = np.stack([item.data, item.position], axis=1)
207
- else:
208
- rag = item.data
209
-
210
- if forward:
211
- ri = 0 if ibuf else ri
212
- n = min(lengths[ix] - ri, row.shape[0] - ibuf)
213
- row[ibuf:ibuf+n] = rag[ri:ri+n]
214
- else:
215
- ri = lengths[ix] if ibuf else ri
216
- n = min(ri, row.size - ibuf)
217
- row[ibuf:ibuf+n] = rag[ri-n:ri][::-1]
218
- ibuf += n
219
- return ro, ri + ((n-overlap) if forward else -(n-overlap))
220
-
221
- def batch_position_tfm(b):
222
- "Batch transform for training with positional encoding"
223
- x,y = b
224
- x = {
225
- 'x': x[...,0],
226
- 'pos': x[...,1]
227
- }
228
- return x, y[...,0]
229
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/music_transformer/learner.py DELETED
@@ -1,171 +0,0 @@
1
- from fastai.basics import *
2
- from fastai.text.learner import LanguageLearner, get_language_model, _model_meta
3
- from .model import *
4
- from .transform import MusicItem
5
- from ..numpy_encode import SAMPLE_FREQ
6
- from ..utils.top_k_top_p import top_k_top_p
7
- from ..utils.midifile import is_empty_midi
8
-
9
- _model_meta[MusicTransformerXL] = _model_meta[TransformerXL] # copy over fastai's model metadata
10
-
11
- def music_model_learner(data:DataBunch, arch=MusicTransformerXL, config:dict=None, drop_mult:float=1.,
12
- pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
13
- "Create a `Learner` with a language model from `data` and `arch`."
14
- meta = _model_meta[arch]
15
-
16
- if pretrained_path:
17
- state = torch.load(pretrained_path, map_location='cpu')
18
- if config is None: config = state['config']
19
-
20
- model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)
21
- learn = MusicLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)
22
-
23
- if pretrained_path:
24
- get_model(model).load_state_dict(state['model'], strict=False)
25
- if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
26
- try: learn.opt.load_state_dict(state['opt'])
27
- except: pass
28
- del state
29
- gc.collect()
30
-
31
- return learn
32
-
33
- # Predictions
34
- from fastai import basic_train # for predictions
35
- class MusicLearner(LanguageLearner):
36
- def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
37
- "Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
38
- out_path = super().save(file, return_path=True, with_opt=with_opt)
39
- if config and out_path:
40
- state = torch.load(out_path)
41
- state['config'] = config
42
- torch.save(state, out_path)
43
- del state
44
- gc.collect()
45
- return out_path
46
-
47
- def beam_search(self, xb:Tensor, n_words:int, top_k:int=10, beam_sz:int=10, temperature:float=1.,
48
- ):
49
- "Return the `n_words` that come after `text` using beam search."
50
- self.model.reset()
51
- self.model.eval()
52
- xb_length = xb.shape[-1]
53
- if xb.shape[0] > 1: xb = xb[0][None]
54
- yb = torch.ones_like(xb)
55
-
56
- nodes = None
57
- xb = xb.repeat(top_k, 1)
58
- nodes = xb.clone()
59
- scores = xb.new_zeros(1).float()
60
- with torch.no_grad():
61
- for k in progress_bar(range(n_words), leave=False):
62
- out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)
63
- values, indices = out.topk(top_k, dim=-1)
64
- scores = (-values + scores[:,None]).view(-1)
65
- indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)
66
- sort_idx = scores.argsort()[:beam_sz]
67
- scores = scores[sort_idx]
68
- nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),
69
- indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)
70
- nodes = nodes.view(-1, nodes.size(2))[sort_idx]
71
- self.model[0].select_hidden(indices_idx[sort_idx])
72
- xb = nodes[:,-1][:,None]
73
- if temperature != 1.: scores.div_(temperature)
74
- node_idx = torch.multinomial(torch.exp(-scores), 1).item()
75
- return [i.item() for i in nodes[node_idx][xb_length:] ]
76
-
77
- def predict(self, item:MusicItem, n_words:int=128,
78
- temperatures:float=(1.0,1.0), min_bars=4,
79
- top_k=30, top_p=0.6):
80
- "Return the `n_words` that come after `text`."
81
- self.model.reset()
82
- new_idx = []
83
- vocab = self.data.vocab
84
- x, pos = item.to_tensor(), item.get_pos_tensor()
85
- last_pos = pos[-1] if len(pos) else 0
86
- y = torch.tensor([0])
87
-
88
- start_pos = last_pos
89
-
90
- sep_count = 0
91
- bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
92
- vocab = self.data.vocab
93
-
94
- repeat_count = 0
95
- if hasattr(self.model[0], 'encode_position'):
96
- encode_position = self.model[0].encode_position
97
- else: encode_position = False
98
-
99
- for i in progress_bar(range(n_words), leave=True):
100
- with torch.no_grad():
101
- if encode_position:
102
- batch = { 'x': x[None], 'pos': pos[None] }
103
- logits = self.model(batch)[0][-1][-1]
104
- else:
105
- logits = self.model(x[None])[0][-1][-1]
106
-
107
- prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
108
-
109
- # Temperature
110
- # Use first temperatures value if last prediction was duration
111
- temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
112
- repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
113
- temperature += repeat_penalty
114
- if temperature != 1.: logits = logits / temperature
115
-
116
-
117
- # Filter
118
- # bar = 16 beats
119
- filter_value = -float('Inf')
120
- if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
121
-
122
- logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
123
- logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
124
-
125
- # Sample
126
- probs = F.softmax(logits, dim=-1)
127
- idx = torch.multinomial(probs, 1).item()
128
-
129
- # Update repeat count
130
- num_choices = len(probs.nonzero().view(-1))
131
- if num_choices <= 2: repeat_count += 1
132
- else: repeat_count = repeat_count // 2
133
-
134
- if prev_idx==vocab.sep_idx:
135
- duration = idx - vocab.dur_range[0]
136
- last_pos = last_pos + duration
137
-
138
- bars_pred = (last_pos - start_pos) // 16
139
- abs_bar = last_pos // 16
140
- # if (bars % 8 == 0) and (bars_pred > min_bars): break
141
- if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
142
-
143
-
144
- if idx==vocab.bos_idx:
145
- print('Predicted BOS token. Returning prediction...')
146
- break
147
-
148
- new_idx.append(idx)
149
- x = x.new_tensor([idx])
150
- pos = pos.new_tensor([last_pos])
151
-
152
- pred = vocab.to_music_item(np.array(new_idx))
153
- full = item.append(pred)
154
- return pred, full
155
-
156
- # High level prediction functions from midi file
157
- def predict_from_midi(learn, midi=None, n_words=400,
158
- temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
159
- vocab = learn.data.vocab
160
- seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
161
- if seed_len is not None: seed = seed.trim_to_beat(seed_len)
162
-
163
- pred, full = learn.predict(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
164
- return full
165
-
166
- def filter_invalid_indexes(res, prev_idx, vocab, filter_value=-float('Inf')):
167
- if vocab.is_duration_or_pad(prev_idx):
168
- res[list(range(*vocab.dur_range))] = filter_value
169
- else:
170
- res[list(range(*vocab.note_range))] = filter_value
171
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/music_transformer/model.py DELETED
@@ -1,66 +0,0 @@
1
- from fastai.basics import *
2
- from fastai.text.models.transformer import TransformerXL
3
- from ..utils.attention_mask import rand_window_mask
4
-
5
- class MusicTransformerXL(TransformerXL):
6
- "Exactly like fastai's TransformerXL, but with more aggressive attention mask: see `rand_window_mask`"
7
- def __init__(self, *args, encode_position=True, mask_steps=1, **kwargs):
8
- import inspect
9
- sig = inspect.signature(TransformerXL)
10
- arg_params = { k:kwargs[k] for k in sig.parameters if k in kwargs }
11
- super().__init__(*args, **arg_params)
12
-
13
- self.encode_position = encode_position
14
- if self.encode_position: self.beat_enc = BeatPositionEncoder(kwargs['d_model'])
15
-
16
- self.mask_steps=mask_steps
17
-
18
-
19
- def forward(self, x):
20
- #The hidden state has to be initiliazed in the forward pass for nn.DataParallel
21
- if self.mem_len > 0 and not self.init:
22
- self.reset()
23
- self.init = True
24
-
25
- benc = 0
26
- if self.encode_position:
27
- x,pos = x['x'], x['pos']
28
- benc = self.beat_enc(pos)
29
-
30
- bs,x_len = x.size()
31
- inp = self.drop_emb(self.encoder(x) + benc) #.mul_(self.d_model ** 0.5)
32
- m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
33
- seq_len = m_len + x_len
34
-
35
- mask = rand_window_mask(x_len, m_len, inp.device, max_size=self.mask_steps, is_eval=not self.training) if self.mask else None
36
- if m_len == 0: mask[...,0,0] = 0
37
- #[None,:,:None] for einsum implementation of attention
38
- hids = []
39
- pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
40
- pos_enc = self.pos_enc(pos)
41
- hids.append(inp)
42
- for i, layer in enumerate(self.layers):
43
- mem = self.hidden[i] if self.mem_len > 0 else None
44
- inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
45
- hids.append(inp)
46
- core_out = inp[:,-x_len:]
47
- if self.mem_len > 0 : self._update_mems(hids)
48
- return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]
49
-
50
-
51
- # Beat encoder
52
- class BeatPositionEncoder(nn.Module):
53
- "Embedding + positional encoding + dropout"
54
- def __init__(self, emb_sz:int, beat_len=32, max_bar_len=1024):
55
- super().__init__()
56
-
57
- self.beat_len, self.max_bar_len = beat_len, max_bar_len
58
- self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0)
59
- self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0)
60
-
61
- def forward(self, pos):
62
- beat_enc = self.beat_enc(pos % self.beat_len)
63
- bar_pos = pos // self.beat_len % self.max_bar_len
64
- bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1
65
- bar_enc = self.bar_enc((bar_pos))
66
- return beat_enc + bar_enc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/music_transformer/transform.py DELETED
@@ -1,235 +0,0 @@
1
- from ..numpy_encode import *
2
- import numpy as np
3
- from enum import Enum
4
- import torch
5
- from ..vocab import *
6
- from functools import partial
7
-
8
- SEQType = Enum('SEQType', 'Mask, Sentence, Melody, Chords, Empty')
9
-
10
- class MusicItem():
11
- def __init__(self, data, vocab, stream=None, position=None):
12
- self.data = data
13
- self.vocab = vocab
14
- self._stream = stream
15
- self._position = position
16
- def __repr__(self): return '\n'.join([
17
- f'\n{self.__class__.__name__} - {self.data.shape}',
18
- f'{self.vocab.textify(self.data[:10])}...'])
19
- def __len__(self): return len(self.data)
20
-
21
- @classmethod
22
- def from_file(cls, midi_file, vocab):
23
- return cls.from_stream(file2stream(midi_file), vocab)
24
- @classmethod
25
- def from_stream(cls, stream, vocab):
26
- if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts()
27
- chordarr = stream2chordarr(stream) # 2.
28
- npenc = chordarr2npenc(chordarr) # 3.
29
- return cls.from_npenc(npenc, vocab, stream)
30
- @classmethod
31
- def from_npenc(cls, npenc, vocab, stream=None): return MusicItem(npenc2idxenc(npenc, vocab), vocab, stream)
32
-
33
- @classmethod
34
- def from_idx(cls, item, vocab):
35
- idx,pos = item
36
- return MusicItem(idx, vocab=vocab, position=pos)
37
- def to_idx(self): return self.data, self.position
38
-
39
- @classmethod
40
- def empty(cls, vocab, seq_type=SEQType.Sentence):
41
- return MusicItem(seq_prefix(seq_type, vocab), vocab)
42
-
43
- @property
44
- def stream(self):
45
- self._stream = self.to_stream() if self._stream is None else self._stream
46
- return self._stream
47
-
48
- def to_stream(self, bpm=120):
49
- return idxenc2stream(self.data, self.vocab, bpm=bpm)
50
-
51
- def to_tensor(self, device=None):
52
- return to_tensor(self.data, device)
53
-
54
- def to_text(self, sep=' '): return self.vocab.textify(self.data, sep)
55
-
56
- @property
57
- def position(self):
58
- self._position = position_enc(self.data, self.vocab) if self._position is None else self._position
59
- return self._position
60
-
61
- def get_pos_tensor(self, device=None): return to_tensor(self.position, device)
62
-
63
- def to_npenc(self):
64
- return idxenc2npenc(self.data, self.vocab)
65
-
66
- def show(self, format:str=None):
67
- return self.stream.show(format)
68
- def play(self): self.stream.show('midi')
69
-
70
- #Added by caslabs
71
- def download(self, filename:str=None, ext:str=None):
72
- return self.stream.write('midi', fp=filename)
73
-
74
- @property
75
- def new(self):
76
- return partial(type(self), vocab=self.vocab)
77
-
78
- def trim_to_beat(self, beat, include_last_sep=False):
79
- return self.new(trim_to_beat(self.data, self.position, self.vocab, beat, include_last_sep))
80
-
81
- def transpose(self, interval):
82
- return self.new(tfm_transpose(self.data, interval, self.vocab), position=self._position)
83
-
84
- def append(self, item):
85
- return self.new(np.concatenate((self.data, item.data), axis=0))
86
-
87
- def mask_pitch(self, section=None):
88
- return self.new(self.mask(self.vocab.note_range, section), position=self.position)
89
-
90
- def mask_duration(self, section=None, keep_position_enc=True):
91
- masked_data = self.mask(self.vocab.dur_range, section)
92
- if keep_position_enc: return self.new(masked_data, position=self.position)
93
- return self.new(masked_data)
94
-
95
- def mask(self, token_range, section_range=None):
96
- return mask_section(self.data, self.position, token_range, self.vocab.mask_idx, section_range=section_range)
97
-
98
- def pad_to(self, bptt):
99
- data = pad_seq(self.data, bptt, self.vocab.pad_idx)
100
- pos = pad_seq(self.position, bptt, 0)
101
- return self.new(data, stream=self._stream, position=pos)
102
-
103
- def split_stream_parts(self):
104
- self._stream = separate_melody_chord(self.stream)
105
- return self.stream
106
-
107
- def remove_eos(self):
108
- if self.data[-1] == self.vocab.stoi[EOS]: return self.new(self.data, stream=self.stream)
109
- return self
110
-
111
- def split_parts(self):
112
- return self.new(self.data, stream=separate_melody_chord(self.stream), position=self.position)
113
-
114
- def pad_seq(seq, bptt, value):
115
- pad_len = max(bptt-seq.shape[0], 0)
116
- return np.pad(seq, (0, pad_len), 'constant', constant_values=value)[:bptt]
117
-
118
- def to_tensor(t, device=None):
119
- t = t if isinstance(t, torch.Tensor) else torch.tensor(t)
120
- if device is None and torch.cuda.is_available(): t = t.cuda()
121
- else: t.to(device)
122
- return t.long()
123
-
124
- def midi2idxenc(midi_file, vocab):
125
- "Converts midi file to index encoding for training"
126
- npenc = midi2npenc(midi_file) # 3.
127
- return npenc2idxenc(npenc, vocab)
128
-
129
- def idxenc2stream(arr, vocab, bpm=120):
130
- "Converts index encoding to music21 stream"
131
- npenc = idxenc2npenc(arr, vocab)
132
- return npenc2stream(npenc, bpm=bpm)
133
-
134
- # single stream instead of note,dur
135
- def npenc2idxenc(t, vocab, seq_type=SEQType.Sentence, add_eos=False):
136
- "Transforms numpy array from 2 column (note, duration) matrix to a single column"
137
- "[[n1, d1], [n2, d2], ...] -> [n1, d1, n2, d2]"
138
- if isinstance(t, (list, tuple)) and len(t) == 2:
139
- return [npenc2idxenc(x, vocab, start_seq) for x in t]
140
- t = t.copy()
141
-
142
- t[:, 0] = t[:, 0] + vocab.note_range[0]
143
- t[:, 1] = t[:, 1] + vocab.dur_range[0]
144
-
145
- prefix = seq_prefix(seq_type, vocab)
146
- suffix = np.array([vocab.stoi[EOS]]) if add_eos else np.empty(0, dtype=int)
147
- return np.concatenate([prefix, t.reshape(-1), suffix])
148
-
149
- def seq_prefix(seq_type, vocab):
150
- if seq_type == SEQType.Empty: return np.empty(0, dtype=int)
151
- start_token = vocab.bos_idx
152
- if seq_type == SEQType.Chords: start_token = vocab.stoi[CSEQ]
153
- if seq_type == SEQType.Melody: start_token = vocab.stoi[MSEQ]
154
- return np.array([start_token, vocab.pad_idx])
155
-
156
- def idxenc2npenc(t, vocab, validate=True):
157
- if validate: t = to_valid_idxenc(t, vocab.npenc_range)
158
- t = t.copy().reshape(-1, 2)
159
- if t.shape[0] == 0: return t
160
-
161
- t[:, 0] = t[:, 0] - vocab.note_range[0]
162
- t[:, 1] = t[:, 1] - vocab.dur_range[0]
163
-
164
- if validate: return to_valid_npenc(t)
165
- return t
166
-
167
- def to_valid_idxenc(t, valid_range):
168
- r = valid_range
169
- t = t[np.where((t >= r[0]) & (t < r[1]))]
170
- if t.shape[-1] % 2 == 1: t = t[..., :-1]
171
- return t
172
-
173
- def to_valid_npenc(t):
174
- is_note = (t[:, 0] < VALTSEP) | (t[:, 0] >= NOTE_SIZE)
175
- invalid_note_idx = is_note.argmax()
176
- invalid_dur_idx = (t[:, 1] < 0).argmax()
177
-
178
- invalid_idx = max(invalid_dur_idx, invalid_note_idx)
179
- if invalid_idx > 0:
180
- if invalid_note_idx > 0 and invalid_dur_idx > 0: invalid_idx = min(invalid_dur_idx, invalid_note_idx)
181
- print('Non midi note detected. Only returning valid portion. Index, seed', invalid_idx, t.shape)
182
- return t[:invalid_idx]
183
- return t
184
-
185
- def position_enc(idxenc, vocab):
186
- "Calculates positional beat encoding."
187
- sep_idxs = (idxenc == vocab.sep_idx).nonzero()[0]
188
- sep_idxs = sep_idxs[sep_idxs+2 < idxenc.shape[0]] # remove any indexes right before out of bounds (sep_idx+2)
189
- dur_vals = idxenc[sep_idxs+1]
190
- dur_vals[dur_vals == vocab.mask_idx] = vocab.dur_range[0] # make sure masked durations are 0
191
- dur_vals -= vocab.dur_range[0]
192
-
193
- posenc = np.zeros_like(idxenc)
194
- posenc[sep_idxs+2] = dur_vals
195
- return posenc.cumsum()
196
-
197
- def beat2index(idxenc, pos, vocab, beat, include_last_sep=False):
198
- cutoff = find_beat(pos, beat)
199
- if cutoff < 2: return 2 # always leave starter tokens
200
- if len(idxenc) < 2 or include_last_sep: return cutoff
201
- if idxenc[cutoff - 2] == vocab.sep_idx: return cutoff - 2
202
- return cutoff
203
-
204
- def find_beat(pos, beat, sample_freq=SAMPLE_FREQ, side='left'):
205
- return np.searchsorted(pos, beat * sample_freq, side=side)
206
-
207
- # TRANSFORMS
208
-
209
- def tfm_transpose(x, value, vocab):
210
- x = x.copy()
211
- x[(x >= vocab.note_range[0]) & (x < vocab.note_range[1])] += value
212
- return x
213
-
214
- def trim_to_beat(idxenc, pos, vocab, to_beat=None, include_last_sep=True):
215
- if to_beat is None: return idxenc
216
- cutoff = beat2index(idxenc, pos, vocab, to_beat, include_last_sep=include_last_sep)
217
- return idxenc[:cutoff]
218
-
219
- def mask_input(xb, mask_range, replacement_idx):
220
- xb = xb.copy()
221
- xb[(xb >= mask_range[0]) & (xb < mask_range[1])] = replacement_idx
222
- return xb
223
-
224
- def mask_section(xb, pos, token_range, replacement_idx, section_range=None):
225
- xb = xb.copy()
226
- token_mask = (xb >= token_range[0]) & (xb < token_range[1])
227
-
228
- if section_range is None: section_range = (None, None)
229
- section_mask = np.zeros_like(xb, dtype=bool)
230
- start_idx = find_beat(pos, section_range[0]) if section_range[0] is not None else 0
231
- end_idx = find_beat(pos, section_range[1]) if section_range[1] is not None else xb.shape[0]
232
- section_mask[start_idx:end_idx] = True
233
-
234
- xb[token_mask & section_mask] = replacement_idx
235
- return xb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/numpy_encode.py DELETED
@@ -1,302 +0,0 @@
1
- "Encoding music21 streams -> numpy array -> text"
2
-
3
- # import re
4
- import music21
5
- import numpy as np
6
- # from pathlib import Path
7
-
8
- BPB = 4 # beats per bar
9
- TIMESIG = f'{BPB}/4' # default time signature
10
- PIANO_RANGE = (21, 108)
11
- VALTSEP = -1 # separator value for numpy encoding
12
- VALTCONT = -2 # numpy value for TCONT - needed for compressing chord array
13
-
14
- SAMPLE_FREQ = 4
15
- NOTE_SIZE = 128
16
- DUR_SIZE = (10*BPB*SAMPLE_FREQ)+1 # Max length - 8 bars. Or 16 beats/quarternotes
17
- MAX_NOTE_DUR = (8*BPB*SAMPLE_FREQ)
18
-
19
- # Encoding process
20
- # 1. midi -> music21.Stream
21
- # 2. Stream -> numpy chord array (timestep X instrument X noterange)
22
- # 3. numpy array -> List[Timestep][NoteEnc]
23
- def midi2npenc(midi_file, skip_last_rest=True):
24
- "Converts midi file to numpy encoding for language model"
25
- stream = file2stream(midi_file) # 1.
26
- chordarr = stream2chordarr(stream) # 2.
27
- return chordarr2npenc(chordarr, skip_last_rest=skip_last_rest) # 3.
28
-
29
- # Decoding process
30
- # 1. NoteEnc -> numpy chord array
31
- # 2. numpy array -> music21.Stream
32
- def npenc2stream(arr, bpm=120):
33
- "Converts numpy encoding to music21 stream"
34
- chordarr = npenc2chordarr(np.array(arr)) # 1.
35
- return chordarr2stream(chordarr, bpm=bpm) # 2.
36
-
37
- ##### ENCODING ######
38
-
39
- # 1. File To STream
40
-
41
- def file2stream(fp):
42
- if isinstance(fp, music21.midi.MidiFile): return music21.midi.translate.midiFileToStream(fp)
43
- return music21.converter.parse(fp)
44
-
45
- # 2.
46
- def stream2chordarr(s, note_size=NOTE_SIZE, sample_freq=SAMPLE_FREQ, max_note_dur=MAX_NOTE_DUR):
47
- "Converts music21.Stream to 1-hot numpy array"
48
- # assuming 4/4 time
49
- # note x instrument x pitch
50
- # FYI: midi middle C value=60
51
-
52
- # (AS) TODO: need to order by instruments most played and filter out percussion or include the channel
53
- highest_time = max(s.flat.getElementsByClass('Note').highestTime, s.flat.getElementsByClass('Chord').highestTime)
54
- maxTimeStep = round(highest_time * sample_freq)+1
55
- score_arr = np.zeros((maxTimeStep, len(s.parts), NOTE_SIZE))
56
-
57
- def note_data(pitch, note):
58
- return (pitch.midi, int(round(note.offset*sample_freq)), int(round(note.duration.quarterLength*sample_freq)))
59
-
60
- for idx,part in enumerate(s.parts):
61
- notes=[]
62
- for elem in part.flat:
63
- if isinstance(elem, music21.note.Note):
64
- notes.append(note_data(elem.pitch, elem))
65
- if isinstance(elem, music21.chord.Chord):
66
- for p in elem.pitches:
67
- notes.append(note_data(p, elem))
68
-
69
- # sort notes by offset (1), duration (2) so that hits are not overwritten and longer notes have priority
70
- notes_sorted = sorted(notes, key=lambda x: (x[1], x[2]))
71
- for n in notes_sorted:
72
- if n is None: continue
73
- pitch,offset,duration = n
74
- if max_note_dur is not None and duration > max_note_dur: duration = max_note_dur
75
- score_arr[offset, idx, pitch] = duration
76
- score_arr[offset+1:offset+duration, idx, pitch] = VALTCONT # Continue holding note
77
- return score_arr
78
-
79
- def chordarr2npenc(chordarr, skip_last_rest=True):
80
- # combine instruments
81
- result = []
82
- wait_count = 0
83
- for idx,timestep in enumerate(chordarr):
84
- flat_time = timestep2npenc(timestep)
85
- if len(flat_time) == 0:
86
- wait_count += 1
87
- else:
88
- # pitch, octave, duration, instrument
89
- if wait_count > 0: result.append([VALTSEP, wait_count])
90
- result.extend(flat_time)
91
- wait_count = 1
92
- if wait_count > 0 and not skip_last_rest: result.append([VALTSEP, wait_count])
93
- return np.array(result, dtype=int).reshape(-1, 2) # reshaping. Just in case result is empty
94
-
95
- # Note: not worrying about overlaps - as notes will still play. just look tied
96
- # http://web.mit.edu/music21/doc/moduleReference/moduleStream.html#music21.stream.Stream.getOverlaps
97
- def timestep2npenc(timestep, note_range=PIANO_RANGE, enc_type=None):
98
- # inst x pitch
99
- notes = []
100
- for i,n in zip(*timestep.nonzero()):
101
- d = timestep[i,n]
102
- if d < 0: continue # only supporting short duration encoding for now
103
- if n < note_range[0] or n >= note_range[1]: continue # must be within midi range
104
- notes.append([n,d,i])
105
-
106
- notes = sorted(notes, key=lambda x: x[0], reverse=True) # sort by note (highest to lowest)
107
-
108
- if enc_type is None:
109
- # note, duration
110
- return [n[:2] for n in notes]
111
- if enc_type == 'parts':
112
- # note, duration, part
113
- return [n for n in notes]
114
- if enc_type == 'full':
115
- # note_class, duration, octave, instrument
116
- return [[n%12, d, n//12, i] for n,d,i in notes]
117
-
118
- ##### DECODING #####
119
-
120
- # 1.
121
- def npenc2chordarr(npenc, note_size=NOTE_SIZE):
122
- num_instruments = 1 if len(npenc.shape) <= 2 else npenc.max(axis=0)[-1]
123
-
124
- max_len = npenc_len(npenc)
125
- # score_arr = (steps, inst, note)
126
- score_arr = np.zeros((max_len, num_instruments, note_size))
127
-
128
- idx = 0
129
- for step in npenc:
130
- n,d,i = (step.tolist()+[0])[:3] # or n,d,i
131
- if n < VALTSEP: continue # special token
132
- if n == VALTSEP:
133
- idx += d
134
- continue
135
- score_arr[idx,i,n] = d
136
- return score_arr
137
-
138
- def npenc_len(npenc):
139
- duration = 0
140
- for t in npenc:
141
- if t[0] == VALTSEP: duration += t[1]
142
- return duration + 1
143
-
144
-
145
- # 2.
146
- def chordarr2stream(arr, sample_freq=SAMPLE_FREQ, bpm=120):
147
- duration = music21.duration.Duration(1. / sample_freq)
148
- stream = music21.stream.Score()
149
- stream.append(music21.meter.TimeSignature(TIMESIG))
150
- stream.append(music21.tempo.MetronomeMark(number=bpm))
151
- stream.append(music21.key.KeySignature(0))
152
- for inst in range(arr.shape[1]):
153
- p = partarr2stream(arr[:,inst,:], duration)
154
- stream.append(p)
155
- stream = stream.transpose(0)
156
- return stream
157
-
158
- # 2b.
159
- def partarr2stream(partarr, duration):
160
- "convert instrument part to music21 chords"
161
- part = music21.stream.Part()
162
- part.append(music21.instrument.Piano())
163
- part_append_duration_notes(partarr, duration, part) # notes already have duration calculated
164
-
165
- return part
166
-
167
- def part_append_duration_notes(partarr, duration, stream):
168
- "convert instrument part to music21 chords"
169
- for tidx,t in enumerate(partarr):
170
- note_idxs = np.where(t > 0)[0] # filter out any negative values (continuous mode)
171
- if len(note_idxs) == 0: continue
172
- notes = []
173
- for nidx in note_idxs:
174
- note = music21.note.Note(nidx)
175
- note.duration = music21.duration.Duration(partarr[tidx,nidx]*duration.quarterLength)
176
- notes.append(note)
177
- for g in group_notes_by_duration(notes):
178
- if len(g) == 1:
179
- stream.insert(tidx*duration.quarterLength, g[0])
180
- else:
181
- chord = music21.chord.Chord(g)
182
- stream.insert(tidx*duration.quarterLength, chord)
183
- return stream
184
-
185
- from itertools import groupby
186
- # combining notes with different durations into a single chord may overwrite conflicting durations. Example: aylictal/still-waters-run-deep
187
- def group_notes_by_duration(notes):
188
- "separate notes into chord groups"
189
- keyfunc = lambda n: n.duration.quarterLength
190
- notes = sorted(notes, key=keyfunc)
191
- return [list(g) for k,g in groupby(notes, keyfunc)]
192
-
193
-
194
- # Midi -> npenc Conversion helpers
195
- def is_valid_npenc(npenc, note_range=PIANO_RANGE, max_dur=DUR_SIZE,
196
- min_notes=32, input_path=None, verbose=True):
197
- if len(npenc) < min_notes:
198
- if verbose: print('Sequence too short:', len(npenc), input_path)
199
- return False
200
- if (npenc[:,1] >= max_dur).any():
201
- if verbose: print(f'npenc exceeds max {max_dur} duration:', npenc[:,1].max(), input_path)
202
- return False
203
- # https://en.wikipedia.org/wiki/Scientific_pitch_notation - 88 key range - 21 = A0, 108 = C8
204
- if ((npenc[...,0] > VALTSEP) & ((npenc[...,0] < note_range[0]) | (npenc[...,0] >= note_range[1]))).any():
205
- print(f'npenc out of piano note range {note_range}:', input_path)
206
- return False
207
- return True
208
-
209
- # seperates overlapping notes to different tracks
210
- def remove_overlaps(stream, separate_chords=True):
211
- if not separate_chords:
212
- return stream.flat.makeVoices().voicesToParts()
213
- return separate_melody_chord(stream)
214
-
215
- # seperates notes and chords to different tracks
216
- def separate_melody_chord(stream):
217
- new_stream = music21.stream.Score()
218
- if stream.timeSignature: new_stream.append(stream.timeSignature)
219
- new_stream.append(stream.metronomeMarkBoundaries()[0][-1])
220
- if stream.keySignature: new_stream.append(stream.keySignature)
221
-
222
- melody_part = music21.stream.Part(stream.flat.getElementsByClass('Note'))
223
- melody_part.insert(0, stream.getInstrument())
224
- chord_part = music21.stream.Part(stream.flat.getElementsByClass('Chord'))
225
- chord_part.insert(0, stream.getInstrument())
226
- new_stream.append(melody_part)
227
- new_stream.append(chord_part)
228
- return new_stream
229
-
230
- # processing functions for sanitizing data
231
-
232
- def compress_chordarr(chordarr):
233
- return shorten_chordarr_rests(trim_chordarr_rests(chordarr))
234
-
235
- def trim_chordarr_rests(arr, max_rests=4, sample_freq=SAMPLE_FREQ):
236
- # max rests is in quarter notes
237
- # max 1 bar between song start and end
238
- start_idx = 0
239
- max_sample = max_rests*sample_freq
240
- for idx,t in enumerate(arr):
241
- if (t != 0).any(): break
242
- start_idx = idx+1
243
-
244
- end_idx = 0
245
- for idx,t in enumerate(reversed(arr)):
246
- if (t != 0).any(): break
247
- end_idx = idx+1
248
- start_idx = start_idx - start_idx % max_sample
249
- end_idx = end_idx - end_idx % max_sample
250
- # if start_idx > 0 or end_idx > 0: print('Trimming rests. Start, end:', start_idx, len(arr)-end_idx, end_idx)
251
- return arr[start_idx:(len(arr)-end_idx)]
252
-
253
- def shorten_chordarr_rests(arr, max_rests=8, sample_freq=SAMPLE_FREQ):
254
- # max rests is in quarter notes
255
- # max 2 bar pause
256
- rest_count = 0
257
- result = []
258
- max_sample = max_rests*sample_freq
259
- for timestep in arr:
260
- if (timestep==0).all():
261
- rest_count += 1
262
- else:
263
- if rest_count > max_sample:
264
- # old_count = rest_count
265
- rest_count = (rest_count % sample_freq) + max_sample
266
- # print(f'Compressing rests: {old_count} -> {rest_count}')
267
- for i in range(rest_count): result.append(np.zeros(timestep.shape))
268
- rest_count = 0
269
- result.append(timestep)
270
- for i in range(rest_count): result.append(np.zeros(timestep.shape))
271
- return np.array(result)
272
-
273
- # sequence 2 sequence convenience functions
274
-
275
- def stream2npenc_parts(stream, sort_pitch=True):
276
- chordarr = stream2chordarr(stream)
277
- _,num_parts,_ = chordarr.shape
278
- parts = [part_enc(chordarr, i) for i in range(num_parts)]
279
- return sorted(parts, key=avg_pitch, reverse=True) if sort_pitch else parts
280
-
281
- def chordarr_combine_parts(parts):
282
- max_ts = max([p.shape[0] for p in parts])
283
- parts_padded = [pad_part_to(p, max_ts) for p in parts]
284
- chordarr_comb = np.concatenate(parts_padded, axis=1)
285
- return chordarr_comb
286
-
287
- def pad_part_to(p, target_size):
288
- pad_width = ((0,target_size-p.shape[0]),(0,0),(0,0))
289
- return np.pad(p, pad_width, 'constant')
290
-
291
- def part_enc(chordarr, part):
292
- partarr = chordarr[:,part:part+1,:]
293
- npenc = chordarr2npenc(partarr)
294
- return npenc
295
-
296
- def avg_tempo(t, sep_idx=VALTSEP):
297
- avg = t[t[:, 0] == sep_idx][:, 1].sum()/t.shape[0]
298
- avg = int(round(avg/SAMPLE_FREQ))
299
- return 'mt'+str(min(avg, MTEMPO_SIZE-1))
300
-
301
- def avg_pitch(t, sep_idx=VALTSEP):
302
- return t[t[:, 0] > sep_idx][:, 0].mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/utils/__init__.py DELETED
File without changes
utils/musicautobot/utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (176 Bytes)
 
utils/musicautobot/utils/__pycache__/attention_mask.cpython-310.pyc DELETED
Binary file (1.3 kB)
 
utils/musicautobot/utils/__pycache__/file_processing.cpython-310.pyc DELETED
Binary file (2.62 kB)
 
utils/musicautobot/utils/__pycache__/midifile.cpython-310.pyc DELETED
Binary file (4.5 kB)
 
utils/musicautobot/utils/__pycache__/setup_musescore.cpython-310.pyc DELETED
Binary file (1.79 kB)
 
utils/musicautobot/utils/__pycache__/top_k_top_p.cpython-310.pyc DELETED
Binary file (1.24 kB)
 
utils/musicautobot/utils/attention_mask.py DELETED
@@ -1,21 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
- def window_mask(x_len, device, m_len=0, size=(1,1)):
5
- win_size,k = size
6
- mem_mask = torch.zeros((x_len,m_len), device=device)
7
- tri_mask = torch.triu(torch.ones((x_len//win_size+1,x_len//win_size+1), device=device),diagonal=k)
8
- window_mask = tri_mask.repeat_interleave(win_size,dim=0).repeat_interleave(win_size,dim=1)[:x_len,:x_len]
9
- if x_len: window_mask[...,0] = 0 # Always allowing first index to see. Otherwise you'll get NaN loss
10
- mask = torch.cat((mem_mask, window_mask), dim=1)[None,None]
11
- return mask.bool() if hasattr(mask, 'bool') else mask.byte()
12
-
13
- def rand_window_mask(x_len,m_len,device,max_size:int=None,p:float=0.2,is_eval:bool=False):
14
- if is_eval or np.random.rand() >= p or max_size is None:
15
- win_size,k = (1,1)
16
- else: win_size,k = (np.random.randint(0,max_size)+1,0)
17
- return window_mask(x_len, device, m_len, size=(win_size,k))
18
-
19
- def lm_mask(x_len, device):
20
- mask = torch.triu(torch.ones((x_len, x_len), device=device), diagonal=1)[None,None]
21
- return mask.bool() if hasattr(mask, 'bool') else mask.byte()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/utils/file_processing.py DELETED
@@ -1,52 +0,0 @@
1
- "Parallel processing for midi files"
2
- import csv
3
- from fastprogress.fastprogress import master_bar, progress_bar
4
- from pathlib import Path
5
- from pebble import ProcessPool
6
- from concurrent.futures import TimeoutError
7
- import numpy as np
8
-
9
- # https://stackoverflow.com/questions/20991968/asynchronous-multiprocessing-with-a-worker-pool-in-python-how-to-keep-going-aft
10
- def process_all(func, arr, timeout_func=None, total=None, max_workers=None, timeout=None):
11
- with ProcessPool() as pool:
12
- future = pool.map(func, arr, timeout=timeout)
13
-
14
- iterator = future.result()
15
- results = []
16
- for i in progress_bar(range(len(arr)), total=len(arr)):
17
- try:
18
- result = next(iterator)
19
- if result: results.append(result)
20
- except StopIteration:
21
- break
22
- except TimeoutError as error:
23
- if timeout_func: timeout_func(arr[i], error.args[1])
24
- return results
25
-
26
- def process_file(file_path, tfm_func=None, src_path=None, dest_path=None):
27
- "Utility function that transforms midi file to numpy array."
28
- output_file = Path(str(file_path).replace(str(src_path), str(dest_path))).with_suffix('.npy')
29
- if output_file.exists(): return output_file
30
- output_file.parent.mkdir(parents=True, exist_ok=True)
31
-
32
- # Call tfm_func and save file
33
- npenc = tfm_func(file_path)
34
- if npenc is not None:
35
- np.save(output_file, npenc)
36
- return output_file
37
-
38
- def arr2csv(arr, out_file):
39
- "Convert metadata array to csv"
40
- all_keys = {k for d in arr for k in d.keys()}
41
- arr = [format_values(x) for x in arr]
42
- with open(out_file, 'w') as f:
43
- dict_writer = csv.DictWriter(f, list(all_keys))
44
- dict_writer.writeheader()
45
- dict_writer.writerows(arr)
46
-
47
- def format_values(d):
48
- "Format array values for csv encoding"
49
- def format_value(v):
50
- if isinstance(v, list): return ','.join(v)
51
- return v
52
- return {k:format_value(v) for k,v in d.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/utils/lamb.py DELETED
@@ -1,106 +0,0 @@
1
- # SOURCE: https://github.com/cybertronai/pytorch-lamb/
2
-
3
- import collections
4
- import math
5
-
6
- import torch
7
- from torch.optim import Optimizer
8
-
9
-
10
- class Lamb(Optimizer):
11
- r"""Implements Lamb algorithm.
12
-
13
- It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_.
14
-
15
- Arguments:
16
- params (iterable): iterable of parameters to optimize or dicts defining
17
- parameter groups
18
- lr (float, optional): learning rate (default: 1e-3)
19
- betas (Tuple[float, float], optional): coefficients used for computing
20
- running averages of gradient and its square (default: (0.9, 0.999))
21
- eps (float, optional): term added to the denominator to improve
22
- numerical stability (default: 1e-8)
23
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
24
- adam (bool, optional): always use trust ratio = 1, which turns this into
25
- Adam. Useful for comparison purposes.
26
-
27
- .. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes:
28
- https://arxiv.org/abs/1904.00962
29
- """
30
-
31
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-4,
32
- weight_decay=0, adam=False):
33
- if not 0.0 <= lr:
34
- raise ValueError("Invalid learning rate: {}".format(lr))
35
- if not 0.0 <= eps:
36
- raise ValueError("Invalid epsilon value: {}".format(eps))
37
- if not 0.0 <= betas[0] < 1.0:
38
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
39
- if not 0.0 <= betas[1] < 1.0:
40
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
41
- defaults = dict(lr=lr, betas=betas, eps=eps,
42
- weight_decay=weight_decay)
43
- self.adam = adam
44
- super(Lamb, self).__init__(params, defaults)
45
-
46
- def step(self, closure=None):
47
- """Performs a single optimization step.
48
-
49
- Arguments:
50
- closure (callable, optional): A closure that reevaluates the model
51
- and returns the loss.
52
- """
53
- loss = None
54
- if closure is not None:
55
- loss = closure()
56
-
57
- for group in self.param_groups:
58
- for p in group['params']:
59
- if p.grad is None:
60
- continue
61
- grad = p.grad.data
62
- if grad.is_sparse:
63
- raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
64
-
65
- state = self.state[p]
66
-
67
- # State initialization
68
- if len(state) == 0:
69
- state['step'] = 0
70
- # Exponential moving average of gradient values
71
- state['exp_avg'] = torch.zeros_like(p.data)
72
- # Exponential moving average of squared gradient values
73
- state['exp_avg_sq'] = torch.zeros_like(p.data)
74
-
75
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
76
- beta1, beta2 = group['betas']
77
-
78
- state['step'] += 1
79
-
80
- if group['weight_decay'] != 0:
81
- grad.add_(group['weight_decay'], p.data)
82
-
83
- # Decay the first and second moment running average coefficient
84
- exp_avg.mul_(beta1).add_(1 - beta1, grad)
85
- exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
86
- denom = exp_avg_sq.sqrt().add_(group['eps'])
87
-
88
- bias_correction1 = 1 - beta1 ** state['step']
89
- bias_correction2 = 1 - beta2 ** state['step']
90
- # Apply bias to lr to avoid broadcast.
91
- step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
92
-
93
- adam_step = exp_avg / denom
94
- # L2 norm uses sum, but here since we're dividing, use mean to avoid overflow.
95
- r1 = p.data.pow(2).mean().sqrt()
96
- r2 = adam_step.pow(2).mean().sqrt()
97
- r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10)
98
- state['r1'] = r1
99
- state['r2'] = r2
100
- state['r'] = r
101
- if self.adam:
102
- r = 1
103
-
104
- p.data.add_(-step_size * r, adam_step)
105
-
106
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/utils/midifile.py DELETED
@@ -1,107 +0,0 @@
1
- "Transform functions for raw midi files"
2
- from enum import Enum
3
- import music21
4
-
5
- PIANO_TYPES = list(range(24)) + list(range(80, 96)) # Piano, Synths
6
- PLUCK_TYPES = list(range(24, 40)) + list(range(104, 112)) # Guitar, Bass, Ethnic
7
- BRIGHT_TYPES = list(range(40, 56)) + list(range(56, 80))
8
-
9
- PIANO_RANGE = (21, 109) # https://en.wikipedia.org/wiki/Scientific_pitch_notation
10
-
11
- class Track(Enum):
12
- PIANO = 0 # discrete instruments - keyboard, woodwinds
13
- PLUCK = 1 # continuous instruments with pitch bend: violin, trombone, synths
14
- BRIGHT = 2
15
- PERC = 3
16
- UNDEF = 4
17
-
18
- type2inst = {
19
- # use print_music21_instruments() to see supported types
20
- Track.PIANO: 0, # Piano
21
- Track.PLUCK: 24, # Guitar
22
- Track.BRIGHT: 40, # Violin
23
- Track.PERC: 114, # Steel Drum
24
- }
25
-
26
- # INFO_TYPES = set(['TIME_SIGNATURE', 'KEY_SIGNATURE'])
27
- INFO_TYPES = set(['TIME_SIGNATURE', 'KEY_SIGNATURE', 'SET_TEMPO'])
28
-
29
- def file2mf(fp):
30
- mf = music21.midi.MidiFile()
31
- if isinstance(fp, bytes):
32
- mf.readstr(fp)
33
- else:
34
- mf.open(fp)
35
- mf.read()
36
- mf.close()
37
- return mf
38
-
39
- def mf2stream(mf): return music21.midi.translate.midiFileToStream(mf)
40
-
41
- def is_empty_midi(fp):
42
- if fp is None: return False
43
- mf = file2mf(fp)
44
- return not any([t.hasNotes() for t in mf.tracks])
45
-
46
- def num_piano_tracks(fp):
47
- music_file = file2mf(fp)
48
- note_tracks = [t for t in music_file.tracks if t.hasNotes() and get_track_type(t) == Track.PIANO]
49
- return len(note_tracks)
50
-
51
- def is_channel(t, c_val):
52
- return any([c == c_val for c in t.getChannels()])
53
-
54
- def track_sort(t): # sort by 1. variation of pitch, 2. number of notes
55
- return len(unique_track_notes(t)), len(t.events)
56
-
57
- def is_piano_note(pitch):
58
- return (pitch >= PIANO_RANGE[0]) and (pitch < PIANO_RANGE[1])
59
-
60
- def unique_track_notes(t):
61
- return { e.pitch for e in t.events if e.pitch is not None }
62
-
63
- def compress_midi_file(fp, cutoff=6, min_variation=3, supported_types=set([Track.PIANO, Track.PLUCK, Track.BRIGHT])):
64
- music_file = file2mf(fp)
65
-
66
- info_tracks = [t for t in music_file.tracks if not t.hasNotes()]
67
- note_tracks = [t for t in music_file.tracks if t.hasNotes()]
68
-
69
- if len(note_tracks) > cutoff:
70
- note_tracks = sorted(note_tracks, key=track_sort, reverse=True)
71
-
72
- supported_tracks = []
73
- for idx,t in enumerate(note_tracks):
74
- if len(supported_tracks) >= cutoff: break
75
- track_type = get_track_type(t)
76
- if track_type not in supported_types: continue
77
- pitch_set = unique_track_notes(t)
78
- if (len(pitch_set) < min_variation): continue # must have more than x unique notes
79
- if not all(map(is_piano_note, pitch_set)): continue # must not contain midi notes outside of piano range
80
- # if track_type == Track.UNDEF: print('Could not designate track:', fp, t)
81
- change_track_instrument(t, type2inst[track_type])
82
- supported_tracks.append(t)
83
- if not supported_tracks: return None
84
- music_file.tracks = info_tracks + supported_tracks
85
- return music_file
86
-
87
- def get_track_type(t):
88
- if is_channel(t, 10): return Track.PERC
89
- i = get_track_instrument(t)
90
- if i in PIANO_TYPES: return Track.PIANO
91
- if i in PLUCK_TYPES: return Track.PLUCK
92
- if i in BRIGHT_TYPES: return Track.BRIGHT
93
- return Track.UNDEF
94
-
95
- def get_track_instrument(t):
96
- for idx,e in enumerate(t.events):
97
- if e.type == 'PROGRAM_CHANGE': return e.data
98
- return None
99
-
100
- def change_track_instrument(t, value):
101
- for idx,e in enumerate(t.events):
102
- if e.type == 'PROGRAM_CHANGE': e.data = value
103
-
104
- def print_music21_instruments():
105
- for i in range(200):
106
- try: print(i, music21.instrument.instrumentFromMidiProgram(i))
107
- except: pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/utils/setup_musescore.py DELETED
@@ -1,46 +0,0 @@
1
- def setup_musescore(musescore_path=None):
2
- if not is_ipython(): return
3
-
4
- import platform
5
- from music21 import environment
6
- from pathlib import Path
7
-
8
- system = platform.system()
9
- if system == 'Linux':
10
- import os
11
- os.environ['QT_QPA_PLATFORM']='offscreen' # https://musescore.org/en/node/29041
12
-
13
- existing_path = environment.get('musicxmlPath')
14
- if existing_path: return
15
- if musescore_path is None:
16
- if system == 'Darwin':
17
- app_paths = list(Path('/Applications').glob('MuseScore *.app'))
18
- if len(app_paths): musescore_path = app_paths[-1]/'Contents/MacOS/mscore'
19
- elif system == 'Linux':
20
- musescore_path = '/usr/bin/musescore'
21
-
22
- if musescore_path is None or not Path(musescore_path).exists():
23
- print('Warning: Could not find musescore installation. Please install musescore (see README) and/or update music21 environment paths')
24
- else :
25
- environment.set('musicxmlPath', musescore_path)
26
- environment.set('musescoreDirectPNGPath', musescore_path)
27
-
28
- def is_ipython():
29
- try: get_ipython
30
- except: return False
31
- return True
32
-
33
- def is_colab():
34
- try: import google.colab
35
- except: return False
36
- return True
37
-
38
- def setup_fluidsynth():
39
- from midi2audio import FluidSynth
40
- from IPython.display import Audio
41
-
42
- def play_wav(stream):
43
- out_midi = stream.write('midi')
44
- out_wav = str(Path(out_midi).with_suffix('.wav'))
45
- FluidSynth("font.sf2").midi_to_audio(out_midi, out_wav)
46
- return Audio(out_wav)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/utils/stacked_dataloader.py DELETED
@@ -1,70 +0,0 @@
1
- "Dataloader wrapper that can combine and handle multiple dataloaders for multitask training"
2
- from fastai.callback import Callback
3
- from typing import Callable
4
-
5
- __all__ = ['StackedDataBunch']
6
-
7
- # DataLoading
8
- class StackedDataBunch():
9
- def __init__(self, dbs, num_it=100):
10
- self.dbs = dbs
11
- self.train_dl = StackedDataloader([db.train_dl for db in self.dbs], num_it)
12
- self.valid_dl = StackedDataloader([db.valid_dl for db in self.dbs], num_it)
13
- self.train_ds = None
14
- self.path = dbs[0].path
15
- self.device = dbs[0].device
16
- self.vocab = dbs[0].vocab
17
- self.empty_val = False
18
-
19
- def add_tfm(self,tfm:Callable)->None:
20
- for dl in self.dbs: dl.add_tfm(tfm)
21
-
22
- def remove_tfm(self,tfm:Callable)->None:
23
- for dl in self.dbs: dl.remove_tfm(tfm)
24
-
25
- # Helper functions
26
- class StackedDataset(Callback):
27
- def __init__(self, dss):
28
- self.dss = dss
29
- def __getattribute__(self, attr):
30
- if attr == 'dss': return super().__getattribute__(attr)
31
- def redirected(*args, **kwargs):
32
- for ds in self.dss:
33
- if hasattr(ds, attr): getattr(ds, attr)(*args, **kwargs)
34
- return redirected
35
- def __len__(self)->int: return sum([len(ds) for ds in self.dss])
36
- def __repr__(self): return '\n'.join([self.__class__.__name__] + [repr(ds) for ds in self.dss])
37
-
38
- class StackedDataloader():
39
- def __init__(self, dls, num_it=100):
40
- self.dls = dls
41
- self.dataset = StackedDataset([dl.dataset for dl in dls if hasattr(dl, 'dataset')])
42
- self.num_it = num_it
43
- self.dl_idx = -1
44
-
45
- def __len__(self)->int: return sum([len(dl) for dl in self.dls])
46
- def __getattr__(self, attr):
47
- def redirected(*args, **kwargs):
48
- for dl in self.dls:
49
- if hasattr(dl, attr):
50
- getattr(dl, attr)(*args, **kwargs)
51
- return redirected
52
-
53
- def __iter__(self):
54
- "Process and returns items from `DataLoader`."
55
- iters = [iter(dl) for dl in self.dls]
56
- self.dl_idx = -1
57
- while len(iters):
58
- self.dl_idx = (self.dl_idx+1) % len(iters)
59
- for b in range(self.num_it):
60
- try:
61
- yield next(iters[self.dl_idx])
62
- except StopIteration as e:
63
- iters.remove(iters[self.dl_idx])
64
- break
65
- # raise StopIteration
66
-
67
- def new(self, **kwargs):
68
- "Create a new copy of `self` with `kwargs` replacing current values."
69
- new_dls = [dl.new(**kwargs) for dl in self.dls]
70
- return StackedDataloader(new_dls, self.num_it)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/utils/top_k_top_p.py DELETED
@@ -1,35 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
-
4
- __all__ = ['top_k_top_p']
5
-
6
- # top_k + nucleus filter - https://twitter.com/thom_wolf/status/1124263861727760384?lang=en
7
- # https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
8
- def top_k_top_p(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
9
- """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
10
- Args:
11
- logits: logits distribution shape (vocabulary size)
12
- top_k >0: keep only top k tokens with highest probability (top-k filtering).
13
- top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
14
- """
15
- logits = logits.clone()
16
- assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
17
- top_k = min(top_k, logits.size(-1)) # Safety check
18
- if top_k > 0:
19
- # Remove all tokens with a probability less than the last token of the top-k
20
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
21
- logits[indices_to_remove] = filter_value
22
-
23
- if top_p > 0.0:
24
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
25
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
26
-
27
- # Remove tokens with cumulative probability above the threshold
28
- sorted_indices_to_remove = cumulative_probs > top_p
29
- # Shift the indices to the right to keep also the first token above the threshold
30
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
31
- sorted_indices_to_remove[..., 0] = 0
32
-
33
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
34
- logits[indices_to_remove] = filter_value
35
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/musicautobot/vocab.py DELETED
@@ -1,93 +0,0 @@
1
- from fastai.basics import *
2
- from .numpy_encode import *
3
- from .music_transformer import transform
4
-
5
- BOS = 'xxbos'
6
- PAD = 'xxpad'
7
- EOS = 'xxeos'
8
- MASK = 'xxmask' # Used for BERT masked language modeling.
9
- CSEQ = 'xxcseq' # Used for Seq2Seq translation - denotes start of chord sequence
10
- MSEQ = 'xxmseq' # Used for Seq2Seq translation - denotes start of melody sequence
11
-
12
- # Deprecated tokens. Kept for compatibility
13
- S2SCLS = 'xxs2scls' # deprecated
14
- NSCLS = 'xxnscls' # deprecated
15
-
16
- SEP = 'xxsep' # Used to denote end of timestep (required for polyphony). separator idx = -1 (part of notes)
17
-
18
- SPECIAL_TOKS = [BOS, PAD, EOS, S2SCLS, MASK, CSEQ, MSEQ, NSCLS, SEP] # Important: SEP token must be last
19
-
20
- NOTE_TOKS = [f'n{i}' for i in range(NOTE_SIZE)]
21
- DUR_TOKS = [f'd{i}' for i in range(DUR_SIZE)]
22
- NOTE_START, NOTE_END = NOTE_TOKS[0], NOTE_TOKS[-1]
23
- DUR_START, DUR_END = DUR_TOKS[0], DUR_TOKS[-1]
24
-
25
- MTEMPO_SIZE = 10
26
- MTEMPO_OFF = 'mt0'
27
- MTEMPO_TOKS = [f'mt{i}' for i in range(MTEMPO_SIZE)]
28
-
29
- # Vocab - token to index mapping
30
- class MusicVocab():
31
- "Contain the correspondence between numbers and tokens and numericalize."
32
- def __init__(self, itos:Collection[str]):
33
- self.itos = itos
34
- self.stoi = {v:k for k,v in enumerate(self.itos)}
35
-
36
- def numericalize(self, t:Collection[str]) -> List[int]:
37
- "Convert a list of tokens `t` to their ids."
38
- return [self.stoi[w] for w in t]
39
-
40
- def textify(self, nums:Collection[int], sep=' ') -> List[str]:
41
- "Convert a list of `nums` to their tokens."
42
- items = [self.itos[i] for i in nums]
43
- return sep.join(items) if sep is not None else items
44
-
45
- def to_music_item(self, idxenc):
46
- return transform.MusicItem(idxenc, self)
47
-
48
- @property
49
- def mask_idx(self): return self.stoi[MASK]
50
- @property
51
- def pad_idx(self): return self.stoi[PAD]
52
- @property
53
- def bos_idx(self): return self.stoi[BOS]
54
- @property
55
- def sep_idx(self): return self.stoi[SEP]
56
- @property
57
- def npenc_range(self): return (self.stoi[SEP], self.stoi[DUR_END]+1)
58
- @property
59
- def note_range(self): return self.stoi[NOTE_START], self.stoi[NOTE_END]+1
60
- @property
61
- def dur_range(self): return self.stoi[DUR_START], self.stoi[DUR_END]+1
62
-
63
- def is_duration(self, idx):
64
- return idx >= self.dur_range[0] and idx < self.dur_range[1]
65
- def is_duration_or_pad(self, idx):
66
- return idx == self.pad_idx or self.is_duration(idx)
67
-
68
- def __getstate__(self):
69
- return {'itos':self.itos}
70
-
71
- def __setstate__(self, state:dict):
72
- self.itos = state['itos']
73
- self.stoi = {v:k for k,v in enumerate(self.itos)}
74
-
75
- def __len__(self): return len(self.itos)
76
-
77
- def save(self, path):
78
- "Save `self.itos` in `path`"
79
- pickle.dump(self.itos, open(path, 'wb'))
80
-
81
- @classmethod
82
- def create(cls) -> 'Vocab':
83
- "Create a vocabulary from a set of `tokens`."
84
- itos = SPECIAL_TOKS + NOTE_TOKS + DUR_TOKS + MTEMPO_TOKS
85
- if len(itos)%8 != 0:
86
- itos = itos + [f'dummy{i}' for i in range(len(itos)%8)]
87
- return cls(itos)
88
-
89
- @classmethod
90
- def load(cls, path):
91
- "Load the `Vocab` contained in `path`"
92
- itos = pickle.load(open(path, 'rb'))
93
- return cls(itos)