HoneyTian commited on
Commit
88b2fbf
·
1 Parent(s): 7f9e32d
examples/conv_tasnet/step_1_prepare_data.py CHANGED
@@ -1,22 +1,18 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  import os
5
  from pathlib import Path
6
  import random
7
  import sys
8
- import shutil
9
 
10
  pwd = os.path.abspath(os.path.dirname(__file__))
11
  sys.path.append(os.path.join(pwd, "../../"))
12
 
13
- import pandas as pd
14
- from scipy.io import wavfile
15
  from tqdm import tqdm
16
  import librosa
17
 
18
- from project_settings import project_path
19
-
20
 
21
  def get_args():
22
  parser = argparse.ArgumentParser()
@@ -33,8 +29,8 @@ def get_args():
33
  type=str
34
  )
35
 
36
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
 
39
  parser.add_argument("--duration", default=2.0, type=float)
40
  parser.add_argument("--min_snr_db", default=-10, type=float)
@@ -80,7 +76,9 @@ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate
80
  yield row
81
 
82
 
83
- def get_dataset(args):
 
 
84
  file_dir = Path(args.file_dir)
85
  file_dir.mkdir(exist_ok=True)
86
 
@@ -104,99 +102,56 @@ def get_dataset(args):
104
 
105
  count = 0
106
  process_bar = tqdm(desc="build dataset excel")
107
- for noise, speech in zip(noise_generator, speech_generator):
108
- if count >= args.max_count:
109
- break
110
-
111
- noise_filename = noise["filename"]
112
- noise_raw_duration = noise["raw_duration"]
113
- noise_offset = noise["offset"]
114
- noise_duration = noise["duration"]
115
-
116
- speech_filename = speech["filename"]
117
- speech_raw_duration = speech["raw_duration"]
118
- speech_offset = speech["offset"]
119
- speech_duration = speech["duration"]
120
-
121
- random1 = random.random()
122
- random2 = random.random()
123
-
124
- row = {
125
- "noise_filename": noise_filename,
126
- "noise_raw_duration": noise_raw_duration,
127
- "noise_offset": noise_offset,
128
- "noise_duration": noise_duration,
129
-
130
- "speech_filename": speech_filename,
131
- "speech_raw_duration": speech_raw_duration,
132
- "speech_offset": speech_offset,
133
- "speech_duration": speech_duration,
134
-
135
- "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
136
-
137
- "random1": random1,
138
- "random2": random2,
139
- "flag": "TRAIN" if random2 < 0.8 else "TEST",
140
- }
141
- dataset.append(row)
142
- count += 1
143
- duration_seconds = count * args.duration
144
- duration_hours = duration_seconds / 3600
145
-
146
- process_bar.update(n=1)
147
- process_bar.set_postfix({
148
- # "duration_seconds": round(duration_seconds, 4),
149
- "duration_hours": round(duration_hours, 4),
150
-
151
- })
152
-
153
- dataset = pd.DataFrame(dataset)
154
- dataset = dataset.sort_values(by=["random1"], ascending=False)
155
- dataset.to_excel(
156
- file_dir / "dataset.xlsx",
157
- index=False,
158
- )
159
- return
160
-
161
-
162
- def split_dataset(args):
163
- """分割训练集, 测试集"""
164
- file_dir = Path(args.file_dir)
165
- file_dir.mkdir(exist_ok=True)
166
-
167
- df = pd.read_excel(file_dir / "dataset.xlsx")
168
-
169
- train = list()
170
- test = list()
171
-
172
- for i, row in df.iterrows():
173
- flag = row["flag"]
174
- if flag == "TRAIN":
175
- train.append(row)
176
- else:
177
- test.append(row)
178
-
179
- train = pd.DataFrame(train)
180
- train.to_excel(
181
- args.train_dataset,
182
- index=False,
183
- # encoding="utf_8_sig"
184
- )
185
- test = pd.DataFrame(test)
186
- test.to_excel(
187
- args.valid_dataset,
188
- index=False,
189
- # encoding="utf_8_sig"
190
- )
191
-
192
- return
193
-
194
-
195
- def main():
196
- args = get_args()
197
 
