mgbam commited on
Commit
d2b7530
·
verified ·
1 Parent(s): 8fa07b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -150
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import json
2
  import ast
3
- import logging
4
  import requests
5
  import streamlit as st
6
  import pdfplumber
@@ -8,16 +7,6 @@ import pandas as pd
8
  import sqlalchemy
9
  from typing import Any, Dict, List, Callable
10
 
11
- # Configure Python logging for production diagnostics.
12
- logger = logging.getLogger("SyntheticDataGenerator")
13
- logger.setLevel(logging.INFO)
14
- if not logger.handlers:
15
- handler = logging.StreamHandler()
16
- handler.setLevel(logging.INFO)
17
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
18
- handler.setFormatter(formatter)
19
- logger.addHandler(handler)
20
-
21
  # Provider clients – ensure these libraries are installed
22
  try:
23
  from openai import OpenAI
@@ -29,35 +18,35 @@ try:
29
  except ImportError:
30
  groq = None
31
 
32
- # Constants for external APIs
33
  HF_API_URL: str = "https://api-inference.huggingface.co/models/"
34
  DEFAULT_TEMPERATURE: float = 0.1
35
  GROQ_MODEL: str = "mixtral-8x7b-32768"
36
 
37
 
38
- class SyntheticDataGenerator:
39
  """
40
- An advanced synthetic data generator for creating fine-tuning training examples.
41
-
42
- This generator uses various input sources and an LLM provider to create synthetic data.
43
- Each generated example is a dictionary with 'input' and 'output' keys.
44
  """
45
  def __init__(self) -> None:
46
  self._setup_providers()
47
  self._setup_input_handlers()
48
  self._initialize_session_state()
49
- # Prompt template: note the use of escaped curly braces so that literal braces are kept.
50
  self.custom_prompt_template: str = (
51
- "You are an expert in generating synthetic training data for fine-tuning. "
52
- "Generate {num_examples} training examples from the following data, formatted as a JSON list of dictionaries. "
53
- "Each dictionary must have keys 'input' and 'output'. "
54
- "The examples should be clear, diverse, and based solely on the provided data. Do not add any external information.\n\n"
 
55
  "Example JSON Output:\n"
56
- "[{{'input': 'sample input text 1', 'output': 'sample output text 1'}}, "
57
- "{{'input': 'sample input text 2', 'output': 'sample output text 2'}}]\n\n"
58
- "Now, generate {num_examples} training examples from this data:\n{data}"
 
59
  )
60
-
61
  def _setup_providers(self) -> None:
62
  """Configure available LLM providers and their client initialization routines."""
63
  self.providers: Dict[str, Dict[str, Any]] = {
@@ -78,9 +67,9 @@ class SyntheticDataGenerator:
78
  "models": ["gpt2", "llama-2"],
79
  },
80
  }
81
-
82
  def _setup_input_handlers(self) -> None:
83
- """Register input handlers for various data types."""
84
  self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
85
  "text": self.handle_text,
86
  "pdf": self.handle_pdf,
@@ -88,23 +77,20 @@ class SyntheticDataGenerator:
88
  "api": self.handle_api,
89
  "db": self.handle_db,
90
  }
91
-
92
  def _initialize_session_state(self) -> None:
93
- """
94
- Initialize the Streamlit session state with default configuration.
95
- Also pre-populate configuration from URL query parameters.
96
- """
97
  defaults: Dict[str, Any] = {
98
  "config": {
99
  "provider": "OpenAI",
100
  "model": "gpt-4-turbo",
101
  "temperature": DEFAULT_TEMPERATURE,
102
- "num_examples": 3,
103
  },
104
  "api_key": "",
105
- "inputs": [], # List to store input sources
106
- "synthetic_data": None, # Generated synthetic training examples
107
- "error_logs": [], # To store error messages
108
  }
109
  for key, value in defaults.items():
110
  if key not in st.session_state:
@@ -126,19 +112,18 @@ class SyntheticDataGenerator:
126
  st.session_state.config["num_examples"] = int(params["num_examples"][0])
