vec2wav2.0-demo / vec2wav2 /models /prompt_prenet.py
cantabile-kwok
prepare demo page
05005db
# -*- coding: utf-8 -*-
# Copyright 2024 Yiwei Guo
# Derived mostly from fairseq (https://github.com/facebookresearch/fairseq)
"""Prompt Pre-net Modules."""
import math
import torch.nn as nn
from vec2wav2.models.fairseq_modules.fp32_group_norm import Fp32GroupNorm
from vec2wav2.models.fairseq_modules.layer_norm import Fp32LayerNorm
from vec2wav2.models.fairseq_modules.transpose_last import TransposeLast
import torch
def norm_block(is_layer_norm, dim, affine=True):
if is_layer_norm:
mod = nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=affine),
TransposeLast(),
)
else:
mod = Fp32GroupNorm(1, dim, affine=affine)
return mod
class ZeroPad1d(nn.Module):
def __init__(self, pad_left, pad_right):
super().__init__()
self.pad_left = pad_left
self.pad_right = pad_right
def forward(self, x):
return nn.functional.pad(x, (self.pad_left, self.pad_right))
class ConvPromptPrenet(nn.Module):
def __init__(
self,
conv_layers,
embed,
dropout,
skip_connections,
residual_scale,
non_affine_group_norm,
conv_bias,
activation,
):
super().__init__()
def block(n_in, n_out, k, stride, pad):
return nn.Sequential(
nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, padding=pad),
nn.Dropout(p=dropout),
norm_block(False, n_out, affine=not non_affine_group_norm),
activation,
)
in_d = embed
self.conv_layers = nn.ModuleList()
self.residual_proj = nn.ModuleList()
for dim, k, stride, pad in conv_layers:
if in_d != dim and skip_connections:
self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False))
else:
self.residual_proj.append(None)
self.conv_layers.append(block(in_d, dim, k, stride, pad))
in_d = dim
self.conv_layers = nn.Sequential(*self.conv_layers)
self.skip_connections = skip_connections
self.residual_scale = math.sqrt(residual_scale)
def forward(self, x):
for rproj, conv in zip(self.residual_proj, self.conv_layers):
residual = x
x = conv(x)
if self.skip_connections:
if rproj is not None:
residual = rproj(residual)
x = (x + residual) * self.residual_scale
return x