Spaces:
Runtime error
Runtime error
Upload main.py
Browse files
main.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.check_dataset import validate_dataset, generate_dataset_report
|
3 |
+
from utils.sample_dataset import generate_sample_datasets
|
4 |
+
from utils.model import GemmaFineTuning
|
5 |
+
|
6 |
+
class GemmaUI:
|
7 |
+
def __init__(self):
|
8 |
+
self.model_instance = GemmaFineTuning()
|
9 |
+
self.default_params = self.model_instance.default_params
|
10 |
+
|
11 |
+
def create_ui(self):
|
12 |
+
"""Create the Gradio interface"""
|
13 |
+
with gr.Blocks(title="Gemma Fine-tuning UI") as app:
|
14 |
+
gr.Markdown("# Gemma Model Fine-tuning Interface")
|
15 |
+
gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model")
|
16 |
+
|
17 |
+
with gr.Tabs():
|
18 |
+
with gr.TabItem("1. Data Upload & Preprocessing"):
|
19 |
+
with gr.Row():
|
20 |
+
with gr.Column():
|
21 |
+
file_upload = gr.File(label="Upload Dataset")
|
22 |
+
file_format = gr.Radio(
|
23 |
+
["csv", "jsonl", "text"],
|
24 |
+
label="File Format",
|
25 |
+
value="csv"
|
26 |
+
)
|
27 |
+
preprocess_button = gr.Button("Preprocess Dataset")
|
28 |
+
dataset_info = gr.TextArea(label="Dataset Information", interactive=False)
|
29 |
+
|
30 |
+
with gr.TabItem("2. Model & Hyperparameters"):
|
31 |
+
with gr.Row():
|
32 |
+
with gr.Column():
|
33 |
+
model_name = gr.Dropdown(
|
34 |
+
choices=[
|
35 |
+
"google/gemma-2b",
|
36 |
+
"google/gemma-7b",
|
37 |
+
"google/gemma-2b-it",
|
38 |
+
"google/gemma-7b-it"
|
39 |
+
],
|
40 |
+
value=self.default_params["model_name"],
|
41 |
+
label="Model Name",
|
42 |
+
info="Select a Gemma model to fine-tune"
|
43 |
+
)
|
44 |
+
learning_rate = gr.Slider(
|
45 |
+
minimum=1e-6,
|
46 |
+
maximum=5e-4,
|
47 |
+
value=self.default_params["learning_rate"],
|
48 |
+
label="Learning Rate",
|
49 |
+
info="Learning rate for the optimizer"
|
50 |
+
)
|
51 |
+
batch_size = gr.Slider(
|
52 |
+
minimum=1,
|
53 |
+
maximum=32,
|
54 |
+
step=1,
|
55 |
+
value=self.default_params["batch_size"],
|
56 |
+
label="Batch Size",
|
57 |
+
info="Number of samples in each training batch"
|
58 |
+
)
|
59 |
+
epochs = gr.Slider(
|
60 |
+
minimum=1,
|
61 |
+
maximum=10,
|
62 |
+
step=1,
|
63 |
+
value=self.default_params["epochs"],
|
64 |
+
label="Epochs",
|
65 |
+
info="Number of training epochs"
|
66 |
+
)
|
67 |
+
|
68 |
+
with gr.Column():
|
69 |
+
max_length = gr.Slider(
|
70 |
+
minimum=128,
|
71 |
+
maximum=2048,
|
72 |
+
step=16,
|
73 |
+
value=self.default_params["max_length"],
|
74 |
+
label="Max Sequence Length",
|
75 |
+
info="Maximum token length for inputs"
|
76 |
+
)
|
77 |
+
use_lora = gr.Checkbox(
|
78 |
+
value=self.default_params["use_lora"],
|
79 |
+
label="Use LoRA for Parameter-Efficient Fine-tuning",
|
80 |
+
info="Recommended for faster training and lower memory usage"
|
81 |
+
)
|
82 |
+
lora_r = gr.Slider(
|
83 |
+
minimum=4,
|
84 |
+
maximum=64,
|
85 |
+
step=4,
|
86 |
+
value=self.default_params["lora_r"],
|
87 |
+
label="LoRA Rank (r)",
|
88 |
+
info="Rank of the LoRA update matrices",
|
89 |
+
visible=lambda: use_lora.value
|
90 |
+
)
|
91 |
+
lora_alpha = gr.Slider(
|
92 |
+
minimum=4,
|
93 |
+
maximum=64,
|
94 |
+
step=4,
|
95 |
+
value=self.default_params["lora_alpha"],
|
96 |
+
label="LoRA Alpha",
|
97 |
+
info="Scaling factor for LoRA updates",
|
98 |
+
visible=lambda: use_lora.value
|
99 |
+
)
|
100 |
+
eval_ratio = gr.Slider(
|
101 |
+
minimum=0.05,
|
102 |
+
maximum=0.3,
|
103 |
+
step=0.05,
|
104 |
+
value=self.default_params["eval_ratio"],
|
105 |
+
label="Validation Split Ratio",
|
106 |
+
info="Portion of data to use for validation"
|
107 |
+
)
|
108 |
+
|
109 |
+
with gr.TabItem("3. Training"):
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
start_training_button = gr.Button("Start Fine-tuning")
|
113 |
+
stop_training_button = gr.Button("Stop Training", variant="stop")
|
114 |
+
training_status = gr.Textbox(label="Training Status", interactive=False)
|
115 |
+
|
116 |
+
with gr.Column():
|
117 |
+
progress_plot = gr.Plot(label="Training Progress")
|
118 |
+
refresh_plot_button = gr.Button("Refresh Plot")
|
119 |
+
|
120 |
+
with gr.TabItem("4. Evaluation & Export"):
|
121 |
+
with gr.Row():
|
122 |
+
with gr.Column():
|
123 |
+
test_prompt = gr.Textbox(
|
124 |
+
label="Test Prompt",
|
125 |
+
placeholder="Enter a prompt to test the model...",
|
126 |
+
lines=3
|
127 |
+
)
|
128 |
+
max_gen_length = gr.Slider(
|
129 |
+
minimum=10,
|
130 |
+
maximum=500,
|
131 |
+
step=10,
|
132 |
+
value=100,
|
133 |
+
label="Max Generation Length"
|
134 |
+
)
|
135 |
+
generate_button = gr.Button("Generate Text")
|
136 |
+
generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
|
137 |
+
|
138 |
+
with gr.Column():
|
139 |
+
export_format = gr.Radio(
|
140 |
+
["pytorch", "tensorflow", "gguf"],
|
141 |
+
label="Export Format",
|
142 |
+
value="pytorch"
|
143 |
+
)
|
144 |
+
export_button = gr.Button("Export Model")
|
145 |
+
export_status = gr.Textbox(label="Export Status", interactive=False)
|
146 |
+
|
147 |
+
# Functionality
|
148 |
+
def preprocess_data(file, format_type):
|
149 |
+
try:
|
150 |
+
if file is None:
|
151 |
+
return "Please upload a file first."
|
152 |
+
|
153 |
+
# Process the uploaded file
|
154 |
+
dataset = self.model_instance.prepare_dataset(file.name, format_type)
|
155 |
+
self.model_instance.dataset = dataset
|
156 |
+
|
157 |
+
# Create a summary of the dataset
|
158 |
+
num_samples = len(dataset["train"])
|
159 |
+
|
160 |
+
|
161 |
+
# Sample a few examples
|
162 |
+
examples = dataset["train"].select(range(min(3, num_samples)))
|
163 |
+
sample_text = []
|
164 |
+
for ex in examples:
|
165 |
+
text_key = list(ex.keys())[0] if "text" not in ex else "text"
|
166 |
+
sample = ex[text_key]
|
167 |
+
if isinstance(sample, str):
|
168 |
+
sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample)
|
169 |
+
|
170 |
+
info = f"Dataset loaded successfully!\n"
|
171 |
+
info += f"Number of training examples: {num_samples}\n"
|
172 |
+
info += f"Sample data:\n" + "\n---\n".join(sample_text)
|
173 |
+
|
174 |
+
return info
|
175 |
+
except Exception as e:
|
176 |
+
return f"Error preprocessing data: {str(e)}"
|
177 |
+
|
178 |
+
def start_training(
|
179 |
+
model_name, learning_rate, batch_size, epochs, max_length,
|
180 |
+
use_lora, lora_r, lora_alpha, eval_ratio
|
181 |
+
):
|
182 |
+
try:
|
183 |
+
if self.model_instance.dataset is None:
|
184 |
+
return "Please preprocess a dataset first."
|
185 |
+
|
186 |
+
# Validate parameters
|
187 |
+
if not model_name:
|
188 |
+
return "Please select a model."
|
189 |
+
|
190 |
+
# Prepare training parameters with proper type conversion
|
191 |
+
training_params = {
|
192 |
+
"model_name": str(model_name),
|
193 |
+
"learning_rate": float(learning_rate),
|
194 |
+
"batch_size": int(batch_size),
|
195 |
+
"epochs": int(epochs),
|
196 |
+
"max_length": int(max_length),
|
197 |
+
"use_lora": bool(use_lora),
|
198 |
+
"lora_r": int(lora_r) if use_lora else None,
|
199 |
+
"lora_alpha": int(lora_alpha) if use_lora else None,
|
200 |
+
"eval_ratio": float(eval_ratio),
|
201 |
+
"weight_decay": float(self.default_params["weight_decay"]),
|
202 |
+
"warmup_ratio": float(self.default_params["warmup_ratio"]),
|
203 |
+
"lora_dropout": float(self.default_params["lora_dropout"])
|
204 |
+
}
|
205 |
+
|
206 |
+
# Start training in a separate thread
|
207 |
+
import threading
|
208 |
+
def train_thread():
|
209 |
+
status = self.model_instance.train(training_params)
|
210 |
+
return status
|
211 |
+
|
212 |
+
thread = threading.Thread(target=train_thread)
|
213 |
+
thread.start()
|
214 |
+
|
215 |
+
return "Training started! Monitor the progress in the Training tab."
|
216 |
+
except Exception as e:
|
217 |
+
return f"Error starting training: {str(e)}"
|
218 |
+
|
219 |
+
def stop_training():
|
220 |
+
if self.model_instance.trainer is not None:
|
221 |
+
# Attempt to stop the trainer
|
222 |
+
self.model_instance.trainer.stop_training = True
|
223 |
+
return "Training stop signal sent. It may take a moment to complete the current step."
|
224 |
+
return "No active training to stop."
|
225 |
+
|
226 |
+
def update_progress_plot():
|
227 |
+
try:
|
228 |
+
return self.model_instance.plot_training_progress()
|
229 |
+
except Exception as e:
|
230 |
+
return None
|
231 |
+
|
232 |
+
def run_text_generation(prompt, max_length):
|
233 |
+
try:
|
234 |
+
if self.model_instance.model is None:
|
235 |
+
return "Please fine-tune a model first."
|
236 |
+
|
237 |
+
return self.model_instance.generate_text(prompt, int(max_length))
|
238 |
+
except Exception as e:
|
239 |
+
return f"Error generating text: {str(e)}"
|
240 |
+
|
241 |
+
def export_model_fn(format_type):
|
242 |
+
try:
|
243 |
+
if self.model_instance.model is None:
|
244 |
+
return "Please fine-tune a model first."
|
245 |
+
|
246 |
+
return self.model_instance.export_model(format_type)
|
247 |
+
except Exception as e:
|
248 |
+
return f"Error exporting model: {str(e)}"
|
249 |
+
|
250 |
+
# Connect UI components to functions
|
251 |
+
preprocess_button.click(
|
252 |
+
preprocess_data,
|
253 |
+
inputs=[file_upload, file_format],
|
254 |
+
outputs=dataset_info
|
255 |
+
)
|
256 |
+
|
257 |
+
start_training_button.click(
|
258 |
+
start_training,
|
259 |
+
inputs=[
|
260 |
+
model_name, learning_rate, batch_size, epochs, max_length,
|
261 |
+
use_lora, lora_r, lora_alpha, eval_ratio
|
262 |
+
],
|
263 |
+
outputs=training_status
|
264 |
+
)
|
265 |
+
|
266 |
+
stop_training_button.click(
|
267 |
+
stop_training,
|
268 |
+
inputs=[],
|
269 |
+
outputs=training_status
|
270 |
+
)
|
271 |
+
|
272 |
+
refresh_plot_button.click(
|
273 |
+
update_progress_plot,
|
274 |
+
inputs=[],
|
275 |
+
outputs=progress_plot
|
276 |
+
)
|
277 |
+
|
278 |
+
generate_button.click(
|
279 |
+
run_text_generation,
|
280 |
+
inputs=[test_prompt, max_gen_length],
|
281 |
+
outputs=generated_output
|
282 |
+
)
|
283 |
+
|
284 |
+
export_button.click(
|
285 |
+
export_model_fn,
|
286 |
+
inputs=[export_format],
|
287 |
+
outputs=export_status
|
288 |
+
)
|
289 |
+
|
290 |
+
return app
|
291 |
+
|
292 |
+
if __name__ == '__main__':
|
293 |
+
ui = GemmaUI()
|
294 |
+
app = ui.create_ui()
|
295 |
+
app.launch()
|