|
import os |
|
import time |
|
|
|
import torch |
|
import torch.multiprocessing |
|
from torch.nn.utils.rnn import pad_sequence |
|
from torch.optim import RAdam |
|
from torch.utils.data.dataloader import DataLoader |
|
from tqdm import tqdm |
|
|
|
from Modules.Aligner.Aligner import Aligner |
|
from Modules.Aligner.Reconstructor import Reconstructor |
|
from Preprocessing.AudioPreprocessor import AudioPreprocessor |
|
from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor |
|
|
|
|
|
def collate_and_pad(batch): |
|
|
|
return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True), |
|
torch.stack([datapoint[1] for datapoint in batch]).squeeze(1), |
|
[datapoint[2] for datapoint in batch], |
|
None, |
|
torch.stack([datapoint[4] for datapoint in batch]).squeeze()) |
|
|
|
|
|
def train_loop(train_dataset, |
|
device, |
|
save_directory, |
|
batch_size, |
|
steps, |
|
path_to_checkpoint=None, |
|
fine_tune=False, |
|
resume=False, |
|
debug_img_path=None, |
|
use_reconstruction=True, |
|
gpu_count=1, |
|
rank=0, |
|
steps_per_checkpoint=None): |
|
""" |
|
Args: |
|
resume: whether to resume from the most recent checkpoint |
|
steps: How many steps to train |
|
path_to_checkpoint: reloads a checkpoint to continue training from there |
|
fine_tune: whether to load everything from a checkpoint, or only the model parameters |
|
train_dataset: Pytorch Dataset Object for train data |
|
device: Device to put the loaded tensors on |
|
save_directory: Where to save the checkpoints |
|
batch_size: How many elements should be loaded at once |
|
debug_img_path: where to put images of the training progress if desired |
|
use_reconstruction: whether to use the auxiliary reconstruction procedure/loss, which can make the alignment sharper |
|
""" |
|
os.makedirs(save_directory, exist_ok=True) |
|
torch.multiprocessing.set_sharing_strategy('file_system') |
|
torch.multiprocessing.set_start_method('spawn', force=True) |
|
|
|
if steps_per_checkpoint is None: |
|
steps_per_checkpoint = len(train_dataset) // batch_size |
|
ap = CodecAudioPreprocessor(input_sr=-1, device=device) |
|
spectrogram_extractor = AudioPreprocessor(input_sr=16000, output_sr=16000, device=device) |
|
|
|
asr_model = Aligner().to(device) |
|
optim_asr = RAdam(asr_model.parameters(), lr=0.0001) |
|
|
|
tiny_tts = Reconstructor().to(device) |
|
optim_tts = RAdam(tiny_tts.parameters(), lr=0.0001) |
|
|
|
if gpu_count > 1: |
|
asr_model.to(rank) |
|
tiny_tts.to(rank) |
|
asr_model = torch.nn.parallel.DistributedDataParallel( |
|
asr_model, |
|
device_ids=[rank], |
|
output_device=rank, |
|
find_unused_parameters=True, |
|
).module |
|
tiny_tts = torch.nn.parallel.DistributedDataParallel( |
|
tiny_tts, |
|
device_ids=[rank], |
|
output_device=rank, |
|
find_unused_parameters=True, |
|
).module |
|
torch.distributed.barrier() |
|
train_sampler = torch.utils.data.RandomSampler(train_dataset) |
|
batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True) |
|
|
|
train_loader = DataLoader(dataset=train_dataset, |
|
num_workers=0, |
|
batch_sampler=batch_sampler_train, |
|
prefetch_factor=None, |
|
collate_fn=collate_and_pad) |
|
|
|
step_counter = 0 |
|
loss_sum = list() |
|
|
|
if resume: |
|
previous_checkpoint = os.path.join(save_directory, "aligner.pt") |
|
path_to_checkpoint = previous_checkpoint |
|
fine_tune = False |
|
|
|
if path_to_checkpoint is not None: |
|
check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device) |
|
asr_model.load_state_dict(check_dict["asr_model"]) |
|
tiny_tts.load_state_dict(check_dict["tts_model"]) |
|
if not fine_tune: |
|
optim_asr.load_state_dict(check_dict["optimizer"]) |
|
optim_tts.load_state_dict(check_dict["tts_optimizer"]) |
|
step_counter = check_dict["step_counter"] |
|
if step_counter > steps: |
|
print("Desired steps already reached in loaded checkpoint.") |
|
return |
|
start_time = time.time() |
|
|
|
while True: |
|
asr_model.train() |
|
tiny_tts.train() |
|
for batch in tqdm(train_loader): |
|
tokens = batch[0].to(device) |
|
tokens_len = batch[1].to(device) |
|
speaker_embeddings = batch[4].to(device) |
|
|
|
mels = list() |
|
mel_lengths = list() |
|
for datapoint in batch[2]: |
|
with torch.inference_mode(): |
|
|
|
speech = ap.indexes_to_audio(datapoint.int().to(device)) |
|
mel = spectrogram_extractor.audio_to_mel_spec_tensor(speech, explicit_sampling_rate=16000).transpose(0, 1).cpu() |
|
speech_len = torch.LongTensor([len(mel)]) |
|
mels.append(mel.clone()) |
|
mel_lengths.append(speech_len) |
|
mel = pad_sequence(mels, batch_first=True).to(device) |
|
mel_len = torch.stack(mel_lengths).squeeze(1).to(device) |
|
|
|
pred = asr_model(mel, mel_len) |
|
|
|
ctc_loss = asr_model.ctc_loss(pred.transpose(0, 1).log_softmax(2), |
|
tokens, |
|
mel_len, |
|
tokens_len) |
|
|
|
if use_reconstruction: |
|
speaker_embeddings_expanded = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(1).expand(-1, pred.size(1), -1) |
|
tts_lambda = min([0.1, step_counter / 10000]) |
|
reconstruction_loss = tiny_tts(x=torch.cat([pred, speaker_embeddings_expanded], dim=-1), |
|
|
|
lens=mel_len, |
|
ys=mel) * tts_lambda |
|
loss = ctc_loss + reconstruction_loss |
|
else: |
|
loss = ctc_loss |
|
|
|
optim_asr.zero_grad() |
|
if use_reconstruction: |
|
optim_tts.zero_grad() |
|
if gpu_count > 1: |
|
torch.distributed.barrier() |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0) |
|
if use_reconstruction: |
|
torch.nn.utils.clip_grad_norm_(tiny_tts.parameters(), 1.0) |
|
optim_asr.step() |
|
if use_reconstruction: |
|
optim_tts.step() |
|
|
|
loss_sum.append(loss.item()) |
|
step_counter += 1 |
|
|
|
if step_counter % steps_per_checkpoint == 0 and rank == 0: |
|
asr_model.eval() |
|
torch.save({ |
|
"asr_model" : asr_model.state_dict(), |
|
"optimizer" : optim_asr.state_dict(), |
|
"tts_model" : tiny_tts.state_dict(), |
|
"tts_optimizer": optim_tts.state_dict(), |
|
"step_counter" : step_counter, |
|
}, |
|
os.path.join(save_directory, "aligner.pt")) |
|
print("Total Loss: {}".format(round(sum(loss_sum) / len(loss_sum), 3))) |
|
print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60))) |
|
print("Steps: {}".format(step_counter)) |
|
if debug_img_path is not None: |
|
asr_model.inference(features=mel[0][:mel_len[0]], |
|
tokens=tokens[0][:tokens_len[0]], |
|
save_img_for_debug=debug_img_path + f"/{step_counter}.png", |
|
train=True) |
|
asr_model.train() |
|
loss_sum = list() |
|
|
|
if step_counter > steps and step_counter % steps_per_checkpoint == 0: |
|
return |
|
|