Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import json
|
|
|
2 |
import requests
|
3 |
import streamlit as st
|
4 |
import pdfplumber
|
5 |
import pandas as pd
|
6 |
import sqlalchemy
|
7 |
-
from typing import Any, Dict, List
|
8 |
|
9 |
# Provider clients – ensure these libraries are installed
|
10 |
try:
|
@@ -17,10 +18,20 @@ try:
|
|
17 |
except ImportError:
|
18 |
groq = None
|
19 |
|
20 |
-
# Hugging Face inference endpoint
|
21 |
-
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
22 |
-
DEFAULT_TEMPERATURE = 0.1
|
23 |
-
GROQ_MODEL = "mixtral-8x7b-32768"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
class QADataGenerator:
|
@@ -28,12 +39,13 @@ class QADataGenerator:
|
|
28 |
A Q&A Synthetic Generator that extracts and generates question-answer pairs
|
29 |
from various input sources using an LLM provider.
|
30 |
"""
|
|
|
31 |
def __init__(self) -> None:
|
32 |
self._setup_providers()
|
33 |
self._setup_input_handlers()
|
34 |
self._initialize_session_state()
|
35 |
-
#
|
36 |
-
self.custom_prompt_template = (
|
37 |
"You are an expert in extracting question and answer pairs from documents. "
|
38 |
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
|
39 |
"Each dictionary must have keys 'question' and 'answer'. "
|
@@ -45,9 +57,11 @@ class QADataGenerator:
|
|
45 |
"{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}]\n\n"
|
46 |
"Now, generate 3 Q&A pairs from this data:\n{data}"
|
47 |
)
|
48 |
-
|
49 |
def _setup_providers(self) -> None:
|
50 |
-
"""
|
|
|
|
|
51 |
self.providers: Dict[str, Dict[str, Any]] = {
|
52 |
"Deepseek": {
|
53 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
|
@@ -66,20 +80,24 @@ class QADataGenerator:
|
|
66 |
"models": ["gpt2", "llama-2"],
|
67 |
},
|
68 |
}
|
69 |
-
|
70 |
def _setup_input_handlers(self) -> None:
|
71 |
-
"""
|
72 |
-
|
|
|
|
|
73 |
"text": self.handle_text,
|
74 |
"pdf": self.handle_pdf,
|
75 |
"csv": self.handle_csv,
|
76 |
"api": self.handle_api,
|
77 |
"db": self.handle_db,
|
78 |
}
|
79 |
-
|
80 |
def _initialize_session_state(self) -> None:
|
81 |
-
"""
|
82 |
-
|
|
|
|
|
83 |
"config": {
|
84 |
"provider": "OpenAI",
|
85 |
"model": "gpt-4-turbo",
|
@@ -93,38 +111,50 @@ class QADataGenerator:
|
|
93 |
for key, value in defaults.items():
|
94 |
if key not in st.session_state:
|
95 |
st.session_state[key] = value
|
96 |
-
|
97 |
def log_error(self, message: str) -> None:
|
98 |
-
"""
|
|
|
|
|
99 |
st.session_state.error_logs.append(message)
|
100 |
st.error(message)
|
101 |
-
|
|
|
102 |
# ----- Input Handlers -----
|
103 |
def handle_text(self, text: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
104 |
return {"data": text, "source": "text"}
|
105 |
-
|
106 |
def handle_pdf(self, file) -> Dict[str, Any]:
|
|
|
|
|
|
|
107 |
try:
|
108 |
with pdfplumber.open(file) as pdf:
|
109 |
-
full_text = ""
|
110 |
-
|
111 |
-
page_text = page.extract_text() or ""
|
112 |
-
full_text += page_text + "\n"
|
113 |
-
return {"data": full_text, "source": "pdf"}
|
114 |
except Exception as e:
|
115 |
self.log_error(f"PDF Processing Error: {e}")
|
116 |
return {"data": "", "source": "pdf"}
|
117 |
-
|
118 |
def handle_csv(self, file) -> Dict[str, Any]:
|
|
|
|
|
|
|
119 |
try:
|
120 |
df = pd.read_csv(file)
|
121 |
-
|
122 |
-
return {"data":
|
123 |
except Exception as e:
|
124 |
self.log_error(f"CSV Processing Error: {e}")
|
125 |
return {"data": "", "source": "csv"}
|
126 |
-
|
127 |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
|
|
|
|
|
|
128 |
try:
|
129 |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
130 |
response.raise_for_status()
|
@@ -132,8 +162,11 @@ class QADataGenerator:
|
|
132 |
except Exception as e:
|
133 |
self.log_error(f"API Processing Error: {e}")
|
134 |
return {"data": "", "source": "api"}
|
135 |
-
|
136 |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
|
|
|
|
|
|
137 |
try:
|
138 |
engine = sqlalchemy.create_engine(config["connection"])
|
139 |
with engine.connect() as conn:
|
@@ -143,15 +176,17 @@ class QADataGenerator:
|
|
143 |
except Exception as e:
|
144 |
self.log_error(f"Database Processing Error: {e}")
|
145 |
return {"data": "", "source": "db"}
|
146 |
-
|
147 |
def aggregate_inputs(self) -> str:
|
148 |
-
"""
|
|
|
|
|
149 |
aggregated_data = ""
|
150 |
for item in st.session_state.inputs:
|
151 |
aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
|
152 |
aggregated_data += item.get("data", "") + "\n\n"
|
153 |
return aggregated_data.strip()
|
154 |
-
|
155 |
def build_prompt(self) -> str:
|
156 |
"""
|
157 |
Build the complete prompt using the custom template and aggregated inputs.
|
@@ -161,50 +196,52 @@ class QADataGenerator:
|
|
161 |
st.write("### Built Prompt")
|
162 |
st.write(prompt)
|
163 |
return prompt
|
164 |
-
|
165 |
def generate_qa_pairs(self) -> bool:
|
166 |
"""
|
167 |
Generate Q&A pairs by sending the built prompt to the selected LLM provider.
|
168 |
"""
|
169 |
-
api_key = st.session_state.api_key
|
170 |
if not api_key:
|
171 |
self.log_error("API key is missing!")
|
172 |
return False
|
173 |
-
|
174 |
-
provider_name = st.session_state.config["provider"]
|
175 |
-
provider_cfg = self.providers.get(provider_name)
|
176 |
if not provider_cfg:
|
177 |
self.log_error(f"Provider {provider_name} is not configured.")
|
178 |
return False
|
179 |
-
|
180 |
-
client_initializer = provider_cfg["client"]
|
181 |
client = client_initializer(api_key)
|
182 |
-
model = st.session_state.config["model"]
|
183 |
-
temperature = st.session_state.config["temperature"]
|
184 |
-
prompt = self.build_prompt()
|
185 |
-
|
186 |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
|
187 |
try:
|
188 |
if provider_name == "HuggingFace":
|
189 |
response = self._huggingface_inference(client, prompt, model)
|
190 |
else:
|
191 |
response = self._standard_inference(client, prompt, model, temperature)
|
192 |
-
|
193 |
st.write("### Raw API Response")
|
194 |
st.write(response)
|
195 |
-
|
196 |
qa_pairs = self._parse_response(response, provider_name)
|
197 |
st.write("### Parsed Q&A Pairs")
|
198 |
st.write(qa_pairs)
|
199 |
-
|
200 |
st.session_state.qa_pairs = qa_pairs
|
201 |
return True
|
202 |
except Exception as e:
|
203 |
self.log_error(f"Generation failed: {e}")
|
204 |
return False
|
205 |
-
|
206 |
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
|
207 |
-
"""
|
|
|
|
|
208 |
try:
|
209 |
st.write("Sending prompt via standard inference...")
|
210 |
result = client.chat.completions.create(
|
@@ -217,9 +254,11 @@ class QADataGenerator:
|
|
217 |
except Exception as e:
|
218 |
self.log_error(f"Standard Inference Error: {e}")
|
219 |
return None
|
220 |
-
|
221 |
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
|
222 |
-
"""
|
|
|
|
|
223 |
try:
|
224 |
st.write("Sending prompt to HuggingFace API...")
|
225 |
response = requests.post(
|
@@ -234,7 +273,7 @@ class QADataGenerator:
|
|
234 |
except Exception as e:
|
235 |
self.log_error(f"HuggingFace Inference Error: {e}")
|
236 |
return None
|
237 |
-
|
238 |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
239 |
"""
|
240 |
Parse the LLM response and return a list of Q&A pairs.
|
@@ -250,15 +289,14 @@ class QADataGenerator:
|
|
250 |
self.log_error("Unexpected HuggingFace response format.")
|
251 |
return []
|
252 |
else:
|
253 |
-
# For OpenAI
|
254 |
-
# response.choices[0].message.content
|
255 |
if response and hasattr(response, "choices") and response.choices:
|
256 |
raw_text = response.choices[0].message.content
|
257 |
else:
|
258 |
self.log_error("Unexpected response format from provider.")
|
259 |
return []
|
260 |
-
|
261 |
-
#
|
262 |
try:
|
263 |
qa_list = json.loads(raw_text)
|
264 |
if isinstance(qa_list, list):
|
@@ -276,28 +314,33 @@ class QADataGenerator:
|
|
276 |
|
277 |
# ============ UI Components ============
|
278 |
|
279 |
-
def config_ui(generator: QADataGenerator):
|
280 |
-
"""
|
|
|
|
|
281 |
with st.sidebar:
|
282 |
st.header("Configuration")
|
283 |
provider = st.selectbox("Select Provider", list(generator.providers.keys()))
|
284 |
st.session_state.config["provider"] = provider
|
285 |
provider_cfg = generator.providers[provider]
|
286 |
-
|
287 |
model = st.selectbox("Select Model", provider_cfg["models"])
|
288 |
st.session_state.config["model"] = model
|
289 |
-
|
290 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
291 |
st.session_state.config["temperature"] = temperature
|
292 |
-
|
293 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
294 |
st.session_state.api_key = api_key
|
295 |
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
298 |
st.subheader("Input Data Sources")
|
299 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
300 |
-
|
301 |
with tabs[0]:
|
302 |
text_input = st.text_area("Enter text input", height=150)
|
303 |
if st.button("Add Text Input", key="text_input"):
|
@@ -306,19 +349,19 @@ def input_ui(generator: QADataGenerator):
|
|
306 |
st.success("Text input added!")
|
307 |
else:
|
308 |
st.warning("Empty text input.")
|
309 |
-
|
310 |
with tabs[1]:
|
311 |
pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
|
312 |
if pdf_file is not None:
|
313 |
st.session_state.inputs.append(generator.handle_pdf(pdf_file))
|
314 |
st.success("PDF input added!")
|
315 |
-
|
316 |
with tabs[2]:
|
317 |
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
318 |
if csv_file is not None:
|
319 |
st.session_state.inputs.append(generator.handle_csv(csv_file))
|
320 |
st.success("CSV input added!")
|
321 |
-
|
322 |
with tabs[3]:
|
323 |
api_url = st.text_input("API Endpoint URL")
|
324 |
api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
|
@@ -331,7 +374,7 @@ def input_ui(generator: QADataGenerator):
|
|
331 |
generator.log_error(f"Invalid JSON for API Headers: {e}")
|
332 |
st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
|
333 |
st.success("API input added!")
|
334 |
-
|
335 |
with tabs[4]:
|
336 |
db_conn = st.text_input("Database Connection String")
|
337 |
db_query = st.text_area("Database Query", height=100)
|
@@ -339,8 +382,11 @@ def input_ui(generator: QADataGenerator):
|
|
339 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
340 |
st.success("Database input added!")
|
341 |
|
342 |
-
|
343 |
-
|
|
|
|
|
|
|
344 |
st.subheader("Q&A Pairs Output")
|
345 |
if st.session_state.qa_pairs:
|
346 |
st.write("### Generated Q&A Pairs")
|
@@ -349,13 +395,16 @@ def output_ui(generator: QADataGenerator):
|
|
349 |
"Download Output",
|
350 |
json.dumps(st.session_state.qa_pairs, indent=2),
|
351 |
file_name="qa_pairs.json",
|
352 |
-
mime="application/json"
|
353 |
)
|
354 |
else:
|
355 |
st.info("No Q&A pairs generated yet.")
|
356 |
|
357 |
-
|
358 |
-
|
|
|
|
|
|
|
359 |
with st.expander("Error Logs & Debug Info", expanded=False):
|
360 |
if st.session_state.error_logs:
|
361 |
for log in st.session_state.error_logs:
|
@@ -364,7 +413,10 @@ def logs_ui():
|
|
364 |
st.write("No logs yet.")
|
365 |
|
366 |
|
367 |
-
def main():
|
|
|
|
|
|
|
368 |
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
|
369 |
st.title("Advanced Q&A Synthetic Generator")
|
370 |
st.markdown(
|
@@ -373,17 +425,17 @@ def main():
|
|
373 |
from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs.
|
374 |
"""
|
375 |
)
|
376 |
-
|
377 |
# Initialize generator and display configuration UI
|
378 |
generator = QADataGenerator()
|
379 |
config_ui(generator)
|
380 |
-
|
381 |
st.header("1. Input Data")
|
382 |
input_ui(generator)
|
383 |
if st.button("Clear All Inputs"):
|
384 |
st.session_state.inputs = []
|
385 |
st.success("All inputs have been cleared!")
|
386 |
-
|
387 |
st.header("2. Generate Q&A Pairs")
|
388 |
if st.button("Generate Q&A Pairs", key="generate_qa"):
|
389 |
with st.spinner("Generating Q&A pairs..."):
|
@@ -391,13 +443,13 @@ def main():
|
|
391 |
st.success("Q&A pairs generated successfully!")
|
392 |
else:
|
393 |
st.error("Q&A generation failed. Check logs for details.")
|
394 |
-
|
395 |
st.header("3. Output")
|
396 |
output_ui(generator)
|
397 |
-
|
398 |
st.header("4. Logs & Debug Information")
|
399 |
logs_ui()
|
400 |
|
401 |
|
402 |
if __name__ == "__main__":
|
403 |
-
main()
|
|
|
1 |
import json
|
2 |
+
import logging
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
import pdfplumber
|
6 |
import pandas as pd
|
7 |
import sqlalchemy
|
8 |
+
from typing import Any, Dict, List, Callable
|
9 |
|
10 |
# Provider clients – ensure these libraries are installed
|
11 |
try:
|
|
|
18 |
except ImportError:
|
19 |
groq = None
|
20 |
|
21 |
+
# Hugging Face inference endpoint and defaults
|
22 |
+
HF_API_URL: str = "https://api-inference.huggingface.co/models/"
|
23 |
+
DEFAULT_TEMPERATURE: float = 0.1
|
24 |
+
GROQ_MODEL: str = "mixtral-8x7b-32768"
|
25 |
+
|
26 |
+
# Optional: Configure a basic logger for debugging purposes
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
logger.setLevel(logging.DEBUG)
|
29 |
+
if not logger.handlers:
|
30 |
+
ch = logging.StreamHandler()
|
31 |
+
ch.setLevel(logging.DEBUG)
|
32 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
33 |
+
ch.setFormatter(formatter)
|
34 |
+
logger.addHandler(ch)
|
35 |
|
36 |
|
37 |
class QADataGenerator:
|
|
|
39 |
A Q&A Synthetic Generator that extracts and generates question-answer pairs
|
40 |
from various input sources using an LLM provider.
|
41 |
"""
|
42 |
+
|
43 |
def __init__(self) -> None:
|
44 |
self._setup_providers()
|
45 |
self._setup_input_handlers()
|
46 |
self._initialize_session_state()
|
47 |
+
# Prompt template for generating 3 Q&A pairs.
|
48 |
+
self.custom_prompt_template: str = (
|
49 |
"You are an expert in extracting question and answer pairs from documents. "
|
50 |
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
|
51 |
"Each dictionary must have keys 'question' and 'answer'. "
|
|
|
57 |
"{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}]\n\n"
|
58 |
"Now, generate 3 Q&A pairs from this data:\n{data}"
|
59 |
)
|
60 |
+
|
61 |
def _setup_providers(self) -> None:
|
62 |
+
"""
|
63 |
+
Configure available LLM providers and their client initialization routines.
|
64 |
+
"""
|
65 |
self.providers: Dict[str, Dict[str, Any]] = {
|
66 |
"Deepseek": {
|
67 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
|
|
|
80 |
"models": ["gpt2", "llama-2"],
|
81 |
},
|
82 |
}
|
83 |
+
|
84 |
def _setup_input_handlers(self) -> None:
|
85 |
+
"""
|
86 |
+
Register handlers for different input data types.
|
87 |
+
"""
|
88 |
+
self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
|
89 |
"text": self.handle_text,
|
90 |
"pdf": self.handle_pdf,
|
91 |
"csv": self.handle_csv,
|
92 |
"api": self.handle_api,
|
93 |
"db": self.handle_db,
|
94 |
}
|
95 |
+
|
96 |
def _initialize_session_state(self) -> None:
|
97 |
+
"""
|
98 |
+
Initialize Streamlit session state with default configuration.
|
99 |
+
"""
|
100 |
+
defaults: Dict[str, Any] = {
|
101 |
"config": {
|
102 |
"provider": "OpenAI",
|
103 |
"model": "gpt-4-turbo",
|
|
|
111 |
for key, value in defaults.items():
|
112 |
if key not in st.session_state:
|
113 |
st.session_state[key] = value
|
114 |
+
|
115 |
def log_error(self, message: str) -> None:
|
116 |
+
"""
|
117 |
+
Log an error message to session state, display it, and send it to the logger.
|
118 |
+
"""
|
119 |
st.session_state.error_logs.append(message)
|
120 |
st.error(message)
|
121 |
+
logger.error(message)
|
122 |
+
|
123 |
# ----- Input Handlers -----
|
124 |
def handle_text(self, text: str) -> Dict[str, Any]:
|
125 |
+
"""
|
126 |
+
Process plain text input.
|
127 |
+
"""
|
128 |
return {"data": text, "source": "text"}
|
129 |
+
|
130 |
def handle_pdf(self, file) -> Dict[str, Any]:
|
131 |
+
"""
|
132 |
+
Extract text from a PDF file.
|
133 |
+
"""
|
134 |
try:
|
135 |
with pdfplumber.open(file) as pdf:
|
136 |
+
full_text = "\n".join(page.extract_text() or "" for page in pdf.pages)
|
137 |
+
return {"data": full_text, "source": "pdf"}
|
|
|
|
|
|
|
138 |
except Exception as e:
|
139 |
self.log_error(f"PDF Processing Error: {e}")
|
140 |
return {"data": "", "source": "pdf"}
|
141 |
+
|
142 |
def handle_csv(self, file) -> Dict[str, Any]:
|
143 |
+
"""
|
144 |
+
Process a CSV file by reading it into a DataFrame and converting it to JSON.
|
145 |
+
"""
|
146 |
try:
|
147 |
df = pd.read_csv(file)
|
148 |
+
json_data = df.to_json(orient="records")
|
149 |
+
return {"data": json_data, "source": "csv"}
|
150 |
except Exception as e:
|
151 |
self.log_error(f"CSV Processing Error: {e}")
|
152 |
return {"data": "", "source": "csv"}
|
153 |
+
|
154 |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
155 |
+
"""
|
156 |
+
Fetch data from an API endpoint.
|
157 |
+
"""
|
158 |
try:
|
159 |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
160 |
response.raise_for_status()
|
|
|
162 |
except Exception as e:
|
163 |
self.log_error(f"API Processing Error: {e}")
|
164 |
return {"data": "", "source": "api"}
|
165 |
+
|
166 |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
167 |
+
"""
|
168 |
+
Query a database using the provided connection string and SQL query.
|
169 |
+
"""
|
170 |
try:
|
171 |
engine = sqlalchemy.create_engine(config["connection"])
|
172 |
with engine.connect() as conn:
|
|
|
176 |
except Exception as e:
|
177 |
self.log_error(f"Database Processing Error: {e}")
|
178 |
return {"data": "", "source": "db"}
|
179 |
+
|
180 |
def aggregate_inputs(self) -> str:
|
181 |
+
"""
|
182 |
+
Combine all input sources into a single aggregated string.
|
183 |
+
"""
|
184 |
aggregated_data = ""
|
185 |
for item in st.session_state.inputs:
|
186 |
aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
|
187 |
aggregated_data += item.get("data", "") + "\n\n"
|
188 |
return aggregated_data.strip()
|
189 |
+
|
190 |
def build_prompt(self) -> str:
|
191 |
"""
|
192 |
Build the complete prompt using the custom template and aggregated inputs.
|
|
|
196 |
st.write("### Built Prompt")
|
197 |
st.write(prompt)
|
198 |
return prompt
|
199 |
+
|
200 |
def generate_qa_pairs(self) -> bool:
|
201 |
"""
|
202 |
Generate Q&A pairs by sending the built prompt to the selected LLM provider.
|
203 |
"""
|
204 |
+
api_key: str = st.session_state.api_key
|
205 |
if not api_key:
|
206 |
self.log_error("API key is missing!")
|
207 |
return False
|
208 |
+
|
209 |
+
provider_name: str = st.session_state.config["provider"]
|
210 |
+
provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {})
|
211 |
if not provider_cfg:
|
212 |
self.log_error(f"Provider {provider_name} is not configured.")
|
213 |
return False
|
214 |
+
|
215 |
+
client_initializer: Callable[[str], Any] = provider_cfg["client"]
|
216 |
client = client_initializer(api_key)
|
217 |
+
model: str = st.session_state.config["model"]
|
218 |
+
temperature: float = st.session_state.config["temperature"]
|
219 |
+
prompt: str = self.build_prompt()
|
220 |
+
|
221 |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
|
222 |
try:
|
223 |
if provider_name == "HuggingFace":
|
224 |
response = self._huggingface_inference(client, prompt, model)
|
225 |
else:
|
226 |
response = self._standard_inference(client, prompt, model, temperature)
|
227 |
+
|
228 |
st.write("### Raw API Response")
|
229 |
st.write(response)
|
230 |
+
|
231 |
qa_pairs = self._parse_response(response, provider_name)
|
232 |
st.write("### Parsed Q&A Pairs")
|
233 |
st.write(qa_pairs)
|
234 |
+
|
235 |
st.session_state.qa_pairs = qa_pairs
|
236 |
return True
|
237 |
except Exception as e:
|
238 |
self.log_error(f"Generation failed: {e}")
|
239 |
return False
|
240 |
+
|
241 |
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
|
242 |
+
"""
|
243 |
+
Inference method for providers using an OpenAI-compatible API.
|
244 |
+
"""
|
245 |
try:
|
246 |
st.write("Sending prompt via standard inference...")
|
247 |
result = client.chat.completions.create(
|
|
|
254 |
except Exception as e:
|
255 |
self.log_error(f"Standard Inference Error: {e}")
|
256 |
return None
|
257 |
+
|
258 |
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
|
259 |
+
"""
|
260 |
+
Inference method for the Hugging Face Inference API.
|
261 |
+
"""
|
262 |
try:
|
263 |
st.write("Sending prompt to HuggingFace API...")
|
264 |
response = requests.post(
|
|
|
273 |
except Exception as e:
|
274 |
self.log_error(f"HuggingFace Inference Error: {e}")
|
275 |
return None
|
276 |
+
|
277 |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
278 |
"""
|
279 |
Parse the LLM response and return a list of Q&A pairs.
|
|
|
289 |
self.log_error("Unexpected HuggingFace response format.")
|
290 |
return []
|
291 |
else:
|
292 |
+
# For OpenAI and similar providers, expect response.choices[0].message.content
|
|
|
293 |
if response and hasattr(response, "choices") and response.choices:
|
294 |
raw_text = response.choices[0].message.content
|
295 |
else:
|
296 |
self.log_error("Unexpected response format from provider.")
|
297 |
return []
|
298 |
+
|
299 |
+
# Parse the raw text as JSON
|
300 |
try:
|
301 |
qa_list = json.loads(raw_text)
|
302 |
if isinstance(qa_list, list):
|
|
|
314 |
|
315 |
# ============ UI Components ============
|
316 |
|
317 |
+
def config_ui(generator: QADataGenerator) -> None:
|
318 |
+
"""
|
319 |
+
Display configuration options in the sidebar.
|
320 |
+
"""
|
321 |
with st.sidebar:
|
322 |
st.header("Configuration")
|
323 |
provider = st.selectbox("Select Provider", list(generator.providers.keys()))
|
324 |
st.session_state.config["provider"] = provider
|
325 |
provider_cfg = generator.providers[provider]
|
326 |
+
|
327 |
model = st.selectbox("Select Model", provider_cfg["models"])
|
328 |
st.session_state.config["model"] = model
|
329 |
+
|
330 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
331 |
st.session_state.config["temperature"] = temperature
|
332 |
+
|
333 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
334 |
st.session_state.api_key = api_key
|
335 |
|
336 |
+
|
337 |
+
def input_ui(generator: QADataGenerator) -> None:
|
338 |
+
"""
|
339 |
+
Display input data source options using tabs.
|
340 |
+
"""
|
341 |
st.subheader("Input Data Sources")
|
342 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
343 |
+
|
344 |
with tabs[0]:
|
345 |
text_input = st.text_area("Enter text input", height=150)
|
346 |
if st.button("Add Text Input", key="text_input"):
|
|
|
349 |
st.success("Text input added!")
|
350 |
else:
|
351 |
st.warning("Empty text input.")
|
352 |
+
|
353 |
with tabs[1]:
|
354 |
pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
|
355 |
if pdf_file is not None:
|
356 |
st.session_state.inputs.append(generator.handle_pdf(pdf_file))
|
357 |
st.success("PDF input added!")
|
358 |
+
|
359 |
with tabs[2]:
|
360 |
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
361 |
if csv_file is not None:
|
362 |
st.session_state.inputs.append(generator.handle_csv(csv_file))
|
363 |
st.success("CSV input added!")
|
364 |
+
|
365 |
with tabs[3]:
|
366 |
api_url = st.text_input("API Endpoint URL")
|
367 |
api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
|
|
|
374 |
generator.log_error(f"Invalid JSON for API Headers: {e}")
|
375 |
st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
|
376 |
st.success("API input added!")
|
377 |
+
|
378 |
with tabs[4]:
|
379 |
db_conn = st.text_input("Database Connection String")
|
380 |
db_query = st.text_area("Database Query", height=100)
|
|
|
382 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
383 |
st.success("Database input added!")
|
384 |
|
385 |
+
|
386 |
+
def output_ui(generator: QADataGenerator) -> None:
|
387 |
+
"""
|
388 |
+
Display the generated Q&A pairs and provide a download option.
|
389 |
+
"""
|
390 |
st.subheader("Q&A Pairs Output")
|
391 |
if st.session_state.qa_pairs:
|
392 |
st.write("### Generated Q&A Pairs")
|
|
|
395 |
"Download Output",
|
396 |
json.dumps(st.session_state.qa_pairs, indent=2),
|
397 |
file_name="qa_pairs.json",
|
398 |
+
mime="application/json",
|
399 |
)
|
400 |
else:
|
401 |
st.info("No Q&A pairs generated yet.")
|
402 |
|
403 |
+
|
404 |
+
def logs_ui() -> None:
|
405 |
+
"""
|
406 |
+
Display error logs and debugging information in an expandable section.
|
407 |
+
"""
|
408 |
with st.expander("Error Logs & Debug Info", expanded=False):
|
409 |
if st.session_state.error_logs:
|
410 |
for log in st.session_state.error_logs:
|
|
|
413 |
st.write("No logs yet.")
|
414 |
|
415 |
|
416 |
+
def main() -> None:
|
417 |
+
"""
|
418 |
+
Main Streamlit application entry point.
|
419 |
+
"""
|
420 |
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
|
421 |
st.title("Advanced Q&A Synthetic Generator")
|
422 |
st.markdown(
|
|
|
425 |
from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs.
|
426 |
"""
|
427 |
)
|
428 |
+
|
429 |
# Initialize generator and display configuration UI
|
430 |
generator = QADataGenerator()
|
431 |
config_ui(generator)
|
432 |
+
|
433 |
st.header("1. Input Data")
|
434 |
input_ui(generator)
|
435 |
if st.button("Clear All Inputs"):
|
436 |
st.session_state.inputs = []
|
437 |
st.success("All inputs have been cleared!")
|
438 |
+
|
439 |
st.header("2. Generate Q&A Pairs")
|
440 |
if st.button("Generate Q&A Pairs", key="generate_qa"):
|
441 |
with st.spinner("Generating Q&A pairs..."):
|
|
|
443 |
st.success("Q&A pairs generated successfully!")
|
444 |
else:
|
445 |
st.error("Q&A generation failed. Check logs for details.")
|
446 |
+
|
447 |
st.header("3. Output")
|
448 |
output_ui(generator)
|
449 |
+
|
450 |
st.header("4. Logs & Debug Information")
|
451 |
logs_ui()
|
452 |
|
453 |
|
454 |
if __name__ == "__main__":
|
455 |
+
main()
|