update
Browse files
.gitignore
CHANGED
@@ -8,11 +8,11 @@
|
|
8 |
**/logs/
|
9 |
**/__pycache__/
|
10 |
|
11 |
-
data/
|
12 |
-
docs/
|
13 |
-
dotenv/
|
14 |
-
trained_models/
|
15 |
-
temp/
|
16 |
|
17 |
#**/*.wav
|
18 |
**/*.xlsx
|
|
|
8 |
**/logs/
|
9 |
**/__pycache__/
|
10 |
|
11 |
+
/data/
|
12 |
+
/docs/
|
13 |
+
/dotenv/
|
14 |
+
/trained_models/
|
15 |
+
/temp/
|
16 |
|
17 |
#**/*.wav
|
18 |
**/*.xlsx
|
examples/vm_sound_classification/run.sh
CHANGED
@@ -13,7 +13,7 @@ E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
|
|
13 |
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
|
14 |
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
|
15 |
|
16 |
-
sh run.sh --stage
|
17 |
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
18 |
|
19 |
|
|
|
13 |
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
|
14 |
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
|
15 |
|
16 |
+
sh run.sh --stage 0 --stop_stage 1 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8-ch16 \
|
17 |
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
18 |
|
19 |
|
toolbox/torch/utils/data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/torch/utils/data/dataset/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/torch/utils/data/dataset/wave_classifier_excel_dataset.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from scipy.io import wavfile
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
15 |
+
|
16 |
+
|
17 |
+
class WaveClassifierExcelDataset(Dataset):
|
18 |
+
def __init__(self,
|
19 |
+
vocab: Vocabulary,
|
20 |
+
excel_file: str,
|
21 |
+
expected_sample_rate: int,
|
22 |
+
resample: bool = False,
|
23 |
+
root_path: str = None,
|
24 |
+
category: str = None,
|
25 |
+
category_field: str = "category",
|
26 |
+
label_field: str = "labels",
|
27 |
+
max_wave_value: float = 1.0,
|
28 |
+
) -> None:
|
29 |
+
self.vocab = vocab
|
30 |
+
self.excel_file = excel_file
|
31 |
+
|
32 |
+
self.expected_sample_rate = expected_sample_rate
|
33 |
+
self.resample = resample
|
34 |
+
self.root_path = root_path
|
35 |
+
self.category = category
|
36 |
+
self.category_field = category_field
|
37 |
+
self.label_field = label_field
|
38 |
+
self.max_wave_value = max_wave_value
|
39 |
+
|
40 |
+
df = pd.read_excel(excel_file)
|
41 |
+
|
42 |
+
samples = list()
|
43 |
+
for i, row in tqdm(df.iterrows(), total=len(df)):
|
44 |
+
filename = row["filename"]
|
45 |
+
label = row[self.label_field]
|
46 |
+
|
47 |
+
if self.category is not None and self.category != row[self.category_field]:
|
48 |
+
continue
|
49 |
+
|
50 |
+
samples.append({
|
51 |
+
"filename": filename,
|
52 |
+
"label": label,
|
53 |
+
})
|
54 |
+
self.samples = samples
|
55 |
+
|
56 |
+
def __getitem__(self, index):
|
57 |
+
sample = self.samples[index]
|
58 |
+
filename = sample["filename"]
|
59 |
+
label = sample["label"]
|
60 |
+
|
61 |
+
if self.root_path is not None:
|
62 |
+
filename = os.path.join(self.root_path, filename)
|
63 |
+
|
64 |
+
waveform = self.filename_to_waveform(filename)
|
65 |
+
|
66 |
+
namespace = self.label_field if self.category is None else self.category
|
67 |
+
token_to_index = self.vocab.get_token_to_index_vocabulary(namespace=namespace)
|
68 |
+
label: int = token_to_index[label]
|
69 |
+
|
70 |
+
result = {
|
71 |
+
"waveform": waveform,
|
72 |
+
"label": torch.tensor(label, dtype=torch.int64),
|
73 |
+
}
|
74 |
+
return result
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
return len(self.samples)
|
78 |
+
|
79 |
+
def filename_to_waveform(self, filename: str):
|
80 |
+
try:
|
81 |
+
if self.resample:
|
82 |
+
waveform, sample_rate = librosa.load(filename, sr=self.expected_sample_rate)
|
83 |
+
# waveform, sample_rate = torchaudio.load(filename, normalize=True)
|
84 |
+
else:
|
85 |
+
sample_rate, waveform = wavfile.read(filename)
|
86 |
+
waveform = waveform / self.max_wave_value
|
87 |
+
except ValueError as e:
|
88 |
+
print(filename)
|
89 |
+
raise e
|
90 |
+
if sample_rate != self.expected_sample_rate:
|
91 |
+
raise AssertionError
|
92 |
+
|
93 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
94 |
+
return waveform
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
pass
|
toolbox/torch/utils/data/vocabulary.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from collections import defaultdict, OrderedDict
|
4 |
+
import os
|
5 |
+
from typing import Any, Callable, Dict, Iterable, List, Set
|
6 |
+
|
7 |
+
|
8 |
+
def namespace_match(pattern: str, namespace: str):
|
9 |
+
"""
|
10 |
+
Matches a namespace pattern against a namespace string. For example, ``*tags`` matches
|
11 |
+
``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
|
12 |
+
``stemmed_tokens``.
|
13 |
+
"""
|
14 |
+
if pattern[0] == '*' and namespace.endswith(pattern[1:]):
|
15 |
+
return True
|
16 |
+
elif pattern == namespace:
|
17 |
+
return True
|
18 |
+
return False
|
19 |
+
|
20 |
+
|
21 |
+
class _NamespaceDependentDefaultDict(defaultdict):
|
22 |
+
def __init__(self,
|
23 |
+
non_padded_namespaces: Set[str],
|
24 |
+
padded_function: Callable[[], Any],
|
25 |
+
non_padded_function: Callable[[], Any]) -> None:
|
26 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
27 |
+
self._padded_function = padded_function
|
28 |
+
self._non_padded_function = non_padded_function
|
29 |
+
super(_NamespaceDependentDefaultDict, self).__init__()
|
30 |
+
|
31 |
+
def __missing__(self, key: str):
|
32 |
+
if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces):
|
33 |
+
value = self._non_padded_function()
|
34 |
+
else:
|
35 |
+
value = self._padded_function()
|
36 |
+
dict.__setitem__(self, key, value)
|
37 |
+
return value
|
38 |
+
|
39 |
+
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
|
40 |
+
# add non_padded_namespaces which weren't already present
|
41 |
+
self._non_padded_namespaces.update(non_padded_namespaces)
|
42 |
+
|
43 |
+
|
44 |
+
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
|
45 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
46 |
+
super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces,
|
47 |
+
lambda: {padding_token: 0, oov_token: 1},
|
48 |
+
lambda: {})
|
49 |
+
|
50 |
+
|
51 |
+
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
|
52 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
53 |
+
super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces,
|
54 |
+
lambda: {0: padding_token, 1: oov_token},
|
55 |
+
lambda: {})
|
56 |
+
|
57 |
+
|
58 |
+
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
|
59 |
+
DEFAULT_PADDING_TOKEN = '[PAD]'
|
60 |
+
DEFAULT_OOV_TOKEN = '[UNK]'
|
61 |
+
NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'
|
62 |
+
|
63 |
+
|
64 |
+
class Vocabulary(object):
|
65 |
+
def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES):
|
66 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
67 |
+
self._padding_token = DEFAULT_PADDING_TOKEN
|
68 |
+
self._oov_token = DEFAULT_OOV_TOKEN
|
69 |
+
self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
|
70 |
+
self._padding_token,
|
71 |
+
self._oov_token)
|
72 |
+
self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
|
73 |
+
self._padding_token,
|
74 |
+
self._oov_token)
|
75 |
+
|
76 |
+
def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int:
|
77 |
+
if token not in self._token_to_index[namespace]:
|
78 |
+
index = len(self._token_to_index[namespace])
|
79 |
+
self._token_to_index[namespace][token] = index
|
80 |
+
self._index_to_token[namespace][index] = token
|
81 |
+
return index
|
82 |
+
else:
|
83 |
+
return self._token_to_index[namespace][token]
|
84 |
+
|
85 |
+
def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]:
|
86 |
+
return self._index_to_token[namespace]
|
87 |
+
|
88 |
+
def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]:
|
89 |
+
return self._token_to_index[namespace]
|
90 |
+
|
91 |
+
def get_token_index(self, token: str, namespace: str = 'tokens') -> int:
|
92 |
+
if token in self._token_to_index[namespace]:
|
93 |
+
return self._token_to_index[namespace][token]
|
94 |
+
else:
|
95 |
+
return self._token_to_index[namespace][self._oov_token]
|
96 |
+
|
97 |
+
def get_token_from_index(self, index: int, namespace: str = 'tokens'):
|
98 |
+
return self._index_to_token[namespace][index]
|
99 |
+
|
100 |
+
def get_vocab_size(self, namespace: str = 'tokens') -> int:
|
101 |
+
return len(self._token_to_index[namespace])
|
102 |
+
|
103 |
+
def save_to_files(self, directory: str):
|
104 |
+
os.makedirs(directory, exist_ok=True)
|
105 |
+
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f:
|
106 |
+
for namespace_str in self._non_padded_namespaces:
|
107 |
+
f.write('{}\n'.format(namespace_str))
|
108 |
+
|
109 |
+
for namespace, token_to_index in self._token_to_index.items():
|
110 |
+
filename = os.path.join(directory, '{}.txt'.format(namespace))
|
111 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
112 |
+
for token, _ in token_to_index.items():
|
113 |
+
f.write('{}\n'.format(token))
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def from_files(cls, directory: str) -> 'Vocabulary':
|
117 |
+
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f:
|
118 |
+
non_padded_namespaces = [namespace_str.strip() for namespace_str in f]
|
119 |
+
|
120 |
+
vocab = cls(non_padded_namespaces=non_padded_namespaces)
|
121 |
+
|
122 |
+
for namespace_filename in os.listdir(directory):
|
123 |
+
if namespace_filename == NAMESPACE_PADDING_FILE:
|
124 |
+
continue
|
125 |
+
if namespace_filename.startswith("."):
|
126 |
+
continue
|
127 |
+
namespace = namespace_filename.replace('.txt', '')
|
128 |
+
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
|
129 |
+
is_padded = False
|
130 |
+
else:
|
131 |
+
is_padded = True
|
132 |
+
filename = os.path.join(directory, namespace_filename)
|
133 |
+
vocab.set_from_file(filename, is_padded, namespace=namespace)
|
134 |
+
|
135 |
+
return vocab
|
136 |
+
|
137 |
+
def set_from_file(self,
|
138 |
+
filename: str,
|
139 |
+
is_padded: bool = True,
|
140 |
+
oov_token: str = DEFAULT_OOV_TOKEN,
|
141 |
+
namespace: str = "tokens"
|
142 |
+
):
|
143 |
+
if is_padded:
|
144 |
+
self._token_to_index[namespace] = {self._padding_token: 0}
|
145 |
+
self._index_to_token[namespace] = {0: self._padding_token}
|
146 |
+
else:
|
147 |
+
self._token_to_index[namespace] = {}
|
148 |
+
self._index_to_token[namespace] = {}
|
149 |
+
|
150 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
151 |
+
index = 1 if is_padded else 0
|
152 |
+
for row in f:
|
153 |
+
token = str(row).strip()
|
154 |
+
if token == oov_token:
|
155 |
+
token = self._oov_token
|
156 |
+
self._token_to_index[namespace][token] = index
|
157 |
+
self._index_to_token[namespace][index] = token
|
158 |
+
index += 1
|
159 |
+
|
160 |
+
def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"):
|
161 |
+
result = list()
|
162 |
+
for token in tokens:
|
163 |
+
idx = self._token_to_index[namespace].get(token)
|
164 |
+
if idx is None:
|
165 |
+
idx = self._token_to_index[namespace][self._oov_token]
|
166 |
+
result.append(idx)
|
167 |
+
return result
|
168 |
+
|
169 |
+
def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"):
|
170 |
+
result = list()
|
171 |
+
for idx in ids:
|
172 |
+
idx = self._index_to_token[namespace][idx]
|
173 |
+
result.append(idx)
|
174 |
+
return result
|
175 |
+
|
176 |
+
def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"):
|
177 |
+
pad_idx = self._token_to_index[namespace][self._padding_token]
|
178 |
+
|
179 |
+
length = len(ids)
|
180 |
+
if length > max_length:
|
181 |
+
result = ids[:max_length]
|
182 |
+
else:
|
183 |
+
result = ids + [pad_idx] * (max_length - length)
|
184 |
+
return result
|
185 |
+
|
186 |
+
|
187 |
+
def demo1():
|
188 |
+
import jieba
|
189 |
+
|
190 |
+
vocabulary = Vocabulary()
|
191 |
+
vocabulary.add_token_to_namespace('白天', 'tokens')
|
192 |
+
vocabulary.add_token_to_namespace('晚上', 'tokens')
|
193 |
+
|
194 |
+
text = '不是在白天, 就是在晚上'
|
195 |
+
tokens = jieba.lcut(text)
|
196 |
+
|
197 |
+
print(tokens)
|
198 |
+
|
199 |
+
ids = vocabulary.convert_tokens_to_ids(tokens)
|
200 |
+
print(ids)
|
201 |
+
|
202 |
+
padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10)
|
203 |
+
print(padded_idx)
|
204 |
+
|
205 |
+
tokens = vocabulary.convert_ids_to_tokens(padded_idx)
|
206 |
+
print(tokens)
|
207 |
+
return
|
208 |
+
|
209 |
+
|
210 |
+
if __name__ == '__main__':
|
211 |
+
demo1()
|