caslabs's picture
Upload 37 files
f35cc94
from fastai.basics import *
from ..vocab import *
from ..utils.top_k_top_p import top_k_top_p
from ..utils.midifile import is_empty_midi
from ..music_transformer.transform import *
from ..music_transformer.learner import filter_invalid_indexes
from .model import get_multitask_model
from .dataloader import *
def multitask_model_learner(data:DataBunch, config:dict=None, drop_mult:float=1.,
pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
"Create a `Learner` with a language model from `data` and `arch`."
vocab = data.vocab
vocab_size = len(vocab)
if pretrained_path:
state = torch.load(pretrained_path, map_location='cpu')
if config is None: config = state['config']
model = get_multitask_model(vocab_size, config=config, drop_mult=drop_mult, pad_idx=vocab.pad_idx)
metrics = [AverageMultiMetric(partial(m, pad_idx=vocab.pad_idx)) for m in [mask_acc, lm_acc, c2m_acc, m2c_acc]]
loss_func = MultiLoss(ignore_index=data.vocab.pad_idx)
learn = MultitaskLearner(data, model, loss_func=loss_func, metrics=metrics, **learn_kwargs)
if pretrained_path:
get_model(model).load_state_dict(state['model'], strict=False)
if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
try: learn.opt.load_state_dict(state['opt'])
except: pass
del state
gc.collect()
return learn
class MultitaskLearner(Learner):
def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
out_path = super().save(file, return_path=True, with_opt=with_opt)
if config and out_path:
state = torch.load(out_path)
state['config'] = config
torch.save(state, out_path)
del state
gc.collect()
return out_path
def predict_nw(self, item:MusicItem, n_words:int=128,
temperatures:float=(1.0,1.0), min_bars=4,
top_k=30, top_p=0.6):
"Return the `n_words` that come after `text`."
self.model.reset()
new_idx = []
vocab = self.data.vocab
x, pos = item.to_tensor(), item.get_pos_tensor()
last_pos = pos[-1] if len(pos) else 0
y = torch.tensor([0])
start_pos = last_pos
sep_count = 0
bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
vocab = self.data.vocab
repeat_count = 0
for i in progress_bar(range(n_words), leave=True):
batch = { 'lm': { 'x': x[None], 'pos': pos[None] } }, y
logits = self.pred_batch(batch=batch)['lm'][-1][-1]
prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
# Temperature
# Use first temperatures value if last prediction was duration
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
temperature += repeat_penalty
if temperature != 1.: logits = logits / temperature
# Filter
# bar = 16 beats
filter_value = -float('Inf')
if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
# Sample
probs = F.softmax(logits, dim=-1)
idx = torch.multinomial(probs, 1).item()
# Update repeat count
num_choices = len(probs.nonzero().view(-1))
if num_choices <= 2: repeat_count += 1
else: repeat_count = repeat_count // 2
if prev_idx==vocab.sep_idx:
duration = idx - vocab.dur_range[0]
last_pos = last_pos + duration
bars_pred = (last_pos - start_pos) // 16
abs_bar = last_pos // 16
# if (bars % 8 == 0) and (bars_pred > min_bars): break
if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
if idx==vocab.bos_idx:
print('Predicted BOS token. Returning prediction...')
break
new_idx.append(idx)
x = x.new_tensor([idx])
pos = pos.new_tensor([last_pos])
pred = vocab.to_music_item(np.array(new_idx))
full = item.append(pred)
return pred, full
def predict_mask(self, masked_item:MusicItem,
temperatures:float=(1.0,1.0),
top_k=20, top_p=0.8):
x = masked_item.to_tensor()
pos = masked_item.get_pos_tensor()
y = torch.tensor([0])
vocab = self.data.vocab
self.model.reset()
mask_idxs = (x == vocab.mask_idx).nonzero().view(-1)
repeat_count = 0
for midx in progress_bar(mask_idxs, leave=True):
prev_idx = x[midx-1]
# Using original positions, otherwise model gets too off track
# pos = torch.tensor(-position_enc(xb[0].cpu().numpy()), device=xb.device)[None]
# Next Word
logits = self.pred_batch(batch=({ 'msk': { 'x': x[None], 'pos': pos[None] } }, y) )['msk'][0][midx]
# Temperature
# Use first temperatures value if last prediction was duration
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
temperature += repeat_penalty
if temperature != 1.: logits = logits / temperature
# Filter
filter_value = -float('Inf')
special_idxs = [vocab.bos_idx, vocab.sep_idx, vocab.stoi[EOS]]
logits[special_idxs] = filter_value # Don't allow any special tokens (as we are only removing notes and durations)
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
# Sampling
probs = F.softmax(logits, dim=-1)
idx = torch.multinomial(probs, 1).item()
# Update repeat count
num_choices = len(probs.nonzero().view(-1))
if num_choices <= 2: repeat_count += 1
else: repeat_count = repeat_count // 2
x[midx] = idx
return vocab.to_music_item(x.cpu().numpy())
def predict_s2s(self, input_item:MusicItem, target_item:MusicItem, n_words:int=256,
temperatures:float=(1.0,1.0), top_k=30, top_p=0.8,
use_memory=True):
vocab = self.data.vocab
# Input doesn't change. We can reuse the encoder output on each prediction
with torch.no_grad():
inp, inp_pos = input_item.to_tensor(), input_item.get_pos_tensor()
x_enc = self.model.encoder(inp[None], inp_pos[None])
# target
targ = target_item.data.tolist()
targ_pos = target_item.position.tolist()
last_pos = targ_pos[-1]
self.model.reset()
repeat_count = 0
max_pos = input_item.position[-1] + SAMPLE_FREQ * 4 # Only predict until both tracks/parts have the same length
x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
for i in progress_bar(range(n_words), leave=True):
# Predict
with torch.no_grad():
dec = self.model.decoder(x[None], pos[None], x_enc)
logits = self.model.head(dec)[-1, -1]
# Temperature
# Use first temperatures value if last prediction was duration
prev_idx = targ[-1] if len(targ) else vocab.pad_idx
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
temperature += repeat_penalty
if temperature != 1.: logits = logits / temperature
# Filter
filter_value = -float('Inf')
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
# Sample
probs = F.softmax(logits, dim=-1)
idx = torch.multinomial(probs, 1).item()
# Update repeat count
num_choices = len(probs.nonzero().view(-1))
if num_choices <= 2: repeat_count += 1
else: repeat_count = repeat_count // 2
if idx == vocab.bos_idx | idx == vocab.stoi[EOS]:
print('Predicting BOS/EOS')
break
if prev_idx == vocab.sep_idx:
duration = idx - vocab.dur_range[0]
last_pos = last_pos + duration
if last_pos > max_pos:
print('Predicted past counter-part length. Returning early')
break
targ_pos.append(last_pos)
targ.append(idx)
if use_memory:
# Relying on memory for kv. Only need last prediction index
x, pos = inp.new_tensor([targ[-1]]), inp_pos.new_tensor([targ_pos[-1]])
else:
# Reset memory after each prediction, since we feeding the whole sequence every time
self.model.reset()
x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
return vocab.to_music_item(np.array(targ))
# High level prediction functions from midi file
def nw_predict_from_midi(learn, midi=None, n_words=400,
temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
vocab = learn.data.vocab
seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
if seed_len is not None: seed = seed.trim_to_beat(seed_len)
pred, full = learn.predict_nw(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
return full
def s2s_predict_from_midi(learn, midi=None, n_words=200,
temperatures=(1.0,1.0), top_k=24, top_p=0.7, seed_len=None, pred_melody=True, **kwargs):
multitrack_item = MultitrackItem.from_file(midi, learn.data.vocab)
melody, chords = multitrack_item.melody, multitrack_item.chords
inp, targ = (chords, melody) if pred_melody else (melody, chords)
# if seed_len is passed, cutoff sequence so we can predict the rest
if seed_len is not None: targ = targ.trim_to_beat(seed_len)
targ = targ.remove_eos()
pred = learn.predict_s2s(inp, targ, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
part_order = (pred, inp) if pred_melody else (inp, pred)
return MultitrackItem(*part_order)
def mask_predict_from_midi(learn, midi=None, predict_notes=True,
temperatures=(1.0,1.0), top_k=30, top_p=0.7, section=None, **kwargs):
item = MusicItem.from_file(midi, learn.data.vocab)
masked_item = item.mask_pitch(section) if predict_notes else item.mask_duration(section)
pred = learn.predict_mask(masked_item, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
return pred
# LOSS AND METRICS
class MultiLoss():
def __init__(self, ignore_index=None):
"Loss mult - Mask, NextWord, Seq2Seq"
self.loss = CrossEntropyFlat(ignore_index=ignore_index)
def __call__(self, inputs:Dict[str,Tensor], targets:Dict[str,Tensor])->Rank0Tensor:
losses = [self.loss(inputs[key], target) for key,target in targets.items()]
return sum(losses)
def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx)->Rank0Tensor:
if input is None or targ is None: return None
n = targ.shape[0]
input = input.argmax(dim=-1).view(n,-1)
targ = targ.view(n,-1)
mask = targ != pad_idx
return (input[mask]==targ[mask]).float().mean()
def acc_index(inputs, targets, key, pad_idx):
return acc_ignore_pad(inputs.get(key), targets.get(key), pad_idx)
def mask_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'msk', pad_idx)
def lm_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'lm', pad_idx)
def c2m_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'c2m', pad_idx)
def m2c_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'm2c', pad_idx)
class AverageMultiMetric(AverageMetric):
"Updated fastai.AverageMetric to support multi task metrics."
def on_batch_end(self, last_output, last_target, **kwargs):
"Update metric computation with `last_output` and `last_target`."
if not is_listy(last_target): last_target=[last_target]
val = self.func(last_output, *last_target)
if val is None: return
self.count += first_el(last_target).size(0)
if self.world:
val = val.clone()
dist.all_reduce(val, op=dist.ReduceOp.SUM)
val /= self.world
self.val += first_el(last_target).size(0) * val.detach().cpu()
def on_epoch_end(self, last_metrics, **kwargs):
"Set the final result in `last_metrics`."
if self.count == 0: return add_metrics(last_metrics, 0)
return add_metrics(last_metrics, self.val/self.count)
# MODEL LOADING
class MTTrainer(LearnerCallback):
"`Callback` that regroups lr adjustment to seq_len, AR and TAR."
def __init__(self, learn:Learner, dataloaders=None, starting_mask_window=1):
super().__init__(learn)
self.count = 1
self.mw_start = starting_mask_window
self.dataloaders = dataloaders
def on_epoch_begin(self, **kwargs):
"Reset the hidden state of the model."
model = get_model(self.learn.model)
model.reset()
model.encoder.mask_steps = max(self.count+self.mw_start, 100)
def on_epoch_end(self, last_metrics, **kwargs):
"Finish the computation and sends the result to the Recorder."
if self.dataloaders is not None:
self.learn.data = self.dataloaders[self.count % len(self.dataloaders)]
self.count += 1