198
- get_dataset(args)
199
- split_dataset(args)
200
  return
201
 
202
 
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ import json
5
  import os
6
  from pathlib import Path
7
  import random
8
  import sys
 
9
 
10
  pwd = os.path.abspath(os.path.dirname(__file__))
11
  sys.path.append(os.path.join(pwd, "../../"))
12
 
 
 
13
  from tqdm import tqdm
14
  import librosa
15
 
 
 
16
 
17
  def get_args():
18
  parser = argparse.ArgumentParser()
 
29
  type=str
30
  )
31
 
32
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
33
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
34
 
35
  parser.add_argument("--duration", default=2.0, type=float)
36
  parser.add_argument("--min_snr_db", default=-10, type=float)
 
76
  yield row
77
 
78
 
79
+ def main():
80
+ args = get_args()
81
+
82
  file_dir = Path(args.file_dir)
83
  file_dir.mkdir(exist_ok=True)
84
 
 
102
 
103
  count = 0
104
  process_bar = tqdm(desc="build dataset excel")
105
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
106
+ for noise, speech in zip(noise_generator, speech_generator):
107
+ if count >= args.max_count:
108
+ break
109
+
110
+ noise_filename = noise["filename"]
111
+ noise_raw_duration = noise["raw_duration"]
112
+ noise_offset = noise["offset"]
113
+ noise_duration = noise["duration"]
114
+
115
+ speech_filename = speech["filename"]
116
+ speech_raw_duration = speech["raw_duration"]
117
+ speech_offset = speech["offset"]
118
+ speech_duration = speech["duration"]
119
+
120
+ random1 = random.random()
121
+ random2 = random.random()
122
+
123
+ row = {
124
+ "noise_filename": noise_filename,
125
+ "noise_raw_duration": noise_raw_duration,
126
+ "noise_offset": noise_offset,
127
+ "noise_duration": noise_duration,
128
+
129
+ "speech_filename": speech_filename,
130
+ "speech_raw_duration": speech_raw_duration,
131
+ "speech_offset": speech_offset,
132
+ "speech_duration": speech_duration,
133
+
134
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
135
+
136
+ "random1": random1,
137
+ }
138
+ row = json.dumps(row, ensure_ascii=False)
139
+ if random2 < 0.8:
140
+ ftrain.write(f"{row}\n")
141
+ else:
142
+ fvalid.write(f"{row}\n")
143
+
144
+ count += 1
145
+ duration_seconds = count * args.duration
146
+ duration_hours = duration_seconds / 3600
147
+
148
+ process_bar.update(n=1)
149
+ process_bar.set_postfix({
150
+ # "duration_seconds": round(duration_seconds, 4),
151
+ "duration_hours": round(duration_hours, 4),
152
+
153
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
 
 
155
  return
156
 
157
 
main.py CHANGED
@@ -74,6 +74,13 @@ denoise_engines = {
74
  project_path / "trained_models/mpnet-nx-speech-20-epoch.zip").as_posix()
75
  }
76
  },
 
 
 
 
 
 
 
77
  "mpnet-aishell-1-epoch": {
78
  "infer_cls": InferenceMPNet,
79
  "kwargs": {
 
74
  project_path / "trained_models/mpnet-nx-speech-20-epoch.zip").as_posix()
75
  }
76
  },
