mgbam commited on
Commit
5f0d3d6
·
verified ·
1 Parent(s): 1de53dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -23
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import json
2
  import requests
3
  import streamlit as st
4
  import pdfplumber
@@ -6,6 +5,7 @@ import pandas as pd
6
  import sqlalchemy
7
  from typing import Any, Dict, List, Optional
8
  from functools import lru_cache
 
9
 
10
  # Provider clients with import guards
11
  try:
@@ -18,6 +18,16 @@ try:
18
  except ImportError:
19
  groq = None
20
 
 
 
 
 
 
 
 
 
 
 
21
  class SyntheticDataGenerator:
22
  """World's Most Advanced Synthetic Data Generation System"""
23
 
@@ -41,6 +51,10 @@ class SyntheticDataGenerator:
41
  "base_url": "https://api-inference.huggingface.co/models/",
42
  "models": ["gpt2", "llama-2-13b-chat"],
43
  "requires_library": None
 
 
 
 
44
  }
45
  }
46
 
@@ -61,7 +75,15 @@ class SyntheticDataGenerator:
61
  "tokens_used": 0,
62
  "error_count": 0
63
  },
64
- "debug_mode": False
 
 
 
 
 
 
 
 
65
  }
66
  for key, val in defaults.items():
67
  if key not in st.session_state:
@@ -71,7 +93,7 @@ class SyntheticDataGenerator:
71
  """Configure available providers with health checks"""
72
  self.available_providers = []
73
  for provider, config in self.PROVIDER_CONFIG.items():
74
- if config["requires_library"] and not globals().get(config["requires_library"].title()):
75
  continue # Skip providers with missing dependencies
76
  self.available_providers.append(provider)
77
 
@@ -83,7 +105,8 @@ class SyntheticDataGenerator:
83
  "csv": self._process_csv,
84
  "api": self._process_api,
85
  "database": self._process_database,
86
- "web": self._process_web
 
87
  }
88
 
89
  # --- Core Generation Engine ---
@@ -108,8 +131,8 @@ class SyntheticDataGenerator:
108
  """Secure client initialization with connection pooling"""
109
  config = self.PROVIDER_CONFIG[provider]
110
  api_key = st.session_state.api_keys.get(provider, "")
111
-
112
- if not api_key:
113
  raise ValueError("API key required")
114
 
115
  try:
@@ -117,6 +140,26 @@ class SyntheticDataGenerator:
117
  return groq.Groq(api_key=api_key)
118
  elif provider == "HuggingFace":
