mgbam commited on
Commit
8fa07b2
·
verified ·
1 Parent(s): a4956fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -89
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import ast
 
3
  import requests
4
  import streamlit as st
5
  import pdfplumber
@@ -7,6 +8,16 @@ import pandas as pd
7
  import sqlalchemy
8
  from typing import Any, Dict, List, Callable
9
 
 
 
 
 
 
 
 
 
 
 
10
  # Provider clients – ensure these libraries are installed
11
  try:
12
  from openai import OpenAI
@@ -18,7 +29,7 @@ try:
18
  except ImportError:
19
  groq = None
20
 
21
- # Hugging Face inference endpoint and defaults
22
  HF_API_URL: str = "https://api-inference.huggingface.co/models/"
23
  DEFAULT_TEMPERATURE: float = 0.1
24
  GROQ_MODEL: str = "mixtral-8x7b-32768"
@@ -26,28 +37,27 @@ GROQ_MODEL: str = "mixtral-8x7b-32768"
26
 
27
  class SyntheticDataGenerator:
28
  """
29
- An advanced Synthetic Data Generator for creating training examples for fine-tuning.
30
 
31
- The generator accepts various input sources and then uses an LLM provider to create
32
- synthetic examples in JSON format. Each example is a dictionary with 'input' and 'output' keys.
33
  """
34
  def __init__(self) -> None:
35
  self._setup_providers()
36
  self._setup_input_handlers()
37
  self._initialize_session_state()
38
- # Prompt template with dynamic {num_examples} parameter and escaped curly braces.
39
  self.custom_prompt_template: str = (
40
  "You are an expert in generating synthetic training data for fine-tuning. "
41
  "Generate {num_examples} training examples from the following data, formatted as a JSON list of dictionaries. "
42
  "Each dictionary must have keys 'input' and 'output'. "
43
- "The examples should be clear, diverse, and based solely on the provided data. "
44
- "Do not add any external information. \n\n"
45
  "Example JSON Output:\n"
46
  "[{{'input': 'sample input text 1', 'output': 'sample output text 1'}}, "
47
  "{{'input': 'sample input text 2', 'output': 'sample output text 2'}}]\n\n"
48
  "Now, generate {num_examples} training examples from this data:\n{data}"
49
  )
50
-
51
  def _setup_providers(self) -> None:
52
  """Configure available LLM providers and their client initialization routines."""
53
  self.providers: Dict[str, Dict[str, Any]] = {
@@ -68,9 +78,9 @@ class SyntheticDataGenerator:
68
  "models": ["gpt2", "llama-2"],
69
  },
70
  }
71
-
72
  def _setup_input_handlers(self) -> None:
73
- """Register handlers for different input data types."""
74
  self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
75
  "text": self.handle_text,
76
  "pdf": self.handle_pdf,
@@ -78,20 +88,23 @@ class SyntheticDataGenerator:
78
  "api": self.handle_api,
79
  "db": self.handle_db,
80
  }
81
-
82
  def _initialize_session_state(self) -> None:
83
- """Initialize Streamlit session state with default configuration."""
 
 
 
84
  defaults: Dict[str, Any] = {
85
  "config": {
86
  "provider": "OpenAI",
87
  "model": "gpt-4-turbo",
88
  "temperature": DEFAULT_TEMPERATURE,
89
- "num_examples": 3, # Default number of synthetic examples
90
  },
91
  "api_key": "",
92
- "inputs": [], # List to store input sources
93
- "synthetic_data": None, # Generated synthetic data output
94
- "error_logs": [], # To store error messages
95
  }
96
  for key, value in defaults.items():
97
  if key not in st.session_state:
@@ -113,18 +126,19 @@ class SyntheticDataGenerator:
113
  st.session_state.config["num_examples"] = int(params["num_examples"][0])
114
  except ValueError:
115
  pass
116
-
117
  def log_error(self, message: str) -> None:
