Spaces:
Build error
Build error
from fastai.basics import * | |
from fastai.text.learner import LanguageLearner, get_language_model, _model_meta | |
from .model import * | |
from .transform import MusicItem | |
from ..numpy_encode import SAMPLE_FREQ | |
from ..utils.top_k_top_p import top_k_top_p | |
from ..utils.midifile import is_empty_midi | |
_model_meta[MusicTransformerXL] = _model_meta[TransformerXL] # copy over fastai's model metadata | |
def music_model_learner(data:DataBunch, arch=MusicTransformerXL, config:dict=None, drop_mult:float=1., | |
pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner': | |
"Create a `Learner` with a language model from `data` and `arch`." | |
meta = _model_meta[arch] | |
if pretrained_path: | |
state = torch.load(pretrained_path, map_location='cpu') | |
if config is None: config = state['config'] | |
model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult) | |
learn = MusicLearner(data, model, split_func=meta['split_lm'], **learn_kwargs) | |
if pretrained_path: | |
get_model(model).load_state_dict(state['model'], strict=False) | |
if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd) | |
try: learn.opt.load_state_dict(state['opt']) | |
except: pass | |
del state | |
gc.collect() | |
return learn | |
# Predictions | |
from fastai import basic_train # for predictions | |
class MusicLearner(LanguageLearner): | |
def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None): | |
"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)" | |
out_path = super().save(file, return_path=True, with_opt=with_opt) | |
if config and out_path: | |
state = torch.load(out_path) | |
state['config'] = config | |
torch.save(state, out_path) | |
del state | |
gc.collect() | |
return out_path | |
def beam_search(self, xb:Tensor, n_words:int, top_k:int=10, beam_sz:int=10, temperature:float=1., | |
): | |
"Return the `n_words` that come after `text` using beam search." | |
self.model.reset() | |
self.model.eval() | |
xb_length = xb.shape[-1] | |
if xb.shape[0] > 1: xb = xb[0][None] | |
yb = torch.ones_like(xb) | |
nodes = None | |
xb = xb.repeat(top_k, 1) | |
nodes = xb.clone() | |
scores = xb.new_zeros(1).float() | |
with torch.no_grad(): | |
for k in progress_bar(range(n_words), leave=False): | |
out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1) | |
values, indices = out.topk(top_k, dim=-1) | |
scores = (-values + scores[:,None]).view(-1) | |
indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1) | |
sort_idx = scores.argsort()[:beam_sz] | |
scores = scores[sort_idx] | |
nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)), | |
indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2) | |
nodes = nodes.view(-1, nodes.size(2))[sort_idx] | |
self.model[0].select_hidden(indices_idx[sort_idx]) | |
xb = nodes[:,-1][:,None] | |
if temperature != 1.: scores.div_(temperature) | |
node_idx = torch.multinomial(torch.exp(-scores), 1).item() | |
return [i.item() for i in nodes[node_idx][xb_length:] ] | |
def predict(self, item:MusicItem, n_words:int=128, | |
temperatures:float=(1.0,1.0), min_bars=4, | |
top_k=30, top_p=0.6): | |
"Return the `n_words` that come after `text`." | |
self.model.reset() | |
new_idx = [] | |
vocab = self.data.vocab | |
x, pos = item.to_tensor(), item.get_pos_tensor() | |
last_pos = pos[-1] if len(pos) else 0 | |
y = torch.tensor([0]) | |
start_pos = last_pos | |
sep_count = 0 | |
bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time | |
vocab = self.data.vocab | |
repeat_count = 0 | |
if hasattr(self.model[0], 'encode_position'): | |
encode_position = self.model[0].encode_position | |
else: encode_position = False | |
for i in progress_bar(range(n_words), leave=True): | |
with torch.no_grad(): | |
if encode_position: | |
batch = { 'x': x[None], 'pos': pos[None] } | |
logits = self.model(batch)[0][-1][-1] | |
else: | |
logits = self.model(x[None])[0][-1][-1] | |
prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx | |
# Temperature | |
# Use first temperatures value if last prediction was duration | |
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1] | |
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature | |
temperature += repeat_penalty | |
if temperature != 1.: logits = logits / temperature | |
# Filter | |
# bar = 16 beats | |
filter_value = -float('Inf') | |
if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value | |
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value) | |
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value) | |
# Sample | |
probs = F.softmax(logits, dim=-1) | |
idx = torch.multinomial(probs, 1).item() | |
# Update repeat count | |
num_choices = len(probs.nonzero().view(-1)) | |
if num_choices <= 2: repeat_count += 1 | |
else: repeat_count = repeat_count // 2 | |
if prev_idx==vocab.sep_idx: | |
duration = idx - vocab.dur_range[0] | |
last_pos = last_pos + duration | |
bars_pred = (last_pos - start_pos) // 16 | |
abs_bar = last_pos // 16 | |
# if (bars % 8 == 0) and (bars_pred > min_bars): break | |
if (i / n_words > 0.80) and (abs_bar % 4 == 0): break | |
if idx==vocab.bos_idx: | |
print('Predicted BOS token. Returning prediction...') | |
break | |
new_idx.append(idx) | |
x = x.new_tensor([idx]) | |
pos = pos.new_tensor([last_pos]) | |
pred = vocab.to_music_item(np.array(new_idx)) | |
full = item.append(pred) | |
return pred, full | |
# High level prediction functions from midi file | |
def predict_from_midi(learn, midi=None, n_words=400, | |
temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs): | |
vocab = learn.data.vocab | |
seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab) | |
if seed_len is not None: seed = seed.trim_to_beat(seed_len) | |
pred, full = learn.predict(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs) | |
return full | |
def filter_invalid_indexes(res, prev_idx, vocab, filter_value=-float('Inf')): | |
if vocab.is_duration_or_pad(prev_idx): | |
res[list(range(*vocab.dur_range))] = filter_value | |
else: | |
res[list(range(*vocab.note_range))] = filter_value | |
return res | |