mgbam commited on
Commit
edd1d90
·
verified ·
1 Parent(s): 3db2361

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -317
app.py CHANGED
@@ -1,24 +1,12 @@
1
  import json
2
- import logging
3
  import requests
4
  import streamlit as st
5
  import pdfplumber
6
  import pandas as pd
7
  import sqlalchemy
8
- from typing import Any, Dict, List, Optional, Union, Callable
9
- from functools import lru_cache
10
 
11
- # --- Logging Configuration ---
12
- logger = logging.getLogger("SyntheticDataGenerator")
13
- logger.setLevel(logging.DEBUG)
14
- if not logger.handlers:
15
- ch = logging.StreamHandler()
16
- ch.setLevel(logging.DEBUG)
17
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
18
- ch.setFormatter(formatter)
19
- logger.addHandler(ch)
20
-
21
- # --- Provider Clients with Import Guards ---
22
  try:
23
  from openai import OpenAI
24
  except ImportError:
@@ -29,351 +17,387 @@ try:
29
  except ImportError:
30
  groq = None
31
 
 
 
 
 
32
 
33
- # --- Custom Exceptions ---
34
- class ProviderClientError(Exception):
35
- """Custom exception for provider client issues."""
36
- pass
37
-
38
-
39
- # --- Core Synthetic Data Generator ---
40
- class SyntheticDataGenerator:
41
- """World's Most Advanced Synthetic Data Generation System"""
42
-
43
- PROVIDER_CONFIG: Dict[str, Dict[str, Union[str, List[str], Optional[str]]]] = {
44
- "Deepseek": {
45
- "base_url": "https://api.deepseek.com/v1",
46
- "models": ["deepseek-chat"],
47
- "requires_library": "openai",
48
- },
49
- "OpenAI": {
50
- "base_url": "https://api.openai.com/v1",
51
- "models": ["gpt-4-turbo", "gpt-3.5-turbo"],
52
- "requires_library": "openai",
53
- },
54
- "Groq": {
55
- "base_url": "https://api.groq.com/openai/v1",
56
- "models": ["mixtral-8x7b-32768", "llama2-70b-4096"],
57
- "requires_library": "groq",
58
- },
59
- "HuggingFace": {
60
- "base_url": "https://api-inference.huggingface.co/models/",
61
- "models": ["gpt2", "llama-2-13b-chat"],
62
- "requires_library": None,
63
- },
64
- }
65
 
 
 
 
 
 
66
  def __init__(self) -> None:
67
- self._init_session_state()
68
  self._setup_providers()
69
  self._setup_input_handlers()
