Spaces:
Running
Running
File size: 8,542 Bytes
2996fd9 66dcc60 2996fd9 d73ef8b 2996fd9 d73ef8b 2996fd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
"""The UI file for the SynthGenAI package."""
import os
import asyncio
import gradio as gr
from synthgenai import DatasetConfig, DatasetGeneratorConfig, LLMConfig, InstructionDatasetGenerator, PreferenceDatasetGenerator,RawDatasetGenerator,SentimentAnalysisDatasetGenerator, SummarizationDatasetGenerator, TextClassificationDatasetGenerator
def validate_inputs(*args):
"""
Validate that all required inputs are filled.
Args:
*args: The input values to validate.
Returns:
bool: True if all required inputs are filled, False otherwise.
"""
for arg in args:
if not arg:
return False
return True
def generate_synthetic_dataset(
llm_model,
temperature,
top_p,
max_tokens,
api_base,
api_key,
dataset_type,
topic,
domains,
language,
additional_description,
num_entries,
hf_token,
hf_repo_name,
llm_env_vars,
):
"""
Generate a dataset based on the provided parameters.
Args:
llm_model (str): The LLM model to use.
temperature (float): The temperature for the LLM.
top_p (float): The top_p value for the LLM.
max_tokens (int): The maximum number of tokens for the LLM.
api_base (str): The API base URL.
api_key (str): The API key.
dataset_type (str): The type of dataset to generate.
topic (str): The topic of the dataset.
domains (str): The domains for the dataset.
language (str): The language of the dataset.
additional_description (str): Additional description for the dataset.
num_entries (int): The number of entries in the dataset.
hf_token (str): The Hugging Face token.
hf_repo_name (str): The Hugging Face repository name.
llm_env_vars (str): Comma-separated environment variables for the LLM.
Returns:
str: A message indicating the result of the dataset generation.
"""
os.environ["HF_TOKEN"] = hf_token
for var in llm_env_vars.split(","):
key, value = var.split("=")
os.environ[key.strip()] = value.strip()
# Validate inputs
if not validate_inputs(
llm_model,
temperature,
top_p,
max_tokens,
dataset_type,
topic,
domains,
language,
num_entries,
hf_token,
hf_repo_name,
llm_env_vars,
):
return "All fields except API Base and API Key must be filled."
if api_base and api_key:
llm_config = LLMConfig(
model=llm_model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
api_base=api_base,
api_key=api_key,
)
else:
llm_config = LLMConfig(
model=llm_model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
dataset_config = DatasetConfig(
topic=topic,
domains=domains.split(","),
language=language,
additional_description=additional_description,
num_entries=num_entries,
)
dataset_generator_config = DatasetGeneratorConfig(
llm_config=llm_config,
dataset_config=dataset_config,
)
if dataset_type == "Raw":
generator = RawDatasetGenerator(dataset_generator_config)
elif dataset_type == "Instruction":
generator = InstructionDatasetGenerator(dataset_generator_config)
elif dataset_type == "Preference":
generator = PreferenceDatasetGenerator(dataset_generator_config)
elif dataset_type == "Sentiment Analysis":
generator = SentimentAnalysisDatasetGenerator(dataset_generator_config)
elif dataset_type == "Summarization":
generator = SummarizationDatasetGenerator(dataset_generator_config)
elif dataset_type == "Text Classification":
generator = TextClassificationDatasetGenerator(dataset_generator_config)
else:
return "Invalid dataset type"
dataset = asyncio.run(generator.agenerate_dataset())
dataset.save_dataset(hf_repo_name=hf_repo_name)
return "Dataset generated and saved successfully."
def ui_main():
"""
Launch the Gradio UI for the SynthGenAI dataset generator.
"""
with gr.Blocks(
title="SynthGenAI Dataset Generator",
css="footer {visibility: hidden}",
theme="ParityError/Interstellar",
) as demo:
gr.Markdown(
"""
<div style="text-align: center;">
<img src="./assets/logo_header.png" alt="Header Image" style="display: block; margin-left: auto; margin-right: auto; width: 50%;"/>
<h1>SynthGenAI Dataset Generator</h1>
</div>
"""
)
gr.Markdown(
"""
## Overview π§
SynthGenAI is designed to be modular and can be easily extended to include different API providers for LLMs and new features.
## Why SynthGenAI? π€
Interest in synthetic data generation has surged recently, driven by the growing recognition of data as a critical asset in AI development. Synthetic data generation addresses challenges by allowing us to create diverse and useful datasets using current pre-trained Large Language Models (LLMs).
[GitHub Repository](https://github.com/Shekswess/synthgenai/tree/main) | [Documentation](https://shekswess.github.io/synthgenai/)
"""
)
with gr.Row():
llm_model = gr.Textbox(
label="LLM Model", placeholder="model_provider/model_name"
)
temperature = gr.Slider(
label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.5
)
top_p = gr.Slider(
label="Top P", minimum=0.0, maximum=1.0, step=0.1, value=0.9
)
max_tokens = gr.Number(label="Max Tokens", value=2048)
api_base = gr.Textbox(label="API Base", placeholder="API Base - Optional")
api_key = gr.Textbox(
label="API Key", placeholder="Your API Key - Optional", type="password"
)
with gr.Row():
dataset_type = gr.Dropdown(
label="Dataset Type",
choices=[
"Raw",
"Instruction",
"Preference",
"Sentiment Analysis",
"Summarization",
"Text Classification",
],
)
topic = gr.Textbox(label="Topic", placeholder="Dataset topic")
domains = gr.Textbox(label="Domains", placeholder="Comma-separated domains")
language = gr.Textbox(
label="Language", placeholder="Language", value="English"
)
additional_description = gr.Textbox(
label="Additional Description",
placeholder="Additional description",
value="",
)
num_entries = gr.Number(label="Number of Entries", value=1000)
with gr.Row():
hf_token = gr.Textbox(
label="Hugging Face Token",
placeholder="Your HF Token",
type="password",
value=None,
)
hf_repo_name = gr.Textbox(
label="Hugging Face Repo Name",
placeholder="organization_or_user_name/dataset_name",
value=None,
)
llm_env_vars = gr.Textbox(
label="LLM Environment Variables",
placeholder="Comma-separated environment variables (e.g., KEY1=VALUE1, KEY2=VALUE2)",
value=None,
)
generate_button = gr.Button("Generate Dataset")
output = gr.Textbox(label="Operation Result", value="")
generate_button.click(
generate_synthetic_dataset,
inputs=[
llm_model,
temperature,
top_p,
max_tokens,
api_base,
api_key,
dataset_type,
topic,
domains,
language,
additional_description,
num_entries,
hf_token,
hf_repo_name,
llm_env_vars,
],
outputs=output,
)
demo.launch(inbrowser=True, favicon_path=None)
if __name__ == "__main__":
ui_main() |