127
  except ValueError:
128
  pass
129
-
130
  def log_error(self, message: str) -> None:
131
- """Log an error message to both Streamlit and the production logger."""
132
  st.session_state.error_logs.append(message)
133
  st.error(message)
134
- logger.error(message)
135
-
136
  # ----- Input Handlers -----
137
  def handle_text(self, text: str) -> Dict[str, Any]:
138
- """Return plain text input."""
139
  return {"data": text, "source": "text"}
140
-
141
- def handle_pdf(self, file: Any) -> Dict[str, Any]:
142
  """Extract text from a PDF file."""
143
  try:
144
  with pdfplumber.open(file) as pdf:
@@ -147,16 +132,17 @@ class SyntheticDataGenerator:
147
  except Exception as e:
148
  self.log_error(f"PDF Processing Error: {e}")
149
  return {"data": "", "source": "pdf"}
150
-
151
- def handle_csv(self, file: Any) -> Dict[str, Any]:
152
- """Process CSV file by converting it to JSON."""
153
  try:
154
  df = pd.read_csv(file)
155
- return {"data": df.to_json(orient="records"), "source": "csv"}
 
156
  except Exception as e:
157
  self.log_error(f"CSV Processing Error: {e}")
158
  return {"data": "", "source": "csv"}
159
-
160
  def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
161
  """Fetch data from an API endpoint."""
162
  try:
@@ -166,9 +152,9 @@ class SyntheticDataGenerator:
166
  except Exception as e:
167
  self.log_error(f"API Processing Error: {e}")
168
  return {"data": "", "source": "api"}
169
-
170
  def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
171
- """Query a database using a connection string and SQL query."""
172
  try:
173
  engine = sqlalchemy.create_engine(config["connection"])
174
  with engine.connect() as conn:
@@ -178,18 +164,19 @@ class SyntheticDataGenerator:
178
  except Exception as e:
179
  self.log_error(f"Database Processing Error: {e}")
180
  return {"data": "", "source": "db"}
181
-
182
  def aggregate_inputs(self) -> str:
183
- """Aggregate all input data sources into a single string."""
184
- aggregated = ""
185
  for item in st.session_state.inputs:
186
- aggregated += f"Source: {item.get('source', 'unknown')}\n{item.get('data', '')}\n\n"
187
- return aggregated.strip()
188
-
 
189
  def build_prompt(self) -> str:
190
  """
191
- Build the complete prompt using the custom template, aggregated inputs,
192
- and the configured number of examples.
193
  """
194
  data = self.aggregate_inputs()
195
  num_examples = st.session_state.config.get("num_examples", 3)
@@ -197,52 +184,50 @@ class SyntheticDataGenerator:
197
  st.write("### Built Prompt")
198
  st.write(prompt)
199
  return prompt
200
-
201
- def generate_synthetic_data(self) -> bool:
202
  """
203
- Generate synthetic training examples by sending the prompt to the selected LLM provider.
204
  """
205
  api_key: str = st.session_state.api_key
206
  if not api_key:
207
  self.log_error("API key is missing!")
208
  return False
209
-
210
  provider_name: str = st.session_state.config["provider"]
211
  provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {})
212
  if not provider_cfg:
213
  self.log_error(f"Provider {provider_name} is not configured.")
214
  return False
215
-
216
  client_initializer: Callable[[str], Any] = provider_cfg["client"]
217
  client = client_initializer(api_key)
218
  model: str = st.session_state.config["model"]
219
  temperature: float = st.session_state.config["temperature"]
220
  prompt: str = self.build_prompt()
221
-
222
  st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
223
  try:
224
  if provider_name == "HuggingFace":
225
  response = self._huggingface_inference(client, prompt, model)
226
  else:
227
  response = self._standard_inference(client, prompt, model, temperature)
228
-
229
  st.write("### Raw API Response")
230
  st.write(response)