77
+ "mpnet-nx-speech-33-epoch-best": {
78
+ "infer_cls": InferenceMPNet,
79
+ "kwargs": {
80
+ "pretrained_model_path_or_zip_file": (
81
+ project_path / "trained_models/mpnet-nx-speech-33-epoch-best.zip").as_posix()
82
+ }
83
+ },
84
  "mpnet-aishell-1-epoch": {
85
  "infer_cls": InferenceMPNet,
86
  "kwargs": {
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import os
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import pandas as pd
9
+ from scipy.io import wavfile
10
+ import torch
11
+ import torchaudio
12
+ from torch.utils.data import Dataset
13
+ from tqdm import tqdm
14
+
15
+
16
+ class DenoiseJsonlDataset(Dataset):
17
+ def __init__(self,
18
+ jsonl_file: str,
19
+ expected_sample_rate: int,
20
+ resample: bool = False,
21
+ max_wave_value: float = 1.0,
22
+ ):
23
+ self.jsonl_file = jsonl_file
24
+ self.expected_sample_rate = expected_sample_rate
25
+ self.resample = resample
26
+ self.max_wave_value = max_wave_value
27
+
28
+ self.samples = self.load_samples(jsonl_file)
29
+
30
+ @staticmethod
31
+ def load_samples(filename: str):
32
+ samples = list()
33
+ with open(filename, "r", encoding="utf-8") as f:
34
+ for row in f:
35
+ row = json.loads(row)
36
+ noise_filename = row["noise_filename"]
37
+ noise_raw_duration = row["noise_raw_duration"]
38
+ noise_offset = row["noise_offset"]
39
+ noise_duration = row["noise_duration"]
40
+
41
+ speech_filename = row["speech_filename"]
42
+ speech_raw_duration = row["speech_raw_duration"]
43
+ speech_offset = row["speech_offset"]
44
+ speech_duration = row["speech_duration"]
45
+
46
+ snr_db = row["snr_db"]
47
+
48
+ row = {
49
+ "noise_filename": noise_filename,
50
+ "noise_raw_duration": noise_raw_duration,
51
+ "noise_offset": noise_offset,
52
+ "noise_duration": noise_duration,
53
+
54
+ "speech_filename": speech_filename,
55
+ "speech_raw_duration": speech_raw_duration,
56
+ "speech_offset": speech_offset,
57
+ "speech_duration": speech_duration,
58
+
59
+ "snr_db": snr_db,
60
+ }
61
+ samples.append(row)
62
+ return samples
63
+
64
+ def __getitem__(self, index):
65
+ sample = self.samples[index]
66
+ noise_filename = sample["noise_filename"]
67
+ noise_offset = sample["noise_offset"]
68
+ noise_duration = sample["noise_duration"]
69
+
70
+ speech_filename = sample["speech_filename"]
71
+ speech_offset = sample["speech_offset"]
72
+ speech_duration = sample["speech_duration"]
73
+
74
+ snr_db = sample["snr_db"]
75
+
76
+ noise_wave = self.filename_to_waveform(noise_filename, noise_offset, noise_duration)
77
+ speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration)
78
+
79
+ mix_wave, noise_wave_adjusted = self.mix_speech_and_noise(
80
+ speech=speech_wave.numpy(),
81
+ noise=noise_wave.numpy(),
82
+ snr_db=snr_db,
83
+ )
84
+ mix_wave = torch.tensor(mix_wave, dtype=torch.float32)
85
+ noise_wave_adjusted = torch.tensor(noise_wave_adjusted, dtype=torch.float32)
86
+
87
+ result = {
88
+ "noise_wave": noise_wave_adjusted,
89
+ "speech_wave": speech_wave,
90
+ "mix_wave": mix_wave,
91
+ "snr_db": snr_db,
92
+ }
93
+ return result
94
+
95
+ def __len__(self):
96
+ return len(self.samples)
97
+
98
+ def filename_to_waveform(self, filename: str, offset: float, duration: float):
99
+ try:
100
+ waveform, sample_rate = librosa.load(
101
+ filename,
102
+ sr=self.expected_sample_rate,
103
+ offset=offset,
104
+ duration=duration,
105
+ )
106
+ except ValueError as e:
107
+ print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
108
+ raise e
109
+ waveform = torch.tensor(waveform, dtype=torch.float32)
110
+ return waveform
111
+
112
+ @staticmethod
113
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
114
+ l1 = len(speech)
115
+ l2 = len(noise)
116
+ l = min(l1, l2)
117
+ speech = speech[:l]
118
+ noise = noise[:l]
119
+
120
+ # np.float32, value between (-1, 1).
121
+
122
+ speech_power = np.mean(np.square(speech))
123
+ noise_power = speech_power / (10 ** (snr_db / 10))
124
+
125
+ noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
126
+
127
+ noisy_signal = speech + noise_adjusted
128
+
129
+ return noisy_signal, noise_adjusted
130
+
131
+
132
+ if __name__ == '__main__':
133
+ pass