Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
import json
|
2 |
-
import logging
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
import pdfplumber
|
6 |
import pandas as pd
|
7 |
import sqlalchemy
|
8 |
-
from typing import Any, Dict, List
|
9 |
|
10 |
# Provider clients – ensure these libraries are installed
|
11 |
try:
|
@@ -18,20 +17,10 @@ try:
|
|
18 |
except ImportError:
|
19 |
groq = None
|
20 |
|
21 |
-
# Hugging Face inference endpoint
|
22 |
-
HF_API_URL
|
23 |
-
DEFAULT_TEMPERATURE
|
24 |
-
GROQ_MODEL
|
25 |
-
|
26 |
-
# Optional: Configure a basic logger for debugging purposes
|
27 |
-
logger = logging.getLogger(__name__)
|
28 |
-
logger.setLevel(logging.DEBUG)
|
29 |
-
if not logger.handlers:
|
30 |
-
ch = logging.StreamHandler()
|
31 |
-
ch.setLevel(logging.DEBUG)
|
32 |
-
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
33 |
-
ch.setFormatter(formatter)
|
34 |
-
logger.addHandler(ch)
|
35 |
|
36 |
|
37 |
class QADataGenerator:
|
@@ -39,29 +28,26 @@ class QADataGenerator:
|
|
39 |
A Q&A Synthetic Generator that extracts and generates question-answer pairs
|
40 |
from various input sources using an LLM provider.
|
41 |
"""
|
42 |
-
|
43 |
def __init__(self) -> None:
|
44 |
self._setup_providers()
|
45 |
self._setup_input_handlers()
|
46 |
self._initialize_session_state()
|
47 |
-
#
|
48 |
-
self.custom_prompt_template
|
49 |
"You are an expert in extracting question and answer pairs from documents. "
|
50 |
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
|
51 |
"Each dictionary must have keys 'question' and 'answer'. "
|
52 |
"The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
|
53 |
"Do not hallucinate. \n\n"
|
54 |
"Example JSON Output:\n"
|
55 |
-
"[{'question': 'What is the capital of France?', 'answer': 'Paris'}, "
|
56 |
-
"{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}, "
|
57 |
-
"{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}]\n\n"
|
58 |
"Now, generate 3 Q&A pairs from this data:\n{data}"
|
59 |
)
|
60 |
-
|
61 |
def _setup_providers(self) -> None:
|
62 |
-
"""
|
63 |
-
Configure available LLM providers and their client initialization routines.
|
64 |
-
"""
|
65 |
self.providers: Dict[str, Dict[str, Any]] = {
|
66 |
"Deepseek": {
|
67 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
|
@@ -80,24 +66,20 @@ class QADataGenerator:
|
|
80 |
"models": ["gpt2", "llama-2"],
|
81 |
},
|
82 |
}
|
83 |
-
|
84 |
def _setup_input_handlers(self) -> None:
|
85 |
-
"""
|
86 |
-
|
87 |
-
"""
|
88 |
-
self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
|
89 |
"text": self.handle_text,
|
90 |
"pdf": self.handle_pdf,
|
91 |
"csv": self.handle_csv,
|
92 |
"api": self.handle_api,
|
93 |
"db": self.handle_db,
|
94 |
}
|
95 |
-
|
96 |
def _initialize_session_state(self) -> None:
|
97 |
-
"""
|
98 |
-
|
99 |
-
"""
|
100 |
-
defaults: Dict[str, Any] = {
|
101 |
"config": {
|
102 |
"provider": "OpenAI",
|
103 |
"model": "gpt-4-turbo",
|
@@ -111,50 +93,38 @@ class QADataGenerator:
|
|
111 |
for key, value in defaults.items():
|
112 |
if key not in st.session_state:
|
113 |
st.session_state[key] = value
|
114 |
-
|
115 |
def log_error(self, message: str) -> None:
|
116 |
-
"""
|
117 |
-
Log an error message to session state, display it, and send it to the logger.
|
118 |
-
"""
|
119 |
st.session_state.error_logs.append(message)
|
120 |
st.error(message)
|
121 |
-
|
122 |
-
|
123 |
# ----- Input Handlers -----
|
124 |
def handle_text(self, text: str) -> Dict[str, Any]:
|
125 |
-
"""
|
126 |
-
Process plain text input.
|
127 |
-
"""
|
128 |
return {"data": text, "source": "text"}
|
129 |
-
|
130 |
def handle_pdf(self, file) -> Dict[str, Any]:
|
131 |
-
"""
|
132 |
-
Extract text from a PDF file.
|
133 |
-
"""
|
134 |
try:
|
135 |
with pdfplumber.open(file) as pdf:
|
136 |
-
full_text = "
|
137 |
-
|
|
|
|
|
|
|
138 |
except Exception as e:
|
139 |
self.log_error(f"PDF Processing Error: {e}")
|
140 |
return {"data": "", "source": "pdf"}
|
141 |
-
|
142 |
def handle_csv(self, file) -> Dict[str, Any]:
|
143 |
-
"""
|
144 |
-
Process a CSV file by reading it into a DataFrame and converting it to JSON.
|
145 |
-
"""
|
146 |
try:
|
147 |
df = pd.read_csv(file)
|
148 |
-
|
149 |
-
return {"data":
|
150 |
except Exception as e:
|
151 |
self.log_error(f"CSV Processing Error: {e}")
|
152 |
return {"data": "", "source": "csv"}
|
153 |
-
|
154 |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
155 |
-
"""
|
156 |
-
Fetch data from an API endpoint.
|
157 |
-
"""
|
158 |
try:
|
159 |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
160 |
response.raise_for_status()
|
@@ -162,11 +132,8 @@ class QADataGenerator:
|
|
162 |
except Exception as e:
|
163 |
self.log_error(f"API Processing Error: {e}")
|
164 |
return {"data": "", "source": "api"}
|
165 |
-
|
166 |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
167 |
-
"""
|
168 |
-
Query a database using the provided connection string and SQL query.
|
169 |
-
"""
|
170 |
try:
|
171 |
engine = sqlalchemy.create_engine(config["connection"])
|
172 |
with engine.connect() as conn:
|
@@ -176,17 +143,15 @@ class QADataGenerator:
|
|
176 |
except Exception as e:
|
177 |
self.log_error(f"Database Processing Error: {e}")
|
178 |
return {"data": "", "source": "db"}
|
179 |
-
|
180 |
def aggregate_inputs(self) -> str:
|
181 |
-
"""
|
182 |
-
Combine all input sources into a single aggregated string.
|
183 |
-
"""
|
184 |
aggregated_data = ""
|
185 |
for item in st.session_state.inputs:
|
186 |
aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
|
187 |
aggregated_data += item.get("data", "") + "\n\n"
|
188 |
return aggregated_data.strip()
|
189 |
-
|
190 |
def build_prompt(self) -> str:
|
191 |
"""
|
192 |
Build the complete prompt using the custom template and aggregated inputs.
|
@@ -196,52 +161,50 @@ class QADataGenerator:
|
|
196 |
st.write("### Built Prompt")
|
197 |
st.write(prompt)
|
198 |
return prompt
|
199 |
-
|
200 |
def generate_qa_pairs(self) -> bool:
|
201 |
"""
|
202 |
Generate Q&A pairs by sending the built prompt to the selected LLM provider.
|
203 |
"""
|
204 |
-
api_key
|
205 |
if not api_key:
|
206 |
self.log_error("API key is missing!")
|
207 |
return False
|
208 |
-
|
209 |
-
provider_name
|
210 |
-
provider_cfg
|
211 |
if not provider_cfg:
|
212 |
self.log_error(f"Provider {provider_name} is not configured.")
|
213 |
return False
|
214 |
-
|
215 |
-
client_initializer
|
216 |
client = client_initializer(api_key)
|
217 |
-
model
|
218 |
-
temperature
|
219 |
-
prompt
|
220 |
-
|
221 |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
|
222 |
try:
|
223 |
if provider_name == "HuggingFace":
|
224 |
response = self._huggingface_inference(client, prompt, model)
|
225 |
else:
|
226 |
response = self._standard_inference(client, prompt, model, temperature)
|
227 |
-
|
228 |
st.write("### Raw API Response")
|
229 |
st.write(response)
|
230 |
-
|
231 |
qa_pairs = self._parse_response(response, provider_name)
|
232 |
st.write("### Parsed Q&A Pairs")
|
233 |
st.write(qa_pairs)
|
234 |
-
|
235 |
st.session_state.qa_pairs = qa_pairs
|
236 |
return True
|
237 |
except Exception as e:
|
238 |
self.log_error(f"Generation failed: {e}")
|
239 |
return False
|
240 |
-
|
241 |
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
|
242 |
-
"""
|
243 |
-
Inference method for providers using an OpenAI-compatible API.
|
244 |
-
"""
|
245 |
try:
|
246 |
st.write("Sending prompt via standard inference...")
|
247 |
result = client.chat.completions.create(
|
@@ -254,11 +217,9 @@ class QADataGenerator:
|
|
254 |
except Exception as e:
|
255 |
self.log_error(f"Standard Inference Error: {e}")
|
256 |
return None
|
257 |
-
|
258 |
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
|
259 |
-
"""
|
260 |
-
Inference method for the Hugging Face Inference API.
|
261 |
-
"""
|
262 |
try:
|
263 |
st.write("Sending prompt to HuggingFace API...")
|
264 |
response = requests.post(
|
@@ -273,7 +234,7 @@ class QADataGenerator:
|
|
273 |
except Exception as e:
|
274 |
self.log_error(f"HuggingFace Inference Error: {e}")
|
275 |
return None
|
276 |
-
|
277 |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
278 |
"""
|
279 |
Parse the LLM response and return a list of Q&A pairs.
|
@@ -289,14 +250,15 @@ class QADataGenerator:
|
|
289 |
self.log_error("Unexpected HuggingFace response format.")
|
290 |
return []
|
291 |
else:
|
292 |
-
# For OpenAI and similar providers
|
|
|
293 |
if response and hasattr(response, "choices") and response.choices:
|
294 |
raw_text = response.choices[0].message.content
|
295 |
else:
|
296 |
self.log_error("Unexpected response format from provider.")
|
297 |
return []
|
298 |
-
|
299 |
-
#
|
300 |
try:
|
301 |
qa_list = json.loads(raw_text)
|
302 |
if isinstance(qa_list, list):
|
@@ -314,33 +276,28 @@ class QADataGenerator:
|
|
314 |
|
315 |
# ============ UI Components ============
|
316 |
|
317 |
-
def config_ui(generator: QADataGenerator)
|
318 |
-
"""
|
319 |
-
Display configuration options in the sidebar.
|
320 |
-
"""
|
321 |
with st.sidebar:
|
322 |
st.header("Configuration")
|
323 |
provider = st.selectbox("Select Provider", list(generator.providers.keys()))
|
324 |
st.session_state.config["provider"] = provider
|
325 |
provider_cfg = generator.providers[provider]
|
326 |
-
|
327 |
model = st.selectbox("Select Model", provider_cfg["models"])
|
328 |
st.session_state.config["model"] = model
|
329 |
-
|
330 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
331 |
st.session_state.config["temperature"] = temperature
|
332 |
-
|
333 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
334 |
st.session_state.api_key = api_key
|
335 |
|
336 |
-
|
337 |
-
|
338 |
-
"""
|
339 |
-
Display input data source options using tabs.
|
340 |
-
"""
|
341 |
st.subheader("Input Data Sources")
|
342 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
343 |
-
|
344 |
with tabs[0]:
|
345 |
text_input = st.text_area("Enter text input", height=150)
|
346 |
if st.button("Add Text Input", key="text_input"):
|
@@ -349,19 +306,19 @@ def input_ui(generator: QADataGenerator) -> None:
|
|
349 |
st.success("Text input added!")
|
350 |
else:
|
351 |
st.warning("Empty text input.")
|
352 |
-
|
353 |
with tabs[1]:
|
354 |
pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
|
355 |
if pdf_file is not None:
|
356 |
st.session_state.inputs.append(generator.handle_pdf(pdf_file))
|
357 |
st.success("PDF input added!")
|
358 |
-
|
359 |
with tabs[2]:
|
360 |
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
361 |
if csv_file is not None:
|
362 |
st.session_state.inputs.append(generator.handle_csv(csv_file))
|
363 |
st.success("CSV input added!")
|
364 |
-
|
365 |
with tabs[3]:
|
366 |
api_url = st.text_input("API Endpoint URL")
|
367 |
api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
|
@@ -374,7 +331,7 @@ def input_ui(generator: QADataGenerator) -> None:
|
|
374 |
generator.log_error(f"Invalid JSON for API Headers: {e}")
|
375 |
st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
|
376 |
st.success("API input added!")
|
377 |
-
|
378 |
with tabs[4]:
|
379 |
db_conn = st.text_input("Database Connection String")
|
380 |
db_query = st.text_area("Database Query", height=100)
|
@@ -382,11 +339,8 @@ def input_ui(generator: QADataGenerator) -> None:
|
|
382 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
383 |
st.success("Database input added!")
|
384 |
|
385 |
-
|
386 |
-
|
387 |
-
"""
|
388 |
-
Display the generated Q&A pairs and provide a download option.
|
389 |
-
"""
|
390 |
st.subheader("Q&A Pairs Output")
|
391 |
if st.session_state.qa_pairs:
|
392 |
st.write("### Generated Q&A Pairs")
|
@@ -395,16 +349,13 @@ def output_ui(generator: QADataGenerator) -> None:
|
|
395 |
"Download Output",
|
396 |
json.dumps(st.session_state.qa_pairs, indent=2),
|
397 |
file_name="qa_pairs.json",
|
398 |
-
mime="application/json"
|
399 |
)
|
400 |
else:
|
401 |
st.info("No Q&A pairs generated yet.")
|
402 |
|
403 |
-
|
404 |
-
|
405 |
-
"""
|
406 |
-
Display error logs and debugging information in an expandable section.
|
407 |
-
"""
|
408 |
with st.expander("Error Logs & Debug Info", expanded=False):
|
409 |
if st.session_state.error_logs:
|
410 |
for log in st.session_state.error_logs:
|
@@ -412,11 +363,7 @@ def logs_ui() -> None:
|
|
412 |
else:
|
413 |
st.write("No logs yet.")
|
414 |
|
415 |
-
|
416 |
-
def main() -> None:
|
417 |
-
"""
|
418 |
-
Main Streamlit application entry point.
|
419 |
-
"""
|
420 |
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
|
421 |
st.title("Advanced Q&A Synthetic Generator")
|
422 |
st.markdown(
|
@@ -425,17 +372,17 @@ def main() -> None:
|
|
425 |
from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs.
|
426 |
"""
|
427 |
)
|
428 |
-
|
429 |
# Initialize generator and display configuration UI
|
430 |
generator = QADataGenerator()
|
431 |
config_ui(generator)
|
432 |
-
|
433 |
st.header("1. Input Data")
|
434 |
input_ui(generator)
|
435 |
if st.button("Clear All Inputs"):
|
436 |
st.session_state.inputs = []
|
437 |
st.success("All inputs have been cleared!")
|
438 |
-
|
439 |
st.header("2. Generate Q&A Pairs")
|
440 |
if st.button("Generate Q&A Pairs", key="generate_qa"):
|
441 |
with st.spinner("Generating Q&A pairs..."):
|
@@ -443,13 +390,12 @@ def main() -> None:
|
|
443 |
st.success("Q&A pairs generated successfully!")
|
444 |
else:
|
445 |
st.error("Q&A generation failed. Check logs for details.")
|
446 |
-
|
447 |
st.header("3. Output")
|
448 |
output_ui(generator)
|
449 |
-
|
450 |
st.header("4. Logs & Debug Information")
|
451 |
logs_ui()
|
452 |
|
453 |
-
|
454 |
if __name__ == "__main__":
|
455 |
main()
|
|
|
1 |
import json
|
|
|
2 |
import requests
|
3 |
import streamlit as st
|
4 |
import pdfplumber
|
5 |
import pandas as pd
|
6 |
import sqlalchemy
|
7 |
+
from typing import Any, Dict, List
|
8 |
|
9 |
# Provider clients – ensure these libraries are installed
|
10 |
try:
|
|
|
17 |
except ImportError:
|
18 |
groq = None
|
19 |
|
20 |
+
# Hugging Face inference endpoint
|
21 |
+
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
22 |
+
DEFAULT_TEMPERATURE = 0.1
|
23 |
+
GROQ_MODEL = "mixtral-8x7b-32768"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
class QADataGenerator:
|
|
|
28 |
A Q&A Synthetic Generator that extracts and generates question-answer pairs
|
29 |
from various input sources using an LLM provider.
|
30 |
"""
|
|
|
31 |
def __init__(self) -> None:
|
32 |
self._setup_providers()
|
33 |
self._setup_input_handlers()
|
34 |
self._initialize_session_state()
|
35 |
+
# Updated prompt template with escaped curly braces
|
36 |
+
self.custom_prompt_template = (
|
37 |
"You are an expert in extracting question and answer pairs from documents. "
|
38 |
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
|
39 |
"Each dictionary must have keys 'question' and 'answer'. "
|
40 |
"The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
|
41 |
"Do not hallucinate. \n\n"
|
42 |
"Example JSON Output:\n"
|
43 |
+
"[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, "
|
44 |
+
"{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, "
|
45 |
+
"{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n"
|
46 |
"Now, generate 3 Q&A pairs from this data:\n{data}"
|
47 |
)
|
48 |
+
|
49 |
def _setup_providers(self) -> None:
|
50 |
+
"""Configure available LLM providers and their client initialization routines."""
|
|
|
|
|
51 |
self.providers: Dict[str, Dict[str, Any]] = {
|
52 |
"Deepseek": {
|
53 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
|
|
|
66 |
"models": ["gpt2", "llama-2"],
|
67 |
},
|
68 |
}
|
69 |
+
|
70 |
def _setup_input_handlers(self) -> None:
|
71 |
+
"""Register handlers for different input data types."""
|
72 |
+
self.input_handlers: Dict[str, Any] = {
|
|
|
|
|
73 |
"text": self.handle_text,
|
74 |
"pdf": self.handle_pdf,
|
75 |
"csv": self.handle_csv,
|
76 |
"api": self.handle_api,
|
77 |
"db": self.handle_db,
|
78 |
}
|
79 |
+
|
80 |
def _initialize_session_state(self) -> None:
|
81 |
+
"""Initialize Streamlit session state with default configuration."""
|
82 |
+
defaults = {
|
|
|
|
|
83 |
"config": {
|
84 |
"provider": "OpenAI",
|
85 |
"model": "gpt-4-turbo",
|
|
|
93 |
for key, value in defaults.items():
|
94 |
if key not in st.session_state:
|
95 |
st.session_state[key] = value
|
96 |
+
|
97 |
def log_error(self, message: str) -> None:
|
98 |
+
"""Log an error message to session state and display it."""
|
|
|
|
|
99 |
st.session_state.error_logs.append(message)
|
100 |
st.error(message)
|
101 |
+
|
|
|
102 |
# ----- Input Handlers -----
|
103 |
def handle_text(self, text: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
104 |
return {"data": text, "source": "text"}
|
105 |
+
|
106 |
def handle_pdf(self, file) -> Dict[str, Any]:
|
|
|
|
|
|
|
107 |
try:
|
108 |
with pdfplumber.open(file) as pdf:
|
109 |
+
full_text = ""
|
110 |
+
for page in pdf.pages:
|
111 |
+
page_text = page.extract_text() or ""
|
112 |
+
full_text += page_text + "\n"
|
113 |
+
return {"data": full_text, "source": "pdf"}
|
114 |
except Exception as e:
|
115 |
self.log_error(f"PDF Processing Error: {e}")
|
116 |
return {"data": "", "source": "pdf"}
|
117 |
+
|
118 |
def handle_csv(self, file) -> Dict[str, Any]:
|
|
|
|
|
|
|
119 |
try:
|
120 |
df = pd.read_csv(file)
|
121 |
+
# Convert the DataFrame to a JSON string
|
122 |
+
return {"data": df.to_json(orient="records"), "source": "csv"}
|
123 |
except Exception as e:
|
124 |
self.log_error(f"CSV Processing Error: {e}")
|
125 |
return {"data": "", "source": "csv"}
|
126 |
+
|
127 |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
|
|
|
|
|
|
128 |
try:
|
129 |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
130 |
response.raise_for_status()
|
|
|
132 |
except Exception as e:
|
133 |
self.log_error(f"API Processing Error: {e}")
|
134 |
return {"data": "", "source": "api"}
|
135 |
+
|
136 |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
|
|
|
|
|
|
137 |
try:
|
138 |
engine = sqlalchemy.create_engine(config["connection"])
|
139 |
with engine.connect() as conn:
|
|
|
143 |
except Exception as e:
|
144 |
self.log_error(f"Database Processing Error: {e}")
|
145 |
return {"data": "", "source": "db"}
|
146 |
+
|
147 |
def aggregate_inputs(self) -> str:
|
148 |
+
"""Combine all input sources into a single aggregated string."""
|
|
|
|
|
149 |
aggregated_data = ""
|
150 |
for item in st.session_state.inputs:
|
151 |
aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
|
152 |
aggregated_data += item.get("data", "") + "\n\n"
|
153 |
return aggregated_data.strip()
|
154 |
+
|
155 |
def build_prompt(self) -> str:
|
156 |
"""
|
157 |
Build the complete prompt using the custom template and aggregated inputs.
|
|
|
161 |
st.write("### Built Prompt")
|
162 |
st.write(prompt)
|
163 |
return prompt
|
164 |
+
|
165 |
def generate_qa_pairs(self) -> bool:
|
166 |
"""
|
167 |
Generate Q&A pairs by sending the built prompt to the selected LLM provider.
|
168 |
"""
|
169 |
+
api_key = st.session_state.api_key
|
170 |
if not api_key:
|
171 |
self.log_error("API key is missing!")
|
172 |
return False
|
173 |
+
|
174 |
+
provider_name = st.session_state.config["provider"]
|
175 |
+
provider_cfg = self.providers.get(provider_name)
|
176 |
if not provider_cfg:
|
177 |
self.log_error(f"Provider {provider_name} is not configured.")
|
178 |
return False
|
179 |
+
|
180 |
+
client_initializer = provider_cfg["client"]
|
181 |
client = client_initializer(api_key)
|
182 |
+
model = st.session_state.config["model"]
|
183 |
+
temperature = st.session_state.config["temperature"]
|
184 |
+
prompt = self.build_prompt()
|
185 |
+
|
186 |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
|
187 |
try:
|
188 |
if provider_name == "HuggingFace":
|
189 |
response = self._huggingface_inference(client, prompt, model)
|
190 |
else:
|
191 |
response = self._standard_inference(client, prompt, model, temperature)
|
192 |
+
|
193 |
st.write("### Raw API Response")
|
194 |
st.write(response)
|
195 |
+
|
196 |
qa_pairs = self._parse_response(response, provider_name)
|
197 |
st.write("### Parsed Q&A Pairs")
|
198 |
st.write(qa_pairs)
|
199 |
+
|
200 |
st.session_state.qa_pairs = qa_pairs
|
201 |
return True
|
202 |
except Exception as e:
|
203 |
self.log_error(f"Generation failed: {e}")
|
204 |
return False
|
205 |
+
|
206 |
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
|
207 |
+
"""Inference method for providers using an OpenAI-compatible API."""
|
|
|
|
|
208 |
try:
|
209 |
st.write("Sending prompt via standard inference...")
|
210 |
result = client.chat.completions.create(
|
|
|
217 |
except Exception as e:
|
218 |
self.log_error(f"Standard Inference Error: {e}")
|
219 |
return None
|
220 |
+
|
221 |
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
|
222 |
+
"""Inference method for the Hugging Face Inference API."""
|
|
|
|
|
223 |
try:
|
224 |
st.write("Sending prompt to HuggingFace API...")
|
225 |
response = requests.post(
|
|
|
234 |
except Exception as e:
|
235 |
self.log_error(f"HuggingFace Inference Error: {e}")
|
236 |
return None
|
237 |
+
|
238 |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
239 |
"""
|
240 |
Parse the LLM response and return a list of Q&A pairs.
|
|
|
250 |
self.log_error("Unexpected HuggingFace response format.")
|
251 |
return []
|
252 |
else:
|
253 |
+
# For OpenAI (and similar providers) assume the response is similar to:
|
254 |
+
# response.choices[0].message.content
|
255 |
if response and hasattr(response, "choices") and response.choices:
|
256 |
raw_text = response.choices[0].message.content
|
257 |
else:
|
258 |
self.log_error("Unexpected response format from provider.")
|
259 |
return []
|
260 |
+
|
261 |
+
# Try parsing the raw text as JSON
|
262 |
try:
|
263 |
qa_list = json.loads(raw_text)
|
264 |
if isinstance(qa_list, list):
|
|
|
276 |
|
277 |
# ============ UI Components ============
|
278 |
|
279 |
+
def config_ui(generator: QADataGenerator):
|
280 |
+
"""Display configuration options in the sidebar."""
|
|
|
|
|
281 |
with st.sidebar:
|
282 |
st.header("Configuration")
|
283 |
provider = st.selectbox("Select Provider", list(generator.providers.keys()))
|
284 |
st.session_state.config["provider"] = provider
|
285 |
provider_cfg = generator.providers[provider]
|
286 |
+
|
287 |
model = st.selectbox("Select Model", provider_cfg["models"])
|
288 |
st.session_state.config["model"] = model
|
289 |
+
|
290 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
291 |
st.session_state.config["temperature"] = temperature
|
292 |
+
|
293 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
294 |
st.session_state.api_key = api_key
|
295 |
|
296 |
+
def input_ui(generator: QADataGenerator):
|
297 |
+
"""Display input data source options using tabs."""
|
|
|
|
|
|
|
298 |
st.subheader("Input Data Sources")
|
299 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
300 |
+
|
301 |
with tabs[0]:
|
302 |
text_input = st.text_area("Enter text input", height=150)
|
303 |
if st.button("Add Text Input", key="text_input"):
|
|
|
306 |
st.success("Text input added!")
|
307 |
else:
|
308 |
st.warning("Empty text input.")
|
309 |
+
|
310 |
with tabs[1]:
|
311 |
pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
|
312 |
if pdf_file is not None:
|
313 |
st.session_state.inputs.append(generator.handle_pdf(pdf_file))
|
314 |
st.success("PDF input added!")
|
315 |
+
|
316 |
with tabs[2]:
|
317 |
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
318 |
if csv_file is not None:
|
319 |
st.session_state.inputs.append(generator.handle_csv(csv_file))
|
320 |
st.success("CSV input added!")
|
321 |
+
|
322 |
with tabs[3]:
|
323 |
api_url = st.text_input("API Endpoint URL")
|
324 |
api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
|
|
|
331 |
generator.log_error(f"Invalid JSON for API Headers: {e}")
|
332 |
st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
|
333 |
st.success("API input added!")
|
334 |
+
|
335 |
with tabs[4]:
|
336 |
db_conn = st.text_input("Database Connection String")
|
337 |
db_query = st.text_area("Database Query", height=100)
|
|
|
339 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
340 |
st.success("Database input added!")
|
341 |
|
342 |
+
def output_ui(generator: QADataGenerator):
|
343 |
+
"""Display the generated Q&A pairs and provide a download option."""
|
|
|
|
|
|
|
344 |
st.subheader("Q&A Pairs Output")
|
345 |
if st.session_state.qa_pairs:
|
346 |
st.write("### Generated Q&A Pairs")
|
|
|
349 |
"Download Output",
|
350 |
json.dumps(st.session_state.qa_pairs, indent=2),
|
351 |
file_name="qa_pairs.json",
|
352 |
+
mime="application/json"
|
353 |
)
|
354 |
else:
|
355 |
st.info("No Q&A pairs generated yet.")
|
356 |
|
357 |
+
def logs_ui():
|
358 |
+
"""Display error logs and debugging information in an expandable section."""
|
|
|
|
|
|
|
359 |
with st.expander("Error Logs & Debug Info", expanded=False):
|
360 |
if st.session_state.error_logs:
|
361 |
for log in st.session_state.error_logs:
|
|
|
363 |
else:
|
364 |
st.write("No logs yet.")
|
365 |
|
366 |
+
def main():
|
|
|
|
|
|
|
|
|
367 |
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
|
368 |
st.title("Advanced Q&A Synthetic Generator")
|
369 |
st.markdown(
|
|
|
372 |
from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs.
|
373 |
"""
|
374 |
)
|
375 |
+
|
376 |
# Initialize generator and display configuration UI
|
377 |
generator = QADataGenerator()
|
378 |
config_ui(generator)
|
379 |
+
|
380 |
st.header("1. Input Data")
|
381 |
input_ui(generator)
|
382 |
if st.button("Clear All Inputs"):
|
383 |
st.session_state.inputs = []
|
384 |
st.success("All inputs have been cleared!")
|
385 |
+
|
386 |
st.header("2. Generate Q&A Pairs")
|
387 |
if st.button("Generate Q&A Pairs", key="generate_qa"):
|
388 |
with st.spinner("Generating Q&A pairs..."):
|
|
|
390 |
st.success("Q&A pairs generated successfully!")
|
391 |
else:
|
392 |
st.error("Q&A generation failed. Check logs for details.")
|
393 |
+
|
394 |
st.header("3. Output")
|
395 |
output_ui(generator)
|
396 |
+
|
397 |
st.header("4. Logs & Debug Information")
|
398 |
logs_ui()
|
399 |
|
|
|
400 |
if __name__ == "__main__":
|
401 |
main()
|