70
-
71
- def _init_session_state(self) -> None:
72
- """Initialize session state with default values."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  defaults = {
74
- "active_provider": "OpenAI",
75
- "api_keys": {},
76
- "system_metrics": {"api_calls": 0, "tokens_used": 0, "error_count": 0},
77
- "error_logs": [],
78
- "debug_mode": False,
 
 
 
 
79
  }
80
  for key, value in defaults.items():
81
  if key not in st.session_state:
82
  st.session_state[key] = value
83
-
84
- def _setup_providers(self) -> None:
85
- """Configure available providers based on dependency availability."""
86
- self.available_providers: List[str] = []
87
- for provider, config in self.PROVIDER_CONFIG.items():
88
- required_lib = config.get("requires_library")
89
- if required_lib and not globals().get(required_lib.title()):
90
- logger.warning(f"Skipping provider {provider} due to missing dependency: {required_lib}")
91
- continue
92
- self.available_providers.append(provider)
93
-
94
- def _setup_input_handlers(self) -> None:
95
- """Register input processors."""
96
- self.input_processors: Dict[str, Callable[[Any], str]] = {
97
- "text": self._process_text,
98
- "pdf": self._process_pdf,
99
- "csv": self._process_csv,
100
- "api": self._process_api,
101
- "database": self._process_database,
102
- "web": self._process_web,
103
- }
104
-
105
- @lru_cache(maxsize=100)
106
- def generate(self, provider: str, model: str, prompt: str) -> Dict[str, Any]:
107
- """
108
- Unified generation endpoint with caching and failover support.
109
- """
110
  try:
111
- if provider not in self.available_providers:
112
- raise ProviderClientError(f"Provider {provider} is not available.")
113
- client = self._get_client(provider)
114
- if not client:
115
- raise ProviderClientError(f"Client initialization failed for provider {provider}.")
116
- return self._execute_generation(client, provider, model, prompt)
117
  except Exception as e:
118
- self._log_error(f"Generation error using provider '{provider}': {e}")
119
- return self._failover_generation(provider, model, prompt)
120
-
121
- def _get_client(self, provider: str) -> Any:
122
- """
123
- Initialize and return a client for the specified provider.
124
- Raises ProviderClientError if API key or dependency issues occur.
125
- """
126
- config = self.PROVIDER_CONFIG[provider]
127
- api_key = st.session_state["api_keys"].get(provider, "")
128
- if not api_key:
129
- raise ProviderClientError(f"Missing API key for {provider}.")
130
  try:
131
- if provider == "Groq":
132
- return groq.Groq(api_key=api_key)
133
- elif provider == "HuggingFace":
134
- return {"headers": {"Authorization": f"Bearer {api_key}"}}
135
- else:
136
- return OpenAI(
137
- base_url=config["base_url"],
138
- api_key=api_key,
139
- timeout=30,
140
- )
141
  except Exception as e:
142
- self._log_error(f"Error initializing client for {provider}: {e}")
143
- raise ProviderClientError(f"Client init error for {provider}")
144
-
145
- def _execute_generation(self, client: Any, provider: str, model: str, prompt: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  """
147
- Execute the generation request for the given provider.
148
- Updates system metrics and returns the result.
149
  """
150
- st.session_state["system_metrics"]["api_calls"] += 1
151
-
152
- if provider == "HuggingFace":
153
- url = self.PROVIDER_CONFIG[provider]["base_url"] + model
154
- response = requests.post(url, headers=client["headers"], json={"inputs": prompt}, timeout=30)
155
- response.raise_for_status()
156
- return response.json()
157
- else:
158
- completion = client.chat.completions.create(
159
- model=model,
160
- messages=[{"role": "user", "content": prompt}],
161
- temperature=0.1,
162
- max_tokens=2000,
163
- )
164
- if hasattr(completion.usage, "total_tokens"):
165
- st.session_state["system_metrics"]["tokens_used"] += completion.usage.total_tokens
166
- try:
167
- return json.loads(completion.choices[0].message.content)
168
- except Exception:
169
- return {"response": completion.choices[0].message.content}
170
-
171
- def _failover_generation(self, original_provider: str, model: str, prompt: str) -> Dict[str, Any]:
172
  """
173
- Attempt to generate synthetic data using alternative providers.
174
  """
175
- for backup_provider in self.available_providers:
176
- if backup_provider == original_provider:
177
- continue
178
- backup_models = self.PROVIDER_CONFIG[backup_provider]["models"]
179
- backup_model = model if model in backup_models else backup_models[0]
180
- try:
181
- st.session_state["active_provider"] = backup_provider
182
- result = self.generate(backup_provider, backup_model, prompt)
183
- self._log_error(f"Failover succeeded: provider '{backup_provider}' with model '{backup_model}'")
184
- return result
185
- except Exception as e:
186
- self._log_error(f"Failover attempt with {backup_provider} failed: {e}")
187
- raise ProviderClientError("All generation providers failed.")
188
-
189
- # --- Input Processors ---
190
- def _process_text(self, text: str) -> str:
191
- """Strip and return plain text input."""
192
- return text.strip()
193
-
194
- def _process_pdf(self, file) -> str:
195
- """Extract and return text from a PDF file."""
196
  try:
197
- with pdfplumber.open(file) as pdf:
198
- return "\n".join((page.extract_text() or "") for page in pdf.pages)
 
 
 
 
 
 
 
 
 
 
 
 
199
  except Exception as e:
200
- self._log_error(f"PDF processing error: {e}")
201
- return ""
202
-
203
- def _process_csv(self, file) -> str:
204
- """Convert CSV file to string via DataFrame conversion."""
205
  try:
206
- df = pd.read_csv(file)
207
- return df.to_csv(index=False)
 
 
 
 
 
 
208
  except Exception as e:
209
- self._log_error(f"CSV processing error: {e}")
210
- return ""
211
-
212
- def _process_api(self, api_url: str) -> str:
213
- """Fetch and return JSON data from the provided API URL."""
214
  try:
215
- response = requests.get(api_url, timeout=10)
 
 
 
 
 
 
216
  response.raise_for_status()
217
- return json.dumps(response.json(), indent=2)
 
218
  except Exception as e:
219
- self._log_error(f"API processing error: {e}")
220
- return ""
221
-
222
- def _process_database(self, config: Dict[str, str]) -> str:
223
  """