118
- """Log an error message to session state and display it."""
119
  st.session_state.error_logs.append(message)
120
  st.error(message)
121
-
 
122
  # ----- Input Handlers -----
123
  def handle_text(self, text: str) -> Dict[str, Any]:
124
- """Process plain text input."""
125
  return {"data": text, "source": "text"}
126
-
127
- def handle_pdf(self, file) -> Dict[str, Any]:
128
  """Extract text from a PDF file."""
129
  try:
130
  with pdfplumber.open(file) as pdf:
@@ -133,17 +147,16 @@ class SyntheticDataGenerator:
133
  except Exception as e:
134
  self.log_error(f"PDF Processing Error: {e}")
135
  return {"data": "", "source": "pdf"}
136
-
137
- def handle_csv(self, file) -> Dict[str, Any]:
138
- """Process a CSV file by converting it to JSON."""
139
  try:
140
  df = pd.read_csv(file)
141
- json_data = df.to_json(orient="records")
142
- return {"data": json_data, "source": "csv"}
143
  except Exception as e:
144
  self.log_error(f"CSV Processing Error: {e}")
145
  return {"data": "", "source": "csv"}
146
-
147
  def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
148
  """Fetch data from an API endpoint."""
149
  try:
@@ -153,9 +166,9 @@ class SyntheticDataGenerator:
153
  except Exception as e:
154
  self.log_error(f"API Processing Error: {e}")
155
  return {"data": "", "source": "api"}
156
-
157
  def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
158
- """Query a database using the provided connection string and SQL query."""
159
  try:
160
  engine = sqlalchemy.create_engine(config["connection"])
161
  with engine.connect() as conn:
@@ -165,19 +178,18 @@ class SyntheticDataGenerator:
165
  except Exception as e:
166
  self.log_error(f"Database Processing Error: {e}")
167
  return {"data": "", "source": "db"}
168
-
169
  def aggregate_inputs(self) -> str:
170
- """Combine all input sources into a single aggregated string."""
171
- aggregated_data = ""
172
  for item in st.session_state.inputs:
173
- aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
174
- aggregated_data += item.get("data", "") + "\n\n"
175
- return aggregated_data.strip()
176
-
177
  def build_prompt(self) -> str:
178
  """
179
- Build the complete prompt using the custom template, aggregated inputs,
180
- and the number of examples.
181
  """
182
  data = self.aggregate_inputs()
183
  num_examples = st.session_state.config.get("num_examples", 3)
@@ -185,50 +197,52 @@ class SyntheticDataGenerator:
185
  st.write("### Built Prompt")
186
  st.write(prompt)
187
  return prompt
188
-
189
  def generate_synthetic_data(self) -> bool:
190
  """
191
- Generate synthetic training examples by sending the built prompt to the selected LLM provider.
192
  """
193
  api_key: str = st.session_state.api_key
194
  if not api_key:
195
  self.log_error("API key is missing!")
196
  return False
197
-
198
  provider_name: str = st.session_state.config["provider"]
199
  provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {})
200
  if not provider_cfg:
201
  self.log_error(f"Provider {provider_name} is not configured.")
202
  return False
203
-
204
  client_initializer: Callable[[str], Any] = provider_cfg["client"]
205
  client = client_initializer(api_key)
206
  model: str = st.session_state.config["model"]
207
  temperature: float = st.session_state.config["temperature"]
208
  prompt: str = self.build_prompt()
209
-
210
  st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
211
  try:
212
  if provider_name == "HuggingFace":
213
  response = self._huggingface_inference(client, prompt, model)
214
  else:
215
  response = self._standard_inference(client, prompt, model, temperature)
216
-
217
  st.write("### Raw API Response")
218
  st.write(response)
219
-
220
  synthetic_examples = self._parse_response(response, provider_name)
221
  st.write("### Parsed Synthetic Data")
222
  st.write(synthetic_examples)
