Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
import torch.nn.init as init | |
import math | |
from .mhsa_pro import MHA_rotary, MHA_decoder | |
from .cnn import ConvBlock, ConvBlockDecoder | |
from typing import Optional,Tuple | |
class ResidualConnectionModule(nn.Module): | |
""" | |
Residual Connection Module. | |
outputs = (module(inputs) x module_factor + inputs x input_factor) | |
""" | |
def __init__(self, module: nn.Module, dims, args): | |
super(ResidualConnectionModule, self).__init__() | |
self.module = module | |
self.module_factor = 1 | |
self.input_factor = 1 | |
def forward(self, inputs: Tensor, **kwargs) -> Tensor: | |
return (self.module(inputs, **kwargs) * self.module_factor) + (inputs * self.input_factor) | |
class PostNorm(nn.Module): | |
""" | |
Residual Connection Module. | |
outputs = (module(inputs) x module_factor + inputs x input_factor) | |
""" | |
def __init__(self, module: nn.Module, dims, args): | |
super(PostNorm, self).__init__() | |
self.module = module | |
input_factor = torch.FloatTensor(args.alpha) if getattr(args, 'alpha', None) else torch.tensor(1.) | |
self.register_buffer('input_factor', input_factor) | |
self.norm = nn.LayerNorm(dims) | |
def forward(self, inputs: Tensor, **kwargs) -> Tensor: | |
return self.norm(self.module(inputs, **kwargs) + (inputs * self.input_factor)) | |
class Linear(nn.Module): | |
""" | |
Wrapper class of torch.nn.Linear | |
Weight initialize by xavier initialization and bias initialize to zeros. | |
""" | |
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: | |
super(Linear, self).__init__() | |
self.linear = nn.Linear(in_features, out_features, bias=bias) | |
init.xavier_uniform_(self.linear.weight) | |
if bias: | |
init.zeros_(self.linear.bias) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.linear(x) | |
class View(nn.Module): | |
""" Wrapper class of torch.view() for Sequential module. """ | |
def __init__(self, shape: tuple, contiguous: bool = False): | |
super(View, self).__init__() | |
self.shape = shape | |
self.contiguous = contiguous | |
def forward(self, x: Tensor) -> Tensor: | |
if self.contiguous: | |
x = x.contiguous() | |
return x.view(*self.shape) | |
class Transpose(nn.Module): | |
""" Wrapper class of torch.transpose() for Sequential module. """ | |
def __init__(self, shape: tuple): | |
super(Transpose, self).__init__() | |
self.shape = shape | |
def forward(self, x: Tensor) -> Tensor: | |
return x.transpose(*self.shape) | |
class FeedForwardModule(nn.Module): | |
""" | |
Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit | |
and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps | |
regularizing the network. | |
Args: | |
encoder_dim (int): Dimension of conformer encoder | |
expansion_factor (int): Expansion factor of feed forward module. | |
dropout_p (float): Ratio of dropout | |
device (torch.device): torch device (cuda or cpu) | |
Inputs: inputs | |
- **inputs** (batch, time, dim): Tensor contains input sequences | |
Outputs: outputs | |
- **outputs** (batch, time, dim): Tensor produces by feed forward module. | |
""" | |
def __init__( | |
self, | |
args, | |
) -> None: | |
super(FeedForwardModule, self).__init__() | |
expansion_factor = 4 | |
self.sequential = nn.Sequential( | |
nn.LayerNorm(args.encoder_dim), | |
Linear(args.encoder_dim, args.encoder_dim * expansion_factor, bias=True), | |
nn.SiLU(), | |
nn.Dropout(p=args.dropout_p), | |
Linear(args.encoder_dim * expansion_factor, args.encoder_dim, bias=True), | |
nn.Dropout(p=args.dropout_p), | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.sequential(inputs) | |
class DepthwiseConv1d(nn.Module): | |
""" | |
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, | |
this operation is termed in literature as depthwise convolution. | |
Args: | |
in_channels (int): Number of channels in the input | |
out_channels (int): Number of channels produced by the convolution | |
kernel_size (int or tuple): Size of the convolving kernel | |
stride (int, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
Inputs: inputs | |
- **inputs** (batch, in_channels, time): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
padding: int = 0, | |
bias: bool = False, | |
) -> None: | |
super(DepthwiseConv1d, self).__init__() | |
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
groups=in_channels, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.conv(inputs) | |
class PointwiseConv1d(nn.Module): | |
""" | |
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. | |
This operation often used to match dimensions. | |
Args: | |
in_channels (int): Number of channels in the input | |
out_channels (int): Number of channels produced by the convolution | |
stride (int, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
Inputs: inputs | |
- **inputs** (batch, in_channels, time): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
stride: int = 1, | |
padding: int = 0, | |
bias: bool = True, | |
) -> None: | |
super(PointwiseConv1d, self).__init__() | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.conv(inputs) | |
class ConformerConvModule(nn.Module): | |
""" | |
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). | |
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution | |
to aid training deep models. | |
Args: | |
in_channels (int): Number of channels in the input | |
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 | |
dropout_p (float, optional): probability of dropout | |
Inputs: inputs | |
inputs (batch, time, dim): Tensor contains input sequences | |
Outputs: outputs | |
outputs (batch, time, dim): Tensor produces by conformer convolution module. | |
""" | |
def __init__( | |
self, | |
args, | |
) -> None: | |
super(ConformerConvModule, self).__init__() | |
assert (args.kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" | |
expansion_factor = 2 | |
dropout_p = 0.1 | |
self.sequential = nn.Sequential( | |
nn.LayerNorm(args.encoder_dim), | |
Transpose(shape=(1, 2)), | |
PointwiseConv1d(args.encoder_dim, args.encoder_dim * expansion_factor, stride=1, padding=0, bias=True), | |
nn.GLU(dim=1), | |
DepthwiseConv1d(args.encoder_dim, args.encoder_dim, args.kernel_size, stride=1, padding=(args.kernel_size - 1) // 2), | |
nn.BatchNorm1d(args.encoder_dim), | |
nn.SiLU(), | |
PointwiseConv1d(args.encoder_dim, args.encoder_dim, stride=1, padding=0, bias=True), | |
nn.Dropout(p=dropout_p), | |
) | |
def forward(self, inputs: Tensor) -> Tensor: | |
return self.sequential(inputs).transpose(1, 2) | |
class PositionalEncoding(nn.Module): | |
""" | |
Positional Encoding proposed in "Attention Is All You Need". | |
Since transformer contains no recurrence and no convolution, in order for the model to make | |
use of the order of the sequence, we must add some positional information. | |
"Attention Is All You Need" use sine and cosine functions of different frequencies: | |
PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) | |
PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) | |
""" | |
def __init__(self, d_model: int = 128, max_len: int = 10000) -> None: | |
super(PositionalEncoding, self).__init__() | |
pe = torch.zeros(max_len, d_model, requires_grad=False) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer('pe', pe) | |
def forward(self, length: int) -> Tensor: | |
return self.pe[:, :length] | |
class RelativeMultiHeadAttention(nn.Module): | |
""" | |
Multi-head attention with relative positional encoding. | |
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" | |
Args: | |
d_model (int): The dimension of model | |
num_heads (int): The number of attention heads. | |
dropout_p (float): probability of dropout | |
Inputs: query, key, value, pos_embedding, mask | |
- **query** (batch, time, dim): Tensor containing query vector | |
- **key** (batch, time, dim): Tensor containing key vector | |
- **value** (batch, time, dim): Tensor containing value vector | |
- **pos_embedding** (batch, time, dim): Positional embedding tensor | |
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked | |
Returns: | |
- **outputs**: Tensor produces by relative multi head attention module. | |
""" | |
def __init__( | |
self, | |
encoder_dim: int = 128, | |
num_heads: int = 8, | |
dropout_p: float = 0.1 | |
): | |
super(RelativeMultiHeadAttention, self).__init__() | |
assert encoder_dim % num_heads == 0, "d_model % num_heads should be zero." | |
self.d_model = encoder_dim | |
self.d_head = int(encoder_dim / num_heads) | |
self.num_heads = num_heads | |
self.sqrt_dim = math.sqrt(encoder_dim) | |
self.query_proj = Linear(encoder_dim, encoder_dim) | |
self.key_proj = Linear(encoder_dim, encoder_dim) | |
self.value_proj = Linear(encoder_dim, encoder_dim) | |
self.pos_proj = Linear(encoder_dim, encoder_dim, bias=False) | |
self.dropout = nn.Dropout(p=dropout_p) | |
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) | |
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) | |
torch.nn.init.xavier_uniform_(self.u_bias) | |
torch.nn.init.xavier_uniform_(self.v_bias) | |
self.out_proj = Linear(encoder_dim, encoder_dim) | |
def forward( | |
self, | |
query: Tensor, | |
key: Tensor, | |
value: Tensor, | |
pos_embedding: Tensor, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
batch_size = value.size(0) | |
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) | |
query = query.view(batch_size, -1, self.num_heads, self.d_head) | |
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) | |
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) | |
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) | |
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) | |
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) | |
# content_score = torch.matmul((query).transpose(1, 2), key.transpose(2, 3)) | |
# pos_score = torch.matmul((query).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) | |
#Q(B,numheads,length,d_head)*PE(B,numheads,d_heads,length) = posscore(B,num_heads,length,length) | |
pos_score = self._relative_shift(pos_score) | |
score = (content_score + pos_score) / self.sqrt_dim | |
if mask is not None: | |
mask = mask.unsqueeze(1) | |
score.masked_fill_(mask, -1e9) | |
score = F.softmax(score, -1) | |
attn = self.dropout(score) | |
context = torch.matmul(attn, value).transpose(1, 2) | |
context = context.contiguous().view(batch_size, -1, self.d_model) | |
return self.out_proj(context) | |
def _relative_shift(self, pos_score: Tensor) -> Tensor: | |
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() | |
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) | |
padded_pos_score = torch.cat([zeros, pos_score], dim=-1) | |
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) | |
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) | |
#shift position score a unit along length axis and leave a blank row. | |
return pos_score | |
class MultiHeadedSelfAttentionModule(nn.Module): | |
""" | |
Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, | |
the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention | |
module to generalize better on different input length and the resulting encoder is more robust to the variance of | |
the utterance length. Conformer use prenorm residual units with dropout which helps training | |
and regularizing deeper models. | |
Args: | |
d_model (int): The dimension of model | |
num_heads (int): The number of attention heads. | |
dropout_p (float): probability of dropout | |
device (torch.device): torch device (cuda or cpu) | |
Inputs: inputs, mask | |
- **inputs** (batch, time, dim): Tensor containing input vector | |
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked | |
Returns: | |
- **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. | |
""" | |
def __init__(self, args): | |
super(MultiHeadedSelfAttentionModule, self).__init__() | |
dropout_p = 0.1 | |
self.positional_encoding = PositionalEncoding(args.encoder_dim) | |
self.layer_norm = nn.LayerNorm(args.encoder_dim) | |
self.attention = RelativeMultiHeadAttention(args.encoder_dim, args.num_heads, args.dropout_p) | |
self.dropout = nn.Dropout(p=dropout_p) | |
def forward(self, inputs: Tensor, mask: Optional[Tensor] = None): | |
batch_size, seq_length, _ = inputs.size() | |
pos_embedding = self.positional_encoding(seq_length) | |
pos_embedding = pos_embedding.repeat(batch_size, 1, 1) | |
inputs = self.layer_norm(inputs) | |
outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask) | |
return self.dropout(outputs) | |
class ConformerBlock(nn.Module): | |
""" | |
Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module | |
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing | |
the original feed-forward layer in the Transformer block into two half-step feed-forward layers, | |
one before the attention layer and one after. | |
Args: | |
encoder_dim (int, optional): Dimension of conformer encoder | |
num_attention_heads (int, optional): Number of attention heads | |
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module | |
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module | |
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout | |
attention_dropout_p (float, optional): Probability of attention module dropout | |
conv_dropout_p (float, optional): Probability of conformer convolution module dropout | |
conv_kernel_size (int or tuple, optional): Size of the convolving kernel | |
half_step_residual (bool): Flag indication whether to use half step residual or not | |
device (torch.device): torch device (cuda or cpu) | |
Inputs: inputs | |
- **inputs** (batch, time, dim): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, time, dim): Tensor produces by conformer block. | |
""" | |
def __init__( | |
self, | |
args | |
): | |
super(ConformerBlock, self).__init__() | |
norm_dict = { | |
'shortcut': ResidualConnectionModule, | |
'postnorm': PostNorm | |
} | |
block_dict = { | |
'ffn': FeedForwardModule, | |
'mhsa': MultiHeadedSelfAttentionModule, | |
'mhsa_pro': MHA_rotary, | |
'conv': ConvBlock, | |
'conformerconv': ConformerConvModule | |
} | |
self.modlist = nn.ModuleList([norm_dict[args.norm](block_dict[block](args), args.encoder_dim, args) for block in args.encoder]\ | |
) | |
def forward(self, x: Tensor, RoPE, key_padding_mask=None) -> Tensor: | |
for m in self.modlist: | |
if isinstance(m.module, MHA_rotary): | |
x = m(x, RoPE=RoPE, key_padding_mask=key_padding_mask) | |
else: | |
x = m(x) | |
return x | |
class DecoderBlock(nn.Module): | |
""" | |
Decoder block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module | |
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing | |
the original feed-forward layer in the Transformer block into two half-step feed-forward layers, | |
one before the attention layer and one after. | |
Args: | |
encoder_dim (int, optional): Dimension of conformer encoder | |
num_attention_heads (int, optional): Number of attention heads | |
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module | |
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module | |
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout | |
attention_dropout_p (float, optional): Probability of attention module dropout | |
conv_dropout_p (float, optional): Probability of conformer convolution module dropout | |
conv_kernel_size (int or tuple, optional): Size of the convolving kernel | |
half_step_residual (bool): Flag indication whether to use half step residual or not | |
device (torch.device): torch device (cuda or cpu) | |
Inputs: inputs | |
- **inputs** (batch, time, dim): Tensor containing input vector | |
Returns: outputs | |
- **outputs** (batch, time, dim): Tensor produces by conformer block. | |
""" | |
def __init__( | |
self, | |
args | |
): | |
super(DecoderBlock, self).__init__() | |
norm_dict = { | |
'shortcut': ResidualConnectionModule, | |
'postnorm': PostNorm | |
} | |
block_dict = { | |
'ffn': FeedForwardModule, | |
'mhsa': MultiHeadedSelfAttentionModule, | |
'mhsa_pro': MHA_rotary, | |
'mhsa_decoder': MHA_decoder, | |
'conv': ConvBlockDecoder, | |
'conformerconv': ConformerConvModule | |
} | |
self.modlist = nn.ModuleList([norm_dict[args.norm](block_dict[block](args),args.decoder_dim, args) for block in args.decoder]\ | |
) | |
def forward(self, x: Tensor, memory:Tensor, RoPE, key_padding_mask=None) -> Tensor: | |
for m in self.modlist: | |
if isinstance(m.module, MHA_decoder): | |
x = m(x, memory=memory, RoPE=RoPE, key_padding_mask=key_padding_mask) | |
elif isinstance(m.module, MHA_rotary): | |
x = m(x, RoPE=RoPE, key_padding_mask=key_padding_mask).transpose(0,1) | |
else: | |
x = m(x) | |
return x | |
class ConformerEncoder(nn.Module): | |
""" | |
Conformer encoder first processes the input with a convolution subsampling layer and then | |
with a number of conformer blocks. | |
Args: | |
input_dim (int, optional): Dimension of input vector | |
encoder_dim (int, optional): Dimension of conformer encoder | |
num_layers (int, optional): Number of conformer blocks | |
num_attention_heads (int, optional): Number of attention heads | |
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module | |
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module | |
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout | |
attention_dropout_p (float, optional): Probability of attention module dropout | |
conv_dropout_p (float, optional): Probability of conformer convolution module dropout | |
conv_kernel_size (int or tuple, optional): Size of the convolving kernel | |
half_step_residual (bool): Flag indication whether to use half step residual or not | |
device (torch.device): torch device (cuda or cpu) | |
Inputs: inputs, input_lengths | |
- **inputs** (batch, time, dim): Tensor containing input vector | |
- **input_lengths** (batch): list of sequence input lengths | |
Returns: outputs, output_lengths | |
- **outputs** (batch, out_channels, time): Tensor produces by conformer encoder. | |
- **output_lengths** (batch): list of sequence output lengths | |
""" | |
def __init__( | |
self, | |
args, | |
): | |
super(ConformerEncoder, self).__init__() | |
self.blocks = nn.ModuleList([ConformerBlock( | |
args) for _ in range(args.num_layers)]) | |
def forward(self, x: Tensor, RoPE=None, key_padding_mask=None) -> Tuple[Tensor, Tensor]: | |
""" | |
Forward propagate a `inputs` for encoder training. | |
Args: | |
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded | |
`FloatTensor` of size ``(batch, seq_length, dimension)``. | |
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` | |
Returns: | |
(Tensor, Tensor) | |
* outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size | |
``(batch, seq_length, dimension)`` | |
* output_lengths (torch.LongTensor): The length of output tensor. ``(batch)`` | |
""" | |
for block in self.blocks: | |
x = block(x, RoPE=RoPE, key_padding_mask=key_padding_mask) | |
return x | |
class ConformerDecoder(nn.Module): | |
""" | |
Conformer encoder first processes the input with a convolution subsampling layer and then | |
with a number of conformer blocks. | |
Args: | |
input_dim (int, optional): Dimension of input vector | |
encoder_dim (int, optional): Dimension of conformer encoder | |
num_layers (int, optional): Number of conformer blocks | |
num_attention_heads (int, optional): Number of attention heads | |
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module | |
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module | |
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout | |
attention_dropout_p (float, optional): Probability of attention module dropout | |
conv_dropout_p (float, optional): Probability of conformer convolution module dropout | |
conv_kernel_size (int or tuple, optional): Size of the convolving kernel | |
half_step_residual (bool): Flag indication whether to use half step residual or not | |
device (torch.device): torch device (cuda or cpu) | |
Inputs: inputs, input_lengths | |
- **inputs** (batch, time, dim): Tensor containing input vector | |
- **input_lengths** (batch): list of sequence input lengths | |
Returns: outputs, output_lengths | |
- **outputs** (batch, out_channels, time): Tensor produces by conformer encoder. | |
- **output_lengths** (batch): list of sequence output lengths | |
""" | |
def __init__( | |
self, | |
args, | |
): | |
super(ConformerDecoder, self).__init__() | |
self.blocks = nn.ModuleList([DecoderBlock( | |
args) for _ in range(args.num_decoder_layers)]) | |
def forward(self, x: Tensor, memory: Tensor, RoPE=None, key_padding_mask=None) -> Tuple[Tensor, Tensor]: | |
""" | |
Forward propagate a `inputs` for encoder training. | |
Args: | |
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded | |
`FloatTensor` of size ``(batch, seq_length, dimension)``. | |
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` | |
Returns: | |
(Tensor, Tensor) | |
* outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size | |
``(batch, seq_length, dimension)`` | |
* output_lengths (torch.LongTensor): The length of output tensor. ``(batch)`` | |
""" | |
for block in self.blocks: | |
x = block(x, memory, RoPE=RoPE, key_padding_mask=key_padding_mask) | |
return x | |