Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import pandas as pd
|
|
6 |
import sqlalchemy
|
7 |
from typing import Any, Dict, List
|
8 |
|
9 |
-
# Provider clients
|
10 |
try:
|
11 |
from openai import OpenAI
|
12 |
except ImportError:
|
@@ -17,35 +17,37 @@ try:
|
|
17 |
except ImportError:
|
18 |
groq = None
|
19 |
|
20 |
-
# Hugging Face
|
21 |
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
22 |
DEFAULT_TEMPERATURE = 0.1
|
23 |
GROQ_MODEL = "mixtral-8x7b-32768"
|
24 |
|
25 |
|
26 |
-
class
|
27 |
"""
|
28 |
-
|
29 |
-
|
30 |
"""
|
31 |
def __init__(self) -> None:
|
32 |
self._setup_providers()
|
33 |
self._setup_input_handlers()
|
34 |
self._initialize_session_state()
|
35 |
-
#
|
36 |
self.custom_prompt_template = (
|
37 |
-
"You are an expert
|
38 |
-
"
|
39 |
-
"
|
40 |
-
"
|
41 |
-
"
|
42 |
-
"
|
43 |
-
"
|
44 |
-
"
|
|
|
|
|
45 |
)
|
46 |
-
|
47 |
def _setup_providers(self) -> None:
|
48 |
-
"""Configure available LLM providers and their initialization routines."""
|
49 |
self.providers: Dict[str, Dict[str, Any]] = {
|
50 |
"Deepseek": {
|
51 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
|
@@ -64,7 +66,7 @@ class AdvancedSyntheticDataGenerator:
|
|
64 |
"models": ["gpt2", "llama-2"],
|
65 |
},
|
66 |
}
|
67 |
-
|
68 |
def _setup_input_handlers(self) -> None:
|
69 |
"""Register handlers for different input data types."""
|
70 |
self.input_handlers: Dict[str, Any] = {
|
@@ -74,7 +76,7 @@ class AdvancedSyntheticDataGenerator:
|
|
74 |
"api": self.handle_api,
|
75 |
"db": self.handle_db,
|
76 |
}
|
77 |
-
|
78 |
def _initialize_session_state(self) -> None:
|
79 |
"""Initialize Streamlit session state with default configuration."""
|
80 |
defaults = {
|
@@ -82,27 +84,25 @@ class AdvancedSyntheticDataGenerator:
|
|
82 |
"provider": "OpenAI",
|
83 |
"model": "gpt-4-turbo",
|
84 |
"temperature": DEFAULT_TEMPERATURE,
|
85 |
-
"output_format": "plain_text", # Options: plain_text, json, csv
|
86 |
},
|
87 |
"api_key": "",
|
88 |
-
"inputs": [],
|
89 |
-
"
|
90 |
-
"
|
91 |
-
"error_logs": [], # Logs for any errors during processing
|
92 |
}
|
93 |
for key, value in defaults.items():
|
94 |
if key not in st.session_state:
|
95 |
st.session_state[key] = value
|
96 |
-
|
97 |
def log_error(self, message: str) -> None:
|
98 |
-
"""Log an error message
|
99 |
st.session_state.error_logs.append(message)
|
100 |
st.error(message)
|
101 |
-
|
102 |
-
#
|
103 |
def handle_text(self, text: str) -> Dict[str, Any]:
|
104 |
return {"data": text, "source": "text"}
|
105 |
-
|
106 |
def handle_pdf(self, file) -> Dict[str, Any]:
|
107 |
try:
|
108 |
with pdfplumber.open(file) as pdf:
|
@@ -114,16 +114,16 @@ class AdvancedSyntheticDataGenerator:
|
|
114 |
except Exception as e:
|
115 |
self.log_error(f"PDF Processing Error: {e}")
|
116 |
return {"data": "", "source": "pdf"}
|
117 |
-
|
118 |
def handle_csv(self, file) -> Dict[str, Any]:
|
119 |
try:
|
120 |
df = pd.read_csv(file)
|
121 |
-
# Convert the DataFrame to JSON
|
122 |
return {"data": df.to_json(orient="records"), "source": "csv"}
|
123 |
except Exception as e:
|
124 |
self.log_error(f"CSV Processing Error: {e}")
|
125 |
return {"data": "", "source": "csv"}
|
126 |
-
|
127 |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
128 |
try:
|
129 |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
@@ -132,7 +132,7 @@ class AdvancedSyntheticDataGenerator:
|
|
132 |
except Exception as e:
|
133 |
self.log_error(f"API Processing Error: {e}")
|
134 |
return {"data": "", "source": "api"}
|
135 |
-
|
136 |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
137 |
try:
|
138 |
engine = sqlalchemy.create_engine(config["connection"])
|
@@ -143,7 +143,7 @@ class AdvancedSyntheticDataGenerator:
|
|
143 |
except Exception as e:
|
144 |
self.log_error(f"Database Processing Error: {e}")
|
145 |
return {"data": "", "source": "db"}
|
146 |
-
|
147 |
def aggregate_inputs(self) -> str:
|
148 |
"""Combine all input sources into a single aggregated string."""
|
149 |
aggregated_data = ""
|
@@ -151,44 +151,38 @@ class AdvancedSyntheticDataGenerator:
|
|
151 |
aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
|
152 |
aggregated_data += item.get("data", "") + "\n\n"
|
153 |
return aggregated_data.strip()
|
154 |
-
|
155 |
def build_prompt(self) -> str:
|
156 |
"""
|
157 |
-
Build the complete prompt using
|
158 |
-
and the desired output format.
|
159 |
"""
|
160 |
-
|
161 |
-
|
162 |
-
output_format = st.session_state.config.get("output_format", "plain_text")
|
163 |
-
prompt = self.custom_prompt_template.format(
|
164 |
-
data=aggregated_data, instructions=instructions, format=output_format
|
165 |
-
)
|
166 |
st.write("### Built Prompt")
|
167 |
st.write(prompt)
|
168 |
return prompt
|
169 |
-
|
170 |
-
def
|
171 |
"""
|
172 |
-
Generate
|
173 |
-
Returns True if generation succeeds.
|
174 |
"""
|
175 |
api_key = st.session_state.api_key
|
176 |
if not api_key:
|
177 |
self.log_error("API key is missing!")
|
178 |
return False
|
179 |
-
|
180 |
provider_name = st.session_state.config["provider"]
|
181 |
provider_cfg = self.providers.get(provider_name)
|
182 |
if not provider_cfg:
|
183 |
self.log_error(f"Provider {provider_name} is not configured.")
|
184 |
return False
|
185 |
-
|
186 |
client_initializer = provider_cfg["client"]
|
187 |
client = client_initializer(api_key)
|
188 |
model = st.session_state.config["model"]
|
189 |
temperature = st.session_state.config["temperature"]
|
190 |
prompt = self.build_prompt()
|
191 |
-
|
192 |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
|
193 |
try:
|
194 |
if provider_name == "HuggingFace":
|
@@ -199,20 +193,18 @@ class AdvancedSyntheticDataGenerator:
|
|
199 |
st.write("### Raw API Response")
|
200 |
st.write(response)
|
201 |
|
202 |
-
|
203 |
-
st.write("### Parsed
|
204 |
-
st.write(
|
205 |
|
206 |
-
st.session_state.
|
207 |
return True
|
208 |
except Exception as e:
|
209 |
self.log_error(f"Generation failed: {e}")
|
210 |
return False
|
211 |
-
|
212 |
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
|
213 |
-
"""
|
214 |
-
Inference for providers using an OpenAI-compatible API.
|
215 |
-
"""
|
216 |
try:
|
217 |
st.write("Sending prompt via standard inference...")
|
218 |
result = client.chat.completions.create(
|
@@ -225,11 +217,9 @@ class AdvancedSyntheticDataGenerator:
|
|
225 |
except Exception as e:
|
226 |
self.log_error(f"Standard Inference Error: {e}")
|
227 |
return None
|
228 |
-
|
229 |
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
|
230 |
-
"""
|
231 |
-
Inference for the Hugging Face Inference API.
|
232 |
-
"""
|
233 |
try:
|
234 |
st.write("Sending prompt to HuggingFace API...")
|
235 |
response = requests.post(
|
@@ -244,62 +234,68 @@ class AdvancedSyntheticDataGenerator:
|
|
244 |
except Exception as e:
|
245 |
self.log_error(f"HuggingFace Inference Error: {e}")
|
246 |
return None
|
247 |
-
|
248 |
-
def _parse_response(self, response: Any, provider: str) -> str:
|
249 |
"""
|
250 |
-
Parse the LLM response
|
|
|
251 |
"""
|
252 |
st.write("Parsing response for provider:", provider)
|
253 |
try:
|
254 |
if provider == "HuggingFace":
|
|
|
255 |
if isinstance(response, list) and response and "generated_text" in response[0]:
|
256 |
-
|
257 |
else:
|
258 |
self.log_error("Unexpected HuggingFace response format.")
|
259 |
-
return
|
260 |
else:
|
261 |
-
#
|
|
|
262 |
if response and hasattr(response, "choices") and response.choices:
|
263 |
-
|
264 |
else:
|
265 |
self.log_error("Unexpected response format from provider.")
|
266 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
except Exception as e:
|
268 |
self.log_error(f"Response Parsing Error: {e}")
|
269 |
-
return
|
270 |
|
271 |
|
272 |
-
#
|
273 |
|
274 |
-
def
|
275 |
-
"""Display
|
276 |
with st.sidebar:
|
277 |
-
st.header("
|
278 |
provider = st.selectbox("Select Provider", list(generator.providers.keys()))
|
279 |
st.session_state.config["provider"] = provider
|
280 |
provider_cfg = generator.providers[provider]
|
281 |
-
|
282 |
model = st.selectbox("Select Model", provider_cfg["models"])
|
283 |
st.session_state.config["model"] = model
|
284 |
-
|
285 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
286 |
st.session_state.config["temperature"] = temperature
|
287 |
-
|
288 |
-
output_format = st.radio("Output Format", ["plain_text", "json", "csv"])
|
289 |
-
st.session_state.config["output_format"] = output_format
|
290 |
-
|
291 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
292 |
st.session_state.api_key = api_key
|
293 |
|
294 |
-
|
295 |
-
"Generate diverse, coherent synthetic data based on the input sources.",
|
296 |
-
height=100)
|
297 |
-
st.session_state.instructions = instructions
|
298 |
-
|
299 |
-
|
300 |
-
def advanced_input_ui(generator: AdvancedSyntheticDataGenerator):
|
301 |
"""Display input data source options using tabs."""
|
302 |
-
st.subheader("
|
303 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
304 |
|
305 |
with tabs[0]:
|
@@ -343,28 +339,23 @@ def advanced_input_ui(generator: AdvancedSyntheticDataGenerator):
|
|
343 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
344 |
st.success("Database input added!")
|
345 |
|
346 |
-
|
347 |
-
|
348 |
-
"
|
349 |
-
st.
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
else:
|
359 |
-
st.text_area("Output", st.session_state.synthetic_data, height=300)
|
360 |
-
st.download_button("Download Output", st.session_state.synthetic_data,
|
361 |
-
file_name="synthetic_data.txt", mime="text/plain")
|
362 |
else:
|
363 |
-
st.info("No
|
364 |
-
|
365 |
|
366 |
-
def
|
367 |
-
"""Display error logs and
|
368 |
with st.expander("Error Logs & Debug Info", expanded=False):
|
369 |
if st.session_state.error_logs:
|
370 |
for log in st.session_state.error_logs:
|
@@ -373,50 +364,39 @@ def advanced_logs_ui():
|
|
373 |
st.write("No logs yet.")
|
374 |
|
375 |
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
st.set_page_config(page_title="Advanced Synthetic Data Generator", layout="wide")
|
380 |
-
|
381 |
-
# Sidebar for advanced configuration
|
382 |
-
generator = AdvancedSyntheticDataGenerator()
|
383 |
-
advanced_config_ui(generator)
|
384 |
-
|
385 |
-
st.title("Advanced Synthetic Data Generator")
|
386 |
st.markdown(
|
387 |
"""
|
388 |
-
Welcome
|
389 |
-
|
390 |
"""
|
391 |
)
|
392 |
-
|
393 |
-
#
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
# Logs Section
|
417 |
-
with st.container():
|
418 |
-
st.header("4. Error Logs & Debug Information")
|
419 |
-
advanced_logs_ui()
|
420 |
|
421 |
|
422 |
if __name__ == "__main__":
|
|
|
6 |
import sqlalchemy
|
7 |
from typing import Any, Dict, List
|
8 |
|
9 |
+
# Provider clients – ensure these libraries are installed
|
10 |
try:
|
11 |
from openai import OpenAI
|
12 |
except ImportError:
|
|
|
17 |
except ImportError:
|
18 |
groq = None
|
19 |
|
20 |
+
# Hugging Face inference endpoint
|
21 |
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
22 |
DEFAULT_TEMPERATURE = 0.1
|
23 |
GROQ_MODEL = "mixtral-8x7b-32768"
|
24 |
|
25 |
|
26 |
+
class QADataGenerator:
|
27 |
"""
|
28 |
+
A Q&A Synthetic Generator that extracts and generates question-answer pairs
|
29 |
+
from various input sources using an LLM provider.
|
30 |
"""
|
31 |
def __init__(self) -> None:
|
32 |
self._setup_providers()
|
33 |
self._setup_input_handlers()
|
34 |
self._initialize_session_state()
|
35 |
+
# This prompt instructs the LLM to generate three Q&A pairs.
|
36 |
self.custom_prompt_template = (
|
37 |
+
"You are an expert in extracting question and answer pairs from documents. "
|
38 |
+
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
|
39 |
+
"Each dictionary must have keys 'question' and 'answer'. "
|
40 |
+
"The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
|
41 |
+
"Do not hallucinate. \n\n"
|
42 |
+
"Example JSON Output:\n"
|
43 |
+
"[{'question': 'What is the capital of France?', 'answer': 'Paris'}, "
|
44 |
+
"{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}, "
|
45 |
+
"{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}]\n\n"
|
46 |
+
"Now, generate 3 Q&A pairs from this data:\n{data}"
|
47 |
)
|
48 |
+
|
49 |
def _setup_providers(self) -> None:
|
50 |
+
"""Configure available LLM providers and their client initialization routines."""
|
51 |
self.providers: Dict[str, Dict[str, Any]] = {
|
52 |
"Deepseek": {
|
53 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
|
|
|
66 |
"models": ["gpt2", "llama-2"],
|
67 |
},
|
68 |
}
|
69 |
+
|
70 |
def _setup_input_handlers(self) -> None:
|
71 |
"""Register handlers for different input data types."""
|
72 |
self.input_handlers: Dict[str, Any] = {
|
|
|
76 |
"api": self.handle_api,
|
77 |
"db": self.handle_db,
|
78 |
}
|
79 |
+
|
80 |
def _initialize_session_state(self) -> None:
|
81 |
"""Initialize Streamlit session state with default configuration."""
|
82 |
defaults = {
|
|
|
84 |
"provider": "OpenAI",
|
85 |
"model": "gpt-4-turbo",
|
86 |
"temperature": DEFAULT_TEMPERATURE,
|
|
|
87 |
},
|
88 |
"api_key": "",
|
89 |
+
"inputs": [], # List to store input sources
|
90 |
+
"qa_pairs": "", # Generated Q&A pairs output
|
91 |
+
"error_logs": [], # To store any error messages
|
|
|
92 |
}
|
93 |
for key, value in defaults.items():
|
94 |
if key not in st.session_state:
|
95 |
st.session_state[key] = value
|
96 |
+
|
97 |
def log_error(self, message: str) -> None:
|
98 |
+
"""Log an error message to session state and display it."""
|
99 |
st.session_state.error_logs.append(message)
|
100 |
st.error(message)
|
101 |
+
|
102 |
+
# ----- Input Handlers -----
|
103 |
def handle_text(self, text: str) -> Dict[str, Any]:
|
104 |
return {"data": text, "source": "text"}
|
105 |
+
|
106 |
def handle_pdf(self, file) -> Dict[str, Any]:
|
107 |
try:
|
108 |
with pdfplumber.open(file) as pdf:
|
|
|
114 |
except Exception as e:
|
115 |
self.log_error(f"PDF Processing Error: {e}")
|
116 |
return {"data": "", "source": "pdf"}
|
117 |
+
|
118 |
def handle_csv(self, file) -> Dict[str, Any]:
|
119 |
try:
|
120 |
df = pd.read_csv(file)
|
121 |
+
# Convert the DataFrame to a JSON string
|
122 |
return {"data": df.to_json(orient="records"), "source": "csv"}
|
123 |
except Exception as e:
|
124 |
self.log_error(f"CSV Processing Error: {e}")
|
125 |
return {"data": "", "source": "csv"}
|
126 |
+
|
127 |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
128 |
try:
|
129 |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
|
|
132 |
except Exception as e:
|
133 |
self.log_error(f"API Processing Error: {e}")
|
134 |
return {"data": "", "source": "api"}
|
135 |
+
|
136 |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
137 |
try:
|
138 |
engine = sqlalchemy.create_engine(config["connection"])
|
|
|
143 |
except Exception as e:
|
144 |
self.log_error(f"Database Processing Error: {e}")
|
145 |
return {"data": "", "source": "db"}
|
146 |
+
|
147 |
def aggregate_inputs(self) -> str:
|
148 |
"""Combine all input sources into a single aggregated string."""
|
149 |
aggregated_data = ""
|
|
|
151 |
aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
|
152 |
aggregated_data += item.get("data", "") + "\n\n"
|
153 |
return aggregated_data.strip()
|
154 |
+
|
155 |
def build_prompt(self) -> str:
|
156 |
"""
|
157 |
+
Build the complete prompt using the custom template and aggregated inputs.
|
|
|
158 |
"""
|
159 |
+
data = self.aggregate_inputs()
|
160 |
+
prompt = self.custom_prompt_template.format(data=data)
|
|
|
|
|
|
|
|
|
161 |
st.write("### Built Prompt")
|
162 |
st.write(prompt)
|
163 |
return prompt
|
164 |
+
|
165 |
+
def generate_qa_pairs(self) -> bool:
|
166 |
"""
|
167 |
+
Generate Q&A pairs by sending the built prompt to the selected LLM provider.
|
|
|
168 |
"""
|
169 |
api_key = st.session_state.api_key
|
170 |
if not api_key:
|
171 |
self.log_error("API key is missing!")
|
172 |
return False
|
173 |
+
|
174 |
provider_name = st.session_state.config["provider"]
|
175 |
provider_cfg = self.providers.get(provider_name)
|
176 |
if not provider_cfg:
|
177 |
self.log_error(f"Provider {provider_name} is not configured.")
|
178 |
return False
|
179 |
+
|
180 |
client_initializer = provider_cfg["client"]
|
181 |
client = client_initializer(api_key)
|
182 |
model = st.session_state.config["model"]
|
183 |
temperature = st.session_state.config["temperature"]
|
184 |
prompt = self.build_prompt()
|
185 |
+
|
186 |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
|
187 |
try:
|
188 |
if provider_name == "HuggingFace":
|
|
|
193 |
st.write("### Raw API Response")
|
194 |
st.write(response)
|
195 |
|
196 |
+
qa_pairs = self._parse_response(response, provider_name)
|
197 |
+
st.write("### Parsed Q&A Pairs")
|
198 |
+
st.write(qa_pairs)
|
199 |
|
200 |
+
st.session_state.qa_pairs = qa_pairs
|
201 |
return True
|
202 |
except Exception as e:
|
203 |
self.log_error(f"Generation failed: {e}")
|
204 |
return False
|
205 |
+
|
206 |
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
|
207 |
+
"""Inference method for providers using an OpenAI-compatible API."""
|
|
|
|
|
208 |
try:
|
209 |
st.write("Sending prompt via standard inference...")
|
210 |
result = client.chat.completions.create(
|
|
|
217 |
except Exception as e:
|
218 |
self.log_error(f"Standard Inference Error: {e}")
|
219 |
return None
|
220 |
+
|
221 |
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
|
222 |
+
"""Inference method for the Hugging Face Inference API."""
|
|
|
|
|
223 |
try:
|
224 |
st.write("Sending prompt to HuggingFace API...")
|
225 |
response = requests.post(
|
|
|
234 |
except Exception as e:
|
235 |
self.log_error(f"HuggingFace Inference Error: {e}")
|
236 |
return None
|
237 |
+
|
238 |
+
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
239 |
"""
|
240 |
+
Parse the LLM response and return a list of Q&A pairs.
|
241 |
+
Expects the response to be JSON formatted.
|
242 |
"""
|
243 |
st.write("Parsing response for provider:", provider)
|
244 |
try:
|
245 |
if provider == "HuggingFace":
|
246 |
+
# For HuggingFace, assume the generated text is under "generated_text"
|
247 |
if isinstance(response, list) and response and "generated_text" in response[0]:
|
248 |
+
raw_text = response[0]["generated_text"]
|
249 |
else:
|
250 |
self.log_error("Unexpected HuggingFace response format.")
|
251 |
+
return []
|
252 |
else:
|
253 |
+
# For OpenAI (and similar providers) assume the response is similar to:
|
254 |
+
# response.choices[0].message.content
|
255 |
if response and hasattr(response, "choices") and response.choices:
|
256 |
+
raw_text = response.choices[0].message.content
|
257 |
else:
|
258 |
self.log_error("Unexpected response format from provider.")
|
259 |
+
return []
|
260 |
+
|
261 |
+
# Try parsing the raw text as JSON
|
262 |
+
try:
|
263 |
+
qa_list = json.loads(raw_text)
|
264 |
+
if isinstance(qa_list, list):
|
265 |
+
return qa_list
|
266 |
+
else:
|
267 |
+
self.log_error("Parsed output is not a list.")
|
268 |
+
return []
|
269 |
+
except json.JSONDecodeError as e:
|
270 |
+
self.log_error(f"JSON Parsing Error: {e}. Raw output: {raw_text}")
|
271 |
+
return []
|
272 |
except Exception as e:
|
273 |
self.log_error(f"Response Parsing Error: {e}")
|
274 |
+
return []
|
275 |
|
276 |
|
277 |
+
# ============ UI Components ============
|
278 |
|
279 |
+
def config_ui(generator: QADataGenerator):
|
280 |
+
"""Display configuration options in the sidebar."""
|
281 |
with st.sidebar:
|
282 |
+
st.header("Configuration")
|
283 |
provider = st.selectbox("Select Provider", list(generator.providers.keys()))
|
284 |
st.session_state.config["provider"] = provider
|
285 |
provider_cfg = generator.providers[provider]
|
286 |
+
|
287 |
model = st.selectbox("Select Model", provider_cfg["models"])
|
288 |
st.session_state.config["model"] = model
|
289 |
+
|
290 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
291 |
st.session_state.config["temperature"] = temperature
|
292 |
+
|
|
|
|
|
|
|
293 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
294 |
st.session_state.api_key = api_key
|
295 |
|
296 |
+
def input_ui(generator: QADataGenerator):
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
"""Display input data source options using tabs."""
|
298 |
+
st.subheader("Input Data Sources")
|
299 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
300 |
|
301 |
with tabs[0]:
|
|
|
339 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
340 |
st.success("Database input added!")
|
341 |
|
342 |
+
def output_ui(generator: QADataGenerator):
|
343 |
+
"""Display the generated Q&A pairs and provide a download option."""
|
344 |
+
st.subheader("Q&A Pairs Output")
|
345 |
+
if st.session_state.qa_pairs:
|
346 |
+
st.write("### Generated Q&A Pairs")
|
347 |
+
st.write(st.session_state.qa_pairs)
|
348 |
+
st.download_button(
|
349 |
+
"Download Output",
|
350 |
+
json.dumps(st.session_state.qa_pairs, indent=2),
|
351 |
+
file_name="qa_pairs.json",
|
352 |
+
mime="application/json"
|
353 |
+
)
|
|
|
|
|
|
|
|
|
354 |
else:
|
355 |
+
st.info("No Q&A pairs generated yet.")
|
|
|
356 |
|
357 |
+
def logs_ui():
|
358 |
+
"""Display error logs and debugging information in an expandable section."""
|
359 |
with st.expander("Error Logs & Debug Info", expanded=False):
|
360 |
if st.session_state.error_logs:
|
361 |
for log in st.session_state.error_logs:
|
|
|
364 |
st.write("No logs yet.")
|
365 |
|
366 |
|
367 |
+
def main():
|
368 |
+
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
|
369 |
+
st.title("Advanced Q&A Synthetic Generator")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
st.markdown(
|
371 |
"""
|
372 |
+
Welcome to the Advanced Q&A Synthetic Generator. This tool extracts and generates question-answer pairs
|
373 |
+
from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs.
|
374 |
"""
|
375 |
)
|
376 |
+
|
377 |
+
# Initialize generator and display configuration UI
|
378 |
+
generator = QADataGenerator()
|
379 |
+
config_ui(generator)
|
380 |
+
|
381 |
+
st.header("1. Input Data")
|
382 |
+
input_ui(generator)
|
383 |
+
if st.button("Clear All Inputs"):
|
384 |
+
st.session_state.inputs = []
|
385 |
+
st.success("All inputs have been cleared!")
|
386 |
+
|
387 |
+
st.header("2. Generate Q&A Pairs")
|
388 |
+
if st.button("Generate Q&A Pairs", key="generate_qa"):
|
389 |
+
with st.spinner("Generating Q&A pairs..."):
|
390 |
+
if generator.generate_qa_pairs():
|
391 |
+
st.success("Q&A pairs generated successfully!")
|
392 |
+
else:
|
393 |
+
st.error("Q&A generation failed. Check logs for details.")
|
394 |
+
|
395 |
+
st.header("3. Output")
|
396 |
+
output_ui(generator)
|
397 |
+
|
398 |
+
st.header("4. Logs & Debug Information")
|
399 |
+
logs_ui()
|
|
|
|
|
|
|
|
|
400 |
|
401 |
|
402 |
if __name__ == "__main__":
|