231
-
232
- synthetic_examples = self._parse_response(response, provider_name)
233
- st.write("### Parsed Synthetic Data")
234
- st.write(synthetic_examples)
235
-
236
- st.session_state.synthetic_data = synthetic_examples
237
  return True
238
  except Exception as e:
239
  self.log_error(f"Generation failed: {e}")
240
  return False
241
-
242
  def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
243
- """
244
- Inference method for providers with an OpenAI-compatible API.
245
- """
246
  try:
247
  st.write("Sending prompt via standard inference...")
248
  result = client.chat.completions.create(
@@ -255,11 +240,9 @@ class SyntheticDataGenerator:
255
  except Exception as e:
256
  self.log_error(f"Standard Inference Error: {e}")
257
  return None
258
-
259
  def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
260
- """
261
- Inference method for the Hugging Face Inference API.
262
- """
263
  try:
264
  st.write("Sending prompt to HuggingFace API...")
265
  response = requests.post(
@@ -274,41 +257,40 @@ class SyntheticDataGenerator:
274
  except Exception as e:
275
  self.log_error(f"HuggingFace Inference Error: {e}")
276
  return None
277
-
278
  def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
279
  """
280
- Parse the LLM response and return a list of synthetic training examples.
281
- Attempts JSON decoding first and falls back to ast.literal_eval.
 
282
  """
283
  st.write("Parsing response for provider:", provider)
284
  try:
285
  if provider == "HuggingFace":
286
- # Expect response to be a list with a key "generated_text"
287
  if isinstance(response, list) and response and "generated_text" in response[0]:
288
  raw_text = response[0]["generated_text"]
289
  else:
290
  self.log_error("Unexpected HuggingFace response format.")
291
  return []
292
  else:
293
- # For OpenAI/Groq, look for choices[0].message.content
294
  if response and hasattr(response, "choices") and response.choices:
295
  raw_text = response.choices[0].message.content
296
  else:
297
  self.log_error("Unexpected response format from provider.")
298
  return []
299
-
300
  try:
301
- examples = json.loads(raw_text)
302
  except json.JSONDecodeError as e:
303
- self.log_error(f"JSON Parsing Error: {e}. Fallback with ast.literal_eval. Raw output: {raw_text}")
304
  try:
305
- examples = ast.literal_eval(raw_text)
306
  except Exception as e2:
307
  self.log_error(f"ast.literal_eval failed: {e2}")
308
  return []
309
-
310
- if isinstance(examples, list):
311
- return examples
312
  else:
313
  self.log_error("Parsed output is not a list.")
314
  return []
@@ -317,57 +299,53 @@ class SyntheticDataGenerator:
317
  return []
318
 
319
 
320
- # =================== UI Components ===================
321
 
322
- def config_ui(generator: SyntheticDataGenerator) -> None:
323
- """
324
- Display configuration options in the sidebar.
325
- Updates URL query parameters using st.set_query_params.
326
- """
327
  with st.sidebar:
328
  st.header("Configuration")
 
329
  params = st.experimental_get_query_params()
330
  default_provider = params.get("provider", ["OpenAI"])[0]
331
  default_model = params.get("model", ["gpt-4-turbo"])[0]
332
  default_temperature = float(params.get("temperature", [DEFAULT_TEMPERATURE])[0])
333
  default_num_examples = int(params.get("num_examples", [3])[0])
334
-
335
  provider_options = list(generator.providers.keys())
336
- provider = st.selectbox("Select Provider", provider_options,
337
- index=provider_options.index(default_provider)
338
- if default_provider in provider_options else 0)
339
  st.session_state.config["provider"] = provider
340
  provider_cfg = generator.providers[provider]
341
-
342
  model_options = provider_cfg["models"]
343
  model = st.selectbox("Select Model", model_options,
344
- index=model_options.index(default_model)
345
- if default_model in model_options else 0)
346
  st.session_state.config["model"] = model
347
-
348
  temperature = st.slider("Temperature", 0.0, 1.0, default_temperature)
349
  st.session_state.config["temperature"] = temperature
350
-
351
- num_examples = st.number_input("Number of Training Examples", min_value=1, max_value=10,
352
  value=default_num_examples, step=1)
353
  st.session_state.config["num_examples"] = num_examples
354
-
355
  api_key = st.text_input(f"{provider} API Key", type="password")
356
  st.session_state.api_key = api_key
357
-
358
- # Update URL query parameters (shareable configuration)
359
- st.set_query_params(
360
  provider=st.session_state.config["provider"],
361
  model=st.session_state.config["model"],
362
  temperature=st.session_state.config["temperature"],
363
  num_examples=st.session_state.config["num_examples"],
364
  )
365
 
366
- def input_ui(generator: SyntheticDataGenerator) -> None:
367
- """Display input data source options in tabs."""
368
  st.subheader("Input Data Sources")
369
  tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
370
-
371
  with tabs[0]:
372
  text_input = st.text_area("Enter text input", height=150)
373
  if st.button("Add Text Input", key="text_input"):
@@ -376,19 +354,19 @@ def input_ui(generator: SyntheticDataGenerator) -> None:
376
  st.success("Text input added!")
377
  else:
378
  st.warning("Empty text input.")
379
-
380
  with tabs[1]:
381
  pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
382
  if pdf_file is not None:
383
  st.session_state.inputs.append(generator.handle_pdf(pdf_file))
384
  st.success("PDF input added!")
385
-
386
  with tabs[2]:
387
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
388
  if csv_file is not None:
389
  st.session_state.inputs.append(generator.handle_csv(csv_file))
390
  st.success("CSV input added!")
391
-
392
  with tabs[3]:
393
  api_url = st.text_input("API Endpoint URL")
394
  api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
@@ -401,7 +379,7 @@ def input_ui(generator: SyntheticDataGenerator) -> None:
401
  generator.log_error(f"Invalid JSON for API Headers: {e}")
402
  st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
403
  st.success("API input added!")
404
-
405
  with tabs[4]:
406
  db_conn = st.text_input("Database Connection String")
407
  db_query = st.text_area("Database Query", height=100)
@@ -409,38 +387,38 @@ def input_ui(generator: SyntheticDataGenerator) -> None:
409
  st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
410
  st.success("Database input added!")
411
 
412
- def output_ui(generator: SyntheticDataGenerator) -> None:
413
- """Display the generated synthetic data and download options (JSON and CSV)."""
414
- st.subheader("Synthetic Data Output")
415
- if st.session_state.synthetic_data:
416
- st.write("### Generated Training Examples")
417
- st.write(st.session_state.synthetic_data)
418
-
419
  # Download as JSON
420
  st.download_button(
421
  "Download as JSON",
422
- json.dumps(st.session_state.synthetic_data, indent=2),
423
- file_name="synthetic_data.json",
424
  mime="application/json"
425
  )
426
-
427
  # Download as CSV
428
  try:
429
- df = pd.DataFrame(st.session_state.synthetic_data)
430
  csv_data = df.to_csv(index=False)
431
  st.download_button(
432
  "Download as CSV",
433
  csv_data,
434
- file_name="synthetic_data.csv",
435
  mime="text/csv"
436
  )
437
  except Exception as e:
438
  st.error(f"Error generating CSV: {e}")
439
  else:
440
- st.info("No synthetic data generated yet.")
441
 
442
  def logs_ui() -> None:
443
- """Display error logs and debug information in an expandable section."""
444
  with st.expander("Error Logs & Debug Info", expanded=False):
445
  if st.session_state.error_logs:
446
  for log in st.session_state.error_logs:
@@ -450,36 +428,36 @@ def logs_ui() -> None:
450
 
451
  def main() -> None:
452
  """Main Streamlit application entry point."""
453
- st.set_page_config(page_title="Advanced Synthetic Data Generator", layout="wide")
454
- st.title("Advanced Synthetic Data Generator")
455
  st.markdown(
456
  """
457
- Welcome to the Advanced Synthetic Data Generator. This tool creates synthetic training examples
458
- for fine-tuning models. Configure your provider in the sidebar, add input data, and generate synthetic data.
459
  """
460
  )
461
-
462
- # Initialize generator and UI
463
- generator = SyntheticDataGenerator()
464
  config_ui(generator)
465
-
466
  st.header("1. Input Data")
467
  input_ui(generator)
468
  if st.button("Clear All Inputs"):
469
  st.session_state.inputs = []
470
  st.success("All inputs have been cleared!")
471
-
472
- st.header("2. Generate Synthetic Data")
473
- if st.button("Generate Synthetic Data", key="generate_data"):
474
- with st.spinner("Generating synthetic data..."):
475
- if generator.generate_synthetic_data():
476
- st.success("Synthetic data generated successfully!")
477
  else:
478
- st.error("Data generation failed. Check logs for details.")
479
-
480
  st.header("3. Output")
481
  output_ui(generator)
482
-
483
  st.header("4. Logs & Debug Information")
484
  logs_ui()
485
 
 
1
  import json
2
  import ast
 
3
  import requests
4
  import streamlit as st
5
  import pdfplumber
 
7
  import sqlalchemy
8
  from typing import Any, Dict, List, Callable
9
 
 
 
 
 
 
 
 
 
 
 
10
  # Provider clients – ensure these libraries are installed
11
  try:
12
  from openai import OpenAI
 
18
  except ImportError:
19
  groq = None
20
 
21
+ # Hugging Face inference endpoint and defaults
22
  HF_API_URL: str = "https://api-inference.huggingface.co/models/"
23
  DEFAULT_TEMPERATURE: float = 0.1
24
  GROQ_MODEL: str = "mixtral-8x7b-32768"
25
 
26
 
27
+ class QADataGenerator:
28
  """
29
+ A Q&A Synthetic Generator that extracts and generates question-answer pairs
30
+ from various input sources using an LLM provider.
 
 
31
  """
32
  def __init__(self) -> None:
33
  self._setup_providers()
34
  self._setup_input_handlers()
35
  self._initialize_session_state()
36
+ # Prompt template with a dynamic {num_examples} parameter and escaped curly braces
37
  self.custom_prompt_template: str = (
38
+ "You are an expert in extracting question and answer pairs from documents. "
39
+ "Generate {num_examples} Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
40
+ "Each dictionary must have keys 'question' and 'answer'. "
41
+ "The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
42
+ "Do not hallucinate. \n\n"
43
  "Example JSON Output:\n"
44
+ "[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, "
45
+ "{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, "
46
+ "{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n"
47
+ "Now, generate {num_examples} Q&A pairs from this data:\n{data}"
48
  )
49
+
50
  def _setup_providers(self) -> None:
51
  """Configure available LLM providers and their client initialization routines."""
52
  self.providers: Dict[str, Dict[str, Any]] = {
 
67
  "models": ["gpt2", "llama-2"],
68
  },
69
  }
70
+
71
  def _setup_input_handlers(self) -> None:
72
+ """Register handlers for different input data types."""
73
  self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
74
  "text": self.handle_text,
75
  "pdf": self.handle_pdf,
 
77
  "api": self.handle_api,
78
  "db": self.handle_db,
79
  }
