Spaces:
Bradarr
/
Runtime error

Bradarr commited on
Commit
2e50dda
·
verified ·
1 Parent(s): cdae601

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -18
app.py CHANGED
@@ -15,7 +15,7 @@ from transformers import GenerationConfig
15
  # Configure logging
16
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
17
 
18
- # --- Authentication and Configuration --- (Moved BEFORE model loading)
19
  try:
20
  api_key = os.getenv("HF_TOKEN")
21
  if not api_key:
@@ -39,9 +39,7 @@ This demo allows you to have a conversation with Sesame CSM 1B, leveraging Whisp
39
  *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
40
  """
41
 
42
- # --- Model Loading --- (Moved INSIDE infer function)
43
-
44
- # --- Constants --- (Constants can stay outside)
45
  SPEAKER_ID = 0
46
  MAX_CONTEXT_SEGMENTS = 3
47
  MAX_GEMMA_LENGTH = 128
@@ -49,6 +47,45 @@ MAX_GEMMA_LENGTH = 128
49
  # --- Global Conversation History ---
50
  conversation_history = []
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # --- Helper Functions ---
53
 
54
  def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_model
@@ -85,10 +122,6 @@ def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: #
85
  assistant_response = decoded_output[start_index:].strip()
86
  return assistant_response
87
  return decoded_output
88
- #input_text = "Reapond to the users prompt: " + text
89
- #input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
90
- #generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
91
- #return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True)
92
  except Exception as e:
93
  logging.error(f"Gemma response generation error: {e}")
94
  return "I'm sorry, I encountered an error generating a response."
@@ -116,25 +149,26 @@ def clear_history():
116
  def infer(user_audio) -> tuple[int, np.ndarray]:
117
  # --- CUDA Availability Check (INSIDE infer) ---
118
  if torch.cuda.is_available():
119
- print(f"CUDA is available! Device count: {torch.cuda.device_count()}")
120
- print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
121
- print(f"CUDA version: {torch.version.cuda}")
122
  device = "cuda"
 
123
  else:
124
- print("CUDA is NOT available. Using CPU.") # Use CPU, don't raise
125
  device = "cpu"
 
 
126
 
127
  try:
128
  # --- Model Loading (INSIDE infer, after device is set) ---
129
- model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
130
- generator = load_csm_1b(model_path, device)
131
  logging.info("Sesame CSM 1B loaded successfully.")
132
 
133
- whisper_model = whisper.load_model("small.en", device=device)
134
- logging.info("Whisper model loaded successfully.")
 
135
 
136
- tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
137
- model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it").to(device)
 
138
  logging.info("Gemma 3 1B pt model loaded successfully.")
139
 
140
  if not user_audio:
 
15
  # Configure logging
16
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
17
 
18
+ # --- Authentication and Configuration ---
19
  try:
20
  api_key = os.getenv("HF_TOKEN")
21
  if not api_key:
 
39
  *Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.*
40
  """
41
 
42
+ # --- Constants ---
 
 
43
  SPEAKER_ID = 0
44
  MAX_CONTEXT_SEGMENTS = 3
45
  MAX_GEMMA_LENGTH = 128
 
47
  # --- Global Conversation History ---
48
  conversation_history = []
49
 
50
+ # --- Model Downloading (PRE-DOWNLOAD) ---
51
+
52
+ # Download Sesame CSM 1B
53
+ csm_1b_model_path = "csm_1b_ckpt.pt" # Local path
54
+ try:
55
+ if not os.path.exists(csm_1b_model_path):
56
+ hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", local_dir=".", local_dir_use_symlinks=False)
57
+ os.rename("ckpt.pt", csm_1b_model_path) # Rename to avoid confusion
58
+ logging.info("Sesame CSM 1B model downloaded.")
59
+ else:
60
+ logging.info("Sesame CSM 1B model already downloaded.")
61
+ except Exception as e:
62
+ logging.error(f"Error downloading Sesame CSM 1B: {e}")
63
+ raise
64
+
65
+ # Download Whisper (using the built-in download mechanism)
66
+ whisper_model_name = "small.en"
67
+ try:
68
+ whisper.load_model(whisper_model_name) # This downloads if not already present
69
+ logging.info(f"Whisper model '{whisper_model_name}' downloaded/loaded.")
70
+ except Exception as e:
71
+ logging.error(f"Error downloading Whisper model: {e}")
72
+ raise
73
+
74
+ # Download Gemma 3 1B (Tokenizer and Model)
75
+ gemma_repo_id = "google/gemma-3-1b-it"
76
+ gemma_local_path = "gemma_model" # Using a directory
77
+ try:
78
+ if not os.path.exists(gemma_local_path):
79
+ tokenizer_gemma = AutoTokenizer.from_pretrained(gemma_repo_id, cache_dir=gemma_local_path) #downloads
80
+ model_gemma = AutoModelForCausalLM.from_pretrained(gemma_repo_id, cache_dir=gemma_local_path) #downloads
81
+ logging.info("Gemma 3 1B model and tokenizer downloaded.")
82
+ else:
83
+ logging.info("Gemma 3 1B model and tokenizer already downloaded.")
84
+ except Exception as e:
85
+ logging.error(f"Error downloading Gemma 3 1B: {e}")
86
+ raise
87
+
88
+
89
  # --- Helper Functions ---
90
 
91
  def transcribe_audio(audio_path: str, whisper_model) -> str: # Pass whisper_model
 
122
  assistant_response = decoded_output[start_index:].strip()
123
  return assistant_response
124
  return decoded_output
 
 
 
 
125
  except Exception as e:
126
  logging.error(f"Gemma response generation error: {e}")
127
  return "I'm sorry, I encountered an error generating a response."
 
149
  def infer(user_audio) -> tuple[int, np.ndarray]:
150
  # --- CUDA Availability Check (INSIDE infer) ---
151
  if torch.cuda.is_available():
 
 
 
152
  device = "cuda"
153
+ logging.info(f"CUDA is available! Using device: {torch.cuda.get_device_name(0)}")
154
  else:
 
155
  device = "cpu"
156
+ logging.info("CUDA is NOT available. Using CPU.")
157
+
158
 
159
  try:
160
  # --- Model Loading (INSIDE infer, after device is set) ---
161
+ # Load Sesame CSM 1B (from local file)
162
+ generator = load_csm_1b(csm_1b_model_path, device)
163
  logging.info("Sesame CSM 1B loaded successfully.")
164
 
165
+ # Load Whisper (from local cache or downloaded)
166
+ whisper_model = whisper.load_model(whisper_model_name, device=device)
167
+ logging.info(f"Whisper model '{whisper_model_name}' loaded successfully.")
168
 
169
+ # Load Gemma (from local cache)
170
+ tokenizer_gemma = AutoTokenizer.from_pretrained(gemma_local_path)
171
+ model_gemma = AutoModelForCausalLM.from_pretrained(gemma_local_path).to(device)
172
  logging.info("Gemma 3 1B pt model loaded successfully.")
173
 
174
  if not user_audio: