lmx0's picture
Upload 4 files
2be48c4
import os
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, DataLoader
from hparams import Hparams
from model_cnn import Model
from dataset import MyDataset
args = Hparams.args
device = args['device']
split = 'train'
tone_class = 5
NUM_EPOCHS = 100
# num_class = len(train_loader.dataset.pinyin) * tone_class + 1
# model = Model(syllable_class = num_class)
# model.to(device)
def move_data_to_device(data, device):
ret = []
for i in data:
if isinstance(i, torch.Tensor):
ret.append(i.to(device))
return ret
def collate_fn(batch):
# TODO
inp = []
f0 = []
word = []
tone = []
max_frame_num = 1600
for sample in batch:
max_frame_num = max(max_frame_num, sample[0].shape[0], sample[1].shape[0], sample[2].shape[0], sample[3].shape[0])
for sample in batch:
inp.append(
torch.nn.functional.pad(sample[0], (0, 0, 0, max_frame_num - sample[0].shape[0]), mode='constant', value=0))
f0.append(
torch.nn.functional.pad(sample[1], (0, max_frame_num - sample[1].shape[0]), mode='constant', value=0))
word.append(
torch.nn.functional.pad(sample[2], (0, 50 - sample[2].shape[0]), mode='constant', value=0))
tone.append(
torch.nn.functional.pad(sample[3], (0, 50 - sample[3].shape[0]), mode='constant', value=0))
inp = torch.stack(inp)
f0 = torch.stack(f0)
word = torch.stack(word)
tone = torch.stack(tone)
return inp, f0, word, tone
def get_data_loader(split, args):
Dataset = MyDataset(
dataset_root=args['dataset_root'],
split=split,
sampling_rate=args['sampling_rate'],
sample_length=args['sample_length'],
frame_size=args['frame_size'],
)
Dataset.dataset_index=Dataset.dataset_index[:32]
Dataset.index=Dataset.index[:32]
data_loader = DataLoader(
Dataset,
batch_size=args['batch_size'],
num_workers=args['num_workers'],
pin_memory=True,
shuffle=True, # changed into True cuz audio files recorded by same speaker are stored in the same folder
collate_fn=collate_fn,
)
return data_loader
# train_loader = get_data_loader(split='train', args=Hparams.args)
# idx2char = { idx:char for char,idx in train_loader.dataset.pinyin.items()}
# def to_pinyin(num):
# if num==0:
# return
# pinyin,tone = idx2char[(num-1)//5],(num-1)%5+1
# return pinyin,tone
def process_sequence(seq):
ret = []
for w in seq:
if len(ret)==0 or ret[-1]!=w:
ret.append(w)
return ret
# def train(NUM_EPOCHS = 100):
# optimizer = optim.Adam(model.parameters(), lr=0.002)
# criterion = nn.CrossEntropyLoss()#(ignore_index=0)
# device = Hparams.args['device']
# for epoch in range(NUM_EPOCHS):
# for idx, data in enumerate(train_loader):
# mel, target, len_mel, len_tag = move_data_to_device(data, device)
# # break
# # input_length = (mel[:,:,0]!=0.0).sum(axis=1)
# # print(mel.shape, f0.shape, word.shape, tone.shape) # torch.Size([8, 1600, 256])
# mel = mel.unsqueeze(1)
# output = model(mel)#[32, 400, 1000]
# # target[:,:len_tag].view(-1)
# # output[:,:len_tag,:].view(-1, num_classes)
# # output_len = input_length//4
# # move_data_to_device(output_len, Hparams.args['device'])
# loss = criterion(output.view(-1, num_class), target.view(-1).long())
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# # if(idx%100==0):
# # print(f'Epoch {epoch+1},Iteration {idx+1}, Loss: {loss.item()}')
# print(f'Epoch {epoch+1}, Loss: {loss.item()}')
class ASR_Model:
'''
This is main class for training model and making predictions.
'''
def __init__(self, device="cpu", model_path=None,pinyin_path ='pinyin.txt'):
# Initialize model
self.device = device
self.pinyin = {} # read encoded pinyin
with open(pinyin_path, 'r') as f:
lines = f.readlines()
i = 0
for l in lines:
self.pinyin[l.replace('\n', '')] = i
i += 1
self.idx2char = { idx:char for char,idx in self.pinyin.items()}
num_class = 2036#len(train_loader.dataset.pinyin) * tone_class + 1
self.model = Model(syllable_class=num_class).to(self.device)
self.sampling_rate = args['sampling_rate']
if model_path is not None:
self.model = torch.load(model_path)
print('Model loaded.')
else:
print('Model initialized.')
self.model.to(device)
def fit(self, args,NUM_EPOCHS=100):
# Set paths
save_model_dir = args['save_model_dir']
if not os.path.exists(save_model_dir):
os.mkdir(save_model_dir)
loss_fn = nn.CTCLoss()
optimizer = optim.Adam(self.model.parameters(), lr=0.001)
train_loader = get_data_loader(split='train', args=args)
valid_loader = get_data_loader(split='train', args=args)
# Start training
print('Start training...')
min_valid_loss = 10000
self.model.train()
for epoch in range(NUM_EPOCHS):
for idx, data in enumerate(train_loader):
mel, f0, word, tone = move_data_to_device(data, device)
input_length = (mel[:,:,0]!=0.0).sum(axis=1)
# print(mel.shape)
mel = mel.unsqueeze(1)
# print(mel.shape)
output = self.model(mel)
output = output.permute(1,0,2)
output_len = input_length//4
move_data_to_device(output_len, Hparams.args['device'])
# print(tone.shape)
target_len = (tone!=0).sum(axis=1)
target = word*5+tone
loss = loss_fn(output,target,output_len,target_len)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if(idx%100==0):
print(f'Epoch {epoch+1},Iteration {idx+1}, Loss: {loss.item()}')
# Validation
self.model.eval()
with torch.no_grad():
losses = []
for idx, data in enumerate(valid_loader):
mel, f0, word, tone = move_data_to_device(data, device)
input_length = (mel[:,:,0]!=0.0).sum(axis=1)
mel = mel.unsqueeze(1)
out = self.model(mel)
out = out.permute(1,0,2)
output_len = input_length//4
move_data_to_device(output_len, Hparams.args['device'])
target_len = (tone!=0).sum(axis=1)
target = word*5+tone
loss = loss_fn(out,target,output_len,target_len)
losses.append(loss.item())
loss = np.mean(losses)
# Save the best model
if loss < min_valid_loss:
min_valid_loss = loss
target_model_path = save_model_dir + '/best_model.pth'
torch.save(self.model, target_model_path)
def to_pinyin(self, num):
if num==0:
return
pinyin,tone = self.idx2char[(num-1)//5],(num-1)%5+1
return pinyin,tone
def getsentence(self, words):
words = words.tolist()
return [self.idx2char[int(word)] for word in words]
def predict(self, audio_fp):
"""Predict results for a given test dataset."""
waveform, sample_rate = torchaudio.load(audio_fp)
waveform = torchaudio.transforms.Resample(sample_rate, self.sampling_rate)(waveform)
mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sampling_rate, n_fft=2048, hop_length=100, n_mels=256)(waveform)
mel_spec = torch.mean(mel_spec,0)
waveform, sr = librosa.load(audio_fp, sr=self.sampling_rate)
f0 = torch.from_numpy(librosa.yin(waveform, fmin=50, fmax=550, hop_length=100))
mel = torch.tensor(mel_spec.T).unsqueeze(0).unsqueeze(0)
# print(mel.shape)
self.model.eval()
with torch.no_grad():
output = self.model(mel.to(self.device))
# print(output.shape)
seq = process_sequence(output[0].cpu().numpy().argmax(-1))
result = [self.to_pinyin(c) for c in seq if c!=0]
return result