MohamedRashad commited on
Commit
0053e1b
·
1 Parent(s): eb30231

Add deepspeed import and print device information in app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import spaces
 
2
  import torch
3
  from TTS.tts.configs.xtts_config import XttsConfig
4
  from TTS.tts.models.xtts import Xtts
@@ -36,6 +37,7 @@ config.load_json(config_path)
36
 
37
  print("Loading model...")
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
39
  model = Xtts.init_from_config(config)
40
  model.load_checkpoint(config, checkpoint_path=model_path, use_deepspeed=True, vocab_path=vocab_path, eval=True)
41
  model.to(device)
 
1
  import spaces
2
+ import deepspeed
3
  import torch
4
  from TTS.tts.configs.xtts_config import XttsConfig
5
  from TTS.tts.models.xtts import Xtts
 
37
 
38
  print("Loading model...")
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ print(device)
41
  model = Xtts.init_from_config(config)
42
  model.load_checkpoint(config, checkpoint_path=model_path, use_deepspeed=True, vocab_path=vocab_path, eval=True)
43
  model.to(device)