Duplicated from Plachta/VALL-E-X
b1e1a76
1
2
3
4
5
6
7
8
9
10
import torch import torch.nn as nn class Transpose(nn.Identity): """(N, T, D) -> (N, D, T)""" def forward(self, input: torch.Tensor) -> torch.Tensor: return input.transpose(1, 2)