HoneyTian commited on
Commit
1474235
·
1 Parent(s): 4d3fcad
examples/dtln/step_2_train_model.py CHANGED
@@ -1,7 +1,8 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- https://github.com/Rikorose/DeepFilterNet
 
5
  """
6
  import argparse
7
  import json
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ https://github.com/breizhn/DTLN
5
+
6
  """
7
  import argparse
8
  import json
examples/rnnoise/step_2_train_model.py CHANGED
@@ -194,18 +194,12 @@ def main():
194
  if last_step_idx != -1:
195
  logger.info(f"resume from steps-{last_step_idx}.")
196
  model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
197
- optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
198
 
199
  logger.info(f"load state dict for model.")
200
  with open(model_pt.as_posix(), "rb") as f:
201
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
202
  model.load_state_dict(state_dict, strict=True)
203
 
204
- logger.info(f"load state dict for optimizer.")
205
- with open(optimizer_pth.as_posix(), "rb") as f:
206
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
207
- optimizer.load_state_dict(state_dict)
208
-
209
  if config.lr_scheduler == "CosineAnnealingLR":
210
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
211
  optimizer,
@@ -265,7 +259,7 @@ def main():
265
 
266
  progress_bar_train = tqdm(
267
  initial=step_idx,
268
- desc="Training; epoch: {}".format(epoch_idx),
269
  )
270
  for train_batch in train_data_loader:
271
  clean_audios, noisy_audios = train_batch
 
194
  if last_step_idx != -1:
195
  logger.info(f"resume from steps-{last_step_idx}.")
196
  model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
 
197
 
198
  logger.info(f"load state dict for model.")
199
  with open(model_pt.as_posix(), "rb") as f:
200
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
201
  model.load_state_dict(state_dict, strict=True)
202
 
 
 
 
 
 
203
  if config.lr_scheduler == "CosineAnnealingLR":
204
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
205
  optimizer,
 
259
 
260
  progress_bar_train = tqdm(
261
  initial=step_idx,
262
+ desc="Training; epoch-{}".format(epoch_idx),
263
  )
264
  for train_batch in train_data_loader:
265
  clean_audios, noisy_audios = train_batch