224
- Execute a database query using a provided configuration.
225
- Expects a dict with 'connection_string' and 'query' keys.
226
  """
 
227
  try:
228
- connection_string = config.get("connection_string", "")
229
- query = config.get("query", "")
230
- if not connection_string or not query:
231
- raise ValueError("Missing connection string or query.")
232
- engine = sqlalchemy.create_engine(connection_string)
233
- with engine.connect() as connection:
234
- df = pd.read_sql(query, connection)
235
- return df.to_csv(index=False)
236
- except Exception as e:
237
- self._log_error(f"Database processing error: {e}")
238
- return ""
239
-
240
- def _process_web(self, url: str) -> str:
241
- """Fetch and return webpage content using anti-bot headers."""
242
- try:
243
- response = requests.get(url, headers={"User-Agent": "Mozilla/5.0 (SyntheticBot/1.0)"}, timeout=10)
244
- response.raise_for_status()
245
- return response.text
246
- except Exception as e:
247
- self._log_error(f"Web extraction error: {e}")
248
- return ""
249
-
250
- # --- Logging & Diagnostics ---
251
- def _log_error(self, message: str) -> None:
252
- """Log errors centrally and update system metrics."""
253
- st.session_state["system_metrics"]["error_count"] += 1
254
- st.session_state["error_logs"].append(message)
255
- logger.error(message)
256
- if st.session_state.get("debug_mode"):
257
- st.error(f"[DEBUG] {message}")
258
-
259
- def health_check(self) -> Dict[str, Any]:
260
- """Return diagnostics including provider connectivity and system metrics."""
261
- connectivity = {provider: self._test_provider_connectivity(provider)
262
- for provider in self.available_providers}
263
- return {
264
- "providers_available": self.available_providers,
265
- "api_connectivity": connectivity,
266
- "system_metrics": st.session_state["system_metrics"],
267
- }
268
-
269
- def _test_provider_connectivity(self, provider: str) -> bool:
270
- """Test connectivity for a given provider."""
271
- try:
272
- client = self._get_client(provider)
273
  if provider == "HuggingFace":
274
- url = self.PROVIDER_CONFIG[provider]["base_url"]
275
- response = requests.get(url, headers=client["headers"], timeout=5)
276
- return response.status_code == 200
 
 
 
277
  else:
278
- client.models.list()
279
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  except Exception as e:
281
- self._log_error(f"Connectivity test failed for {provider}: {e}")
282
- return False
283
 
284
 
285
- # --- Streamlit UI Components ---
286
- def provider_config_ui(generator: SyntheticDataGenerator) -> None:
287
- """Provider configuration and health check UI."""
288
- with st.sidebar:
289
- st.header("⚙️ AI Engine Configuration")
290
- provider = st.selectbox(
291
- "AI Provider",
292
- generator.available_providers,
293
- index=generator.available_providers.index(st.session_state.get("active_provider", "OpenAI")),
294
- help="Select your preferred AI provider."
295
- )
296
- st.session_state["active_provider"] = provider
297
-
298
- api_key = st.text_input(
299
- f"{provider} API Key",
300
- type="password",
301
- value=st.session_state["api_keys"].get(provider, ""),
302
- help=f"Enter your API key for {provider}."
303
- )
304
- st.session_state["api_keys"][provider] = api_key
305
 
306
- model = st.selectbox(
307
- "Model",
308
- generator.PROVIDER_CONFIG[provider]["models"],
309
- help="Select the model to use."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  )
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- if st.button("Run Health Check"):
313
- report = generator.health_check()
314
- st.json(report)
315
-
316
-
317
- def main() -> None:
318
- """Main Streamlit UI entry point."""
319
- st.set_page_config(page_title="Synthetic Data Factory Pro", page_icon="🏭", layout="wide")
320
- generator = SyntheticDataGenerator()
321
 
322
- st.title("🏭 Synthetic Data Factory Pro")
 
 
323
  st.markdown(
324
  """
