akhaliq3
spaces demo
546a9ba
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import json
import math
import re
import collections
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
def gelu(x):
return (
0.5
* x
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
)
def swish(x):
return x * torch.sigmoid(x)
class LayerNorm(nn.Module):
"Construct a layernorm module in the OpenAI style (epsilon inside the square root)."
def __init__(self, n_state, e=1e-5):
super(LayerNorm, self).__init__()
self.g = nn.Parameter(torch.ones(n_state))
self.b = nn.Parameter(torch.zeros(n_state))
self.e = e
"""
Input:
x: n_state-dim
Output:
o: n_state-dim
"""
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.e)
return self.g * x + self.b
"""
Convolution
nx is the last input dim
nf is the last output dim
"""
class Conv1D(nn.Module):
def __init__(self, nf, nx):
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.w = Parameter(w)
self.b = Parameter(torch.zeros(nf))
"""
Input:
x: batch x len x nx
Output:
x: batch x len x nf
"""
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
x = x.view(*size_out)
return x
class PositionalEmbedding(nn.Module):
def __init__(self, opt, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.pos_discount = float(opt["TRANSFORMER_POS_DISCOUNT"])
self.register_buffer("inv_freq", inv_freq)
"""
Input:
pos_seq: len
Output:
pos_emb: len x demb
"""
def forward(self, pos_seq):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = (
torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
/ self.pos_discount
)
return pos_emb
"""
Splitter
"""
class Splitter(nn.Module):
def __init__(self, nx):
super(Splitter, self).__init__()
self.nx = nx
self.augmenter = Conv1D(nx * 3, nx)
"""
Input:
x: batch x len x nx
Output:
query,key,value: batch x len x nx
"""
def forward(self, x):
x = self.augmenter(x)
# x: batch x len x (3 x nx)
query, key, value = x.split(self.nx, dim=2)
# query,key,value: batch x len x nx
return query, key, value
"""
Multi-head Attention
"""
class Attention(nn.Module):
"""
nx: input dimension
"""
def __init__(self, nx, opt):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
n_head = int(opt["TRANSFORMER_HEAD"])
resid_pdrop = opt["TRANSFORMER_RESIDUAL_DROPOUT"]
attn_pdrop = opt["TRANSFORMER_ATTENTION_DROPOUT"]
use_cuda = opt["cuda"]
assert n_state % n_head == 0
# if mask is needed, uncomment this
self.maxlen = 2048 # beyond this scale
self.mask = (
Variable(
torch.tril(torch.ones(self.maxlen, self.maxlen)).view(
1, 1, self.maxlen, self.maxlen
),
requires_grad=False,
).cuda()
if use_cuda
else Variable(
torch.tril(torch.ones(self.maxlen, self.maxlen)).view(
1, 1, self.maxlen, self.maxlen
),
requires_grad=False,
)
)
self.n_head = n_head
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(attn_pdrop)
self.resid_dropout = nn.Dropout(resid_pdrop)
self.use_cuda = use_cuda
"""
Input:
q: batch x n_head x len x dim
k: batch x n_head x dim x kv_len
v: batch x n_head x kv_len x dim
x_mask: batch x kv_len # key and value's mask (if not None, used for encoder's self-attention and decoder's src-tgt attention)
one_dir_visible: only sees previous history (used for decoder's self-attention)
return_attn_weight: if true, also return the attention weights
Output:
a: batch x n_head x len x n_state x dim
attn_weight (if return_attn_weight): attn_weight: batch x n_head x len x kv_len
"""
def _attn(self, q, k, v, x_mask, one_dir_visible, return_attn_weight):
w = torch.matmul(q, k)
# batch x n_head x len x kv_len
w = w / math.sqrt(v.size(-1))
mask = None
if one_dir_visible: # mask "seeing the future"
if w.size(-2) <= self.maxlen and w.size(-1) <= self.maxlen:
mask = (
self.mask[:, :, : w.size(-2), : w.size(-1)].cuda()
if self.use_cuda
else self.mask[:, :, : w.size(-2), : w.size(-1)]
)
else:
mask = (
Variable(
torch.tril(torch.ones(w.size(-2), w.size(-1))).view(
1, 1, w.size(-2), w.size(-1)
),
requires_grad=False,
).cuda()
if self.use_cuda
else Variable(
torch.tril(torch.ones(w.size(-2), w.size(-1))).view(
1, 1, w.size(-2), w.size(-1)
),
requires_grad=False,
)
)
if x_mask is not None:
mask = x_mask.unsqueeze(1).unsqueeze(1).expand_as(w).float()
# batch x n_head x len x kv_len
if mask is not None:
w = w * mask + -1e9 * (1 - mask)
w_prob = nn.Softmax(dim=-1)(w)
w_prob = self.attn_dropout(w_prob)
if return_attn_weight:
return torch.matmul(w_prob, v), w
else:
return torch.matmul(w_prob, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
"""
Input:
x: batch x len x dim
Output:
not k: batch x n_head x (dim/n_head) x len
k: batch x n_head x len x (dim/n_head)
"""
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
"""
Input:
query: batch x len x n_state
key, value: batch x kv_len x n_state
x_mask: batch x kv_len # key and value's mask (if not None, used for encoder's self-attention and decoder's src-tgt attention)
one_dir_visible: only sees previous history (used for decoder's self-attention)
return_attn_weight: if true, also return the attention weights
Output:
a: batch x len x n_state
attn_weight (if return_attn_weight): batch x len x kv_len
"""
def forward(
self, query, key, value, x_mask, one_dir_visible=False, return_attn_weight=False
):
query = self.split_heads(query)
# batch x n_head x len x (n_state/n_head)
key = self.split_heads(key, k=True)
# batch x n_head x (n_state/n_head) x kv_len
value = self.split_heads(value)
# batch x n_head x kv_len x (n_state/n_head)
out = self._attn(query, key, value, x_mask, one_dir_visible, return_attn_weight)
if return_attn_weight:
a, attn_weight = out
# a: batch x n_head x len x (n_state/n_head)
# attn_weight: batch x n_head x len x kv_len
attn_weight = attn_weight.permute(0, 2, 3, 1).contiguous()
# batch x len x kv_len x n_head
attn_weight = torch.sum(attn_weight, dim=3)
# batch x len x kv_len
else:
a = out
# batch x n_head x len x (n_state/n_head)
a = self.merge_heads(a)
# batch x len x n_state
a = self.c_proj(a)
# batch x len x n_state
a = self.resid_dropout(a)
# batch x len x n_state
if return_attn_weight:
return a, attn_weight
else:
return a
"""
Two-layer network
"""
class MLP(nn.Module):
"""
Input:
n_state: intermediate dim
"""
def __init__(self, n_state, opt): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = int(opt["transformer_embed_dim"])
resid_pdrop = opt["TRANSFORMER_RESIDUAL_DROPOUT"]
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
self.dropout = nn.Dropout(resid_pdrop)
"""
Input:
x: batch x len x nx
Output: batch x len x nx
"""
def forward(self, x):
h = F.relu(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
"""
One encoder block of transformer
"""
class EncoderBlock(nn.Module):
def __init__(self, opt):
super(EncoderBlock, self).__init__()
nx = int(opt["transformer_embed_dim"])
self.one_dir_visible = False
if "transformer_encoder_one_dir_visible" in opt:
self.one_dir_visible = opt["transformer_encoder_one_dir_visible"]
self.splitter = Splitter(nx)
self.attn = Attention(nx, opt)
self.ln_1 = LayerNorm(nx)
self.mlp = MLP(4 * nx, opt)
self.ln_2 = LayerNorm(nx)
"""
Input:
x: batch x len x n_state
x_mask: batch x len (1 means there's something)
Output:
h: batch x len x n_state
"""
def forward(self, x, x_mask):
query, key, value = self.splitter(x)
if self.one_dir_visible:
# in this case, use triangle masking, as it's one_direction
a = self.attn(query, key, value, None, one_dir_visible=True)
else:
# in this case, use x_mask for attention masking
a = self.attn(query, key, value, x_mask, one_dir_visible=False)
n = self.ln_1(x + a) # residual
m = self.mlp(n)
h = self.ln_2(n + m)
return h
"""
One encoder block of transformer
"""
class DecoderBlock(nn.Module):
def __init__(self, opt):
super(DecoderBlock, self).__init__()
nx = int(opt["transformer_embed_dim"])
self.decoder_splitter = Splitter(nx)
self.self_attn = Attention(nx, opt)
self.cross_attn = Attention(nx, opt)
self.ln_1 = LayerNorm(nx)
self.ln_2 = LayerNorm(nx)
self.mlp = MLP(4 * nx, opt)
self.ln_3 = LayerNorm(nx)
"""
Input:
x_mask: batch x len, mask for encoder's input
y: batch x len x n_state (decoder part)
enc_key: batch x encoder_len x n_state
enc_value: batch x encoder_len x n_state
lang_model: whether it's for language model training (no encoder part is used)
Output:
h: batch x len x n_state
"""
def forward(self, x_mask, y, enc_key, enc_value, lang_model=False):
query, key, value = self.decoder_splitter(y)
# batch x len x n_state
# self-attention
a = self.self_attn(query, key, value, None, one_dir_visible=True)
# batch x len x n_state
n = self.ln_1(y + a) # residual
# seq2seq
if not lang_model:
# src-tgt attention
o = self.cross_attn(n, enc_key, enc_value, x_mask)
p = self.ln_2(n + o) # residual
# batch x len x n_state
else: # language model
p = n
m = self.mlp(p)
h = self.ln_3(p + m)
return h
"""
Embedder
"""
class Embedder(nn.Module):
"""
Input:
vocab: size of vocabulary
"""
def __init__(self, opt, embed=None):
super(Embedder, self).__init__()
n_state = int(opt["transformer_embed_dim"]) # n_state
embed_dropout_rate = opt["TRANSFORMER_EMBED_DROPOUT"]
if embed is None:
self.embed = nn.Embedding(opt["vocab_size"], n_state)
nn.init.normal_(self.embed.weight, std=0.02)
else:
self.embed = embed
self.drop = nn.Dropout(embed_dropout_rate)
self.pos_emb = PositionalEmbedding(opt, n_state)
self.use_cuda = opt["cuda"]
"""
Input:
x: batch x len (word_id)
Output:
h: batch x len x n_state
"""
def forward(self, x):
x_emb = self.embed(x)
batch_size = x.shape[0]
x_len = x.shape[1]
x_pos = self.pos_emb(
torch.arange(x_len).type(
torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
)
) # len x n_state
x_pos = (
Variable(
x_pos.unsqueeze(0).repeat(batch_size, 1, 1), requires_grad=False
).cuda()
if self.use_cuda
else Variable(
x_pos.unsqueeze(0).repeat(batch_size, 1, 1), requires_grad=False
)
)
x_input = x_emb + x_pos
h = self.drop(x_input)
return h
"""
Transformer encoder
"""
class TransformerEncoder(nn.Module):
"""
Input:
embed: (if not None) pre-computed vocab embeddings
"""
def __init__(self, opt, embed=None):
super(TransformerEncoder, self).__init__()
vocab = int(opt["vocab_size"])
n_state = int(opt["transformer_embed_dim"])
n_layer = int(opt["TRANSFORMER_LAYER"])
if "vae_z_scale_factor" in opt:
self.vae_z_scale_factor = float(opt["vae_z_scale_factor"])
self.embedder = Embedder(opt, embed)
block = EncoderBlock(opt)
self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(n_layer)])
self.use_cuda = opt["cuda"]
"""
Input:
x: batch x len (word_id)
z (optional): batch x len x n_state (for VAE)
Output:
h: batch x len x n_state (word_id)
"""
def forward(self, x, z=None):
x_mask = ~x.eq(0) # 1 is PAD_id
x_mask = x_mask.type(
torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
)
h = self.embedder(x)
if z is not None:
z *= self.vae_z_scale_factor
h += z
for block in self.blocks:
h = block(h, x_mask)
return h
"""
Transformer decoder
"""
class TransformerDecoder(nn.Module):
"""
Input:
embed: (if not None) pre-computed vocab embeddings
"""
def __init__(self, opt, embed=None):
super(TransformerDecoder, self).__init__()
self.opt = opt
vocab_size = int(opt["vocab_size"])
n_state = int(opt["transformer_embed_dim"]) # n_state
n_layer = int(opt["TRANSFORMER_LAYER"])
self.embedder = Embedder(opt, embed)
self.encoder_splitter = Splitter(n_state)
block = DecoderBlock(opt)
self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(n_layer)])
if embed is None:
self.linear = Conv1D(vocab_size, n_state)
else:
self.linear = nn.Linear(n_state, vocab_size, bias=False)
if (
"FINETUNE_RETRAIN_SOFTMAX" not in opt
): # if FINETUNE_RETRAIN_SOFTMAX, linear needs to be seperately trained
self.linear.weight = embed.weight # share weight
self.use_coda = opt["cuda"]
"""
Input:
x: batch x encoder_len (word id)
x_out: batch x encoder_len x n_state
y: batch x len (word_id) (decoder part)
lang_model: whether it's for language model training (no encoder part is used)
Output:
prob: batch x len x vocab_size (probabilities after softmax)
"""
def forward(self, x, x_out, y, lang_model=False):
# seq2seq
if not lang_model:
_, enc_key, enc_value = self.encoder_splitter(x_out)
# enc_key: batch x encoder_len x n_state
# enc_value: batch x encoder_len x n_state
x_mask = ~x.eq(0) # 1 is PAD_id
x_mask = x_mask.type(
torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
)
else:
enc_key = None
enc_value = None
x_mask = None
h = self.embedder(y)
for block in self.blocks:
h = block(x_mask, h, enc_key, enc_value, lang_model)
prob = F.softmax(self.linear(h), dim=-1)
return prob
class TransformerBeam:
"""
Input:
encoder: TransformerEncoder class
decoder: TransformerDecoder class
begin_id: word id of '<BEGIN>'
vocab: list of words
"""
def __init__(self, opt, encoder, decoder, begin_id, vocab):
self.encoder = encoder
self.decoder = decoder
self.opt = opt
self.max_sent_len = int(opt["max_sent_len"])
self.begin_id = begin_id
self.vocab = vocab
self.beam_width = int(opt["beam_width"])
self.use_cuda = opt["cuda"]
# each candidate is (idx, prob, 0/1, position/wordid)
def merge_candidates(self, cand_A, cand_B):
C = []
pA, lA, pB, lB = 0, len(cand_A), 0, len(cand_B)
lC = 0
while (pA < lA or pB < lB) and (lC < self.beam_width):
if pA < lA and (pB >= lB or cand_A[pA][1] > cand_B[pB][1]):
C.append(cand_A[pA])
pA += 1
else:
C.append(cand_B[pB])
pB += 1
lC += 1
return C
"""
Input:
x = batch * encoder_len (word_ids) encoder's input
k: top-k sampling
Output:
sents: list of words, with batch items, each one with up to beam_width (sentence, log_prob), each sentence with up to max_sent_len_word words
"""
def topk(self, x, k):
batch_size = x.shape[0]
x_len = x.shape[1]
x_out = self.encoder(x)
# x_out: batch x encoder_len x n_state
# sent_ids is the words for each of the batch_size sentences
sent_ids = []
for i in range(batch_size):
sent_ids.append([self.begin_id])
topk = 1
MIN_GEN_LENGTH = 45
if "MIN_GEN_LENGTH" in self.opt:
MIN_GEN_LENGTH = int(self.opt["MIN_GEN_LENGTH"])
for l in range(self.max_sent_len):
y = (
Variable(torch.LongTensor(sent_ids)).cuda()
if self.use_cuda
else Variable(torch.LongTensor(sent_ids))
) # batch_size x l
decoder_outputs = self.decoder(x, x_out, y)
probs = decoder_outputs[
:, -1, :
] # batch_size x vocab_size (only take the last output)
for i in range(batch_size):
topk_probs, _ = torch.topk(probs[i], k)
threshold = float(topk_probs[-1])
probs[i][probs[i] < threshold] = 0.0
samples = torch.multinomial(
probs, 2
) # sample 2 since the first one may be <END>
for i in range(batch_size):
if l < MIN_GEN_LENGTH and self.vocab[int(samples[i, 0])] == "<END>":
sent_ids[i].append(int(samples[i, 1]))
else:
sent_ids[i].append(int(samples[i, 0]))
sents = []
for i in range(batch_size):
utt = []
for j in range(len(sent_ids[i])):
w = self.vocab[sent_ids[i][j]]
if w == "<BEGIN>":
continue
if w == "<END>":
break
utt.append(w)
sents.append([(utt, 0)])
return sents
"""
Input:
x = batch * encoder_len (word_ids) encoder's input
Output:
sents: list of words, with batch items, each one with up to beam_width (sentence, log_prob), each sentence with up to max_sent_len_word words
"""
def beam_search(self, x):
batch_size = x.shape[0]
x_len = x.shape[1]
x_out = self.encoder(x)
# x_out: batch x encoder_len x n_state
sents = []
topk = 1
history_nodes = [{}]
end_nodes = {}
for idx in range(batch_size):
start_node = BeamSearchNode([self.begin_id], 0, 1)
history_nodes[0][idx] = [start_node]
end_nodes[idx] = []
for l in range(self.max_sent_len):
last_nodes = history_nodes[-1]
if sum([len(l) for i, l in last_nodes.items()]) == 0: # no nodes left
break
ys = []
x_outs = []
xs = []
for idx in range(batch_size):
ys.extend([node.word_ids for node in last_nodes[idx]])
x_outs.extend(
[x_out[idx, :, :].unsqueeze(0) for node in last_nodes[idx]]
)
xs.extend([x[idx, :].unsqueeze(0) for node in last_nodes[idx]])
ys = (
Variable(torch.LongTensor(ys)).cuda()
if self.use_cuda
else Variable(torch.LongTensor(ys))
) # N x l
x_outs = torch.cat(x_outs, dim=0) # N x x_len x n_state
xs = torch.cat(xs, dim=0) # N x x_len
probs = self.decoder(xs, x_outs, ys)
log_probs = torch.log(
probs[:, -1, :] + 1e-15
) # N x vocab_size (only take the last output)
history_nodes.append({})
p = 0
for idx in range(batch_size):
history_nodes[-1][idx] = []
N = len(last_nodes[idx])
if N == 0:
continue
log_prob = log_probs[p : p + N]
p += N
# log_prob = N x extended_vocab_size
# generate
candidates = []
for k in range(N):
logprobs, ids = torch.topk(log_prob[k], self.beam_width)
candidates = self.merge_candidates(
candidates, [(k, p, d) for p, d in zip(logprobs, ids)]
)
candidates = candidates[: self.beam_width]
extended_nodes_in_last_nodes = set()
for k in range(len(candidates)):
h, logp, next_word_id = candidates[
k
] # h means "the h-th node in last_nodes"
logp = float(logp)
next_word_id = int(next_word_id)
prev_node = last_nodes[idx][h]
next_wordids = prev_node.word_ids + [next_word_id]
next_word = self.vocab[next_word_id]
next_node = BeamSearchNode(
next_wordids, prev_node.log_prob + logp, prev_node.length + 1
)
if next_node.duplicate == False: # no duplicate trigram generated
extended_nodes_in_last_nodes.add(h)
if next_word == "<END>" or l == self.max_sent_len - 1:
end_nodes[idx].append((next_node.eval(), next_node))
else:
history_nodes[-1][idx].append(next_node)
special_words = ["<PAD>", "<UNK>", "<s>", "</s>", "<BEGIN>", "<END>"]
for k in range(N):
if k not in extended_nodes_in_last_nodes:
node = last_nodes[idx][k]
effective_word_count = sum(
[
1
for x in node.word_ids
if self.vocab[x] not in special_words
]
)
if effective_word_count >= 5:
end_nodes[idx].append((node.eval(), node))
MIN_GEN_LENGTH = 45
if "MIN_GEN_LENGTH" in self.opt:
MIN_GEN_LENGTH = int(self.opt["MIN_GEN_LENGTH"])
for idx in range(batch_size):
t = len([w for w in end_nodes[idx] if w[1].length > MIN_GEN_LENGTH])
if t > 0:
end_nodes[idx] = [
w for w in end_nodes[idx] if w[1].length > MIN_GEN_LENGTH
]
end_nodes[idx].sort(key=lambda tup: tup[0], reverse=True)
candidates = []
for score, node in end_nodes[idx][:topk]:
utt = [self.vocab[x] for x in node.word_ids]
utt = [x for x in utt if x not in ["<BEGIN>", "<END>"]]
candidates.append((utt, score))
if len(candidates) == 0:
candidates.append(("", 0))
sents.append(candidates)
return sents
class BeamSearchNode(object):
def __init__(self, word_ids, log_prob, length):
self.word_ids = word_ids
self.log_prob = log_prob
self.length = length
trigram_set = set()
self.duplicate = False
for i in range(2, len(word_ids)):
trigram = (
str(word_ids[i - 2])
+ " "
+ str(word_ids[i - 1])
+ " "
+ str(word_ids[i])
)
if trigram in trigram_set:
self.duplicate = True
break
trigram_set.add(trigram)
def eval(self):
return self.log_prob / float(self.length - 1.0 + 1e-6)
def __lt__(self, other):
return self.length < other.length