Last commit not found
import os | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import torchaudio | |
from generator import Segment, load_csm_1b | |
from huggingface_hub import hf_hub_download, login | |
from watermarking import watermark | |
import whisperx | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Authentication and Configuration | |
try: | |
api_key = os.getenv("HF_TOKEN") | |
if not api_key: | |
raise ValueError("HF_TOKEN not found in environment variables.") | |
login(token=api_key) | |
CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" "))) | |
if not CSM_1B_HF_WATERMARK: | |
raise ValueError("WATERMARK_KEY not found or invalid in environment variables.") | |
gpu_timeout = int(os.getenv("GPU_TIMEOUT", 180)) | |
except (ValueError, TypeError) as e: | |
logging.error(f"Configuration error: {e}") | |
raise | |
SPACE_INTRO_TEXT = """\ | |
# Sesame CSM 1B - Conversational Demo | |
This demo allows you to have a conversation with Sesame CSM 1B, leveraging WhisperX for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources. | |
*Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.* | |
""" | |
# Constants | |
SPEAKER_ID = 0 # Arbitrary speaker ID | |
MAX_CONTEXT_SEGMENTS = 5 | |
MAX_GEMMA_LENGTH = 300 | |
device = "cuda" # if torch.cuda.is_available() else "cpu" | |
# Global conversation history | |
conversation_history = [] | |
# Global variables to hold loaded models | |
global_generator = None | |
global_whisper_model = None | |
global_model_a = None | |
# global_whisper_metadata = None # No longer needed at the global level | |
global_tokenizer_gemma = None | |
global_model_gemma = None | |
# --- HELPER FUNCTIONS --- | |
def transcribe_audio(audio_path: str, whisper_model, model_a) -> str: # Removed whisper_metadata | |
"""Transcribes audio using WhisperX and aligns it.""" | |
try: | |
audio = whisperx.load_audio(audio_path) | |
result = whisper_model.transcribe(audio, batch_size=16) | |
# Get language from the result. Much more reliable. | |
language = result["language"] | |
# Align Whisper output | |
model_a, metadata = whisperx.load_align_model(language_code=language, device=device) #Load it here to ensure metadata is extracted. | |
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) | |
return result_aligned["segments"][0]["text"] | |
except Exception as e: | |
logging.error(f"WhisperX transcription error: {e}") | |
return "Error: Could not transcribe audio." | |
def generate_response(text: str, tokenizer_gemma, model_gemma) -> str: | |
"""Generates a response using Gemma.""" | |
try: | |
input_text = "Here is a response for the user. " + text | |
input = tokenizer_gemma(input_text, return_tensors="pt").to(device) | |
generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True) | |
return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True) | |
except Exception as e: | |
logging.error(f"Gemma response generation error: {e}") | |
return "I'm sorry, I encountered an error generating a response." | |
def load_audio(audio_path: str) -> torch.Tensor: | |
"""Loads audio from file and returns a torch tensor.""" | |
try: | |
audio_tensor, sample_rate = torchaudio.load(audio_path) | |
audio_tensor = audio_tensor.mean(dim=0) # Mono audio | |
if sample_rate != global_generator.sample_rate: | |
audio_tensor = torchaudio.functional.resample( | |
audio_tensor, orig_freq=sample_rate, new_freq=global_generator.sample_rate | |
) | |
return audio_tensor | |
except Exception as e: | |
logging.error(f"Audio loading error: {e}") | |
raise gr.Error("Could not load or process the audio file.") from e | |
def clear_history(): | |
"""Clears the conversation history""" | |
global conversation_history | |
conversation_history = [] | |
logging.info("Conversation history cleared.") | |
return "Conversation history cleared." | |
# --- MAIN INFERENCE FUNCTION --- | |
def infer(user_audio) -> tuple: | |
"""Infers a response from the user audio.""" | |
global global_generator, global_whisper_model, global_model_a, global_tokenizer_gemma, global_model_gemma, device | |
try: | |
if not user_audio: | |
raise ValueError("No audio input received.") | |
# Load models if not already loaded | |
if global_generator is None: | |
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt") | |
global_generator = load_csm_1b(model_path, device) | |
logging.info("Sesame CSM 1B loaded successfully on GPU.") | |
if global_whisper_model is None: | |
global_whisper_model = whisperx.load_model("large-v2", device) # No unpacking | |
logging.info("WhisperX model loaded successfully on GPU.") | |
if global_tokenizer_gemma is None: | |
global_tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt") | |
global_model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device) | |
logging.info("Gemma 3 1B pt model loaded successfully on GPU.") | |
return _infer(user_audio, global_generator, global_whisper_model, global_model_a, global_tokenizer_gemma, global_model_gemma) #Removed Metadata | |
except Exception as e: | |
logging.exception(f"Inference error: {e}") | |
raise gr.Error(f"An error occurred during processing: {e}") | |
def _infer(user_audio, generator, whisper_model, model_a, tokenizer_gemma, model_gemma) -> tuple: | |
"""Processes the user input, generates a response, and returns audio.""" | |
global conversation_history | |
try: | |
# 1. ASR: Transcribe user audio using WhisperX | |
user_text = transcribe_audio(user_audio, whisper_model, model_a) #Removed Metadata | |
logging.info(f"User: {user_text}") | |
# 2. LLM: Generate a response using Gemma | |
ai_text = generate_response(user_text, tokenizer_gemma, model_gemma) | |
logging.info(f"AI: {ai_text}") | |
# 3. Generate audio using the CSM model | |
ai_audio = generator.generate( | |
text=ai_text, | |
speaker=SPEAKER_ID, | |
context=conversation_history, | |
max_audio_length_ms=30_000, | |
) | |
logging.info("Audio generated successfully.") | |
#Update conversation history with user input and ai response. | |
user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio)) | |
ai_segment = Segment(speaker = SPEAKER_ID, text = 'AI Audio', audio = ai_audio) | |
conversation_history.append(user_segment) | |
conversation_history.append(ai_segment) | |
#Limit Conversation History | |
if len(conversation_history) > MAX_CONTEXT_SEGMENTS: | |
conversation_history.pop(0) | |
# 4. Watermarking and Audio Conversion | |
audio_tensor, wm_sample_rate = watermark( | |
generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK | |
) | |
audio_tensor = torchaudio.functional.resample( | |
audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate | |
) | |
ai_audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy() | |
return generator.sample_rate, ai_audio_array | |
except Exception as e: | |
logging.exception(f"Error in _infer: {e}") | |
raise gr.Error(f"An error occurred during processing: {e}") | |
# --- GRADIO INTERFACE --- | |
with gr.Blocks() as app: | |
gr.Markdown(SPACE_INTRO_TEXT) | |
audio_input = gr.Audio(label="Your Input", type="filepath") | |
audio_output = gr.Audio(label="AI Response") | |
clear_button = gr.Button("Clear Conversation History") | |
status_display = gr.Textbox(label="Status", visible=False) | |
btn = gr.Button("Generate Response") | |
btn.click(infer, inputs=[audio_input], outputs=[audio_output]) | |
clear_button.click(clear_history, outputs=[status_display]) | |
app.launch(share=False) |