|
import math |
|
import json |
|
import torch |
|
import librosa |
|
import torchaudio |
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from torch.utils.data import Dataset, DataLoader |
|
import time |
|
|
|
|
|
def move_data_to_device(data, device): |
|
ret = [] |
|
for i in data: |
|
if isinstance(i, torch.Tensor): |
|
ret.append(i.to(device)) |
|
return ret |
|
|
|
def read_content(filepath): |
|
''' |
|
Read the content file for characters, pinyin and tones. |
|
|
|
return: |
|
dict: {index: [characters, pinyin, tones]} |
|
exp. {'SS00050001': ['你 好 ', 'ni3 hao3 ', '3 3 ']} |
|
''' |
|
res = {} |
|
with open(filepath, 'r') as f: |
|
lines = f.readlines() |
|
for l in lines: |
|
l = l.replace('\n', ' ').replace('\t', ' ') |
|
tmp = l.split(' ') |
|
if len(tmp) == 0: |
|
break |
|
number = tmp[0][0:len(tmp[0])-4] |
|
s = '' |
|
pinyin = '' |
|
tones = '' |
|
for i in range(1, len(tmp)): |
|
if len(tmp[i]) == 0: |
|
continue |
|
if i % 2 == 0: |
|
pinyin += tmp[i] + ' ' |
|
tones += tmp[i][-1] + ' ' |
|
else: |
|
s += tmp[i] + ' ' |
|
res[number] = [s, pinyin, tones] |
|
return res |
|
|
|
def read_dataset_index(filepath='/kaggle/input/paddle-speech/AISHELL-3/train'): |
|
''' |
|
get all audio files' index and file paths |
|
read content.txt to get corresponding words, pinyin, tones, duration |
|
|
|
return dataframe: |
|
['index', 'filepath', 'word', 'pinyin', 'tone', 'duration'] |
|
|
|
5 tones in total, 5 represents neutral tone |
|
''' |
|
features = read_content(os.path.join(filepath, 'content.txt')) |
|
|
|
start_time = time.time() |
|
count = 0 |
|
|
|
durations = {} |
|
with open('/kaggle/input/durations/durations.txt', 'r') as f: |
|
lines = f.readlines() |
|
for l in lines: |
|
tmp = (l.replace('\n', '')).split(' ') |
|
if len(tmp) != 0: |
|
durations[tmp[0]] = float(tmp[1]) |
|
|
|
audio_path = os.path.join(filepath, 'wav') |
|
indexes = [] |
|
for root, dirs, files in os.walk(audio_path): |
|
for f in files: |
|
if f.endswith('.wav'): |
|
count += 1 |
|
index = f[0:len(f)-4] |
|
filepath = os.path.join(audio_path, index[0:len(index)-4], f) |
|
word, py, tone = features[index] |
|
du = durations[index] |
|
indexes.append((index, filepath, word, py, tone, du)) |
|
|
|
end_time = time.time() |
|
print('#wav file read:', count) |
|
print('read dataset index time: ', end_time - start_time) |
|
|
|
return pd.DataFrame.from_records(indexes, columns=['index', 'filepath', 'word', 'pinyin', 'tone', 'duration']) |
|
|
|
def collate_fn(batch): |
|
inp = [] |
|
f0 = [] |
|
word = [] |
|
tone = [] |
|
max_frame_num = 1600 |
|
for sample in batch: |
|
max_frame_num = max(max_frame_num, sample[0].shape[0], sample[1].shape[0], sample[2].shape[0], sample[3].shape[0]) |
|
for sample in batch: |
|
inp.append( |
|
torch.nn.functional.pad(sample[0], (0, 0, 0, max_frame_num - sample[0].shape[0]), mode='constant', value=0)) |
|
f0.append( |
|
torch.nn.functional.pad(sample[1], (0, max_frame_num - sample[1].shape[0]), mode='constant', value=0)) |
|
word.append( |
|
torch.nn.functional.pad(sample[2], (0, 50 - sample[2].shape[0]), mode='constant', value=0)) |
|
tone.append( |
|
torch.nn.functional.pad(sample[3], (0, 50 - sample[3].shape[0]), mode='constant', value=0)) |
|
inp = torch.stack(inp) |
|
f0 = torch.stack(f0) |
|
word = torch.stack(word) |
|
tone = torch.stack(tone) |
|
|
|
return inp, f0, word, tone |
|
|
|
|
|
def get_data_loader(split, args): |
|
Dataset = MyDataset( |
|
dataset_root=args['dataset_root'], |
|
split=split, |
|
sampling_rate=args['sampling_rate'], |
|
sample_length=args['sample_length'], |
|
frame_size=args['frame_size'], |
|
) |
|
Dataset.dataset_index=Dataset.dataset_index[:32] |
|
Dataset.index=Dataset.index[:32] |
|
data_loader = DataLoader( |
|
Dataset, |
|
batch_size=args['batch_size'], |
|
num_workers=args['num_workers'], |
|
pin_memory=True, |
|
shuffle=True, |
|
collate_fn=collate_fn, |
|
) |
|
return data_loader |
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, dataset_root, split, sampling_rate, sample_length, frame_size): |
|
self.dataset_root = dataset_root |
|
self.split = split |
|
self.sampling_rate = sampling_rate |
|
self.sample_length = sample_length |
|
self.frame_size = frame_size |
|
self.frame_per_sec = int(1 / self.frame_size) |
|
|
|
|
|
|
|
self.dataset_index = read_dataset_index(os.path.join(self.dataset_root, 'AISHELL-3', split)) |
|
|
|
self.duration = {} |
|
self.index = self.index_data() |
|
|
|
|
|
self.pinyin = {} |
|
with open('/kaggle/input/pinyin-encode/pinyin.txt', 'r') as f: |
|
lines = f.readlines() |
|
i = 0 |
|
for l in lines: |
|
self.pinyin[l.replace('\n', '')] = i |
|
i += 1 |
|
|
|
|
|
def index_data(self): |
|
''' |
|
Prepare the index for the dataset, i.e., the audio file name and starting time of each sample |
|
|
|
go through self.dataset_index to get duration and then calculate |
|
''' |
|
index = [] |
|
for indexs, row in self.dataset_index.iterrows(): |
|
duration = row['duration'] |
|
num_seg = math.ceil(duration / self.sample_length) |
|
for i in range(num_seg): |
|
index.append([indexs, i * self.sample_length]) |
|
self.duration[row['index']] = row['duration'] |
|
|
|
return index |
|
|
|
|
|
def __len__(self): |
|
return len(self.index) |
|
|
|
def __getitem__(self, idx): |
|
''' |
|
int idx: index of the audio file (not exp.SSB00050001) |
|
|
|
return mel spectrogram, FUNDAMENTAL FREQUENCY(crepe/pyin), words, tones |
|
''' |
|
audio_fn, start_sec = self.index[idx] |
|
end_sec = start_sec + self.sample_length |
|
|
|
audio_fp = self.dataset_index.loc[audio_fn,'filepath'] |
|
|
|
mel = None |
|
|
|
waveform, sample_rate = torchaudio.load(audio_fp) |
|
waveform = torchaudio.transforms.Resample(sample_rate, self.sampling_rate)(waveform) |
|
mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sampling_rate, n_fft=2048, hop_length=100, n_mels=256)(waveform) |
|
mel_spec = torch.mean(mel_spec,0) |
|
|
|
|
|
|
|
f0 = None |
|
waveform, sr = librosa.load(audio_fp, sr=self.sampling_rate) |
|
f0 = torch.from_numpy(librosa.yin(waveform, fmin=50, fmax=550, hop_length=100)) |
|
|
|
|
|
words = self.dataset_index.loc[audio_fn, 'pinyin'] |
|
w = words.split(' ') |
|
word_roll = [] |
|
for i in range(0, len(w)): |
|
if len(w[i]) != 0: |
|
if self.pinyin.get(w[i][0:-1]) == None: |
|
self.pinyin[w[i][0:-1]] = len(self.pinyin) |
|
word_roll.append(self.pinyin[w[i][0:-1]]) |
|
tones = self.dataset_index.loc[audio_fn, 'tone'] |
|
t = tones.split(' ') |
|
tone_roll = [] |
|
for tone in t: |
|
if len(tone) != 0: |
|
tone_roll.append(int(tone)) |
|
|
|
spectrogram_clip = None |
|
f0_clip = None |
|
word_clip = None |
|
tone_clip = None |
|
|
|
|
|
start_frame = int(start_sec * self.frame_per_sec) |
|
end_frame = start_frame + 1600 |
|
|
|
spectrogram_clip = mel_spec[:, start_frame:end_frame].T |
|
f0_clip = f0[start_sec:end_sec] |
|
|
|
|
|
|
|
|
|
return spectrogram_clip, f0_clip, torch.Tensor(word_roll), torch.Tensor(tone_roll) |
|
|
|
def get_labels(self, annotation_data, duration): |
|
''' |
|
This function read annotation from file, and then convert annotation from note-level to frame-level |
|
Because we will be using frame-level labels in training. |
|
''' |
|
frame_num = math.ceil(duration * self.frame_per_sec) |
|
|
|
word_roll = torch.zeros(size=(frame_num + 1,), dtype=torch.long) |
|
tone_roll = torch.zeros(size=(frame_num + 1,), dtype=torch.long) |
|
|
|
for note in annotation_data: |
|
start_time, end_time, mark = note |
|
|
|
|
|
start_frame = int(start_time * self.frame_per_sec) |
|
end_frame = int(end_time * self.frame_per_sec) |
|
|
|
|
|
start_frame = max(0, min(frame_num, start_frame)) |
|
end_frame = max(0, min(frame_num, end_frame)) |
|
|
|
|
|
|
|
word_roll[start_frame:end_frame+1] = self.pinyin[mark[:-1]] |
|
tone_roll[start_frame:end_frame+1] = int(mark[-1]) |
|
|
|
return word_roll, tone_roll |
|
|