Gemma_Finetuner / app.py
codewithdark's picture
Rename main.py to app.py
ef47783 verified
import gradio as gr
from utils.check_dataset import validate_dataset, generate_dataset_report
from utils.sample_dataset import generate_sample_datasets
from utils.model import GemmaFineTuning
class GemmaUI:
def __init__(self):
self.model_instance = GemmaFineTuning()
self.default_params = self.model_instance.default_params
def create_ui(self):
"""Create the Gradio interface"""
with gr.Blocks(title="Gemma Fine-tuning UI") as app:
gr.Markdown("# Gemma Model Fine-tuning Interface")
gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model")
with gr.Tabs():
with gr.TabItem("1. Data Upload & Preprocessing"):
with gr.Row():
with gr.Column():
file_upload = gr.File(label="Upload Dataset")
file_format = gr.Radio(
["csv", "jsonl", "text"],
label="File Format",
value="csv"
)
preprocess_button = gr.Button("Preprocess Dataset")
dataset_info = gr.TextArea(label="Dataset Information", interactive=False)
with gr.TabItem("2. Model & Hyperparameters"):
with gr.Row():
with gr.Column():
model_name = gr.Dropdown(
choices=[
"google/gemma-2b",
"google/gemma-7b",
"google/gemma-2b-it",
"google/gemma-7b-it"
],
value=self.default_params["model_name"],
label="Model Name",
info="Select a Gemma model to fine-tune"
)
learning_rate = gr.Slider(
minimum=1e-6,
maximum=5e-4,
value=self.default_params["learning_rate"],
label="Learning Rate",
info="Learning rate for the optimizer"
)
batch_size = gr.Slider(
minimum=1,
maximum=32,
step=1,
value=self.default_params["batch_size"],
label="Batch Size",
info="Number of samples in each training batch"
)
epochs = gr.Slider(
minimum=1,
maximum=10,
step=1,
value=self.default_params["epochs"],
label="Epochs",
info="Number of training epochs"
)
with gr.Column():
max_length = gr.Slider(
minimum=128,
maximum=2048,
step=16,
value=self.default_params["max_length"],
label="Max Sequence Length",
info="Maximum token length for inputs"
)
use_lora = gr.Checkbox(
value=self.default_params["use_lora"],
label="Use LoRA for Parameter-Efficient Fine-tuning",
info="Recommended for faster training and lower memory usage"
)
lora_r = gr.Slider(
minimum=4,
maximum=64,
step=4,
value=self.default_params["lora_r"],
label="LoRA Rank (r)",
info="Rank of the LoRA update matrices",
visible=lambda: use_lora.value
)
lora_alpha = gr.Slider(
minimum=4,
maximum=64,
step=4,
value=self.default_params["lora_alpha"],
label="LoRA Alpha",
info="Scaling factor for LoRA updates",
visible=lambda: use_lora.value
)
eval_ratio = gr.Slider(
minimum=0.05,
maximum=0.3,
step=0.05,
value=self.default_params["eval_ratio"],
label="Validation Split Ratio",
info="Portion of data to use for validation"
)
with gr.TabItem("3. Training"):
with gr.Row():
with gr.Column():
start_training_button = gr.Button("Start Fine-tuning")
stop_training_button = gr.Button("Stop Training", variant="stop")
training_status = gr.Textbox(label="Training Status", interactive=False)
with gr.Column():
progress_plot = gr.Plot(label="Training Progress")
refresh_plot_button = gr.Button("Refresh Plot")
with gr.TabItem("4. Evaluation & Export"):
with gr.Row():
with gr.Column():
test_prompt = gr.Textbox(
label="Test Prompt",
placeholder="Enter a prompt to test the model...",
lines=3
)
max_gen_length = gr.Slider(
minimum=10,
maximum=500,
step=10,
value=100,
label="Max Generation Length"
)
generate_button = gr.Button("Generate Text")
generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
with gr.Column():
export_format = gr.Radio(
["pytorch", "tensorflow", "gguf"],
label="Export Format",
value="pytorch"
)
export_button = gr.Button("Export Model")
export_status = gr.Textbox(label="Export Status", interactive=False)
# Functionality
def preprocess_data(file, format_type):
try:
if file is None:
return "Please upload a file first."
# Process the uploaded file
dataset = self.model_instance.prepare_dataset(file.name, format_type)
self.model_instance.dataset = dataset
# Create a summary of the dataset
num_samples = len(dataset["train"])
# Sample a few examples
examples = dataset["train"].select(range(min(3, num_samples)))
sample_text = []
for ex in examples:
text_key = list(ex.keys())[0] if "text" not in ex else "text"
sample = ex[text_key]
if isinstance(sample, str):
sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample)
info = f"Dataset loaded successfully!\n"
info += f"Number of training examples: {num_samples}\n"
info += f"Sample data:\n" + "\n---\n".join(sample_text)
return info
except Exception as e:
return f"Error preprocessing data: {str(e)}"
def start_training(
model_name, learning_rate, batch_size, epochs, max_length,
use_lora, lora_r, lora_alpha, eval_ratio
):
try:
if self.model_instance.dataset is None:
return "Please preprocess a dataset first."
# Validate parameters
if not model_name:
return "Please select a model."
# Prepare training parameters with proper type conversion
training_params = {
"model_name": str(model_name),
"learning_rate": float(learning_rate),
"batch_size": int(batch_size),
"epochs": int(epochs),
"max_length": int(max_length),
"use_lora": bool(use_lora),
"lora_r": int(lora_r) if use_lora else None,
"lora_alpha": int(lora_alpha) if use_lora else None,
"eval_ratio": float(eval_ratio),
"weight_decay": float(self.default_params["weight_decay"]),
"warmup_ratio": float(self.default_params["warmup_ratio"]),
"lora_dropout": float(self.default_params["lora_dropout"])
}
# Start training in a separate thread
import threading
def train_thread():
status = self.model_instance.train(training_params)
return status
thread = threading.Thread(target=train_thread)
thread.start()
return "Training started! Monitor the progress in the Training tab."
except Exception as e:
return f"Error starting training: {str(e)}"
def stop_training():
if self.model_instance.trainer is not None:
# Attempt to stop the trainer
self.model_instance.trainer.stop_training = True
return "Training stop signal sent. It may take a moment to complete the current step."
return "No active training to stop."
def update_progress_plot():
try:
return self.model_instance.plot_training_progress()
except Exception as e:
return None
def run_text_generation(prompt, max_length):
try:
if self.model_instance.model is None:
return "Please fine-tune a model first."
return self.model_instance.generate_text(prompt, int(max_length))
except Exception as e:
return f"Error generating text: {str(e)}"
def export_model_fn(format_type):
try:
if self.model_instance.model is None:
return "Please fine-tune a model first."
return self.model_instance.export_model(format_type)
except Exception as e:
return f"Error exporting model: {str(e)}"
# Connect UI components to functions
preprocess_button.click(
preprocess_data,
inputs=[file_upload, file_format],
outputs=dataset_info
)
start_training_button.click(
start_training,
inputs=[
model_name, learning_rate, batch_size, epochs, max_length,
use_lora, lora_r, lora_alpha, eval_ratio
],
outputs=training_status
)
stop_training_button.click(
stop_training,
inputs=[],
outputs=training_status
)
refresh_plot_button.click(
update_progress_plot,
inputs=[],
outputs=progress_plot
)
generate_button.click(
run_text_generation,
inputs=[test_prompt, max_gen_length],
outputs=generated_output
)
export_button.click(
export_model_fn,
inputs=[export_format],
outputs=export_status
)
return app
if __name__ == '__main__':
ui = GemmaUI()
app = ui.create_ui()
app.launch()