HoneyTian commited on
Commit
a0cbcda
·
1 Parent(s): db3e977
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -42,7 +42,7 @@ def get_args():
42
  parser.add_argument("--max_epochs", default=200, type=int)
43
 
44
  parser.add_argument("--batch_size", default=16, type=int)
45
- parser.add_argument("--learning_rate", default=1e-4, type=float)
46
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
47
  parser.add_argument("--patience", default=5, type=int)
48
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
@@ -243,8 +243,8 @@ def main():
243
  progress_bar = tqdm(
244
  desc="Training; epoch-{}".format(idx_epoch),
245
  )
246
- for batch in train_data_loader:
247
- clean_audios, noisy_audios = batch
248
  clean_audios = clean_audios.to(device)
249
  noisy_audios = noisy_audios.to(device)
250
 
@@ -298,7 +298,7 @@ def main():
298
 
299
  # evaluation
300
  total_steps += 1
301
- if total_steps % args.eval_steps:
302
  model.eval()
303
  torch.cuda.empty_cache()
304
 
@@ -313,8 +313,8 @@ def main():
313
  desc="Evaluation; step-{}".format(total_steps),
314
  )
315
  with torch.no_grad():
316
- for batch in valid_data_loader:
317
- clean_audios, noisy_audios = batch
318
  clean_audios = clean_audios.to(device)
319
  noisy_audios = noisy_audios.to(device)
320
 
 
42
  parser.add_argument("--max_epochs", default=200, type=int)
43
 
44
  parser.add_argument("--batch_size", default=16, type=int)
45
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
46
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
47
  parser.add_argument("--patience", default=5, type=int)
48
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
 
243
  progress_bar = tqdm(
244
  desc="Training; epoch-{}".format(idx_epoch),
245
  )
246
+ for train_batch in train_data_loader:
247
+ clean_audios, noisy_audios = train_batch
248
  clean_audios = clean_audios.to(device)
249
  noisy_audios = noisy_audios.to(device)
250
 
 
298
 
299
  # evaluation
300
  total_steps += 1
301
+ if total_steps % args.eval_steps == 0:
302
  model.eval()
303
  torch.cuda.empty_cache()
304
 
 
313
  desc="Evaluation; step-{}".format(total_steps),
314
  )
315
  with torch.no_grad():
316
+ for eval_batch in valid_data_loader:
317
+ clean_audios, noisy_audios = eval_batch
318
  clean_audios = clean_audios.to(device)
319
  noisy_audios = noisy_audios.to(device)
320
 
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py CHANGED
@@ -49,15 +49,16 @@ class DenoiseJsonlDataset(IterableDataset):
49
  item = next(iterable_source)
50
  # 随机替换缓冲区元素
51
  replace_idx = random.randint(0, len(self.buffer_samples) - 1)
52
- yield self.buffer_samples[replace_idx]
53
  self.buffer_samples[replace_idx] = item
 
54
  except StopIteration:
55
  break
56
 
57
  # 清空剩余元素
58
  random.shuffle(self.buffer_samples)
59
  for sample in self.buffer_samples:
60
- yield sample
61
 
62
  def iterable_source(self):
63
  with open(self.jsonl_file, "r", encoding="utf-8") as f:
@@ -75,7 +76,7 @@ class DenoiseJsonlDataset(IterableDataset):
75
 
76
  snr_db = row["snr_db"]
77
 
78
- row = {
79
  "noise_filename": noise_filename,
80
  "noise_raw_duration": noise_raw_duration,
81
  "noise_offset": noise_offset,
@@ -88,7 +89,6 @@ class DenoiseJsonlDataset(IterableDataset):
88
 
89
  "snr_db": snr_db,
90
  }
91
- sample = self.convert_sample(row)
92
  yield sample
93
 
94
  def convert_sample(self, sample: dict):
 
49
  item = next(iterable_source)
50
  # 随机替换缓冲区元素
51
  replace_idx = random.randint(0, len(self.buffer_samples) - 1)
52
+ sample = self.buffer_samples[replace_idx]
53
  self.buffer_samples[replace_idx] = item
54
+ yield self.convert_sample(sample)
55
  except StopIteration:
56
  break
57
 
58
  # 清空剩余元素
59
  random.shuffle(self.buffer_samples)
60
  for sample in self.buffer_samples:
61
+ yield self.convert_sample(sample)
62
 
63
  def iterable_source(self):
64
  with open(self.jsonl_file, "r", encoding="utf-8") as f:
 
76
 
77
  snr_db = row["snr_db"]
78
 
79
+ sample = {
80
  "noise_filename": noise_filename,
81
  "noise_raw_duration": noise_raw_duration,
82
  "noise_offset": noise_offset,
 
89
 
90
  "snr_db": snr_db,
91
  }
 
92
  yield sample
93
 
94
  def convert_sample(self, sample: dict):