File size: 12,399 Bytes
d5ee97c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 |
# -*- coding: utf-8 -*-
# Copyright 2020 The FastSpeech2 Authors and Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Model modules for FastSpeech2."""
import tensorflow as tf
from tensorflow_tts.models.fastspeech import TFFastSpeech, get_initializer
class TFFastSpeechVariantPredictor(tf.keras.layers.Layer):
"""FastSpeech duration predictor module."""
def __init__(self, config, **kwargs):
"""Init variables."""
super().__init__(**kwargs)
self.conv_layers = []
for i in range(config.variant_prediction_num_conv_layers):
self.conv_layers.append(
tf.keras.layers.Conv1D(
config.variant_predictor_filter,
config.variant_predictor_kernel_size,
padding="same",
name="conv_._{}".format(i),
)
)
self.conv_layers.append(tf.keras.layers.Activation(tf.nn.relu))
self.conv_layers.append(
tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="LayerNorm_._{}".format(i)
)
)
self.conv_layers.append(
tf.keras.layers.Dropout(config.variant_predictor_dropout_rate)
)
self.conv_layers_sequence = tf.keras.Sequential(self.conv_layers)
self.output_layer = tf.keras.layers.Dense(1)
if config.n_speakers > 1:
self.decoder_speaker_embeddings = tf.keras.layers.Embedding(
config.n_speakers,
config.encoder_self_attention_params.hidden_size,
embeddings_initializer=get_initializer(config.initializer_range),
name="speaker_embeddings",
)
self.speaker_fc = tf.keras.layers.Dense(
units=config.encoder_self_attention_params.hidden_size,
name="speaker_fc",
)
self.config = config
def call(self, inputs, training=False):
"""Call logic."""
encoder_hidden_states, speaker_ids, attention_mask = inputs
attention_mask = tf.cast(
tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype
)
if self.config.n_speakers > 1:
speaker_embeddings = self.decoder_speaker_embeddings(speaker_ids)
speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings))
# extended speaker embeddings
extended_speaker_features = speaker_features[:, tf.newaxis, :]
encoder_hidden_states += extended_speaker_features
# mask encoder hidden states
masked_encoder_hidden_states = encoder_hidden_states * attention_mask
# pass though first layer
outputs = self.conv_layers_sequence(masked_encoder_hidden_states)
outputs = self.output_layer(outputs)
masked_outputs = outputs * attention_mask
outputs = tf.squeeze(masked_outputs, -1)
return outputs
class TFFastSpeech2(TFFastSpeech):
"""TF Fastspeech module."""
def __init__(self, config, **kwargs):
"""Init layers for fastspeech."""
super().__init__(config, **kwargs)
self.f0_predictor = TFFastSpeechVariantPredictor(
config, dtype=tf.float32, name="f0_predictor"
)
self.energy_predictor = TFFastSpeechVariantPredictor(
config, dtype=tf.float32, name="energy_predictor",
)
self.duration_predictor = TFFastSpeechVariantPredictor(
config, dtype=tf.float32, name="duration_predictor"
)
# define f0_embeddings and energy_embeddings
self.f0_embeddings = tf.keras.layers.Conv1D(
filters=config.encoder_self_attention_params.hidden_size,
kernel_size=9,
padding="same",
name="f0_embeddings",
)
self.f0_dropout = tf.keras.layers.Dropout(0.5)
self.energy_embeddings = tf.keras.layers.Conv1D(
filters=config.encoder_self_attention_params.hidden_size,
kernel_size=9,
padding="same",
name="energy_embeddings",
)
self.energy_dropout = tf.keras.layers.Dropout(0.5)
def _build(self):
"""Dummy input for building model."""
# fake inputs
input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
speaker_ids = tf.convert_to_tensor([0], tf.int32)
duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)
f0_gts = tf.convert_to_tensor(
[[10, 10, 10, 10, 10, 10, 10, 10, 10, 10]], tf.float32
)
energy_gts = tf.convert_to_tensor(
[[10, 10, 10, 10, 10, 10, 10, 10, 10, 10]], tf.float32
)
self(
input_ids=input_ids,
speaker_ids=speaker_ids,
duration_gts=duration_gts,
f0_gts=f0_gts,
energy_gts=energy_gts,
)
def call(
self,
input_ids,
speaker_ids,
duration_gts,
f0_gts,
energy_gts,
training=False,
**kwargs,
):
"""Call logic."""
attention_mask = tf.math.not_equal(input_ids, 0)
embedding_output = self.embeddings([input_ids, speaker_ids], training=training)
encoder_output = self.encoder(
[embedding_output, attention_mask], training=training
)
last_encoder_hidden_states = encoder_output[0]
# energy predictor, here use last_encoder_hidden_states, u can use more hidden_states layers
# rather than just use last_hidden_states of encoder for energy_predictor.
duration_outputs = self.duration_predictor(
[last_encoder_hidden_states, speaker_ids, attention_mask]
) # [batch_size, length]
f0_outputs = self.f0_predictor(
[last_encoder_hidden_states, speaker_ids, attention_mask], training=training
)
energy_outputs = self.energy_predictor(
[last_encoder_hidden_states, speaker_ids, attention_mask], training=training
)
f0_embedding = self.f0_embeddings(
tf.expand_dims(f0_gts, 2)
) # [barch_size, mel_length, feature]
energy_embedding = self.energy_embeddings(
tf.expand_dims(energy_gts, 2)
) # [barch_size, mel_length, feature]
# apply dropout both training/inference
f0_embedding = self.f0_dropout(f0_embedding, training=True)
energy_embedding = self.energy_dropout(energy_embedding, training=True)
# sum features
last_encoder_hidden_states += f0_embedding + energy_embedding
length_regulator_outputs, encoder_masks = self.length_regulator(
[last_encoder_hidden_states, duration_gts], training=training
)
# create decoder positional embedding
decoder_pos = tf.range(
1, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32
)
masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks
decoder_output = self.decoder(
[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],
training=training,
)
last_decoder_hidden_states = decoder_output[0]
# here u can use sum or concat more than 1 hidden states layers from decoder.
mels_before = self.mel_dense(last_decoder_hidden_states)
mels_after = (
self.postnet([mels_before, encoder_masks], training=training) + mels_before
)
outputs = (
mels_before,
mels_after,
duration_outputs,
f0_outputs,
energy_outputs,
)
return outputs
def _inference(
self, input_ids, speaker_ids, speed_ratios, f0_ratios, energy_ratios, **kwargs,
):
"""Call logic."""
attention_mask = tf.math.not_equal(input_ids, 0)
embedding_output = self.embeddings([input_ids, speaker_ids], training=False)
encoder_output = self.encoder(
[embedding_output, attention_mask], training=False
)
last_encoder_hidden_states = encoder_output[0]
# expand ratios
speed_ratios = tf.expand_dims(speed_ratios, 1) # [B, 1]
f0_ratios = tf.expand_dims(f0_ratios, 1) # [B, 1]
energy_ratios = tf.expand_dims(energy_ratios, 1) # [B, 1]
# energy predictor, here use last_encoder_hidden_states, u can use more hidden_states layers
# rather than just use last_hidden_states of encoder for energy_predictor.
duration_outputs = self.duration_predictor(
[last_encoder_hidden_states, speaker_ids, attention_mask]
) # [batch_size, length]
duration_outputs = tf.nn.relu(tf.math.exp(duration_outputs) - 1.0)
duration_outputs = tf.cast(
tf.math.round(duration_outputs * speed_ratios), tf.int32
)
f0_outputs = self.f0_predictor(
[last_encoder_hidden_states, speaker_ids, attention_mask], training=False
)
f0_outputs *= f0_ratios
energy_outputs = self.energy_predictor(
[last_encoder_hidden_states, speaker_ids, attention_mask], training=False
)
energy_outputs *= energy_ratios
f0_embedding = self.f0_dropout(
self.f0_embeddings(tf.expand_dims(f0_outputs, 2)), training=True
)
energy_embedding = self.energy_dropout(
self.energy_embeddings(tf.expand_dims(energy_outputs, 2)), training=True
)
# sum features
last_encoder_hidden_states += f0_embedding + energy_embedding
length_regulator_outputs, encoder_masks = self.length_regulator(
[last_encoder_hidden_states, duration_outputs], training=False
)
# create decoder positional embedding
decoder_pos = tf.range(
1, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32
)
masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks
decoder_output = self.decoder(
[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],
training=False,
)
last_decoder_hidden_states = decoder_output[0]
# here u can use sum or concat more than 1 hidden states layers from decoder.
mel_before = self.mel_dense(last_decoder_hidden_states)
mel_after = (
self.postnet([mel_before, encoder_masks], training=False) + mel_before
)
outputs = (mel_before, mel_after, duration_outputs, f0_outputs, energy_outputs)
return outputs
def setup_inference_fn(self):
self.inference = tf.function(
self._inference,
experimental_relax_shapes=True,
input_signature=[
tf.TensorSpec(shape=[None, None], dtype=tf.int32, name="input_ids"),
tf.TensorSpec(shape=[None,], dtype=tf.int32, name="speaker_ids"),
tf.TensorSpec(shape=[None,], dtype=tf.float32, name="speed_ratios"),
tf.TensorSpec(shape=[None,], dtype=tf.float32, name="f0_ratios"),
tf.TensorSpec(shape=[None,], dtype=tf.float32, name="energy_ratios"),
],
)
self.inference_tflite = tf.function(
self._inference,
experimental_relax_shapes=True,
input_signature=[
tf.TensorSpec(shape=[1, None], dtype=tf.int32, name="input_ids"),
tf.TensorSpec(shape=[1,], dtype=tf.int32, name="speaker_ids"),
tf.TensorSpec(shape=[1,], dtype=tf.float32, name="speed_ratios"),
tf.TensorSpec(shape=[1,], dtype=tf.float32, name="f0_ratios"),
tf.TensorSpec(shape=[1,], dtype=tf.float32, name="energy_ratios"),
],
)
|