223
-
224
  st.session_state.synthetic_data = synthetic_examples
225
  return True
226
  except Exception as e:
227
  self.log_error(f"Generation failed: {e}")
228
  return False
229
-
230
  def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
231
- """Inference method for providers using an OpenAI-compatible API."""
 
 
232
  try:
233
  st.write("Sending prompt via standard inference...")
234
  result = client.chat.completions.create(
@@ -241,9 +255,11 @@ class SyntheticDataGenerator:
241
  except Exception as e:
242
  self.log_error(f"Standard Inference Error: {e}")
243
  return None
244
-
245
  def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
246
- """Inference method for the Hugging Face Inference API."""
 
 
247
  try:
248
  st.write("Sending prompt to HuggingFace API...")
249
  response = requests.post(
@@ -258,38 +274,39 @@ class SyntheticDataGenerator:
258
  except Exception as e:
259
  self.log_error(f"HuggingFace Inference Error: {e}")
260
  return None
261
-
262
  def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
263
  """
264
  Parse the LLM response and return a list of synthetic training examples.
265
- Expects the response to be JSON formatted; if JSON decoding fails,
266
- uses ast.literal_eval as a fallback.
267
  """
268
  st.write("Parsing response for provider:", provider)
269
  try:
270
  if provider == "HuggingFace":
 
271
  if isinstance(response, list) and response and "generated_text" in response[0]:
272
  raw_text = response[0]["generated_text"]
273
  else:
274
  self.log_error("Unexpected HuggingFace response format.")
275
  return []
276
  else:
 
277
  if response and hasattr(response, "choices") and response.choices:
278
  raw_text = response.choices[0].message.content
279
  else:
280
  self.log_error("Unexpected response format from provider.")
281
  return []
282
-
283
  try:
284
  examples = json.loads(raw_text)
285
  except json.JSONDecodeError as e:
286
- self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}")
287
  try:
288
  examples = ast.literal_eval(raw_text)
289
  except Exception as e2:
290
  self.log_error(f"ast.literal_eval failed: {e2}")
291
  return []
292
-
293
  if isinstance(examples, list):
294
  return examples
295
  else:
@@ -300,41 +317,45 @@ class SyntheticDataGenerator:
300
  return []
301
 
302
 
303
- # ============ UI Components ============
304
 
305
  def config_ui(generator: SyntheticDataGenerator) -> None:
306
- """Display configuration options in the sidebar and update URL query parameters."""
 
 
 
307
  with st.sidebar:
308
  st.header("Configuration")
309
- # Retrieve query parameters (if any)
310
  params = st.experimental_get_query_params()
311
  default_provider = params.get("provider", ["OpenAI"])[0]
312
  default_model = params.get("model", ["gpt-4-turbo"])[0]
313
  default_temperature = float(params.get("temperature", [DEFAULT_TEMPERATURE])[0])
314
  default_num_examples = int(params.get("num_examples", [3])[0])
315
-
316
  provider_options = list(generator.providers.keys())
317
- provider = st.selectbox("Select Provider", provider_options,
318
- index=provider_options.index(default_provider) if default_provider in provider_options else 0)
 
319
  st.session_state.config["provider"] = provider
320
  provider_cfg = generator.providers[provider]
321
-
322
  model_options = provider_cfg["models"]
323
  model = st.selectbox("Select Model", model_options,
324
- index=model_options.index(default_model) if default_model in model_options else 0)
 
325
  st.session_state.config["model"] = model
326
-
327
  temperature = st.slider("Temperature", 0.0, 1.0, default_temperature)
328
  st.session_state.config["temperature"] = temperature
329
-
330
- num_examples = st.number_input("Number of Training Examples", min_value=1, max_value=10,
331
  value=default_num_examples, step=1)
332
  st.session_state.config["num_examples"] = num_examples
333
-
334
  api_key = st.text_input(f"{provider} API Key", type="password")
335
  st.session_state.api_key = api_key
336
-
337
- # Update URL query parameters using the new API (st.set_query_params)
338
  st.set_query_params(
339
  provider=st.session_state.config["provider"],
340
  model=st.session_state.config["model"],
@@ -343,10 +364,10 @@ def config_ui(generator: SyntheticDataGenerator) -> None:
343
  )
344
 
345
  def input_ui(generator: SyntheticDataGenerator) -> None:
346
- """Display input data source options using tabs."""
347
  st.subheader("Input Data Sources")
348
  tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
349
-
350
  with tabs[0]:
351
  text_input = st.text_area("Enter text input", height=150)
352
  if st.button("Add Text Input", key="text_input"):
@@ -355,19 +376,19 @@ def input_ui(generator: SyntheticDataGenerator) -> None:
355
  st.success("Text input added!")
356
  else:
357
  st.warning("Empty text input.")
358
-
359
  with tabs[1]:
360
  pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
361
  if pdf_file is not None:
362
  st.session_state.inputs.append(generator.handle_pdf(pdf_file))
363
  st.success("PDF input added!")
364
-
365
  with tabs[2]:
366
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
367
  if csv_file is not None:
368
  st.session_state.inputs.append(generator.handle_csv(csv_file))
369
  st.success("CSV input added!")
370
-
371
  with tabs[3]:
372
  api_url = st.text_input("API Endpoint URL")
373
  api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
@@ -380,7 +401,7 @@ def input_ui(generator: SyntheticDataGenerator) -> None:
380
  generator.log_error(f"Invalid JSON for API Headers: {e}")
381
  st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
382
  st.success("API input added!")
383
-
384
  with tabs[4]:
385
  db_conn = st.text_input("Database Connection String")
386
  db_query = st.text_area("Database Query", height=100)
@@ -389,12 +410,12 @@ def input_ui(generator: SyntheticDataGenerator) -> None:
389
  st.success("Database input added!")
390
 
391
  def output_ui(generator: SyntheticDataGenerator) -> None:
392
- """Display the generated synthetic data and provide download options (JSON and CSV)."""
393
  st.subheader("Synthetic Data Output")
394
  if st.session_state.synthetic_data:
395
  st.write("### Generated Training Examples")
396
  st.write(st.session_state.synthetic_data)
397
-
398
  # Download as JSON
399
  st.download_button(
400
  "Download as JSON",
@@ -402,7 +423,7 @@ def output_ui(generator: SyntheticDataGenerator) -> None:
402
  file_name="synthetic_data.json",
403
  mime="application/json"
404
  )
405
-
406
  # Download as CSV
407
  try:
408
  df = pd.DataFrame(st.session_state.synthetic_data)
@@ -419,7 +440,7 @@ def output_ui(generator: SyntheticDataGenerator) -> None:
419
  st.info("No synthetic data generated yet.")
420
 
421
  def logs_ui() -> None:
422
- """Display error logs and debugging information in an expandable section."""
423
  with st.expander("Error Logs & Debug Info", expanded=False):
424
  if st.session_state.error_logs:
425
  for log in st.session_state.error_logs:
@@ -434,21 +455,20 @@ def main() -> None:
434
  st.markdown(
435
  """
436
  Welcome to the Advanced Synthetic Data Generator. This tool creates synthetic training examples
437
- for fine-tuning models. Configure your provider in the sidebar, add input data, and click the button
438
- below to generate synthetic data.
439
  """
440
  )
441
-
442
- # Initialize generator and display configuration UI
443
  generator = SyntheticDataGenerator()
444
  config_ui(generator)
445
-
446
  st.header("1. Input Data")
447
  input_ui(generator)
448
  if st.button("Clear All Inputs"):
449
  st.session_state.inputs = []
450
  st.success("All inputs have been cleared!")
451
-
452
  st.header("2. Generate Synthetic Data")
453
  if st.button("Generate Synthetic Data", key="generate_data"):
454
  with st.spinner("Generating synthetic data..."):
@@ -456,10 +476,10 @@ def main() -> None:
456
  st.success("Synthetic data generated successfully!")
457
  else:
458
  st.error("Data generation failed. Check logs for details.")
459
-
460
  st.header("3. Output")
461
  output_ui(generator)
462
-
463
  st.header("4. Logs & Debug Information")
464
  logs_ui()
465
 
 
1
  import json
2
  import ast
3
+ import logging
4
  import requests
5
  import streamlit as st
6
  import pdfplumber
 
8
  import sqlalchemy
9
  from typing import Any, Dict, List, Callable
10
 
11
+ # Configure Python logging for production diagnostics.
12
+ logger = logging.getLogger("SyntheticDataGenerator")
13
+ logger.setLevel(logging.INFO)
14
+ if not logger.handlers:
15
+ handler = logging.StreamHandler()
16
+ handler.setLevel(logging.INFO)
17
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
18
+ handler.setFormatter(formatter)
19
+ logger.addHandler(handler)
20
+
21
  # Provider clients – ensure these libraries are installed
22
  try:
23
  from openai import OpenAI
 
29
  except ImportError:
30
  groq = None
31
 
32
+ # Constants for external APIs
33
  HF_API_URL: str = "https://api-inference.huggingface.co/models/"
34
  DEFAULT_TEMPERATURE: float = 0.1
35
  GROQ_MODEL: str = "mixtral-8x7b-32768"
 
37
 
38
  class SyntheticDataGenerator:
39
  """
40
+ An advanced synthetic data generator for creating fine-tuning training examples.
41
 
42
+ This generator uses various input sources and an LLM provider to create synthetic data.
43
+ Each generated example is a dictionary with 'input' and 'output' keys.
44
  """
45
  def __init__(self) -> None:
46
  self._setup_providers()
47
  self._setup_input_handlers()
48
  self._initialize_session_state()
49
+ # Prompt template: note the use of escaped curly braces so that literal braces are kept.
50
  self.custom_prompt_template: str = (
51
  "You are an expert in generating synthetic training data for fine-tuning. "
52
  "Generate {num_examples} training examples from the following data, formatted as a JSON list of dictionaries. "
53
  "Each dictionary must have keys 'input' and 'output'. "
54
+ "The examples should be clear, diverse, and based solely on the provided data. Do not add any external information.\n\n"
 
55
  "Example JSON Output:\n"
56
  "[{{'input': 'sample input text 1', 'output': 'sample output text 1'}}, "
57
  "{{'input': 'sample input text 2', 'output': 'sample output text 2'}}]\n\n"
58
  "Now, generate {num_examples} training examples from this data:\n{data}"
59
  )
60
+
61
  def _setup_providers(self) -> None:
62
  """Configure available LLM providers and their client initialization routines."""
63
  self.providers: Dict[str, Dict[str, Any]] = {
 
78
  "models": ["gpt2", "llama-2"],
79
  },
80
  }
