|
""" ContextGate module """ |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def context_gate_factory(gate_type, embeddings_size, decoder_size, |
|
attention_size, output_size): |
|
"""Returns the correct ContextGate class""" |
|
|
|
gate_types = {'source': SourceContextGate, |
|
'target': TargetContextGate, |
|
'both': BothContextGate} |
|
|
|
assert gate_type in gate_types, "Not valid ContextGate type: {0}".format( |
|
gate_type) |
|
return gate_types[gate_type](embeddings_size, decoder_size, attention_size, |
|
output_size) |
|
|
|
|
|
class ContextGate(nn.Module): |
|
""" |
|
Context gate is a decoder module that takes as input the previous word |
|
embedding, the current decoder state and the attention state, and |
|
produces a gate. |
|
The gate can be used to select the input from the target side context |
|
(decoder state), from the source context (attention state) or both. |
|
""" |
|
|
|
def __init__(self, embeddings_size, decoder_size, |
|
attention_size, output_size): |
|
super(ContextGate, self).__init__() |
|
input_size = embeddings_size + decoder_size + attention_size |
|
self.gate = nn.Linear(input_size, output_size, bias=True) |
|
self.sig = nn.Sigmoid() |
|
self.source_proj = nn.Linear(attention_size, output_size) |
|
self.target_proj = nn.Linear(embeddings_size + decoder_size, |
|
output_size) |
|
|
|
def forward(self, prev_emb, dec_state, attn_state): |
|
input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) |
|
z = self.sig(self.gate(input_tensor)) |
|
proj_source = self.source_proj(attn_state) |
|
proj_target = self.target_proj( |
|
torch.cat((prev_emb, dec_state), dim=1)) |
|
return z, proj_source, proj_target |
|
|
|
|
|
class SourceContextGate(nn.Module): |
|
"""Apply the context gate only to the source context""" |
|
|
|
def __init__(self, embeddings_size, decoder_size, |
|
attention_size, output_size): |
|
super(SourceContextGate, self).__init__() |
|
self.context_gate = ContextGate(embeddings_size, decoder_size, |
|
attention_size, output_size) |
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, prev_emb, dec_state, attn_state): |
|
z, source, target = self.context_gate( |
|
prev_emb, dec_state, attn_state) |
|
return self.tanh(target + z * source) |
|
|
|
|
|
class TargetContextGate(nn.Module): |
|
"""Apply the context gate only to the target context""" |
|
|
|
def __init__(self, embeddings_size, decoder_size, |
|
attention_size, output_size): |
|
super(TargetContextGate, self).__init__() |
|
self.context_gate = ContextGate(embeddings_size, decoder_size, |
|
attention_size, output_size) |
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, prev_emb, dec_state, attn_state): |
|
z, source, target = self.context_gate(prev_emb, dec_state, attn_state) |
|
return self.tanh(z * target + source) |
|
|
|
|
|
class BothContextGate(nn.Module): |
|
"""Apply the context gate to both contexts""" |
|
|
|
def __init__(self, embeddings_size, decoder_size, |
|
attention_size, output_size): |
|
super(BothContextGate, self).__init__() |
|
self.context_gate = ContextGate(embeddings_size, decoder_size, |
|
attention_size, output_size) |
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, prev_emb, dec_state, attn_state): |
|
z, source, target = self.context_gate(prev_emb, dec_state, attn_state) |
|
return self.tanh((1. - z) * target + z * source) |
|
|