Spaces:
Bradarr
/
Runtime error

Bradarr commited on
Commit
6989477
·
verified ·
1 Parent(s): 2e50dda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -39
app.py CHANGED
@@ -4,11 +4,11 @@ import numpy as np
4
  import spaces
5
  import torch
6
  import torchaudio
7
- from generator import Segment, load_csm_1b
8
- from huggingface_hub import hf_hub_download, login
9
  from watermarking import watermark
10
- import whisper
11
- from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import logging
13
  from transformers import GenerationConfig
14
 
@@ -47,14 +47,14 @@ MAX_GEMMA_LENGTH = 128
47
  # --- Global Conversation History ---
48
  conversation_history = []
49
 
50
- # --- Model Downloading (PRE-DOWNLOAD) ---
51
 
52
- # Download Sesame CSM 1B
53
- csm_1b_model_path = "csm_1b_ckpt.pt" # Local path
54
  try:
55
  if not os.path.exists(csm_1b_model_path):
56
  hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", local_dir=".", local_dir_use_symlinks=False)
57
- os.rename("ckpt.pt", csm_1b_model_path) # Rename to avoid confusion
58
  logging.info("Sesame CSM 1B model downloaded.")
59
  else:
60
  logging.info("Sesame CSM 1B model already downloaded.")
@@ -62,25 +62,40 @@ except Exception as e:
62
  logging.error(f"Error downloading Sesame CSM 1B: {e}")
63
  raise
64
 
65
- # Download Whisper (using the built-in download mechanism)
66
  whisper_model_name = "small.en"
 
67
  try:
68
- whisper.load_model(whisper_model_name) # This downloads if not already present
69
- logging.info(f"Whisper model '{whisper_model_name}' downloaded/loaded.")
 
 
 
 
70
  except Exception as e:
71
- logging.error(f"Error downloading Whisper model: {e}")
72
- raise
73
 
74
- # Download Gemma 3 1B (Tokenizer and Model)
75
  gemma_repo_id = "google/gemma-3-1b-it"
76
- gemma_local_path = "gemma_model" # Using a directory
77
  try:
78
  if not os.path.exists(gemma_local_path):
79
- tokenizer_gemma = AutoTokenizer.from_pretrained(gemma_repo_id, cache_dir=gemma_local_path) #downloads
80
- model_gemma = AutoModelForCausalLM.from_pretrained(gemma_repo_id, cache_dir=gemma_local_path) #downloads
81
- logging.info("Gemma 3 1B model and tokenizer downloaded.")
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
- logging.info("Gemma 3 1B model and tokenizer already downloaded.")
84
  except Exception as e:
85
  logging.error(f"Error downloading Gemma 3 1B: {e}")
86
  raise
@@ -88,7 +103,7 @@ except Exception as e:
88
 
89
  # --- Helper Functions ---
90
 
91
- def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_model
92
  try:
93
  audio = whisper.load_audio(audio_path)
94
  audio = whisper.pad_or_trim(audio)
@@ -98,23 +113,19 @@ def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_mod
98
  logging.error(f"Whisper transcription error: {e}")
99
  return "Error: Could not transcribe audio."
100
 
101
- def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: # Pass model and tokenizer
102
  try:
103
- # Gemma 3 chat template format
104
  messages = [{"role": "user", "content": text}]
105
  input = tokenizer_gemma.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
106
  generation_config = GenerationConfig(
107
  max_new_tokens=MAX_GEMMA_LENGTH,
108
  early_stopping=True,
109
  )
110
-
111
  generated_output = model_gemma.generate(input, generation_config=generation_config)
112
  decoded_output = tokenizer_gemma.decode(generated_output[0], skip_special_tokens=False)
113
 
114
- # Extract the assistant's response (Gemma specific)
115
  start_token = "<start_of_turn>model"
116
  end_token = "<end_of_turn>"
117
-
118
  start_index = decoded_output.find(start_token)
119
  if start_index != -1:
120
  start_index += len(start_token)
@@ -122,11 +133,12 @@ def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: #
122
  assistant_response = decoded_output[start_index:].strip()
123
  return assistant_response
124
  return decoded_output
 
125
  except Exception as e:
126
  logging.error(f"Gemma response generation error: {e}")
127
  return "I'm sorry, I encountered an error generating a response."
128
 
129
- def load_audio(audio_path: str, generator) -> torch.Tensor: #Pass generator
130
  try:
131
  audio_tensor, sample_rate = torchaudio.load(audio_path)
132
  audio_tensor = audio_tensor.mean(dim=0)
@@ -145,35 +157,31 @@ def clear_history():
145
 
146
  # --- Main Inference Function ---
147
 
