caslabs's picture
Upload 37 files
f35cc94
from fastai.basics import *
from fastai.text.models.transformer import Activation, PositionalEncoding, feed_forward, init_transformer, _line_shift
from fastai.text.models.awd_lstm import RNNDropout
from ..utils.attention_mask import *
def get_multitask_model(vocab_size:int, config:dict=None, drop_mult:float=1., pad_idx=None):
"Create a language model from `arch` and its `config`, maybe `pretrained`."
for k in config.keys():
if k.endswith('_p'): config[k] *= drop_mult
n_hid = config['d_model']
mem_len = config.pop('mem_len')
embed = TransformerEmbedding(vocab_size, n_hid, embed_p=config['embed_p'], mem_len=mem_len, pad_idx=pad_idx)
encoder = MTEncoder(embed, n_hid, n_layers=config['enc_layers'], mem_len=0, **config) # encoder doesn't need memory
decoder = MTEncoder(embed, n_hid, is_decoder=True, n_layers=config['dec_layers'], mem_len=mem_len, **config)
head = MTLinearDecoder(n_hid, vocab_size, tie_encoder=embed.embed, **config)
model = MultiTransformer(encoder, decoder, head, mem_len=mem_len)
return model.apply(init_transformer)
class MultiTransformer(nn.Module):
"Multitask Transformer for training mask, next word, and sequence 2 sequence"
def __init__(self, encoder, decoder, head, mem_len):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.head = head
self.default_mem_len = mem_len
self.current_mem_len = None
def forward(self, inp):
# data order: mask, next word, melody, chord
outputs = {}
msk, lm, c2m, m2c = [inp.get(key) for key in ['msk', 'lm', 'c2m', 'm2c']]
if msk is not None:
outputs['msk'] = self.head(self.encoder(msk['x'], msk['pos']))
if lm is not None:
outputs['lm'] = self.head(self.decoder(lm['x'], lm['pos']))
if c2m is not None:
self.reset()
c2m_enc = self.encoder(c2m['enc'], c2m['enc_pos'])
c2m_dec = self.decoder(c2m['dec'], c2m['dec_pos'], c2m_enc)
outputs['c2m'] = self.head(c2m_dec)
if m2c is not None:
self.reset()
m2c_enc = self.encoder(m2c['enc'], m2c['enc_pos'])
m2c_dec = self.decoder(m2c['dec'], m2c['dec_pos'], m2c_enc)
outputs['m2c'] = self.head(m2c_dec)
return outputs
"A sequential module that passes the reset call to its children."
def reset(self):
for module in self.children():
reset_children(module)
def reset_children(mod):
if hasattr(mod, 'reset'): mod.reset()
for module in mod.children():
reset_children(module)
# COMPONENTS
class TransformerEmbedding(nn.Module):
"Embedding + positional encoding + dropout"
def __init__(self, vocab_size:int, emb_sz:int, embed_p:float=0., mem_len=512, beat_len=32, max_bar_len=1024, pad_idx=None):
super().__init__()
self.emb_sz = emb_sz
self.pad_idx = pad_idx
self.embed = nn.Embedding(vocab_size, emb_sz, padding_idx=pad_idx)
self.pos_enc = PositionalEncoding(emb_sz)
self.beat_len, self.max_bar_len = beat_len, max_bar_len
self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0)
self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0)
self.drop = nn.Dropout(embed_p)
self.mem_len = mem_len
def forward(self, inp, pos):
beat_enc = self.beat_enc(pos % self.beat_len)
bar_pos = pos // self.beat_len % self.max_bar_len
bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1
bar_enc = self.bar_enc((bar_pos))
emb = self.drop(self.embed(inp) + beat_enc + bar_enc)
return emb
def relative_pos_enc(self, emb):
# return torch.arange(640-1, -1, -1).float().cuda()
seq_len = emb.shape[1] + self.mem_len
pos = torch.arange(seq_len-1, -1, -1, device=emb.device, dtype=emb.dtype) # backwards (txl pos encoding)
return self.pos_enc(pos)
class MTLinearDecoder(nn.Module):
"To go on top of a RNNCore module and create a Language Model."
initrange=0.1
def __init__(self, n_hid:int, n_out:int, output_p:float, tie_encoder:nn.Module=None, out_bias:bool=True, **kwargs):
super().__init__()
self.decoder = nn.Linear(n_hid, n_out, bias=out_bias)
self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
self.output_dp = RNNDropout(output_p)
if out_bias: self.decoder.bias.data.zero_()
if tie_encoder: self.decoder.weight = tie_encoder.weight
def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
output = self.output_dp(input)
decoded = self.decoder(output)
return decoded
# DECODER TRANSLATE BLOCK
class MTEncoder(nn.Module):
def __init__(self, embed:nn.Module, n_hid:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,
resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=True, scale:bool=True,
act:Activation=Activation.ReLU, double_drop:bool=True, mem_len:int=512, is_decoder=False,
mask_steps=1, mask_p=0.3, **kwargs):
super().__init__()
self.embed = embed
self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
self.n_layers,self.d_model = n_layers,d_model
self.layers = nn.ModuleList([MTEncoderBlock(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, mem_len=mem_len,
) for k in range(n_layers)])
self.mask_steps, self.mask_p = mask_steps, mask_p
self.is_decoder = is_decoder
nn.init.normal_(self.u, 0., 0.02)
nn.init.normal_(self.v, 0., 0.02)
def forward(self, x_lm, lm_pos, msk_emb=None):
bs,lm_len = x_lm.size()
lm_emb = self.embed(x_lm, lm_pos)
if msk_emb is not None and msk_emb.shape[1] > lm_emb.shape[1]:
pos_enc = self.embed.relative_pos_enc(msk_emb)
else:
pos_enc = self.embed.relative_pos_enc(lm_emb)
# Masks
if self.is_decoder:
lm_mask = rand_window_mask(lm_len, self.embed.mem_len, x_lm.device,
max_size=self.mask_steps, p=self.mask_p, is_eval=not self.training)
else:
lm_mask = None
for i, layer in enumerate(self.layers):
lm_emb = layer(lm_emb, msk_emb, lm_mask=lm_mask,
r=pos_enc, g_u=self.u, g_v=self.v)
return lm_emb
class MTEncoderBlock(nn.Module):
"Decoder block of a Transformer model."
#Can't use Sequential directly cause more than one input...
def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,
bias:bool=True, scale:bool=True, double_drop:bool=True, mem_len:int=512, mha2_mem_len=0, **kwargs):
super().__init__()
attn_cls = MemMultiHeadRelativeAttentionKV
self.mha1 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mem_len, r_mask=False)
self.mha2 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mha2_mem_len, r_mask=True)
self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, double_drop=double_drop)
def forward(self, enc_lm:Tensor, enc_msk:Tensor,
r=None, g_u=None, g_v=None,
msk_mask:Tensor=None, lm_mask:Tensor=None):
y_lm = self.mha1(enc_lm, enc_lm, enc_lm, r, g_u, g_v, mask=lm_mask)
if enc_msk is None: return y_lm
return self.ff(self.mha2(y_lm, enc_msk, enc_msk, r, g_u, g_v, mask=msk_mask))
# Attention Layer
# Attn
class MemMultiHeadRelativeAttentionKV(nn.Module):
"Attention Layer monster - relative positioning, keeps track of own memory, separate kv weights to support sequence2sequence decoding."
def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True,
scale:bool=True, mem_len:int=512, r_mask=True):
super().__init__()
d_head = ifnone(d_head, d_model//n_heads)
self.n_heads,self.d_head,self.scale = n_heads,d_head,scale
assert(d_model == d_head * n_heads)
self.q_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
self.k_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
self.v_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)
self.ln = nn.LayerNorm(d_model)
self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)
self.r_mask = r_mask
self.mem_len = mem_len
self.prev_k = None
self.prev_v = None
def forward(self, q:Tensor, k:Tensor=None, v:Tensor=None,
r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None,
mask:Tensor=None, **kwargs):
if k is None: k = q
if v is None: v = q
return self.ln(q + self.drop_res(self._apply_attention(q, k, v, r, g_u, g_v, mask=mask, **kwargs)))
def mem_k(self, k):
if self.mem_len == 0: return k
if self.prev_k is None or (self.prev_k.shape[0] != k.shape[0]): # reset if wrong batch size
self.prev_k = k[:, -self.mem_len:]
return k
with torch.no_grad():
k_ext = torch.cat([self.prev_k, k], dim=1)
self.prev_k = k_ext[:, -self.mem_len:]
return k_ext.detach()
def mem_v(self, v):
if self.mem_len == 0: return v
if self.prev_v is None or (self.prev_v.shape[0] != v.shape[0]): # reset if wrong batch size
self.prev_v = v[:, -self.mem_len:]
return v
with torch.no_grad():
v_ext = torch.cat([self.prev_v, v], dim=1)
self.prev_v = v_ext[:, -self.mem_len:]
return v_ext.detach()
def reset(self):
self.prev_v = None
self.prev_k = None
def _apply_attention(self, q:Tensor, k:Tensor, v:Tensor,
r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None,
mask:Tensor=None, **kwargs):
#Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable
#parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states.
# bs,x_len,seq_len = q.size(0),q.size(1),r.size(0)
k = self.mem_k(k)
v = self.mem_v(v)
bs,x_len,seq_len = q.size(0),q.size(1),k.size(1)
wq,wk,wv = self.q_wgt(q),self.k_wgt(k),self.v_wgt(v)
wq = wq[:,-x_len:]
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
wkr = self.r_attn(r[-seq_len:])
wkr = wkr.view(seq_len, self.n_heads, self.d_head)
wkr = wkr.permute(1,2,0)
#### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)
AC = torch.matmul(wq+g_u,wk)
BD = _line_shift(torch.matmul(wq+g_v, wkr), mask=self.r_mask)
if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))
if mask is not None:
mask = mask[...,-seq_len:]
if hasattr(mask, 'bool'): mask = mask.bool()
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
attn_vec = torch.matmul(attn_prob, wv)
return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)