Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -3,29 +3,48 @@ import numpy as np
|
|
3 |
import torch
|
4 |
from chatterbox.src.chatterbox.tts import ChatterboxTTS
|
5 |
import gradio as gr
|
|
|
6 |
|
7 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
8 |
|
|
|
|
|
9 |
|
10 |
def set_seed(seed: int):
|
11 |
torch.manual_seed(seed)
|
12 |
-
|
13 |
-
|
|
|
14 |
random.seed(seed)
|
15 |
np.random.seed(seed)
|
16 |
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
def generate(model, text, audio_prompt_path, exaggeration, pace, temperature, seed_num, cfgw):
|
22 |
-
if model is None:
|
23 |
-
model = ChatterboxTTS.from_pretrained(DEVICE)
|
24 |
|
25 |
if seed_num != 0:
|
26 |
set_seed(int(seed_num))
|
27 |
|
28 |
-
|
|
|
29 |
text,
|
30 |
audio_prompt_path=audio_prompt_path,
|
31 |
exaggeration=exaggeration,
|
@@ -33,13 +52,29 @@ def generate(model, text, audio_prompt_path, exaggeration, pace, temperature, se
|
|
33 |
temperature=temperature,
|
34 |
cfg_weight=cfgw,
|
35 |
)
|
36 |
-
|
|
|
|
|
37 |
|
38 |
|
39 |
with gr.Blocks() as demo:
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
with gr.Row():
|
|
|
43 |
with gr.Column():
|
44 |
text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
|
45 |
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/wav7604828.wav")
|
@@ -72,8 +107,9 @@ with gr.Blocks() as demo:
|
|
72 |
outputs=[model_state, audio_output],
|
73 |
)
|
74 |
|
|
|
|
|
75 |
demo.queue(
|
76 |
max_size=50,
|
77 |
-
default_concurrency_limit=1,
|
78 |
-
).launch(share=True
|
79 |
-
|
|
|
3 |
import torch
|
4 |
from chatterbox.src.chatterbox.tts import ChatterboxTTS
|
5 |
import gradio as gr
|
6 |
+
import spaces # <<< IMPORT THIS
|
7 |
|
8 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
print(f"🚀 Running on device: {DEVICE}") # Good to log this
|
10 |
|
11 |
+
# Global model variable to load only once if not using gr.State for model object
|
12 |
+
# global_model = None
|
13 |
|
14 |
def set_seed(seed: int):
|
15 |
torch.manual_seed(seed)
|
16 |
+
if DEVICE == "cuda": # Only seed cuda if available
|
17 |
+
torch.cuda.manual_seed(seed)
|
18 |
+
torch.cuda.manual_seed_all(seed)
|
19 |
random.seed(seed)
|
20 |
np.random.seed(seed)
|
21 |
|
22 |
+
# Optional: Decorate model loading if it's done on first use within a GPU function
|
23 |
+
# However, it's often better to load the model once globally or manage with gr.State
|
24 |
+
# and ensure the function CALLING the model is decorated.
|
25 |
|
26 |
+
@spaces.GPU # <<< ADD THIS DECORATOR
|
27 |
+
def generate(model_obj, text, audio_prompt_path, exaggeration, pace, temperature, seed_num, cfgw):
|
28 |
+
# It's better to load the model once, perhaps when the gr.State is initialized
|
29 |
+
# or globally, rather than checking `model_obj is None` on every call.
|
30 |
+
# For ZeroGPU, the decorated function handles the GPU context.
|
31 |
+
# Let's assume model_obj is passed correctly and is already on DEVICE
|
32 |
+
# or will be moved to DEVICE by ChatterboxTTS internally.
|
33 |
+
|
34 |
+
if model_obj is None:
|
35 |
+
print("Model is None, attempting to load...")
|
36 |
+
# This load should ideally happen on DEVICE and be efficient.
|
37 |
+
# If ChatterboxTTS.from_pretrained(DEVICE) is slow,
|
38 |
+
# this will happen inside the GPU-allocated time.
|
39 |
+
model_obj = ChatterboxTTS.from_pretrained(DEVICE)
|
40 |
+
print(f"Model loaded on device: {model_obj.device if hasattr(model_obj, 'device') else 'unknown'}")
|
41 |
|
|
|
|
|
|
|
42 |
|
43 |
if seed_num != 0:
|
44 |
set_seed(int(seed_num))
|
45 |
|
46 |
+
print(f"Generating audio for text: '{text}' on device: {DEVICE}")
|
47 |
+
wav = model_obj.generate(
|
48 |
text,
|
49 |
audio_prompt_path=audio_prompt_path,
|
50 |
exaggeration=exaggeration,
|
|
|
52 |
temperature=temperature,
|
53 |
cfg_weight=cfgw,
|
54 |
)
|
55 |
+
print("Audio generation complete.")
|
56 |
+
# The model state is passed back out, which is correct for gr.State
|
57 |
+
return (model_obj, (model_obj.sr, wav.squeeze(0).numpy()))
|
58 |
|
59 |
|
60 |
with gr.Blocks() as demo:
|
61 |
+
# To ensure model loads on app start and uses DEVICE correctly:
|
62 |
+
# Pre-load the model here if you want it loaded once globally for the Space instance.
|
63 |
+
# However, with gr.State(None) and loading in `generate`,
|
64 |
+
# the first user hitting "Generate" will trigger the load.
|
65 |
+
# This is fine if `ChatterboxTTS.from_pretrained(DEVICE)` correctly uses the GPU
|
66 |
+
# within the @spaces.GPU decorated `generate` function.
|
67 |
+
|
68 |
+
# For better clarity on model loading with ZeroGPU:
|
69 |
+
# Consider a dedicated function for loading the model that's called to initialize gr.State,
|
70 |
+
# or ensure the first call to `generate` handles it robustly within the GPU context.
|
71 |
+
# The current approach of loading if model_state is None within `generate` is okay
|
72 |
+
# as long as `generate` itself is decorated.
|
73 |
+
|
74 |
+
model_state = gr.State(None)
|
75 |
|
76 |
with gr.Row():
|
77 |
+
# ... (rest of your UI code is fine) ...
|
78 |
with gr.Column():
|
79 |
text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
|
80 |
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/wav7604828.wav")
|
|
|
107 |
outputs=[model_state, audio_output],
|
108 |
)
|
109 |
|
110 |
+
# The share=True in launch() will give a UserWarning on Spaces, it's not needed.
|
111 |
+
# Hugging Face Spaces provides the public link automatically.
|
112 |
demo.queue(
|
113 |
max_size=50,
|
114 |
+
default_concurrency_limit=1, # Good for single model instance on GPU
|
115 |
+
).launch() # Removed share=True
|
|