325
- **World's Most Advanced Synthetic Data Generation Platform**
326
- *Multi-provider AI Engine | Enterprise Input Processors | Real-time Monitoring*
327
  """
328
  )
329
-
330
- provider_config_ui(generator)
331
-
332
- # --- Input Data Section ---
333
- st.subheader("Input Data")
334
- input_type = st.selectbox("Select Input Type", list(generator.input_processors.keys()))
335
- if input_type == "text":
336
- user_input = st.text_area("Enter your text:")
337
- elif input_type == "pdf":
338
- user_input = st.file_uploader("Upload a PDF file", type=["pdf"])
339
- elif input_type == "csv":
340
- user_input = st.file_uploader("Upload a CSV file", type=["csv"])
341
- elif input_type == "api":
342
- user_input = st.text_input("Enter API URL:")
343
- elif input_type == "database":
344
- user_input = st.text_area("Enter Database Config as JSON (with 'connection_string' and 'query'):")
345
- elif input_type == "web":
346
- user_input = st.text_input("Enter Website URL:")
347
-
348
- processed_input = ""
349
- if st.button("Process Input"):
350
- processor = generator.input_processors.get(input_type)
351
- if processor:
352
- if input_type in ("pdf", "csv"):
353
- processed_input = processor(user_input)
354
- elif input_type == "database":
355
- try:
356
- config = json.loads(user_input)
357
- processed_input = processor(config)
358
- except Exception as e:
359
- st.error("Invalid JSON configuration for database.")
360
- processed_input = ""
361
  else:
362
- processed_input = processor(user_input)
363
- st.text_area("Processed Input", value=processed_input, height=200)
364
-
365
- # --- Data Generation Section ---
366
- st.subheader("Generate Synthetic Data")
367
- prompt = st.text_area("Enter your generation prompt:")
368
- if st.button("Generate"):
369
- active_provider = st.session_state.get("active_provider", "OpenAI")
370
- model = st.selectbox("Select Generation Model", generator.PROVIDER_CONFIG[active_provider]["models"])
371
- try:
372
- result = generator.generate(active_provider, model, prompt)
373
- st.json(result)
374
- except Exception as e:
375
- st.error(f"Data generation failed: {e}")
376
 
377
 
378
  if __name__ == "__main__":
379
- main()
 
1
  import json
 
2
  import requests
3
  import streamlit as st
4
  import pdfplumber
5
  import pandas as pd
6
  import sqlalchemy
7
+ from typing import Any, Dict, List
 
8
 
9
+ # Provider clients ensure these libraries are installed
 
 
 
 
 
 
 
 
 
 
10
  try:
11
  from openai import OpenAI
12
  except ImportError:
 
17
  except ImportError:
18
  groq = None
19
 
20
+ # Hugging Face inference endpoint
21
+ HF_API_URL = "https://api-inference.huggingface.co/models/"
22
+ DEFAULT_TEMPERATURE = 0.1
23
+ GROQ_MODEL = "mixtral-8x7b-32768"
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ class QADataGenerator:
27
+ """
28
+ A Q&A Synthetic Generator that extracts and generates question-answer pairs
29
+ from various input sources using an LLM provider.
30
+ """
31
  def __init__(self) -> None:
 
32
  self._setup_providers()
33
  self._setup_input_handlers()
34
+ self._initialize_session_state()
35
+ # This prompt instructs the LLM to generate three Q&A pairs.
36
+ self.custom_prompt_template = (
37
+ "You are an expert in extracting question and answer pairs from documents. "
38
+ "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
39
+ "Each dictionary must have keys 'question' and 'answer'. "
40
+ "The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
41
+ "Do not hallucinate. \n\n"
42
+ "Example JSON Output:\n"
43
+ "[{'question': 'What is the capital of France?', 'answer': 'Paris'}, "
44
+ "{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}, "
45
+ "{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}]\n\n"
46
+ "Now, generate 3 Q&A pairs from this data:\n{data}"
47
+ )
48
+
49
+ def _setup_providers(self) -> None:
50
+ """Configure available LLM providers and their client initialization routines."""
51
+ self.providers: Dict[str, Dict[str, Any]] = {
52
+ "Deepseek": {
53
+ "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
54
+ "models": ["deepseek-chat"],
55
+ },
56
+ "OpenAI": {
57
+ "client": lambda key: OpenAI(api_key=key) if OpenAI else None,
58
+ "models": ["gpt-4-turbo", "gpt-3.5-turbo"],
59
+ },
60
+ "Groq": {
61
+ "client": lambda key: groq.Groq(api_key=key) if groq else None,
62
+ "models": [GROQ_MODEL],
63
+ },
64
+ "HuggingFace": {
65
+ "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
66
+ "models": ["gpt2", "llama-2"],
67
+ },
68
+ }
69
+
70
+ def _setup_input_handlers(self) -> None:
71
+ """Register handlers for different input data types."""
72
+ self.input_handlers: Dict[str, Any] = {
73
+ "text": self.handle_text,
74
+ "pdf": self.handle_pdf,
75
+ "csv": self.handle_csv,
76
+ "api": self.handle_api,
77
+ "db": self.handle_db,
78
+ }
79
+
80
+ def _initialize_session_state(self) -> None:
81
+ """Initialize Streamlit session state with default configuration."""
82
  defaults = {
83
+ "config": {
84
+ "provider": "OpenAI",
85
+ "model": "gpt-4-turbo",
86
+ "temperature": DEFAULT_TEMPERATURE,
87
+ },
88
+ "api_key": "",
89
+ "inputs": [], # List to store input sources
90
+ "qa_pairs": "", # Generated Q&A pairs output
91
+ "error_logs": [], # To store any error messages
92
  }
93
  for key, value in defaults.items():
94
  if key not in st.session_state:
95
  st.session_state[key] = value
96
+
97
+ def log_error(self, message: str) -> None:
98
+ """Log an error message to session state and display it."""
99
+ st.session_state.error_logs.append(message)
100
+ st.error(message)
101
+
102
+ # ----- Input Handlers -----
103
+ def handle_text(self, text: str) -> Dict[str, Any]:
104
+ return {"data": text, "source": "text"}
105
+
106
+ def handle_pdf(self, file) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  try:
108
+ with pdfplumber.open(file) as pdf:
109
+ full_text = ""
110
+ for page in pdf.pages:
111
+ page_text = page.extract_text() or ""
112
+ full_text += page_text + "\n"
113
+ return {"data": full_text, "source": "pdf"}
114
  except Exception as e:
115
+ self.log_error(f"PDF Processing Error: {e}")
116
+ return {"data": "", "source": "pdf"}
117
+
118
+ def handle_csv(self, file) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
119
  try:
120
+ df = pd.read_csv(file)
121
+ # Convert the DataFrame to a JSON string
122
+ return {"data": df.to_json(orient="records"), "source": "csv"}
 
 
 
 
 
 
 
123
  except Exception as e:
124
+ self.log_error(f"CSV Processing Error: {e}")
125
+ return {"data": "", "source": "csv"}
126
+
127
+ def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
128
+ try:
129
+ response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
130
+ response.raise_for_status()
131
+ return {"data": json.dumps(response.json()), "source": "api"}
132
+ except Exception as e:
133
+ self.log_error(f"API Processing Error: {e}")
134
+ return {"data": "", "source": "api"}
135
+
136
+ def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
137
+ try:
138
+ engine = sqlalchemy.create_engine(config["connection"])
139
+ with engine.connect() as conn:
140
+ result = conn.execute(sqlalchemy.text(config["query"]))
141
+ rows = [dict(row) for row in result]
142
+ return {"data": json.dumps(rows), "source": "db"}
143
+ except Exception as e:
144
+ self.log_error(f"Database Processing Error: {e}")
145
+ return {"data": "", "source": "db"}
146
+
147
+ def aggregate_inputs(self) -> str:
148
+ """Combine all input sources into a single aggregated string."""
149
+ aggregated_data = ""
150
+ for item in st.session_state.inputs:
151
+ aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
152
+ aggregated_data += item.get("data", "") + "\n\n"
153
+ return aggregated_data.strip()
154
+
155
+ def build_prompt(self) -> str:
156
  """
