Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import json
|
|
|
2 |
import requests
|
3 |
import streamlit as st
|
4 |
import pdfplumber
|
5 |
import pandas as pd
|
6 |
import sqlalchemy
|
7 |
-
from typing import Any, Dict, List
|
8 |
|
9 |
# Provider clients – ensure these libraries are installed
|
10 |
try:
|
@@ -17,10 +18,10 @@ try:
|
|
17 |
except ImportError:
|
18 |
groq = None
|
19 |
|
20 |
-
# Hugging Face inference endpoint
|
21 |
-
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
22 |
-
DEFAULT_TEMPERATURE = 0.1
|
23 |
-
GROQ_MODEL = "mixtral-8x7b-32768"
|
24 |
|
25 |
|
26 |
class QADataGenerator:
|
@@ -32,8 +33,8 @@ class QADataGenerator:
|
|
32 |
self._setup_providers()
|
33 |
self._setup_input_handlers()
|
34 |
self._initialize_session_state()
|
35 |
-
# Updated prompt template with escaped curly braces
|
36 |
-
self.custom_prompt_template = (
|
37 |
"You are an expert in extracting question and answer pairs from documents. "
|
38 |
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
|
39 |
"Each dictionary must have keys 'question' and 'answer'. "
|
@@ -69,7 +70,7 @@ class QADataGenerator:
|
|
69 |
|
70 |
def _setup_input_handlers(self) -> None:
|
71 |
"""Register handlers for different input data types."""
|
72 |
-
self.input_handlers: Dict[str, Any] = {
|
73 |
"text": self.handle_text,
|
74 |
"pdf": self.handle_pdf,
|
75 |
"csv": self.handle_csv,
|
@@ -79,7 +80,7 @@ class QADataGenerator:
|
|
79 |
|
80 |
def _initialize_session_state(self) -> None:
|
81 |
"""Initialize Streamlit session state with default configuration."""
|
82 |
-
defaults = {
|
83 |
"config": {
|
84 |
"provider": "OpenAI",
|
85 |
"model": "gpt-4-turbo",
|
@@ -101,30 +102,31 @@ class QADataGenerator:
|
|
101 |
|
102 |
# ----- Input Handlers -----
|
103 |
def handle_text(self, text: str) -> Dict[str, Any]:
|
|
|
104 |
return {"data": text, "source": "text"}
|
105 |
|
106 |
def handle_pdf(self, file) -> Dict[str, Any]:
|
|
|
107 |
try:
|
108 |
with pdfplumber.open(file) as pdf:
|
109 |
-
full_text = ""
|
110 |
-
|
111 |
-
page_text = page.extract_text() or ""
|
112 |
-
full_text += page_text + "\n"
|
113 |
-
return {"data": full_text, "source": "pdf"}
|
114 |
except Exception as e:
|
115 |
self.log_error(f"PDF Processing Error: {e}")
|
116 |
return {"data": "", "source": "pdf"}
|
117 |
|
118 |
def handle_csv(self, file) -> Dict[str, Any]:
|
|
|
119 |
try:
|
120 |
df = pd.read_csv(file)
|
121 |
-
|
122 |
-
return {"data":
|
123 |
except Exception as e:
|
124 |
self.log_error(f"CSV Processing Error: {e}")
|
125 |
return {"data": "", "source": "csv"}
|
126 |
|
127 |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
|
|
128 |
try:
|
129 |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
130 |
response.raise_for_status()
|
@@ -134,6 +136,7 @@ class QADataGenerator:
|
|
134 |
return {"data": "", "source": "api"}
|
135 |
|
136 |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
|
|
137 |
try:
|
138 |
engine = sqlalchemy.create_engine(config["connection"])
|
139 |
with engine.connect() as conn:
|
@@ -166,22 +169,22 @@ class QADataGenerator:
|
|
166 |
"""
|
167 |
Generate Q&A pairs by sending the built prompt to the selected LLM provider.
|
168 |
"""
|
169 |
-
api_key = st.session_state.api_key
|
170 |
if not api_key:
|
171 |
self.log_error("API key is missing!")
|
172 |
return False
|
173 |
|
174 |
-
provider_name = st.session_state.config["provider"]
|
175 |
-
provider_cfg = self.providers.get(provider_name)
|
176 |
if not provider_cfg:
|
177 |
self.log_error(f"Provider {provider_name} is not configured.")
|
178 |
return False
|
179 |
|
180 |
-
client_initializer = provider_cfg["client"]
|
181 |
client = client_initializer(api_key)
|
182 |
-
model = st.session_state.config["model"]
|
183 |
-
temperature = st.session_state.config["temperature"]
|
184 |
-
prompt = self.build_prompt()
|
185 |
|
186 |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
|
187 |
try:
|
@@ -238,36 +241,40 @@ class QADataGenerator:
|
|
238 |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
239 |
"""
|
240 |
Parse the LLM response and return a list of Q&A pairs.
|
241 |
-
Expects the response to be JSON formatted
|
|
|
242 |
"""
|
243 |
st.write("Parsing response for provider:", provider)
|
244 |
try:
|
|
|
245 |
if provider == "HuggingFace":
|
246 |
-
# For HuggingFace, assume the generated text is under "generated_text"
|
247 |
if isinstance(response, list) and response and "generated_text" in response[0]:
|
248 |
raw_text = response[0]["generated_text"]
|
249 |
else:
|
250 |
self.log_error("Unexpected HuggingFace response format.")
|
251 |
return []
|
252 |
else:
|
253 |
-
# For OpenAI (and similar providers) assume the response is similar to:
|
254 |
-
# response.choices[0].message.content
|
255 |
if response and hasattr(response, "choices") and response.choices:
|
256 |
raw_text = response.choices[0].message.content
|
257 |
else:
|
258 |
self.log_error("Unexpected response format from provider.")
|
259 |
return []
|
260 |
|
261 |
-
#
|
262 |
try:
|
263 |
qa_list = json.loads(raw_text)
|
264 |
-
if isinstance(qa_list, list):
|
265 |
-
return qa_list
|
266 |
-
else:
|
267 |
-
self.log_error("Parsed output is not a list.")
|
268 |
-
return []
|
269 |
except json.JSONDecodeError as e:
|
270 |
-
self.log_error(f"JSON Parsing Error: {e}. Raw output: {raw_text}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
return []
|
272 |
except Exception as e:
|
273 |
self.log_error(f"Response Parsing Error: {e}")
|
@@ -276,7 +283,7 @@ class QADataGenerator:
|
|
276 |
|
277 |
# ============ UI Components ============
|
278 |
|
279 |
-
def config_ui(generator: QADataGenerator):
|
280 |
"""Display configuration options in the sidebar."""
|
281 |
with st.sidebar:
|
282 |
st.header("Configuration")
|
@@ -293,7 +300,7 @@ def config_ui(generator: QADataGenerator):
|
|
293 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
294 |
st.session_state.api_key = api_key
|
295 |
|
296 |
-
def input_ui(generator: QADataGenerator):
|
297 |
"""Display input data source options using tabs."""
|
298 |
st.subheader("Input Data Sources")
|
299 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
@@ -339,7 +346,7 @@ def input_ui(generator: QADataGenerator):
|
|
339 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
340 |
st.success("Database input added!")
|
341 |
|
342 |
-
def output_ui(generator: QADataGenerator):
|
343 |
"""Display the generated Q&A pairs and provide a download option."""
|
344 |
st.subheader("Q&A Pairs Output")
|
345 |
if st.session_state.qa_pairs:
|
@@ -354,7 +361,7 @@ def output_ui(generator: QADataGenerator):
|
|
354 |
else:
|
355 |
st.info("No Q&A pairs generated yet.")
|
356 |
|
357 |
-
def logs_ui():
|
358 |
"""Display error logs and debugging information in an expandable section."""
|
359 |
with st.expander("Error Logs & Debug Info", expanded=False):
|
360 |
if st.session_state.error_logs:
|
@@ -363,7 +370,8 @@ def logs_ui():
|
|
363 |
else:
|
364 |
st.write("No logs yet.")
|
365 |
|
366 |
-
def main():
|
|
|
367 |
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
|
368 |
st.title("Advanced Q&A Synthetic Generator")
|
369 |
st.markdown(
|
|
|
1 |
import json
|
2 |
+
import ast
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
import pdfplumber
|
6 |
import pandas as pd
|
7 |
import sqlalchemy
|
8 |
+
from typing import Any, Dict, List, Callable
|
9 |
|
10 |
# Provider clients – ensure these libraries are installed
|
11 |
try:
|
|
|
18 |
except ImportError:
|
19 |
groq = None
|
20 |
|
21 |
+
# Hugging Face inference endpoint and defaults
|
22 |
+
HF_API_URL: str = "https://api-inference.huggingface.co/models/"
|
23 |
+
DEFAULT_TEMPERATURE: float = 0.1
|
24 |
+
GROQ_MODEL: str = "mixtral-8x7b-32768"
|
25 |
|
26 |
|
27 |
class QADataGenerator:
|
|
|
33 |
self._setup_providers()
|
34 |
self._setup_input_handlers()
|
35 |
self._initialize_session_state()
|
36 |
+
# Updated prompt template with escaped curly braces for literal output
|
37 |
+
self.custom_prompt_template: str = (
|
38 |
"You are an expert in extracting question and answer pairs from documents. "
|
39 |
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
|
40 |
"Each dictionary must have keys 'question' and 'answer'. "
|
|
|
70 |
|
71 |
def _setup_input_handlers(self) -> None:
|
72 |
"""Register handlers for different input data types."""
|
73 |
+
self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
|
74 |
"text": self.handle_text,
|
75 |
"pdf": self.handle_pdf,
|
76 |
"csv": self.handle_csv,
|
|
|
80 |
|
81 |
def _initialize_session_state(self) -> None:
|
82 |
"""Initialize Streamlit session state with default configuration."""
|
83 |
+
defaults: Dict[str, Any] = {
|
84 |
"config": {
|
85 |
"provider": "OpenAI",
|
86 |
"model": "gpt-4-turbo",
|
|
|
102 |
|
103 |
# ----- Input Handlers -----
|
104 |
def handle_text(self, text: str) -> Dict[str, Any]:
|
105 |
+
"""Process plain text input."""
|
106 |
return {"data": text, "source": "text"}
|
107 |
|
108 |
def handle_pdf(self, file) -> Dict[str, Any]:
|
109 |
+
"""Extract text from a PDF file."""
|
110 |
try:
|
111 |
with pdfplumber.open(file) as pdf:
|
112 |
+
full_text = "\n".join(page.extract_text() or "" for page in pdf.pages)
|
113 |
+
return {"data": full_text, "source": "pdf"}
|
|
|
|
|
|
|
114 |
except Exception as e:
|
115 |
self.log_error(f"PDF Processing Error: {e}")
|
116 |
return {"data": "", "source": "pdf"}
|
117 |
|
118 |
def handle_csv(self, file) -> Dict[str, Any]:
|
119 |
+
"""Process a CSV file by converting it to JSON."""
|
120 |
try:
|
121 |
df = pd.read_csv(file)
|
122 |
+
json_data = df.to_json(orient="records")
|
123 |
+
return {"data": json_data, "source": "csv"}
|
124 |
except Exception as e:
|
125 |
self.log_error(f"CSV Processing Error: {e}")
|
126 |
return {"data": "", "source": "csv"}
|
127 |
|
128 |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
129 |
+
"""Fetch data from an API endpoint."""
|
130 |
try:
|
131 |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
132 |
response.raise_for_status()
|
|
|
136 |
return {"data": "", "source": "api"}
|
137 |
|
138 |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
139 |
+
"""Query a database using the provided connection string and SQL query."""
|
140 |
try:
|
141 |
engine = sqlalchemy.create_engine(config["connection"])
|
142 |
with engine.connect() as conn:
|
|
|
169 |
"""
|
170 |
Generate Q&A pairs by sending the built prompt to the selected LLM provider.
|
171 |
"""
|
172 |
+
api_key: str = st.session_state.api_key
|
173 |
if not api_key:
|
174 |
self.log_error("API key is missing!")
|
175 |
return False
|
176 |
|
177 |
+
provider_name: str = st.session_state.config["provider"]
|
178 |
+
provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {})
|
179 |
if not provider_cfg:
|
180 |
self.log_error(f"Provider {provider_name} is not configured.")
|
181 |
return False
|
182 |
|
183 |
+
client_initializer: Callable[[str], Any] = provider_cfg["client"]
|
184 |
client = client_initializer(api_key)
|
185 |
+
model: str = st.session_state.config["model"]
|
186 |
+
temperature: float = st.session_state.config["temperature"]
|
187 |
+
prompt: str = self.build_prompt()
|
188 |
|
189 |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
|
190 |
try:
|
|
|
241 |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
242 |
"""
|
243 |
Parse the LLM response and return a list of Q&A pairs.
|
244 |
+
Expects the response to be JSON formatted; if JSON decoding fails,
|
245 |
+
tries to use ast.literal_eval as a fallback.
|
246 |
"""
|
247 |
st.write("Parsing response for provider:", provider)
|
248 |
try:
|
249 |
+
# For non-HuggingFace providers, extract the raw text from the response.
|
250 |
if provider == "HuggingFace":
|
|
|
251 |
if isinstance(response, list) and response and "generated_text" in response[0]:
|
252 |
raw_text = response[0]["generated_text"]
|
253 |
else:
|
254 |
self.log_error("Unexpected HuggingFace response format.")
|
255 |
return []
|
256 |
else:
|
|
|
|
|
257 |
if response and hasattr(response, "choices") and response.choices:
|
258 |
raw_text = response.choices[0].message.content
|
259 |
else:
|
260 |
self.log_error("Unexpected response format from provider.")
|
261 |
return []
|
262 |
|
263 |
+
# Attempt to parse using json.loads first.
|
264 |
try:
|
265 |
qa_list = json.loads(raw_text)
|
|
|
|
|
|
|
|
|
|
|
266 |
except json.JSONDecodeError as e:
|
267 |
+
self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}")
|
268 |
+
try:
|
269 |
+
qa_list = ast.literal_eval(raw_text)
|
270 |
+
except Exception as e2:
|
271 |
+
self.log_error(f"ast.literal_eval failed: {e2}")
|
272 |
+
return []
|
273 |
+
|
274 |
+
if isinstance(qa_list, list):
|
275 |
+
return qa_list
|
276 |
+
else:
|
277 |
+
self.log_error("Parsed output is not a list.")
|
278 |
return []
|
279 |
except Exception as e:
|
280 |
self.log_error(f"Response Parsing Error: {e}")
|
|
|
283 |
|
284 |
# ============ UI Components ============
|
285 |
|
286 |
+
def config_ui(generator: QADataGenerator) -> None:
|
287 |
"""Display configuration options in the sidebar."""
|
288 |
with st.sidebar:
|
289 |
st.header("Configuration")
|
|
|
300 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
301 |
st.session_state.api_key = api_key
|
302 |
|
303 |
+
def input_ui(generator: QADataGenerator) -> None:
|
304 |
"""Display input data source options using tabs."""
|
305 |
st.subheader("Input Data Sources")
|
306 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
|
|
346 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
347 |
st.success("Database input added!")
|
348 |
|
349 |
+
def output_ui(generator: QADataGenerator) -> None:
|
350 |
"""Display the generated Q&A pairs and provide a download option."""
|
351 |
st.subheader("Q&A Pairs Output")
|
352 |
if st.session_state.qa_pairs:
|
|
|
361 |
else:
|
362 |
st.info("No Q&A pairs generated yet.")
|
363 |
|
364 |
+
def logs_ui() -> None:
|
365 |
"""Display error logs and debugging information in an expandable section."""
|
366 |
with st.expander("Error Logs & Debug Info", expanded=False):
|
367 |
if st.session_state.error_logs:
|
|
|
370 |
else:
|
371 |
st.write("No logs yet.")
|
372 |
|
373 |
+
def main() -> None:
|
374 |
+
"""Main Streamlit application entry point."""
|
375 |
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
|
376 |
st.title("Advanced Q&A Synthetic Generator")
|
377 |
st.markdown(
|