lmx0 commited on
Commit
2be48c4
·
1 Parent(s): 7eb99fe

Upload 4 files

Browse files
Files changed (4) hide show
  1. dataset.py +341 -0
  2. model_cnn.py +57 -0
  3. test.py +6 -0
  4. train.py +274 -0
dataset.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import json
3
+ import torch
4
+ import librosa
5
+ import torchaudio
6
+ import os
7
+ import numpy as np
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import time
12
+
13
+
14
+ def move_data_to_device(data, device):
15
+ ret = []
16
+ for i in data:
17
+ if isinstance(i, torch.Tensor):
18
+ ret.append(i.to(device))
19
+ return ret
20
+
21
+ def read_content(filepath):
22
+ '''
23
+ Read the content file for characters, pinyin and tones.
24
+
25
+ return:
26
+ dict: {index: [characters, pinyin, tones]}
27
+ exp. {'SS00050001': ['你 好 ', 'ni3 hao3 ', '3 3 ']}
28
+ '''
29
+ res = {}
30
+ with open(filepath, 'r') as f:
31
+ lines = f.readlines()
32
+ for l in lines:
33
+ l = l.replace('\n', ' ').replace('\t', ' ')
34
+ tmp = l.split(' ')
35
+ if len(tmp) == 0:
36
+ break
37
+ number = tmp[0][0:len(tmp[0])-4]
38
+ s = ''
39
+ pinyin = ''
40
+ tones = ''
41
+ for i in range(1, len(tmp)):
42
+ if len(tmp[i]) == 0:
43
+ continue
44
+ # need blank space or not?
45
+ if i % 2 == 0:
46
+ pinyin += tmp[i] + ' '
47
+ tones += tmp[i][-1] + ' '
48
+ else:
49
+ s += tmp[i] + ' '
50
+ res[number] = [s, pinyin, tones]
51
+ return res
52
+
53
+ def read_dataset_index(filepath='/kaggle/input/paddle-speech/AISHELL-3/train'):
54
+ '''
55
+ get all audio files' index and file paths
56
+ read content.txt to get corresponding words, pinyin, tones, duration
57
+
58
+ return dataframe:
59
+ ['index', 'filepath', 'word', 'pinyin', 'tone', 'duration']
60
+
61
+ 5 tones in total, 5 represents neutral tone
62
+ '''
63
+ features = read_content(os.path.join(filepath, 'content.txt'))
64
+
65
+ start_time = time.time()
66
+ count = 0
67
+
68
+ durations = {}
69
+ with open('/kaggle/input/durations/durations.txt', 'r') as f:
70
+ lines = f.readlines()
71
+ for l in lines:
72
+ tmp = (l.replace('\n', '')).split(' ')
73
+ if len(tmp) != 0:
74
+ durations[tmp[0]] = float(tmp[1])
75
+
76
+ audio_path = os.path.join(filepath, 'wav')#这里要删掉
77
+ indexes = []
78
+ for root, dirs, files in os.walk(audio_path):
79
+ for f in files:
80
+ if f.endswith('.wav'):
81
+ count += 1
82
+ index = f[0:len(f)-4]
83
+ filepath = os.path.join(audio_path, index[0:len(index)-4], f)
84
+ word, py, tone = features[index]
85
+ # du = librosa.get_duration(filename=filepath)
86
+ du = durations[index]
87
+ indexes.append((index, filepath, word, py, tone, du))
88
+
89
+ end_time = time.time()
90
+ print('#wav file read:', count)
91
+ print('read dataset index time: ', end_time - start_time)
92
+
93
+ '''indexes = sorted(indexes, key=lambda x: x[0])
94
+ with open('./durations.txt', 'w') as f:
95
+ for i in indexes:
96
+ f.write(i[0]+ ' ' + str(i[5]) + '\n')'''
97
+
98
+ return pd.DataFrame.from_records(indexes, columns=['index', 'filepath', 'word', 'pinyin', 'tone', 'duration'])
99
+
100
+
101
+ def read_dataset_index(filepath='/kaggle/input/paddle-speech/AISHELL-3/train'):
102
+ '''
103
+ get all audio files' index and file paths
104
+ read content.txt to get corresponding words, pinyin, tones, duration
105
+
106
+ return dataframe:
107
+ ['index', 'filepath', 'word', 'pinyin', 'tone', 'duration']
108
+
109
+ 5 tones in total, 5 represents neutral tone
110
+ '''
111
+ features = read_content(os.path.join(filepath, 'content.txt'))
112
+
113
+ start_time = time.time()
114
+ count = 0
115
+
116
+ durations = {}
117
+ with open('/kaggle/input/durations/durations.txt', 'r') as f:
118
+ lines = f.readlines()
119
+ for l in lines:
120
+ tmp = (l.replace('\n', '')).split(' ')
121
+ if len(tmp) != 0:
122
+ durations[tmp[0]] = float(tmp[1])
123
+
124
+ audio_path = os.path.join(filepath, 'wav')#这里要删掉
125
+ indexes = []
126
+ for root, dirs, files in os.walk(audio_path):
127
+ for f in files:
128
+ if f.endswith('.wav'):
129
+ count += 1
130
+ index = f[0:len(f)-4]
131
+ filepath = os.path.join(audio_path, index[0:len(index)-4], f)
132
+ word, py, tone = features[index]
133
+ # du = librosa.get_duration(filename=filepath)
134
+ du = durations[index]
135
+ indexes.append((index, filepath, word, py, tone, du))
136
+
137
+ end_time = time.time()
138
+ print('#wav file read:', count)
139
+ print('read dataset index time: ', end_time - start_time)
140
+
141
+ '''indexes = sorted(indexes, key=lambda x: x[0])
142
+ with open('./durations.txt', 'w') as f:
143
+ for i in indexes:
144
+ f.write(i[0]+ ' ' + str(i[5]) + '\n')'''
145
+
146
+ return pd.DataFrame.from_records(indexes, columns=['index', 'filepath', 'word', 'pinyin', 'tone', 'duration'])
147
+
148
+
149
+ def collate_fn(batch):
150
+ inp = []
151
+ f0 = []
152
+ word = []
153
+ tone = []
154
+ max_frame_num = 1600
155
+ for sample in batch:
156
+ max_frame_num = max(max_frame_num, sample[0].shape[0], sample[1].shape[0], sample[2].shape[0], sample[3].shape[0])
157
+ for sample in batch:
158
+ inp.append(
159
+ torch.nn.functional.pad(sample[0], (0, 0, 0, max_frame_num - sample[0].shape[0]), mode='constant', value=0))
160
+ f0.append(
161
+ torch.nn.functional.pad(sample[1], (0, max_frame_num - sample[1].shape[0]), mode='constant', value=0))
162
+ word.append(
163
+ torch.nn.functional.pad(sample[2], (0, 50 - sample[2].shape[0]), mode='constant', value=0))
164
+ tone.append(
165
+ torch.nn.functional.pad(sample[3], (0, 50 - sample[3].shape[0]), mode='constant', value=0))
166
+ inp = torch.stack(inp)
167
+ f0 = torch.stack(f0)
168
+ word = torch.stack(word)
169
+ tone = torch.stack(tone)
170
+
171
+ return inp, f0, word, tone
172
+
173
+
174
+ def get_data_loader(split, args):
175
+ Dataset = MyDataset(
176
+ dataset_root=args['dataset_root'],
177
+ split=split,
178
+ sampling_rate=args['sampling_rate'],
179
+ sample_length=args['sample_length'],
180
+ frame_size=args['frame_size'],
181
+ )
182
+ Dataset.dataset_index=Dataset.dataset_index[:32]
183
+ Dataset.index=Dataset.index[:32]
184
+ data_loader = DataLoader(
185
+ Dataset,
186
+ batch_size=args['batch_size'],
187
+ num_workers=args['num_workers'],
188
+ pin_memory=True,
189
+ shuffle=True, # changed into True cuz audio files recorded by same speaker are stored in the same folder
190
+ collate_fn=collate_fn,
191
+ )
192
+ return data_loader
193
+
194
+ class MyDataset(Dataset):
195
+ def __init__(self, dataset_root, split, sampling_rate, sample_length, frame_size):
196
+ self.dataset_root = dataset_root
197
+ self.split = split # train or test
198
+ self.sampling_rate = sampling_rate
199
+ self.sample_length = sample_length
200
+ self.frame_size = frame_size
201
+ self.frame_per_sec = int(1 / self.frame_size)
202
+
203
+ # self.annotations = get_annotations(get_all_file_names(os.path.join(self.dataset_root, 'AISHELL-3', split)), level='word')
204
+
205
+ self.dataset_index = read_dataset_index(os.path.join(self.dataset_root, 'AISHELL-3', split)) # maybe can be removed
206
+
207
+ self.duration = {}
208
+ self.index = self.index_data()
209
+ # print(len(self.index))
210
+ self.dataset_index=self.dataset_index[:10]
211
+ self.index=self.index[:10]
212
+
213
+ self.pinyin = {} # read encoded pinyin
214
+ with open('/kaggle/input/pinyin-encode/pinyin.txt', 'r') as f:
215
+ lines = f.readlines()
216
+ i = 0
217
+ for l in lines:
218
+ self.pinyin[l.replace('\n', '')] = i
219
+ i += 1
220
+
221
+
222
+ def index_data(self):
223
+ '''
224
+ Prepare the index for the dataset, i.e., the audio file name and starting time of each sample
225
+
226
+ go through self.dataset_index to get duration and then calculate
227
+ '''
228
+ # duration already in dataset_index
229
+ # TODO
230
+ # pass
231
+ index = []
232
+ for indexs, row in self.dataset_index.iterrows():
233
+ duration = row['duration']
234
+ num_seg = math.ceil(duration / self.sample_length)
235
+ for i in range(num_seg):
236
+ # index.append([row['index'], i * self.sample_length])
237
+ index.append([indexs, i * self.sample_length])
238
+ self.duration[row['index']] = row['duration']
239
+
240
+ return index
241
+
242
+
243
+ def __len__(self):
244
+ return len(self.index)
245
+
246
+ def __getitem__(self, idx):
247
+ '''
248
+ int idx: index of the audio file (not exp.SSB00050001)
249
+
250
+ return mel spectrogram, FUNDAMENTAL FREQUENCY(crepe/pyin), words, tones
251
+ '''
252
+ audio_fn, start_sec = self.index[idx]
253
+ end_sec = start_sec + self.sample_length
254
+ # print(start_sec, end_sec)
255
+ #???
256
+ audio_fp = self.dataset_index.loc[audio_fn,'filepath']
257
+ # audio_fp = jpath('./dataset/AISHELL-3/train/wav/SSB0005/SSB0005',audio_fp,'.wav')
258
+ #/kaggle/input/paddle-speech/AISHELL-3/train/wav/SSB0005/SSB00050001.wav
259
+
260
+ # TODO: calculate mel spectrogram
261
+ mel = None
262
+ #load data from file
263
+ waveform, sample_rate = torchaudio.load(audio_fp)
264
+ waveform = torchaudio.transforms.Resample(sample_rate, self.sampling_rate)(waveform)
265
+ mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sampling_rate, n_fft=2048, hop_length=100, n_mels=256)(waveform)
266
+ mel_spec = torch.mean(mel_spec,0)
267
+ # print(mel_spec.shape)
268
+
269
+ # TODO: calculate fundamental frequency
270
+ f0 = None
271
+ waveform, sr = librosa.load(audio_fp, sr=self.sampling_rate)
272
+ f0 = torch.from_numpy(librosa.yin(waveform, fmin=50, fmax=550, hop_length=100))
273
+
274
+ # get labels???
275
+ # word_roll, tone_roll = self.get_labels(self.annotations[self.dataset_index.loc[audio_fn, 'index']], self.dataset_index.loc[audio_fn,'duration'])
276
+ words = self.dataset_index.loc[audio_fn, 'pinyin']
277
+ w = words.split(' ')
278
+ word_roll = []
279
+ for i in range(0, len(w)):
280
+
281
+ if len(w[i]) != 0:
282
+ if self.pinyin.get(w[i][0:-1]) == None:
283
+ self.pinyin[w[i][0:-1]] = len(self.pinyin)
284
+ word_roll.append(self.pinyin[w[i][0:-1]])
285
+ tones = self.dataset_index.loc[audio_fn, 'tone']
286
+ t = tones.split(' ')
287
+ tone_roll = []
288
+ for tone in t:
289
+ if len(tone) != 0:
290
+ tone_roll.append(int(tone))
291
+
292
+ spectrogram_clip = None
293
+ f0_clip = None
294
+ onset_clip = None
295
+ offset_clip = None
296
+ word_clip = None
297
+ tone_clip = None
298
+
299
+ # TODO: create clips
300
+ start_frame = int(start_sec * self.frame_per_sec)
301
+ end_frame = start_frame + 1600 #int(end_sec * self.frame_per_sec)
302
+ # print(start_frame, end_frame)
303
+ spectrogram_clip = mel_spec[:, start_frame:end_frame].T
304
+ f0_clip = f0[start_sec:end_sec]
305
+ #word_clip = word_roll[start_frame:end_frame]
306
+ #tone_clip = tone_roll[start_frame:end_frame]
307
+
308
+ # print(tone_roll)
309
+ #return spectrogram_clip, f0_clip, onset_clip, offset_clip, pinyin_clip, tone_clip
310
+ return spectrogram_clip, f0_clip, torch.Tensor(word_roll), torch.Tensor(tone_roll) #word_clip, tone_clip
311
+
312
+ def get_labels(self, annotation_data, duration):
313
+ '''
314
+ This function read annotation from file, and then convert annotation from note-level to frame-level
315
+ Because we will be using frame-level labels in training.
316
+ '''
317
+ # TODO
318
+ # pass
319
+ frame_num = math.ceil(duration * self.frame_per_sec)
320
+
321
+ word_roll = torch.zeros(size=(frame_num + 1,), dtype=torch.long)
322
+ tone_roll = torch.zeros(size=(frame_num + 1,), dtype=torch.long)
323
+ # f0_roll = torch.zeros(size=(frame_num + 1,), dtype=torch.long)
324
+ # mel_roll = torch.zeros(size=(frame_num + 1,), dtype=torch.long)
325
+ for note in annotation_data:
326
+ start_time, end_time, mark = note # Assuming annotation format: (start_time, end_time, pitch)
327
+
328
+ # Convert note start and end times to frame indices
329
+ start_frame = int(start_time * self.frame_per_sec)
330
+ end_frame = int(end_time * self.frame_per_sec)
331
+
332
+ # Clip frame indices to be within the valid range, no need in this task
333
+ start_frame = max(0, min(frame_num, start_frame))
334
+ end_frame = max(0, min(frame_num, end_frame))
335
+ #print(start_frame, end_frame)
336
+
337
+ # WORD LEVEL Mark the frames corresponding to the note
338
+ word_roll[start_frame:end_frame+1] = self.pinyin[mark[:-1]] #mark[:-1]
339
+ tone_roll[start_frame:end_frame+1] = int(mark[-1])
340
+ # print(tone_roll)
341
+ return word_roll, tone_roll
model_cnn.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import random
6
+ import numpy as np
7
+ from scipy.fftpack import fft
8
+ import wave
9
+
10
+
11
+ class Model(nn.Module):
12
+ def __init__(self, input_dim=1, hidden_dim = 256, tone_class=5, syllable_class=1000):
13
+
14
+ super().__init__()
15
+ self.input_dim = input_dim
16
+ self.tone_class = tone_class
17
+ self.syllable_class = syllable_class
18
+ # hidden_size = 128*hidden_dim//16
19
+ conv_layers = []
20
+ in_channels = input_dim # Input channels for the first layer
21
+ channel_list = [16,16,'p2',32,32,'p2',64,64,'p1',64]
22
+ # channel_list = [32,'p','p',128]
23
+ # channel_list = [32,32,64,64,128]
24
+
25
+ for out_channels in channel_list:
26
+ if out_channels=='p2':
27
+ conv_layers.append(nn.MaxPool2d(kernel_size=2))
28
+ continue
29
+ elif out_channels=='p1':
30
+ conv_layers.append(nn.MaxPool2d(kernel_size=1))
31
+ continue
32
+ conv_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
33
+ conv_layers.append(nn.BatchNorm2d(out_channels))
34
+ conv_layers.append(nn.ReLU(inplace=True))
35
+ conv_layers.append(nn.Dropout(0.1))
36
+ in_channels = out_channels
37
+
38
+ self.conv = nn.Sequential(*conv_layers)
39
+
40
+
41
+
42
+ self.output = nn.Sequential(
43
+ nn.Linear(4096, 128),
44
+ nn.ReLU(),
45
+ nn.Dropout(0.1),
46
+
47
+ nn.Linear(128,syllable_class)
48
+ )
49
+
50
+ def forward(self, x):
51
+
52
+ x = self.conv(x) #[batch_size,channel,length(input_length//4),hidden_dim]
53
+ x = x.permute((0,2,1,3))#[batch_size,length,channel,hidden_dim]
54
+ x = x.reshape(x.shape[0],x.shape[1],-1)
55
+
56
+ return self.output(x)
57
+
test.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from train import ASR_Model
2
+ from model_cnn import Model
3
+
4
+ model = ASR_Model(device='cuda',model_path='model/model.pth')
5
+ result = model.predict('Examples_中原石化加油站.wav')
6
+ print(result)
train.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import librosa
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import torchaudio
8
+ from torch.utils.data import Dataset, DataLoader
9
+
10
+
11
+ from hparams import Hparams
12
+ from model_cnn import Model
13
+ from dataset import MyDataset
14
+
15
+
16
+ args = Hparams.args
17
+ device = args['device']
18
+ split = 'train'
19
+
20
+ tone_class = 5
21
+ NUM_EPOCHS = 100
22
+
23
+
24
+
25
+
26
+ # num_class = len(train_loader.dataset.pinyin) * tone_class + 1
27
+
28
+ # model = Model(syllable_class = num_class)
29
+ # model.to(device)
30
+
31
+
32
+
33
+ def move_data_to_device(data, device):
34
+ ret = []
35
+ for i in data:
36
+ if isinstance(i, torch.Tensor):
37
+ ret.append(i.to(device))
38
+ return ret
39
+
40
+ def collate_fn(batch):
41
+ # TODO
42
+ inp = []
43
+ f0 = []
44
+ word = []
45
+ tone = []
46
+ max_frame_num = 1600
47
+ for sample in batch:
48
+ max_frame_num = max(max_frame_num, sample[0].shape[0], sample[1].shape[0], sample[2].shape[0], sample[3].shape[0])
49
+ for sample in batch:
50
+ inp.append(
51
+ torch.nn.functional.pad(sample[0], (0, 0, 0, max_frame_num - sample[0].shape[0]), mode='constant', value=0))
52
+ f0.append(
53
+ torch.nn.functional.pad(sample[1], (0, max_frame_num - sample[1].shape[0]), mode='constant', value=0))
54
+ word.append(
55
+ torch.nn.functional.pad(sample[2], (0, 50 - sample[2].shape[0]), mode='constant', value=0))
56
+ tone.append(
57
+ torch.nn.functional.pad(sample[3], (0, 50 - sample[3].shape[0]), mode='constant', value=0))
58
+ inp = torch.stack(inp)
59
+ f0 = torch.stack(f0)
60
+ word = torch.stack(word)
61
+ tone = torch.stack(tone)
62
+
63
+ return inp, f0, word, tone
64
+
65
+ def get_data_loader(split, args):
66
+ Dataset = MyDataset(
67
+ dataset_root=args['dataset_root'],
68
+ split=split,
69
+ sampling_rate=args['sampling_rate'],
70
+ sample_length=args['sample_length'],
71
+ frame_size=args['frame_size'],
72
+ )
73
+ Dataset.dataset_index=Dataset.dataset_index[:32]
74
+ Dataset.index=Dataset.index[:32]
75
+ data_loader = DataLoader(
76
+ Dataset,
77
+ batch_size=args['batch_size'],
78
+ num_workers=args['num_workers'],
79
+ pin_memory=True,
80
+ shuffle=True, # changed into True cuz audio files recorded by same speaker are stored in the same folder
81
+ collate_fn=collate_fn,
82
+ )
83
+
84
+ return data_loader
85
+
86
+
87
+ # train_loader = get_data_loader(split='train', args=Hparams.args)
88
+ # idx2char = { idx:char for char,idx in train_loader.dataset.pinyin.items()}
89
+
90
+
91
+ # def to_pinyin(num):
92
+ # if num==0:
93
+ # return
94
+ # pinyin,tone = idx2char[(num-1)//5],(num-1)%5+1
95
+ # return pinyin,tone
96
+
97
+ def process_sequence(seq):
98
+ ret = []
99
+ for w in seq:
100
+ if len(ret)==0 or ret[-1]!=w:
101
+ ret.append(w)
102
+ return ret
103
+
104
+
105
+ # def train(NUM_EPOCHS = 100):
106
+ # optimizer = optim.Adam(model.parameters(), lr=0.002)
107
+ # criterion = nn.CrossEntropyLoss()#(ignore_index=0)
108
+
109
+ # device = Hparams.args['device']
110
+
111
+ # for epoch in range(NUM_EPOCHS):
112
+ # for idx, data in enumerate(train_loader):
113
+ # mel, target, len_mel, len_tag = move_data_to_device(data, device)
114
+ # # break
115
+ # # input_length = (mel[:,:,0]!=0.0).sum(axis=1)
116
+ # # print(mel.shape, f0.shape, word.shape, tone.shape) # torch.Size([8, 1600, 256])
117
+
118
+ # mel = mel.unsqueeze(1)
119
+
120
+ # output = model(mel)#[32, 400, 1000]
121
+ # # target[:,:len_tag].view(-1)
122
+ # # output[:,:len_tag,:].view(-1, num_classes)
123
+
124
+
125
+ # # output_len = input_length//4
126
+ # # move_data_to_device(output_len, Hparams.args['device'])
127
+ # loss = criterion(output.view(-1, num_class), target.view(-1).long())
128
+ # optimizer.zero_grad()
129
+ # loss.backward()
130
+ # optimizer.step()
131
+
132
+ # # if(idx%100==0):
133
+ # # print(f'Epoch {epoch+1},Iteration {idx+1}, Loss: {loss.item()}')
134
+
135
+ # print(f'Epoch {epoch+1}, Loss: {loss.item()}')
136
+
137
+
138
+ class ASR_Model:
139
+ '''
140
+ This is main class for training model and making predictions.
141
+ '''
142
+ def __init__(self, device="cpu", model_path=None,pinyin_path ='pinyin.txt'):
143
+ # Initialize model
144
+ self.device = device
145
+
146
+ self.pinyin = {} # read encoded pinyin
147
+
148
+ with open(pinyin_path, 'r') as f:
149
+ lines = f.readlines()
150
+ i = 0
151
+ for l in lines:
152
+ self.pinyin[l.replace('\n', '')] = i
153
+ i += 1
154
+
155
+ self.idx2char = { idx:char for char,idx in self.pinyin.items()}
156
+ num_class = 2036#len(train_loader.dataset.pinyin) * tone_class + 1
157
+
158
+ self.model = Model(syllable_class=num_class).to(self.device)
159
+ self.sampling_rate = args['sampling_rate']
160
+ if model_path is not None:
161
+ self.model = torch.load(model_path)
162
+ print('Model loaded.')
163
+ else:
164
+ print('Model initialized.')
165
+ self.model.to(device)
166
+
167
+
168
+ def fit(self, args,NUM_EPOCHS=100):
169
+ # Set paths
170
+ save_model_dir = args['save_model_dir']
171
+ if not os.path.exists(save_model_dir):
172
+ os.mkdir(save_model_dir)
173
+ loss_fn = nn.CTCLoss()
174
+ optimizer = optim.Adam(self.model.parameters(), lr=0.001)
175
+
176
+
177
+ train_loader = get_data_loader(split='train', args=args)
178
+ valid_loader = get_data_loader(split='train', args=args)
179
+
180
+ # Start training
181
+ print('Start training...')
182
+ min_valid_loss = 10000
183
+
184
+ self.model.train()
185
+ for epoch in range(NUM_EPOCHS):
186
+ for idx, data in enumerate(train_loader):
187
+ mel, f0, word, tone = move_data_to_device(data, device)
188
+ input_length = (mel[:,:,0]!=0.0).sum(axis=1)
189
+ # print(mel.shape)
190
+ mel = mel.unsqueeze(1)
191
+ # print(mel.shape)
192
+
193
+
194
+ output = self.model(mel)
195
+ output = output.permute(1,0,2)
196
+
197
+
198
+ output_len = input_length//4
199
+ move_data_to_device(output_len, Hparams.args['device'])
200
+ # print(tone.shape)
201
+ target_len = (tone!=0).sum(axis=1)
202
+
203
+ target = word*5+tone
204
+
205
+ loss = loss_fn(output,target,output_len,target_len)
206
+ optimizer.zero_grad()
207
+ loss.backward()
208
+ optimizer.step()
209
+
210
+ if(idx%100==0):
211
+ print(f'Epoch {epoch+1},Iteration {idx+1}, Loss: {loss.item()}')
212
+
213
+ # Validation
214
+ self.model.eval()
215
+ with torch.no_grad():
216
+ losses = []
217
+ for idx, data in enumerate(valid_loader):
218
+ mel, f0, word, tone = move_data_to_device(data, device)
219
+ input_length = (mel[:,:,0]!=0.0).sum(axis=1)
220
+ mel = mel.unsqueeze(1)
221
+
222
+ out = self.model(mel)
223
+ out = out.permute(1,0,2)
224
+
225
+ output_len = input_length//4
226
+ move_data_to_device(output_len, Hparams.args['device'])
227
+ target_len = (tone!=0).sum(axis=1)
228
+ target = word*5+tone
229
+
230
+ loss = loss_fn(out,target,output_len,target_len)
231
+ losses.append(loss.item())
232
+ loss = np.mean(losses)
233
+
234
+ # Save the best model
235
+ if loss < min_valid_loss:
236
+ min_valid_loss = loss
237
+ target_model_path = save_model_dir + '/best_model.pth'
238
+ torch.save(self.model, target_model_path)
239
+
240
+ def to_pinyin(self, num):
241
+ if num==0:
242
+ return
243
+ pinyin,tone = self.idx2char[(num-1)//5],(num-1)%5+1
244
+ return pinyin,tone
245
+
246
+ def getsentence(self, words):
247
+ words = words.tolist()
248
+ return [self.idx2char[int(word)] for word in words]
249
+
250
+ def predict(self, audio_fp):
251
+ """Predict results for a given test dataset."""
252
+
253
+
254
+ waveform, sample_rate = torchaudio.load(audio_fp)
255
+ waveform = torchaudio.transforms.Resample(sample_rate, self.sampling_rate)(waveform)
256
+ mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sampling_rate, n_fft=2048, hop_length=100, n_mels=256)(waveform)
257
+ mel_spec = torch.mean(mel_spec,0)
258
+
259
+ waveform, sr = librosa.load(audio_fp, sr=self.sampling_rate)
260
+ f0 = torch.from_numpy(librosa.yin(waveform, fmin=50, fmax=550, hop_length=100))
261
+ mel = torch.tensor(mel_spec.T).unsqueeze(0).unsqueeze(0)
262
+
263
+ # print(mel.shape)
264
+ self.model.eval()
265
+ with torch.no_grad():
266
+ output = self.model(mel.to(self.device))
267
+ # print(output.shape)
268
+ seq = process_sequence(output[0].cpu().numpy().argmax(-1))
269
+ result = [self.to_pinyin(c) for c in seq if c!=0]
270
+
271
+ return result
272
+
273
+
274
+