Spaces:
Running
Running
update
Browse files
examples/dtln/step_2_train_model.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
-
https://github.com/
|
|
|
5 |
"""
|
6 |
import argparse
|
7 |
import json
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
+
https://github.com/breizhn/DTLN
|
5 |
+
|
6 |
"""
|
7 |
import argparse
|
8 |
import json
|
examples/rnnoise/step_2_train_model.py
CHANGED
@@ -194,18 +194,12 @@ def main():
|
|
194 |
if last_step_idx != -1:
|
195 |
logger.info(f"resume from steps-{last_step_idx}.")
|
196 |
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
197 |
-
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
198 |
|
199 |
logger.info(f"load state dict for model.")
|
200 |
with open(model_pt.as_posix(), "rb") as f:
|
201 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
202 |
model.load_state_dict(state_dict, strict=True)
|
203 |
|
204 |
-
logger.info(f"load state dict for optimizer.")
|
205 |
-
with open(optimizer_pth.as_posix(), "rb") as f:
|
206 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
207 |
-
optimizer.load_state_dict(state_dict)
|
208 |
-
|
209 |
if config.lr_scheduler == "CosineAnnealingLR":
|
210 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
211 |
optimizer,
|
@@ -265,7 +259,7 @@ def main():
|
|
265 |
|
266 |
progress_bar_train = tqdm(
|
267 |
initial=step_idx,
|
268 |
-
desc="Training; epoch
|
269 |
)
|
270 |
for train_batch in train_data_loader:
|
271 |
clean_audios, noisy_audios = train_batch
|
|
|
194 |
if last_step_idx != -1:
|
195 |
logger.info(f"resume from steps-{last_step_idx}.")
|
196 |
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
|
|
197 |
|
198 |
logger.info(f"load state dict for model.")
|
199 |
with open(model_pt.as_posix(), "rb") as f:
|
200 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
201 |
model.load_state_dict(state_dict, strict=True)
|
202 |
|
|
|
|
|
|
|
|
|
|
|
203 |
if config.lr_scheduler == "CosineAnnealingLR":
|
204 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
205 |
optimizer,
|
|
|
259 |
|
260 |
progress_bar_train = tqdm(
|
261 |
initial=step_idx,
|
262 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
263 |
)
|
264 |
for train_batch in train_data_loader:
|
265 |
clean_audios, noisy_audios = train_batch
|