feministmystique commited on
Commit
5c39a0d
·
verified ·
1 Parent(s): bdc1106

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +28 -52
src/streamlit_app.py CHANGED
@@ -1,19 +1,14 @@
1
  import os, pathlib
2
 
3
 
4
- home = "/tmp"
5
- os.environ["HOME"] = home
6
- cfg_dir = pathlib.Path(home) / ".streamlit"
7
- cfg_dir.mkdir(parents=True, exist_ok=True)
8
 
 
 
 
9
 
10
  (cfg_dir / "config.toml").write_text(
11
- "[server]\n"
12
- "headless = true\n"
13
- "port = 7860\n"
14
- "address = \"0.0.0.0\"\n\n"
15
- "[browser]\n"
16
- "gatherUsageStats = false\n",
17
  encoding="utf-8",
18
  )
19
  import streamlit as st
@@ -62,6 +57,7 @@ class ConfigManager:
62
  self.hf_token = os.getenv("HF_TOKEN")
63
  self.google_creds_json = os.getenv("GOOGLE_SHEETS_CREDENTIALS")
64
  self.google_sheets_id = os.getenv("GOOGLE_SHEETS_ID")
 
65
 
66
 
67
  missing_vars = []
@@ -153,7 +149,7 @@ class AIAssistant:
153
 
154
  self.client = InferenceClient(
155
  model=self.model,
156
- token=None,
157
  timeout=60.0,
158
  )
159
 
@@ -186,59 +182,39 @@ Please follow these guidelines:
186
 
187
  def generate_response(self, button_name: str, question: str, retry_count: int = 0) -> str:
188
  try:
189
- # Build prompts
190
  system_text = self.base_prompt + self.prompt_templates.get(button_name, "")
191
  if retry_count > 0:
192
  system_text += f"\nPlease provide a different explanation. This is attempt {retry_count + 1}."
193
  user_text = f"Question:\n{question}"
194
 
195
- full_prompt = f"{system_text}\n\n{user_text}" # still used for text_generation
196
-
197
  try:
198
- # Try classic text-generation first
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  text = self.client.text_generation(
200
  prompt=full_prompt,
201
  max_new_tokens=300,
202
  temperature=0.7,
203
  repetition_penalty=1.1,
204
- model=self.model, # explicit even though client is bound
205
- )
206
- except (HfHubHTTPError, ValueError) as e:
207
- # If the provider/model doesn't support text-generation, fall back to chat
208
- msg = str(e)
209
- unsupported = (
210
- "Task 'text-generation' not supported" in msg
211
- or "doesn't support task 'text-generation'" in msg
212
- or "Available tasks: ['conversational']" in msg
213
  )
214
- if unsupported:
215
- # OpenAI-style chat interface
216
- # Build messages: system + user
217
- messages = [
218
- {"role": "system", "content": system_text},
219
- {"role": "user", "content": user_text},
220
- ]
221
- chat = self.client.chat_completion(
222
- messages=messages,
223
- max_tokens=350,
224
- temperature=0.7,
225
- model=self.model,
226
- )
227
- # Robust extraction
228
- text = ""
229
- try:
230
- # chat.choices[0].message.content (OpenAI-like)
231
- choices = getattr(chat, "choices", None) or chat.get("choices", [])
232
- if choices:
233
- msg0 = choices[0].get("message") or {}
234
- text = msg0.get("content") or ""
235
- if not text:
236
- # Some providers return 'generated_text'
237
- text = getattr(chat, "generated_text", None) or chat.get("generated_text", "") or ""
238
- except Exception:
239
- text = str(chat)
240
- else:
241
- raise
242
  except (httpx.ReadTimeout, httpx.ConnectTimeout):
243
  return "The model request timed out. Please try again."
244
 
 
1
  import os, pathlib
2
 
3
 
 
 
 
 
4
 
5
+ os.environ["HOME"] = "/tmp"
6
+ cfg_dir = pathlib.Path("/tmp/.streamlit")
7
+ cfg_dir.mkdir(parents=True, exist_ok=True)
8
 
9
  (cfg_dir / "config.toml").write_text(
10
+ "[server]\nheadless = true\nport = 7860\naddress = \"0.0.0.0\"\n\n"
11
+ "[browser]\ngatherUsageStats = false\n",
 
 
 
 
12
  encoding="utf-8",
13
  )
14
  import streamlit as st
 
57
  self.hf_token = os.getenv("HF_TOKEN")
58
  self.google_creds_json = os.getenv("GOOGLE_SHEETS_CREDENTIALS")
59
  self.google_sheets_id = os.getenv("GOOGLE_SHEETS_ID")
60
+
61
 
62
 
63
  missing_vars = []
 
149
 
150
  self.client = InferenceClient(
151
  model=self.model,
152
+ token=token,
153
  timeout=60.0,
154
  )
155
 
 
182
 
183
  def generate_response(self, button_name: str, question: str, retry_count: int = 0) -> str:
184
  try:
 
185
  system_text = self.base_prompt + self.prompt_templates.get(button_name, "")
186
  if retry_count > 0:
187
  system_text += f"\nPlease provide a different explanation. This is attempt {retry_count + 1}."
188
  user_text = f"Question:\n{question}"
189
 
 
 
190
  try:
191
+ messages = [
192
+ {"role": "system", "content": system_text},
193
+ {"role": "user", "content": user_text},
194
+ ]
195
+ chat = self.client.chat_completion(
196
+ messages=messages,
197
+ max_tokens=350,
198
+ temperature=0.7,
199
+ )
200
+ text = ""
201
+ try:
202
+ choices = getattr(chat, "choices", None) or chat.get("choices", [])
203
+ if choices:
204
+ msg0 = choices[0].get("message") or {}
205
+ text = msg0.get("content") or ""
206
+ if not text:
207
+ text = getattr(chat, "generated_text", None) or chat.get("generated_text", "") or ""
208
+ except Exception:
209
+ text = str(chat)
210
+ except (HfHubHTTPError, ValueError, AttributeError):
211
+ full_prompt = f"{system_text}\n\n{user_text}"
212
  text = self.client.text_generation(
213
  prompt=full_prompt,
214
  max_new_tokens=300,
215
  temperature=0.7,
216
  repetition_penalty=1.1,
 
 
 
 
 
 
 
 
 
217
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  except (httpx.ReadTimeout, httpx.ConnectTimeout):
219
  return "The model request timed out. Please try again."
220