|
import logging |
|
import os |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
import tensorflow.keras.backend as K |
|
|
|
from tensorflow.keras import layers, regularizers |
|
|
|
from tensorflow.keras.layers import ( |
|
BatchNormalization, |
|
Conv2D, |
|
Dense, |
|
Dropout, |
|
Input, |
|
Lambda, |
|
Reshape, |
|
) |
|
|
|
from tensorflow.keras.models import Model |
|
|
|
from tensorflow.keras.optimizers import Adam |
|
|
|
from deep_speaker.constants import NUM_FBANKS, SAMPLE_RATE, NUM_FRAMES |
|
from deep_speaker.triplet_loss import deep_speaker_loss |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@tf.function |
|
def tf_normalize(data, ndims, eps=0, adjusted=False): |
|
data = tf.convert_to_tensor(data, name='data') |
|
|
|
reduce_dims = [-i - 1 for i in range(ndims)] |
|
|
|
data = tf.cast(data, dtype=tf.dtypes.float32) |
|
data_num = tf.reduce_prod(data.shape[-ndims:]) |
|
data_mean = tf.reduce_mean(data, axis=reduce_dims, keepdims=True) |
|
|
|
|
|
stddev = tf.math.reduce_std(data, axis=reduce_dims, keepdims=True) |
|
adjusted_stddev = stddev |
|
if adjusted: |
|
min_stddev = tf.math.rsqrt(tf.cast(data_num, tf.dtypes.float32)) |
|
eps = tf.maximum(eps, min_stddev) |
|
if eps > 0: |
|
adjusted_stddev = tf.maximum(adjusted_stddev, eps) |
|
|
|
return (data - data_mean) / adjusted_stddev |
|
|
|
|
|
@tf.function |
|
def tf_fbank(samples): |
|
""" |
|
Compute Mel-filterbank energy features from an audio signal. |
|
See python_speech_features.fbank |
|
""" |
|
frame_length = int(0.025 * SAMPLE_RATE) |
|
frame_step = int(0.01 * SAMPLE_RATE) |
|
fft_length = 512 |
|
fft_bins = fft_length // 2 + 1 |
|
|
|
pre_emphasis = samples[:, 1:] - 0.97 * samples[:, :-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spec = tf.abs(tf.signal.stft(pre_emphasis, frame_length, frame_step, fft_length, window_fn=tf.ones)) |
|
powspec = tf.square(spec) / fft_length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix( |
|
num_mel_bins=NUM_FBANKS, |
|
num_spectrogram_bins=fft_bins, |
|
sample_rate=SAMPLE_RATE, |
|
lower_edge_hertz=0, |
|
upper_edge_hertz=SAMPLE_RATE / 2, |
|
) |
|
|
|
feat = tf.matmul(powspec, linear_to_mel_weight_matrix) |
|
|
|
return feat |
|
|
|
|
|
class DeepSpeakerModel: |
|
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
batch_input_shape=(None, NUM_FRAMES, NUM_FBANKS, 1), |
|
include_softmax=False, |
|
num_speakers_softmax=None, |
|
pcm_input=False |
|
): |
|
if pcm_input: |
|
batch_input_shape = None |
|
self.include_softmax = include_softmax |
|
if self.include_softmax: |
|
assert num_speakers_softmax > 0 |
|
self.clipped_relu_count = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if pcm_input: |
|
batch_input_shape = batch_input_shape or (None, None) |
|
inputs = Input(batch_shape=batch_input_shape, name='raw_inputs') |
|
x = inputs |
|
x = Lambda(tf_fbank)(x) |
|
x = Lambda(lambda x_: tf_normalize(x_, 1, 1e-12))(x) |
|
x = Lambda(lambda x_: tf.expand_dims(x_, axis=-1))(x) |
|
else: |
|
batch_input_shape = batch_input_shape or (None, None, NUM_FBANKS, 1) |
|
inputs = Input(batch_shape=batch_input_shape, name='input') |
|
x = inputs |
|
|
|
x = self.cnn_component(x) |
|
|
|
x = Reshape((-1, 2048))(x) |
|
|
|
x = Lambda(lambda y: K.mean(y, axis=1), name='average')(x) |
|
if include_softmax: |
|
logger.info('Including a Dropout layer to reduce overfitting.') |
|
|
|
x = Dropout(0.5)(x) |
|
x = Dense(512, name='affine')(x) |
|
if include_softmax: |
|
|
|
x = Dense(num_speakers_softmax, activation='softmax')(x) |
|
else: |
|
|
|
x = Lambda(lambda y: K.l2_normalize(y, axis=1), name='ln')(x) |
|
self.m = Model(inputs, x, name='ResCNN') |
|
|
|
def keras_model(self): |
|
return self.m |
|
|
|
def get_weights(self): |
|
w = self.m.get_weights() |
|
if self.include_softmax: |
|
w.pop() |
|
w.pop() |
|
return w |
|
|
|
def clipped_relu(self, inputs): |
|
relu = Lambda(lambda y: K.minimum(K.maximum(y, 0), 20), name=f'clipped_relu_{self.clipped_relu_count}')(inputs) |
|
self.clipped_relu_count += 1 |
|
return relu |
|
|
|
def identity_block(self, input_tensor, kernel_size, filters, stage, block): |
|
conv_name_base = f'res{stage}_{block}_branch' |
|
|
|
x = Conv2D(filters, |
|
kernel_size=kernel_size, |
|
strides=1, |
|
activation=None, |
|
padding='same', |
|
kernel_initializer='glorot_uniform', |
|
kernel_regularizer=regularizers.l2(l=0.0001), |
|
name=conv_name_base + '_2a')(input_tensor) |
|
x = BatchNormalization(name=conv_name_base + '_2a_bn')(x) |
|
x = self.clipped_relu(x) |
|
|
|
x = Conv2D( |
|
filters, |
|
kernel_size=kernel_size, |
|
strides=1, |
|
activation=None, |
|
padding='same', |
|
kernel_initializer='glorot_uniform', |
|
kernel_regularizer=regularizers.l2(l=0.0001), |
|
name=conv_name_base + '_2b', |
|
)(x) |
|
x = BatchNormalization(name=conv_name_base + '_2b_bn')(x) |
|
|
|
x = self.clipped_relu(x) |
|
|
|
x = layers.add([x, input_tensor]) |
|
x = self.clipped_relu(x) |
|
return x |
|
|
|
def conv_and_res_block(self, inp, filters, stage): |
|
conv_name = 'conv{}-s'.format(filters) |
|
|
|
o = Conv2D(filters, |
|
kernel_size=5, |
|
strides=2, |
|
activation=None, |
|
padding='same', |
|
kernel_initializer='glorot_uniform', |
|
kernel_regularizer=regularizers.l2(l=0.0001), name=conv_name)(inp) |
|
o = BatchNormalization(name=conv_name + '_bn')(o) |
|
o = self.clipped_relu(o) |
|
for i in range(3): |
|
o = self.identity_block(o, kernel_size=3, filters=filters, stage=stage, block=i) |
|
return o |
|
|
|
def cnn_component(self, inp): |
|
x = self.conv_and_res_block(inp, 64, stage=1) |
|
x = self.conv_and_res_block(x, 128, stage=2) |
|
x = self.conv_and_res_block(x, 256, stage=3) |
|
x = self.conv_and_res_block(x, 512, stage=4) |
|
return x |
|
|
|
def set_weights(self, w): |
|
for layer, layer_w in zip(self.m.layers, w): |
|
layer.set_weights(layer_w) |
|
logger.info(f'Setting weights for [{layer.name}]...') |
|
|
|
|
|
def main(): |
|
|
|
|
|
dsm = DeepSpeakerModel() |
|
dsm.m.summary() |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _train(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dsm = DeepSpeakerModel(batch_input_shape=(None, 32, 64, 4), include_softmax=False) |
|
|
|
dsm.m.compile(optimizer=Adam(lr=0.01), loss=deep_speaker_loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unit_batch_size = 20 |
|
negative = np.ones(shape=(unit_batch_size, 32, 64, 4)) * (-1) |
|
batch = np.vstack((negative, negative, negative)) |
|
x = batch |
|
y = np.zeros(shape=(len(batch), 512)) |
|
print('Starting to fit...') |
|
while True: |
|
print(dsm.m.train_on_batch(x, y)) |
|
|
|
|
|
def _test_checkpoint_compatibility(): |
|
dsm = DeepSpeakerModel(batch_input_shape=(None, 32, 64, 4), include_softmax=True, num_speakers_softmax=10) |
|
dsm.m.save_weights('test.h5') |
|
dsm = DeepSpeakerModel(batch_input_shape=(None, 32, 64, 4), include_softmax=False) |
|
dsm.m.load_weights('test.h5', by_name=True) |
|
os.remove('test.h5') |
|
|
|
|
|
if __name__ == '__main__': |
|
_test_checkpoint_compatibility() |
|
|