"""The UI file for the SynthGenAI package.""" import os import asyncio from huggingface_hub import HfFolder import gradio as gr from synthgenai import DatasetConfig, DatasetGeneratorConfig, LLMConfig, InstructionDatasetGenerator, PreferenceDatasetGenerator,RawDatasetGenerator,SentimentAnalysisDatasetGenerator, SummarizationDatasetGenerator, TextClassificationDatasetGenerator def validate_inputs(*args): """ Validate that all required inputs are filled. Args: *args: The input values to validate. Returns: bool: True if all required inputs are filled, False otherwise. """ for arg in args: if not arg: return False return True stop_event = asyncio.Event() def stop_generation(): """ Stop the dataset generation process. """ stop_event.set() def get_hf_token(): """ Retrieve the Hugging Face token from the huggingface_hub. Returns: str: The Hugging Face token. """ token = HfFolder.get_token() if not token: raise ValueError("Hugging Face token not found. Please login using the LoginButton.") return token def generate_synthetic_dataset( llm_model, temperature, top_p, max_tokens, dataset_type, topic, domains, language, additional_description, num_entries, hf_repo_name, llm_env_vars, ): """ Generate a dataset based on the provided parameters. Args: llm_model (str): The LLM model to use. temperature (float): The temperature for the LLM. top_p (float): The top_p value for the LLM. max_tokens (int): The maximum number of tokens for the LLM. dataset_type (str): The type of dataset to generate. topic (str): The topic of the dataset. domains (str): The domains for the dataset. language (str): The language of the dataset. additional_description (str): Additional description for the dataset. num_entries (int): The number of entries in the dataset. hf_repo_name (str): The Hugging Face repository name. llm_env_vars (str): Comma-separated environment variables for the LLM. Returns: str: A message indicating the result of the dataset generation. """ hf_token = get_hf_token() os.environ["HF_TOKEN"] = hf_token for var in llm_env_vars.split(","): if "=" in var: key, value = var.split("=", 1) os.environ[key.strip()] = value.strip() # Validate inputs if not validate_inputs( llm_model, temperature, top_p, max_tokens, dataset_type, topic, domains, language, num_entries, hf_repo_name, llm_env_vars, ): return "All fields except API Base and API Key must be filled." llm_config = LLMConfig( model=llm_model, temperature=temperature, top_p=top_p, max_tokens=max_tokens, ) dataset_config = DatasetConfig( topic=topic, domains=domains.split(","), language=language, additional_description=additional_description, num_entries=num_entries, ) dataset_generator_config = DatasetGeneratorConfig( llm_config=llm_config, dataset_config=dataset_config, ) if dataset_type == "Raw": generator = RawDatasetGenerator(dataset_generator_config) elif dataset_type == "Instruction": generator = InstructionDatasetGenerator(dataset_generator_config) elif dataset_type == "Preference": generator = PreferenceDatasetGenerator(dataset_generator_config) elif dataset_type == "Sentiment Analysis": generator = SentimentAnalysisDatasetGenerator(dataset_generator_config) elif dataset_type == "Summarization": generator = SummarizationDatasetGenerator(dataset_generator_config) elif dataset_type == "Text Classification": generator = TextClassificationDatasetGenerator(dataset_generator_config) else: return "Invalid dataset type" stop_event.clear() async def generate(): if stop_event.is_set(): return "Dataset generation stopped." dataset = await generator.agenerate_dataset() if stop_event.is_set(): return "Dataset generation stopped." dataset.save_dataset(hf_repo_name=hf_repo_name) return "Dataset generated and saved successfully." return asyncio.run(generate()) def ui_main(): """ Launch the Gradio UI for the SynthGenAI dataset generator. """ with gr.Blocks( title="SynthGenAI Dataset Generator", css=""" .gradio-container .gr-block { margin-bottom: 10px; } """, theme="ParityError/Interstellar", ) as demo: gr.Markdown( """