148
- @spaces.GPU(duration=gpu_timeout) # Decorator FIRST
149
  def infer(user_audio) -> tuple[int, np.ndarray]:
150
- # --- CUDA Availability Check (INSIDE infer) ---
151
  if torch.cuda.is_available():
152
  device = "cuda"
153
  logging.info(f"CUDA is available! Using device: {torch.cuda.get_device_name(0)}")
154
  else:
155
  device = "cpu"
156
- logging.info("CUDA is NOT available. Using CPU.")
157
-
158
 
159
  try:
160
- # --- Model Loading (INSIDE infer, after device is set) ---
161
- # Load Sesame CSM 1B (from local file)
162
  generator = load_csm_1b(csm_1b_model_path, device)
163
  logging.info("Sesame CSM 1B loaded successfully.")
164
 
165
- # Load Whisper (from local cache or downloaded)
166
- whisper_model = whisper.load_model(whisper_model_name, device=device)
167
  logging.info(f"Whisper model '{whisper_model_name}' loaded successfully.")
168
 
169
- # Load Gemma (from local cache)
170
  tokenizer_gemma = AutoTokenizer.from_pretrained(gemma_local_path)
171
  model_gemma = AutoModelForCausalLM.from_pretrained(gemma_local_path).to(device)
172
  logging.info("Gemma 3 1B pt model loaded successfully.")
173
 
174
  if not user_audio:
175
  raise ValueError("No audio input received.")
176
- return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) #Pass all models
 
177
  except Exception as e:
178
  logging.exception(f"Inference error: {e}")
179
  raise gr.Error(f"An error occurred during processing: {e}")
@@ -182,10 +190,10 @@ def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, d
182
  global conversation_history
183
 
184
  try:
185
- user_text = transcribe_audio(user_audio, whisper_model) # Pass whisper_model
186
  logging.info(f"User: {user_text}")
187
 
188
- ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device) # Pass model and tokenizer
189
  logging.info(f"AI: {ai_text}")
190
 
191
  try:
@@ -201,7 +209,7 @@ def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, d
201
  raise gr.Error(f"Sesame response generation error: {e}")
202
 
203
 
204
- user_segment = Segment(speaker = 1, text = user_text, audio = load_audio(user_audio, generator)) #Pass Generator
205
  ai_segment = Segment(speaker = SPEAKER_ID, text = ai_text, audio = ai_audio)
206
  conversation_history.append(user_segment)
207
  conversation_history.append(ai_segment)
 
4
  import spaces
5
  import torch
6
  import torchaudio
7
+ from generator import Segment, load_csm_1b # We'll use load_csm_1b *later*
8
+ from huggingface_hub import hf_hub_download, login, HfApi
9
  from watermarking import watermark
10
+ import whisper # We'll use whisper.load_model *later*
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM # We'll use these *later*
12
  import logging
13
  from transformers import GenerationConfig
14
 
 
47
  # --- Global Conversation History ---
48
  conversation_history = []
49
 
50
+ # --- Model Downloading (PRE-DOWNLOAD, NO LOADING) ---
51
 
52
+ # 1. Download Sesame CSM 1B
53
+ csm_1b_model_path = "csm_1b_ckpt.pt" # Local path for the downloaded model
54
  try:
55
  if not os.path.exists(csm_1b_model_path):
56
  hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", local_dir=".", local_dir_use_symlinks=False)
57
+ os.rename("ckpt.pt", csm_1b_model_path)
58
  logging.info("Sesame CSM 1B model downloaded.")
59
  else:
60
  logging.info("Sesame CSM 1B model already downloaded.")
 
62
  logging.error(f"Error downloading Sesame CSM 1B: {e}")
63
  raise
64
 
65
+ # 2. Download Whisper (using hf_hub_download for consistency)
66
  whisper_model_name = "small.en"
67
+ whisper_local_dir = "whisper_model" # Local directory for Whisper
68
  try:
69
+ if not os.path.exists(whisper_local_dir):
70
+ os.makedirs(whisper_local_dir, exist_ok=True) #Create if not exist
71
+ #Whisper uses a specific download method. This command should pre download everything needed
72
+ whisper.load_model(whisper_model_name, download_root=whisper_local_dir)
73
+ else:
74
+ logging.info("Whisper model already downloaded.")
75
  except Exception as e:
76
+ logging.error(f"Whisper model download failed with exception: {e}")
 
77
 
78
+ # 3. Download Gemma 3 1B (using hf_hub_download, individual files)
79
  gemma_repo_id = "google/gemma-3-1b-it"
80
+ gemma_local_path = os.path.abspath("gemma_model") # Absolute path
81
  try:
82
  if not os.path.exists(gemma_local_path):
