File size: 5,455 Bytes
aed64b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
import os

# pylint: disable=E0611,E0401
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
# pylint: disable=E0611,E0401
from tensorflow.keras.optimizers import SGD
from tqdm import tqdm

from deep_speaker.batcher import KerasFormatConverter, LazyTripletBatcher
from deep_speaker.constants import BATCH_SIZE, CHECKPOINTS_SOFTMAX_DIR, CHECKPOINTS_TRIPLET_DIR, NUM_FRAMES, NUM_FBANKS
from deep_speaker.conv_models import DeepSpeakerModel
from deep_speaker.triplet_loss import deep_speaker_loss
from deep_speaker.utils import load_best_checkpoint, ensures_dir

logger = logging.getLogger(__name__)

# Otherwise it's just too much logging from Tensorflow...
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def fit_model(dsm: DeepSpeakerModel, working_dir: str, max_length: int = NUM_FRAMES, batch_size=BATCH_SIZE):
    batcher = LazyTripletBatcher(working_dir, max_length, dsm)

    # build small test set.
    test_batches = []
    for _ in tqdm(range(200), desc='Build test set'):
        test_batches.append(batcher.get_batch_test(batch_size))

    def test_generator():
        while True:
            for bb in test_batches:
                yield bb

    def train_generator():
        while True:
            yield batcher.get_random_batch(batch_size, is_test=False)

    checkpoint_name = dsm.m.name + '_checkpoint'
    checkpoint_filename = os.path.join(CHECKPOINTS_TRIPLET_DIR, checkpoint_name + '_{epoch}.h5')
    checkpoint = ModelCheckpoint(monitor='val_loss', filepath=checkpoint_filename, save_best_only=True)
    dsm.m.fit(x=train_generator(), y=None, steps_per_epoch=2000, shuffle=False,
              epochs=1000, validation_data=test_generator(), validation_steps=len(test_batches),
              callbacks=[checkpoint])


def fit_model_softmax(dsm: DeepSpeakerModel, kx_train, ky_train, kx_test, ky_test,
                      batch_size=BATCH_SIZE, max_epochs=1000, initial_epoch=0):
    checkpoint_name = dsm.m.name + '_checkpoint'
    checkpoint_filename = os.path.join(CHECKPOINTS_SOFTMAX_DIR, checkpoint_name + '_{epoch}.h5')
    checkpoint = ModelCheckpoint(monitor='val_accuracy', filepath=checkpoint_filename, save_best_only=True)

    # if the accuracy does not increase by 0.1% over 20 epochs, we stop the training.
    early_stopping = EarlyStopping(monitor='val_accuracy', min_delta=0.001, patience=20, verbose=1, mode='max')

    # if the accuracy does not increase over 10 epochs, we reduce the learning rate by half.
    reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=10, min_lr=0.0001, verbose=1)

    max_len_train = len(kx_train) - len(kx_train) % batch_size
    kx_train = kx_train[0:max_len_train]
    ky_train = ky_train[0:max_len_train]
    max_len_test = len(kx_test) - len(kx_test) % batch_size
    kx_test = kx_test[0:max_len_test]
    ky_test = ky_test[0:max_len_test]

    dsm.m.fit(x=kx_train,
              y=ky_train,
              batch_size=batch_size,
              epochs=initial_epoch + max_epochs,
              initial_epoch=initial_epoch,
              verbose=1,
              shuffle=True,
              validation_data=(kx_test, ky_test),
              callbacks=[early_stopping, reduce_lr, checkpoint])


def start_training(working_dir, pre_training_phase=True):
    ensures_dir(CHECKPOINTS_SOFTMAX_DIR)
    ensures_dir(CHECKPOINTS_TRIPLET_DIR)
    batch_input_shape = [None, NUM_FRAMES, NUM_FBANKS, 1]
    if pre_training_phase:
        logger.info('Softmax pre-training.')
        kc = KerasFormatConverter(working_dir)
        num_speakers_softmax = len(kc.categorical_speakers.speaker_ids)
        dsm = DeepSpeakerModel(batch_input_shape, include_softmax=True, num_speakers_softmax=num_speakers_softmax)
        dsm.m.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        pre_training_checkpoint = load_best_checkpoint(CHECKPOINTS_SOFTMAX_DIR)
        if pre_training_checkpoint is not None:
            initial_epoch = int(pre_training_checkpoint.split('/')[-1].split('.')[0].split('_')[-1])
            logger.info(f'Initial epoch is {initial_epoch}.')
            logger.info(f'Loading softmax checkpoint: {pre_training_checkpoint}.')
            dsm.m.load_weights(pre_training_checkpoint)  # latest one.
        else:
            initial_epoch = 0
        fit_model_softmax(dsm, kc.kx_train, kc.ky_train, kc.kx_test, kc.ky_test, initial_epoch=initial_epoch)
    else:
        logger.info('Training with the triplet loss.')
        dsm = DeepSpeakerModel(batch_input_shape, include_softmax=False)
        triplet_checkpoint = load_best_checkpoint(CHECKPOINTS_TRIPLET_DIR)
        pre_training_checkpoint = load_best_checkpoint(CHECKPOINTS_SOFTMAX_DIR)
        if triplet_checkpoint is not None:
            logger.info(f'Loading triplet checkpoint: {triplet_checkpoint}.')
            dsm.m.load_weights(triplet_checkpoint)
        elif pre_training_checkpoint is not None:
            logger.info(f'Loading pre-training checkpoint: {pre_training_checkpoint}.')
            # If `by_name` is True, weights are loaded into layers only if they share the
            # same name. This is useful for fine-tuning or transfer-learning models where
            # some of the layers have changed.
            dsm.m.load_weights(pre_training_checkpoint, by_name=True)
        dsm.m.compile(optimizer=SGD(), loss=deep_speaker_loss)
        fit_model(dsm, working_dir, NUM_FRAMES)