119
  return {"headers": {"Authorization": f"Bearer {api_key}"}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  else:
121
  return OpenAI(
122
  base_url=config["base_url"],
@@ -130,7 +173,7 @@ class SyntheticDataGenerator:
130
  def _execute_generation(self, client, provider: str, model: str, prompt: str) -> Dict[str, Any]:
131
  """Execute provider-specific generation with circuit breaker"""
132
  st.session_state.system_metrics["api_calls"] += 1
133
-
134
  if provider == "HuggingFace":
135
  response = requests.post(
136
  self.PROVIDER_CONFIG[provider]["base_url"] + model,
@@ -140,22 +183,41 @@ class SyntheticDataGenerator:
140
  )
141
  response.raise_for_status()
142
  return response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  else:
144
  completion = client.chat.completions.create(
145
  model=model,
146
  messages=[{"role": "user", "content": prompt}],
147
- temperature=0.1,
148
- max_tokens=2000
149
  )
150
  st.session_state.system_metrics["tokens_used"] += completion.usage.total_tokens
151
- return json.loads(completion.choices[0].message.content)
 
 
 
152
 
153
  def _failover_generation(self, prompt: str) -> Dict[str, Any]:
154
  """Enterprise failover to secondary providers"""
155
  for backup_provider in self.available_providers:
156
  if backup_provider != st.session_state.active_provider:
157
  try:
158
- return self.generate(backup_provider, ...)
159
  except Exception:
160
  continue
161
  raise RuntimeError("All generation providers unavailable")
@@ -181,14 +243,76 @@ class SyntheticDataGenerator:
181
  self._log_error(f"Web Extraction Error: {str(e)}")
182
  return ""
183
 
184
- # Additional processors follow similar patterns...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  # --- Enterprise Features ---
187
  def _log_error(self, message: str) -> None:
188
  """Centralized error logging with telemetry"""
189
  st.session_state.system_metrics["error_count"] += 1
190
  st.session_state.error_logs = st.session_state.get("error_logs", []) + [message]
191
-
192
  if st.session_state.debug_mode:
193
  st.error(f"[DEBUG] {message}")
194
 
@@ -214,6 +338,26 @@ class SyntheticDataGenerator:
214
  timeout=5
215
  )
216
  return response.status_code == 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  else:
218
  client.models.list()
219
  return True
@@ -225,14 +369,15 @@ def provider_config_ui(gen: SyntheticDataGenerator):
225
  """Advanced provider configuration interface"""
226
  with st.sidebar:
227
  st.header("⚙️ AI Engine Configuration")
228
-
229
  # Provider selection with availability checks
230
  provider = st.selectbox(
231
  "AI Provider",
232
  gen.available_providers,
233
  help="Available providers based on system configuration"
234
  )
235
-
 
236
  # API key management
237
  api_key = st.text_input(
238
  f"{provider} API Key",
@@ -241,19 +386,69 @@ def provider_config_ui(gen: SyntheticDataGenerator):
241
  help=f"Obtain API key from {provider} portal"
242
  )
243
  st.session_state.api_keys[provider] = api_key
244
-
245
  # Model selection
246
  model = st.selectbox(
247
  "Model",
248
  gen.PROVIDER_CONFIG[provider]["models"],
249
  help="Select model version based on your API plan"
250
  )
251
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  # System monitoring
253
  if st.button("Run Health Check"):
254
  report = gen.health_check()
255
  st.json(report)
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  def main():
258
  """Enterprise-grade user interface"""
259
  st.set_page_config(
@@ -261,18 +456,57 @@ def main():
261
  page_icon="🏭",
262
  layout="wide"
263
  )
264
-
265
  gen = SyntheticDataGenerator()
266
-
267
  st.title("🏭 Synthetic Data Factory Pro")
268
  st.markdown("""
269
- **World's Most Advanced Synthetic Data Generation Platform**
270
  *Multi-provider AI Engine | Enterprise Input Processors | Real-time Monitoring*
271
  """)
272
-
273
  provider_config_ui(gen)
274
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  # Input management and generation UI components...
276
-
277
  if __name__ == "__main__":
278
  main()
 
 
1
  import requests
2
  import streamlit as st
3
  import pdfplumber
 
5
  import sqlalchemy
6
  from typing import Any, Dict, List, Optional
7
  from functools import lru_cache
8
+ import os # Import the 'os' module
9
 
10
  # Provider clients with import guards
11
  try:
 
18
  except ImportError:
19
  groq = None
20
 
21
+ try:
22
+ import google.generativeai as genai
23
+ from google.generativeai import GenerativeModel, configure
24
+ except ImportError:
25
+ GenerativeModel = None
26
+ configure = None
27
+ genai = None #Also set this to none
28
+
29
+ import json # Ensure json is explicitly imported for enhanced use
30
+
31
  class SyntheticDataGenerator:
32
  """World's Most Advanced Synthetic Data Generation System"""
33
 
 
51
  "base_url": "https://api-inference.huggingface.co/models/",
52
  "models": ["gpt2", "llama-2-13b-chat"],
53
  "requires_library": None
54
+ },
55
+ "Google": {
56
+ "models": ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"], # Include Gemini 2.0 Flash
57
+ "requires_library": "google.generativeai"
58
  }
59
  }
60
 
 
75
  "tokens_used": 0,
76
  "error_count": 0
77
  },
