Spaces:
Sleeping
Sleeping
File size: 6,154 Bytes
ea41881 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================
"""
@author: nl8590687
一些常用操作函数的定义
"""
import wave
import difflib
import matplotlib.pyplot as plt
import numpy as np
import os
def read_wav_data(filename: str) -> tuple:
"""
读取一个wav文件,返回声音信号的时域谱矩阵和播放时间
"""
# if os.path.exists(filename):
# print("文件存在")
# else:
# print("文件不存在",filename)
wav = wave.open(filename,"rb") # 打开一个wav格式的声音文件流
num_frame = wav.getnframes() # 获取帧数
num_channel=wav.getnchannels() # 获取声道数
framerate=wav.getframerate() # 获取帧速率
num_sample_width=wav.getsampwidth() # 获取实例的比特宽度,即每一帧的字节数
str_data = wav.readframes(num_frame) # 读取全部的帧
wav.close() # 关闭流
wave_data = np.fromstring(str_data, dtype = np.short) # 将声音文件数据转换为数组矩阵形式
wave_data.shape = -1, num_channel # 按照声道数将数组整形,单声道时候是一列数组,双声道时候是两列的矩阵
wave_data = wave_data.T # 将矩阵转置
return wave_data, framerate, num_channel, num_sample_width
def read_wav_bytes(filename: str) -> tuple:
"""
读取一个wav文件,返回声音信号的时域谱矩阵和播放时间
"""
wav = wave.open(filename,"rb") # 打开一个wav格式的声音文件流
num_frame = wav.getnframes() # 获取帧数
num_channel=wav.getnchannels() # 获取声道数
framerate=wav.getframerate() # 获取帧速率
num_sample_width=wav.getsampwidth() # 获取实例的比特宽度,即每一帧的字节数
str_data = wav.readframes(num_frame) # 读取全部的帧
wav.close() # 关闭流
return str_data, framerate, num_channel, num_sample_width
def get_edit_distance(str1, str2) -> int:
"""
计算两个串的编辑距离,支持str和list类型
"""
leven_cost = 0
sequence_match = difflib.SequenceMatcher(None, str1, str2)
for tag, index_1, index_2, index_j1, index_j2 in sequence_match.get_opcodes():
if tag == 'replace':
leven_cost += max(index_2-index_1, index_j2-index_j1)
elif tag == 'insert':
leven_cost += (index_j2-index_j1)
elif tag == 'delete':
leven_cost += (index_2-index_1)
return leven_cost
def ctc_decode_delete_tail_blank(ctc_decode_list):
"""
处理CTC解码后序列末尾余留的空白元素,删除掉
"""
p = 0
while p < len(ctc_decode_list) and ctc_decode_list[p] != -1:
p += 1
return ctc_decode_list[0:p]
def visual_1D(points_list, frequency=1):
"""
可视化1D数据
"""
# 首先创建绘图网格,1个子图
fig, ax = plt.subplots(1)
x = np.linspace(0, len(points_list)-1, len(points_list)) / frequency
# 在对应对象上调用 plot() 方法
ax.plot(x, points_list)
fig.show()
def visual_2D(img):
"""
可视化2D数据
"""
plt.subplot(111)
plt.imshow(img)
plt.colorbar(cax=None, ax=None, shrink=0.5)
plt.show()
def decode_wav_bytes(samples_data: bytes, channels: int = 1, byte_width: int = 2) -> list:
"""
解码wav格式样本点字节流,得到numpy数组
"""
numpy_type = np.short
if byte_width == 4:
numpy_type = np.int
elif byte_width != 2:
raise Exception('error: unsurpport byte width `' + str(byte_width) + '`')
wave_data = np.fromstring(samples_data, dtype=numpy_type) # 将声音文件数据转换为数组矩阵形式
wave_data.shape = -1, channels # 按照声道数将数组整形,单声道时候是一列数组,双声道时候是两列的矩阵
wave_data = wave_data.T # 将矩阵转置
return wave_data
def get_symbol_dict(dict_filename):
"""
读取拼音汉字的字典文件
返回读取后的字典
"""
txt_obj = open(dict_filename, 'r', encoding='UTF-8') # 打开文件并读入
txt_text = txt_obj.read()
txt_obj.close()
txt_lines = txt_text.split('\n') # 文本分割
dic_symbol = {} # 初始化符号字典
for i in txt_lines:
list_symbol = [] # 初始化符号列表
if i != '':
txt_l=i.split('\t')
pinyin = txt_l[0]
for word in txt_l[1]:
list_symbol.append(word)
dic_symbol[pinyin] = list_symbol
return dic_symbol
def get_language_model(model_language_filename):
"""
读取语言模型的文件
返回读取后的模型
"""
txt_obj = open(model_language_filename, 'r', encoding='UTF-8') # 打开文件并读入
txt_text = txt_obj.read()
txt_obj.close()
txt_lines = txt_text.split('\n') # 文本分割
dic_model = {} # 初始化符号字典
for i in txt_lines:
if i != '':
txt_l = i.split('\t')
if len(txt_l) == 1:
continue
dic_model[txt_l[0]] = txt_l[1]
return dic_model
def ctc_decode_stream(tokens):
i = 0
while i < len(tokens):
while i+1 < len(tokens) and tokens[i] == tokens[i+1]:
i += 1
if i+1 == len(tokens) and tokens[i] != -1:
return tokens[0], []
if tokens[i] != -1:
return tokens[i], tokens[i+1:]
i += 1
return -1, []
|