Spaces:
Runtime error
Runtime error
import os | |
import json | |
import argparse | |
import math | |
import torch | |
from torch import nn, optim | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from data_utils import TextMelLoader, TextMelCollate | |
import models | |
import commons | |
import utils | |
class FlowGenerator_DDI(models.FlowGenerator): | |
"""A helper for Data-dependent Initialization""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
for f in self.decoder.flows: | |
if getattr(f, "set_ddi", False): | |
f.set_ddi(True) | |
def main(): | |
hps = utils.get_hparams() | |
logger = utils.get_logger(hps.log_dir) | |
logger.info(hps) | |
utils.check_git_hash(hps.log_dir) | |
torch.manual_seed(hps.train.seed) | |
train_dataset = TextMelLoader(hps.data.training_files, hps.data) | |
collate_fn = TextMelCollate(1) | |
train_loader = DataLoader( | |
train_dataset, | |
num_workers=8, | |
shuffle=True, | |
batch_size=hps.train.batch_size, | |
pin_memory=True, | |
drop_last=True, | |
collate_fn=collate_fn, | |
) | |
symbols = hps.data.punc + hps.data.chars | |
generator = FlowGenerator_DDI( | |
len(symbols) + getattr(hps.data, "add_blank", False), | |
out_channels=hps.data.n_mel_channels, | |
**hps.model | |
).cuda() | |
optimizer_g = commons.Adam( | |
generator.parameters(), | |
scheduler=hps.train.scheduler, | |
dim_model=hps.model.hidden_channels, | |
warmup_steps=hps.train.warmup_steps, | |
lr=hps.train.learning_rate, | |
betas=hps.train.betas, | |
eps=hps.train.eps, | |
) | |
generator.train() | |
for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader): | |
x, x_lengths = x.cuda(), x_lengths.cuda() | |
y, y_lengths = y.cuda(), y_lengths.cuda() | |
_ = generator(x, x_lengths, y, y_lengths, gen=False) | |
break | |
utils.save_checkpoint( | |
generator, | |
optimizer_g, | |
hps.train.learning_rate, | |
0, | |
os.path.join(hps.model_dir, "ddi_G.pth"), | |
) | |
if __name__ == "__main__": | |
main() | |