caslabs's picture
Upload 37 files
f35cc94
from fastai.basics import *
from fastai.text.learner import LanguageLearner, get_language_model, _model_meta
from .model import *
from .transform import MusicItem
from ..numpy_encode import SAMPLE_FREQ
from ..utils.top_k_top_p import top_k_top_p
from ..utils.midifile import is_empty_midi
_model_meta[MusicTransformerXL] = _model_meta[TransformerXL] # copy over fastai's model metadata
def music_model_learner(data:DataBunch, arch=MusicTransformerXL, config:dict=None, drop_mult:float=1.,
pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
"Create a `Learner` with a language model from `data` and `arch`."
meta = _model_meta[arch]
if pretrained_path:
state = torch.load(pretrained_path, map_location='cpu')
if config is None: config = state['config']
model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)
learn = MusicLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)
if pretrained_path:
get_model(model).load_state_dict(state['model'], strict=False)
if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
try: learn.opt.load_state_dict(state['opt'])
except: pass
del state
gc.collect()
return learn
# Predictions
from fastai import basic_train # for predictions
class MusicLearner(LanguageLearner):
def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
out_path = super().save(file, return_path=True, with_opt=with_opt)
if config and out_path:
state = torch.load(out_path)
state['config'] = config
torch.save(state, out_path)
del state
gc.collect()
return out_path
def beam_search(self, xb:Tensor, n_words:int, top_k:int=10, beam_sz:int=10, temperature:float=1.,
):
"Return the `n_words` that come after `text` using beam search."
self.model.reset()
self.model.eval()
xb_length = xb.shape[-1]
if xb.shape[0] > 1: xb = xb[0][None]
yb = torch.ones_like(xb)
nodes = None
xb = xb.repeat(top_k, 1)
nodes = xb.clone()
scores = xb.new_zeros(1).float()
with torch.no_grad():
for k in progress_bar(range(n_words), leave=False):
out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)
values, indices = out.topk(top_k, dim=-1)
scores = (-values + scores[:,None]).view(-1)
indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)
sort_idx = scores.argsort()[:beam_sz]
scores = scores[sort_idx]
nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),
indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)
nodes = nodes.view(-1, nodes.size(2))[sort_idx]
self.model[0].select_hidden(indices_idx[sort_idx])
xb = nodes[:,-1][:,None]
if temperature != 1.: scores.div_(temperature)
node_idx = torch.multinomial(torch.exp(-scores), 1).item()
return [i.item() for i in nodes[node_idx][xb_length:] ]
def predict(self, item:MusicItem, n_words:int=128,
temperatures:float=(1.0,1.0), min_bars=4,
top_k=30, top_p=0.6):
"Return the `n_words` that come after `text`."
self.model.reset()
new_idx = []
vocab = self.data.vocab
x, pos = item.to_tensor(), item.get_pos_tensor()
last_pos = pos[-1] if len(pos) else 0
y = torch.tensor([0])
start_pos = last_pos
sep_count = 0
bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
vocab = self.data.vocab
repeat_count = 0
if hasattr(self.model[0], 'encode_position'):
encode_position = self.model[0].encode_position
else: encode_position = False
for i in progress_bar(range(n_words), leave=True):
with torch.no_grad():
if encode_position:
batch = { 'x': x[None], 'pos': pos[None] }
logits = self.model(batch)[0][-1][-1]
else:
logits = self.model(x[None])[0][-1][-1]
prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
# Temperature
# Use first temperatures value if last prediction was duration
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
temperature += repeat_penalty
if temperature != 1.: logits = logits / temperature
# Filter
# bar = 16 beats
filter_value = -float('Inf')
if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
# Sample
probs = F.softmax(logits, dim=-1)
idx = torch.multinomial(probs, 1).item()
# Update repeat count
num_choices = len(probs.nonzero().view(-1))
if num_choices <= 2: repeat_count += 1
else: repeat_count = repeat_count // 2
if prev_idx==vocab.sep_idx:
duration = idx - vocab.dur_range[0]
last_pos = last_pos + duration
bars_pred = (last_pos - start_pos) // 16
abs_bar = last_pos // 16
# if (bars % 8 == 0) and (bars_pred > min_bars): break
if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
if idx==vocab.bos_idx:
print('Predicted BOS token. Returning prediction...')
break
new_idx.append(idx)
x = x.new_tensor([idx])
pos = pos.new_tensor([last_pos])
pred = vocab.to_music_item(np.array(new_idx))
full = item.append(pred)
return pred, full
# High level prediction functions from midi file
def predict_from_midi(learn, midi=None, n_words=400,
temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
vocab = learn.data.vocab
seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
if seed_len is not None: seed = seed.trim_to_beat(seed_len)
pred, full = learn.predict(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
return full
def filter_invalid_indexes(res, prev_idx, vocab, filter_value=-float('Inf')):
if vocab.is_duration_or_pad(prev_idx):
res[list(range(*vocab.dur_range))] = filter_value
else:
res[list(range(*vocab.note_range))] = filter_value
return res