Spaces:
Build error
Build error
from fastai.basics import * | |
from fastai.text.models.transformer import Activation, PositionalEncoding, feed_forward, init_transformer, _line_shift | |
from fastai.text.models.awd_lstm import RNNDropout | |
from ..utils.attention_mask import * | |
def get_multitask_model(vocab_size:int, config:dict=None, drop_mult:float=1., pad_idx=None): | |
"Create a language model from `arch` and its `config`, maybe `pretrained`." | |
for k in config.keys(): | |
if k.endswith('_p'): config[k] *= drop_mult | |
n_hid = config['d_model'] | |
mem_len = config.pop('mem_len') | |
embed = TransformerEmbedding(vocab_size, n_hid, embed_p=config['embed_p'], mem_len=mem_len, pad_idx=pad_idx) | |
encoder = MTEncoder(embed, n_hid, n_layers=config['enc_layers'], mem_len=0, **config) # encoder doesn't need memory | |
decoder = MTEncoder(embed, n_hid, is_decoder=True, n_layers=config['dec_layers'], mem_len=mem_len, **config) | |
head = MTLinearDecoder(n_hid, vocab_size, tie_encoder=embed.embed, **config) | |
model = MultiTransformer(encoder, decoder, head, mem_len=mem_len) | |
return model.apply(init_transformer) | |
class MultiTransformer(nn.Module): | |
"Multitask Transformer for training mask, next word, and sequence 2 sequence" | |
def __init__(self, encoder, decoder, head, mem_len): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.head = head | |
self.default_mem_len = mem_len | |
self.current_mem_len = None | |
def forward(self, inp): | |
# data order: mask, next word, melody, chord | |
outputs = {} | |
msk, lm, c2m, m2c = [inp.get(key) for key in ['msk', 'lm', 'c2m', 'm2c']] | |
if msk is not None: | |
outputs['msk'] = self.head(self.encoder(msk['x'], msk['pos'])) | |
if lm is not None: | |
outputs['lm'] = self.head(self.decoder(lm['x'], lm['pos'])) | |
if c2m is not None: | |
self.reset() | |
c2m_enc = self.encoder(c2m['enc'], c2m['enc_pos']) | |
c2m_dec = self.decoder(c2m['dec'], c2m['dec_pos'], c2m_enc) | |
outputs['c2m'] = self.head(c2m_dec) | |
if m2c is not None: | |
self.reset() | |
m2c_enc = self.encoder(m2c['enc'], m2c['enc_pos']) | |
m2c_dec = self.decoder(m2c['dec'], m2c['dec_pos'], m2c_enc) | |
outputs['m2c'] = self.head(m2c_dec) | |
return outputs | |
"A sequential module that passes the reset call to its children." | |
def reset(self): | |
for module in self.children(): | |
reset_children(module) | |
def reset_children(mod): | |
if hasattr(mod, 'reset'): mod.reset() | |
for module in mod.children(): | |
reset_children(module) | |
# COMPONENTS | |
class TransformerEmbedding(nn.Module): | |
"Embedding + positional encoding + dropout" | |
def __init__(self, vocab_size:int, emb_sz:int, embed_p:float=0., mem_len=512, beat_len=32, max_bar_len=1024, pad_idx=None): | |
super().__init__() | |
self.emb_sz = emb_sz | |
self.pad_idx = pad_idx | |
self.embed = nn.Embedding(vocab_size, emb_sz, padding_idx=pad_idx) | |
self.pos_enc = PositionalEncoding(emb_sz) | |
self.beat_len, self.max_bar_len = beat_len, max_bar_len | |
self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0) | |
self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0) | |
self.drop = nn.Dropout(embed_p) | |
self.mem_len = mem_len | |
def forward(self, inp, pos): | |
beat_enc = self.beat_enc(pos % self.beat_len) | |
bar_pos = pos // self.beat_len % self.max_bar_len | |
bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1 | |
bar_enc = self.bar_enc((bar_pos)) | |
emb = self.drop(self.embed(inp) + beat_enc + bar_enc) | |
return emb | |
def relative_pos_enc(self, emb): | |
# return torch.arange(640-1, -1, -1).float().cuda() | |
seq_len = emb.shape[1] + self.mem_len | |
pos = torch.arange(seq_len-1, -1, -1, device=emb.device, dtype=emb.dtype) # backwards (txl pos encoding) | |
return self.pos_enc(pos) | |
class MTLinearDecoder(nn.Module): | |
"To go on top of a RNNCore module and create a Language Model." | |
initrange=0.1 | |
def __init__(self, n_hid:int, n_out:int, output_p:float, tie_encoder:nn.Module=None, out_bias:bool=True, **kwargs): | |
super().__init__() | |
self.decoder = nn.Linear(n_hid, n_out, bias=out_bias) | |
self.decoder.weight.data.uniform_(-self.initrange, self.initrange) | |
self.output_dp = RNNDropout(output_p) | |
if out_bias: self.decoder.bias.data.zero_() | |
if tie_encoder: self.decoder.weight = tie_encoder.weight | |
def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]: | |
output = self.output_dp(input) | |
decoded = self.decoder(output) | |
return decoded | |
# DECODER TRANSLATE BLOCK | |
class MTEncoder(nn.Module): | |
def __init__(self, embed:nn.Module, n_hid:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, | |
resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=True, scale:bool=True, | |
act:Activation=Activation.ReLU, double_drop:bool=True, mem_len:int=512, is_decoder=False, | |
mask_steps=1, mask_p=0.3, **kwargs): | |
super().__init__() | |
self.embed = embed | |
self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention | |
self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention | |
self.n_layers,self.d_model = n_layers,d_model | |
self.layers = nn.ModuleList([MTEncoderBlock(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p, | |
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, mem_len=mem_len, | |
) for k in range(n_layers)]) | |
self.mask_steps, self.mask_p = mask_steps, mask_p | |
self.is_decoder = is_decoder | |
nn.init.normal_(self.u, 0., 0.02) | |
nn.init.normal_(self.v, 0., 0.02) | |
def forward(self, x_lm, lm_pos, msk_emb=None): | |
bs,lm_len = x_lm.size() | |
lm_emb = self.embed(x_lm, lm_pos) | |
if msk_emb is not None and msk_emb.shape[1] > lm_emb.shape[1]: | |
pos_enc = self.embed.relative_pos_enc(msk_emb) | |
else: | |
pos_enc = self.embed.relative_pos_enc(lm_emb) | |
# Masks | |
if self.is_decoder: | |
lm_mask = rand_window_mask(lm_len, self.embed.mem_len, x_lm.device, | |
max_size=self.mask_steps, p=self.mask_p, is_eval=not self.training) | |
else: | |
lm_mask = None | |
for i, layer in enumerate(self.layers): | |
lm_emb = layer(lm_emb, msk_emb, lm_mask=lm_mask, | |
r=pos_enc, g_u=self.u, g_v=self.v) | |
return lm_emb | |
class MTEncoderBlock(nn.Module): | |
"Decoder block of a Transformer model." | |
#Can't use Sequential directly cause more than one input... | |
def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0., | |
bias:bool=True, scale:bool=True, double_drop:bool=True, mem_len:int=512, mha2_mem_len=0, **kwargs): | |
super().__init__() | |
attn_cls = MemMultiHeadRelativeAttentionKV | |
self.mha1 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mem_len, r_mask=False) | |
self.mha2 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mha2_mem_len, r_mask=True) | |
self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, double_drop=double_drop) | |
def forward(self, enc_lm:Tensor, enc_msk:Tensor, | |
r=None, g_u=None, g_v=None, | |
msk_mask:Tensor=None, lm_mask:Tensor=None): | |
y_lm = self.mha1(enc_lm, enc_lm, enc_lm, r, g_u, g_v, mask=lm_mask) | |
if enc_msk is None: return y_lm | |
return self.ff(self.mha2(y_lm, enc_msk, enc_msk, r, g_u, g_v, mask=msk_mask)) | |
# Attention Layer | |
# Attn | |
class MemMultiHeadRelativeAttentionKV(nn.Module): | |
"Attention Layer monster - relative positioning, keeps track of own memory, separate kv weights to support sequence2sequence decoding." | |
def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True, | |
scale:bool=True, mem_len:int=512, r_mask=True): | |
super().__init__() | |
d_head = ifnone(d_head, d_model//n_heads) | |
self.n_heads,self.d_head,self.scale = n_heads,d_head,scale | |
assert(d_model == d_head * n_heads) | |
self.q_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias) | |
self.k_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias) | |
self.v_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias) | |
self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p) | |
self.ln = nn.LayerNorm(d_model) | |
self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias) | |
self.r_mask = r_mask | |
self.mem_len = mem_len | |
self.prev_k = None | |
self.prev_v = None | |
def forward(self, q:Tensor, k:Tensor=None, v:Tensor=None, | |
r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None, | |
mask:Tensor=None, **kwargs): | |
if k is None: k = q | |
if v is None: v = q | |
return self.ln(q + self.drop_res(self._apply_attention(q, k, v, r, g_u, g_v, mask=mask, **kwargs))) | |
def mem_k(self, k): | |
if self.mem_len == 0: return k | |
if self.prev_k is None or (self.prev_k.shape[0] != k.shape[0]): # reset if wrong batch size | |
self.prev_k = k[:, -self.mem_len:] | |
return k | |
with torch.no_grad(): | |
k_ext = torch.cat([self.prev_k, k], dim=1) | |
self.prev_k = k_ext[:, -self.mem_len:] | |
return k_ext.detach() | |
def mem_v(self, v): | |
if self.mem_len == 0: return v | |
if self.prev_v is None or (self.prev_v.shape[0] != v.shape[0]): # reset if wrong batch size | |
self.prev_v = v[:, -self.mem_len:] | |
return v | |
with torch.no_grad(): | |
v_ext = torch.cat([self.prev_v, v], dim=1) | |
self.prev_v = v_ext[:, -self.mem_len:] | |
return v_ext.detach() | |
def reset(self): | |
self.prev_v = None | |
self.prev_k = None | |
def _apply_attention(self, q:Tensor, k:Tensor, v:Tensor, | |
r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None, | |
mask:Tensor=None, **kwargs): | |
#Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable | |
#parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states. | |
# bs,x_len,seq_len = q.size(0),q.size(1),r.size(0) | |
k = self.mem_k(k) | |
v = self.mem_v(v) | |
bs,x_len,seq_len = q.size(0),q.size(1),k.size(1) | |
wq,wk,wv = self.q_wgt(q),self.k_wgt(k),self.v_wgt(v) | |
wq = wq[:,-x_len:] | |
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv)) | |
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3) | |
wkr = self.r_attn(r[-seq_len:]) | |
wkr = wkr.view(seq_len, self.n_heads, self.d_head) | |
wkr = wkr.permute(1,2,0) | |
#### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper) | |
AC = torch.matmul(wq+g_u,wk) | |
BD = _line_shift(torch.matmul(wq+g_v, wkr), mask=self.r_mask) | |
if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5)) | |
if mask is not None: | |
mask = mask[...,-seq_len:] | |
if hasattr(mask, 'bool'): mask = mask.bool() | |
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score) | |
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1)) | |
attn_vec = torch.matmul(attn_prob, wv) | |
return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1) | |