Spaces:
Runtime error
Runtime error
import gradio as gr | |
from utils.check_dataset import validate_dataset, generate_dataset_report | |
from utils.sample_dataset import generate_sample_datasets | |
from utils.model import GemmaFineTuning | |
class GemmaUI: | |
def __init__(self): | |
self.model_instance = GemmaFineTuning() | |
self.default_params = self.model_instance.default_params | |
def create_ui(self): | |
"""Create the Gradio interface""" | |
with gr.Blocks(title="Gemma Fine-tuning UI") as app: | |
gr.Markdown("# Gemma Model Fine-tuning Interface") | |
gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model") | |
with gr.Tabs(): | |
with gr.TabItem("1. Data Upload & Preprocessing"): | |
with gr.Row(): | |
with gr.Column(): | |
file_upload = gr.File(label="Upload Dataset") | |
file_format = gr.Radio( | |
["csv", "jsonl", "text"], | |
label="File Format", | |
value="csv" | |
) | |
preprocess_button = gr.Button("Preprocess Dataset") | |
dataset_info = gr.TextArea(label="Dataset Information", interactive=False) | |
with gr.TabItem("2. Model & Hyperparameters"): | |
with gr.Row(): | |
with gr.Column(): | |
model_name = gr.Dropdown( | |
choices=[ | |
"google/gemma-2b", | |
"google/gemma-7b", | |
"google/gemma-2b-it", | |
"google/gemma-7b-it" | |
], | |
value=self.default_params["model_name"], | |
label="Model Name", | |
info="Select a Gemma model to fine-tune" | |
) | |
learning_rate = gr.Slider( | |
minimum=1e-6, | |
maximum=5e-4, | |
value=self.default_params["learning_rate"], | |
label="Learning Rate", | |
info="Learning rate for the optimizer" | |
) | |
batch_size = gr.Slider( | |
minimum=1, | |
maximum=32, | |
step=1, | |
value=self.default_params["batch_size"], | |
label="Batch Size", | |
info="Number of samples in each training batch" | |
) | |
epochs = gr.Slider( | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=self.default_params["epochs"], | |
label="Epochs", | |
info="Number of training epochs" | |
) | |
with gr.Column(): | |
max_length = gr.Slider( | |
minimum=128, | |
maximum=2048, | |
step=16, | |
value=self.default_params["max_length"], | |
label="Max Sequence Length", | |
info="Maximum token length for inputs" | |
) | |
use_lora = gr.Checkbox( | |
value=self.default_params["use_lora"], | |
label="Use LoRA for Parameter-Efficient Fine-tuning", | |
info="Recommended for faster training and lower memory usage" | |
) | |
lora_r = gr.Slider( | |
minimum=4, | |
maximum=64, | |
step=4, | |
value=self.default_params["lora_r"], | |
label="LoRA Rank (r)", | |
info="Rank of the LoRA update matrices", | |
visible=lambda: use_lora.value | |
) | |
lora_alpha = gr.Slider( | |
minimum=4, | |
maximum=64, | |
step=4, | |
value=self.default_params["lora_alpha"], | |
label="LoRA Alpha", | |
info="Scaling factor for LoRA updates", | |
visible=lambda: use_lora.value | |
) | |
eval_ratio = gr.Slider( | |
minimum=0.05, | |
maximum=0.3, | |
step=0.05, | |
value=self.default_params["eval_ratio"], | |
label="Validation Split Ratio", | |
info="Portion of data to use for validation" | |
) | |
with gr.TabItem("3. Training"): | |
with gr.Row(): | |
with gr.Column(): | |
start_training_button = gr.Button("Start Fine-tuning") | |
stop_training_button = gr.Button("Stop Training", variant="stop") | |
training_status = gr.Textbox(label="Training Status", interactive=False) | |
with gr.Column(): | |
progress_plot = gr.Plot(label="Training Progress") | |
refresh_plot_button = gr.Button("Refresh Plot") | |
with gr.TabItem("4. Evaluation & Export"): | |
with gr.Row(): | |
with gr.Column(): | |
test_prompt = gr.Textbox( | |
label="Test Prompt", | |
placeholder="Enter a prompt to test the model...", | |
lines=3 | |
) | |
max_gen_length = gr.Slider( | |
minimum=10, | |
maximum=500, | |
step=10, | |
value=100, | |
label="Max Generation Length" | |
) | |
generate_button = gr.Button("Generate Text") | |
generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False) | |
with gr.Column(): | |
export_format = gr.Radio( | |
["pytorch", "tensorflow", "gguf"], | |
label="Export Format", | |
value="pytorch" | |
) | |
export_button = gr.Button("Export Model") | |
export_status = gr.Textbox(label="Export Status", interactive=False) | |
# Functionality | |
def preprocess_data(file, format_type): | |
try: | |
if file is None: | |
return "Please upload a file first." | |
# Process the uploaded file | |
dataset = self.model_instance.prepare_dataset(file.name, format_type) | |
self.model_instance.dataset = dataset | |
# Create a summary of the dataset | |
num_samples = len(dataset["train"]) | |
# Sample a few examples | |
examples = dataset["train"].select(range(min(3, num_samples))) | |
sample_text = [] | |
for ex in examples: | |
text_key = list(ex.keys())[0] if "text" not in ex else "text" | |
sample = ex[text_key] | |
if isinstance(sample, str): | |
sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample) | |
info = f"Dataset loaded successfully!\n" | |
info += f"Number of training examples: {num_samples}\n" | |
info += f"Sample data:\n" + "\n---\n".join(sample_text) | |
return info | |
except Exception as e: | |
return f"Error preprocessing data: {str(e)}" | |
def start_training( | |
model_name, learning_rate, batch_size, epochs, max_length, | |
use_lora, lora_r, lora_alpha, eval_ratio | |
): | |
try: | |
if self.model_instance.dataset is None: | |
return "Please preprocess a dataset first." | |
# Validate parameters | |
if not model_name: | |
return "Please select a model." | |
# Prepare training parameters with proper type conversion | |
training_params = { | |
"model_name": str(model_name), | |
"learning_rate": float(learning_rate), | |
"batch_size": int(batch_size), | |
"epochs": int(epochs), | |
"max_length": int(max_length), | |
"use_lora": bool(use_lora), | |
"lora_r": int(lora_r) if use_lora else None, | |
"lora_alpha": int(lora_alpha) if use_lora else None, | |
"eval_ratio": float(eval_ratio), | |
"weight_decay": float(self.default_params["weight_decay"]), | |
"warmup_ratio": float(self.default_params["warmup_ratio"]), | |
"lora_dropout": float(self.default_params["lora_dropout"]) | |
} | |
# Start training in a separate thread | |
import threading | |
def train_thread(): | |
status = self.model_instance.train(training_params) | |
return status | |
thread = threading.Thread(target=train_thread) | |
thread.start() | |
return "Training started! Monitor the progress in the Training tab." | |
except Exception as e: | |
return f"Error starting training: {str(e)}" | |
def stop_training(): | |
if self.model_instance.trainer is not None: | |
# Attempt to stop the trainer | |
self.model_instance.trainer.stop_training = True | |
return "Training stop signal sent. It may take a moment to complete the current step." | |
return "No active training to stop." | |
def update_progress_plot(): | |
try: | |
return self.model_instance.plot_training_progress() | |
except Exception as e: | |
return None | |
def run_text_generation(prompt, max_length): | |
try: | |
if self.model_instance.model is None: | |
return "Please fine-tune a model first." | |
return self.model_instance.generate_text(prompt, int(max_length)) | |
except Exception as e: | |
return f"Error generating text: {str(e)}" | |
def export_model_fn(format_type): | |
try: | |
if self.model_instance.model is None: | |
return "Please fine-tune a model first." | |
return self.model_instance.export_model(format_type) | |
except Exception as e: | |
return f"Error exporting model: {str(e)}" | |
# Connect UI components to functions | |
preprocess_button.click( | |
preprocess_data, | |
inputs=[file_upload, file_format], | |
outputs=dataset_info | |
) | |
start_training_button.click( | |
start_training, | |
inputs=[ | |
model_name, learning_rate, batch_size, epochs, max_length, | |
use_lora, lora_r, lora_alpha, eval_ratio | |
], | |
outputs=training_status | |
) | |
stop_training_button.click( | |
stop_training, | |
inputs=[], | |
outputs=training_status | |
) | |
refresh_plot_button.click( | |
update_progress_plot, | |
inputs=[], | |
outputs=progress_plot | |
) | |
generate_button.click( | |
run_text_generation, | |
inputs=[test_prompt, max_gen_length], | |
outputs=generated_output | |
) | |
export_button.click( | |
export_model_fn, | |
inputs=[export_format], | |
outputs=export_status | |
) | |
return app | |
if __name__ == '__main__': | |
ui = GemmaUI() | |
app = ui.create_ui() | |
app.launch() |