|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import re |
|
from functools import partial |
|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from fairseq.models import ( |
|
FairseqEncoder, |
|
) |
|
from fairseq.models.speech_to_text.utils import ( |
|
NoOp, |
|
lengths_to_padding_mask, |
|
segments_to_sequence, |
|
) |
|
from fairseq.models.speech_to_text.utils import ( |
|
attention_suppression, |
|
layer_norm_backward_hook, |
|
) |
|
from torch import Tensor, device as Device |
|
from torch.quantization.qconfig import ( |
|
default_dynamic_qconfig, |
|
per_channel_dynamic_qconfig, |
|
) |
|
|
|
|
|
class RelativePositionEmbedding(nn.Module): |
|
""" |
|
Implementation according to https://arxiv.org/abs/1803.02155 |
|
""" |
|
|
|
def __init__(self, head_dim, max_position, norm_init=True): |
|
super().__init__() |
|
self.head_dim = head_dim |
|
self.max_position = max_position |
|
self.embeddings = nn.Parameter(torch.Tensor(max_position * 2 + 1, head_dim)) |
|
if norm_init: |
|
nn.init.xavier_normal_(self.embeddings) |
|
else: |
|
nn.init.xavier_uniform_(self.embeddings) |
|
|
|
def forward(self, input: Tensor): |
|
output = nn.functional.embedding(input.long(), self.embeddings) |
|
return output |
|
|
|
|
|
class Fp32LayerNorm(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim, |
|
clamp_grad=True, |
|
max_grad_value=256, |
|
eps=1e-5, |
|
elementwise_affine=True, |
|
): |
|
super().__init__() |
|
self.torch_module = torch.nn.LayerNorm( |
|
input_dim, eps=eps, elementwise_affine=elementwise_affine |
|
) |
|
if clamp_grad: |
|
hook = partial(layer_norm_backward_hook, clamp_value=max_grad_value) |
|
self.torch_module.register_backward_hook(hook) |
|
|
|
def forward(self, input): |
|
output = torch.nn.functional.layer_norm( |
|
input.float(), |
|
self.torch_module.normalized_shape, |
|
self.torch_module.weight.float() |
|
if self.torch_module.weight is not None |
|
else None, |
|
self.torch_module.bias.float() |
|
if self.torch_module.bias is not None |
|
else None, |
|
self.torch_module.eps, |
|
).type_as(input) |
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PositionwiseFF(nn.Module): |
|
""" |
|
FFN layer in transformer. |
|
|
|
Args: |
|
input_dim: input embedding dimension |
|
ffn_dim: FFN layer inner dimension |
|
dropout_on_fc1: dropout for first linear layer |
|
dropout_on_fc2: dropout fr second linear layer |
|
activation_fn: activation function used after first linear layer. \ |
|
Only relu or gelu is supported. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, input_dim, ffn_dim, dropout_on_fc1, dropout_on_fc2, activation_fn |
|
): |
|
super(PositionwiseFF, self).__init__() |
|
|
|
self.input_dim = input_dim |
|
self.ffn_dim = ffn_dim |
|
if activation_fn == "relu": |
|
ac = nn.ReLU() |
|
elif activation_fn == "gelu": |
|
ac = nn.GELU() |
|
else: |
|
raise ValueError("Unsupported activation_fn = ({})".format(activation_fn)) |
|
|
|
|
|
self.module = nn.Sequential( |
|
nn.Linear(input_dim, ffn_dim), |
|
ac, |
|
nn.Dropout(dropout_on_fc1), |
|
nn.Linear(ffn_dim, input_dim), |
|
nn.Dropout(dropout_on_fc2), |
|
) |
|
|
|
self.layer_norm = Fp32LayerNorm(input_dim) |
|
|
|
def forward(self, input): |
|
module_out = self.module(self.layer_norm(input)) |
|
output = module_out + input |
|
|
|
return output |
|
|
|
def quantize_(self, params=None): |
|
if params and "per_channel" in params and params["per_channel"]: |
|
qconfig = per_channel_dynamic_qconfig |
|
else: |
|
qconfig = default_dynamic_qconfig |
|
torch.quantization.quantize_dynamic( |
|
self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True |
|
) |
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SummarizationLayer(nn.Module): |
|
def __init__(self, method, segment_size, embedding_dim): |
|
super(SummarizationLayer, self).__init__() |
|
self.segment_size = segment_size |
|
self.embedding_dim = embedding_dim |
|
nonlin_match = re.match(r"nonlinear\((?P<act>[a-z]+),(?P<dim>[0-9]+)\)", method) |
|
self.method = method |
|
if method == "mean": |
|
self.module = nn.AvgPool1d( |
|
kernel_size=segment_size, |
|
stride=segment_size, |
|
ceil_mode=True, |
|
) |
|
elif method == "max": |
|
self.module = nn.MaxPool1d( |
|
kernel_size=segment_size, |
|
stride=segment_size, |
|
ceil_mode=True, |
|
) |
|
elif method == "linear": |
|
self.module = nn.Linear(segment_size, 1) |
|
elif nonlin_match: |
|
nonlin_args = nonlin_match.groupdict() |
|
act_type = nonlin_args["act"] |
|
hid_dim = int(nonlin_args["dim"]) |
|
if act_type == "relu": |
|
act = nn.ReLU() |
|
elif act_type == "gelu": |
|
act = nn.GELU() |
|
else: |
|
raise ValueError("Unsupported activation_fn = ({})".format(act_type)) |
|
self.module = nn.Sequential( |
|
nn.Linear(segment_size, hid_dim), |
|
act, |
|
nn.Linear(hid_dim, 1), |
|
) |
|
else: |
|
raise ValueError("Unsupported summarization method = ({})".format(method)) |
|
|
|
def forward(self, input): |
|
|
|
input = input.permute(1, 2, 0) |
|
|
|
if self.method == "mean" or self.method == "max": |
|
output = self.module(input) |
|
output = output.permute(2, 0, 1) |
|
return output |
|
|
|
full_seg_length = input.size(2) // self.segment_size * self.segment_size |
|
if full_seg_length > 0: |
|
|
|
B = input.size(0) |
|
D = input.size(1) |
|
input_todo = ( |
|
input[:, :, :full_seg_length] |
|
.contiguous() |
|
.view(B, -1, self.segment_size) |
|
) |
|
output = self.module(input_todo) |
|
output = output.view(B, D, -1) |
|
else: |
|
output = input.new_zeros(input.size(0), input.size(1), 0) |
|
left = input.size(2) - full_seg_length |
|
if left > 0: |
|
|
|
zeros = input.new_zeros(input.size(0), input.size(1), 1) |
|
output = torch.cat([output, zeros], dim=2) |
|
output = output.permute(2, 0, 1) |
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NoSegAugmentedMemoryMultiheadAttentionBmm(nn.Module): |
|
""" |
|
Whole utterance augmented memory multihead attention using BMM. |
|
|
|
Different with previous augmented memory multihead attention where |
|
the utterance is chunked into segments. Here we use attention mask |
|
achieve so. The input embedding [right_context, utterance, summary] |
|
is a concatenation of right context, utterance and summary. |
|
|
|
Right context block is the concatenation of all the right context for |
|
each segments. [right_context_0, right_context_1, ..., right_context_n] |
|
For example, if we have utterance = [v0, v1, v2, ...., v20]. segment |
|
size 8, right_context size 4. Then the right context blocks = |
|
[v8, v9, v10, v11, v16, v17, v18, v19, 0, 0, 0, 0], where v8, v9, v10, |
|
and v11 are the right context for first segment. v16, v17, v18 and v19 |
|
are the right context for second segment. 0, 0, 0 and 0 are right context |
|
for the last segment. |
|
|
|
utterance is corresponding to input embedding sequence |
|
|
|
summary is concatenation of average of each segments. [summary_0, |
|
summary_1, ..., ]. |
|
|
|
In augmented memory multihead attention, the query is [right_context, |
|
utterance, summary], key is [memory, right_context, utterance]. Different |
|
with AugmentedMemoryMultiheadAttentionBmm, memory here is passed from |
|
previous attention layer. For the first attention layer, memory is average |
|
of each segment. |
|
|
|
Memory is a concatenation of memory from each segments in previous attention |
|
layer. For example, current layer is i, then memory is [m_0, m_1, ..., m_n]. |
|
Each m_k is the output from seg_k in layer i-1. |
|
|
|
args: |
|
input_dim: input embedding dimension |
|
num_heads: number of heads in multihead self-attention |
|
dropout: attention dropout |
|
std_scale: if std_scale is not None. The weak attention suppression is |
|
turned on. For std_scale = 0.5, all the attention smaller than |
|
mean + 0.5 * std will be suppressed. |
|
scaled_init: whether to use scaled init for linear weight |
|
tanh_on_mem: whether to use tanh on memory output |
|
use_mem: whether to use memory or not. When max_memory_size is 0, then |
|
we don't have memory anymore. |
|
layer_index: current self-attention layer index that is used in depth |
|
initialization |
|
max_relative_position: max relative position used in relative position |
|
embedding |
|
rpe_old_option: To be compatible with previous model. The previous model |
|
was trained with attention += attention + rpe. The correct equation |
|
should be attention = attention + rpe |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim, |
|
num_heads, |
|
dropout=0.0, |
|
std_scale=None, |
|
scaled_init=False, |
|
tanh_on_mem=False, |
|
use_mem=True, |
|
mini_batches=False, |
|
negative_inf="-inf", |
|
layer_index=-1, |
|
max_relative_position=0, |
|
rpe_old_option=True, |
|
): |
|
if input_dim % num_heads: |
|
raise ValueError( |
|
"input_dim ({}) must be divisible by num_heads ({})".format( |
|
input_dim, num_heads |
|
) |
|
) |
|
|
|
super().__init__() |
|
|
|
embed_dim = input_dim |
|
self.e2h_kv = torch.nn.Linear(input_dim, 2 * input_dim, bias=True) |
|
self.e2h_q = torch.nn.Linear(input_dim, input_dim, bias=True) |
|
self.rpe_old_option = rpe_old_option |
|
if max_relative_position > 0: |
|
self.use_rpe = True |
|
self.rpe_k = RelativePositionEmbedding( |
|
head_dim=input_dim // num_heads, |
|
max_position=max_relative_position, |
|
) |
|
self.rpe_v = RelativePositionEmbedding( |
|
head_dim=input_dim // num_heads, |
|
max_position=max_relative_position, |
|
) |
|
else: |
|
self.use_rpe = False |
|
self.rpe_k = None |
|
self.rpe_v = None |
|
if scaled_init: |
|
if layer_index == -1: |
|
gain = 1.0 / math.sqrt(2) |
|
else: |
|
|
|
|
|
|
|
gain = 1.0 / math.sqrt(layer_index + 1) |
|
torch.nn.init.xavier_uniform_(self.e2h_kv.weight, gain=gain) |
|
torch.nn.init.xavier_uniform_(self.e2h_q.weight, gain=gain) |
|
|
|
self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
|
|
self.head_dim = embed_dim // num_heads |
|
self.scaling = self.head_dim ** -0.5 |
|
|
|
self.std_scale = std_scale |
|
self.use_mem = use_mem |
|
self.mini_batches = mini_batches |
|
self.negative_inf = negative_inf |
|
|
|
if tanh_on_mem: |
|
self.squash_mem = torch.tanh |
|
self.nonlinear_squash_mem = True |
|
else: |
|
self.squash_mem = NoOp() |
|
self.nonlinear_squash_mem = False |
|
|
|
def prepare_qkv( |
|
self, |
|
input: Tensor, |
|
mems: Tensor, |
|
lengths: Tensor, |
|
summary_length: int, |
|
lc_length: int, |
|
): |
|
|
|
T, B, D = input.shape |
|
mem_length = mems.size(0) |
|
utterance_length = torch.max(lengths) |
|
|
|
right_context_blocks_length = T - utterance_length - summary_length |
|
rc_block = input[:right_context_blocks_length, :, :] |
|
utterance_block = input[right_context_blocks_length : T - summary_length, :, :] |
|
|
|
if B == 1: |
|
padding_mask = None |
|
else: |
|
klengths = lengths + mem_length + right_context_blocks_length + lc_length |
|
padding_mask = lengths_to_padding_mask(lengths=klengths) |
|
|
|
mem_rc_input = torch.cat([mems, rc_block, utterance_block], dim=0) |
|
|
|
|
|
key_length = mem_rc_input.size(0) + lc_length |
|
rc_input_sum = input |
|
q = self.e2h_q(rc_input_sum) |
|
kv = self.e2h_kv(mem_rc_input) |
|
k, v = kv.chunk(chunks=2, dim=2) |
|
result_qkv = (q, k, v) |
|
input_shape = (T, B, D) |
|
result_lengths_info = ( |
|
mem_length, |
|
utterance_length, |
|
right_context_blocks_length, |
|
key_length, |
|
) |
|
if padding_mask is not None: |
|
assert padding_mask.size(0) == B |
|
assert padding_mask.size(1) == key_length |
|
|
|
return result_qkv, input_shape, result_lengths_info, padding_mask |
|
|
|
def prepare_attention_weights( |
|
self, |
|
q: Tensor, |
|
new_k: Tensor, |
|
new_v: Tensor, |
|
input_shape: Tuple[int, int, int], |
|
rpe: Optional[Tensor], |
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
T, B, D = input_shape |
|
q = ( |
|
q.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) |
|
* self.scaling |
|
) |
|
|
|
k = ( |
|
new_k.contiguous() |
|
.view(-1, B * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
|
|
v = ( |
|
new_v.contiguous() |
|
.view(-1, B * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
|
|
attention_weights = torch.bmm(q, k.transpose(1, 2)) |
|
if self.use_rpe and rpe is not None and self.rpe_v is not None: |
|
r_k = self.rpe_k(rpe) |
|
|
|
attention_weights_rpe = torch.matmul( |
|
q.transpose(0, 1), r_k.transpose(1, 2) |
|
).transpose(0, 1) |
|
attention_weights = attention_weights + attention_weights_rpe |
|
attention_weights_float = attention_weights.float() |
|
|
|
return attention_weights, attention_weights_float, v |
|
|
|
def prepare_attention_output( |
|
self, |
|
attention_weights: Tensor, |
|
attention_weights_float: Tensor, |
|
v: Tensor, |
|
input_shape: Tuple[int, int, int], |
|
key_length: int, |
|
padding_mask: Optional[Tensor], |
|
rpe: Optional[Tensor], |
|
) -> Tensor: |
|
T, B, D = input_shape |
|
if padding_mask is not None: |
|
attention_weights_float = attention_weights_float.view( |
|
B, self.num_heads, T, key_length |
|
) |
|
attention_weights_float = attention_weights_float.masked_fill( |
|
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
|
) |
|
attention_weights_float = attention_weights_float.view( |
|
B * self.num_heads, T, key_length |
|
) |
|
|
|
if self.std_scale is not None: |
|
attention_weights_float = attention_suppression( |
|
attention_weights_float, self.std_scale |
|
) |
|
|
|
attention_weights_float = torch.nn.functional.softmax( |
|
attention_weights_float, dim=-1 |
|
) |
|
attention_weights = attention_weights_float.type_as(attention_weights) |
|
|
|
attention_probs = torch.nn.functional.dropout( |
|
attention_weights, p=self.dropout, training=self.training |
|
) |
|
|
|
|
|
|
|
attention = torch.bmm(attention_probs, v) |
|
if self.use_rpe and rpe is not None and self.rpe_v is not None: |
|
r_v = self.rpe_v(rpe) |
|
attention_rpe = torch.matmul( |
|
attention_probs.transpose(0, 1), r_v |
|
).transpose(0, 1) |
|
|
|
if self.rpe_old_option: |
|
attention += attention + attention_rpe |
|
else: |
|
attention = attention + attention_rpe |
|
|
|
assert list(attention.shape) == [B * self.num_heads, T, self.head_dim] |
|
|
|
attention = attention.transpose(0, 1).contiguous().view(T, B, self.embed_dim) |
|
|
|
rc_output_memory = self.out_proj(attention) |
|
return rc_output_memory |
|
|
|
@torch.jit.unused |
|
def forward( |
|
self, |
|
input: Tensor, |
|
lengths: Tensor, |
|
mems: Tensor, |
|
attention_mask: Tensor, |
|
pre_mems: Optional[Tensor] = None, |
|
left_context_key: Optional[Tensor] = None, |
|
left_context_val: Optional[Tensor] = None, |
|
rpe: Optional[Tensor] = None, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
""" |
|
forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in training. |
|
|
|
args: |
|
input: formed in the following way |
|
[right_context_0, right_contex_1, ..., seg_0, seg_1, |
|
..., summary_0, summary_1,..] |
|
lengths: the length of query which is [seg_0, seg_1, ....] |
|
mems: [mem_0, mem_1, ...]. |
|
attention_mask: attention mask for query = [right_context, query, summary] |
|
key = [mem, right_context, query]. This is only used for traing. |
|
|
|
""" |
|
if self.use_mem: |
|
mem_length = mems.size(0) |
|
summary_length = mem_length + 1 |
|
if pre_mems is not None: |
|
mems = torch.cat([pre_mems, mems], dim=0) |
|
else: |
|
mem_length = 0 |
|
summary_length = 0 |
|
|
|
|
|
if left_context_key is not None: |
|
lc_length = left_context_key.size(0) |
|
else: |
|
lc_length = 0 |
|
results = self.prepare_qkv( |
|
input=input, |
|
mems=mems, |
|
lengths=lengths, |
|
summary_length=summary_length, |
|
lc_length=lc_length, |
|
) |
|
result_qkv, input_shape, result_lengths_info, padding_mask = results |
|
q, k, v = result_qkv |
|
( |
|
mem_length, |
|
utterance_length, |
|
right_context_blocks_length, |
|
key_length, |
|
) = result_lengths_info |
|
|
|
if left_context_key is not None: |
|
|
|
new_k = torch.cat( |
|
[ |
|
k[: mem_length + right_context_blocks_length, :, :], |
|
left_context_key, |
|
k[-utterance_length:, :, :], |
|
], |
|
dim=0, |
|
) |
|
new_v = torch.cat( |
|
[ |
|
v[: mem_length + right_context_blocks_length, :, :], |
|
left_context_val, |
|
v[-utterance_length:, :, :], |
|
], |
|
dim=0, |
|
) |
|
next_k = new_k[mem_length + right_context_blocks_length :, :, :] |
|
next_v = new_v[mem_length + right_context_blocks_length :, :, :] |
|
else: |
|
new_k = k |
|
new_v = v |
|
next_k = None |
|
next_v = None |
|
|
|
attention_weights, attention_weights_float, v = self.prepare_attention_weights( |
|
q=q, |
|
new_k=new_k, |
|
new_v=new_v, |
|
input_shape=input_shape, |
|
rpe=rpe, |
|
) |
|
|
|
|
|
attention_mask = attention_mask.unsqueeze(0) |
|
attention_weights_float = attention_weights_float.masked_fill( |
|
attention_mask, float(self.negative_inf) |
|
) |
|
|
|
rc_output_memory = self.prepare_attention_output( |
|
attention_weights=attention_weights, |
|
attention_weights_float=attention_weights_float, |
|
v=v, |
|
input_shape=input_shape, |
|
key_length=key_length, |
|
padding_mask=padding_mask, |
|
rpe=rpe, |
|
) |
|
|
|
if self.use_mem: |
|
|
|
|
|
if self.mini_batches: |
|
next_m = rc_output_memory[-summary_length:] |
|
else: |
|
next_m = rc_output_memory[-summary_length:-1] |
|
|
|
next_m = self.squash_mem(next_m) |
|
|
|
rc_output = rc_output_memory[:-summary_length] |
|
if not self.nonlinear_squash_mem: |
|
next_m = torch.clamp(next_m, min=-10, max=10) |
|
else: |
|
next_m = mems |
|
rc_output = rc_output_memory |
|
|
|
return rc_output, next_m, next_k, next_v |
|
|
|
@torch.jit.export |
|
def forward_jit( |
|
self, |
|
input: Tensor, |
|
lengths: Tensor, |
|
mems: Tensor, |
|
left_context_key: Tensor, |
|
left_context_val: Tensor, |
|
rpe: Optional[Tensor], |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
""" |
|
forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in decoding. |
|
|
|
args: |
|
input: formed in the following way |
|
[right_context_0, right_contex_1, ..., seg_0, seg_1, |
|
..., summary_0, summary_1,..] |
|
lengths: the length of query which is [seg_0, seg_1, ....] |
|
mems: [mem_0, mem_1, ...]. |
|
left_context_key: left_context for key part. This is only used for online |
|
decoding. In training, this is empty tensor |
|
left_context_val: left_context for value part. This is only used for online |
|
decoding. In training, this is empty tensor |
|
|
|
""" |
|
lc_length = left_context_key.size(0) |
|
|
|
|
|
if self.use_mem: |
|
summary_length = 1 |
|
else: |
|
summary_length = 0 |
|
|
|
results = self.prepare_qkv( |
|
input=input, |
|
mems=mems, |
|
lengths=lengths, |
|
summary_length=summary_length, |
|
lc_length=lc_length, |
|
) |
|
result_qkv, input_shape, result_lengths_info, padding_mask = results |
|
q, k, v = result_qkv |
|
( |
|
mem_length, |
|
utterance_length, |
|
right_context_blocks_length, |
|
key_length, |
|
) = result_lengths_info |
|
|
|
|
|
new_k = torch.cat( |
|
[ |
|
k[: mem_length + right_context_blocks_length, :, :], |
|
left_context_key, |
|
k[-utterance_length:, :, :], |
|
], |
|
dim=0, |
|
) |
|
new_v = torch.cat( |
|
[ |
|
v[: mem_length + right_context_blocks_length, :, :], |
|
left_context_val, |
|
v[-utterance_length:, :, :], |
|
], |
|
dim=0, |
|
) |
|
next_k = new_k[mem_length + right_context_blocks_length :, :, :] |
|
next_v = new_v[mem_length + right_context_blocks_length :, :, :] |
|
|
|
attention_weights, attention_weights_float, v = self.prepare_attention_weights( |
|
q=q, |
|
new_k=new_k, |
|
new_v=new_v, |
|
input_shape=input_shape, |
|
rpe=rpe, |
|
) |
|
|
|
|
|
attention_weights_float[:, -1, :mem_length] = float(self.negative_inf) |
|
rc_output_memory = self.prepare_attention_output( |
|
attention_weights=attention_weights, |
|
attention_weights_float=attention_weights_float, |
|
v=v, |
|
input_shape=input_shape, |
|
key_length=key_length, |
|
padding_mask=padding_mask, |
|
rpe=rpe, |
|
) |
|
|
|
|
|
if self.use_mem: |
|
next_m = rc_output_memory[-1:] |
|
next_m = self.squash_mem(next_m) |
|
|
|
rc_output = rc_output_memory[:-1] |
|
if not self.nonlinear_squash_mem: |
|
next_m = torch.clamp(next_m, min=-10, max=10) |
|
else: |
|
rc_output = rc_output_memory |
|
|
|
next_m = mems |
|
|
|
return rc_output, next_m, next_k, next_v |
|
|
|
def quantize_(self, params=None): |
|
if params and "per_channel" in params and params["per_channel"]: |
|
qconfig = per_channel_dynamic_qconfig |
|
else: |
|
qconfig = default_dynamic_qconfig |
|
torch.quantization.quantize_dynamic( |
|
self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True |
|
) |
|
return self |
|
|
|
|
|
class NoSegAugmentedMemoryTransformer(nn.Module): |
|
""" |
|
Whole utterance augmented memory transformer. |
|
|
|
This is not pyspeech nn layer. It is used as a module in a master layer where |
|
multiple transformers is used. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim, |
|
num_heads, |
|
ffn_dim, |
|
dropout_in_attn=0.0, |
|
dropout_on_attn=None, |
|
dropout_on_fc1=None, |
|
dropout_on_fc2=None, |
|
activation_fn="relu", |
|
tanh_on_mem=False, |
|
std_scale=None, |
|
scaled_init=False, |
|
segment_size=128, |
|
use_mem=True, |
|
mini_batches=False, |
|
negative_inf="-inf", |
|
layer_index=-1, |
|
summarization_method="mean", |
|
max_relative_position=0, |
|
rpe_old_option=True, |
|
): |
|
super(NoSegAugmentedMemoryTransformer, self).__init__() |
|
|
|
self.attention = NoSegAugmentedMemoryMultiheadAttentionBmm( |
|
input_dim=input_dim, |
|
num_heads=num_heads, |
|
dropout=dropout_in_attn, |
|
scaled_init=scaled_init, |
|
tanh_on_mem=tanh_on_mem, |
|
std_scale=std_scale, |
|
use_mem=use_mem, |
|
mini_batches=mini_batches, |
|
negative_inf=negative_inf, |
|
layer_index=layer_index, |
|
max_relative_position=max_relative_position, |
|
) |
|
self.dropout = nn.Dropout(dropout_on_attn) |
|
self.pos_ff = PositionwiseFF( |
|
input_dim=input_dim, |
|
ffn_dim=ffn_dim, |
|
dropout_on_fc1=dropout_on_fc1, |
|
dropout_on_fc2=dropout_on_fc2, |
|
activation_fn=activation_fn, |
|
) |
|
self.layer_norm_pre = Fp32LayerNorm(input_dim) |
|
self.layer_norm = Fp32LayerNorm(input_dim) |
|
self.segment_size = segment_size |
|
self.use_mem = use_mem |
|
|
|
self.memory_op = SummarizationLayer( |
|
summarization_method, segment_size, input_dim |
|
) |
|
|
|
def set_mini_batches(self, mini_batches): |
|
self.attention.mini_batches = mini_batches |
|
|
|
def gen_summary_queries(self, input): |
|
sum_input = self.memory_op(input) |
|
return sum_input |
|
|
|
def pre_attention_ops(self, input, right_context_blocks): |
|
rc_length = right_context_blocks.size(0) |
|
input_length = input.size(0) |
|
|
|
rc_and_input = torch.cat([right_context_blocks, input], dim=0) |
|
residual_input = rc_and_input |
|
rc_and_input = self.layer_norm_pre(rc_and_input) |
|
|
|
query_input = rc_and_input[-input_length:, :, :] |
|
return rc_length, input_length, residual_input, query_input, rc_and_input |
|
|
|
def after_attention_ops(self, attention_output, residual_input): |
|
output = self.dropout(attention_output) |
|
output = output + residual_input |
|
output = self.pos_ff(output) |
|
output = self.layer_norm(output) |
|
return output |
|
|
|
@torch.jit.export |
|
def forward_jit( |
|
self, |
|
input: Tensor, |
|
lengths: Tensor, |
|
mems: Tensor, |
|
left_context_key: Tensor, |
|
left_context_val: Tensor, |
|
right_context_blocks: Tensor, |
|
rpe: Optional[Tensor], |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: |
|
|
|
results = self.pre_attention_ops(input, right_context_blocks) |
|
rc_length, input_length, residual_input, query_input, rc_and_input = results |
|
|
|
|
|
if self.use_mem: |
|
summary_query = self.gen_summary_queries(query_input) |
|
summary_query = summary_query[0:1, :, :] |
|
rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) |
|
else: |
|
rc_qu_su = rc_and_input |
|
|
|
rc_output, next_m, next_k, next_v = self.attention.forward_jit( |
|
input=rc_qu_su, |
|
lengths=lengths, |
|
mems=mems, |
|
left_context_key=left_context_key, |
|
left_context_val=left_context_val, |
|
rpe=rpe, |
|
) |
|
rc_output = self.after_attention_ops(rc_output, residual_input) |
|
results = ( |
|
rc_output[-input_length:, :, :], |
|
next_m, |
|
rc_output[0:rc_length, :, :], |
|
next_k, |
|
next_v, |
|
) |
|
return results |
|
|
|
@torch.jit.unused |
|
def forward( |
|
self, |
|
input, |
|
lengths, |
|
mems, |
|
right_context_blocks, |
|
attention_mask, |
|
pre_mems, |
|
left_context_key, |
|
left_context_val, |
|
rpe, |
|
): |
|
|
|
results = self.pre_attention_ops(input, right_context_blocks) |
|
rc_length, input_length, residual_input, query_input, rc_and_input = results |
|
if self.use_mem: |
|
summary_query = self.gen_summary_queries(query_input) |
|
rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) |
|
else: |
|
rc_qu_su = rc_and_input |
|
|
|
rc_output, next_m, next_k, next_v = self.attention( |
|
input=rc_qu_su, |
|
lengths=lengths, |
|
mems=mems, |
|
attention_mask=attention_mask, |
|
pre_mems=pre_mems, |
|
left_context_key=left_context_key, |
|
left_context_val=left_context_val, |
|
rpe=rpe, |
|
) |
|
|
|
|
|
|
|
rc_output = self.after_attention_ops(rc_output, residual_input) |
|
results = ( |
|
rc_output[-input_length:, :, :], |
|
next_m, |
|
rc_output[0:rc_length, :, :], |
|
next_k, |
|
next_v, |
|
) |
|
|
|
return results |
|
|
|
|
|
class NoSegAugmentedMemoryTransformerEncoderLayer(FairseqEncoder): |
|
""" |
|
Whole utterance augmented memory transformer encoder layer. This is a master layer |
|
where we can define multiple augmented memory transformers. There are two reasons |
|
to setup the master layer. |
|
1. We only need to define once about the attention mask. All the layers in the master |
|
layer share the same mask. |
|
2. pyspeech nn layer has special input and output format. Defining one master layer is |
|
easier to passing memory between different layes inside the master layer |
|
|
|
args: |
|
input_dim: input embedding dimension |
|
num_heads: number of heads in multihead self-attention |
|
ffn_dim: ffn dimension in FFN layer |
|
num_layers: number of augmented memory transformer layers |
|
dropout_in_attn: dropout used in multi-head self-attention |
|
dropout_on_attn: dropout used for output from te multihead self-attention |
|
dropout_on_fc1: dropout used in FFN layer for the first linear layer |
|
dropout_on_fc2: dropout used in FFN layer for the second linear layer |
|
segment_size: segment size for each segment |
|
context_config: (left_context_size, right_context_size) defines the surround context size |
|
for each segment |
|
max_memory_size: maximum memory size used for each segment |
|
scaled_init: whether use scaled init for weight initialization in attention layer |
|
std_scale: if std_scale is not None. The weak attention suppression is |
|
turned on. For std_scale = 0.5, all the attention smaller than |
|
mean + 0.5 * std will be suppressed. |
|
activation_fn: activation function used in FFN layer. [ReLU, GELU] supported |
|
tanh_on_mem: whether use tanh on memory |
|
mini_batches: use mini-btach training |
|
negative_inf: the negative infinity value used in attention masking. default is "-inf". |
|
For some situation, e.g. LM. it is better to use "-1e8" to avoid nan issue. |
|
summarization_method: method to generate segment summrization embedding |
|
max_relative_position: max relatie position for relative position embedding |
|
rpe_old_option: To be compatible with previous model. The previous model |
|
was trained with attention += attention + rpe. The correct equation |
|
should be attention = attention + rpe |
|
[TODO]: remove the rpe_old_option by the end of 2021 Q1. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim, |
|
num_heads, |
|
ffn_dim, |
|
num_layers=1, |
|
dropout_in_attn=0.0, |
|
dropout_on_attn=0.0, |
|
dropout_on_fc1=0.0, |
|
dropout_on_fc2=0.0, |
|
segment_size=128, |
|
context_config=(0, 0), |
|
max_memory_size=0, |
|
scaled_init=True, |
|
std_scale=None, |
|
activation_fn="relu", |
|
tanh_on_mem=False, |
|
mini_batches=False, |
|
negative_inf="-inf", |
|
deep_init=True, |
|
summarization_method="mean", |
|
max_relative_position=0, |
|
rpe_old_option=True, |
|
): |
|
super().__init__(None) |
|
if input_dim % num_heads: |
|
raise ValueError( |
|
"input_dim ({}) must be divisible by num_heads ({})".format( |
|
input_dim, num_heads |
|
) |
|
) |
|
|
|
|
|
|
|
if max_memory_size < 0: |
|
raise ValueError("max_memory_size must be >= 0") |
|
|
|
|
|
|
|
self.left_context, self.right_context = context_config |
|
self.segment_size = segment_size |
|
self.memory_dim = input_dim |
|
self.max_memory_size = max_memory_size |
|
self.mini_batches = mini_batches |
|
if self.max_memory_size != 0: |
|
self.use_mem = True |
|
else: |
|
self.use_mem = False |
|
|
|
self.memory_op = SummarizationLayer( |
|
summarization_method, segment_size, input_dim |
|
) |
|
|
|
self.layers = torch.nn.ModuleList() |
|
self.num_layers = num_layers |
|
self.max_relative_position = max_relative_position |
|
if self.max_relative_position > 0: |
|
self.use_rpe = True |
|
else: |
|
self.use_rpe = False |
|
for i in range(self.num_layers): |
|
if deep_init: |
|
layer_index = i |
|
else: |
|
layer_index = -1 |
|
|
|
self.layers.append( |
|
NoSegAugmentedMemoryTransformer( |
|
num_heads=num_heads, |
|
input_dim=input_dim, |
|
ffn_dim=ffn_dim, |
|
dropout_in_attn=dropout_in_attn, |
|
dropout_on_attn=dropout_on_attn, |
|
dropout_on_fc1=dropout_on_fc1, |
|
dropout_on_fc2=dropout_on_fc2, |
|
segment_size=segment_size, |
|
std_scale=std_scale, |
|
activation_fn=activation_fn, |
|
tanh_on_mem=tanh_on_mem, |
|
scaled_init=scaled_init, |
|
use_mem=self.use_mem, |
|
mini_batches=mini_batches, |
|
negative_inf=negative_inf, |
|
layer_index=layer_index, |
|
summarization_method=summarization_method, |
|
max_relative_position=max_relative_position, |
|
rpe_old_option=rpe_old_option, |
|
) |
|
) |
|
|
|
def set_mini_batches(self, mini_batches): |
|
|
|
self.mini_batches = mini_batches |
|
for layer in self.layers: |
|
layer.set_mini_batches(mini_batches) |
|
|
|
def _get_relative_position( |
|
self, |
|
input: Tensor, |
|
max_relative_position: int, |
|
left_context_length: int, |
|
past_length: int, |
|
is_decoding: bool, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T, B, D = input.shape |
|
num_segs = math.ceil((T - self.right_context) / self.segment_size) |
|
|
|
|
|
u_st = past_length * self.segment_size |
|
u_ed = u_st + T |
|
utterance_ranges = torch.arange(u_st, u_ed - self.right_context) |
|
|
|
|
|
left_context_ranges = torch.arange(u_st - left_context_length, u_st) |
|
|
|
|
|
|
|
right_context_blocks = [] |
|
for i in range(0, num_segs - 1): |
|
st = (i + 1) * self.segment_size + u_st |
|
ed = st + self.right_context |
|
assert ed < u_ed |
|
temp = torch.arange(st, ed) |
|
right_context_blocks.append(temp) |
|
right_context_blocks.append(torch.arange(u_ed - self.right_context, u_ed)) |
|
right_context_ranges = torch.cat(right_context_blocks) |
|
|
|
if self.use_mem: |
|
|
|
|
|
if is_decoding: |
|
memory_size = min(past_length, self.max_memory_size) |
|
else: |
|
memory_size = num_segs + past_length - 1 |
|
memory_bank_ranges = torch.arange( |
|
-max_relative_position - 1, -max_relative_position - 1 - memory_size, -1 |
|
) |
|
|
|
|
|
|
|
|
|
summary_pos_st = u_ed + max_relative_position + 1 |
|
summary_vector_ranges = torch.arange( |
|
summary_pos_st, summary_pos_st + num_segs |
|
) |
|
|
|
key_ranges = torch.cat( |
|
[ |
|
memory_bank_ranges, |
|
right_context_ranges, |
|
left_context_ranges, |
|
utterance_ranges, |
|
] |
|
) |
|
|
|
query_ranges = torch.cat( |
|
[right_context_ranges, utterance_ranges, summary_vector_ranges] |
|
) |
|
else: |
|
key_ranges = torch.cat( |
|
[right_context_ranges, left_context_ranges, utterance_ranges] |
|
) |
|
|
|
query_ranges = torch.cat([right_context_ranges, utterance_ranges]) |
|
|
|
distance = key_ranges[None, :] - query_ranges[:, None] |
|
distance_clamp = ( |
|
torch.clamp(distance, -max_relative_position, max_relative_position) |
|
+ max_relative_position |
|
) |
|
distance_clamp = distance_clamp.to(input.device).long().detach() |
|
return distance_clamp |
|
|
|
def _get_attention_mask(self, input, past_length=0, left_context_cache=0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utterance_length, batch_size, _ = input.shape |
|
summary_length = math.ceil(utterance_length / self.segment_size) |
|
num_segs = summary_length |
|
rc_length = self.right_context * num_segs |
|
rc = self.right_context |
|
lc = self.left_context |
|
|
|
|
|
|
|
lcc = left_context_cache |
|
|
|
|
|
|
|
if self.use_mem: |
|
mem_length = num_segs - 1 + past_length |
|
else: |
|
mem_length = 0 |
|
rc_mask = [] |
|
query_mask = [] |
|
summary_mask = [] |
|
for j in range(0, num_segs): |
|
ssize = min(self.segment_size, utterance_length - j * self.segment_size) |
|
|
|
rc_size = rc |
|
rc_mat = [] |
|
q_mat = [] |
|
s_mat = [] |
|
m_start = max(j + past_length - self.max_memory_size, 0) |
|
|
|
|
|
if self.use_mem: |
|
|
|
rc_mat.append(input.new_zeros(rc_size, m_start)) |
|
q_mat.append(input.new_zeros(ssize, m_start)) |
|
s_mat.append(input.new_zeros(1, m_start)) |
|
|
|
|
|
col_1 = j + past_length - m_start |
|
rc_mat.append(torch.ones(rc_size, col_1, device=input.device)) |
|
q_mat.append(torch.ones(ssize, col_1, device=input.device)) |
|
|
|
|
|
s_mat.append(input.new_zeros(1, col_1)) |
|
|
|
|
|
col_2 = mem_length - (j + past_length) |
|
rc_mat.append(input.new_zeros(rc_size, col_2)) |
|
q_mat.append(input.new_zeros(ssize, col_2)) |
|
s_mat.append(input.new_zeros(1, col_2)) |
|
|
|
|
|
rc_start = j * rc |
|
rc_mat.append(input.new_zeros(rc_size, rc_start)) |
|
q_mat.append(input.new_zeros(ssize, rc_start)) |
|
s_mat.append(input.new_zeros(1, rc_start)) |
|
|
|
|
|
rc_end = rc_start + rc |
|
col_4 = rc |
|
rc_mat.append(torch.ones(rc_size, col_4, device=input.device)) |
|
q_mat.append(torch.ones(ssize, col_4, device=input.device)) |
|
s_mat.append(torch.ones(1, col_4, device=input.device)) |
|
|
|
|
|
col_5 = rc_length - rc_end |
|
rc_mat.append(input.new_zeros(rc_size, col_5)) |
|
q_mat.append(input.new_zeros(ssize, col_5)) |
|
s_mat.append(input.new_zeros(1, col_5)) |
|
|
|
|
|
seg_start = max(j * self.segment_size + lcc - lc, 0) |
|
rc_mat.append(input.new_zeros(rc_size, seg_start)) |
|
q_mat.append(input.new_zeros(ssize, seg_start)) |
|
s_mat.append(input.new_zeros(1, seg_start)) |
|
|
|
|
|
|
|
|
|
seg_end = min((j + 1) * self.segment_size + lcc, utterance_length + lcc) |
|
col_7 = seg_end - seg_start |
|
rc_mat.append(torch.ones(rc_size, col_7, device=input.device)) |
|
q_mat.append(torch.ones(ssize, col_7, device=input.device)) |
|
s_mat.append(torch.ones(1, col_7, device=input.device)) |
|
|
|
|
|
col_8 = utterance_length + lcc - seg_end |
|
rc_mat.append(input.new_zeros(rc_size, col_8)) |
|
q_mat.append(input.new_zeros(ssize, col_8)) |
|
s_mat.append(input.new_zeros(1, col_8)) |
|
|
|
rc_mask.append(torch.cat(rc_mat, dim=1)) |
|
query_mask.append(torch.cat(q_mat, dim=1)) |
|
summary_mask.append(torch.cat(s_mat, dim=1)) |
|
|
|
|
|
if self.use_mem: |
|
attention_mask = ( |
|
1 |
|
- torch.cat( |
|
[ |
|
torch.cat(rc_mask, dim=0), |
|
torch.cat(query_mask, dim=0), |
|
torch.cat(summary_mask, dim=0), |
|
], |
|
dim=0, |
|
) |
|
).to(torch.bool) |
|
else: |
|
attention_mask = ( |
|
1 |
|
- torch.cat( |
|
[torch.cat(rc_mask, dim=0), torch.cat(query_mask, dim=0)], dim=0 |
|
) |
|
).to(torch.bool) |
|
|
|
return attention_mask |
|
|
|
@torch.jit.export |
|
def init_state( |
|
self, batch_size: int, device: Optional[Device] = None |
|
) -> List[Tensor]: |
|
empty_memory = torch.zeros( |
|
self.num_layers, |
|
self.max_memory_size, |
|
batch_size, |
|
self.memory_dim, |
|
device=device, |
|
) |
|
left_context_key = torch.zeros( |
|
self.num_layers, |
|
self.left_context, |
|
batch_size, |
|
self.memory_dim, |
|
device=device, |
|
) |
|
left_context_val = torch.zeros( |
|
self.num_layers, |
|
self.left_context, |
|
batch_size, |
|
self.memory_dim, |
|
device=device, |
|
) |
|
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device) |
|
|
|
return [empty_memory, left_context_key, left_context_val, past_length] |
|
|
|
@torch.jit.export |
|
def batch_state(self, states: List[List[Tensor]]) -> List[Tensor]: |
|
if len(states) == 0: |
|
return [] |
|
batched_m = [] |
|
batched_lc_key = [] |
|
batched_lc_val = [] |
|
batched_past_length = [] |
|
for state in states: |
|
if len(state) == 0: |
|
continue |
|
m, lc_key, lc_val, past_length = state |
|
batched_m.append(m) |
|
batched_lc_key.append(lc_key) |
|
batched_lc_val.append(lc_val) |
|
batched_past_length.append(past_length) |
|
|
|
if ( |
|
(len(batched_m) == 0) |
|
or (len(batched_lc_key) == 0) |
|
or (len(batched_lc_val) == 0) |
|
or (len(batched_past_length) == 0) |
|
): |
|
return [ |
|
torch.tensor([]), |
|
torch.tensor([]), |
|
torch.tensor([]), |
|
torch.tensor([]), |
|
] |
|
|
|
batched_m = torch.cat(batched_m, dim=2) |
|
batched_lc_key = torch.cat(batched_lc_key, dim=2) |
|
batched_lc_val = torch.cat(batched_lc_val, dim=2) |
|
batched_past_length = torch.cat(batched_past_length, dim=1) |
|
return [batched_m, batched_lc_key, batched_lc_val, batched_past_length] |
|
|
|
@torch.jit.export |
|
def reorder_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: |
|
if len(state) == 0: |
|
return [] |
|
m, lc_key, lc_val, past_length = state |
|
indices = indices.to(device=m.device) |
|
reord_m = torch.index_select(m, 2, indices) |
|
reord_lc_key = torch.index_select(lc_key, 2, indices) |
|
reord_lc_val = torch.index_select(lc_val, 2, indices) |
|
reord_past_length = torch.index_select(past_length, 1, indices) |
|
return [reord_m, reord_lc_key, reord_lc_val, reord_past_length] |
|
|
|
@torch.jit.export |
|
def reset_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: |
|
m, lc_key, lc_val, past_length = state |
|
m = m.index_fill(dim=2, index=indices, value=0.0) |
|
lc_key = lc_key.index_fill(dim=2, index=indices, value=0.0) |
|
lc_val = lc_val.index_fill(dim=2, index=indices, value=0.0) |
|
past_length = past_length.index_fill(dim=1, index=indices, value=0) |
|
|
|
return [m, lc_key, lc_val, past_length] |
|
|
|
@torch.jit.export |
|
def state_size(self) -> int: |
|
return 4 |
|
|
|
@torch.jit.export |
|
def batch_size_in_state( |
|
self, state: Optional[List[Tensor]], sloppy: bool = True |
|
) -> Optional[int]: |
|
if state is None: |
|
return None |
|
return state[0].size(2) |
|
|
|
def gen_summary_queries(self, input): |
|
sum_input = self.memory_op(input) |
|
return sum_input |
|
|
|
def _gen_right_context_padded_input(self, input): |
|
|
|
|
|
right_context_blocks = [] |
|
T, B, D = input.shape |
|
num_segs = math.ceil((T - self.right_context) / self.segment_size) |
|
for i in range(0, num_segs - 1): |
|
st = (i + 1) * self.segment_size |
|
ed = st + self.right_context |
|
assert ed < T |
|
temp = input[st:ed, :, :] |
|
right_context_blocks.append(temp) |
|
|
|
|
|
right_context_blocks.append(input[T - self.right_context :, :, :]) |
|
return torch.cat(right_context_blocks, dim=0) |
|
|
|
def _gen_segs_right_context(self, input, lengths): |
|
segments = [] |
|
T, B, D = input.size() |
|
nT = T - self.right_context |
|
|
|
|
|
num_segs = math.ceil(nT / self.segment_size) |
|
|
|
|
|
for i in range(0, num_segs - 1): |
|
st = i * self.segment_size |
|
ed = min(T, st + self.segment_size + self.right_context) |
|
temp = input[st:ed, :, :] |
|
rest_lengths = torch.clamp( |
|
lengths - self.segment_size, min=0, max=nT - (i + 1) * self.segment_size |
|
) |
|
segments.append((temp, lengths - rest_lengths + self.right_context)) |
|
lengths = rest_lengths |
|
|
|
last_seg = input[st + self.segment_size :, :, :] |
|
segments.append((last_seg, rest_lengths + self.right_context)) |
|
|
|
return segments |
|
|
|
@torch.jit.unused |
|
def forward( |
|
self, input: Tensor, padding_masks: Tensor, state: Optional[List[Tensor]] = None |
|
) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: |
|
|
|
lengths = (~padding_masks).sum(dim=1).long() |
|
|
|
if self.mini_batches: |
|
return self.forward_mini_batches(input, lengths, state) |
|
|
|
|
|
|
|
T, B, D = input.size() |
|
right_context_blocks = self._gen_right_context_padded_input(input) |
|
|
|
|
|
if self.use_rpe: |
|
rpe = self._get_relative_position( |
|
input=input, |
|
max_relative_position=self.max_relative_position, |
|
left_context_length=0, |
|
past_length=0, |
|
is_decoding=False, |
|
) |
|
else: |
|
rpe = None |
|
input = input[: T - self.right_context, :, :] |
|
|
|
attention_mask = self._get_attention_mask(input) |
|
|
|
|
|
|
|
if self.use_mem: |
|
mems = self.gen_summary_queries(input)[:-1, :, :] |
|
else: |
|
mems = torch.zeros(0, input.size(1), input.size(2), device=input.device) |
|
mems = mems.type_as(input) |
|
|
|
output = input |
|
all_outputs = [] |
|
|
|
for layer in self.layers: |
|
output, mems, right_context_blocks, _, _ = layer( |
|
input=output, |
|
lengths=lengths, |
|
attention_mask=attention_mask, |
|
mems=mems, |
|
right_context_blocks=right_context_blocks, |
|
pre_mems=None, |
|
left_context_key=None, |
|
left_context_val=None, |
|
rpe=rpe, |
|
) |
|
all_outputs.append(output) |
|
return output, padding_masks, [], all_outputs |
|
|
|
def forward_jit_mini_batch_init( |
|
self, |
|
seg: Tensor, |
|
state: Optional[List[Tensor]] = None, |
|
is_decoding: bool = False, |
|
): |
|
|
|
|
|
if state is None: |
|
state = self.init_state(batch_size=seg.size(1), device=seg.device) |
|
if seg.dtype == torch.half: |
|
state = [state[0].half(), state[1].half(), state[2].half(), state[3]] |
|
|
|
if self.use_mem: |
|
|
|
|
|
|
|
full_mems = self.gen_summary_queries(seg) |
|
if is_decoding: |
|
mems = full_mems[0:1, :, :] |
|
state_mems = torch.cat([state[0][0], mems], dim=0) |
|
else: |
|
mems = full_mems[:-1, :, :] |
|
state_mems = torch.cat([state[0][0], full_mems], dim=0) |
|
else: |
|
mems = state[0][0] |
|
state_mems = mems |
|
|
|
|
|
|
|
past_length = state[3][0][0].item() |
|
past_left_context = min(past_length * self.segment_size, self.left_context) |
|
past_length = min(self.max_memory_size, past_length) |
|
|
|
return state, mems, state_mems, past_length, past_left_context |
|
|
|
def state_update_before( |
|
self, layer: int, state: List[Tensor], past_length: int, past_left_context: int |
|
): |
|
pre_mems = state[0][layer][self.max_memory_size - past_length :, :, :] |
|
lc_key = state[1][layer][self.left_context - past_left_context :, :, :] |
|
lc_val = state[2][layer][self.left_context - past_left_context :, :, :] |
|
return pre_mems, lc_key, lc_val |
|
|
|
def state_update_after( |
|
self, |
|
layer: int, |
|
state: List[Tensor], |
|
mems: Tensor, |
|
next_key: Tensor, |
|
next_val: Tensor, |
|
mems_list: List[Tensor], |
|
lc_key_list: List[Tensor], |
|
lc_val_list: List[Tensor], |
|
): |
|
|
|
if layer < self.num_layers - 1: |
|
state_mems = torch.cat([state[0][layer + 1], mems], dim=0) |
|
mems_list.append(state_mems[-self.max_memory_size :, :, :]) |
|
|
|
|
|
|
|
mems = mems[:-1, :, :] |
|
|
|
|
|
new_k = torch.cat([state[1][layer], next_key], dim=0) |
|
new_v = torch.cat([state[2][layer], next_val], dim=0) |
|
lc_key_list.append(new_k[-self.left_context :, :, :]) |
|
lc_val_list.append(new_v[-self.left_context :, :, :]) |
|
return mems_list, lc_key_list, lc_val_list, mems |
|
|
|
def state_update_after_loop( |
|
self, |
|
state: List[Tensor], |
|
mems_list: List[Tensor], |
|
lc_key_list: List[Tensor], |
|
lc_val_list: List[Tensor], |
|
update_length: int, |
|
): |
|
state[0] = torch.stack(mems_list, dim=0) |
|
state[1] = torch.stack(lc_key_list, dim=0) |
|
state[2] = torch.stack(lc_val_list, dim=0) |
|
state[3] = state[3] + update_length |
|
return state |
|
|
|
@torch.jit.unused |
|
def forward_mini_batches( |
|
self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None |
|
) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: |
|
T, B, D = input.size() |
|
|
|
|
|
seg = input[: T - self.right_context, :, :] |
|
|
|
|
|
right_context_blocks = self._gen_right_context_padded_input(input) |
|
|
|
mems_list = [] |
|
lc_key_list = [] |
|
lc_val_list = [] |
|
results = self.forward_jit_mini_batch_init(seg, state, False) |
|
state, mems, state_mems, past_length, past_left_context = results |
|
|
|
|
|
if self.use_rpe: |
|
rpe = self._get_relative_position( |
|
input=input, |
|
max_relative_position=self.max_relative_position, |
|
left_context_length=past_left_context, |
|
past_length=past_length, |
|
is_decoding=False, |
|
) |
|
else: |
|
rpe = None |
|
|
|
|
|
|
|
attention_mask = self._get_attention_mask(seg, past_length, past_left_context) |
|
mems_list.append(state_mems[-self.max_memory_size :, :, :]) |
|
output = seg |
|
i = 0 |
|
all_outputs = [] |
|
for layer in self.layers: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pre_mems, lc_key, lc_val = self.state_update_before( |
|
i, state, past_length, past_left_context |
|
) |
|
|
|
output, mems, right_context_blocks, next_key, next_val = layer.forward( |
|
input=output, |
|
lengths=lengths, |
|
attention_mask=attention_mask, |
|
mems=mems, |
|
right_context_blocks=right_context_blocks, |
|
pre_mems=pre_mems, |
|
left_context_key=lc_key, |
|
left_context_val=lc_val, |
|
rpe=rpe, |
|
) |
|
all_outputs.append(output) |
|
mems_list, lc_key_list, lc_val_list, mems = self.state_update_after( |
|
layer=i, |
|
state=state, |
|
mems=mems, |
|
next_key=next_key, |
|
next_val=next_val, |
|
mems_list=mems_list, |
|
lc_key_list=lc_key_list, |
|
lc_val_list=lc_val_list, |
|
) |
|
|
|
i += 1 |
|
|
|
|
|
update_length = math.ceil((T - self.right_context) / self.segment_size) |
|
state = self.state_update_after_loop( |
|
state=state, |
|
mems_list=mems_list, |
|
lc_key_list=lc_key_list, |
|
lc_val_list=lc_val_list, |
|
update_length=update_length, |
|
) |
|
|
|
return output, lengths, state, all_outputs |
|
|
|
def forward_jit_test( |
|
self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None |
|
) -> Tuple[Tensor, Tensor, List[Tensor]]: |
|
""" |
|
This one simulate sequence encoder forward jit. This is for unit test purpose. |
|
It is not used in training or decoding. Note, extra_right_context is set in |
|
the model. In unit test, input = [utterance, right_context], lengths = |
|
[utterance_length]. |
|
args: |
|
input: input utterance |
|
lengths: utterance input length |
|
state: None here. input is whole utterance |
|
""" |
|
|
|
seg_src_tokens_lengths = self._gen_segs_right_context(input, lengths) |
|
|
|
seg_enc_tokens_lengths: List[Tuple[Tensor, Tensor]] = [] |
|
state: Optional[List[Tensor]] = None |
|
for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: |
|
seg_enc_tokens, seg_enc_lengths, state = self.forward_jit( |
|
input=seg_src_tokens, lengths=seg_src_lengths, state=state |
|
) |
|
seg_enc_tokens_lengths.append((seg_enc_tokens, seg_enc_lengths)) |
|
|
|
enc_tokens, enc_lengths = segments_to_sequence( |
|
segments=seg_enc_tokens_lengths, time_axis=0 |
|
) |
|
|
|
state = [] |
|
|
|
return enc_tokens, enc_lengths, state |
|
|
|
@torch.jit.export |
|
def forward_jit( |
|
self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None |
|
) -> Tuple[Tensor, Tensor, List[Tensor]]: |
|
""" |
|
Forward helper for online decoding. |
|
|
|
args: |
|
input: [seg, right_context]. We assume in online we |
|
always padding the right context to the preset right context size. |
|
For the last segment, we may have short segment size, but right |
|
context size is the same as other segments |
|
lengths: utterance input length is the utterance segment length and |
|
right context size |
|
state: [memory, left_context_key, left_context_val]. To improve throughput, |
|
in addition to memory, we also cache key and value for left_context in |
|
multihead self-attention |
|
""" |
|
|
|
|
|
|
|
T, B, D = input.size() |
|
rc_str = T - self.right_context |
|
rc_end = T |
|
right_context_blocks = input[rc_str:rc_end, :, :] |
|
seg = input[:rc_str, :, :] |
|
lengths = torch.clamp(lengths - self.right_context, min=0) |
|
mems_list = [] |
|
lc_key_list = [] |
|
lc_val_list = [] |
|
|
|
results = self.forward_jit_mini_batch_init(seg, state, True) |
|
state, mems, state_mems, past_length, past_left_context = results |
|
|
|
|
|
if self.use_rpe: |
|
rpe = self._get_relative_position( |
|
input=input, |
|
max_relative_position=self.max_relative_position, |
|
left_context_length=past_left_context, |
|
past_length=past_length, |
|
is_decoding=True, |
|
) |
|
else: |
|
rpe = None |
|
|
|
|
|
mems_list.append(state_mems[-self.max_memory_size :, :, :]) |
|
output = seg |
|
i = 0 |
|
for layer in self.layers: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
true_mems, lc_key, lc_val = self.state_update_before( |
|
layer=i, |
|
state=state, |
|
past_length=past_length, |
|
past_left_context=past_left_context, |
|
) |
|
|
|
output, mems, right_context_blocks, next_key, next_val = layer.forward_jit( |
|
input=output, |
|
lengths=lengths, |
|
mems=true_mems, |
|
right_context_blocks=right_context_blocks, |
|
left_context_key=lc_key, |
|
left_context_val=lc_val, |
|
rpe=rpe, |
|
) |
|
|
|
mems_list, lc_key_list, lc_val_list, _ = self.state_update_after( |
|
layer=i, |
|
state=state, |
|
mems_list=mems_list, |
|
mems=mems, |
|
next_key=next_key, |
|
next_val=next_val, |
|
lc_key_list=lc_key_list, |
|
lc_val_list=lc_val_list, |
|
) |
|
i += 1 |
|
|
|
|
|
state = self.state_update_after_loop( |
|
state=state, |
|
mems_list=mems_list, |
|
lc_key_list=lc_key_list, |
|
lc_val_list=lc_val_list, |
|
update_length=1, |
|
) |
|
|
|
return output, lengths, state |
|
|
|
def quantize_(self, params=None): |
|
if params and "per_channel" in params and params["per_channel"]: |
|
qconfig = per_channel_dynamic_qconfig |
|
else: |
|
qconfig = default_dynamic_qconfig |
|
torch.quantization.quantize_dynamic( |
|
self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True |
|
) |
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
def emformer_encoder(klass): |
|
class SpeechEncoder(klass): |
|
def __init__(self, args): |
|
super().__init__(args) |
|
stride = SpeechEncoder.conv_layer_stride(args) |
|
trf_left_context = args.segment_left_context // stride |
|
trf_right_context = args.segment_right_context // stride |
|
context_config = [trf_left_context, trf_right_context] |
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
NoSegAugmentedMemoryTransformerEncoderLayer( |
|
input_dim=args.encoder_embed_dim, |
|
num_heads=args.encoder_attention_heads, |
|
ffn_dim=args.encoder_ffn_embed_dim, |
|
num_layers=args.encoder_layers, |
|
dropout_in_attn=args.dropout, |
|
dropout_on_attn=args.dropout, |
|
dropout_on_fc1=args.dropout, |
|
dropout_on_fc2=args.dropout, |
|
activation_fn=args.activation_fn, |
|
context_config=context_config, |
|
segment_size=args.segment_length, |
|
max_memory_size=args.max_memory_size, |
|
scaled_init=True, |
|
tanh_on_mem=args.amtrf_tanh_on_mem, |
|
) |
|
] |
|
) |
|
|
|
def forward(self, src_tokens, src_lengths): |
|
encoder_out = super().forward(src_tokens, src_lengths) |
|
output = encoder_out["encoder_out"][0] |
|
encoder_padding_masks = encoder_out["encoder_padding_mask"][0] |
|
|
|
|
|
|
|
encoder_padding_masks = encoder_padding_masks[:, : output.size(0)] |
|
|
|
return { |
|
"encoder_out": [output], |
|
"encoder_padding_mask": [encoder_padding_masks], |
|
"encoder_embedding": [], |
|
"encoder_states": [], |
|
"src_tokens": [], |
|
"src_lengths": [], |
|
} |
|
|
|
@staticmethod |
|
def conv_layer_stride(args): |
|
|
|
return 4 |
|
|
|
SpeechEncoder.__name__ = klass.__name__ |
|
return SpeechEncoder |
|
|