Update app.py
Browse files
app.py
CHANGED
@@ -4,11 +4,11 @@ import numpy as np
|
|
4 |
import spaces
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
-
from generator import Segment, load_csm_1b
|
8 |
-
from huggingface_hub import hf_hub_download, login
|
9 |
from watermarking import watermark
|
10 |
-
import whisper
|
11 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
12 |
import logging
|
13 |
from transformers import GenerationConfig
|
14 |
|
@@ -47,14 +47,14 @@ MAX_GEMMA_LENGTH = 128
|
|
47 |
# --- Global Conversation History ---
|
48 |
conversation_history = []
|
49 |
|
50 |
-
# --- Model Downloading (PRE-DOWNLOAD) ---
|
51 |
|
52 |
-
# Download Sesame CSM 1B
|
53 |
-
csm_1b_model_path = "csm_1b_ckpt.pt" # Local path
|
54 |
try:
|
55 |
if not os.path.exists(csm_1b_model_path):
|
56 |
hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", local_dir=".", local_dir_use_symlinks=False)
|
57 |
-
os.rename("ckpt.pt", csm_1b_model_path)
|
58 |
logging.info("Sesame CSM 1B model downloaded.")
|
59 |
else:
|
60 |
logging.info("Sesame CSM 1B model already downloaded.")
|
@@ -62,25 +62,40 @@ except Exception as e:
|
|
62 |
logging.error(f"Error downloading Sesame CSM 1B: {e}")
|
63 |
raise
|
64 |
|
65 |
-
# Download Whisper (using
|
66 |
whisper_model_name = "small.en"
|
|
|
67 |
try:
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
except Exception as e:
|
71 |
-
|
72 |
-
raise
|
73 |
|
74 |
-
# Download Gemma 3 1B (
|
75 |
gemma_repo_id = "google/gemma-3-1b-it"
|
76 |
-
gemma_local_path = "gemma_model"
|
77 |
try:
|
78 |
if not os.path.exists(gemma_local_path):
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
else:
|
83 |
-
logging.info("Gemma 3 1B model and tokenizer already downloaded.")
|
84 |
except Exception as e:
|
85 |
logging.error(f"Error downloading Gemma 3 1B: {e}")
|
86 |
raise
|
@@ -88,7 +103,7 @@ except Exception as e:
|
|
88 |
|
89 |
# --- Helper Functions ---
|
90 |
|
91 |
-
def transcribe_audio(audio_path: str, whisper_model) -> str:
|
92 |
try:
|
93 |
audio = whisper.load_audio(audio_path)
|
94 |
audio = whisper.pad_or_trim(audio)
|
@@ -98,23 +113,19 @@ def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_mod
|
|
98 |
logging.error(f"Whisper transcription error: {e}")
|
99 |
return "Error: Could not transcribe audio."
|
100 |
|
101 |
-
def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str:
|
102 |
try:
|
103 |
-
# Gemma 3 chat template format
|
104 |
messages = [{"role": "user", "content": text}]
|
105 |
input = tokenizer_gemma.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
|
106 |
generation_config = GenerationConfig(
|
107 |
max_new_tokens=MAX_GEMMA_LENGTH,
|
108 |
early_stopping=True,
|
109 |
)
|
110 |
-
|
111 |
generated_output = model_gemma.generate(input, generation_config=generation_config)
|
112 |
decoded_output = tokenizer_gemma.decode(generated_output[0], skip_special_tokens=False)
|
113 |
|
114 |
-
# Extract the assistant's response (Gemma specific)
|
115 |
start_token = "<start_of_turn>model"
|
116 |
end_token = "<end_of_turn>"
|
117 |
-
|
118 |
start_index = decoded_output.find(start_token)
|
119 |
if start_index != -1:
|
120 |
start_index += len(start_token)
|
@@ -122,11 +133,12 @@ def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: #
|
|
122 |
assistant_response = decoded_output[start_index:].strip()
|
123 |
return assistant_response
|
124 |
return decoded_output
|
|
|
125 |
except Exception as e:
|
126 |
logging.error(f"Gemma response generation error: {e}")
|
127 |
return "I'm sorry, I encountered an error generating a response."
|
128 |
|
129 |
-
def load_audio(audio_path: str, generator) -> torch.Tensor:
|
130 |
try:
|
131 |
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
132 |
audio_tensor = audio_tensor.mean(dim=0)
|
@@ -145,35 +157,31 @@ def clear_history():
|
|
145 |
|
146 |
# --- Main Inference Function ---
|
147 |
|
148 |
-
@spaces.GPU(duration=gpu_timeout) #
|
149 |
def infer(user_audio) -> tuple[int, np.ndarray]:
|
150 |
-
# --- CUDA Availability Check (INSIDE infer) ---
|
151 |
if torch.cuda.is_available():
|
152 |
device = "cuda"
|
153 |
logging.info(f"CUDA is available! Using device: {torch.cuda.get_device_name(0)}")
|
154 |
else:
|
155 |
device = "cpu"
|
156 |
-
logging.info("CUDA is NOT available.
|
157 |
-
|
158 |
|
159 |
try:
|
160 |
-
|
161 |
-
# Load Sesame CSM 1B (from local file)
|
162 |
generator = load_csm_1b(csm_1b_model_path, device)
|
163 |
logging.info("Sesame CSM 1B loaded successfully.")
|
164 |
|
165 |
-
|
166 |
-
whisper_model = whisper.load_model(whisper_model_name, device=device)
|
167 |
logging.info(f"Whisper model '{whisper_model_name}' loaded successfully.")
|
168 |
|
169 |
-
# Load Gemma (from local cache)
|
170 |
tokenizer_gemma = AutoTokenizer.from_pretrained(gemma_local_path)
|
171 |
model_gemma = AutoModelForCausalLM.from_pretrained(gemma_local_path).to(device)
|
172 |
logging.info("Gemma 3 1B pt model loaded successfully.")
|
173 |
|
174 |
if not user_audio:
|
175 |
raise ValueError("No audio input received.")
|
176 |
-
return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device)
|
|
|
177 |
except Exception as e:
|
178 |
logging.exception(f"Inference error: {e}")
|
179 |
raise gr.Error(f"An error occurred during processing: {e}")
|
@@ -182,10 +190,10 @@ def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, d
|
|
182 |
global conversation_history
|
183 |
|
184 |
try:
|
185 |
-
user_text = transcribe_audio(user_audio, whisper_model)
|
186 |
logging.info(f"User: {user_text}")
|
187 |
|
188 |
-
ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device)
|
189 |
logging.info(f"AI: {ai_text}")
|
190 |
|
191 |
try:
|
@@ -201,7 +209,7 @@ def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, d
|
|
201 |
raise gr.Error(f"Sesame response generation error: {e}")
|
202 |
|
203 |
|
204 |
-
user_segment = Segment(speaker = 1, text = user_text, audio = load_audio(user_audio, generator))
|
205 |
ai_segment = Segment(speaker = SPEAKER_ID, text = ai_text, audio = ai_audio)
|
206 |
conversation_history.append(user_segment)
|
207 |
conversation_history.append(ai_segment)
|
|
|
4 |
import spaces
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
+
from generator import Segment, load_csm_1b # We'll use load_csm_1b *later*
|
8 |
+
from huggingface_hub import hf_hub_download, login, HfApi
|
9 |
from watermarking import watermark
|
10 |
+
import whisper # We'll use whisper.load_model *later*
|
11 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM # We'll use these *later*
|
12 |
import logging
|
13 |
from transformers import GenerationConfig
|
14 |
|
|
|
47 |
# --- Global Conversation History ---
|
48 |
conversation_history = []
|
49 |
|
50 |
+
# --- Model Downloading (PRE-DOWNLOAD, NO LOADING) ---
|
51 |
|
52 |
+
# 1. Download Sesame CSM 1B
|
53 |
+
csm_1b_model_path = "csm_1b_ckpt.pt" # Local path for the downloaded model
|
54 |
try:
|
55 |
if not os.path.exists(csm_1b_model_path):
|
56 |
hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", local_dir=".", local_dir_use_symlinks=False)
|
57 |
+
os.rename("ckpt.pt", csm_1b_model_path)
|
58 |
logging.info("Sesame CSM 1B model downloaded.")
|
59 |
else:
|
60 |
logging.info("Sesame CSM 1B model already downloaded.")
|
|
|
62 |
logging.error(f"Error downloading Sesame CSM 1B: {e}")
|
63 |
raise
|
64 |
|
65 |
+
# 2. Download Whisper (using hf_hub_download for consistency)
|
66 |
whisper_model_name = "small.en"
|
67 |
+
whisper_local_dir = "whisper_model" # Local directory for Whisper
|
68 |
try:
|
69 |
+
if not os.path.exists(whisper_local_dir):
|
70 |
+
os.makedirs(whisper_local_dir, exist_ok=True) #Create if not exist
|
71 |
+
#Whisper uses a specific download method. This command should pre download everything needed
|
72 |
+
whisper.load_model(whisper_model_name, download_root=whisper_local_dir)
|
73 |
+
else:
|
74 |
+
logging.info("Whisper model already downloaded.")
|
75 |
except Exception as e:
|
76 |
+
logging.error(f"Whisper model download failed with exception: {e}")
|
|
|
77 |
|
78 |
+
# 3. Download Gemma 3 1B (using hf_hub_download, individual files)
|
79 |
gemma_repo_id = "google/gemma-3-1b-it"
|
80 |
+
gemma_local_path = os.path.abspath("gemma_model") # Absolute path
|
81 |
try:
|
82 |
if not os.path.exists(gemma_local_path):
|
83 |
+
os.makedirs(gemma_local_path, exist_ok=True) # Create the directory
|
84 |
+
api = HfApi()
|
85 |
+
# List all files in the repository
|
86 |
+
repo_files = api.list_repo_files(gemma_repo_id)
|
87 |
+
|
88 |
+
# Download each file individually
|
89 |
+
for file in repo_files:
|
90 |
+
hf_hub_download(
|
91 |
+
repo_id=gemma_repo_id,
|
92 |
+
filename=file,
|
93 |
+
local_dir=gemma_local_path,
|
94 |
+
local_dir_use_symlinks=False, # Ensure files are copied, not linked
|
95 |
+
)
|
96 |
+
logging.info("Gemma 3 1B model and tokenizer files downloaded.")
|
97 |
else:
|
98 |
+
logging.info("Gemma 3 1B model and tokenizer files already downloaded.")
|
99 |
except Exception as e:
|
100 |
logging.error(f"Error downloading Gemma 3 1B: {e}")
|
101 |
raise
|
|
|
103 |
|
104 |
# --- Helper Functions ---
|
105 |
|
106 |
+
def transcribe_audio(audio_path: str, whisper_model) -> str:
|
107 |
try:
|
108 |
audio = whisper.load_audio(audio_path)
|
109 |
audio = whisper.pad_or_trim(audio)
|
|
|
113 |
logging.error(f"Whisper transcription error: {e}")
|
114 |
return "Error: Could not transcribe audio."
|
115 |
|
116 |
+
def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str:
|
117 |
try:
|
|
|
118 |
messages = [{"role": "user", "content": text}]
|
119 |
input = tokenizer_gemma.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
|
120 |
generation_config = GenerationConfig(
|
121 |
max_new_tokens=MAX_GEMMA_LENGTH,
|
122 |
early_stopping=True,
|
123 |
)
|
|
|
124 |
generated_output = model_gemma.generate(input, generation_config=generation_config)
|
125 |
decoded_output = tokenizer_gemma.decode(generated_output[0], skip_special_tokens=False)
|
126 |
|
|
|
127 |
start_token = "<start_of_turn>model"
|
128 |
end_token = "<end_of_turn>"
|
|
|
129 |
start_index = decoded_output.find(start_token)
|
130 |
if start_index != -1:
|
131 |
start_index += len(start_token)
|
|
|
133 |
assistant_response = decoded_output[start_index:].strip()
|
134 |
return assistant_response
|
135 |
return decoded_output
|
136 |
+
|
137 |
except Exception as e:
|
138 |
logging.error(f"Gemma response generation error: {e}")
|
139 |
return "I'm sorry, I encountered an error generating a response."
|
140 |
|
141 |
+
def load_audio(audio_path: str, generator) -> torch.Tensor:
|
142 |
try:
|
143 |
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
144 |
audio_tensor = audio_tensor.mean(dim=0)
|
|
|
157 |
|
158 |
# --- Main Inference Function ---
|
159 |
|
160 |
+
@spaces.GPU(duration=gpu_timeout) # GPU decorator
|
161 |
def infer(user_audio) -> tuple[int, np.ndarray]:
|
|
|
162 |
if torch.cuda.is_available():
|
163 |
device = "cuda"
|
164 |
logging.info(f"CUDA is available! Using device: {torch.cuda.get_device_name(0)}")
|
165 |
else:
|
166 |
device = "cpu"
|
167 |
+
logging.info("CUDA is NOT available. Using CPU.")
|
|
|
168 |
|
169 |
try:
|
170 |
+
# --- Model Loading (ONLY inside infer, after GPU is available) ---
|
|
|
171 |
generator = load_csm_1b(csm_1b_model_path, device)
|
172 |
logging.info("Sesame CSM 1B loaded successfully.")
|
173 |
|
174 |
+
whisper_model = whisper.load_model(whisper_model_name, device=device, download_root=whisper_local_dir)
|
|
|
175 |
logging.info(f"Whisper model '{whisper_model_name}' loaded successfully.")
|
176 |
|
|
|
177 |
tokenizer_gemma = AutoTokenizer.from_pretrained(gemma_local_path)
|
178 |
model_gemma = AutoModelForCausalLM.from_pretrained(gemma_local_path).to(device)
|
179 |
logging.info("Gemma 3 1B pt model loaded successfully.")
|
180 |
|
181 |
if not user_audio:
|
182 |
raise ValueError("No audio input received.")
|
183 |
+
return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device)
|
184 |
+
|
185 |
except Exception as e:
|
186 |
logging.exception(f"Inference error: {e}")
|
187 |
raise gr.Error(f"An error occurred during processing: {e}")
|
|
|
190 |
global conversation_history
|
191 |
|
192 |
try:
|
193 |
+
user_text = transcribe_audio(user_audio, whisper_model)
|
194 |
logging.info(f"User: {user_text}")
|
195 |
|
196 |
+
ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device)
|
197 |
logging.info(f"AI: {ai_text}")
|
198 |
|
199 |
try:
|
|
|
209 |
raise gr.Error(f"Sesame response generation error: {e}")
|
210 |
|
211 |
|
212 |
+
user_segment = Segment(speaker = 1, text = user_text, audio = load_audio(user_audio, generator))
|
213 |
ai_segment = Segment(speaker = SPEAKER_ID, text = ai_text, audio = ai_audio)
|
214 |
conversation_history.append(user_segment)
|
215 |
conversation_history.append(ai_segment)
|