78
+ "debug_mode": False,
79
+ "google_configured": False, # Track if Google API is configured
80
+ "advanced_options": { # Store advanced generation options
81
+ "temperature": 0.7, # Default temperature
82
+ "top_p": 0.95, # Default top_p
83
+ "top_k": 40, # Default top_k
84
+ "max_output_tokens": 2000 # Default max_output_tokens
85
+ },
86
+ "generation_format": "json" # Default output format (json or text)
87
  }
88
  for key, val in defaults.items():
89
  if key not in st.session_state:
 
93
  """Configure available providers with health checks"""
94
  self.available_providers = []
95
  for provider, config in self.PROVIDER_CONFIG.items():
96
+ if config["requires_library"] and not globals().get(config["requires_library"].split('.')[0].title()):
97
  continue # Skip providers with missing dependencies
98
  self.available_providers.append(provider)
99
 
 
105
  "csv": self._process_csv,
106
  "api": self._process_api,
107
  "database": self._process_database,
108
+ "web": self._process_web,
109
+ "image": self._process_image #Add Image
110
  }
111
 
112
  # --- Core Generation Engine ---
 
131
  """Secure client initialization with connection pooling"""
132
  config = self.PROVIDER_CONFIG[provider]
133
  api_key = st.session_state.api_keys.get(provider, "")
134
+
135
+ if not api_key and provider != "Google": #Google API key is configured by configure()
136
  raise ValueError("API key required")
137
 
138
  try:
 
140
  return groq.Groq(api_key=api_key)
141
  elif provider == "HuggingFace":
142
  return {"headers": {"Authorization": f"Bearer {api_key}"}}
143
+ elif provider == "Google":
144
+ if not st.session_state.google_configured:
145
+ # Check if the API key is set as an environment variable
146
+ if "GOOGLE_API_KEY" in os.environ:
147
+ api_key = os.environ["GOOGLE_API_KEY"]
148
+ else:
149
+ # Use the API key from session state if available
150
+ api_key = st.session_state.api_keys.get("Google", "")
151
+ if not api_key:
152
+ raise ValueError("Google API key is required. Please set it in the app or as the GOOGLE_API_KEY environment variable.")
153
+ configure(api_key=api_key) #Configure the Google API key. Only do once
154
+ st.session_state.google_configured = True
155
+
156
+ generation_config = genai.GenerationConfig(
157
+ temperature=st.session_state.advanced_options["temperature"],
158
+ top_p=st.session_state.advanced_options["top_p"],
159
+ top_k=st.session_state.advanced_options["top_k"],
160
+ max_output_tokens=st.session_state.advanced_options["max_output_tokens"]
161
+ )
162
+ return GenerativeModel(model_name=model, generation_config=generation_config) # Create the GenerativeModel with generation config
163
  else:
