Spaces:
Sleeping
Sleeping
File size: 11,724 Bytes
ea41881 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================
"""
@author: nl8590687
声学模型基础功能模板定义
"""
import os
import time
import random
import numpy as np
from utils.ops import get_edit_distance, read_wav_data
from utils.config import load_config_file, DEFAULT_CONFIG_FILENAME, load_pinyin_dict
from utils.thread import threadsafe_generator
class ModelSpeech:
"""
语音模型类
参数:
speech_model: 声学模型类型 (BaseModel类) 实例对象
speech_features: 声学特征类型(SpeechFeatureMeta类)实例对象
"""
def __init__(self, speech_model, speech_features, max_label_length=64):
self.data_loader = None
self.speech_model = speech_model
self.trained_model, self.base_model = speech_model.get_model()
# self.load_model("save_models\SpeechModel251bn\SpeechModel251bn.model.h5")
self.speech_features = speech_features
self.max_label_length = max_label_length
@threadsafe_generator
def _data_generator(self, batch_size, data_loader):
"""
数据生成器函数,用于Keras的generator_fit训练
batch_size: 一次产生的数据量
"""
labels = np.zeros((batch_size, 1), dtype=np.float64)
data_count = data_loader.get_data_count()
index = 0
while True:
X = np.zeros((batch_size,) + self.speech_model.input_shape, dtype=np.float64)
y = np.zeros((batch_size, self.max_label_length), dtype=np.int16)
input_length = []
label_length = []
for i in range(batch_size):
wavdata, sample_rate, data_labels = data_loader.get_data(index)
data_input = self.speech_features.run(wavdata, sample_rate)
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
# 必须加上模pool_size得到的值,否则会出现inf问题,然后提示No valid path found.
# 但是直接加又可能会出现sequence_length <= xxx 的问题,因此不能让其超过时间序列长度的最大值,比如200
pool_size = self.speech_model.input_shape[0] // self.speech_model.output_shape[0]
inlen = min(data_input.shape[0] // pool_size + data_input.shape[0] % pool_size,
self.speech_model.output_shape[0])
input_length.append(inlen)
X[i, 0:len(data_input)] = data_input
y[i, 0:len(data_labels)] = data_labels
label_length.append([len(data_labels)])
index = (index + 1) % data_count
label_length = np.matrix(label_length)
input_length = np.array([input_length]).T
yield [X, y, input_length, label_length], labels
def train_model(self, optimizer, data_loader, epochs=1, save_step=1, batch_size=16, last_epoch=0, call_back=None):
"""
训练模型
参数:
optimizer:tensorflow.keras.optimizers 优化器实例对象
data_loader:数据加载器类型 (SpeechData) 实例对象
epochs: 迭代轮数
save_step: 每多少epoch保存一次模型
batch_size: mini batch大小
last_epoch: 上一次epoch的编号,可用于断点处继续训练时,epoch编号不冲突
call_back: keras call back函数
"""
save_filename = os.path.join('save_models', self.speech_model.get_model_name(),
self.speech_model.get_model_name())
self.trained_model.compile(loss=self.speech_model.get_loss_function(), optimizer=optimizer)
print('[ASRT] Compiles Model Successfully.')
yielddatas = self._data_generator(batch_size, data_loader)
data_count = data_loader.get_data_count() # 获取数据的数量
# 计算每一个epoch迭代的次数
num_iterate = data_count // batch_size
iter_start = last_epoch
iter_end = last_epoch + epochs
for epoch in range(iter_start, iter_end): # 迭代轮数
try:
epoch += 1
print('[ASRT Training] train epoch %d/%d .' % (epoch, iter_end))
data_loader.shuffle()
self.trained_model.fit_generator(yielddatas, num_iterate, callbacks=call_back)
except StopIteration:
print('[error] generator error. please check data format.')
break
if epoch % save_step == 0:
if not os.path.exists('save_models'): # 判断保存模型的目录是否存在
os.makedirs('save_models') # 如果不存在,就新建一个,避免之后保存模型的时候炸掉
if not os.path.exists(os.path.join('save_models', self.speech_model.get_model_name())): # 判断保存模型的目录是否存在
os.makedirs(
os.path.join('save_models', self.speech_model.get_model_name())) # 如果不存在,就新建一个,避免之后保存模型的时候炸掉
self.save_model(save_filename + '_epoch' + str(epoch))
print('[ASRT Info] Model training complete. ')
def load_model(self, filename):
"""
加载模型参数
"""
print(f"loading pre-trained model from {filename}")
self.speech_model.load_weights(filename)
def save_model(self, filename):
"""
保存模型参数
"""
self.speech_model.save_weights(filename)
def evaluate_model(self, data_loader, data_count=-1, out_report=False, show_ratio=True, show_per_step=100):
"""
评估检验模型的识别效果
"""
data_nums = data_loader.get_data_count()
if data_count <= 0 or data_count > data_nums: # 当data_count为小于等于0或者大于测试数据量的值时,则使用全部数据来测试
data_count = data_nums
try:
ran_num = random.randint(0, data_nums - 1) # 获取一个随机数
words_num = 0
word_error_num = 0
nowtime = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
if out_report:
txt_obj = open('Test_Report_' + data_loader.dataset_type + '_' + nowtime + '.txt', 'w',
encoding='UTF-8') # 打开文件并读入
txt_obj.truncate((data_count + 1) * 300) # 预先分配一定数量的磁盘空间,避免后期在硬盘中文件存储位置频繁移动,以防写入速度越来越慢
txt_obj.seek(0) # 从文件首开始
txt = ''
i = 0
while i < data_count:
wavdata, fs, data_labels = data_loader.get_data((ran_num + i) % data_nums) # 从随机数开始连续向后取一定数量数据
data_input = self.speech_features.run(wavdata, fs)
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
# 数据格式出错处理 开始
# 当输入的wav文件长度过长时自动跳过该文件,转而使用下一个wav文件来运行
if data_input.shape[0] > self.speech_model.input_shape[0]:
print('*[Error]', 'wave data lenghth of num', (ran_num + i) % data_nums, 'is too long.',
'this data\'s length is', data_input.shape[0],
'expect <=', self.speech_model.input_shape[0],
'\n A Exception raise when test Speech Model.')
i += 1
continue
# 数据格式出错处理 结束
pre = self.predict(data_input)
words_n = data_labels.shape[0] # 获取每个句子的字数
words_num += words_n # 把句子的总字数加上
edit_distance = get_edit_distance(data_labels, pre) # 获取编辑距离
if edit_distance <= words_n: # 当编辑距离小于等于句子字数时
word_error_num += edit_distance # 使用编辑距离作为错误字数
else: # 否则肯定是增加了一堆乱七八糟的奇奇怪怪的字
word_error_num += words_n # 就直接加句子本来的总字数就好了
if i % show_per_step == 0 and show_ratio:
print('[ASRT Info] Testing: ', i, '/', data_count)
txt = ''
if out_report:
txt += str(i) + '\n'
txt += 'True:\t' + str(data_labels) + '\n'
txt += 'Pred:\t' + str(pre) + '\n'
txt += '\n'
txt_obj.write(txt)
i += 1
# print('*[测试结果] 语音识别 ' + str_dataset + ' 集语音单字错误率:', word_error_num / words_num * 100, '%')
print('*[ASRT Test Result] Speech Recognition ' + data_loader.dataset_type + ' set word error ratio: ',
word_error_num / words_num * 100, '%')
if out_report:
txt = '*[ASRT Test Result] Speech Recognition ' + data_loader.dataset_type + ' set word error ratio: ' + str(
word_error_num / words_num * 100) + ' %'
txt_obj.write(txt)
txt_obj.truncate() # 去除文件末尾剩余未使用的空白存储字节
txt_obj.close()
except StopIteration:
print('[ASRT Error] Model testing raise a error. Please check data format.')
def predict(self, data_input):
"""
预测结果
返回语音识别后的forward结果
"""
return self.speech_model.forward(data_input)
def recognize_speech(self, wavsignal, fs):
"""
最终做语音识别用的函数,识别一个wav序列的语音
"""
# 获取输入特征
data_input = self.speech_features.run(wavsignal, fs)
data_input = np.array(data_input, dtype=np.float64)
# print(data_input,data_input.shape)
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
r1 = self.predict(data_input)
# 获取拼音列表
list_symbol_dic, _ = load_pinyin_dict(load_config_file(DEFAULT_CONFIG_FILENAME)['dict_filename'])
r_str = []
for i in r1:
r_str.append(list_symbol_dic[i])
return r_str
def recognize_speech_from_file(self, filename):
"""
最终做语音识别用的函数,识别指定文件名的语音
"""
wavsignal, sample_rate, _, _ = read_wav_data(filename)
r = self.recognize_speech(wavsignal, sample_rate)
return r
@property
def model(self):
"""
返回tf.keras model
"""
return self.trained_model
|