|
import os |
|
import librosa |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torchaudio |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
from hparams import Hparams |
|
from model_cnn import Model |
|
from dataset import MyDataset |
|
|
|
|
|
args = Hparams.args |
|
device = args['device'] |
|
split = 'train' |
|
|
|
tone_class = 5 |
|
NUM_EPOCHS = 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def move_data_to_device(data, device): |
|
ret = [] |
|
for i in data: |
|
if isinstance(i, torch.Tensor): |
|
ret.append(i.to(device)) |
|
return ret |
|
|
|
def collate_fn(batch): |
|
|
|
inp = [] |
|
f0 = [] |
|
word = [] |
|
tone = [] |
|
max_frame_num = 1600 |
|
for sample in batch: |
|
max_frame_num = max(max_frame_num, sample[0].shape[0], sample[1].shape[0], sample[2].shape[0], sample[3].shape[0]) |
|
for sample in batch: |
|
inp.append( |
|
torch.nn.functional.pad(sample[0], (0, 0, 0, max_frame_num - sample[0].shape[0]), mode='constant', value=0)) |
|
f0.append( |
|
torch.nn.functional.pad(sample[1], (0, max_frame_num - sample[1].shape[0]), mode='constant', value=0)) |
|
word.append( |
|
torch.nn.functional.pad(sample[2], (0, 50 - sample[2].shape[0]), mode='constant', value=0)) |
|
tone.append( |
|
torch.nn.functional.pad(sample[3], (0, 50 - sample[3].shape[0]), mode='constant', value=0)) |
|
inp = torch.stack(inp) |
|
f0 = torch.stack(f0) |
|
word = torch.stack(word) |
|
tone = torch.stack(tone) |
|
|
|
return inp, f0, word, tone |
|
|
|
def get_data_loader(split, args): |
|
Dataset = MyDataset( |
|
dataset_root=args['dataset_root'], |
|
split=split, |
|
sampling_rate=args['sampling_rate'], |
|
sample_length=args['sample_length'], |
|
frame_size=args['frame_size'], |
|
) |
|
Dataset.dataset_index=Dataset.dataset_index[:32] |
|
Dataset.index=Dataset.index[:32] |
|
data_loader = DataLoader( |
|
Dataset, |
|
batch_size=args['batch_size'], |
|
num_workers=args['num_workers'], |
|
pin_memory=True, |
|
shuffle=True, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
return data_loader |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_sequence(seq): |
|
ret = [] |
|
for w in seq: |
|
if len(ret)==0 or ret[-1]!=w: |
|
ret.append(w) |
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ASR_Model: |
|
''' |
|
This is main class for training model and making predictions. |
|
''' |
|
def __init__(self, device="cpu", model_path=None,pinyin_path ='pinyin.txt'): |
|
|
|
self.device = device |
|
|
|
self.pinyin = {} |
|
|
|
with open(pinyin_path, 'r') as f: |
|
lines = f.readlines() |
|
i = 0 |
|
for l in lines: |
|
self.pinyin[l.replace('\n', '')] = i |
|
i += 1 |
|
|
|
self.idx2char = { idx:char for char,idx in self.pinyin.items()} |
|
num_class = 2036 |
|
|
|
self.model = Model(syllable_class=num_class).to(self.device) |
|
self.sampling_rate = args['sampling_rate'] |
|
if model_path is not None: |
|
self.model = torch.load(model_path) |
|
print('Model loaded.') |
|
else: |
|
print('Model initialized.') |
|
self.model.to(device) |
|
|
|
|
|
def fit(self, args,NUM_EPOCHS=100): |
|
|
|
save_model_dir = args['save_model_dir'] |
|
if not os.path.exists(save_model_dir): |
|
os.mkdir(save_model_dir) |
|
loss_fn = nn.CTCLoss() |
|
optimizer = optim.Adam(self.model.parameters(), lr=0.001) |
|
|
|
|
|
train_loader = get_data_loader(split='train', args=args) |
|
valid_loader = get_data_loader(split='train', args=args) |
|
|
|
|
|
print('Start training...') |
|
min_valid_loss = 10000 |
|
|
|
self.model.train() |
|
for epoch in range(NUM_EPOCHS): |
|
for idx, data in enumerate(train_loader): |
|
mel, f0, word, tone = move_data_to_device(data, device) |
|
input_length = (mel[:,:,0]!=0.0).sum(axis=1) |
|
|
|
mel = mel.unsqueeze(1) |
|
|
|
|
|
|
|
output = self.model(mel) |
|
output = output.permute(1,0,2) |
|
|
|
|
|
output_len = input_length//4 |
|
move_data_to_device(output_len, Hparams.args['device']) |
|
|
|
target_len = (tone!=0).sum(axis=1) |
|
|
|
target = word*5+tone |
|
|
|
loss = loss_fn(output,target,output_len,target_len) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if(idx%100==0): |
|
print(f'Epoch {epoch+1},Iteration {idx+1}, Loss: {loss.item()}') |
|
|
|
|
|
self.model.eval() |
|
with torch.no_grad(): |
|
losses = [] |
|
for idx, data in enumerate(valid_loader): |
|
mel, f0, word, tone = move_data_to_device(data, device) |
|
input_length = (mel[:,:,0]!=0.0).sum(axis=1) |
|
mel = mel.unsqueeze(1) |
|
|
|
out = self.model(mel) |
|
out = out.permute(1,0,2) |
|
|
|
output_len = input_length//4 |
|
move_data_to_device(output_len, Hparams.args['device']) |
|
target_len = (tone!=0).sum(axis=1) |
|
target = word*5+tone |
|
|
|
loss = loss_fn(out,target,output_len,target_len) |
|
losses.append(loss.item()) |
|
loss = np.mean(losses) |
|
|
|
|
|
if loss < min_valid_loss: |
|
min_valid_loss = loss |
|
target_model_path = save_model_dir + '/best_model.pth' |
|
torch.save(self.model, target_model_path) |
|
|
|
def to_pinyin(self, num): |
|
if num==0: |
|
return |
|
pinyin,tone = self.idx2char[(num-1)//5],(num-1)%5+1 |
|
return pinyin,tone |
|
|
|
def getsentence(self, words): |
|
words = words.tolist() |
|
return [self.idx2char[int(word)] for word in words] |
|
|
|
def predict(self, audio_fp): |
|
"""Predict results for a given test dataset.""" |
|
|
|
|
|
waveform, sample_rate = torchaudio.load(audio_fp) |
|
waveform = torchaudio.transforms.Resample(sample_rate, self.sampling_rate)(waveform) |
|
mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sampling_rate, n_fft=2048, hop_length=100, n_mels=256)(waveform) |
|
mel_spec = torch.mean(mel_spec,0) |
|
|
|
waveform, sr = librosa.load(audio_fp, sr=self.sampling_rate) |
|
f0 = torch.from_numpy(librosa.yin(waveform, fmin=50, fmax=550, hop_length=100)) |
|
mel = torch.tensor(mel_spec.T).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
self.model.eval() |
|
with torch.no_grad(): |
|
output = self.model(mel.to(self.device)) |
|
|
|
seq = process_sequence(output[0].cpu().numpy().argmax(-1)) |
|
result = [self.to_pinyin(c) for c in seq if c!=0] |
|
|
|
return result |
|
|
|
|
|
|
|
|