Spaces:
Running
Running
File size: 4,621 Bytes
9d593b2 0b3e025 9d593b2 0b3e025 9d593b2 0b3e025 9d593b2 0b3e025 3dab9c0 9d593b2 0b3e025 6273840 0b3e025 4ce9740 0b3e025 9d593b2 3dab9c0 0b3e025 9d593b2 0b3e025 2e214c5 9d593b2 4ce9740 7f26581 58ffee2 2bd4215 58ffee2 9d593b2 6273840 9d593b2 0b3e025 9d593b2 0b3e025 9d593b2 58ffee2 9d593b2 0b3e025 9d593b2 fae012e 0b3e025 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import random
import numpy as np
import torch
from chatterbox.src.chatterbox.tts import ChatterboxTTS # Assuming this path is correct
import gradio as gr
import spaces
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}")
# --- Global Model Initialization ---
# Load the model once when the application starts.
# This model will be accessible by the @spaces.GPU decorated function.
MODEL = None
def get_or_load_model():
global MODEL
if MODEL is None:
print("Global MODEL is None, loading...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
# Ensure model is on the correct device if not handled by from_pretrained
if DEVICE == "cuda" and hasattr(MODEL, 'to'):
MODEL.to(DEVICE)
print(f"Global MODEL loaded. Device: {DEVICE}")
if hasattr(MODEL, 'device'): # If the model object has a device attribute
print(f"Model internal device attribute: {MODEL.device}")
except Exception as e:
print(f"Error loading global model: {e}")
raise
return MODEL
# Attempt to load the model at startup.
# If this fails, the app will likely fail to start, which is informative.
try:
get_or_load_model()
except Exception as e:
# Handle critical model loading failure if necessary, or let it propagate
print(f"CRITICAL: Failed to load model on startup. Error: {e}")
# You might want to display an error in Gradio if this happens,
# but for now, a print is fine for debugging.
def set_seed(seed: int):
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
@spaces.GPU # Your GPU-accelerated function
def generate_tts_audio(text_input, audio_prompt_path_input, exaggeration_input, temperature_input, seed_num_input, cfgw_input):
current_model = get_or_load_model() # Access the global model
if current_model is None:
# This should ideally not happen if startup loading was successful
# Or, it indicates an issue with the global model pattern in this specific env.
raise RuntimeError("Model could not be loaded or accessed.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"Generating audio for text: '{text_input}'")
wav = current_model.generate(
text_input[:300],
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
cfg_weight=cfgw_input,
)
print("Audio generation complete.")
# ONLY return pickleable data
return (current_model.sr, wav.squeeze(0).numpy())
with gr.Blocks() as demo:
# No gr.State needed for the model object if it's managed globally
# and not passed back and forth.
with gr.Row():
with gr.Column():
text = gr.Textbox(value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", label="Text to synthesize (max chars 300)")
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart.flac")
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5)
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
run_btn.click(
fn=generate_tts_audio, # Use the new function name
inputs=[
# model_state, # Removed: model is now global
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
],
outputs=[audio_output], # Only outputting the audio data
)
demo.queue(
max_size=50,
default_concurrency_limit=1, # Important for a single global model
).launch() # share=True is not needed and causes a warning on Spaces |