81
+
82
  def _setup_input_handlers(self) -> None:
83
+ """Register input handlers for various data types."""
84
  self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
85
  "text": self.handle_text,
86
  "pdf": self.handle_pdf,
 
88
  "api": self.handle_api,
89
  "db": self.handle_db,
90
  }
91
+
92
  def _initialize_session_state(self) -> None:
93
+ """
94
+ Initialize the Streamlit session state with default configuration.
95
+ Also pre-populate configuration from URL query parameters.
96
+ """
97
  defaults: Dict[str, Any] = {
98
  "config": {
99
  "provider": "OpenAI",
100
  "model": "gpt-4-turbo",
101
  "temperature": DEFAULT_TEMPERATURE,
102
+ "num_examples": 3,
103
  },
104
  "api_key": "",
105
+ "inputs": [], # List to store input sources
106
+ "synthetic_data": None, # Generated synthetic training examples
107
+ "error_logs": [], # To store error messages
108
  }
109
  for key, value in defaults.items():
110
  if key not in st.session_state:
 
126
  st.session_state.config["num_examples"] = int(params["num_examples"][0])
127
  except ValueError:
128
  pass
129
+
130
  def log_error(self, message: str) -> None:
131
+ """Log an error message to both Streamlit and the production logger."""
132
  st.session_state.error_logs.append(message)
