|
import os |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch |
|
import torchaudio |
|
from generator import Segment, load_csm_1b |
|
from huggingface_hub import hf_hub_download, login |
|
from watermarking import watermark |
|
import whisperx |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
try: |
|
api_key = os.getenv("HF_TOKEN") |
|
if not api_key: |
|
raise ValueError("HF_TOKEN not found in environment variables.") |
|
login(token=api_key) |
|
|
|
CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" "))) |
|
if not CSM_1B_HF_WATERMARK: |
|
raise ValueError("WATERMARK_KEY not found or invalid in environment variables.") |
|
|
|
gpu_timeout = int(os.getenv("GPU_TIMEOUT", 180)) |
|
except (ValueError, TypeError) as e: |
|
logging.error(f"Configuration error: {e}") |
|
raise |
|
|
|
SPACE_INTRO_TEXT = """\ |
|
# Sesame CSM 1B - Conversational Demo |
|
|
|
This demo allows you to have a conversation with Sesame CSM 1B, leveraging WhisperX for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources. |
|
|
|
*Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.* |
|
""" |
|
|
|
|
|
SPEAKER_ID = 0 |
|
MAX_CONTEXT_SEGMENTS = 5 |
|
MAX_GEMMA_LENGTH = 300 |
|
device = "cuda" |
|
|
|
|
|
conversation_history = [] |
|
|
|
|
|
global_generator = None |
|
global_whisper_model = None |
|
global_model_a = None |
|
|
|
global_tokenizer_gemma = None |
|
global_model_gemma = None |
|
|
|
|
|
|
|
|
|
def transcribe_audio(audio_path: str, whisper_model, model_a) -> str: |
|
"""Transcribes audio using WhisperX and aligns it.""" |
|
try: |
|
audio = whisperx.load_audio(audio_path) |
|
result = whisper_model.transcribe(audio, batch_size=16) |
|
|
|
language = result["language"] |
|
|
|
|
|
|
|
model_a, metadata = whisperx.load_align_model(language_code=language, device=device) |
|
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) |
|
|
|
return result_aligned["segments"][0]["text"] |
|
except Exception as e: |
|
logging.error(f"WhisperX transcription error: {e}") |
|
return "Error: Could not transcribe audio." |
|
|
|
|
|
def generate_response(text: str, tokenizer_gemma, model_gemma) -> str: |
|
"""Generates a response using Gemma.""" |
|
try: |
|
input_text = "Here is a response for the user. " + text |
|
input = tokenizer_gemma(input_text, return_tensors="pt").to(device) |
|
generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True) |
|
return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True) |
|
except Exception as e: |
|
logging.error(f"Gemma response generation error: {e}") |
|
return "I'm sorry, I encountered an error generating a response." |
|
|
|
|
|
def load_audio(audio_path: str) -> torch.Tensor: |
|
"""Loads audio from file and returns a torch tensor.""" |
|
try: |
|
audio_tensor, sample_rate = torchaudio.load(audio_path) |
|
audio_tensor = audio_tensor.mean(dim=0) |
|
if sample_rate != global_generator.sample_rate: |
|
audio_tensor = torchaudio.functional.resample( |
|
audio_tensor, orig_freq=sample_rate, new_freq=global_generator.sample_rate |
|
) |
|
return audio_tensor |
|
except Exception as e: |
|
logging.error(f"Audio loading error: {e}") |
|
raise gr.Error("Could not load or process the audio file.") from e |
|
|
|
|
|
def clear_history(): |
|
"""Clears the conversation history""" |
|
global conversation_history |
|
conversation_history = [] |
|
logging.info("Conversation history cleared.") |
|
return "Conversation history cleared." |
|
|
|
|
|
|
|
|
|
@spaces.GPU(gpu_timeout=gpu_timeout) |
|
def infer(user_audio) -> tuple: |
|
"""Infers a response from the user audio.""" |
|
global global_generator, global_whisper_model, global_model_a, global_tokenizer_gemma, global_model_gemma, device |
|
|
|
try: |
|
if not user_audio: |
|
raise ValueError("No audio input received.") |
|
|
|
|
|
if global_generator is None: |
|
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt") |
|
global_generator = load_csm_1b(model_path, device) |
|
logging.info("Sesame CSM 1B loaded successfully on GPU.") |
|
|
|
if global_whisper_model is None: |
|
global_whisper_model = whisperx.load_model("large-v2", device) |
|
logging.info("WhisperX model loaded successfully on GPU.") |
|
|
|
if global_tokenizer_gemma is None: |
|
global_tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt") |
|
global_model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt").to(device) |
|
logging.info("Gemma 3 1B pt model loaded successfully on GPU.") |
|
return _infer(user_audio, global_generator, global_whisper_model, global_model_a, global_tokenizer_gemma, global_model_gemma) |
|
except Exception as e: |
|
logging.exception(f"Inference error: {e}") |
|
raise gr.Error(f"An error occurred during processing: {e}") |
|
|
|
|
|
def _infer(user_audio, generator, whisper_model, model_a, tokenizer_gemma, model_gemma) -> tuple: |
|
"""Processes the user input, generates a response, and returns audio.""" |
|
global conversation_history |
|
|
|
try: |
|
|
|
user_text = transcribe_audio(user_audio, whisper_model, model_a) |
|
logging.info(f"User: {user_text}") |
|
|
|
|
|
ai_text = generate_response(user_text, tokenizer_gemma, model_gemma) |
|
logging.info(f"AI: {ai_text}") |
|
|
|
|
|
ai_audio = generator.generate( |
|
text=ai_text, |
|
speaker=SPEAKER_ID, |
|
context=conversation_history, |
|
max_audio_length_ms=30_000, |
|
) |
|
logging.info("Audio generated successfully.") |
|
|
|
|
|
user_segment = Segment(speaker = SPEAKER_ID, text = 'User Audio', audio = load_audio(user_audio)) |
|
ai_segment = Segment(speaker = SPEAKER_ID, text = 'AI Audio', audio = ai_audio) |
|
conversation_history.append(user_segment) |
|
conversation_history.append(ai_segment) |
|
|
|
|
|
if len(conversation_history) > MAX_CONTEXT_SEGMENTS: |
|
conversation_history.pop(0) |
|
|
|
|
|
audio_tensor, wm_sample_rate = watermark( |
|
generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK |
|
) |
|
audio_tensor = torchaudio.functional.resample( |
|
audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate |
|
) |
|
|
|
ai_audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy() |
|
return generator.sample_rate, ai_audio_array |
|
|
|
except Exception as e: |
|
logging.exception(f"Error in _infer: {e}") |
|
raise gr.Error(f"An error occurred during processing: {e}") |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown(SPACE_INTRO_TEXT) |
|
audio_input = gr.Audio(label="Your Input", type="filepath") |
|
audio_output = gr.Audio(label="AI Response") |
|
clear_button = gr.Button("Clear Conversation History") |
|
status_display = gr.Textbox(label="Status", visible=False) |
|
|
|
btn = gr.Button("Generate Response") |
|
btn.click(infer, inputs=[audio_input], outputs=[audio_output]) |
|
clear_button.click(clear_history, outputs=[status_display]) |
|
|
|
app.launch(share=False) |