wasmdashai commited on
Commit
6b42195
·
verified ·
1 Parent(s): 93b95e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -570,12 +570,13 @@ def greet(text,id):
570
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, VITSTrainingArguments))
571
  json_file = os.path.abspath('VitsModelSplit/finetune_config_ara.json')
572
  model_args, data_args, training_args = parser.parse_json_file(json_file = json_file)
 
573
  sgl=get_state_grad_loss(mel=True,
574
  # generator=False,
575
  # discriminator=False,
576
  duration=False)
577
 
578
-
579
  training_args.num_train_epochs=1000
580
  training_args.fp16=True
581
  training_args.eval_steps=300
@@ -585,9 +586,11 @@ def greet(text,id):
585
  training_args.weight_mel=45
586
  training_args.num_train_epochs=4
587
  training_args.eval_steps=1000
 
 
588
 
589
  b=int(id)
590
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
591
  ctrain_datasets,eval_dataset,full_generation_dataset=get_data_loader(train_dataset_dirs = train_dataset_dirs,
592
  eval_dataset_dir = os.path.join(dataset_dir,'eval'),
593
  full_generation_dir = os.path.join(dataset_dir,'full_generation'),
 
570
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, VITSTrainingArguments))
571
  json_file = os.path.abspath('VitsModelSplit/finetune_config_ara.json')
572
  model_args, data_args, training_args = parser.parse_json_file(json_file = json_file)
573
+ print('start')
574
  sgl=get_state_grad_loss(mel=True,
575
  # generator=False,
576
  # discriminator=False,
577
  duration=False)
578
 
579
+ print(training_args)
580
  training_args.num_train_epochs=1000
581
  training_args.fp16=True
582
  training_args.eval_steps=300
 
586
  training_args.weight_mel=45
587
  training_args.num_train_epochs=4
588
  training_args.eval_steps=1000
589
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
590
+ print(device)
591
 
592
  b=int(id)
593
+
594
  ctrain_datasets,eval_dataset,full_generation_dataset=get_data_loader(train_dataset_dirs = train_dataset_dirs,
595
  eval_dataset_dir = os.path.join(dataset_dir,'eval'),
596
  full_generation_dir = os.path.join(dataset_dir,'full_generation'),