File size: 338 Bytes
b1e1a76 52e32c0 b1e1a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
import torch.nn as nn
# from icefall.utils import make_pad_mask
from .symbol_table import SymbolTable
# make_pad_mask = make_pad_mask
SymbolTable = SymbolTable
class Transpose(nn.Identity):
"""(N, T, D) -> (N, D, T)"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.transpose(1, 2)
|