157
+ Build the complete prompt using the custom template and aggregated inputs.
 
158
  """
159
+ data = self.aggregate_inputs()
160
+ prompt = self.custom_prompt_template.format(data=data)
161
+ st.write("### Built Prompt")
162
+ st.write(prompt)
163
+ return prompt
164
+
165
+ def generate_qa_pairs(self) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  """
167
+ Generate Q&A pairs by sending the built prompt to the selected LLM provider.
168
  """
169
+ api_key = st.session_state.api_key
170
+ if not api_key:
171
+ self.log_error("API key is missing!")
172
+ return False
173
+
174
+ provider_name = st.session_state.config["provider"]
175
+ provider_cfg = self.providers.get(provider_name)
176
+ if not provider_cfg:
177
+ self.log_error(f"Provider {provider_name} is not configured.")
178
+ return False
179
+
180
+ client_initializer = provider_cfg["client"]
181
+ client = client_initializer(api_key)
182
+ model = st.session_state.config["model"]
183
+ temperature = st.session_state.config["temperature"]
184
+ prompt = self.build_prompt()
185
+
186
+ st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
 
 
 
187
  try:
188
+ if provider_name == "HuggingFace":
189
+ response = self._huggingface_inference(client, prompt, model)
190
+ else:
191
+ response = self._standard_inference(client, prompt, model, temperature)
192
+
193
+ st.write("### Raw API Response")
194
+ st.write(response)
195
+
196
+ qa_pairs = self._parse_response(response, provider_name)
197
+ st.write("### Parsed Q&A Pairs")
198
+ st.write(qa_pairs)
199
+
200
+ st.session_state.qa_pairs = qa_pairs
201
+ return True
202
  except Exception as e:
203
+ self.log_error(f"Generation failed: {e}")
204
+ return False
205
+
206
+ def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
207
+ """Inference method for providers using an OpenAI-compatible API."""
208
  try:
209
+ st.write("Sending prompt via standard inference...")
210
+ result = client.chat.completions.create(
211
+ model=model,
212
+ messages=[{"role": "user", "content": prompt}],
213
+ temperature=temperature,
214
+ )
215
+ st.write("Standard inference result received.")
216
+ return result
217
  except Exception as e:
218
+ self.log_error(f"Standard Inference Error: {e}")
219
+ return None
220
+
221
+ def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
222
+ """Inference method for the Hugging Face Inference API."""
223
  try:
224
+ st.write("Sending prompt to HuggingFace API...")
225
+ response = requests.post(
226
+ HF_API_URL + model,
227
+ headers=client["headers"],
228
+ json={"inputs": prompt},
229
+ timeout=30,
230
+ )
231
  response.raise_for_status()
232
+ st.write("HuggingFace API response received.")
233
+ return response.json()
234
  except Exception as e:
235
+ self.log_error(f"HuggingFace Inference Error: {e}")
236
+ return None
237
+
238
+ def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
239
  """
