Spaces:
Running
Running
File size: 4,379 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts
import math
import torch
from torch import nn
from torch.nn import functional as F
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def normalization(channels):
groups = 32
if channels <= 16:
groups = 8
elif channels <= 64:
groups = 16
while channels % groups != 0:
groups = int(groups / 2)
assert groups > 2
return GroupNorm32(groups, channels)
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
class QKVAttention(nn.Module):
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None, qk_bias=0):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = weight + qk_bias
if mask is not None:
mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
"""An attention block that allows spatial positions to attend to each other."""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
out_channels=None,
do_activation=False,
):
super().__init__()
self.channels = channels
out_channels = channels if out_channels is None else out_channels
self.do_activation = do_activation
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, out_channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
def forward(self, x, mask=None, qk_bias=0):
b, c, *spatial = x.shape
if mask is not None:
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
if mask.shape[1] != x.shape[-1]:
mask = mask[:, : x.shape[-1], : x.shape[-1]]
x = x.reshape(b, c, -1)
x = self.norm(x)
if self.do_activation:
x = F.silu(x, inplace=True)
qkv = self.qkv(x)
h = self.attention(qkv, mask=mask, qk_bias=qk_bias)
h = self.proj_out(h)
xp = self.x_proj(x)
return (xp + h).reshape(b, xp.shape[1], *spatial)
class ConditioningEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
"""
x: (b, 80, s)
"""
h = self.init(x)
h = self.attn(h)
return h
|