Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -24,27 +24,28 @@ DEFAULT_TEMPERATURE: float = 0.1
|
|
24 |
GROQ_MODEL: str = "mixtral-8x7b-32768"
|
25 |
|
26 |
|
27 |
-
class
|
28 |
"""
|
29 |
-
|
30 |
-
|
|
|
|
|
31 |
"""
|
32 |
def __init__(self) -> None:
|
33 |
self._setup_providers()
|
34 |
self._setup_input_handlers()
|
35 |
self._initialize_session_state()
|
36 |
-
# Prompt template with
|
37 |
self.custom_prompt_template: str = (
|
38 |
-
"You are an expert in
|
39 |
-
"Generate {num_examples}
|
40 |
-
"Each dictionary must have keys '
|
41 |
-
"The
|
42 |
-
"Do not
|
43 |
"Example JSON Output:\n"
|
44 |
-
"[{{'
|
45 |
-
"{{'
|
46 |
-
"
|
47 |
-
"Now, generate {num_examples} Q&A pairs from this data:\n{data}"
|
48 |
)
|
49 |
|
50 |
def _setup_providers(self) -> None:
|
@@ -85,12 +86,12 @@ class QADataGenerator:
|
|
85 |
"provider": "OpenAI",
|
86 |
"model": "gpt-4-turbo",
|
87 |
"temperature": DEFAULT_TEMPERATURE,
|
88 |
-
"num_examples": 3, # Default number of
|
89 |
},
|
90 |
"api_key": "",
|
91 |
-
"inputs": [],
|
92 |
-
"
|
93 |
-
"error_logs": [],
|
94 |
}
|
95 |
for key, value in defaults.items():
|
96 |
if key not in st.session_state:
|
@@ -185,9 +186,9 @@ class QADataGenerator:
|
|
185 |
st.write(prompt)
|
186 |
return prompt
|
187 |
|
188 |
-
def
|
189 |
"""
|
190 |
-
Generate
|
191 |
"""
|
192 |
api_key: str = st.session_state.api_key
|
193 |
if not api_key:
|
@@ -216,11 +217,11 @@ class QADataGenerator:
|
|
216 |
st.write("### Raw API Response")
|
217 |
st.write(response)
|
218 |
|
219 |
-
|
220 |
-
st.write("### Parsed
|
221 |
-
st.write(
|
222 |
|
223 |
-
st.session_state.
|
224 |
return True
|
225 |
except Exception as e:
|
226 |
self.log_error(f"Generation failed: {e}")
|
@@ -260,7 +261,7 @@ class QADataGenerator:
|
|
260 |
|
261 |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
262 |
"""
|
263 |
-
Parse the LLM response and return a list of
|
264 |
Expects the response to be JSON formatted; if JSON decoding fails,
|
265 |
uses ast.literal_eval as a fallback.
|
266 |
"""
|
@@ -280,17 +281,17 @@ class QADataGenerator:
|
|
280 |
return []
|
281 |
|
282 |
try:
|
283 |
-
|
284 |
except json.JSONDecodeError as e:
|
285 |
self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}")
|
286 |
try:
|
287 |
-
|
288 |
except Exception as e2:
|
289 |
self.log_error(f"ast.literal_eval failed: {e2}")
|
290 |
return []
|
291 |
|
292 |
-
if isinstance(
|
293 |
-
return
|
294 |
else:
|
295 |
self.log_error("Parsed output is not a list.")
|
296 |
return []
|
@@ -301,11 +302,11 @@ class QADataGenerator:
|
|
301 |
|
302 |
# ============ UI Components ============
|
303 |
|
304 |
-
def config_ui(generator:
|
305 |
"""Display configuration options in the sidebar and update URL query parameters."""
|
306 |
with st.sidebar:
|
307 |
st.header("Configuration")
|
308 |
-
# Retrieve
|
309 |
params = st.experimental_get_query_params()
|
310 |
default_provider = params.get("provider", ["OpenAI"])[0]
|
311 |
default_model = params.get("model", ["gpt-4-turbo"])[0]
|
@@ -326,22 +327,22 @@ def config_ui(generator: QADataGenerator) -> None:
|
|
326 |
temperature = st.slider("Temperature", 0.0, 1.0, default_temperature)
|
327 |
st.session_state.config["temperature"] = temperature
|
328 |
|
329 |
-
num_examples = st.number_input("Number of
|
330 |
value=default_num_examples, step=1)
|
331 |
st.session_state.config["num_examples"] = num_examples
|
332 |
|
333 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
334 |
st.session_state.api_key = api_key
|
335 |
|
336 |
-
# Update
|
337 |
-
st.
|
338 |
provider=st.session_state.config["provider"],
|
339 |
model=st.session_state.config["model"],
|
340 |
temperature=st.session_state.config["temperature"],
|
341 |
num_examples=st.session_state.config["num_examples"],
|
342 |
)
|
343 |
|
344 |
-
def input_ui(generator:
|
345 |
"""Display input data source options using tabs."""
|
346 |
st.subheader("Input Data Sources")
|
347 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
@@ -387,35 +388,35 @@ def input_ui(generator: QADataGenerator) -> None:
|
|
387 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
388 |
st.success("Database input added!")
|
389 |
|
390 |
-
def output_ui(generator:
|
391 |
-
"""Display the generated
|
392 |
-
st.subheader("
|
393 |
-
if st.session_state.
|
394 |
-
st.write("### Generated
|
395 |
-
st.write(st.session_state.
|
396 |
|
397 |
# Download as JSON
|
398 |
st.download_button(
|
399 |
"Download as JSON",
|
400 |
-
json.dumps(st.session_state.
|
401 |
-
file_name="
|
402 |
mime="application/json"
|
403 |
)
|
404 |
|
405 |
# Download as CSV
|
406 |
try:
|
407 |
-
df = pd.DataFrame(st.session_state.
|
408 |
csv_data = df.to_csv(index=False)
|
409 |
st.download_button(
|
410 |
"Download as CSV",
|
411 |
csv_data,
|
412 |
-
file_name="
|
413 |
mime="text/csv"
|
414 |
)
|
415 |
except Exception as e:
|
416 |
st.error(f"Error generating CSV: {e}")
|
417 |
else:
|
418 |
-
st.info("No
|
419 |
|
420 |
def logs_ui() -> None:
|
421 |
"""Display error logs and debugging information in an expandable section."""
|
@@ -428,17 +429,18 @@ def logs_ui() -> None:
|
|
428 |
|
429 |
def main() -> None:
|
430 |
"""Main Streamlit application entry point."""
|
431 |
-
st.set_page_config(page_title="Advanced
|
432 |
-
st.title("Advanced
|
433 |
st.markdown(
|
434 |
"""
|
435 |
-
Welcome to the Advanced
|
436 |
-
|
|
|
437 |
"""
|
438 |
)
|
439 |
|
440 |
# Initialize generator and display configuration UI
|
441 |
-
generator =
|
442 |
config_ui(generator)
|
443 |
|
444 |
st.header("1. Input Data")
|
@@ -447,13 +449,13 @@ def main() -> None:
|
|
447 |
st.session_state.inputs = []
|
448 |
st.success("All inputs have been cleared!")
|
449 |
|
450 |
-
st.header("2. Generate
|
451 |
-
if st.button("Generate
|
452 |
-
with st.spinner("Generating
|
453 |
-
if generator.
|
454 |
-
st.success("
|
455 |
else:
|
456 |
-
st.error("
|
457 |
|
458 |
st.header("3. Output")
|
459 |
output_ui(generator)
|
|
|
24 |
GROQ_MODEL: str = "mixtral-8x7b-32768"
|
25 |
|
26 |
|
27 |
+
class SyntheticDataGenerator:
|
28 |
"""
|
29 |
+
An advanced Synthetic Data Generator for creating training examples for fine-tuning.
|
30 |
+
|
31 |
+
The generator accepts various input sources and then uses an LLM provider to create
|
32 |
+
synthetic examples in JSON format. Each example is a dictionary with 'input' and 'output' keys.
|
33 |
"""
|
34 |
def __init__(self) -> None:
|
35 |
self._setup_providers()
|
36 |
self._setup_input_handlers()
|
37 |
self._initialize_session_state()
|
38 |
+
# Prompt template with dynamic {num_examples} parameter and escaped curly braces.
|
39 |
self.custom_prompt_template: str = (
|
40 |
+
"You are an expert in generating synthetic training data for fine-tuning. "
|
41 |
+
"Generate {num_examples} training examples from the following data, formatted as a JSON list of dictionaries. "
|
42 |
+
"Each dictionary must have keys 'input' and 'output'. "
|
43 |
+
"The examples should be clear, diverse, and based solely on the provided data. "
|
44 |
+
"Do not add any external information. \n\n"
|
45 |
"Example JSON Output:\n"
|
46 |
+
"[{{'input': 'sample input text 1', 'output': 'sample output text 1'}}, "
|
47 |
+
"{{'input': 'sample input text 2', 'output': 'sample output text 2'}}]\n\n"
|
48 |
+
"Now, generate {num_examples} training examples from this data:\n{data}"
|
|
|
49 |
)
|
50 |
|
51 |
def _setup_providers(self) -> None:
|
|
|
86 |
"provider": "OpenAI",
|
87 |
"model": "gpt-4-turbo",
|
88 |
"temperature": DEFAULT_TEMPERATURE,
|
89 |
+
"num_examples": 3, # Default number of synthetic examples
|
90 |
},
|
91 |
"api_key": "",
|
92 |
+
"inputs": [], # List to store input sources
|
93 |
+
"synthetic_data": None, # Generated synthetic data output
|
94 |
+
"error_logs": [], # To store error messages
|
95 |
}
|
96 |
for key, value in defaults.items():
|
97 |
if key not in st.session_state:
|
|
|
186 |
st.write(prompt)
|
187 |
return prompt
|
188 |
|
189 |
+
def generate_synthetic_data(self) -> bool:
|
190 |
"""
|
191 |
+
Generate synthetic training examples by sending the built prompt to the selected LLM provider.
|
192 |
"""
|
193 |
api_key: str = st.session_state.api_key
|
194 |
if not api_key:
|
|
|
217 |
st.write("### Raw API Response")
|
218 |
st.write(response)
|
219 |
|
220 |
+
synthetic_examples = self._parse_response(response, provider_name)
|
221 |
+
st.write("### Parsed Synthetic Data")
|
222 |
+
st.write(synthetic_examples)
|
223 |
|
224 |
+
st.session_state.synthetic_data = synthetic_examples
|
225 |
return True
|
226 |
except Exception as e:
|
227 |
self.log_error(f"Generation failed: {e}")
|
|
|
261 |
|
262 |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
263 |
"""
|
264 |
+
Parse the LLM response and return a list of synthetic training examples.
|
265 |
Expects the response to be JSON formatted; if JSON decoding fails,
|
266 |
uses ast.literal_eval as a fallback.
|
267 |
"""
|
|
|
281 |
return []
|
282 |
|
283 |
try:
|
284 |
+
examples = json.loads(raw_text)
|
285 |
except json.JSONDecodeError as e:
|
286 |
self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}")
|
287 |
try:
|
288 |
+
examples = ast.literal_eval(raw_text)
|
289 |
except Exception as e2:
|
290 |
self.log_error(f"ast.literal_eval failed: {e2}")
|
291 |
return []
|
292 |
|
293 |
+
if isinstance(examples, list):
|
294 |
+
return examples
|
295 |
else:
|
296 |
self.log_error("Parsed output is not a list.")
|
297 |
return []
|
|
|
302 |
|
303 |
# ============ UI Components ============
|
304 |
|
305 |
+
def config_ui(generator: SyntheticDataGenerator) -> None:
|
306 |
"""Display configuration options in the sidebar and update URL query parameters."""
|
307 |
with st.sidebar:
|
308 |
st.header("Configuration")
|
309 |
+
# Retrieve query parameters (if any)
|
310 |
params = st.experimental_get_query_params()
|
311 |
default_provider = params.get("provider", ["OpenAI"])[0]
|
312 |
default_model = params.get("model", ["gpt-4-turbo"])[0]
|
|
|
327 |
temperature = st.slider("Temperature", 0.0, 1.0, default_temperature)
|
328 |
st.session_state.config["temperature"] = temperature
|
329 |
|
330 |
+
num_examples = st.number_input("Number of Training Examples", min_value=1, max_value=10,
|
331 |
value=default_num_examples, step=1)
|
332 |
st.session_state.config["num_examples"] = num_examples
|
333 |
|
334 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
335 |
st.session_state.api_key = api_key
|
336 |
|
337 |
+
# Update URL query parameters using the new API (st.set_query_params)
|
338 |
+
st.set_query_params(
|
339 |
provider=st.session_state.config["provider"],
|
340 |
model=st.session_state.config["model"],
|
341 |
temperature=st.session_state.config["temperature"],
|
342 |
num_examples=st.session_state.config["num_examples"],
|
343 |
)
|
344 |
|
345 |
+
def input_ui(generator: SyntheticDataGenerator) -> None:
|
346 |
"""Display input data source options using tabs."""
|
347 |
st.subheader("Input Data Sources")
|
348 |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
|
|
388 |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
389 |
st.success("Database input added!")
|
390 |
|
391 |
+
def output_ui(generator: SyntheticDataGenerator) -> None:
|
392 |
+
"""Display the generated synthetic data and provide download options (JSON and CSV)."""
|
393 |
+
st.subheader("Synthetic Data Output")
|
394 |
+
if st.session_state.synthetic_data:
|
395 |
+
st.write("### Generated Training Examples")
|
396 |
+
st.write(st.session_state.synthetic_data)
|
397 |
|
398 |
# Download as JSON
|
399 |
st.download_button(
|
400 |
"Download as JSON",
|
401 |
+
json.dumps(st.session_state.synthetic_data, indent=2),
|
402 |
+
file_name="synthetic_data.json",
|
403 |
mime="application/json"
|
404 |
)
|
405 |
|
406 |
# Download as CSV
|
407 |
try:
|
408 |
+
df = pd.DataFrame(st.session_state.synthetic_data)
|
409 |
csv_data = df.to_csv(index=False)
|
410 |
st.download_button(
|
411 |
"Download as CSV",
|
412 |
csv_data,
|
413 |
+
file_name="synthetic_data.csv",
|
414 |
mime="text/csv"
|
415 |
)
|
416 |
except Exception as e:
|
417 |
st.error(f"Error generating CSV: {e}")
|
418 |
else:
|
419 |
+
st.info("No synthetic data generated yet.")
|
420 |
|
421 |
def logs_ui() -> None:
|
422 |
"""Display error logs and debugging information in an expandable section."""
|
|
|
429 |
|
430 |
def main() -> None:
|
431 |
"""Main Streamlit application entry point."""
|
432 |
+
st.set_page_config(page_title="Advanced Synthetic Data Generator", layout="wide")
|
433 |
+
st.title("Advanced Synthetic Data Generator")
|
434 |
st.markdown(
|
435 |
"""
|
436 |
+
Welcome to the Advanced Synthetic Data Generator. This tool creates synthetic training examples
|
437 |
+
for fine-tuning models. Configure your provider in the sidebar, add input data, and click the button
|
438 |
+
below to generate synthetic data.
|
439 |
"""
|
440 |
)
|
441 |
|
442 |
# Initialize generator and display configuration UI
|
443 |
+
generator = SyntheticDataGenerator()
|
444 |
config_ui(generator)
|
445 |
|
446 |
st.header("1. Input Data")
|
|
|
449 |
st.session_state.inputs = []
|
450 |
st.success("All inputs have been cleared!")
|
451 |
|
452 |
+
st.header("2. Generate Synthetic Data")
|
453 |
+
if st.button("Generate Synthetic Data", key="generate_data"):
|
454 |
+
with st.spinner("Generating synthetic data..."):
|
455 |
+
if generator.generate_synthetic_data():
|
456 |
+
st.success("Synthetic data generated successfully!")
|
457 |
else:
|
458 |
+
st.error("Data generation failed. Check logs for details.")
|
459 |
|
460 |
st.header("3. Output")
|
461 |
output_ui(generator)
|