80
+
81
  def _initialize_session_state(self) -> None:
82
+ """Initialize Streamlit session state with default configuration."""
 
 
 
83
  defaults: Dict[str, Any] = {
84
  "config": {
85
  "provider": "OpenAI",
86
  "model": "gpt-4-turbo",
87
  "temperature": DEFAULT_TEMPERATURE,
88
+ "num_examples": 3, # Default number of Q&A pairs
89
  },
90
  "api_key": "",
91
+ "inputs": [], # List to store input sources
92
+ "qa_pairs": None, # Generated Q&A pairs output
93
+ "error_logs": [], # To store error messages
94
  }
95
  for key, value in defaults.items():
96
  if key not in st.session_state:
 
112
  st.session_state.config["num_examples"] = int(params["num_examples"][0])
113
  except ValueError:
114
  pass
115
+
116
  def log_error(self, message: str) -> None:
117
+ """Log an error message to session state and display it."""
118
  st.session_state.error_logs.append(message)
119
  st.error(message)
120
+
 
121
  # ----- Input Handlers -----
122
  def handle_text(self, text: str) -> Dict[str, Any]:
123
+ """Process plain text input."""
124
  return {"data": text, "source": "text"}
125
+
126
+ def handle_pdf(self, file) -> Dict[str, Any]:
127
  """Extract text from a PDF file."""