164
  return OpenAI(
165
  base_url=config["base_url"],
 
173
  def _execute_generation(self, client, provider: str, model: str, prompt: str) -> Dict[str, Any]:
174
  """Execute provider-specific generation with circuit breaker"""
175
  st.session_state.system_metrics["api_calls"] += 1
176
+
177
  if provider == "HuggingFace":
178
  response = requests.post(
179
  self.PROVIDER_CONFIG[provider]["base_url"] + model,
 
183
  )
184
  response.raise_for_status()
185
  return response.json()
186
+ elif provider == "Google":
187
+ try:
188
+ response = client.generate_content(prompt)
189
+ content = response.text
190
+
191
+ if st.session_state.generation_format == "json": # Check requested format
192
+ try:
193
+ return json.loads(content) # Attempt to parse as JSON
194
+ except json.JSONDecodeError:
195
+ return {"content": content, "warning": "Could not parse response as valid JSON. Returning raw text."} #Return raw content with warning
196
+ else:
197
+ return {"content": content} # Return raw content
198
+
199
+ except Exception as e:
200
+ self._log_error(f"Google Generation Error: {str(e)}")
201
+ return {"error": str(e), "content": ""}
202
  else:
203
  completion = client.chat.completions.create(
204
  model=model,
205
  messages=[{"role": "user", "content": prompt}],
206
+ temperature=st.session_state.advanced_options["temperature"], #Use temp from session
207
+ max_tokens=st.session_state.advanced_options["max_output_tokens"]
208
  )
209
  st.session_state.system_metrics["tokens_used"] += completion.usage.total_tokens
210
+ try:
211
+ return json.loads(completion.choices[0].message.content)
212
+ except json.JSONDecodeError:
213
+ return {"content": completion.choices[0].message.content, "warning": "Could not parse response as valid JSON. Returning raw text."}
214
 
215
  def _failover_generation(self, prompt: str) -> Dict[str, Any]:
216
  """Enterprise failover to secondary providers"""
217
  for backup_provider in self.available_providers:
218
  if backup_provider != st.session_state.active_provider:
219
  try:
220
+ return self.generate(backup_provider, ..., prompt=prompt) # Corrected: include prompt
221
  except Exception:
222
  continue
223
  raise RuntimeError("All generation providers unavailable")
 
243
  self._log_error(f"Web Extraction Error: {str(e)}")
244
  return ""
245
 
246
+ def _process_csv(self, file) -> str:
247
+ """Process CSV files and return as a string representation."""
248
+ try:
249
+ df = pd.read_csv(file)
250
+
251
+ # Attempt to infer a schema for the synthetic data generation
252
+ column_names = df.columns.tolist()
253
+ data_types = [str(df[col].dtype) for col in df.columns]
254
+ schema_prompt = f"Column Names: {column_names}\nData Types: {data_types}"
255
+ st.session_state.csv_schema = schema_prompt # Store the schema
256
+
257
+ return df.to_string() # Convert DataFrame to string
258
+ except Exception as e:
259
+ self._log_error(f"CSV Processing Error: {str(e)}")
260
+ return ""
261
+
262
+ def _process_text(self, text: str) -> str:
263
+ """Simple text passthrough processor"""
264
+ return text
265
+
266
+ def _process_api(self, url: str, method="GET", headers: Optional[Dict[str, str]] = None, data: Optional[Dict[str, Any]] = None) -> str:
267
+ """Generic API endpoint processor with configurable methods and headers."""
268
+ try:
269
+ if method.upper() == "GET":
270
+ response = requests.get(url, headers=headers or {}, timeout=10)
271
+ elif method.upper() == "POST":
272
+ response = requests.post(url, headers=headers or {}, json=data, timeout=10)
273
+ else:
274
+ raise ValueError("Unsupported HTTP method.")
275
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
276
+
277
+ try:
278
+ return json.dumps(response.json(), indent=2) # Pretty print JSON if possible
279
+ except json.JSONDecodeError:
280
+ return response.text # Otherwise, return raw text
281
+ except requests.exceptions.RequestException as e:
282
+ self._log_error(f"API Processing Error: {str(e)}")
283
+ return ""
284
+
285
+ def _process_database(self, connection_string: str, query: str) -> str:
286
+ """Database query processor using SQLAlchemy."""
287
+ try:
288
+ engine = sqlalchemy.create_engine(connection_string)
289
+ with engine.connect() as connection:
290
+ result = connection.execute(sqlalchemy.text(query))
291
+ df = pd.DataFrame(result.fetchall(), columns=result.keys())
292
+ return df.to_string()
293
+ except Exception as e:
294
+ self._log_error(f"Database Processing Error: {str(e)}")
295
+ return ""
296
+
297
+ def _process_image(self, image_file) -> str:
298
+ """Processes image files for multimodal generation"""
299
+ try:
300
+ # For Google's Gemini, you need to prepare the image in a specific format
301
+ image_data = image_file.read()
302
+ image_part = {"mime_type": image_file.type, "data": image_data}
303
+ st.session_state.image_part = image_part #Store image part
304
+ return "Image uploaded. Include instructions for processing the image in your prompt." # Basic instruction to the LLM
305
+
306
+ except Exception as e:
307
+ self._log_error(f"Image Processing Error: {str(e)}")
308
+ return ""
309
 
310
  # --- Enterprise Features ---
311
  def _log_error(self, message: str) -> None:
312
  """Centralized error logging with telemetry"""
313
  st.session_state.system_metrics["error_count"] += 1
314
  st.session_state.error_logs = st.session_state.get("error_logs", []) + [message]
315
+
316
  if st.session_state.debug_mode:
317
  st.error(f"[DEBUG] {message}")
318
 
 
338
  timeout=5
339
  )