133
  st.error(message)
134
+ logger.error(message)
135
+
136
  # ----- Input Handlers -----
137
  def handle_text(self, text: str) -> Dict[str, Any]:
138
+ """Return plain text input."""
139
  return {"data": text, "source": "text"}
140
+
141
+ def handle_pdf(self, file: Any) -> Dict[str, Any]:
142
  """Extract text from a PDF file."""
143
  try:
144
  with pdfplumber.open(file) as pdf:
 
147
  except Exception as e:
148
  self.log_error(f"PDF Processing Error: {e}")
149
  return {"data": "", "source": "pdf"}
150
+
151
+ def handle_csv(self, file: Any) -> Dict[str, Any]:
152
+ """Process CSV file by converting it to JSON."""
153
  try:
154
  df = pd.read_csv(file)
155
+ return {"data": df.to_json(orient="records"), "source": "csv"}
 
156
  except Exception as e:
157
  self.log_error(f"CSV Processing Error: {e}")
158
  return {"data": "", "source": "csv"}
159
+
160
  def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
161
  """Fetch data from an API endpoint."""
162
  try:
 
166
  except Exception as e:
167
  self.log_error(f"API Processing Error: {e}")
168
  return {"data": "", "source": "api"}
169
+
170
  def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
171
+ """Query a database using a connection string and SQL query."""
172
  try:
