Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright 2024 Yiwei Guo | |
# Derived mostly from fairseq (https://github.com/facebookresearch/fairseq) | |
"""Prompt Pre-net Modules.""" | |
import math | |
import torch.nn as nn | |
from vec2wav2.models.fairseq_modules.fp32_group_norm import Fp32GroupNorm | |
from vec2wav2.models.fairseq_modules.layer_norm import Fp32LayerNorm | |
from vec2wav2.models.fairseq_modules.transpose_last import TransposeLast | |
import torch | |
def norm_block(is_layer_norm, dim, affine=True): | |
if is_layer_norm: | |
mod = nn.Sequential( | |
TransposeLast(), | |
Fp32LayerNorm(dim, elementwise_affine=affine), | |
TransposeLast(), | |
) | |
else: | |
mod = Fp32GroupNorm(1, dim, affine=affine) | |
return mod | |
class ZeroPad1d(nn.Module): | |
def __init__(self, pad_left, pad_right): | |
super().__init__() | |
self.pad_left = pad_left | |
self.pad_right = pad_right | |
def forward(self, x): | |
return nn.functional.pad(x, (self.pad_left, self.pad_right)) | |
class ConvPromptPrenet(nn.Module): | |
def __init__( | |
self, | |
conv_layers, | |
embed, | |
dropout, | |
skip_connections, | |
residual_scale, | |
non_affine_group_norm, | |
conv_bias, | |
activation, | |
): | |
super().__init__() | |
def block(n_in, n_out, k, stride, pad): | |
return nn.Sequential( | |
nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, padding=pad), | |
nn.Dropout(p=dropout), | |
norm_block(False, n_out, affine=not non_affine_group_norm), | |
activation, | |
) | |
in_d = embed | |
self.conv_layers = nn.ModuleList() | |
self.residual_proj = nn.ModuleList() | |
for dim, k, stride, pad in conv_layers: | |
if in_d != dim and skip_connections: | |
self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) | |
else: | |
self.residual_proj.append(None) | |
self.conv_layers.append(block(in_d, dim, k, stride, pad)) | |
in_d = dim | |
self.conv_layers = nn.Sequential(*self.conv_layers) | |
self.skip_connections = skip_connections | |
self.residual_scale = math.sqrt(residual_scale) | |
def forward(self, x): | |
for rproj, conv in zip(self.residual_proj, self.conv_layers): | |
residual = x | |
x = conv(x) | |
if self.skip_connections: | |
if rproj is not None: | |
residual = rproj(residual) | |
x = (x + residual) * self.residual_scale | |
return x | |