340
  return response.status_code == 200
341
+ elif provider == "Google":
342
+ try:
343
+ #Need to initialize before listing models
344
+ if not st.session_state.google_configured:
345
+ api_key = st.session_state.api_keys.get("Google", "")
346
+ if not api_key:
347
+ api_key = os.environ.get("GOOGLE_API_KEY") #Check env variables
348
+ if not api_key:
349
+ return False
350
+
351
+ configure(api_key=api_key) #Configure API Key
352
+ st.session_state.google_configured = True
353
+
354
+ genai.GenerativeModel(model_name=self.PROVIDER_CONFIG["Google"]["models"][0]).generate_content("test") #Send a test query
355
+ return True #Connected if made it this far
356
+
357
+ except Exception as e:
358
+ print(e)
359
+ return False
360
+
361
  else:
362
  client.models.list()
363
  return True
 
369
  """Advanced provider configuration interface"""
370
  with st.sidebar:
371
  st.header("⚙️ AI Engine Configuration")
372
+
373
  # Provider selection with availability checks
374
  provider = st.selectbox(
375
  "AI Provider",
376
  gen.available_providers,
377
  help="Available providers based on system configuration"
378
  )
379
+ st.session_state.active_provider = provider
380
+
381
  # API key management
382
  api_key = st.text_input(
383
  f"{provider} API Key",
 
386
  help=f"Obtain API key from {provider} portal"
387
  )
388
  st.session_state.api_keys[provider] = api_key
389
+
390
  # Model selection
391
  model = st.selectbox(
392
  "Model",
393
  gen.PROVIDER_CONFIG[provider]["models"],
394
  help="Select model version based on your API plan"
395
  )
396
+ st.session_state.active_model = model
397
+
398
+ # Advanced Options (for providers that support it)
399
+ if provider == "Google" or provider == "OpenAI": #Only add if OpenAI
400
+ st.subheader("Advanced Generation Options")
401
+ st.session_state.advanced_options["temperature"] = st.slider("Temperature", min_value=0.0, max_value=1.0, value=st.session_state.advanced_options["temperature"], step=0.05, help="Controls randomness. Lower values = more deterministic.")
402
+
403
+ if provider == "Google":
404
+ st.session_state.advanced_options["top_p"] = st.slider("Top P", min_value=0.0, max_value=1.0, value=st.session_state.advanced_options["top_p"], step=0.05, help="Nucleus sampling: Considers the most probable tokens.")
405
+ st.session_state.advanced_options["top_k"] = st.slider("Top K", min_value=1, max_value=100, value=st.session_state.advanced_options["top_k"], step=1, help="Considers the top K most probable tokens.")
406
+
407
+ st.session_state.advanced_options["max_output_tokens"] = st.number_input("Max Output Tokens", min_value=50, max_value=4096, value=st.session_state.advanced_options["max_output_tokens"], step=50, help="Maximum number of tokens in the generated output.")
408
+
409
+ # Output format
410
+ st.session_state.generation_format = st.selectbox("Output Format", ["json", "text"], help="Choose the desired output format.")
411
+
412
  # System monitoring
413
  if st.button("Run Health Check"):
414
  report = gen.health_check()
415
  st.json(report)
416
 