128
  try:
129
  with pdfplumber.open(file) as pdf:
 
132
  except Exception as e:
133
  self.log_error(f"PDF Processing Error: {e}")
134
  return {"data": "", "source": "pdf"}
135
+
136
+ def handle_csv(self, file) -> Dict[str, Any]:
137
+ """Process a CSV file by converting it to JSON."""
138
  try:
139
  df = pd.read_csv(file)
140
+ json_data = df.to_json(orient="records")
141
+ return {"data": json_data, "source": "csv"}
142
  except Exception as e:
143
  self.log_error(f"CSV Processing Error: {e}")
144
  return {"data": "", "source": "csv"}
145
+
146
  def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
147
  """Fetch data from an API endpoint."""
148
  try:
 
152
  except Exception as e:
153
  self.log_error(f"API Processing Error: {e}")
154
  return {"data": "", "source": "api"}
155
+
156
  def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
157
+ """Query a database using the provided connection string and SQL query."""
158
  try:
159
  engine = sqlalchemy.create_engine(config["connection"])
160
  with engine.connect() as conn:
 
164
  except Exception as e:
165
  self.log_error(f"Database Processing Error: {e}")
166
  return {"data": "", "source": "db"}
167
+
168
  def aggregate_inputs(self) -> str:
169
+ """Combine all input sources into a single aggregated string."""
170
+ aggregated_data = ""
171
  for item in st.session_state.inputs:
