Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -42,7 +42,7 @@ def get_args():
|
|
42 |
parser.add_argument("--max_epochs", default=200, type=int)
|
43 |
|
44 |
parser.add_argument("--batch_size", default=16, type=int)
|
45 |
-
parser.add_argument("--learning_rate", default=1e-
|
46 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
47 |
parser.add_argument("--patience", default=5, type=int)
|
48 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
@@ -243,8 +243,8 @@ def main():
|
|
243 |
progress_bar = tqdm(
|
244 |
desc="Training; epoch-{}".format(idx_epoch),
|
245 |
)
|
246 |
-
for
|
247 |
-
clean_audios, noisy_audios =
|
248 |
clean_audios = clean_audios.to(device)
|
249 |
noisy_audios = noisy_audios.to(device)
|
250 |
|
@@ -298,7 +298,7 @@ def main():
|
|
298 |
|
299 |
# evaluation
|
300 |
total_steps += 1
|
301 |
-
if total_steps % args.eval_steps:
|
302 |
model.eval()
|
303 |
torch.cuda.empty_cache()
|
304 |
|
@@ -313,8 +313,8 @@ def main():
|
|
313 |
desc="Evaluation; step-{}".format(total_steps),
|
314 |
)
|
315 |
with torch.no_grad():
|
316 |
-
for
|
317 |
-
clean_audios, noisy_audios =
|
318 |
clean_audios = clean_audios.to(device)
|
319 |
noisy_audios = noisy_audios.to(device)
|
320 |
|
|
|
42 |
parser.add_argument("--max_epochs", default=200, type=int)
|
43 |
|
44 |
parser.add_argument("--batch_size", default=16, type=int)
|
45 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
46 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
47 |
parser.add_argument("--patience", default=5, type=int)
|
48 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
|
|
243 |
progress_bar = tqdm(
|
244 |
desc="Training; epoch-{}".format(idx_epoch),
|
245 |
)
|
246 |
+
for train_batch in train_data_loader:
|
247 |
+
clean_audios, noisy_audios = train_batch
|
248 |
clean_audios = clean_audios.to(device)
|
249 |
noisy_audios = noisy_audios.to(device)
|
250 |
|
|
|
298 |
|
299 |
# evaluation
|
300 |
total_steps += 1
|
301 |
+
if total_steps % args.eval_steps == 0:
|
302 |
model.eval()
|
303 |
torch.cuda.empty_cache()
|
304 |
|
|
|
313 |
desc="Evaluation; step-{}".format(total_steps),
|
314 |
)
|
315 |
with torch.no_grad():
|
316 |
+
for eval_batch in valid_data_loader:
|
317 |
+
clean_audios, noisy_audios = eval_batch
|
318 |
clean_audios = clean_audios.to(device)
|
319 |
noisy_audios = noisy_audios.to(device)
|
320 |
|
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py
CHANGED
@@ -49,15 +49,16 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
49 |
item = next(iterable_source)
|
50 |
# 随机替换缓冲区元素
|
51 |
replace_idx = random.randint(0, len(self.buffer_samples) - 1)
|
52 |
-
|
53 |
self.buffer_samples[replace_idx] = item
|
|
|
54 |
except StopIteration:
|
55 |
break
|
56 |
|
57 |
# 清空剩余元素
|
58 |
random.shuffle(self.buffer_samples)
|
59 |
for sample in self.buffer_samples:
|
60 |
-
yield sample
|
61 |
|
62 |
def iterable_source(self):
|
63 |
with open(self.jsonl_file, "r", encoding="utf-8") as f:
|
@@ -75,7 +76,7 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
75 |
|
76 |
snr_db = row["snr_db"]
|
77 |
|
78 |
-
|
79 |
"noise_filename": noise_filename,
|
80 |
"noise_raw_duration": noise_raw_duration,
|
81 |
"noise_offset": noise_offset,
|
@@ -88,7 +89,6 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
88 |
|
89 |
"snr_db": snr_db,
|
90 |
}
|
91 |
-
sample = self.convert_sample(row)
|
92 |
yield sample
|
93 |
|
94 |
def convert_sample(self, sample: dict):
|
|
|
49 |
item = next(iterable_source)
|
50 |
# 随机替换缓冲区元素
|
51 |
replace_idx = random.randint(0, len(self.buffer_samples) - 1)
|
52 |
+
sample = self.buffer_samples[replace_idx]
|
53 |
self.buffer_samples[replace_idx] = item
|
54 |
+
yield self.convert_sample(sample)
|
55 |
except StopIteration:
|
56 |
break
|
57 |
|
58 |
# 清空剩余元素
|
59 |
random.shuffle(self.buffer_samples)
|
60 |
for sample in self.buffer_samples:
|
61 |
+
yield self.convert_sample(sample)
|
62 |
|
63 |
def iterable_source(self):
|
64 |
with open(self.jsonl_file, "r", encoding="utf-8") as f:
|
|
|
76 |
|
77 |
snr_db = row["snr_db"]
|
78 |
|
79 |
+
sample = {
|
80 |
"noise_filename": noise_filename,
|
81 |
"noise_raw_duration": noise_raw_duration,
|
82 |
"noise_offset": noise_offset,
|
|
|
89 |
|
90 |
"snr_db": snr_db,
|
91 |
}
|
|
|
92 |
yield sample
|
93 |
|
94 |
def convert_sample(self, sample: dict):
|