417
+ def input_ui():
418
+ """Creates the input method UI"""
419
+ input_method = st.selectbox("Input Method", ["Text", "PDF", "Web URL", "CSV", "Image", "Structured Prompt (Advanced)"]) #Add Image input, Add Structured Prompt (Advanced)
420
+ input_content = None
421
+ additional_instructions = "" #For structured prompt
422
+
423
+ if input_method == "Text":
424
+ input_content = st.text_area("Enter Text", height=200)
425
+ elif input_method == "PDF":
426
+ uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
427
+ if uploaded_file is not None:
428
+ input_content = uploaded_file
429
+ elif input_method == "Web URL":
430
+ url = st.text_input("Enter Web URL")
431
+ input_content = url
432
+ elif input_method == "CSV":
433
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
434
+ if uploaded_file is not None:
435
+ input_content = uploaded_file
436
+ if "csv_schema" in st.session_state:
437
+ st.write("Inferred CSV Schema:")
438
+ st.write(st.session_state.csv_schema) #Display inferred schema
439
+
440
+ elif input_method == "Image":
441
+ uploaded_file = st.file_uploader("Upload an Image file", type=["png", "jpg", "jpeg"])
442
+ if uploaded_file is not None:
443
+ input_content = uploaded_file
444
+
445
+ elif input_method == "Structured Prompt (Advanced)":
446
+ st.subheader("Structured Prompt")
447
+ input_content = st.text_area("Enter the base prompt/instructions", height=100)
448
+ additional_instructions = st.text_area("Specify constraints, data format, or other requirements:", height=100)
449
+
450
+ return input_method, input_content, additional_instructions #Also return additional instructions
451
+
452
  def main():
453
  """Enterprise-grade user interface"""
454
  st.set_page_config(
 
456
  page_icon="🏭",
457
  layout="wide"
458
  )
459
+
460
  gen = SyntheticDataGenerator()
461
+
462
  st.title("🏭 Synthetic Data Factory Pro")
463
  st.markdown("""
464
+ **World's Most Advanced Synthetic Data Generation Platform**
465
  *Multi-provider AI Engine | Enterprise Input Processors | Real-time Monitoring*
466
  """)
467
+
468
  provider_config_ui(gen)
469
+
470
+ input_method, input_content, additional_instructions = input_ui() #Get additonal instructions
471
+
472
+ if st.button("Generate Data"):
473
+ if input_content or input_method == "Structured Prompt (Advanced)": #Allow generation with *just* structured prompt
474
+ processed_input = None
475
+
476
+ if input_method == "Text":
477
+ processed_input = gen._process_text(input_content)
478
+ elif input_method == "PDF":
479
+ processed_input = gen._process_pdf(input_content)
480
+ elif input_method == "Web URL":
481
+ processed_input = gen._process_web(input_content)
482
+ elif input_method == "CSV":
483
+ processed_input = gen._process_csv(input_content)
484
+ elif input_method == "Image":
485
+ processed_input = gen._process_image(input_content)
486
+ elif input_method == "Structured Prompt (Advanced)":
487
+ processed_input = input_content + "\n" + additional_instructions #Combine instructions and constraints
488
+ #st.write("Combined Prompt:")
489
+ #st.write(processed_input) #Debug
490
+
491
+ if processed_input:
492
+ try:
493
+ #Handle Google image case - requires a list of content. Other providers just use the text
494
+ if st.session_state.active_provider == "Google" and input_method == "Image":
495
+ prompt_parts = [processed_input, st.session_state.image_part] # Image part already stored
496
+ result = gen.generate(st.session_state.active_provider, st.session_state.active_model, prompt_parts) # Process Google Images
497
+ else:
498
+ result = gen.generate(st.session_state.active_provider, st.session_state.active_model, processed_input) # Generic text case
499
+
500
+ st.subheader("Generated Output:")
501
+ st.json(result) # Display the JSON output
502
+ except Exception as e:
503
+ st.error(f"Error during generation: {e}")
504
+ else:
505
+ st.warning("No data to process. Please check your input.")
506
+ else:
507
+ st.warning("Please provide input data.")
508
+
509
  # Input management and generation UI components...
510
+
511
  if __name__ == "__main__":
512
  main()