173
  engine = sqlalchemy.create_engine(config["connection"])
174
  with engine.connect() as conn:
 
178
  except Exception as e:
179
  self.log_error(f"Database Processing Error: {e}")
180
  return {"data": "", "source": "db"}
181
+
182
  def aggregate_inputs(self) -> str:
183
+ """Aggregate all input data sources into a single string."""
184
+ aggregated = ""
185
  for item in st.session_state.inputs:
186
+ aggregated += f"Source: {item.get('source', 'unknown')}\n{item.get('data', '')}\n\n"
187
+ return aggregated.strip()
188
+
 
189
  def build_prompt(self) -> str:
190
  """
191
+ Build the complete prompt using the custom template, aggregated inputs,
192
+ and the configured number of examples.
193
  """
194
  data = self.aggregate_inputs()
195
  num_examples = st.session_state.config.get("num_examples", 3)
 
197
  st.write("### Built Prompt")
198
  st.write(prompt)
199
  return prompt
200
+
201
  def generate_synthetic_data(self) -> bool:
202
  """
203
+ Generate synthetic training examples by sending the prompt to the selected LLM provider.
204
  """
205
  api_key: str = st.session_state.api_key
206
  if not api_key:
207
  self.log_error("API key is missing!")
208
  return False
209
+
210
  provider_name: str = st.session_state.config["provider"]
211
  provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {})
212
  if not provider_cfg:
213
  self.log_error(f"Provider {provider_name} is not configured.")
214
  return False
215
+
216
  client_initializer: Callable[[str], Any] = provider_cfg["client"]
217
  client = client_initializer(api_key)
218
  model: str = st.session_state.config["model"]
219
  temperature: float = st.session_state.config["temperature"]
220
  prompt: str = self.build_prompt()
221
+
222
  st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
223
  try:
224
  if provider_name == "HuggingFace":
225
  response = self._huggingface_inference(client, prompt, model)
226
  else:
227
  response = self._standard_inference(client, prompt, model, temperature)
