Spaces:
Bradarr
/
Running on Zero

csm-1b / app.py
Bradarr's picture
Update app.py
512b6c2 verified
import os
import gradio as gr
import numpy as np
import spaces
import torch
import torchaudio
from generator import Segment, load_csm_1b
from huggingface_hub import hf_hub_download, login
from watermarking import watermark
import whisper
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
from transformers import GenerationConfig
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Authentication and Configuration --- (Moved BEFORE model loading)
try:
api_key = os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("HF_TOKEN not found in environment variables.")
login(token=api_key)
CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
if not CSM_1B_HF_WATERMARK:
raise ValueError("WATERMARK_KEY not found or invalid in environment variables.")
gpu_timeout = int(os.getenv("GPU_TIMEOUT", 120))
except (ValueError, TypeError) as e:
logging.error(f"Configuration error: {e}")
raise
SPACE_INTRO_TEXT = """
# Sesame CSM 1B - Conversational Demo
This demo allows you to have a conversation with Sesame CSM 1B, leveraging Whisper for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources.
*Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
"""
# --- Model Loading --- (Moved INSIDE infer function)
# --- Constants --- (Constants can stay outside)
SPEAKER_ID = 0
MAX_CONTEXT_SEGMENTS = 3
MAX_GEMMA_LENGTH = 128
# --- Global Conversation History ---
conversation_history = []
# --- Helper Functions ---
def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_model
try:
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
result = whisper_model.transcribe(audio)
return result["text"]
except Exception as e:
logging.error(f"Whisper transcription error: {e}")
return "Error: Could not transcribe audio."
def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: # Pass model and tokenizer
try:
# Gemma 3 chat template format
messages = [{"role": "user", "content": text}]
input = tokenizer_gemma.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
generation_config = GenerationConfig(
max_new_tokens=MAX_GEMMA_LENGTH,
early_stopping=True,
)
generated_output = model_gemma.generate(input, generation_config=generation_config)
decoded_output = tokenizer_gemma.decode(generated_output[0], skip_special_tokens=False)
# Extract the assistant's response (Gemma specific)
start_token = "<start_of_turn>model"
end_token = "<end_of_turn>"
start_index = decoded_output.find(start_token)
if start_index != -1:
start_index += len(start_token)
end_index = decoded_output.find(end_token, start_index)
assistant_response = decoded_output[start_index:].strip()
return assistant_response
return decoded_output
#input_text = "Reapond to the users prompt: " + text
#input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
#generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
#return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
except Exception as e:
logging.error(f"Gemma response generation error: {e}")
return "I'm sorry, I encountered an error generating a response."
def load_audio(audio_path: str, generator) -> torch.Tensor: #Pass generator
try:
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = audio_tensor.mean(dim=0)
if sample_rate != generator.sample_rate:
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate)
return audio_tensor
except Exception as e:
logging.error(f"Audio loading error: {e}")
raise gr.Error("Could not load or process the audio file.") from e
def clear_history():
global conversation_history
conversation_history = []
logging.info("Conversation history cleared.")
return "Conversation history cleared."
# --- Main Inference Function ---
@spaces.GPU(duration=gpu_timeout) # Decorator FIRST
def infer(user_audio) -> tuple[int, np.ndarray]:
# --- CUDA Availability Check (INSIDE infer) ---
if torch.cuda.is_available():
print(f"CUDA is available! Device count: {torch.cuda.device_count()}")
print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
print(f"CUDA version: {torch.version.cuda}")
device = "cuda"
else:
print("CUDA is NOT available. Using CPU.") # Use CPU, don't raise
device = "cpu"
try:
# --- Model Loading (INSIDE infer, after device is set) ---
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
generator = load_csm_1b(model_path, device)
logging.info("Sesame CSM 1B loaded successfully.")
whisper_model = whisper.load_model("small.en", device=device)
logging.info("Whisper model loaded successfully.")
tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it").to(device)
logging.info("Gemma 3 1B pt model loaded successfully.")
if not user_audio:
raise ValueError("No audio input received.")
return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) #Pass all models
except Exception as e:
logging.exception(f"Inference error: {e}")
raise gr.Error(f"An error occurred during processing: {e}")
def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) -> tuple[int, np.ndarray]:
global conversation_history
try:
user_text = transcribe_audio(user_audio, whisper_model) # Pass whisper_model
logging.info(f"User: {user_text}")
ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device) # Pass model and tokenizer
logging.info(f"AI: {ai_text}")
try:
ai_audio = generator.generate(
text=ai_text,
speaker=SPEAKER_ID,
context=conversation_history,
max_audio_length_ms=10_000,
)
logging.info("Audio generated successfully.")
except Exception as e:
logging.error(f"Sesame response generation error: {e}")
raise gr.Error(f"Sesame response generation error: {e}")
user_segment = Segment(speaker = 1, text = user_text, audio = load_audio(user_audio, generator)) #Pass Generator
ai_segment = Segment(speaker = SPEAKER_ID, text = ai_text, audio = ai_audio)
conversation_history.append(user_segment)
conversation_history.append(ai_segment)
if len(conversation_history) > MAX_CONTEXT_SEGMENTS:
conversation_history.pop(0)
audio_tensor, wm_sample_rate = watermark(
generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK
)
audio_tensor = torchaudio.functional.resample(
audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate
)
ai_audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy()
return generator.sample_rate, ai_audio_array
except Exception as e:
logging.exception(f"Error in _infer: {e}")
raise gr.Error(f"An error occurred during processing: {e}")
# --- Gradio Interface ---
with gr.Blocks() as app:
gr.Markdown(SPACE_INTRO_TEXT)
audio_input = gr.Audio(label="Your Input", type="filepath")
audio_output = gr.Audio(label="AI Response")
clear_button = gr.Button("Clear Conversation History")
status_display = gr.Textbox(label="Status", visible=False)
btn = gr.Button("Generate Response")
btn.click(infer, inputs=[audio_input], outputs=[audio_output])
clear_button.click(clear_history, outputs=[status_display])
app.launch(ssr_mode=False)