172
+ aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
173
+ aggregated_data += item.get("data", "") + "\n\n"
174
+ return aggregated_data.strip()
175
+
176
  def build_prompt(self) -> str:
177
  """
178
+ Build the complete prompt using the custom template, aggregated inputs,
179
+ and the number of examples.
180
  """
181
  data = self.aggregate_inputs()
182
  num_examples = st.session_state.config.get("num_examples", 3)
 
184
  st.write("### Built Prompt")
185
  st.write(prompt)
186
  return prompt
187
+
188
+ def generate_qa_pairs(self) -> bool:
189
  """
190
+ Generate Q&A pairs by sending the built prompt to the selected LLM provider.
191
  """
192
  api_key: str = st.session_state.api_key
193
  if not api_key:
194
  self.log_error("API key is missing!")
195
  return False
196
+
197
  provider_name: str = st.session_state.config["provider"]
198
  provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {})
199
  if not provider_cfg:
200
  self.log_error(f"Provider {provider_name} is not configured.")
201
  return False
202
+
203
  client_initializer: Callable[[str], Any] = provider_cfg["client"]
204
  client = client_initializer(api_key)
205
  model: str = st.session_state.config["model"]
206
  temperature: float = st.session_state.config["temperature"]
207
  prompt: str = self.build_prompt()
208
+
209
  st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
210
  try:
211
  if provider_name == "HuggingFace":
212
  response = self._huggingface_inference(client, prompt, model)
213
  else:
214
  response = self._standard_inference(client, prompt, model, temperature)
215
+
216
  st.write("### Raw API Response")
217
  st.write(response)
218
+
219
+ qa_pairs = self._parse_response(response, provider_name)
220
+ st.write("### Parsed Q&A Pairs")
221
+ st.write(qa_pairs)
222
+
223
+ st.session_state.qa_pairs = qa_pairs
224
  return True
225
  except Exception as e:
226
  self.log_error(f"Generation failed: {e}")
227
  return False
