mgbam commited on
Commit
6bba837
·
verified ·
1 Parent(s): e9a68df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -148
app.py CHANGED
@@ -1,33 +1,36 @@
 
 
1
  import streamlit as st
2
  import pdfplumber
3
  import pandas as pd
4
- import requests
5
- import json
6
  from PIL import Image
 
 
 
7
  from openai import OpenAI
8
  import google.generative_ai as genai
9
  import groq
10
- import sqlalchemy
11
- from typing import Dict, Any
12
 
13
  # --- CONSTANTS ---
14
  HF_API_URL = "https://api-inference.huggingface.co/models/"
15
  DEFAULT_TEMPERATURE = 0.1
16
- MODEL = "mixtral-8x7b-32768" # Groq model
17
- API_HEADERS_HEIGHT = 70 # Minimum height for st.text_area
18
 
19
 
20
  class SyntheticDataGenerator:
21
- """Generates synthetic Q&A data from various input sources using LLMs."""
22
-
23
- def __init__(self):
 
24
  self._setup_providers()
25
  self._setup_input_handlers()
26
  self._initialize_session_state()
27
 
28
- def _setup_providers(self):
29
- """Defines the available LLM providers and their configurations."""
30
- self.providers = {
31
  "Deepseek": {
32
  "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
33
  "models": ["deepseek-chat"],
@@ -38,7 +41,7 @@ class SyntheticDataGenerator:
38
  },
39
  "Groq": {
40
  "client": lambda key: groq.Groq(api_key=key),
41
- "models": [MODEL],
42
  },
43
  "HuggingFace": {
44
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
@@ -50,9 +53,9 @@ class SyntheticDataGenerator:
50
  },
51
  }
52
 
53
- def _setup_input_handlers(self):
54
- """Defines handlers for different input data types."""
55
- self.input_handlers = {
56
  "pdf": self.handle_pdf,
57
  "text": self.handle_text,
58
  "csv": self.handle_csv,
@@ -60,21 +63,25 @@ class SyntheticDataGenerator:
60
  "db": self.handle_db,
61
  }
62
 
63
- def _initialize_session_state(self):
64
- """Initializes Streamlit session state variables."""
65
  session_defaults = {
66
  "inputs": [],
67
  "qa_data": [],
68
  "processing": {"stage": "idle", "progress": 0, "errors": []},
69
- "config": {"provider": "Groq", "model": MODEL, "temperature": DEFAULT_TEMPERATURE},
70
- "api_key": "", # Explicitly initialize api_key in session state
 
 
 
 
71
  }
72
  for key, value in session_defaults.items():
73
  if key not in st.session_state:
74
  st.session_state[key] = value
75
 
76
- def _configure_google_genai(self, api_key: str):
77
- """Configures the Google Generative AI client."""
78
  try:
79
  genai.configure(api_key=api_key)
80
  return genai.GenerativeModel
@@ -83,50 +90,63 @@ class SyntheticDataGenerator:
83
  return None
84
 
85
  # --- INPUT HANDLERS ---
86
- def handle_pdf(self, file):
87
- """Extracts text and images from a PDF file."""
 
 
 
 
 
88
  try:
89
  with pdfplumber.open(file) as pdf:
90
  extracted_data = []
91
  for i, page in enumerate(pdf.pages):
92
  page_text = page.extract_text() or ""
93
  page_images = self.process_images(page)
94
- extracted_data.append(
95
- {"text": page_text, "images": page_images, "meta": {"type": "pdf", "page": i + 1}}
96
- )
 
 
97
  return extracted_data
98
  except Exception as e:
99
- self._log_error(f"PDF Error: {str(e)}")
100
  return []
101
 
102
- def handle_text(self, text):
103
- """Handles manual text input."""
104
  return [{"text": text, "meta": {"type": "domain", "source": "manual"}}]
105
 
106
- def handle_csv(self, file):
107
- """Reads a CSV file and prepares data for Q&A generation."""
108
  try:
109
  df = pd.read_csv(file)