83
+ os.makedirs(gemma_local_path, exist_ok=True) # Create the directory
84
+ api = HfApi()
85
+ # List all files in the repository
86
+ repo_files = api.list_repo_files(gemma_repo_id)
87
+
88
+ # Download each file individually
89
+ for file in repo_files:
90
+ hf_hub_download(
91
+ repo_id=gemma_repo_id,
92
+ filename=file,
93
+ local_dir=gemma_local_path,
94
+ local_dir_use_symlinks=False, # Ensure files are copied, not linked
95
+ )
96
+ logging.info("Gemma 3 1B model and tokenizer files downloaded.")
97
  else:
98
+ logging.info("Gemma 3 1B model and tokenizer files already downloaded.")
99
  except Exception as e:
100
  logging.error(f"Error downloading Gemma 3 1B: {e}")
101
  raise
 
103
 
104
  # --- Helper Functions ---
105
 
106
+ def transcribe_audio(audio_path: str, whisper_model) -> str:
107
  try:
108
  audio = whisper.load_audio(audio_path)
109
  audio = whisper.pad_or_trim(audio)
 
113
  logging.error(f"Whisper transcription error: {e}")
114
  return "Error: Could not transcribe audio."
115
 
116
+ def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str:
117
  try:
 
118
  messages = [{"role": "user", "content": text}]
119
  input = tokenizer_gemma.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
120
  generation_config = GenerationConfig(
121
  max_new_tokens=MAX_GEMMA_LENGTH,
122
  early_stopping=True,
123
  )
 
124
  generated_output = model_gemma.generate(input, generation_config=generation_config)
125
  decoded_output = tokenizer_gemma.decode(generated_output[0], skip_special_tokens=False)
126
 
 
127
  start_token = "<start_of_turn>model"
128
  end_token = "<end_of_turn>"
 
129
  start_index = decoded_output.find(start_token)
130
  if start_index != -1:
131
  start_index += len(start_token)
 
133
  assistant_response = decoded_output[start_index:].strip()
134
  return assistant_response
135
  return decoded_output
136
+
137
  except Exception as e:
138
  logging.error(f"Gemma response generation error: {e}")
139
  return "I'm sorry, I encountered an error generating a response."
140
 
141
+ def load_audio(audio_path: str, generator) -> torch.Tensor:
142
  try:
143
  audio_tensor, sample_rate = torchaudio.load(audio_path)
144
  audio_tensor = audio_tensor.mean(dim=0)
 
157
 
158
  # --- Main Inference Function ---
159
 
160
+ @spaces.GPU(duration=gpu_timeout) # GPU decorator
161
  def infer(user_audio) -> tuple[int, np.ndarray]:
 
162
  if torch.cuda.is_available():
163
  device = "cuda"
164
  logging.info(f"CUDA is available! Using device: {torch.cuda.get_device_name(0)}")
165
  else:
166
  device = "cpu"
167
+ logging.info("CUDA is NOT available. Using CPU.")
 
168
 
169
  try:
170
+ # --- Model Loading (ONLY inside infer, after GPU is available) ---
 
171
  generator = load_csm_1b(csm_1b_model_path, device)
172
  logging.info("Sesame CSM 1B loaded successfully.")
173
 
174
+ whisper_model = whisper.load_model(whisper_model_name, device=device, download_root=whisper_local_dir)
 
175
  logging.info(f"Whisper model '{whisper_model_name}' loaded successfully.")
176
 
 
177
  tokenizer_gemma = AutoTokenizer.from_pretrained(gemma_local_path)
178
  model_gemma = AutoModelForCausalLM.from_pretrained(gemma_local_path).to(device)
179
  logging.info("Gemma 3 1B pt model loaded successfully.")
180
 
181
  if not user_audio:
182
  raise ValueError("No audio input received.")
183
+ return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device)
184
+
185
  except Exception as e:
186
  logging.exception(f"Inference error: {e}")
187
  raise gr.Error(f"An error occurred during processing: {e}")
 
190
  global conversation_history
191
 
192
  try:
193
+ user_text = transcribe_audio(user_audio, whisper_model)
194
  logging.info(f"User: {user_text}")
195
 
196
+ ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device)
197
  logging.info(f"AI: {ai_text}")
198
 
199
  try:
 
209
  raise gr.Error(f"Sesame response generation error: {e}")
210
 
211
 
212
+ user_segment = Segment(speaker = 1, text = user_text, audio = load_audio(user_audio, generator))
213
  ai_segment = Segment(speaker = SPEAKER_ID, text = ai_text, audio = ai_audio)
214
  conversation_history.append(user_segment)
215
  conversation_history.append(ai_segment)