OSUM / wenet /paraformer /subsampling.py
tomxxie
适配zeroGPU
568e264
raw
history blame
1.54 kB
from typing import Tuple, Union
import torch
from wenet.transformer.subsampling import BaseSubsampling
class IdentitySubsampling(BaseSubsampling):
""" Paraformer subsampling
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
super().__init__()
_, _ = idim, odim
self.right_context = 6
self.subsampling_rate = 6
self.pos_enc = pos_enc_class
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[torch.Tensor, int] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time
torch.Tensor: positional encoding
"""
# NOTE(Mddct): Paraformer starts from 1
if isinstance(offset, torch.Tensor):
offset = torch.add(offset, 1)
else:
offset = offset + 1
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return self.pos_enc.position_encoding(offset + 1, size)