228
+
229
  st.write("### Raw API Response")
230
  st.write(response)
231
+
232
  synthetic_examples = self._parse_response(response, provider_name)
233
  st.write("### Parsed Synthetic Data")
234
  st.write(synthetic_examples)
235
+
236
  st.session_state.synthetic_data = synthetic_examples
237
  return True
238
  except Exception as e:
239
  self.log_error(f"Generation failed: {e}")
240
  return False
241
+
242
  def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
243
+ """
244
+ Inference method for providers with an OpenAI-compatible API.
245
+ """
246
  try:
247
  st.write("Sending prompt via standard inference...")
248
  result = client.chat.completions.create(
 
255
  except Exception as e:
256
  self.log_error(f"Standard Inference Error: {e}")
257
  return None
258
+
259
  def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
260
+ """
261
+ Inference method for the Hugging Face Inference API.
262
+ """
263
  try:
264
  st.write("Sending prompt to HuggingFace API...")
265
  response = requests.post(
 
274
  except Exception as e:
275
  self.log_error(f"HuggingFace Inference Error: {e}")
276
  return None
277
+
278
  def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
279
  """
280
  Parse the LLM response and return a list of synthetic training examples.
281
+ Attempts JSON decoding first and falls back to ast.literal_eval.
 
282
  """
283
  st.write("Parsing response for provider:", provider)
284
  try:
285
  if provider == "HuggingFace":
286
+ # Expect response to be a list with a key "generated_text"
287
  if isinstance(response, list) and response and "generated_text" in response[0]:
288
  raw_text = response[0]["generated_text"]
289
  else:
290
  self.log_error("Unexpected HuggingFace response format.")
291
  return []
292
  else:
293
+ # For OpenAI/Groq, look for choices[0].message.content
294
  if response and hasattr(response, "choices") and response.choices:
295
  raw_text = response.choices[0].message.content
296
  else:
297
  self.log_error("Unexpected response format from provider.")
298
  return []
299
+
300
  try:
301
  examples = json.loads(raw_text)
302
  except json.JSONDecodeError as e:
303
+ self.log_error(f"JSON Parsing Error: {e}. Fallback with ast.literal_eval. Raw output: {raw_text}")
304
  try:
305
  examples = ast.literal_eval(raw_text)
306
  except Exception as e2:
307
  self.log_error(f"ast.literal_eval failed: {e2}")
308
  return []
309
+
310
  if isinstance(examples, list):
311
  return examples
312
  else:
 
317
  return []
318
 
319
 
320
+ # =================== UI Components ===================
321
 
322
  def config_ui(generator: SyntheticDataGenerator) -> None:
323
+ """
324
+ Display configuration options in the sidebar.
325
+ Updates URL query parameters using st.set_query_params.
326
+ """
327
  with st.sidebar:
328
  st.header("Configuration")
 
329
  params = st.experimental_get_query_params()
330
  default_provider = params.get("provider", ["OpenAI"])[0]
331
  default_model = params.get("model", ["gpt-4-turbo"])[0]
332
  default_temperature = float(params.get("temperature", [DEFAULT_TEMPERATURE])[0])
333
  default_num_examples = int(params.get("num_examples", [3])[0])
334
+
335
  provider_options = list(generator.providers.keys())
336
+ provider = st.selectbox("Select Provider", provider_options,
337
+ index=provider_options.index(default_provider)
338
+ if default_provider in provider_options else 0)
339
  st.session_state.config["provider"] = provider
340
  provider_cfg = generator.providers[provider]
341
+
342
  model_options = provider_cfg["models"]
343
  model = st.selectbox("Select Model", model_options,
344
+ index=model_options.index(default_model)
345
+ if default_model in model_options else 0)
346
  st.session_state.config["model"] = model
347
+
348
  temperature = st.slider("Temperature", 0.0, 1.0, default_temperature)
349
  st.session_state.config["temperature"] = temperature
350
+
351
+ num_examples = st.number_input("Number of Training Examples", min_value=1, max_value=10,
352
  value=default_num_examples, step=1)
353
  st.session_state.config["num_examples"] = num_examples
354
+
355
  api_key = st.text_input(f"{provider} API Key", type="password")
356
  st.session_state.api_key = api_key
357
+
358
+ # Update URL query parameters (shareable configuration)
359
  st.set_query_params(
360
  provider=st.session_state.config["provider"],
361
  model=st.session_state.config["model"],
 
364
  )
365
 
366
  def input_ui(generator: SyntheticDataGenerator) -> None:
367
+ """Display input data source options in tabs."""
368
  st.subheader("Input Data Sources")
369
  tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
370
+
371
  with tabs[0]:
372
  text_input = st.text_area("Enter text input", height=150)
373
  if st.button("Add Text Input", key="text_input"):
 
376
  st.success("Text input added!")
377
  else:
378
  st.warning("Empty text input.")
379
+
380
  with tabs[1]:
381
  pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
382
  if pdf_file is not None:
383
  st.session_state.inputs.append(generator.handle_pdf(pdf_file))
384
  st.success("PDF input added!")
385
+
386
  with tabs[2]:
387
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
388
  if csv_file is not None:
389
  st.session_state.inputs.append(generator.handle_csv(csv_file))
390
  st.success("CSV input added!")
391
+
392
  with tabs[3]:
393
  api_url = st.text_input("API Endpoint URL")
394
  api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
 
401
  generator.log_error(f"Invalid JSON for API Headers: {e}")
402
  st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
403
  st.success("API input added!")
404
+
405
  with tabs[4]:
406
  db_conn = st.text_input("Database Connection String")
407
  db_query = st.text_area("Database Query", height=100)
 
410
  st.success("Database input added!")
411
 
412
  def output_ui(generator: SyntheticDataGenerator) -> None:
413
+ """Display the generated synthetic data and download options (JSON and CSV)."""
414
  st.subheader("Synthetic Data Output")
415
  if st.session_state.synthetic_data:
416
  st.write("### Generated Training Examples")
417
  st.write(st.session_state.synthetic_data)
418
+
419
  # Download as JSON
420
  st.download_button(
421
  "Download as JSON",
 
423
  file_name="synthetic_data.json",
424
  mime="application/json"
425
  )
426
+
427
  # Download as CSV
428
  try:
429
  df = pd.DataFrame(st.session_state.synthetic_data)
 
440
  st.info("No synthetic data generated yet.")
441
 
442
  def logs_ui() -> None:
443
+ """Display error logs and debug information in an expandable section."""
444
  with st.expander("Error Logs & Debug Info", expanded=False):
445
  if st.session_state.error_logs:
446
  for log in st.session_state.error_logs:
 
455
  st.markdown(
456
  """
457
  Welcome to the Advanced Synthetic Data Generator. This tool creates synthetic training examples
458
+ for fine-tuning models. Configure your provider in the sidebar, add input data, and generate synthetic data.
 
459
  """
460
  )
461
+
462
+ # Initialize generator and UI
463
  generator = SyntheticDataGenerator()
464
  config_ui(generator)
465
+
466
  st.header("1. Input Data")
467
  input_ui(generator)
468
  if st.button("Clear All Inputs"):
469
  st.session_state.inputs = []
470
  st.success("All inputs have been cleared!")
471
+
472
  st.header("2. Generate Synthetic Data")
473
  if st.button("Generate Synthetic Data", key="generate_data"):
474
  with st.spinner("Generating synthetic data..."):
 
476
  st.success("Synthetic data generated successfully!")
477
  else:
478
  st.error("Data generation failed. Check logs for details.")
479
+
480
  st.header("3. Output")
481
  output_ui(generator)
482
+
483
  st.header("4. Logs & Debug Information")
484
  logs_ui()
485