VALL-E-X / utils /__init__.py
Plachta's picture
initial commit
b1e1a76
raw
history blame
198 Bytes
import torch
import torch.nn as nn
class Transpose(nn.Identity):
"""(N, T, D) -> (N, D, T)"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.transpose(1, 2)