228
+
229
  def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
230
+ """Inference method for providers using an OpenAI-compatible API."""
 
 
231
  try:
232
  st.write("Sending prompt via standard inference...")
233
  result = client.chat.completions.create(
 
240
  except Exception as e:
241
  self.log_error(f"Standard Inference Error: {e}")
242
  return None
243
+
244
  def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
245
+ """Inference method for the Hugging Face Inference API."""
 
 
246
  try:
247
  st.write("Sending prompt to HuggingFace API...")
248
  response = requests.post(
 
257
  except Exception as e:
258
  self.log_error(f"HuggingFace Inference Error: {e}")
259
  return None
260
+
261
  def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
262
  """
263
+ Parse the LLM response and return a list of Q&A pairs.
264
+ Expects the response to be JSON formatted; if JSON decoding fails,
265
+ uses ast.literal_eval as a fallback.
266
  """
267
  st.write("Parsing response for provider:", provider)
268
  try:
269
  if provider == "HuggingFace":
 
270
  if isinstance(response, list) and response and "generated_text" in response[0]:
271
  raw_text = response[0]["generated_text"]
272
  else:
273
  self.log_error("Unexpected HuggingFace response format.")
274
  return []
275
  else:
 
276
  if response and hasattr(response, "choices") and response.choices:
277
  raw_text = response.choices[0].message.content
278
  else:
279
  self.log_error("Unexpected response format from provider.")
280
  return []
281
+
282
  try:
283
+ qa_list = json.loads(raw_text)
284
  except json.JSONDecodeError as e:
285
+ self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}")
286
  try:
287
+ qa_list = ast.literal_eval(raw_text)
288
  except Exception as e2:
289
  self.log_error(f"ast.literal_eval failed: {e2}")
290
  return []
291
+
292
+ if isinstance(qa_list, list):
293
+ return qa_list
294
  else:
295
  self.log_error("Parsed output is not a list.")
296
  return []
 
299
  return []
300
 
301
 
302
+ # ============ UI Components ============
303
 
304
+ def config_ui(generator: QADataGenerator) -> None:
305
+ """Display configuration options in the sidebar and update URL query parameters."""
 
 
 
306
  with st.sidebar:
307
  st.header("Configuration")
308
+ # Retrieve any query parameters from the URL
309
  params = st.experimental_get_query_params()
310
  default_provider = params.get("provider", ["OpenAI"])[0]
311
  default_model = params.get("model", ["gpt-4-turbo"])[0]
312
  default_temperature = float(params.get("temperature", [DEFAULT_TEMPERATURE])[0])
313
  default_num_examples = int(params.get("num_examples", [3])[0])
314
+
315
  provider_options = list(generator.providers.keys())
316
+ provider = st.selectbox("Select Provider", provider_options,
317
+ index=provider_options.index(default_provider) if default_provider in provider_options else 0)
 
318
  st.session_state.config["provider"] = provider
319
  provider_cfg = generator.providers[provider]
320
+
321
  model_options = provider_cfg["models"]
322
  model = st.selectbox("Select Model", model_options,
323
+ index=model_options.index(default_model) if default_model in model_options else 0)
 
324
  st.session_state.config["model"] = model
325
+
326
  temperature = st.slider("Temperature", 0.0, 1.0, default_temperature)
327
  st.session_state.config["temperature"] = temperature
328
+
329
+ num_examples = st.number_input("Number of Q&A Pairs", min_value=1, max_value=10,
330
  value=default_num_examples, step=1)
331
  st.session_state.config["num_examples"] = num_examples
332
+
333
  api_key = st.text_input(f"{provider} API Key", type="password")
334
  st.session_state.api_key = api_key
335
+
336
+ # Update the URL query parameters for sharing/pre-populating configuration
337
+ st.experimental_set_query_params(
338
  provider=st.session_state.config["provider"],
339
  model=st.session_state.config["model"],
340
  temperature=st.session_state.config["temperature"],
341
  num_examples=st.session_state.config["num_examples"],
342
  )
343
 
344
+ def input_ui(generator: QADataGenerator) -> None:
345
+ """Display input data source options using tabs."""
346
  st.subheader("Input Data Sources")
347
  tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
348
+
349
  with tabs[0]:
350
  text_input = st.text_area("Enter text input", height=150)
351
  if st.button("Add Text Input", key="text_input"):
 
354
  st.success("Text input added!")
355
  else:
356
  st.warning("Empty text input.")
357
+
358
  with tabs[1]:
359
  pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
360
  if pdf_file is not None:
361
  st.session_state.inputs.append(generator.handle_pdf(pdf_file))
362
  st.success("PDF input added!")
363
+
364
  with tabs[2]:
365
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
366
  if csv_file is not None:
367
  st.session_state.inputs.append(generator.handle_csv(csv_file))
368
  st.success("CSV input added!")
369
+
370
  with tabs[3]:
371
  api_url = st.text_input("API Endpoint URL")
372
  api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
 
379
  generator.log_error(f"Invalid JSON for API Headers: {e}")
380
  st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
381
  st.success("API input added!")
382
+
383
  with tabs[4]:
384
  db_conn = st.text_input("Database Connection String")
385
  db_query = st.text_area("Database Query", height=100)
 
387
  st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
388
  st.success("Database input added!")
389
 
390
+ def output_ui(generator: QADataGenerator) -> None:
391
+ """Display the generated Q&A pairs and provide download options (JSON and CSV)."""
392
+ st.subheader("Q&A Pairs Output")
393
+ if st.session_state.qa_pairs:
394
+ st.write("### Generated Q&A Pairs")
395
+ st.write(st.session_state.qa_pairs)
396
+
397
  # Download as JSON
398
  st.download_button(
399
  "Download as JSON",
400
+ json.dumps(st.session_state.qa_pairs, indent=2),
401
+ file_name="qa_pairs.json",
402
  mime="application/json"
403
  )
404
+
405
  # Download as CSV
406
  try:
407
+ df = pd.DataFrame(st.session_state.qa_pairs)
408
  csv_data = df.to_csv(index=False)
409
  st.download_button(
410
  "Download as CSV",
411
  csv_data,
412
+ file_name="qa_pairs.csv",
413
  mime="text/csv"
414
  )
415
  except Exception as e:
416
  st.error(f"Error generating CSV: {e}")
417
  else:
418
+ st.info("No Q&A pairs generated yet.")
419
 
420
  def logs_ui() -> None:
421
+ """Display error logs and debugging information in an expandable section."""
422
  with st.expander("Error Logs & Debug Info", expanded=False):
423
  if st.session_state.error_logs:
424
  for log in st.session_state.error_logs:
 
428
 
429
  def main() -> None:
430
  """Main Streamlit application entry point."""
431
+ st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
432
+ st.title("Advanced Q&A Synthetic Generator")
433
  st.markdown(
434
  """
435
+ Welcome to the Advanced Q&A Synthetic Generator. This tool extracts and generates question-answer pairs
436
+ from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs.
437
  """
438
  )
439
+
440
+ # Initialize generator and display configuration UI
441
+ generator = QADataGenerator()
442
  config_ui(generator)
443
+
444
  st.header("1. Input Data")
445
  input_ui(generator)
446
  if st.button("Clear All Inputs"):
447
  st.session_state.inputs = []
448
  st.success("All inputs have been cleared!")
449
+
450
+ st.header("2. Generate Q&A Pairs")
451
+ if st.button("Generate Q&A Pairs", key="generate_qa"):
452
+ with st.spinner("Generating Q&A pairs..."):
453
+ if generator.generate_qa_pairs():
454
+ st.success("Q&A pairs generated successfully!")
455
  else:
456
+ st.error("Q&A generation failed. Check logs for details.")
457
+
458
  st.header("3. Output")
459
  output_ui(generator)
460
+
461
  st.header("4. Logs & Debug Information")
462
  logs_ui()
463