110
  return [
111
- {"text": "\n".join([f"{col}: {row[col]}" for col in df.columns]), "meta": {"type": "csv", "columns": list(df.columns)}}
 
 
 
112
  for _, row in df.iterrows()
113
  ]
114
  except Exception as e:
115
- self._log_error(f"CSV Error: {str(e)}")
116
  return []
117
 
118
- def handle_api(self, config):
119
- """Fetches data from an API endpoint."""
120
  try:
121
- response = requests.get(config["url"], headers=config["headers"], timeout=10) # Add timeout
122
- response.raise_for_status() # Raise HTTPError for bad responses
123
- return [{"text": json.dumps(response.json()), "meta": {"type": "api", "endpoint": config["url"]}}]
 
 
 
124
  except requests.exceptions.RequestException as e:
125
- self._log_error(f"API Error: {str(e)}")
126
  return []
127
 
128
- def handle_db(self, config):
129
- """Connects to a database and executes a query."""
130
  try:
131
  engine = sqlalchemy.create_engine(config["connection"])
132
  with engine.connect() as conn:
@@ -139,11 +159,11 @@ class SyntheticDataGenerator:
139
  for row in result
140
  ]
141
  except Exception as e:
142
- self._log_error(f"DB Error: {str(e)}")
143
  return []
144
 
145
- def process_images(self, page):
146
- """Extracts and processes images from a PDF page."""
147
  images = []
148
  for img in page.images:
149
  try:
@@ -151,69 +171,70 @@ class SyntheticDataGenerator:
151
  width = int(stream.get("Width", 0))
152
  height = int(stream.get("Height", 0))
153
  image_data = stream.get_data()
154
-
155
  if width > 0 and height > 0 and image_data:
156
  try:
157
  image = Image.frombytes("RGB", (width, height), image_data)
158
  images.append({"data": image, "meta": {"dims": (width, height)}})
159
  except Exception as e:
160
- self._log_error(f"Image Creation Error: {str(e)}. Width: {width}, Height: {height}")
161
  else:
162
- self._log_error(
163
- f"Image Error: Insufficient data or invalid dimensions (w={width}, h={height})"
164
- )
165
-
166
  except Exception as e:
167
- self._log_error(f"Image Extraction Error: {str(e)}")
168
  return images
169
 
170
  # --- LLM INFERENCE ---
171
  def generate(self, api_key: str) -> bool:
172
- """Generates Q&A pairs using the selected LLM provider."""
173
- try:
174
- if not api_key:
175
- st.error("API Key cannot be empty.")
176
- return False
 
 
 
 
177
 
178
- provider_cfg = self.providers[st.session_state.config["provider"]]
 
 
179
  client_initializer = provider_cfg["client"]
180
 
181
- if st.session_state.config["provider"] == "Google":
 
182
  client = client_initializer(api_key)
183
  if not client:
184
- return False # Google config failed
185
  else:
186
  client = client_initializer(api_key)
187
 
188
  for i, input_data in enumerate(st.session_state.inputs):
189
  st.session_state.processing["progress"] = (i + 1) / len(st.session_state.inputs)
190
-
191
- # Debugging: Display input data
192
  st.write("--- Input Data ---")
193
  st.write(input_data["text"])
194
 
195
- if st.session_state.config["provider"] == "HuggingFace":
196
  response = self._huggingface_inference(client, input_data)
197
- elif st.session_state.config["provider"] == "Google":
198
  response = self._google_inference(client, input_data)
199
  else:
200
  response = self._standard_inference(client, input_data)
201
 
202
  if response:
203
- # Debugging: Display raw response
204
  st.write("--- Raw Response ---")
205
  st.write(response)
206
-
207
- st.session_state.qa_data.extend(self._parse_response(response, st.session_state.config["provider"]))
 
208
 
209
  return True
210
 
211
  except Exception as e:
212
- self._log_error(f"Generation Error: {str(e)}")
213
  return False
214
 
215
- def _standard_inference(self, client, input_data):
216
- """Performs inference using OpenAI-compatible API."""
217
  try:
