File size: 3,635 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
""" ContextGate module """
import torch
import torch.nn as nn
def context_gate_factory(gate_type, embeddings_size, decoder_size,
attention_size, output_size):
"""Returns the correct ContextGate class"""
gate_types = {'source': SourceContextGate,
'target': TargetContextGate,
'both': BothContextGate}
assert gate_type in gate_types, "Not valid ContextGate type: {0}".format(
gate_type)
return gate_types[gate_type](embeddings_size, decoder_size, attention_size,
output_size)
class ContextGate(nn.Module):
"""
Context gate is a decoder module that takes as input the previous word
embedding, the current decoder state and the attention state, and
produces a gate.
The gate can be used to select the input from the target side context
(decoder state), from the source context (attention state) or both.
"""
def __init__(self, embeddings_size, decoder_size,
attention_size, output_size):
super(ContextGate, self).__init__()
input_size = embeddings_size + decoder_size + attention_size
self.gate = nn.Linear(input_size, output_size, bias=True)
self.sig = nn.Sigmoid()
self.source_proj = nn.Linear(attention_size, output_size)
self.target_proj = nn.Linear(embeddings_size + decoder_size,
output_size)
def forward(self, prev_emb, dec_state, attn_state):
input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1)
z = self.sig(self.gate(input_tensor))
proj_source = self.source_proj(attn_state)
proj_target = self.target_proj(
torch.cat((prev_emb, dec_state), dim=1))
return z, proj_source, proj_target
class SourceContextGate(nn.Module):
"""Apply the context gate only to the source context"""
def __init__(self, embeddings_size, decoder_size,
attention_size, output_size):
super(SourceContextGate, self).__init__()
self.context_gate = ContextGate(embeddings_size, decoder_size,
attention_size, output_size)
self.tanh = nn.Tanh()
def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(
prev_emb, dec_state, attn_state)
return self.tanh(target + z * source)
class TargetContextGate(nn.Module):
"""Apply the context gate only to the target context"""
def __init__(self, embeddings_size, decoder_size,
attention_size, output_size):
super(TargetContextGate, self).__init__()
self.context_gate = ContextGate(embeddings_size, decoder_size,
attention_size, output_size)
self.tanh = nn.Tanh()
def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
return self.tanh(z * target + source)
class BothContextGate(nn.Module):
"""Apply the context gate to both contexts"""
def __init__(self, embeddings_size, decoder_size,
attention_size, output_size):
super(BothContextGate, self).__init__()
self.context_gate = ContextGate(embeddings_size, decoder_size,
attention_size, output_size)
self.tanh = nn.Tanh()
def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
return self.tanh((1. - z) * target + z * source)
|