Spaces:
Build error
Build error
File size: 7,253 Bytes
f35cc94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from fastai.basics import *
from fastai.text.learner import LanguageLearner, get_language_model, _model_meta
from .model import *
from .transform import MusicItem
from ..numpy_encode import SAMPLE_FREQ
from ..utils.top_k_top_p import top_k_top_p
from ..utils.midifile import is_empty_midi
_model_meta[MusicTransformerXL] = _model_meta[TransformerXL] # copy over fastai's model metadata
def music_model_learner(data:DataBunch, arch=MusicTransformerXL, config:dict=None, drop_mult:float=1.,
pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
"Create a `Learner` with a language model from `data` and `arch`."
meta = _model_meta[arch]
if pretrained_path:
state = torch.load(pretrained_path, map_location='cpu')
if config is None: config = state['config']
model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)
learn = MusicLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)
if pretrained_path:
get_model(model).load_state_dict(state['model'], strict=False)
if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
try: learn.opt.load_state_dict(state['opt'])
except: pass
del state
gc.collect()
return learn
# Predictions
from fastai import basic_train # for predictions
class MusicLearner(LanguageLearner):
def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
out_path = super().save(file, return_path=True, with_opt=with_opt)
if config and out_path:
state = torch.load(out_path)
state['config'] = config
torch.save(state, out_path)
del state
gc.collect()
return out_path
def beam_search(self, xb:Tensor, n_words:int, top_k:int=10, beam_sz:int=10, temperature:float=1.,
):
"Return the `n_words` that come after `text` using beam search."
self.model.reset()
self.model.eval()
xb_length = xb.shape[-1]
if xb.shape[0] > 1: xb = xb[0][None]
yb = torch.ones_like(xb)
nodes = None
xb = xb.repeat(top_k, 1)
nodes = xb.clone()
scores = xb.new_zeros(1).float()
with torch.no_grad():
for k in progress_bar(range(n_words), leave=False):
out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)
values, indices = out.topk(top_k, dim=-1)
scores = (-values + scores[:,None]).view(-1)
indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)
sort_idx = scores.argsort()[:beam_sz]
scores = scores[sort_idx]
nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),
indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)
nodes = nodes.view(-1, nodes.size(2))[sort_idx]
self.model[0].select_hidden(indices_idx[sort_idx])
xb = nodes[:,-1][:,None]
if temperature != 1.: scores.div_(temperature)
node_idx = torch.multinomial(torch.exp(-scores), 1).item()
return [i.item() for i in nodes[node_idx][xb_length:] ]
def predict(self, item:MusicItem, n_words:int=128,
temperatures:float=(1.0,1.0), min_bars=4,
top_k=30, top_p=0.6):
"Return the `n_words` that come after `text`."
self.model.reset()
new_idx = []
vocab = self.data.vocab
x, pos = item.to_tensor(), item.get_pos_tensor()
last_pos = pos[-1] if len(pos) else 0
y = torch.tensor([0])
start_pos = last_pos
sep_count = 0
bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
vocab = self.data.vocab
repeat_count = 0
if hasattr(self.model[0], 'encode_position'):
encode_position = self.model[0].encode_position
else: encode_position = False
for i in progress_bar(range(n_words), leave=True):
with torch.no_grad():
if encode_position:
batch = { 'x': x[None], 'pos': pos[None] }
logits = self.model(batch)[0][-1][-1]
else:
logits = self.model(x[None])[0][-1][-1]
prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
# Temperature
# Use first temperatures value if last prediction was duration
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
temperature += repeat_penalty
if temperature != 1.: logits = logits / temperature
# Filter
# bar = 16 beats
filter_value = -float('Inf')
if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
# Sample
probs = F.softmax(logits, dim=-1)
idx = torch.multinomial(probs, 1).item()
# Update repeat count
num_choices = len(probs.nonzero().view(-1))
if num_choices <= 2: repeat_count += 1
else: repeat_count = repeat_count // 2
if prev_idx==vocab.sep_idx:
duration = idx - vocab.dur_range[0]
last_pos = last_pos + duration
bars_pred = (last_pos - start_pos) // 16
abs_bar = last_pos // 16
# if (bars % 8 == 0) and (bars_pred > min_bars): break
if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
if idx==vocab.bos_idx:
print('Predicted BOS token. Returning prediction...')
break
new_idx.append(idx)
x = x.new_tensor([idx])
pos = pos.new_tensor([last_pos])
pred = vocab.to_music_item(np.array(new_idx))
full = item.append(pred)
return pred, full
# High level prediction functions from midi file
def predict_from_midi(learn, midi=None, n_words=400,
temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
vocab = learn.data.vocab
seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
if seed_len is not None: seed = seed.trim_to_beat(seed_len)
pred, full = learn.predict(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
return full
def filter_invalid_indexes(res, prev_idx, vocab, filter_value=-float('Inf')):
if vocab.is_duration_or_pad(prev_idx):
res[list(range(*vocab.dur_range))] = filter_value
else:
res[list(range(*vocab.note_range))] = filter_value
return res
|