240
+ Parse the LLM response and return a list of Q&A pairs.
241
+ Expects the response to be JSON formatted.
242
  """
243
+ st.write("Parsing response for provider:", provider)
244
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  if provider == "HuggingFace":
246
+ # For HuggingFace, assume the generated text is under "generated_text"
247
+ if isinstance(response, list) and response and "generated_text" in response[0]:
248
+ raw_text = response[0]["generated_text"]
249
+ else:
250
+ self.log_error("Unexpected HuggingFace response format.")
251
+ return []
252
  else:
253
+ # For OpenAI (and similar providers) assume the response is similar to:
254
+ # response.choices[0].message.content
255
+ if response and hasattr(response, "choices") and response.choices:
256
+ raw_text = response.choices[0].message.content
257
+ else:
258
+ self.log_error("Unexpected response format from provider.")
259
+ return []
260
+
261
+ # Try parsing the raw text as JSON
262
+ try:
263
+ qa_list = json.loads(raw_text)
264
+ if isinstance(qa_list, list):
265
+ return qa_list
266
+ else:
267
+ self.log_error("Parsed output is not a list.")
268
+ return []
269
+ except json.JSONDecodeError as e:
270
+ self.log_error(f"JSON Parsing Error: {e}. Raw output: {raw_text}")
271
+ return []
272
  except Exception as e:
273
+ self.log_error(f"Response Parsing Error: {e}")
274
+ return []
275
 
276
 
277
+ # ============ UI Components ============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ def config_ui(generator: QADataGenerator):
280
+ """Display configuration options in the sidebar."""
281
+ with st.sidebar:
282
+ st.header("Configuration")
283
+ provider = st.selectbox("Select Provider", list(generator.providers.keys()))
284
+ st.session_state.config["provider"] = provider
285
+ provider_cfg = generator.providers[provider]
286
+
287
+ model = st.selectbox("Select Model", provider_cfg["models"])
288
+ st.session_state.config["model"] = model
289
+
290
+ temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
291
+ st.session_state.config["temperature"] = temperature
292
+
293
+ api_key = st.text_input(f"{provider} API Key", type="password")
294
+ st.session_state.api_key = api_key
295
+
296
+ def input_ui(generator: QADataGenerator):
297
+ """Display input data source options using tabs."""
298
+ st.subheader("Input Data Sources")
299
+ tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
300
+
301
+ with tabs[0]:
302
+ text_input = st.text_area("Enter text input", height=150)
303
+ if st.button("Add Text Input", key="text_input"):
304
+ if text_input.strip():
305
+ st.session_state.inputs.append(generator.handle_text(text_input))
306
+ st.success("Text input added!")
307
+ else:
308
+ st.warning("Empty text input.")
309
+
310
+ with tabs[1]:
311
+ pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
312
+ if pdf_file is not None:
313
+ st.session_state.inputs.append(generator.handle_pdf(pdf_file))
314
+ st.success("PDF input added!")
315
+
316
+ with tabs[2]:
317
+ csv_file = st.file_uploader("Upload CSV", type=["csv"])
318
+ if csv_file is not None:
319
+ st.session_state.inputs.append(generator.handle_csv(csv_file))
320
+ st.success("CSV input added!")
321
+
322
+ with tabs[3]:
323
+ api_url = st.text_input("API Endpoint URL")
324
+ api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
325
+ if st.button("Add API Input", key="api_input"):
326
+ headers = {}
327
+ try:
328
+ if api_headers:
329
+ headers = json.loads(api_headers)
330
+ except Exception as e:
331
+ generator.log_error(f"Invalid JSON for API Headers: {e}")
332
+ st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
333
+ st.success("API input added!")
334
+
335
+ with tabs[4]:
336
+ db_conn = st.text_input("Database Connection String")
337
+ db_query = st.text_area("Database Query", height=100)
338
+ if st.button("Add Database Input", key="db_input"):
339
+ st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
340
+ st.success("Database input added!")
341
+
342
+ def output_ui(generator: QADataGenerator):
343
+ """Display the generated Q&A pairs and provide a download option."""
344
+ st.subheader("Q&A Pairs Output")
345
+ if st.session_state.qa_pairs:
346
+ st.write("### Generated Q&A Pairs")
347
+ st.write(st.session_state.qa_pairs)
348
+ st.download_button(
349
+ "Download Output",
350
+ json.dumps(st.session_state.qa_pairs, indent=2),
351
+ file_name="qa_pairs.json",
352
+ mime="application/json"
353
  )
354
+ else:
355
+ st.info("No Q&A pairs generated yet.")
356
+
357
+ def logs_ui():
358
+ """Display error logs and debugging information in an expandable section."""
359
+ with st.expander("Error Logs & Debug Info", expanded=False):
360
+ if st.session_state.error_logs:
361
+ for log in st.session_state.error_logs:
362
+ st.write(log)
363
+ else:
364
+ st.write("No logs yet.")
365
 
 
 
 
 
 
 
 
 
 
366
 
367
+ def main():
368
+ st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
369
+ st.title("Advanced Q&A Synthetic Generator")
370
  st.markdown(
371
  """
372
+ Welcome to the Advanced Q&A Synthetic Generator. This tool extracts and generates question-answer pairs
373
+ from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs.
374
  """
375
  )
376
+
377
+ # Initialize generator and display configuration UI
378
+ generator = QADataGenerator()
379
+ config_ui(generator)
380
+
381
+ st.header("1. Input Data")
382
+ input_ui(generator)
383
+ if st.button("Clear All Inputs"):
384
+ st.session_state.inputs = []
385
+ st.success("All inputs have been cleared!")
386
+
387
+ st.header("2. Generate Q&A Pairs")
388
+ if st.button("Generate Q&A Pairs", key="generate_qa"):
389
+ with st.spinner("Generating Q&A pairs..."):
390
+ if generator.generate_qa_pairs():
391
+ st.success("Q&A pairs generated successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  else:
393
+ st.error("Q&A generation failed. Check logs for details.")
394
+
395
+ st.header("3. Output")
396
+ output_ui(generator)
397
+
398
+ st.header("4. Logs & Debug Information")
399
+ logs_ui()
 
 
 
 
 
 
 
400
 
401
 
402
  if __name__ == "__main__":
403
+ main()