Update src/streamlit_app.py
Browse files- src/streamlit_app.py +28 -52
src/streamlit_app.py
CHANGED
@@ -1,19 +1,14 @@
|
|
1 |
import os, pathlib
|
2 |
|
3 |
|
4 |
-
home = "/tmp"
|
5 |
-
os.environ["HOME"] = home
|
6 |
-
cfg_dir = pathlib.Path(home) / ".streamlit"
|
7 |
-
cfg_dir.mkdir(parents=True, exist_ok=True)
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
(cfg_dir / "config.toml").write_text(
|
11 |
-
"[server]\n"
|
12 |
-
"
|
13 |
-
"port = 7860\n"
|
14 |
-
"address = \"0.0.0.0\"\n\n"
|
15 |
-
"[browser]\n"
|
16 |
-
"gatherUsageStats = false\n",
|
17 |
encoding="utf-8",
|
18 |
)
|
19 |
import streamlit as st
|
@@ -62,6 +57,7 @@ class ConfigManager:
|
|
62 |
self.hf_token = os.getenv("HF_TOKEN")
|
63 |
self.google_creds_json = os.getenv("GOOGLE_SHEETS_CREDENTIALS")
|
64 |
self.google_sheets_id = os.getenv("GOOGLE_SHEETS_ID")
|
|
|
65 |
|
66 |
|
67 |
missing_vars = []
|
@@ -153,7 +149,7 @@ class AIAssistant:
|
|
153 |
|
154 |
self.client = InferenceClient(
|
155 |
model=self.model,
|
156 |
-
token=
|
157 |
timeout=60.0,
|
158 |
)
|
159 |
|
@@ -186,59 +182,39 @@ Please follow these guidelines:
|
|
186 |
|
187 |
def generate_response(self, button_name: str, question: str, retry_count: int = 0) -> str:
|
188 |
try:
|
189 |
-
# Build prompts
|
190 |
system_text = self.base_prompt + self.prompt_templates.get(button_name, "")
|
191 |
if retry_count > 0:
|
192 |
system_text += f"\nPlease provide a different explanation. This is attempt {retry_count + 1}."
|
193 |
user_text = f"Question:\n{question}"
|
194 |
|
195 |
-
full_prompt = f"{system_text}\n\n{user_text}" # still used for text_generation
|
196 |
-
|
197 |
try:
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
text = self.client.text_generation(
|
200 |
prompt=full_prompt,
|
201 |
max_new_tokens=300,
|
202 |
temperature=0.7,
|
203 |
repetition_penalty=1.1,
|
204 |
-
model=self.model, # explicit even though client is bound
|
205 |
-
)
|
206 |
-
except (HfHubHTTPError, ValueError) as e:
|
207 |
-
# If the provider/model doesn't support text-generation, fall back to chat
|
208 |
-
msg = str(e)
|
209 |
-
unsupported = (
|
210 |
-
"Task 'text-generation' not supported" in msg
|
211 |
-
or "doesn't support task 'text-generation'" in msg
|
212 |
-
or "Available tasks: ['conversational']" in msg
|
213 |
)
|
214 |
-
if unsupported:
|
215 |
-
# OpenAI-style chat interface
|
216 |
-
# Build messages: system + user
|
217 |
-
messages = [
|
218 |
-
{"role": "system", "content": system_text},
|
219 |
-
{"role": "user", "content": user_text},
|
220 |
-
]
|
221 |
-
chat = self.client.chat_completion(
|
222 |
-
messages=messages,
|
223 |
-
max_tokens=350,
|
224 |
-
temperature=0.7,
|
225 |
-
model=self.model,
|
226 |
-
)
|
227 |
-
# Robust extraction
|
228 |
-
text = ""
|
229 |
-
try:
|
230 |
-
# chat.choices[0].message.content (OpenAI-like)
|
231 |
-
choices = getattr(chat, "choices", None) or chat.get("choices", [])
|
232 |
-
if choices:
|
233 |
-
msg0 = choices[0].get("message") or {}
|
234 |
-
text = msg0.get("content") or ""
|
235 |
-
if not text:
|
236 |
-
# Some providers return 'generated_text'
|
237 |
-
text = getattr(chat, "generated_text", None) or chat.get("generated_text", "") or ""
|
238 |
-
except Exception:
|
239 |
-
text = str(chat)
|
240 |
-
else:
|
241 |
-
raise
|
242 |
except (httpx.ReadTimeout, httpx.ConnectTimeout):
|
243 |
return "The model request timed out. Please try again."
|
244 |
|
|
|
1 |
import os, pathlib
|
2 |
|
3 |
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
os.environ["HOME"] = "/tmp"
|
6 |
+
cfg_dir = pathlib.Path("/tmp/.streamlit")
|
7 |
+
cfg_dir.mkdir(parents=True, exist_ok=True)
|
8 |
|
9 |
(cfg_dir / "config.toml").write_text(
|
10 |
+
"[server]\nheadless = true\nport = 7860\naddress = \"0.0.0.0\"\n\n"
|
11 |
+
"[browser]\ngatherUsageStats = false\n",
|
|
|
|
|
|
|
|
|
12 |
encoding="utf-8",
|
13 |
)
|
14 |
import streamlit as st
|
|
|
57 |
self.hf_token = os.getenv("HF_TOKEN")
|
58 |
self.google_creds_json = os.getenv("GOOGLE_SHEETS_CREDENTIALS")
|
59 |
self.google_sheets_id = os.getenv("GOOGLE_SHEETS_ID")
|
60 |
+
|
61 |
|
62 |
|
63 |
missing_vars = []
|
|
|
149 |
|
150 |
self.client = InferenceClient(
|
151 |
model=self.model,
|
152 |
+
token=token,
|
153 |
timeout=60.0,
|
154 |
)
|
155 |
|
|
|
182 |
|
183 |
def generate_response(self, button_name: str, question: str, retry_count: int = 0) -> str:
|
184 |
try:
|
|
|
185 |
system_text = self.base_prompt + self.prompt_templates.get(button_name, "")
|
186 |
if retry_count > 0:
|
187 |
system_text += f"\nPlease provide a different explanation. This is attempt {retry_count + 1}."
|
188 |
user_text = f"Question:\n{question}"
|
189 |
|
|
|
|
|
190 |
try:
|
191 |
+
messages = [
|
192 |
+
{"role": "system", "content": system_text},
|
193 |
+
{"role": "user", "content": user_text},
|
194 |
+
]
|
195 |
+
chat = self.client.chat_completion(
|
196 |
+
messages=messages,
|
197 |
+
max_tokens=350,
|
198 |
+
temperature=0.7,
|
199 |
+
)
|
200 |
+
text = ""
|
201 |
+
try:
|
202 |
+
choices = getattr(chat, "choices", None) or chat.get("choices", [])
|
203 |
+
if choices:
|
204 |
+
msg0 = choices[0].get("message") or {}
|
205 |
+
text = msg0.get("content") or ""
|
206 |
+
if not text:
|
207 |
+
text = getattr(chat, "generated_text", None) or chat.get("generated_text", "") or ""
|
208 |
+
except Exception:
|
209 |
+
text = str(chat)
|
210 |
+
except (HfHubHTTPError, ValueError, AttributeError):
|
211 |
+
full_prompt = f"{system_text}\n\n{user_text}"
|
212 |
text = self.client.text_generation(
|
213 |
prompt=full_prompt,
|
214 |
max_new_tokens=300,
|
215 |
temperature=0.7,
|
216 |
repetition_penalty=1.1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
except (httpx.ReadTimeout, httpx.ConnectTimeout):
|
219 |
return "The model request timed out. Please try again."
|
220 |
|