mgbam commited on
Commit
d6dd233
·
verified ·
1 Parent(s): ed85532

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -61
app.py CHANGED
@@ -7,12 +7,22 @@ import json
7
  from PIL import Image
8
  from io import BytesIO
9
  from openai import OpenAI
 
10
  import groq
11
  import sqlalchemy
12
  from typing import Dict, Any
13
 
 
 
 
 
14
  class SyntheticDataGenerator:
 
 
 
 
15
  def __init__(self):
 
16
  self.providers = {
17
  "Deepseek": {
18
  "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
@@ -29,9 +39,13 @@ class SyntheticDataGenerator:
29
  "HuggingFace": {
30
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
31
  "models": ["gpt2", "llama-2"]
32
- }
 
 
 
 
33
  }
34
-
35
  self.input_handlers = {
36
  "pdf": self.handle_pdf,
37
  "text": self.handle_text,
@@ -39,10 +53,20 @@ class SyntheticDataGenerator:
39
  "api": self.handle_api,
40
  "db": self.handle_db
41
  }
42
-
43
  self.init_session()
44
 
 
 
 
 
 
 
 
 
 
45
  def init_session(self):
 
46
  session_defaults = {
47
  'inputs': [],
48
  'qa_data': [],
@@ -54,34 +78,42 @@ class SyntheticDataGenerator:
54
  'config': {
55
  'provider': "Deepseek",
56
  'model': "deepseek-chat",
57
- 'temperature': 0.3
58
  }
59
  }
60
-
61
  for key, val in session_defaults.items():
62
  if key not in st.session_state:
63
  st.session_state[key] = val
64
 
65
  # Input Processors
66
  def handle_pdf(self, file):
67
- try:
 
68
  with pdfplumber.open(file) as pdf:
69
- return [{
70
- "text": page.extract_text() or "",
71
- "images": self.process_images(page),
72
- "meta": {"type": "pdf", "page": i+1}
73
- } for i, page in enumerate(pdf.pages)]
74
- except Exception as e:
75
- self.log_error(f"PDF Error: {str(e)}")
76
- return []
 
 
 
 
 
77
 
78
  def handle_text(self, text):
 
79
  return [{
80
  "text": text,
81
  "meta": {"type": "domain", "source": "manual"}
82
  }]
83
 
84
  def handle_csv(self, file):
 
85
  try:
86
  df = pd.read_csv(file)
87
  return [{
@@ -93,17 +125,21 @@ class SyntheticDataGenerator:
93
  return []
94
 
95
  def handle_api(self, config):
 
96
  try:
97
  response = requests.get(config['url'], headers=config['headers'])
 
98
  return [{
99
  "text": json.dumps(response.json()),
100
  "meta": {"type": "api", "endpoint": config['url']}
101
  }]
102
- except Exception as e:
103
  self.log_error(f"API Error: {str(e)}")
104
  return []
105
 
 
106
  def handle_db(self, config):
 
107
  try:
108
  engine = sqlalchemy.create_engine(config['connection'])
109
  with engine.connect() as conn:
@@ -117,6 +153,7 @@ class SyntheticDataGenerator:
117
  return []
118
 
119
  def process_images(self, page):
 
120
  images = []
121
  for img in page.images:
122
  try:
@@ -134,130 +171,237 @@ class SyntheticDataGenerator:
134
 
135
  # Core Generation Engine
136
  def generate(self, api_key: str) -> bool:
 
 
 
 
 
 
 
 
 
137
  try:
138
  provider_cfg = self.providers[st.session_state.config['provider']]
139
- client = provider_cfg["client"](api_key)
140
-
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  for i, input_data in enumerate(st.session_state.inputs):
142
  st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
143
-
144
  if st.session_state.config['provider'] == "HuggingFace":
145
  response = self._huggingface_inference(client, input_data)
 
 
146
  else:
147
  response = self._standard_inference(client, input_data)
148
-
149
  if response:
150
- st.session_state.qa_data.extend(self._parse_response(response))
151
-
 
152
  return True
153
  except Exception as e:
154
  self.log_error(f"Generation Error: {str(e)}")
155
  return False
156
 
157
  def _standard_inference(self, client, input_data):
158
- return client.chat.completions.create(
159
- model=st.session_state.config['model'],
160
- messages=[{
161
- "role": "user",
162
- "content": self._build_prompt(input_data)
163
- }],
164
- temperature=st.session_state.config['temperature'],
165
- response_format={"type": "json_object"}
166
- )
 
 
 
 
 
167
 
168
  def _huggingface_inference(self, client, input_data):
169
- API_URL = "https://api-inference.huggingface.co/models/"
170
- response = requests.post(
171
- API_URL + st.session_state.config['model'],
172
- headers=client["headers"],
173
- json={"inputs": self._build_prompt(input_data)}
174
- )
175
- return response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  def _build_prompt(self, input_data):
178
- base = "Generate 3 Q&A pairs from this financial content:\n"
 
179
  if input_data['meta']['type'] == 'csv':
180
  return base + "Structured data:\n" + input_data['text']
181
  elif input_data['meta']['type'] == 'api':
182
  return base + "API response:\n" + input_data['text']
183
  return base + input_data['text']
184
 
185
- def _parse_response(self, response):
 
186
  try:
187
- if st.session_state.config['provider'] == "HuggingFace":
188
  return response[0]['generated_text']
189
- return json.loads(response.choices[0].message.content).get("qa_pairs", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  except Exception as e:
191
- self.log_error(f"Parse Error: {str(e)}")
192
  return []
193
 
194
  def log_error(self, message):
 
195
  st.session_state.processing['errors'].append(message)
196
  st.error(message)
197
 
198
  # Streamlit UI Components
199
  def input_sidebar(gen: SyntheticDataGenerator):
 
 
 
 
 
 
 
 
 
200
  with st.sidebar:
201
  st.header("⚙️ Configuration")
202
-
203
  # AI Provider Settings
204
  provider = st.selectbox("Provider", list(gen.providers.keys()))
205
  provider_cfg = gen.providers[provider]
206
-
207
  api_key = st.text_input(f"{provider} API Key", type="password")
 
 
208
  model = st.selectbox("Model", provider_cfg["models"])
209
- temp = st.slider("Temperature", 0.0, 1.0, 0.3)
210
-
211
  # Update session config
212
  st.session_state.config.update({
213
  "provider": provider,
214
  "model": model,
215
  "temperature": temp
216
  })
217
-
218
  # Input Source Selection
219
  st.header("🔗 Data Sources")
220
  input_type = st.selectbox("Input Type", list(gen.input_handlers.keys()))
221
-
222
  if input_type == "text":
223
  domain_input = st.text_area("Domain Knowledge", height=150)
224
  if st.button("Add Domain Input"):
225
- gen.input_handlers["text"](domain_input)
226
-
227
  elif input_type == "csv":
228
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
229
  if csv_file:
230
- gen.input_handlers["csv"](csv_file)
231
-
232
  elif input_type == "api":
233
  api_url = st.text_input("API Endpoint")
234
- if st.button("Connect API"):
235
- gen.input_handlers["api"]({"url": api_url})
236
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  return api_key
238
 
239
  def main_display(gen: SyntheticDataGenerator):
 
 
 
 
 
 
240
  st.title("🚀 Enterprise Synthetic Data Factory")
241
-
242
  # Input Processing
243
  col1, col2 = st.columns([3, 1])
244
  with col1:
245
  pdf_file = st.file_uploader("Upload Document", type=["pdf"])
246
  if pdf_file:
247
- gen.input_handlers["pdf"](pdf_file)
248
-
249
  # Generation Controls
250
  with col2:
251
  if st.button("Start Generation"):
252
  with st.status("Processing..."):
253
- gen.generate(st.session_state.get('api_key'))
254
-
 
 
 
255
  # Results Display
256
  if st.session_state.qa_data:
257
  st.header("Generated Data")
258
  df = pd.DataFrame(st.session_state.qa_data)
259
  st.dataframe(df)
260
-
261
  # Export Options
262
  st.download_button(
263
  "Export CSV",
@@ -266,6 +410,7 @@ def main_display(gen: SyntheticDataGenerator):
266
  )
267
 
268
  def main():
 
269
  gen = SyntheticDataGenerator()
270
  api_key = input_sidebar(gen)
271
  main_display(gen)
 
7
  from PIL import Image
8
  from io import BytesIO
9
  from openai import OpenAI
10
+ import google.generativeai as genai # Added Google GenAI
11
  import groq
12
  import sqlalchemy
13
  from typing import Dict, Any
14
 
15
+ # Constants for Default Values and API URLs
16
+ HF_API_URL = "https://api-inference.huggingface.co/models/"
17
+ DEFAULT_TEMPERATURE = 0.3
18
+
19
  class SyntheticDataGenerator:
20
+ """
21
+ A class to generate synthetic Q&A data from various input sources using different LLM providers.
22
+ """
23
+
24
  def __init__(self):
25
+ """Initializes the SyntheticDataGenerator with supported providers, input handlers, and session state."""
26
  self.providers = {
27
  "Deepseek": {
28
  "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
 
39
  "HuggingFace": {
40
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
41
  "models": ["gpt2", "llama-2"]
42
+ },
43
+ "Google": {
44
+ "client": lambda key: self._configure_google_genai(key), # Using a custom configure function
45
+ "models": ["gemini-2.0-pro"] # Add supported Gemini models. Consider adding "gemini-1.5-pro" when released.
46
+ },
47
  }
48
+
49
  self.input_handlers = {
50
  "pdf": self.handle_pdf,
51
  "text": self.handle_text,
 
53
  "api": self.handle_api,
54
  "db": self.handle_db
55
  }
56
+
57
  self.init_session()
58
 
59
+ def _configure_google_genai(self, api_key: str):
60
+ """Configures the Google Generative AI client."""
61
+ try:
62
+ genai.configure(api_key=api_key)
63
+ return genai.GenerativeModel # return the model class, not an instantiation
64
+ except Exception as e:
65
+ st.error(f"Error configuring Google GenAI: {e}")
66
+ return None # Important: Handle the case where configuration fails
67
+
68
  def init_session(self):
69
+ """Initializes the Streamlit session state with default values."""
70
  session_defaults = {
71
  'inputs': [],
72
  'qa_data': [],
 
78
  'config': {
79
  'provider': "Deepseek",
80
  'model': "deepseek-chat",
81
+ 'temperature': DEFAULT_TEMPERATURE
82
  }
83
  }
84
+
85
  for key, val in session_defaults.items():
86
  if key not in st.session_state:
87
  st.session_state[key] = val
88
 
89
  # Input Processors
90
  def handle_pdf(self, file):
91
+ """Extracts text and images from a PDF file."""
92
+ try:
93
  with pdfplumber.open(file) as pdf:
94
+ extracted_data = []
95
+ for i, page in enumerate(pdf.pages):
96
+ page_text = page.extract_text() or ""
97
+ page_images = self.process_images(page)
98
+ extracted_data.append({
99
+ "text": page_text,
100
+ "images": page_images,
101
+ "meta": {"type": "pdf", "page": i + 1}
102
+ })
103
+ return extracted_data
104
+ except Exception as e:
105
+ self.log_error(f"PDF Error: {str(e)}")
106
+ return []
107
 
108
  def handle_text(self, text):
109
+ """Handles manual text input."""
110
  return [{
111
  "text": text,
112
  "meta": {"type": "domain", "source": "manual"}
113
  }]
114
 
115
  def handle_csv(self, file):
116
+ """Reads a CSV file and prepares data for Q&A generation."""
117
  try:
118
  df = pd.read_csv(file)
119
  return [{
 
125
  return []
126
 
127
  def handle_api(self, config):
128
+ """Fetches data from an API endpoint."""
129
  try:
130
  response = requests.get(config['url'], headers=config['headers'])
131
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
132
  return [{
133
  "text": json.dumps(response.json()),
134
  "meta": {"type": "api", "endpoint": config['url']}
135
  }]
136
+ except requests.exceptions.RequestException as e:
137
  self.log_error(f"API Error: {str(e)}")
138
  return []
139
 
140
+
141
  def handle_db(self, config):
142
+ """Connects to a database and executes a query."""
143
  try:
144
  engine = sqlalchemy.create_engine(config['connection'])
145
  with engine.connect() as conn:
 
153
  return []
154
 
155
  def process_images(self, page):
156
+ """Extracts and processes images from a PDF page."""
157
  images = []
158
  for img in page.images:
159
  try:
 
171
 
172
  # Core Generation Engine
173
  def generate(self, api_key: str) -> bool:
174
+ """
175
+ Generates Q&A pairs using the selected LLM provider.
176
+
177
+ Args:
178
+ api_key (str): The API key for the selected LLM provider.
179
+
180
+ Returns:
181
+ bool: True if generation was successful, False otherwise.
182
+ """
183
  try:
184
  provider_cfg = self.providers[st.session_state.config['provider']]
185
+ client_initializer = provider_cfg["client"] #Get the client init function.
186
+
187
+ # Check that the key is not an empty string
188
+ if not api_key:
189
+ st.error("API Key cannot be empty.")
190
+ return False
191
+
192
+ # Initialize the client
193
+ if st.session_state.config['provider'] == "Google":
194
+ client = client_initializer(api_key) # Client is the class
195
+ if not client:
196
+ return False # Google config failed
197
+ else:
198
+ client = client_initializer(api_key)
199
+
200
  for i, input_data in enumerate(st.session_state.inputs):
201
  st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
202
+
203
  if st.session_state.config['provider'] == "HuggingFace":
204
  response = self._huggingface_inference(client, input_data)
205
+ elif st.session_state.config['provider'] == "Google":
206
+ response = self._google_inference(client, input_data)
207
  else:
208
  response = self._standard_inference(client, input_data)
209
+
210
  if response:
211
+ # Check if the parsing function needs access to the provider
212
+ st.session_state.qa_data.extend(self._parse_response(response, st.session_state.config['provider']))
213
+
214
  return True
215
  except Exception as e:
216
  self.log_error(f"Generation Error: {str(e)}")
217
  return False
218
 
219
  def _standard_inference(self, client, input_data):
220
+ """Performs inference using standard OpenAI-compatible API."""
221
+ try:
222
+ return client.chat.completions.create(
223
+ model=st.session_state.config['model'],
224
+ messages=[{
225
+ "role": "user",
226
+ "content": self._build_prompt(input_data)
227
+ }],
228
+ temperature=st.session_state.config['temperature'],
229
+ response_format={"type": "json_object"} #Request json
230
+ )
231
+ except Exception as e:
232
+ self.log_error(f"OpenAI Inference Error: {e}")
233
+ return None
234
 
235
  def _huggingface_inference(self, client, input_data):
236
+ """Performs inference using Hugging Face Inference API."""
237
+ try:
238
+ response = requests.post(
239
+ HF_API_URL + st.session_state.config['model'],
240
+ headers=client["headers"],
241
+ json={"inputs": self._build_prompt(input_data)}
242
+ )
243
+ response.raise_for_status() #Check for HTTP errors
244
+ return response.json()
245
+ except requests.exceptions.RequestException as e:
246
+ self.log_error(f"Hugging Face Inference Error: {e}")
247
+ return None
248
+
249
+ def _google_inference(self, client, input_data):
250
+ """Performs inference using Google Generative AI API."""
251
+ try:
252
+
253
+ model = client(st.session_state.config['model']) # Instantiate the model with the selected model name
254
+ response = model.generate_content(
255
+ self._build_prompt(input_data),
256
+ generation_config = genai.types.GenerationConfig(temperature=st.session_state.config['temperature'])
257
+
258
+ )
259
+ return response
260
+ except Exception as e:
261
+ self.log_error(f"Google GenAI Inference Error: {e}")
262
+ return None
263
 
264
  def _build_prompt(self, input_data):
265
+ """Builds the prompt for the LLM based on the input data type."""
266
+ base = "Generate 3 Q&A pairs from this financial content, formatted as a JSON list of dictionaries with 'question' and 'answer' keys:\n"
267
  if input_data['meta']['type'] == 'csv':
268
  return base + "Structured data:\n" + input_data['text']
269
  elif input_data['meta']['type'] == 'api':
270
  return base + "API response:\n" + input_data['text']
271
  return base + input_data['text']
272
 
273
+ def _parse_response(self, response, provider):
274
+ """Parses the response from the LLM into a list of Q&A pairs."""
275
  try:
276
+ if provider == "HuggingFace":
277
  return response[0]['generated_text']
278
+ elif provider == "Google":
279
+ # Expecting a text response from Gemini
280
+ try:
281
+ json_string = response.text.strip() # Removes surrounding whitespace that can cause errors
282
+ qa_pairs = json.loads(json_string).get("qa_pairs", []) # Extract the qa_pairs
283
+
284
+ # Validate the structure of qa_pairs
285
+ if not isinstance(qa_pairs, list):
286
+ raise ValueError("Expected a list of QA pairs.")
287
+
288
+ for pair in qa_pairs:
289
+ if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
290
+ raise ValueError("Each item in the list must be a dictionary with 'question' and 'answer' keys.")
291
+ return qa_pairs # Return the extracted and validated list
292
+ except (json.JSONDecodeError, ValueError) as e:
293
+ self.log_error(f"Google JSON Parse Error: {e}. Raw Response: {response.text}")
294
+ return [] # Return empty in case of parsing failure
295
+ else:
296
+ # Assuming JSON response from other providers (OpenAI, Deepseek, Groq)
297
+ json_output = json.loads(response.choices[0].message.content) # load the JSON data
298
+ return json_output.get("qa_pairs", []) # Return the qa_pairs
299
  except Exception as e:
300
+ self.log_error(f"Parse Error: {e}. Raw Response: {response}")
301
  return []
302
 
303
  def log_error(self, message):
304
+ """Logs an error message to the Streamlit session state and displays it in the UI."""
305
  st.session_state.processing['errors'].append(message)
306
  st.error(message)
307
 
308
  # Streamlit UI Components
309
  def input_sidebar(gen: SyntheticDataGenerator):
310
+ """
311
+ Creates the input sidebar in the Streamlit UI.
312
+
313
+ Args:
314
+ gen (SyntheticDataGenerator): The SyntheticDataGenerator instance.
315
+
316
+ Returns:
317
+ str: The API key entered by the user.
318
+ """
319
  with st.sidebar:
320
  st.header("⚙️ Configuration")
321
+
322
  # AI Provider Settings
323
  provider = st.selectbox("Provider", list(gen.providers.keys()))
324
  provider_cfg = gen.providers[provider]
325
+
326
  api_key = st.text_input(f"{provider} API Key", type="password")
327
+ st.session_state['api_key'] = api_key #Store API Key
328
+
329
  model = st.selectbox("Model", provider_cfg["models"])
330
+ temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
331
+
332
  # Update session config
333
  st.session_state.config.update({
334
  "provider": provider,
335
  "model": model,
336
  "temperature": temp
337
  })
338
+
339
  # Input Source Selection
340
  st.header("🔗 Data Sources")
341
  input_type = st.selectbox("Input Type", list(gen.input_handlers.keys()))
342
+
343
  if input_type == "text":
344
  domain_input = st.text_area("Domain Knowledge", height=150)
345
  if st.button("Add Domain Input"):
346
+ st.session_state.inputs.append(gen.input_handlers["text"](domain_input)[0])
347
+
348
  elif input_type == "csv":
349
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
350
  if csv_file:
351
+ st.session_state.inputs.extend(gen.input_handlers["csv"](csv_file))
352
+
353
  elif input_type == "api":
354
  api_url = st.text_input("API Endpoint")
355
+ api_headers = st.text_area("API Headers (JSON format, optional)", height=50)
356
+ headers = {}
357
+ try:
358
+ if api_headers:
359
+ headers = json.loads(api_headers)
360
+ except json.JSONDecodeError:
361
+ st.error("Invalid JSON format for API headers.")
362
+ if st.button("Add API Input"):
363
+ st.session_state.inputs.extend(gen.input_handlers["api"]({"url": api_url, "headers": headers}))
364
+
365
+ elif input_type == "db":
366
+ db_connection = st.text_input("Database Connection String")
367
+ db_query = st.text_area("Database Query")
368
+ db_table = st.text_input("Table Name (optional)")
369
+ if st.button("Add DB Input"):
370
+ st.session_state.inputs.extend(gen.input_handlers["db"]({"connection": db_connection, "query": db_query, "table": db_table}))
371
+
372
  return api_key
373
 
374
  def main_display(gen: SyntheticDataGenerator):
375
+ """
376
+ Creates the main display area in the Streamlit UI.
377
+
378
+ Args:
379
+ gen (SyntheticDataGenerator): The SyntheticDataGenerator instance.
380
+ """
381
  st.title("🚀 Enterprise Synthetic Data Factory")
382
+
383
  # Input Processing
384
  col1, col2 = st.columns([3, 1])
385
  with col1:
386
  pdf_file = st.file_uploader("Upload Document", type=["pdf"])
387
  if pdf_file:
388
+ st.session_state.inputs.extend(gen.input_handlers["pdf"](pdf_file))
389
+
390
  # Generation Controls
391
  with col2:
392
  if st.button("Start Generation"):
393
  with st.status("Processing..."):
394
+ if not st.session_state.get('api_key'):
395
+ st.error("Please provide an API Key.")
396
+ else:
397
+ gen.generate(st.session_state.get('api_key'))
398
+
399
  # Results Display
400
  if st.session_state.qa_data:
401
  st.header("Generated Data")
402
  df = pd.DataFrame(st.session_state.qa_data)
403
  st.dataframe(df)
404
+
405
  # Export Options
406
  st.download_button(
407
  "Export CSV",
 
410
  )
411
 
412
  def main():
413
+ """Main function to run the Streamlit application."""
414
  gen = SyntheticDataGenerator()
415
  api_key = input_sidebar(gen)
416
  main_display(gen)