Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/models
/tacotron2.py
# -*- coding: utf-8 -*- | |
# Copyright 2020 The Tacotron-2 Authors, Minh Nguyen (@dathudeptrai), Eren Gölge (@erogol) and Jae Yoo (@jaeyoo) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tacotron-2 Modules.""" | |
import collections | |
import numpy as np | |
import tensorflow as tf | |
# TODO: once https://github.com/tensorflow/addons/pull/1964 is fixed, | |
# uncomment this line. | |
# from tensorflow_addons.seq2seq import dynamic_decode | |
from tensorflow_addons.seq2seq import BahdanauAttention, Decoder, Sampler | |
from tensorflow_tts.utils import dynamic_decode | |
from tensorflow_tts.models import BaseModel | |
def get_initializer(initializer_range=0.02): | |
"""Creates a `tf.initializers.truncated_normal` with the given range. | |
Args: | |
initializer_range: float, initializer range for stddev. | |
Returns: | |
TruncatedNormal initializer with stddev = `initializer_range`. | |
""" | |
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) | |
def gelu(x): | |
"""Gaussian Error Linear unit.""" | |
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) | |
return x * cdf | |
def gelu_new(x): | |
"""Smoother gaussian Error Linear Unit.""" | |
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) | |
return x * cdf | |
def swish(x): | |
"""Swish activation function.""" | |
return tf.nn.swish(x) | |
def mish(x): | |
return x * tf.math.tanh(tf.math.softplus(x)) | |
ACT2FN = { | |
"identity": tf.keras.layers.Activation("linear"), | |
"tanh": tf.keras.layers.Activation("tanh"), | |
"gelu": tf.keras.layers.Activation(gelu), | |
"relu": tf.keras.activations.relu, | |
"swish": tf.keras.layers.Activation(swish), | |
"gelu_new": tf.keras.layers.Activation(gelu_new), | |
"mish": tf.keras.layers.Activation(mish), | |
} | |
class TFEmbedding(tf.keras.layers.Embedding): | |
"""Faster version of embedding.""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def call(self, inputs): | |
inputs = tf.cast(tf.expand_dims(inputs, -1), tf.int32) | |
outputs = tf.gather_nd(self.embeddings, inputs) | |
return outputs | |
class TFTacotronConvBatchNorm(tf.keras.layers.Layer): | |
"""Tacotron-2 Convolutional Batchnorm module.""" | |
def __init__( | |
self, filters, kernel_size, dropout_rate, activation=None, name_idx=None | |
): | |
super().__init__() | |
self.conv1d = tf.keras.layers.Conv1D( | |
filters, | |
kernel_size, | |
kernel_initializer=get_initializer(0.02), | |
padding="same", | |
name="conv_._{}".format(name_idx), | |
) | |
self.norm = tf.keras.layers.experimental.SyncBatchNormalization( | |
axis=-1, name="batch_norm_._{}".format(name_idx) | |
) | |
self.dropout = tf.keras.layers.Dropout( | |
rate=dropout_rate, name="dropout_._{}".format(name_idx) | |
) | |
self.act = ACT2FN[activation] | |
def call(self, inputs, training=False): | |
outputs = self.conv1d(inputs) | |
outputs = self.norm(outputs, training=training) | |
outputs = self.act(outputs) | |
outputs = self.dropout(outputs, training=training) | |
return outputs | |
class TFTacotronEmbeddings(tf.keras.layers.Layer): | |
"""Construct character/phoneme/positional/speaker embeddings.""" | |
def __init__(self, config, **kwargs): | |
"""Init variables.""" | |
super().__init__(**kwargs) | |
self.vocab_size = config.vocab_size | |
self.embedding_hidden_size = config.embedding_hidden_size | |
self.initializer_range = config.initializer_range | |
self.config = config | |
if config.n_speakers > 1: | |
self.speaker_embeddings = TFEmbedding( | |
config.n_speakers, | |
config.embedding_hidden_size, | |
embeddings_initializer=get_initializer(self.initializer_range), | |
name="speaker_embeddings", | |
) | |
self.speaker_fc = tf.keras.layers.Dense( | |
units=config.embedding_hidden_size, name="speaker_fc" | |
) | |
self.LayerNorm = tf.keras.layers.LayerNormalization( | |
epsilon=config.layer_norm_eps, name="LayerNorm" | |
) | |
self.dropout = tf.keras.layers.Dropout(config.embedding_dropout_prob) | |
def build(self, input_shape): | |
"""Build shared character/phoneme embedding layers.""" | |
with tf.name_scope("character_embeddings"): | |
self.character_embeddings = self.add_weight( | |
"weight", | |
shape=[self.vocab_size, self.embedding_hidden_size], | |
initializer=get_initializer(self.initializer_range), | |
) | |
super().build(input_shape) | |
def call(self, inputs, training=False): | |
"""Get character embeddings of inputs. | |
Args: | |
1. character, Tensor (int32) shape [batch_size, length]. | |
2. speaker_id, Tensor (int32) shape [batch_size] | |
Returns: | |
Tensor (float32) shape [batch_size, length, embedding_size]. | |
""" | |
return self._embedding(inputs, training=training) | |
def _embedding(self, inputs, training=False): | |
"""Applies embedding based on inputs tensor.""" | |
input_ids, speaker_ids = inputs | |
# create embeddings | |
inputs_embeds = tf.gather(self.character_embeddings, input_ids) | |
embeddings = inputs_embeds | |
if self.config.n_speakers > 1: | |
speaker_embeddings = self.speaker_embeddings(speaker_ids) | |
speaker_features = tf.math.softplus(self.speaker_fc(speaker_embeddings)) | |
# extended speaker embeddings | |
extended_speaker_features = speaker_features[:, tf.newaxis, :] | |
# sum all embedding | |
embeddings += extended_speaker_features | |
# apply layer-norm and dropout for embeddings. | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings, training=training) | |
return embeddings | |
class TFTacotronEncoderConvs(tf.keras.layers.Layer): | |
"""Tacotron-2 Encoder Convolutional Batchnorm module.""" | |
def __init__(self, config, **kwargs): | |
"""Init variables.""" | |
super().__init__(**kwargs) | |
self.conv_batch_norm = [] | |
for i in range(config.n_conv_encoder): | |
conv = TFTacotronConvBatchNorm( | |
filters=config.encoder_conv_filters, | |
kernel_size=config.encoder_conv_kernel_sizes, | |
activation=config.encoder_conv_activation, | |
dropout_rate=config.encoder_conv_dropout_rate, | |
name_idx=i, | |
) | |
self.conv_batch_norm.append(conv) | |
def call(self, inputs, training=False): | |
"""Call logic.""" | |
outputs = inputs | |
for conv in self.conv_batch_norm: | |
outputs = conv(outputs, training=training) | |
return outputs | |
class TFTacotronEncoder(tf.keras.layers.Layer): | |
"""Tacotron-2 Encoder.""" | |
def __init__(self, config, **kwargs): | |
"""Init variables.""" | |
super().__init__(**kwargs) | |
self.embeddings = TFTacotronEmbeddings(config, name="embeddings") | |
self.convbn = TFTacotronEncoderConvs(config, name="conv_batch_norm") | |
self.bilstm = tf.keras.layers.Bidirectional( | |
tf.keras.layers.LSTM( | |
units=config.encoder_lstm_units, return_sequences=True | |
), | |
name="bilstm", | |
) | |
if config.n_speakers > 1: | |
self.encoder_speaker_embeddings = TFEmbedding( | |
config.n_speakers, | |
config.embedding_hidden_size, | |
embeddings_initializer=get_initializer(config.initializer_range), | |
name="encoder_speaker_embeddings", | |
) | |
self.encoder_speaker_fc = tf.keras.layers.Dense( | |
units=config.encoder_lstm_units * 2, name="encoder_speaker_fc" | |
) | |
self.config = config | |
def call(self, inputs, training=False): | |
"""Call logic.""" | |
input_ids, speaker_ids, input_mask = inputs | |
# create embedding and mask them since we sum | |
# speaker embedding to all character embedding. | |
input_embeddings = self.embeddings([input_ids, speaker_ids], training=training) | |
# pass embeddings to convolution batch norm | |
conv_outputs = self.convbn(input_embeddings, training=training) | |
# bi-lstm. | |
outputs = self.bilstm(conv_outputs, mask=input_mask) | |
if self.config.n_speakers > 1: | |
encoder_speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids) | |
encoder_speaker_features = tf.math.softplus( | |
self.encoder_speaker_fc(encoder_speaker_embeddings) | |
) | |
# extended encoderspeaker embeddings | |
extended_encoder_speaker_features = encoder_speaker_features[ | |
:, tf.newaxis, : | |
] | |
# sum to encoder outputs | |
outputs += extended_encoder_speaker_features | |
return outputs | |
class Tacotron2Sampler(Sampler): | |
"""Tacotron2 sampler for Seq2Seq training.""" | |
def __init__( | |
self, config, | |
): | |
super().__init__() | |
self.config = config | |
# create schedule factor. | |
# the input of a next decoder cell is calculated by formular: | |
# next_inputs = ratio * prev_groundtruth_outputs + (1.0 - ratio) * prev_predicted_outputs. | |
self._ratio = tf.constant(1.0, dtype=tf.float32) | |
self._reduction_factor = self.config.reduction_factor | |
def setup_target(self, targets, mel_lengths): | |
"""Setup ground-truth mel outputs for decoder.""" | |
self.mel_lengths = mel_lengths | |
self.set_batch_size(tf.shape(targets)[0]) | |
self.targets = targets[ | |
:, self._reduction_factor - 1 :: self._reduction_factor, : | |
] | |
self.max_lengths = tf.tile([tf.shape(self.targets)[1]], [self._batch_size]) | |
def batch_size(self): | |
return self._batch_size | |
def sample_ids_shape(self): | |
return tf.TensorShape([]) | |
def sample_ids_dtype(self): | |
return tf.int32 | |
def reduction_factor(self): | |
return self._reduction_factor | |
def initialize(self): | |
"""Return (Finished, next_inputs).""" | |
return ( | |
tf.tile([False], [self._batch_size]), | |
tf.tile([[0.0]], [self._batch_size, self.config.n_mels]), | |
) | |
def sample(self, time, outputs, state): | |
return tf.tile([0], [self._batch_size]) | |
def next_inputs( | |
self, | |
time, | |
outputs, | |
state, | |
sample_ids, | |
stop_token_prediction, | |
training=False, | |
**kwargs, | |
): | |
if training: | |
finished = time + 1 >= self.max_lengths | |
next_inputs = ( | |
self._ratio * self.targets[:, time, :] | |
+ (1.0 - self._ratio) * outputs[:, -self.config.n_mels :] | |
) | |
next_state = state | |
return (finished, next_inputs, next_state) | |
else: | |
stop_token_prediction = tf.nn.sigmoid(stop_token_prediction) | |
finished = tf.cast(tf.round(stop_token_prediction), tf.bool) | |
finished = tf.reduce_all(finished) | |
next_inputs = outputs[:, -self.config.n_mels :] | |
next_state = state | |
return (finished, next_inputs, next_state) | |
def set_batch_size(self, batch_size): | |
self._batch_size = batch_size | |
class TFTacotronLocationSensitiveAttention(BahdanauAttention): | |
"""Tacotron-2 Location Sensitive Attention module.""" | |
def __init__( | |
self, | |
config, | |
memory, | |
mask_encoder=True, | |
memory_sequence_length=None, | |
is_cumulate=True, | |
): | |
"""Init variables.""" | |
memory_length = memory_sequence_length if (mask_encoder is True) else None | |
super().__init__( | |
units=config.attention_dim, | |
memory=memory, | |
memory_sequence_length=memory_length, | |
probability_fn="softmax", | |
name="LocationSensitiveAttention", | |
) | |
self.location_convolution = tf.keras.layers.Conv1D( | |
filters=config.attention_filters, | |
kernel_size=config.attention_kernel, | |
padding="same", | |
use_bias=False, | |
name="location_conv", | |
) | |
self.location_layer = tf.keras.layers.Dense( | |
units=config.attention_dim, use_bias=False, name="location_layer" | |
) | |
self.v = tf.keras.layers.Dense(1, use_bias=True, name="scores_attention") | |
self.config = config | |
self.is_cumulate = is_cumulate | |
self.use_window = False | |
def setup_window(self, win_front=2, win_back=4): | |
self.win_front = tf.constant(win_front, tf.int32) | |
self.win_back = tf.constant(win_back, tf.int32) | |
self._indices = tf.expand_dims(tf.range(tf.shape(self.keys)[1]), 0) | |
self._indices = tf.tile( | |
self._indices, [tf.shape(self.keys)[0], 1] | |
) # [batch_size, max_time] | |
self.use_window = True | |
def _compute_window_mask(self, max_alignments): | |
"""Compute window mask for inference. | |
Args: | |
max_alignments (int): [batch_size] | |
""" | |
expanded_max_alignments = tf.expand_dims(max_alignments, 1) # [batch_size, 1] | |
low = expanded_max_alignments - self.win_front | |
high = expanded_max_alignments + self.win_back | |
mlow = tf.cast((self._indices < low), tf.float32) | |
mhigh = tf.cast((self._indices > high), tf.float32) | |
mask = mlow + mhigh | |
return mask # [batch_size, max_length] | |
def __call__(self, inputs, training=False): | |
query, state, prev_max_alignments = inputs | |
processed_query = self.query_layer(query) if self.query_layer else query | |
processed_query = tf.expand_dims(processed_query, 1) | |
expanded_alignments = tf.expand_dims(state, axis=2) | |
f = self.location_convolution(expanded_alignments) | |
processed_location_features = self.location_layer(f) | |
energy = self._location_sensitive_score( | |
processed_query, processed_location_features, self.keys | |
) | |
# mask energy on inference steps. | |
if self.use_window is True: | |
window_mask = self._compute_window_mask(prev_max_alignments) | |
energy = energy + window_mask * -1e20 | |
alignments = self.probability_fn(energy, state) | |
if self.is_cumulate: | |
state = alignments + state | |
else: | |
state = alignments | |
expanded_alignments = tf.expand_dims(alignments, 2) | |
context = tf.reduce_sum(expanded_alignments * self.values, 1) | |
return context, alignments, state | |
def _location_sensitive_score(self, W_query, W_fil, W_keys): | |
"""Calculate location sensitive energy.""" | |
return tf.squeeze(self.v(tf.nn.tanh(W_keys + W_query + W_fil)), -1) | |
def get_initial_state(self, batch_size, size): | |
"""Get initial alignments.""" | |
return tf.zeros(shape=[batch_size, size], dtype=tf.float32) | |
def get_initial_context(self, batch_size): | |
"""Get initial attention.""" | |
return tf.zeros( | |
shape=[batch_size, self.config.encoder_lstm_units * 2], dtype=tf.float32 | |
) | |
class TFTacotronPrenet(tf.keras.layers.Layer): | |
"""Tacotron-2 prenet.""" | |
def __init__(self, config, **kwargs): | |
"""Init variables.""" | |
super().__init__(**kwargs) | |
self.prenet_dense = [ | |
tf.keras.layers.Dense( | |
units=config.prenet_units, | |
activation=ACT2FN[config.prenet_activation], | |
name="dense_._{}".format(i), | |
) | |
for i in range(config.n_prenet_layers) | |
] | |
self.dropout = tf.keras.layers.Dropout( | |
rate=config.prenet_dropout_rate, name="dropout" | |
) | |
def call(self, inputs, training=False): | |
"""Call logic.""" | |
outputs = inputs | |
for layer in self.prenet_dense: | |
outputs = layer(outputs) | |
outputs = self.dropout(outputs, training=True) | |
return outputs | |
class TFTacotronPostnet(tf.keras.layers.Layer): | |
"""Tacotron-2 postnet.""" | |
def __init__(self, config, **kwargs): | |
"""Init variables.""" | |
super().__init__(**kwargs) | |
self.conv_batch_norm = [] | |
for i in range(config.n_conv_postnet): | |
conv = TFTacotronConvBatchNorm( | |
filters=config.postnet_conv_filters, | |
kernel_size=config.postnet_conv_kernel_sizes, | |
dropout_rate=config.postnet_dropout_rate, | |
activation="identity" if i + 1 == config.n_conv_postnet else "tanh", | |
name_idx=i, | |
) | |
self.conv_batch_norm.append(conv) | |
def call(self, inputs, training=False): | |
"""Call logic.""" | |
outputs = inputs | |
for _, conv in enumerate(self.conv_batch_norm): | |
outputs = conv(outputs, training=training) | |
return outputs | |
TFTacotronDecoderCellState = collections.namedtuple( | |
"TFTacotronDecoderCellState", | |
[ | |
"attention_lstm_state", | |
"decoder_lstms_state", | |
"context", | |
"time", | |
"state", | |
"alignment_history", | |
"max_alignments", | |
], | |
) | |
TFDecoderOutput = collections.namedtuple( | |
"TFDecoderOutput", ("mel_output", "token_output", "sample_id") | |
) | |
class TFTacotronDecoderCell(tf.keras.layers.AbstractRNNCell): | |
"""Tacotron-2 custom decoder cell.""" | |
def __init__(self, config, enable_tflite_convertible=False, **kwargs): | |
"""Init variables.""" | |
super().__init__(**kwargs) | |
self.enable_tflite_convertible = enable_tflite_convertible | |
self.prenet = TFTacotronPrenet(config, name="prenet") | |
# define lstm cell on decoder. | |
# TODO(@dathudeptrai) switch to zone-out lstm. | |
self.attention_lstm = tf.keras.layers.LSTMCell( | |
units=config.decoder_lstm_units, name="attention_lstm_cell" | |
) | |
lstm_cells = [] | |
for i in range(config.n_lstm_decoder): | |
lstm_cell = tf.keras.layers.LSTMCell( | |
units=config.decoder_lstm_units, name="lstm_cell_._{}".format(i) | |
) | |
lstm_cells.append(lstm_cell) | |
self.decoder_lstms = tf.keras.layers.StackedRNNCells( | |
lstm_cells, name="decoder_lstms" | |
) | |
# define attention layer. | |
if config.attention_type == "lsa": | |
# create location-sensitive attention. | |
self.attention_layer = TFTacotronLocationSensitiveAttention( | |
config, | |
memory=None, | |
mask_encoder=True, | |
memory_sequence_length=None, | |
is_cumulate=True, | |
) | |
else: | |
raise ValueError("Only lsa (location-sensitive attention) is supported") | |
# frame, stop projection layer. | |
self.frame_projection = tf.keras.layers.Dense( | |
units=config.n_mels * config.reduction_factor, name="frame_projection" | |
) | |
self.stop_projection = tf.keras.layers.Dense( | |
units=config.reduction_factor, name="stop_projection" | |
) | |
self.config = config | |
def set_alignment_size(self, alignment_size): | |
self.alignment_size = alignment_size | |
def output_size(self): | |
"""Return output (mel) size.""" | |
return self.frame_projection.units | |
def state_size(self): | |
"""Return hidden state size.""" | |
return TFTacotronDecoderCellState( | |
attention_lstm_state=self.attention_lstm.state_size, | |
decoder_lstms_state=self.decoder_lstms.state_size, | |
time=tf.TensorShape([]), | |
attention=self.config.attention_dim, | |
state=self.alignment_size, | |
alignment_history=(), | |
max_alignments=tf.TensorShape([1]), | |
) | |
def get_initial_state(self, batch_size): | |
"""Get initial states.""" | |
initial_attention_lstm_cell_states = self.attention_lstm.get_initial_state( | |
None, batch_size, dtype=tf.float32 | |
) | |
initial_decoder_lstms_cell_states = self.decoder_lstms.get_initial_state( | |
None, batch_size, dtype=tf.float32 | |
) | |
initial_context = tf.zeros( | |
shape=[batch_size, self.config.encoder_lstm_units * 2], dtype=tf.float32 | |
) | |
initial_state = self.attention_layer.get_initial_state( | |
batch_size, size=self.alignment_size | |
) | |
if self.enable_tflite_convertible: | |
initial_alignment_history = () | |
else: | |
initial_alignment_history = tf.TensorArray( | |
dtype=tf.float32, size=0, dynamic_size=True | |
) | |
return TFTacotronDecoderCellState( | |
attention_lstm_state=initial_attention_lstm_cell_states, | |
decoder_lstms_state=initial_decoder_lstms_cell_states, | |
time=tf.zeros([], dtype=tf.int32), | |
context=initial_context, | |
state=initial_state, | |
alignment_history=initial_alignment_history, | |
max_alignments=tf.zeros([batch_size], dtype=tf.int32), | |
) | |
def call(self, inputs, states, training=False): | |
"""Call logic.""" | |
decoder_input = inputs | |
# 1. apply prenet for decoder_input. | |
prenet_out = self.prenet(decoder_input, training=training) # [batch_size, dim] | |
# 2. concat prenet_out and prev context vector | |
# then use it as input of attention lstm layer. | |
attention_lstm_input = tf.concat([prenet_out, states.context], axis=-1) | |
attention_lstm_output, next_attention_lstm_state = self.attention_lstm( | |
attention_lstm_input, states.attention_lstm_state | |
) | |
# 3. compute context, alignment and cumulative alignment. | |
prev_state = states.state | |
if not self.enable_tflite_convertible: | |
prev_alignment_history = states.alignment_history | |
prev_max_alignments = states.max_alignments | |
context, alignments, state = self.attention_layer( | |
[attention_lstm_output, prev_state, prev_max_alignments], training=training, | |
) | |
# 4. run decoder lstm(s) | |
decoder_lstms_input = tf.concat([attention_lstm_output, context], axis=-1) | |
decoder_lstms_output, next_decoder_lstms_state = self.decoder_lstms( | |
decoder_lstms_input, states.decoder_lstms_state | |
) | |
# 5. compute frame feature and stop token. | |
projection_inputs = tf.concat([decoder_lstms_output, context], axis=-1) | |
decoder_outputs = self.frame_projection(projection_inputs) | |
stop_inputs = tf.concat([decoder_lstms_output, decoder_outputs], axis=-1) | |
stop_tokens = self.stop_projection(stop_inputs) | |
# 6. save alignment history to visualize. | |
if self.enable_tflite_convertible: | |
alignment_history = () | |
else: | |
alignment_history = prev_alignment_history.write(states.time, alignments) | |
# 7. return new states. | |
new_states = TFTacotronDecoderCellState( | |
attention_lstm_state=next_attention_lstm_state, | |
decoder_lstms_state=next_decoder_lstms_state, | |
time=states.time + 1, | |
context=context, | |
state=state, | |
alignment_history=alignment_history, | |
max_alignments=tf.argmax(alignments, -1, output_type=tf.int32), | |
) | |
return (decoder_outputs, stop_tokens), new_states | |
class TFTacotronDecoder(Decoder): | |
"""Tacotron-2 Decoder.""" | |
def __init__( | |
self, | |
decoder_cell, | |
decoder_sampler, | |
output_layer=None, | |
enable_tflite_convertible=False, | |
): | |
"""Initial variables.""" | |
self.cell = decoder_cell | |
self.sampler = decoder_sampler | |
self.output_layer = output_layer | |
self.enable_tflite_convertible = enable_tflite_convertible | |
def setup_decoder_init_state(self, decoder_init_state): | |
self.initial_state = decoder_init_state | |
def initialize(self, **kwargs): | |
return self.sampler.initialize() + (self.initial_state,) | |
def output_size(self): | |
return TFDecoderOutput( | |
mel_output=tf.nest.map_structure( | |
lambda shape: tf.TensorShape(shape), self.cell.output_size | |
), | |
token_output=tf.TensorShape(self.sampler.reduction_factor), | |
sample_id=tf.TensorShape([1]) | |
if self.enable_tflite_convertible | |
else self.sampler.sample_ids_shape, # tf.TensorShape([]) | |
) | |
def output_dtype(self): | |
return TFDecoderOutput(tf.float32, tf.float32, self.sampler.sample_ids_dtype) | |
def batch_size(self): | |
return self.sampler._batch_size | |
def step(self, time, inputs, state, training=False): | |
(mel_outputs, stop_tokens), cell_state = self.cell( | |
inputs, state, training=training | |
) | |
if self.output_layer is not None: | |
mel_outputs = self.output_layer(mel_outputs) | |
sample_ids = self.sampler.sample( | |
time=time, outputs=mel_outputs, state=cell_state | |
) | |
(finished, next_inputs, next_state) = self.sampler.next_inputs( | |
time=time, | |
outputs=mel_outputs, | |
state=cell_state, | |
sample_ids=sample_ids, | |
stop_token_prediction=stop_tokens, | |
training=training, | |
) | |
outputs = TFDecoderOutput(mel_outputs, stop_tokens, sample_ids) | |
return (outputs, next_state, next_inputs, finished) | |
class TFTacotron2(BaseModel): | |
"""Tensorflow tacotron-2 model.""" | |
def __init__(self, config, **kwargs): | |
"""Initalize tacotron-2 layers.""" | |
enable_tflite_convertible = kwargs.pop("enable_tflite_convertible", False) | |
super().__init__(self, **kwargs) | |
self.encoder = TFTacotronEncoder(config, name="encoder") | |
self.decoder_cell = TFTacotronDecoderCell( | |
config, | |
name="decoder_cell", | |
enable_tflite_convertible=enable_tflite_convertible, | |
) | |
self.decoder = TFTacotronDecoder( | |
self.decoder_cell, | |
Tacotron2Sampler(config), | |
enable_tflite_convertible=enable_tflite_convertible, | |
) | |
self.postnet = TFTacotronPostnet(config, name="post_net") | |
self.post_projection = tf.keras.layers.Dense( | |
units=config.n_mels, name="residual_projection" | |
) | |
self.use_window_mask = False | |
self.maximum_iterations = 4000 | |
self.enable_tflite_convertible = enable_tflite_convertible | |
self.config = config | |
def setup_window(self, win_front, win_back): | |
"""Call only for inference.""" | |
self.use_window_mask = True | |
self.win_front = win_front | |
self.win_back = win_back | |
def setup_maximum_iterations(self, maximum_iterations): | |
"""Call only for inference.""" | |
self.maximum_iterations = maximum_iterations | |
def _build(self): | |
input_ids = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9]]) | |
input_lengths = np.array([9]) | |
speaker_ids = np.array([0]) | |
mel_outputs = np.random.normal(size=(1, 50, 80)).astype(np.float32) | |
mel_lengths = np.array([50]) | |
self( | |
input_ids, | |
input_lengths, | |
speaker_ids, | |
mel_outputs, | |
mel_lengths, | |
10, | |
training=True, | |
) | |
def call( | |
self, | |
input_ids, | |
input_lengths, | |
speaker_ids, | |
mel_gts, | |
mel_lengths, | |
maximum_iterations=None, | |
use_window_mask=False, | |
win_front=2, | |
win_back=3, | |
training=False, | |
**kwargs, | |
): | |
"""Call logic.""" | |
# create input-mask based on input_lengths | |
input_mask = tf.sequence_mask( | |
input_lengths, | |
maxlen=tf.reduce_max(input_lengths), | |
name="input_sequence_masks", | |
) | |
# Encoder Step. | |
encoder_hidden_states = self.encoder( | |
[input_ids, speaker_ids, input_mask], training=training | |
) | |
batch_size = tf.shape(encoder_hidden_states)[0] | |
alignment_size = tf.shape(encoder_hidden_states)[1] | |
# Setup some initial placeholders for decoder step. Include: | |
# 1. mel_gts, mel_lengths for teacher forcing mode. | |
# 2. alignment_size for attention size. | |
# 3. initial state for decoder cell. | |
# 4. memory (encoder hidden state) for attention mechanism. | |
self.decoder.sampler.setup_target(targets=mel_gts, mel_lengths=mel_lengths) | |
self.decoder.cell.set_alignment_size(alignment_size) | |
self.decoder.setup_decoder_init_state( | |
self.decoder.cell.get_initial_state(batch_size) | |
) | |
self.decoder.cell.attention_layer.setup_memory( | |
memory=encoder_hidden_states, | |
memory_sequence_length=input_lengths, # use for mask attention. | |
) | |
if use_window_mask: | |
self.decoder.cell.attention_layer.setup_window( | |
win_front=win_front, win_back=win_back | |
) | |
# run decode step. | |
( | |
(frames_prediction, stop_token_prediction, _), | |
final_decoder_state, | |
_, | |
) = dynamic_decode( | |
self.decoder, | |
maximum_iterations=maximum_iterations, | |
enable_tflite_convertible=self.enable_tflite_convertible, | |
training=training, | |
) | |
decoder_outputs = tf.reshape( | |
frames_prediction, [batch_size, -1, self.config.n_mels] | |
) | |
stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1]) | |
residual = self.postnet(decoder_outputs, training=training) | |
residual_projection = self.post_projection(residual) | |
mel_outputs = decoder_outputs + residual_projection | |
if self.enable_tflite_convertible: | |
mask = tf.math.not_equal( | |
tf.cast( | |
tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32 | |
), | |
0, | |
) | |
decoder_outputs = tf.expand_dims( | |
tf.boolean_mask(decoder_outputs, mask), axis=0 | |
) | |
mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0) | |
alignment_history = () | |
else: | |
alignment_history = tf.transpose( | |
final_decoder_state.alignment_history.stack(), [1, 2, 0] | |
) | |
return decoder_outputs, mel_outputs, stop_token_prediction, alignment_history | |
def inference(self, input_ids, input_lengths, speaker_ids, **kwargs): | |
"""Call logic.""" | |
# create input-mask based on input_lengths | |
input_mask = tf.sequence_mask( | |
input_lengths, | |
maxlen=tf.reduce_max(input_lengths), | |
name="input_sequence_masks", | |
) | |
# Encoder Step. | |
encoder_hidden_states = self.encoder( | |
[input_ids, speaker_ids, input_mask], training=False | |
) | |
batch_size = tf.shape(encoder_hidden_states)[0] | |
alignment_size = tf.shape(encoder_hidden_states)[1] | |
# Setup some initial placeholders for decoder step. Include: | |
# 1. batch_size for inference. | |
# 2. alignment_size for attention size. | |
# 3. initial state for decoder cell. | |
# 4. memory (encoder hidden state) for attention mechanism. | |
# 5. window front/back to solve long sentence synthesize problems. (call after setup memory.) | |
self.decoder.sampler.set_batch_size(batch_size) | |
self.decoder.cell.set_alignment_size(alignment_size) | |
self.decoder.setup_decoder_init_state( | |
self.decoder.cell.get_initial_state(batch_size) | |
) | |
self.decoder.cell.attention_layer.setup_memory( | |
memory=encoder_hidden_states, | |
memory_sequence_length=input_lengths, # use for mask attention. | |
) | |
if self.use_window_mask: | |
self.decoder.cell.attention_layer.setup_window( | |
win_front=self.win_front, win_back=self.win_back | |
) | |
# run decode step. | |
( | |
(frames_prediction, stop_token_prediction, _), | |
final_decoder_state, | |
_, | |
) = dynamic_decode( | |
self.decoder, maximum_iterations=self.maximum_iterations, training=False | |
) | |
decoder_outputs = tf.reshape( | |
frames_prediction, [batch_size, -1, self.config.n_mels] | |
) | |
stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1]) | |
residual = self.postnet(decoder_outputs, training=False) | |
residual_projection = self.post_projection(residual) | |
mel_outputs = decoder_outputs + residual_projection | |
alignment_historys = tf.transpose( | |
final_decoder_state.alignment_history.stack(), [1, 2, 0] | |
) | |
return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys | |
def inference_tflite(self, input_ids, input_lengths, speaker_ids, **kwargs): | |
"""Call logic.""" | |
# create input-mask based on input_lengths | |
input_mask = tf.sequence_mask( | |
input_lengths, | |
maxlen=tf.reduce_max(input_lengths), | |
name="input_sequence_masks", | |
) | |
# Encoder Step. | |
encoder_hidden_states = self.encoder( | |
[input_ids, speaker_ids, input_mask], training=False | |
) | |
batch_size = tf.shape(encoder_hidden_states)[0] | |
alignment_size = tf.shape(encoder_hidden_states)[1] | |
# Setup some initial placeholders for decoder step. Include: | |
# 1. batch_size for inference. | |
# 2. alignment_size for attention size. | |
# 3. initial state for decoder cell. | |
# 4. memory (encoder hidden state) for attention mechanism. | |
# 5. window front/back to solve long sentence synthesize problems. (call after setup memory.) | |
self.decoder.sampler.set_batch_size(batch_size) | |
self.decoder.cell.set_alignment_size(alignment_size) | |
self.decoder.setup_decoder_init_state( | |
self.decoder.cell.get_initial_state(batch_size) | |
) | |
self.decoder.cell.attention_layer.setup_memory( | |
memory=encoder_hidden_states, | |
memory_sequence_length=input_lengths, # use for mask attention. | |
) | |
if self.use_window_mask: | |
self.decoder.cell.attention_layer.setup_window( | |
win_front=self.win_front, win_back=self.win_back | |
) | |
# run decode step. | |
( | |
(frames_prediction, stop_token_prediction, _), | |
final_decoder_state, | |
_, | |
) = dynamic_decode( | |
self.decoder, | |
maximum_iterations=self.maximum_iterations, | |
enable_tflite_convertible=self.enable_tflite_convertible, | |
training=False, | |
) | |
decoder_outputs = tf.reshape( | |
frames_prediction, [batch_size, -1, self.config.n_mels] | |
) | |
stop_token_predictions = tf.reshape(stop_token_prediction, [batch_size, -1]) | |
residual = self.postnet(decoder_outputs, training=False) | |
residual_projection = self.post_projection(residual) | |
mel_outputs = decoder_outputs + residual_projection | |
if self.enable_tflite_convertible: | |
mask = tf.math.not_equal( | |
tf.cast( | |
tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32 | |
), | |
0, | |
) | |
decoder_outputs = tf.expand_dims( | |
tf.boolean_mask(decoder_outputs, mask), axis=0 | |
) | |
mel_outputs = tf.expand_dims(tf.boolean_mask(mel_outputs, mask), axis=0) | |
alignment_historys = () | |
else: | |
alignment_historys = tf.transpose( | |
final_decoder_state.alignment_history.stack(), [1, 2, 0] | |
) | |
return decoder_outputs, mel_outputs, stop_token_predictions, alignment_historys | |