218
  return client.chat.completions.create(
219
  model=st.session_state.config["model"],
@@ -224,8 +245,8 @@ class SyntheticDataGenerator:
224
  self._log_error(f"OpenAI Inference Error: {e}")
225
  return None
226
 
227
- def _huggingface_inference(self, client, input_data):
228
- """Performs inference using Hugging Face Inference API."""
229
  try:
230
  response = requests.post(
231
  HF_API_URL + st.session_state.config["model"],
@@ -238,13 +259,15 @@ class SyntheticDataGenerator:
238
  self._log_error(f"Hugging Face Inference Error: {e}")
239
  return None
240
 
241
- def _google_inference(self, client, input_data):
242
- """Performs inference using Google Generative AI API."""
243
  try:
244
  model = client(st.session_state.config["model"])
245
  response = model.generate_content(
246
  self._build_prompt(input_data),
247
- generation_config=genai.types.GenerationConfig(temperature=st.session_state.config["temperature"]),
 
 
248
  )
249
  return response
250
  except Exception as e:
@@ -252,172 +275,175 @@ class SyntheticDataGenerator:
252
  return None
253
 
254
  # --- PROMPT ENGINEERING ---
255
- def _build_prompt(self, input_data):
256
- """Builds the prompt for the LLM based on the input data type."""
257
- base = (
 
 
 
 
258
  "You are an expert in extracting question and answer pairs from documents. "
259
  "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries.\n"
260
  "Each dictionary must have the keys 'question' and 'answer'.\n"
261
- "The 'question' should be clear and concise, and the 'answer' should directly answer the question using only "
262
- "information from the data. Do not hallucinate or invent information.\n"
263
- "Answer from the exact same document, not outside from the document\n"
264
  "Example JSON Output:\n"
265
  '[{"question": "What is the capital of France?", "answer": "The capital of France is Paris."}, '
266
  '{"question": "What is the highest mountain in the world?", "answer": "The highest mountain in the world is Mount Everest."}, '
267
  '{"question": "What is the chemical symbol for gold?", "answer": "The chemical symbol for gold is Au."}]\n'
268
  "Now, generate 3 Q&A pairs from this data:\n"
269
  )
270
-
271
- if input_data["meta"]["type"] == "csv":
272
- return base + "Data:\n" + input_data["text"]
273
- elif input_data["meta"]["type"] == "api":
274
- return base + "API response:\n" + input_data["text"]
275
- return base + input_data["text"]
276
 
277
  # --- RESPONSE PARSING ---
278
- def _parse_response(self, response: Any, provider: str) -> list[dict[str, str]]:
279
- """Parses the LLM response into a list of Q&A pairs."""
 
 
 
 
280
  try:
281
  response_text = ""
282
-
283
  if provider == "HuggingFace":
284
- response_text = response[0]["generated_text"]
285
- return response_text
286
  elif provider == "Google":
287
  response_text = response.text.strip()
288
-
289
  else: # OpenAI, Deepseek, Groq
290
  if not response or not response.choices or not response.choices[0].message.content:
291
  self._log_error("Empty or malformed response from LLM.")
292
  return []
293
-
294
  response_text = response.choices[0].message.content
295
 
296
  try:
297
  json_output = json.loads(response_text)
 
 
 
298
 
299
- if isinstance(json_output, list):
300
- qa_pairs = json_output
301
- elif isinstance(json_output, dict) and "questionList" in json_output:
302
- qa_pairs = json_output["questionList"]
303
- else:
304
- self._log_error(f"Unexpected JSON structure: {response_text}")
305
- return []
306
-
307
- if not isinstance(qa_pairs, list):
308
- self._log_error(f"Expected a list of QA pairs, but got: {type(qa_pairs)}")
309
- return []
310
 
311
- for pair in qa_pairs:
312
- if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
313
- self._log_error(f"Invalid QA pair structure: {pair}")
314
- return []
315
 
316
- return qa_pairs
 
 
 
317
 
318
- except json.JSONDecodeError as e:
319
- self._log_error(f"JSON Parse Error: {e}. Raw Response: {response_text}")
320
- return []
321
 
322
  except Exception as e:
323
  self._log_error(f"Parse Error: {e}. Raw Response: {response}")
324
  return []
325
 
326
- def _log_error(self, message):
327
- """Logs an error message to Streamlit session state and displays it."""
328
  st.session_state.processing["errors"].append(message)
329
  st.error(message)
330
 
331
 
332
  # --- STREAMLIT UI COMPONENTS ---
333
- def input_sidebar(gen: SyntheticDataGenerator) -> str:
334
- """Creates the input sidebar in the Streamlit UI."""
335
  with st.sidebar:
336
  st.header("⚙️ Configuration")
337
-
338
- provider = st.selectbox("Provider", list(gen.providers.keys()))
339
- st.session_state.config["provider"] = provider # Update session state immediately
340
- provider_cfg = gen.providers[provider]
341
 
342
  api_key = st.text_input(f"{provider} API Key", type="password")
343
  st.session_state["api_key"] = api_key
344
 
345
  model = st.selectbox("Model", provider_cfg["models"])
346
- st.session_state.config["model"] = model # Update model selection
347
 
348
- temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
349
- st.session_state.config["temperature"] = temp # Update temperature
350
 
351
- # Input Source Selection
352
  st.header("🔗 Data Sources")
353
- input_type = st.selectbox("Input Type", list(gen.input_handlers.keys()))
354
 
355
  if input_type == "text":
356
  domain_input = st.text_area("Domain Knowledge", height=150)
357
  if st.button("Add Domain Input"):
358
- st.session_state.inputs.append(gen.input_handlers["text"](domain_input)[0])
359
 
360
  elif input_type == "csv":
361
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
362
  if csv_file:
363
- st.session_state.inputs.extend(gen.input_handlers["csv"](csv_file))
364
 
365
  elif input_type == "api":
366
  api_url = st.text_input("API Endpoint")
367
  api_headers = st.text_area("API Headers (JSON format, optional)", height=API_HEADERS_HEIGHT)
368
  headers = {}
369
- try:
370
- if api_headers:
371
  headers = json.loads(api_headers)
372
- except json.JSONDecodeError:
373
- st.error("Invalid JSON format for API headers.")
374
  if st.button("Add API Input"):
375
- st.session_state.inputs.extend(gen.input_handlers["api"]({"url": api_url, "headers": headers}))
376
 
377
  elif input_type == "db":
378
  db_connection = st.text_input("Database Connection String")
379
  db_query = st.text_area("Database Query")
380
  db_table = st.text_input("Table Name (optional)")
381
  if st.button("Add DB Input"):
382
- st.session_state.inputs.extend(
383
- gen.input_handlers["db"]({"connection": db_connection, "query": db_query, "table": db_table})
384
- )
 
 
385
 
386
- return api_key
387
 
388
 
389
- def main_display(gen: SyntheticDataGenerator):
390
- """Creates the main display area in the Streamlit UI."""
391
  st.title("🚀 Enterprise Synthetic Data Factory")
392
 
393
  col1, col2 = st.columns([3, 1])
394
  with col1:
395
  pdf_file = st.file_uploader("Upload Document", type=["pdf"])
396
  if pdf_file:
397
- st.session_state.inputs.extend(gen.input_handlers["pdf"](pdf_file))
398
 
399
  with col2:
400
  if st.button("Start Generation"):
401
- with st.status("Processing..."):
402
  if not st.session_state["api_key"]:
403
  st.error("Please provide an API Key.")
404
  else:
405
- gen.generate(st.session_state["api_key"])
406
 
407
  if st.session_state.qa_data:
408
  st.header("Generated Data")
409
  df = pd.DataFrame(st.session_state.qa_data)
410
  st.dataframe(df)
411
-
412
  st.download_button("Export CSV", df.to_csv(index=False), "synthetic_data.csv")
413
 
414
 
415
- def main():
416
  """Main function to run the Streamlit application."""
417
- gen = SyntheticDataGenerator()
418
- api_key = input_sidebar(gen)
419
- main_display(gen)
420
 
421
 
422
  if __name__ == "__main__":
423
- main()
 
1
+ import json
2
+ import requests
3
  import streamlit as st
4
  import pdfplumber
5
  import pandas as pd
6
+ import sqlalchemy
 
7
  from PIL import Image
8
+ from typing import Any, Dict, List
9
+
10
+ # Provider clients
11
  from openai import OpenAI
12
  import google.generative_ai as genai
13
  import groq
 
 
14
 
15
  # --- CONSTANTS ---
16
  HF_API_URL = "https://api-inference.huggingface.co/models/"
17
  DEFAULT_TEMPERATURE = 0.1
18
+ GROQ_MODEL = "mixtral-8x7b-32768" # Groq model
19
+ API_HEADERS_HEIGHT = 70 # Height for the API headers text area
20
 
21
 
22
  class SyntheticDataGenerator:
23
+ """
24
+ Generates synthetic Q&A data from various input sources using multiple LLM providers.
25
+ """
26
+ def __init__(self) -> None:
27
  self._setup_providers()
28
  self._setup_input_handlers()
29
  self._initialize_session_state()
30
 
31
+ def _setup_providers(self) -> None:
32
+ """Configure available LLM providers and their client initializations."""
33
+ self.providers: Dict[str, Dict[str, Any]] = {
34
  "Deepseek": {
35
  "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
36
  "models": ["deepseek-chat"],
 
41
  },
42
  "Groq": {
43
  "client": lambda key: groq.Groq(api_key=key),
44
+ "models": [GROQ_MODEL],
45
  },
46
  "HuggingFace": {
47
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
 
53
  },
54
  }
55
 
56
+ def _setup_input_handlers(self) -> None:
57
+ """Define handlers for different input data types."""
58
+ self.input_handlers: Dict[str, Any] = {
59
  "pdf": self.handle_pdf,
60
  "text": self.handle_text,
61
  "csv": self.handle_csv,
 
63
  "db": self.handle_db,
64
  }
65
 
66
+ def _initialize_session_state(self) -> None:
67
+ """Initialize Streamlit session state with default configurations."""
68
  session_defaults = {
69
  "inputs": [],
70
  "qa_data": [],
71
  "processing": {"stage": "idle", "progress": 0, "errors": []},
72
+ "config": {
73
+ "provider": "Groq",
74
+ "model": GROQ_MODEL,
75
+ "temperature": DEFAULT_TEMPERATURE,
76
+ },
77
+ "api_key": "", # Explicitly initialize the API key
78
  }
79
  for key, value in session_defaults.items():
80
  if key not in st.session_state:
81
  st.session_state[key] = value
82
 
83
+ def _configure_google_genai(self, api_key: str) -> Any:
84
+ """Configure and return the Google Generative AI client."""
85
  try:
86
  genai.configure(api_key=api_key)
87
  return genai.GenerativeModel
 
90
  return None
91
 
92
  # --- INPUT HANDLERS ---
93
+ def handle_pdf(self, file) -> List[Dict[str, Any]]:
94
+ """
95
+ Extract text and images from a PDF file.
96
+
97
+ Returns:
98
+ A list of dictionaries containing text, images, and metadata.
99
+ """
100
  try:
101
  with pdfplumber.open(file) as pdf:
102
  extracted_data = []
103
  for i, page in enumerate(pdf.pages):
104
  page_text = page.extract_text() or ""
105
  page_images = self.process_images(page)
106
+ extracted_data.append({
107
+ "text": page_text,
108
+ "images": page_images,
109
+ "meta": {"type": "pdf", "page": i + 1},
110
+ })
111
  return extracted_data
112
  except Exception as e:
113
+ self._log_error(f"PDF Error: {e}")
114
  return []
115
 
116
+ def handle_text(self, text: str) -> List[Dict[str, Any]]:
117
+ """Handle manual text input."""
118
  return [{"text": text, "meta": {"type": "domain", "source": "manual"}}]
119
 
120
+ def handle_csv(self, file) -> List[Dict[str, Any]]:
121
+ """Process a CSV file and format the data for Q&A generation."""
122
  try:
123
  df = pd.read_csv(file)
124
  return [
125
+ {
126
+ "text": "\n".join([f"{col}: {row[col]}" for col in df.columns]),
127
+ "meta": {"type": "csv", "columns": list(df.columns)},
128
+ }
129
  for _, row in df.iterrows()
130
  ]
131
  except Exception as e:
132
+ self._log_error(f"CSV Error: {e}")
133
  return []
134
 
135
+ def handle_api(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
136
+ """Fetch data from an API endpoint and format it for processing."""
137
  try:
138
+ response = requests.get(config["url"], headers=config["headers"], timeout=10)
139
+ response.raise_for_status()
140
+ return [{
141
+ "text": json.dumps(response.json()),
142
+ "meta": {"type": "api", "endpoint": config["url"]},
143
+ }]
144
  except requests.exceptions.RequestException as e:
145
+ self._log_error(f"API Error: {e}")
146
  return []
147
 
148
+ def handle_db(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
149
+ """Connect to a database, execute a query, and format the results."""
150
  try:
151
  engine = sqlalchemy.create_engine(config["connection"])
152
  with engine.connect() as conn:
 
159
  for row in result
160
  ]
161
  except Exception as e:
162
+ self._log_error(f"DB Error: {e}")
163
  return []
164
 
165
+ def process_images(self, page) -> List[Dict[str, Any]]:
166
+ """Extract and process images from a PDF page."""
167
  images = []
168
  for img in page.images:
169
  try:
 
171
  width = int(stream.get("Width", 0))
172
  height = int(stream.get("Height", 0))
173
  image_data = stream.get_data()
 
174
  if width > 0 and height > 0 and image_data:
175
  try:
176
  image = Image.frombytes("RGB", (width, height), image_data)
177
  images.append({"data": image, "meta": {"dims": (width, height)}})
178
  except Exception as e:
179
+ self._log_error(f"Image Creation Error: {e} (Width: {width}, Height: {height})")
180
  else:
181
+ self._log_error(f"Image Error: Insufficient data or invalid dimensions (w={width}, h={height})")
 
 
 
182
  except Exception as e:
183
+ self._log_error(f"Image Extraction Error: {e}")
184
  return images
185
 
186
  # --- LLM INFERENCE ---
187
  def generate(self, api_key: str) -> bool:
188
+ """
189
+ Generate Q&A pairs using the selected LLM provider.
190
+
191
+ Iterates over all the input data, calls the appropriate inference method,
192
+ and aggregates the generated Q&A pairs into session state.
193
+ """
194
+ if not api_key:
195
+ st.error("API Key cannot be empty.")
196
+ return False
197
 
198
+ try:
199
+ provider_name = st.session_state.config["provider"]
200
+ provider_cfg = self.providers[provider_name]
201
  client_initializer = provider_cfg["client"]
202
 
203
+ # Initialize the client
204
+ if provider_name == "Google":
205
  client = client_initializer(api_key)
206
  if not client:
207
+ return False
208
  else:
209
  client = client_initializer(api_key)
210
 
211
  for i, input_data in enumerate(st.session_state.inputs):
212
  st.session_state.processing["progress"] = (i + 1) / len(st.session_state.inputs)
 
 
213
  st.write("--- Input Data ---")
214
  st.write(input_data["text"])
215
 
216
+ if provider_name == "HuggingFace":
217
  response = self._huggingface_inference(client, input_data)
218
+ elif provider_name == "Google":
219
  response = self._google_inference(client, input_data)
220
  else:
221
  response = self._standard_inference(client, input_data)
222
 
223
  if response:
 
224
  st.write("--- Raw Response ---")
225
  st.write(response)
226
+ parsed_response = self._parse_response(response, provider_name)
227
+ if parsed_response:
228
+ st.session_state.qa_data.extend(parsed_response)
229
 
230
  return True
231
 
232
  except Exception as e:
233
+ self._log_error(f"Generation Error: {e}")
234
  return False
235
 
236
+ def _standard_inference(self, client: Any, input_data: Dict[str, Any]) -> Any:
237
+ """Perform inference using an OpenAI-compatible API."""
238
  try:
239
  return client.chat.completions.create(
240
  model=st.session_state.config["model"],
 
245
  self._log_error(f"OpenAI Inference Error: {e}")
246
  return None
247
 
248
+ def _huggingface_inference(self, client: Dict[str, Any], input_data: Dict[str, Any]) -> Any:
249
+ """Perform inference using the Hugging Face Inference API."""
250
  try:
251
  response = requests.post(
252
  HF_API_URL + st.session_state.config["model"],
 
259
  self._log_error(f"Hugging Face Inference Error: {e}")
260
  return None
261
 
262
+ def _google_inference(self, client: Any, input_data: Dict[str, Any]) -> Any:
263
+ """Perform inference using the Google Generative AI API."""
264
  try:
265
  model = client(st.session_state.config["model"])
266
  response = model.generate_content(
267
  self._build_prompt(input_data),
268
+ generation_config=genai.types.GenerationConfig(
269
+ temperature=st.session_state.config["temperature"]
270
+ ),
271
  )
272
  return response
273
  except Exception as e:
 
275
  return None
276
 
277
  # --- PROMPT ENGINEERING ---
278
+ def _build_prompt(self, input_data: Dict[str, Any]) -> str:
279
+ """
280
+ Build the prompt for the LLM based on the input data.
281
+
282
+ The prompt instructs the LLM to extract 3 Q&A pairs in JSON format.
283
+ """
284
+ base_prompt = (
285
  "You are an expert in extracting question and answer pairs from documents. "
286
  "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries.\n"
287
  "Each dictionary must have the keys 'question' and 'answer'.\n"
288
+ "The 'question' should be clear and concise, and the 'answer' should directly answer the question "
289
+ "using only information from the provided data. Do not hallucinate or invent information.\n"
290
+ "Answer using the exact information from the document, not external knowledge.\n"
291
  "Example JSON Output:\n"
292
  '[{"question": "What is the capital of France?", "answer": "The capital of France is Paris."}, '
293
  '{"question": "What is the highest mountain in the world?", "answer": "The highest mountain in the world is Mount Everest."}, '
294
  '{"question": "What is the chemical symbol for gold?", "answer": "The chemical symbol for gold is Au."}]\n'
295
  "Now, generate 3 Q&A pairs from this data:\n"
296
  )
297
+ data_type = input_data["meta"].get("type", "text")
298
+ if data_type == "csv":
299
+ return base_prompt + "Data:\n" + input_data["text"]
300
+ elif data_type == "api":
301
+ return base_prompt + "API response:\n" + input_data["text"]
302
+ return base_prompt + input_data["text"]
303
 
304
  # --- RESPONSE PARSING ---
305
+ def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
306
+ """
307
+ Parse the LLM response into a list of Q&A pairs.
308
+
309
+ Expects the response to be a JSON formatted string.
310
+ """
311
  try:
312
  response_text = ""
 
313
  if provider == "HuggingFace":
314
+ response_text = response[0].get("generated_text", "")
 
315
  elif provider == "Google":
316
  response_text = response.text.strip()
 
317
  else: # OpenAI, Deepseek, Groq
318
  if not response or not response.choices or not response.choices[0].message.content:
319
  self._log_error("Empty or malformed response from LLM.")
320
  return []
 
321
  response_text = response.choices[0].message.content
322
 
323
  try:
324
  json_output = json.loads(response_text)
325
+ except json.JSONDecodeError as e:
326
+ self._log_error(f"JSON Parse Error: {e}. Raw Response: {response_text}")
327
+ return []
328
 
329
+ if isinstance(json_output, list):
330
+ qa_pairs = json_output
331
+ elif isinstance(json_output, dict) and "questionList" in json_output:
332
+ qa_pairs = json_output["questionList"]
333
+ else:
334
+ self._log_error(f"Unexpected JSON structure: {response_text}")
335
+ return []
 
 
 
 
336
 
337
+ if not isinstance(qa_pairs, list):
338
+ self._log_error(f"Expected a list of QA pairs, but got: {type(qa_pairs)}")
339
+ return []
 
340
 
341
+ for pair in qa_pairs:
342
+ if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
343
+ self._log_error(f"Invalid QA pair structure: {pair}")
344
+ return []
345
 
346
+ return qa_pairs
 
 
347
 
348
  except Exception as e:
349
  self._log_error(f"Parse Error: {e}. Raw Response: {response}")
350
  return []
351
 
352
+ def _log_error(self, message: str) -> None:
353
+ """Log an error message to the session state and display it."""
354
  st.session_state.processing["errors"].append(message)
355
  st.error(message)
356
 
357
 
358
  # --- STREAMLIT UI COMPONENTS ---
359
+ def input_sidebar(generator: SyntheticDataGenerator) -> str:
360
+ """Create the input sidebar in the Streamlit UI."""
361
  with st.sidebar:
362
  st.header("⚙️ Configuration")
363
+ provider = st.selectbox("Provider", list(generator.providers.keys()))
364
+ st.session_state.config["provider"] = provider # Update provider in session state
365
+ provider_cfg = generator.providers[provider]
 
366
 
367
  api_key = st.text_input(f"{provider} API Key", type="password")
368
  st.session_state["api_key"] = api_key
369
 
370
  model = st.selectbox("Model", provider_cfg["models"])
371
+ st.session_state.config["model"] = model
372
 
373
+ temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
374
+ st.session_state.config["temperature"] = temperature
375
 
376
+ # Data Source Input
377
  st.header("🔗 Data Sources")
378
+ input_type = st.selectbox("Input Type", list(generator.input_handlers.keys()))
379
 
380
  if input_type == "text":
381
  domain_input = st.text_area("Domain Knowledge", height=150)
382
  if st.button("Add Domain Input"):
383
+ st.session_state.inputs.append(generator.input_handlers["text"](domain_input)[0])
384
 
385
  elif input_type == "csv":
386
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
387
  if csv_file:
388
+ st.session_state.inputs.extend(generator.input_handlers["csv"](csv_file))
389
 
390
  elif input_type == "api":
391
  api_url = st.text_input("API Endpoint")
392
  api_headers = st.text_area("API Headers (JSON format, optional)", height=API_HEADERS_HEIGHT)
393
  headers = {}
394
+ if api_headers:
395
+ try:
396
  headers = json.loads(api_headers)
397
+ except json.JSONDecodeError:
398
+ st.error("Invalid JSON format for API headers.")
399
  if st.button("Add API Input"):
400
+ st.session_state.inputs.extend(generator.input_handlers["api"]({"url": api_url, "headers": headers}))
401
 
402
  elif input_type == "db":
403
  db_connection = st.text_input("Database Connection String")
404
  db_query = st.text_area("Database Query")
405
  db_table = st.text_input("Table Name (optional)")
406
  if st.button("Add DB Input"):
407
+ st.session_state.inputs.extend(generator.input_handlers["db"]({
408
+ "connection": db_connection,
409
+ "query": db_query,
410
+ "table": db_table
411
+ }))
412
 
413
+ return api_key
414
 
415
 
416
+ def main_display(generator: SyntheticDataGenerator) -> None:
417
+ """Create the main display area in the Streamlit UI."""
418
  st.title("🚀 Enterprise Synthetic Data Factory")
419
 
420
  col1, col2 = st.columns([3, 1])
421
  with col1:
422
  pdf_file = st.file_uploader("Upload Document", type=["pdf"])
423
  if pdf_file:
424
+ st.session_state.inputs.extend(generator.input_handlers["pdf"](pdf_file))
425
 
426
  with col2:
427
  if st.button("Start Generation"):
428
+ with st.spinner("Processing..."):
429
  if not st.session_state["api_key"]:
430
  st.error("Please provide an API Key.")
431
  else:
432
+ generator.generate(st.session_state["api_key"])
433
 
434
  if st.session_state.qa_data:
435
  st.header("Generated Data")
436
  df = pd.DataFrame(st.session_state.qa_data)
437
  st.dataframe(df)
 
438
  st.download_button("Export CSV", df.to_csv(index=False), "synthetic_data.csv")
439
 
440
 
441
+ def main() -> None:
442
  """Main function to run the Streamlit application."""
443
+ generator = SyntheticDataGenerator()
444
+ _ = input_sidebar(generator)
445
+ main_display(generator)
446
 
447
 
448
  if __name__ == "__main__":
449
+ main()