Spaces:
Build error
Build error
from fastai.basics import * | |
from fastai.text.models.transformer import TransformerXL | |
from ..utils.attention_mask import rand_window_mask | |
class MusicTransformerXL(TransformerXL): | |
"Exactly like fastai's TransformerXL, but with more aggressive attention mask: see `rand_window_mask`" | |
def __init__(self, *args, encode_position=True, mask_steps=1, **kwargs): | |
import inspect | |
sig = inspect.signature(TransformerXL) | |
arg_params = { k:kwargs[k] for k in sig.parameters if k in kwargs } | |
super().__init__(*args, **arg_params) | |
self.encode_position = encode_position | |
if self.encode_position: self.beat_enc = BeatPositionEncoder(kwargs['d_model']) | |
self.mask_steps=mask_steps | |
def forward(self, x): | |
#The hidden state has to be initiliazed in the forward pass for nn.DataParallel | |
if self.mem_len > 0 and not self.init: | |
self.reset() | |
self.init = True | |
benc = 0 | |
if self.encode_position: | |
x,pos = x['x'], x['pos'] | |
benc = self.beat_enc(pos) | |
bs,x_len = x.size() | |
inp = self.drop_emb(self.encoder(x) + benc) #.mul_(self.d_model ** 0.5) | |
m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0 | |
seq_len = m_len + x_len | |
mask = rand_window_mask(x_len, m_len, inp.device, max_size=self.mask_steps, is_eval=not self.training) if self.mask else None | |
if m_len == 0: mask[...,0,0] = 0 | |
#[None,:,:None] for einsum implementation of attention | |
hids = [] | |
pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype) | |
pos_enc = self.pos_enc(pos) | |
hids.append(inp) | |
for i, layer in enumerate(self.layers): | |
mem = self.hidden[i] if self.mem_len > 0 else None | |
inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem) | |
hids.append(inp) | |
core_out = inp[:,-x_len:] | |
if self.mem_len > 0 : self._update_mems(hids) | |
return (self.hidden if self.mem_len > 0 else [core_out]),[core_out] | |
# Beat encoder | |
class BeatPositionEncoder(nn.Module): | |
"Embedding + positional encoding + dropout" | |
def __init__(self, emb_sz:int, beat_len=32, max_bar_len=1024): | |
super().__init__() | |
self.beat_len, self.max_bar_len = beat_len, max_bar_len | |
self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0) | |
self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0) | |
def forward(self, pos): | |
beat_enc = self.beat_enc(pos % self.beat_len) | |
bar_pos = pos // self.beat_len % self.max_bar_len | |
bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1 | |
bar_enc = self.bar_enc((bar_pos)) | |
return beat_enc + bar_enc |