wasmdashai commited on
Commit
4687b44
·
verified ·
1 Parent(s): ecc082c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -422,7 +422,7 @@ class TrinerModelVITS:
422
 
423
 
424
  self.initialize_training_components()
425
- self.epoch_count=0
426
 
427
 
428
  def load_model(self):
@@ -594,13 +594,14 @@ pro=TrinerModelVITS(dir_model=dir_model,
594
 
595
  @spaces.GPU(duration=120)
596
  def run_train_epoch(num):
597
- pro.init_training()
598
- for i in range(num):
 
599
  # model.train(True)
600
  yield pro.run_train_epoch()
601
-
602
- pro.save_pretrained(pro.dir_model)
603
- pro.load_model()
604
  return 'save model '
605
 
606
  @spaces.GPU
 
422
 
423
 
424
  self.initialize_training_components()
425
+ # self.epoch_count=0
426
 
427
 
428
  def load_model(self):
 
594
 
595
  @spaces.GPU(duration=120)
596
  def run_train_epoch(num):
597
+ if num >0:
598
+ pro.init_training()
599
+ for i in range(num):
600
  # model.train(True)
601
  yield pro.run_train_epoch()
602
+ else:
603
+ pro.save_pretrained(pro.dir_model)
604
+ pro.load_model()
605
  return 'save model '
606
 
607
  @spaces.GPU