import gradio as gr from utils.check_dataset import validate_dataset, generate_dataset_report from utils.sample_dataset import generate_sample_datasets from utils.model import GemmaFineTuning class GemmaUI: def __init__(self): self.model_instance = GemmaFineTuning() self.default_params = self.model_instance.default_params def create_ui(self): """Create the Gradio interface""" with gr.Blocks(title="Gemma Fine-tuning UI") as app: gr.Markdown("# Gemma Model Fine-tuning Interface") gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model") with gr.Tabs(): with gr.TabItem("1. Data Upload & Preprocessing"): with gr.Row(): with gr.Column(): file_upload = gr.File(label="Upload Dataset") file_format = gr.Radio( ["csv", "jsonl", "text"], label="File Format", value="csv" ) preprocess_button = gr.Button("Preprocess Dataset") dataset_info = gr.TextArea(label="Dataset Information", interactive=False) with gr.TabItem("2. Model & Hyperparameters"): with gr.Row(): with gr.Column(): model_name = gr.Dropdown( choices=[ "google/gemma-2b", "google/gemma-7b", "google/gemma-2b-it", "google/gemma-7b-it" ], value=self.default_params["model_name"], label="Model Name", info="Select a Gemma model to fine-tune" ) learning_rate = gr.Slider( minimum=1e-6, maximum=5e-4, value=self.default_params["learning_rate"], label="Learning Rate", info="Learning rate for the optimizer" ) batch_size = gr.Slider( minimum=1, maximum=32, step=1, value=self.default_params["batch_size"], label="Batch Size", info="Number of samples in each training batch" ) epochs = gr.Slider( minimum=1, maximum=10, step=1, value=self.default_params["epochs"], label="Epochs", info="Number of training epochs" ) with gr.Column(): max_length = gr.Slider( minimum=128, maximum=2048, step=16, value=self.default_params["max_length"], label="Max Sequence Length", info="Maximum token length for inputs" ) use_lora = gr.Checkbox( value=self.default_params["use_lora"], label="Use LoRA for Parameter-Efficient Fine-tuning", info="Recommended for faster training and lower memory usage" ) lora_r = gr.Slider( minimum=4, maximum=64, step=4, value=self.default_params["lora_r"], label="LoRA Rank (r)", info="Rank of the LoRA update matrices", visible=lambda: use_lora.value ) lora_alpha = gr.Slider( minimum=4, maximum=64, step=4, value=self.default_params["lora_alpha"], label="LoRA Alpha", info="Scaling factor for LoRA updates", visible=lambda: use_lora.value ) eval_ratio = gr.Slider( minimum=0.05, maximum=0.3, step=0.05, value=self.default_params["eval_ratio"], label="Validation Split Ratio", info="Portion of data to use for validation" ) with gr.TabItem("3. Training"): with gr.Row(): with gr.Column(): start_training_button = gr.Button("Start Fine-tuning") stop_training_button = gr.Button("Stop Training", variant="stop") training_status = gr.Textbox(label="Training Status", interactive=False) with gr.Column(): progress_plot = gr.Plot(label="Training Progress") refresh_plot_button = gr.Button("Refresh Plot") with gr.TabItem("4. Evaluation & Export"): with gr.Row(): with gr.Column(): test_prompt = gr.Textbox( label="Test Prompt", placeholder="Enter a prompt to test the model...", lines=3 ) max_gen_length = gr.Slider( minimum=10, maximum=500, step=10, value=100, label="Max Generation Length" ) generate_button = gr.Button("Generate Text") generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False) with gr.Column(): export_format = gr.Radio( ["pytorch", "tensorflow", "gguf"], label="Export Format", value="pytorch" ) export_button = gr.Button("Export Model") export_status = gr.Textbox(label="Export Status", interactive=False) # Functionality def preprocess_data(file, format_type): try: if file is None: return "Please upload a file first." # Process the uploaded file dataset = self.model_instance.prepare_dataset(file.name, format_type) self.model_instance.dataset = dataset # Create a summary of the dataset num_samples = len(dataset["train"]) # Sample a few examples examples = dataset["train"].select(range(min(3, num_samples))) sample_text = [] for ex in examples: text_key = list(ex.keys())[0] if "text" not in ex else "text" sample = ex[text_key] if isinstance(sample, str): sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample) info = f"Dataset loaded successfully!\n" info += f"Number of training examples: {num_samples}\n" info += f"Sample data:\n" + "\n---\n".join(sample_text) return info except Exception as e: return f"Error preprocessing data: {str(e)}" def start_training( model_name, learning_rate, batch_size, epochs, max_length, use_lora, lora_r, lora_alpha, eval_ratio ): try: if self.model_instance.dataset is None: return "Please preprocess a dataset first." # Validate parameters if not model_name: return "Please select a model." # Prepare training parameters with proper type conversion training_params = { "model_name": str(model_name), "learning_rate": float(learning_rate), "batch_size": int(batch_size), "epochs": int(epochs), "max_length": int(max_length), "use_lora": bool(use_lora), "lora_r": int(lora_r) if use_lora else None, "lora_alpha": int(lora_alpha) if use_lora else None, "eval_ratio": float(eval_ratio), "weight_decay": float(self.default_params["weight_decay"]), "warmup_ratio": float(self.default_params["warmup_ratio"]), "lora_dropout": float(self.default_params["lora_dropout"]) } # Start training in a separate thread import threading def train_thread(): status = self.model_instance.train(training_params) return status thread = threading.Thread(target=train_thread) thread.start() return "Training started! Monitor the progress in the Training tab." except Exception as e: return f"Error starting training: {str(e)}" def stop_training(): if self.model_instance.trainer is not None: # Attempt to stop the trainer self.model_instance.trainer.stop_training = True return "Training stop signal sent. It may take a moment to complete the current step." return "No active training to stop." def update_progress_plot(): try: return self.model_instance.plot_training_progress() except Exception as e: return None def run_text_generation(prompt, max_length): try: if self.model_instance.model is None: return "Please fine-tune a model first." return self.model_instance.generate_text(prompt, int(max_length)) except Exception as e: return f"Error generating text: {str(e)}" def export_model_fn(format_type): try: if self.model_instance.model is None: return "Please fine-tune a model first." return self.model_instance.export_model(format_type) except Exception as e: return f"Error exporting model: {str(e)}" # Connect UI components to functions preprocess_button.click( preprocess_data, inputs=[file_upload, file_format], outputs=dataset_info ) start_training_button.click( start_training, inputs=[ model_name, learning_rate, batch_size, epochs, max_length, use_lora, lora_r, lora_alpha, eval_ratio ], outputs=training_status ) stop_training_button.click( stop_training, inputs=[], outputs=training_status ) refresh_plot_button.click( update_progress_plot, inputs=[], outputs=progress_plot ) generate_button.click( run_text_generation, inputs=[test_prompt, max_gen_length], outputs=generated_output ) export_button.click( export_model_fn, inputs=[export_format], outputs=export_status ) return app if __name__ == '__main__': ui = GemmaUI() app = ui.create_ui() app.launch()