|
""" |
|
Copied from https://github.com/KdaiP/StableTTS by https://github.com/KdaiP |
|
|
|
https://github.com/KdaiP/StableTTS/blob/eebb177ebf195fd1246dedabec4ef69d9351a4f8/models/dit.py |
|
|
|
Code is under MIT License |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class FFN(nn.Module): |
|
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.filter_channels = filter_channels |
|
self.kernel_size = kernel_size |
|
self.p_dropout = p_dropout |
|
self.gin_channels = gin_channels |
|
|
|
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) |
|
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) |
|
self.drop = nn.Dropout(p_dropout) |
|
self.act1 = nn.GELU(approximate="tanh") |
|
|
|
def forward(self, x, x_mask): |
|
x = self.conv_1(x * x_mask) |
|
x = self.act1(x) |
|
x = self.drop(x) |
|
x = self.conv_2(x * x_mask) |
|
return x * x_mask |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, channels, out_channels, n_heads, p_dropout=0.): |
|
super().__init__() |
|
assert channels % n_heads == 0 |
|
|
|
self.channels = channels |
|
self.out_channels = out_channels |
|
self.n_heads = n_heads |
|
self.p_dropout = p_dropout |
|
|
|
self.k_channels = channels // n_heads |
|
self.conv_q = torch.nn.Conv1d(channels, channels, 1) |
|
self.conv_k = torch.nn.Conv1d(channels, channels, 1) |
|
self.conv_v = torch.nn.Conv1d(channels, channels, 1) |
|
|
|
|
|
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) |
|
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) |
|
|
|
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) |
|
self.drop = torch.nn.Dropout(p_dropout) |
|
|
|
torch.nn.init.xavier_uniform_(self.conv_q.weight) |
|
torch.nn.init.xavier_uniform_(self.conv_k.weight) |
|
torch.nn.init.xavier_uniform_(self.conv_v.weight) |
|
|
|
def forward(self, x, attn_mask=None): |
|
q = self.conv_q(x) |
|
k = self.conv_k(x) |
|
v = self.conv_v(x) |
|
|
|
x = self.attention(q, k, v, mask=attn_mask) |
|
|
|
x = self.conv_o(x) |
|
return x |
|
|
|
def attention(self, query, key, value, mask=None): |
|
b, d, t_s, t_t = (*key.size(), query.size(2)) |
|
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) |
|
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) |
|
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) |
|
|
|
query = self.query_rotary_pe(query) |
|
key = self.key_rotary_pe(key) |
|
|
|
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0) |
|
output = output.transpose(2, 3).contiguous().view(b, d, t_t) |
|
return output |
|
|
|
|
|
|
|
class DiTConVBlock(nn.Module): |
|
""" |
|
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. |
|
""" |
|
|
|
def __init__(self, hidden_channels, out_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) |
|
self.attn = MultiHeadAttention(hidden_channels + out_channels, hidden_channels + out_channels, num_heads, p_dropout) |
|
self.norm2 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) |
|
self.mlp = FFN(hidden_channels + out_channels, hidden_channels + out_channels, filter_channels, kernel_size, p_dropout=p_dropout) |
|
self.adaLN_modulation = nn.Sequential( |
|
nn.Linear(gin_channels, hidden_channels + out_channels) if gin_channels != hidden_channels + out_channels else nn.Identity(), |
|
nn.SiLU(), |
|
nn.Linear(hidden_channels + out_channels, 6 * (hidden_channels + out_channels), bias=True) |
|
) |
|
|
|
def forward(self, x, c, x_mask): |
|
""" |
|
Args: |
|
x : [batch_size, channel, time] |
|
c : [batch_size, channel] |
|
x_mask : [batch_size, 1, time] |
|
return the same shape as x |
|
""" |
|
x = x * x_mask |
|
attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) |
|
|
|
if c is not None: |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) |
|
x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1, 2)).transpose(1, 2), shift_msa, scale_msa), attn_mask) * x_mask |
|
|
|
x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1, 2)).transpose(1, 2), shift_mlp, scale_mlp), x_mask) * x_mask |
|
else: |
|
|
|
x = x + self.attn(self.norm1(x.transpose(1, 2)).transpose(1, 2), attn_mask) |
|
x = x + self.mlp(self.norm1(x.transpose(1, 2)).transpose(1, 2), x_mask) |
|
return x |
|
|
|
@staticmethod |
|
def modulate(x, shift, scale): |
|
return x * (1 + scale) + shift |
|
|
|
|
|
class RotaryPositionalEmbeddings(nn.Module): |
|
""" |
|
## RoPE module |
|
|
|
Rotary encoding transforms pairs of features by rotating in the 2D plane. |
|
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. |
|
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it |
|
by an angle depending on the position of the token. |
|
""" |
|
|
|
def __init__(self, d: int, base: int = 10_000): |
|
r""" |
|
* `d` is the number of features $d$ |
|
* `base` is the constant used for calculating $\Theta$ |
|
""" |
|
super().__init__() |
|
|
|
self.base = base |
|
self.d = int(d) |
|
self.cos_cached = None |
|
self.sin_cached = None |
|
|
|
def _build_cache(self, x: torch.Tensor): |
|
r""" |
|
Cache $\cos$ and $\sin$ values |
|
""" |
|
|
|
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: |
|
return |
|
|
|
|
|
seq_len = x.shape[0] |
|
|
|
|
|
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) |
|
|
|
|
|
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) |
|
|
|
|
|
idx_theta = torch.einsum("n,d->nd", seq_idx, theta) |
|
|
|
|
|
|
|
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) |
|
|
|
|
|
self.cos_cached = idx_theta2.cos()[:, None, None, :] |
|
self.sin_cached = idx_theta2.sin()[:, None, None, :] |
|
|
|
def _neg_half(self, x: torch.Tensor): |
|
|
|
d_2 = self.d // 2 |
|
|
|
|
|
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) |
|
|
|
def forward(self, x: torch.Tensor): |
|
""" |
|
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` |
|
""" |
|
|
|
x = x.permute(2, 0, 1, 3) |
|
|
|
self._build_cache(x) |
|
|
|
|
|
x_rope, x_pass = x[..., : self.d], x[..., self.d:] |
|
|
|
|
|
|
|
neg_half_x = self._neg_half(x_rope) |
|
|
|
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) |
|
|
|
return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) |
|
|
|
|
|
class Transpose(nn.Identity): |
|
"""(N, T, D) -> (N, D, T)""" |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return input.transpose(1, 2) |
|
|