Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,750 Bytes
2da45ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import math
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from torch import nn
#.............................................
class VitsConfig(PretrainedConfig):
model_type = "vits"
def __init__(
self,
vocab_size=38,
hidden_size=192,
num_hidden_layers=6,
num_attention_heads=2,
window_size=4,
use_bias=True,
ffn_dim=768,
layerdrop=0.1,
ffn_kernel_size=3,
flow_size=192,
spectrogram_bins=513,
hidden_act="relu",
hidden_dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_stochastic_duration_prediction=True,
num_speakers=1,
speaker_embedding_size=0,
upsample_initial_channel=512,
upsample_rates=[8, 8, 2, 2],
upsample_kernel_sizes=[16, 16, 4, 4],
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
leaky_relu_slope=0.1,
depth_separable_channels=2,
depth_separable_num_layers=3,
duration_predictor_flow_bins=10,
duration_predictor_tail_bound=5.0,
duration_predictor_kernel_size=3,
duration_predictor_dropout=0.5,
duration_predictor_num_flows=4,
duration_predictor_filter_channels=256,
prior_encoder_num_flows=4,
prior_encoder_num_wavenet_layers=4,
posterior_encoder_num_wavenet_layers=16,
wavenet_kernel_size=5,
wavenet_dilation_rate=1,
wavenet_dropout=0.0,
speaking_rate=1.0,
noise_scale=0.667,
noise_scale_duration=0.8,
sampling_rate=16_000,
discriminator_kernel_size=5,
discriminator_stride=3,
discriminator_periods=[2, 3, 5, 7, 11],
discriminator_period_channels=[1, 32, 128, 512, 1024],
discriminator_scale_channels=[1, 16, 64, 256, 1024],
segment_size=8192,
hop_length=256,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.window_size = window_size
self.use_bias = use_bias
self.ffn_dim = ffn_dim
self.layerdrop = layerdrop
self.ffn_kernel_size = ffn_kernel_size
self.flow_size = flow_size
self.spectrogram_bins = spectrogram_bins
self.hidden_act = hidden_act
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
self.num_speakers = num_speakers
self.speaker_embedding_size = speaker_embedding_size
self.upsample_initial_channel = upsample_initial_channel
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.leaky_relu_slope = leaky_relu_slope
self.depth_separable_channels = depth_separable_channels
self.depth_separable_num_layers = depth_separable_num_layers
self.duration_predictor_flow_bins = duration_predictor_flow_bins
self.duration_predictor_tail_bound = duration_predictor_tail_bound
self.duration_predictor_kernel_size = duration_predictor_kernel_size
self.duration_predictor_dropout = duration_predictor_dropout
self.duration_predictor_num_flows = duration_predictor_num_flows
self.duration_predictor_filter_channels = duration_predictor_filter_channels
self.prior_encoder_num_flows = prior_encoder_num_flows
self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
self.wavenet_kernel_size = wavenet_kernel_size
self.wavenet_dilation_rate = wavenet_dilation_rate
self.wavenet_dropout = wavenet_dropout
self.speaking_rate = speaking_rate
self.noise_scale = noise_scale
self.noise_scale_duration = noise_scale_duration
self.sampling_rate = sampling_rate
# used for training
self.discriminator_kernel_size = discriminator_kernel_size
self.discriminator_stride = discriminator_stride
self.discriminator_periods = discriminator_periods
self.discriminator_period_channels = discriminator_period_channels
self.discriminator_scale_channels = discriminator_scale_channels
self.segment_size = segment_size
self.hop_length = hop_length
if len(upsample_kernel_sizes) != len(upsample_rates):
raise ValueError(
f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of "
f"`upsample_rates` ({len(upsample_rates)})"
)
super().__init__(**kwargs)
#.............................................................................................
class VitsPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = VitsConfig
base_model_prefix = "vits"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
#............................................................................................. |