ollieollie commited on
Commit
3dab9c0
·
verified ·
1 Parent(s): 31b5538

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -13
app.py CHANGED
@@ -3,29 +3,48 @@ import numpy as np
3
  import torch
4
  from chatterbox.src.chatterbox.tts import ChatterboxTTS
5
  import gradio as gr
 
6
 
7
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
8
 
 
 
9
 
10
  def set_seed(seed: int):
11
  torch.manual_seed(seed)
12
- torch.cuda.manual_seed(seed)
13
- torch.cuda.manual_seed_all(seed)
 
14
  random.seed(seed)
15
  np.random.seed(seed)
16
 
 
 
 
17
 
18
- def load_model():
19
- return ChatterboxTTS.from_pretrained(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def generate(model, text, audio_prompt_path, exaggeration, pace, temperature, seed_num, cfgw):
22
- if model is None:
23
- model = ChatterboxTTS.from_pretrained(DEVICE)
24
 
25
  if seed_num != 0:
26
  set_seed(int(seed_num))
27
 
28
- wav = model.generate(
 
29
  text,
30
  audio_prompt_path=audio_prompt_path,
31
  exaggeration=exaggeration,
@@ -33,13 +52,29 @@ def generate(model, text, audio_prompt_path, exaggeration, pace, temperature, se
33
  temperature=temperature,
34
  cfg_weight=cfgw,
35
  )
36
- return (model, (model.sr, wav.squeeze(0).numpy()))
 
 
37
 
38
 
39
  with gr.Blocks() as demo:
40
- model_state = gr.State(None) # Loaded once per session/user
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  with gr.Row():
 
43
  with gr.Column():
44
  text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
45
  ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/wav7604828.wav")
@@ -72,8 +107,9 @@ with gr.Blocks() as demo:
72
  outputs=[model_state, audio_output],
73
  )
74
 
 
 
75
  demo.queue(
76
  max_size=50,
77
- default_concurrency_limit=1,
78
- ).launch(share=True)
79
-
 
3
  import torch
4
  from chatterbox.src.chatterbox.tts import ChatterboxTTS
5
  import gradio as gr
6
+ import spaces # <<< IMPORT THIS
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"🚀 Running on device: {DEVICE}") # Good to log this
10
 
11
+ # Global model variable to load only once if not using gr.State for model object
12
+ # global_model = None
13
 
14
  def set_seed(seed: int):
15
  torch.manual_seed(seed)
16
+ if DEVICE == "cuda": # Only seed cuda if available
17
+ torch.cuda.manual_seed(seed)
18
+ torch.cuda.manual_seed_all(seed)
19
  random.seed(seed)
20
  np.random.seed(seed)
21
 
22
+ # Optional: Decorate model loading if it's done on first use within a GPU function
23
+ # However, it's often better to load the model once globally or manage with gr.State
24
+ # and ensure the function CALLING the model is decorated.
25
 
26
+ @spaces.GPU # <<< ADD THIS DECORATOR
27
+ def generate(model_obj, text, audio_prompt_path, exaggeration, pace, temperature, seed_num, cfgw):
28
+ # It's better to load the model once, perhaps when the gr.State is initialized
29
+ # or globally, rather than checking `model_obj is None` on every call.
30
+ # For ZeroGPU, the decorated function handles the GPU context.
31
+ # Let's assume model_obj is passed correctly and is already on DEVICE
32
+ # or will be moved to DEVICE by ChatterboxTTS internally.
33
+
34
+ if model_obj is None:
35
+ print("Model is None, attempting to load...")
36
+ # This load should ideally happen on DEVICE and be efficient.
37
+ # If ChatterboxTTS.from_pretrained(DEVICE) is slow,
38
+ # this will happen inside the GPU-allocated time.
39
+ model_obj = ChatterboxTTS.from_pretrained(DEVICE)
40
+ print(f"Model loaded on device: {model_obj.device if hasattr(model_obj, 'device') else 'unknown'}")
41
 
 
 
 
42
 
43
  if seed_num != 0:
44
  set_seed(int(seed_num))
45
 
46
+ print(f"Generating audio for text: '{text}' on device: {DEVICE}")
47
+ wav = model_obj.generate(
48
  text,
49
  audio_prompt_path=audio_prompt_path,
50
  exaggeration=exaggeration,
 
52
  temperature=temperature,
53
  cfg_weight=cfgw,
54
  )
55
+ print("Audio generation complete.")
56
+ # The model state is passed back out, which is correct for gr.State
57
+ return (model_obj, (model_obj.sr, wav.squeeze(0).numpy()))
58
 
59
 
60
  with gr.Blocks() as demo:
61
+ # To ensure model loads on app start and uses DEVICE correctly:
62
+ # Pre-load the model here if you want it loaded once globally for the Space instance.
63
+ # However, with gr.State(None) and loading in `generate`,
64
+ # the first user hitting "Generate" will trigger the load.
65
+ # This is fine if `ChatterboxTTS.from_pretrained(DEVICE)` correctly uses the GPU
66
+ # within the @spaces.GPU decorated `generate` function.
67
+
68
+ # For better clarity on model loading with ZeroGPU:
69
+ # Consider a dedicated function for loading the model that's called to initialize gr.State,
70
+ # or ensure the first call to `generate` handles it robustly within the GPU context.
71
+ # The current approach of loading if model_state is None within `generate` is okay
72
+ # as long as `generate` itself is decorated.
73
+
74
+ model_state = gr.State(None)
75
 
76
  with gr.Row():
77
+ # ... (rest of your UI code is fine) ...
78
  with gr.Column():
79
  text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
80
  ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/wav7604828.wav")
 
107
  outputs=[model_state, audio_output],
108
  )
109
 
110
+ # The share=True in launch() will give a UserWarning on Spaces, it's not needed.
111
+ # Hugging Face Spaces provides the public link automatically.
112
  demo.queue(
113
  max_size=50,
114
+ default_concurrency_limit=1, # Good for single model instance on GPU
115
+ ).launch() # Removed share=True