ollieollie commited on
Commit
2e214c5
·
verified ·
1 Parent(s): 866a959

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -15,9 +15,13 @@ def set_seed(seed: int):
15
  np.random.seed(seed)
16
 
17
 
18
- model = ChatterboxTTS.from_pretrained(DEVICE)
 
 
 
 
 
19
 
20
- def generate(text, audio_prompt_path, exaggeration, pace, temperature, seed_num, cfg_weight):
21
  if seed_num != 0:
22
  set_seed(int(seed_num))
23
 
@@ -27,12 +31,14 @@ def generate(text, audio_prompt_path, exaggeration, pace, temperature, seed_num,
27
  exaggeration=exaggeration,
28
  pace=pace,
29
  temperature=temperature,
30
- cfg_weight=cfg_weight,
31
  )
32
- return model.sr, wav.squeeze(0).numpy()
33
 
34
 
35
  with gr.Blocks() as demo:
 
 
36
  with gr.Row():
37
  with gr.Column():
38
  text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
@@ -54,6 +60,7 @@ with gr.Blocks() as demo:
54
  run_btn.click(
55
  fn=generate,
56
  inputs=[
 
57
  text,
58
  ref_wav,
59
  exaggeration,
@@ -62,9 +69,8 @@ with gr.Blocks() as demo:
62
  seed_num,
63
  cfg_weight,
64
  ],
65
- outputs=audio_output,
66
  )
67
 
68
  if __name__ == "__main__":
69
- demo.queue()
70
- demo.launch()
 
15
  np.random.seed(seed)
16
 
17
 
18
+ def load_model():
19
+ return ChatterboxTTS.from_pretrained(DEVICE)
20
+
21
+ def generate(model, text, audio_prompt_path, exaggeration, pace, temperature, seed_num, cfgw):
22
+ if model is None:
23
+ model = ChatterboxTTS.from_pretrained(DEVICE)
24
 
 
25
  if seed_num != 0:
26
  set_seed(int(seed_num))
27
 
 
31
  exaggeration=exaggeration,
32
  pace=pace,
33
  temperature=temperature,
34
+ cfg_weight=cfgw,
35
  )
36
+ return (model, (model.sr, wav.squeeze(0).numpy()))
37
 
38
 
39
  with gr.Blocks() as demo:
40
+ model_state = gr.State(None) # Loaded once per session/user
41
+
42
  with gr.Row():
43
  with gr.Column():
44
  text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
 
60
  run_btn.click(
61
  fn=generate,
62
  inputs=[
63
+ model_state,
64
  text,
65
  ref_wav,
66
  exaggeration,
 
69
  seed_num,
70
  cfg_weight,
71
  ],
72
+ outputs=[model_state, audio_output],
73
  )
74
 
75
  if __name__ == "__main__":
76
+ demo.queue().launch()