Spaces:
Runtime error
Runtime error
Delete app.py
Browse files
app.py
DELETED
@@ -1,295 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from utils.check_dataset import validate_dataset, generate_dataset_report
|
3 |
-
from utils.sample_dataset import generate_sample_datasets
|
4 |
-
from utils.model import GemmaFineTuning
|
5 |
-
|
6 |
-
class GemmaUI:
|
7 |
-
def __init__(self):
|
8 |
-
self.model_instance = GemmaFineTuning()
|
9 |
-
self.default_params = self.model_instance.default_params
|
10 |
-
|
11 |
-
def create_ui(self):
|
12 |
-
"""Create the Gradio interface"""
|
13 |
-
with gr.Blocks(title="Gemma Fine-tuning UI") as app:
|
14 |
-
gr.Markdown("# Gemma Model Fine-tuning Interface")
|
15 |
-
gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model")
|
16 |
-
|
17 |
-
with gr.Tabs():
|
18 |
-
with gr.TabItem("1. Data Upload & Preprocessing"):
|
19 |
-
with gr.Row():
|
20 |
-
with gr.Column():
|
21 |
-
file_upload = gr.File(label="Upload Dataset")
|
22 |
-
file_format = gr.Radio(
|
23 |
-
["csv", "jsonl", "text"],
|
24 |
-
label="File Format",
|
25 |
-
value="csv"
|
26 |
-
)
|
27 |
-
preprocess_button = gr.Button("Preprocess Dataset")
|
28 |
-
dataset_info = gr.TextArea(label="Dataset Information", interactive=False)
|
29 |
-
|
30 |
-
with gr.TabItem("2. Model & Hyperparameters"):
|
31 |
-
with gr.Row():
|
32 |
-
with gr.Column():
|
33 |
-
model_name = gr.Dropdown(
|
34 |
-
choices=[
|
35 |
-
"google/gemma-2b",
|
36 |
-
"google/gemma-7b",
|
37 |
-
"google/gemma-2b-it",
|
38 |
-
"google/gemma-7b-it"
|
39 |
-
],
|
40 |
-
value=self.default_params["model_name"],
|
41 |
-
label="Model Name",
|
42 |
-
info="Select a Gemma model to fine-tune"
|
43 |
-
)
|
44 |
-
learning_rate = gr.Slider(
|
45 |
-
minimum=1e-6,
|
46 |
-
maximum=5e-4,
|
47 |
-
value=self.default_params["learning_rate"],
|
48 |
-
label="Learning Rate",
|
49 |
-
info="Learning rate for the optimizer"
|
50 |
-
)
|
51 |
-
batch_size = gr.Slider(
|
52 |
-
minimum=1,
|
53 |
-
maximum=32,
|
54 |
-
step=1,
|
55 |
-
value=self.default_params["batch_size"],
|
56 |
-
label="Batch Size",
|
57 |
-
info="Number of samples in each training batch"
|
58 |
-
)
|
59 |
-
epochs = gr.Slider(
|
60 |
-
minimum=1,
|
61 |
-
maximum=10,
|
62 |
-
step=1,
|
63 |
-
value=self.default_params["epochs"],
|
64 |
-
label="Epochs",
|
65 |
-
info="Number of training epochs"
|
66 |
-
)
|
67 |
-
|
68 |
-
with gr.Column():
|
69 |
-
max_length = gr.Slider(
|
70 |
-
minimum=128,
|
71 |
-
maximum=2048,
|
72 |
-
step=16,
|
73 |
-
value=self.default_params["max_length"],
|
74 |
-
label="Max Sequence Length",
|
75 |
-
info="Maximum token length for inputs"
|
76 |
-
)
|
77 |
-
use_lora = gr.Checkbox(
|
78 |
-
value=self.default_params["use_lora"],
|
79 |
-
label="Use LoRA for Parameter-Efficient Fine-tuning",
|
80 |
-
info="Recommended for faster training and lower memory usage"
|
81 |
-
)
|
82 |
-
lora_r = gr.Slider(
|
83 |
-
minimum=4,
|
84 |
-
maximum=64,
|
85 |
-
step=4,
|
86 |
-
value=self.default_params["lora_r"],
|
87 |
-
label="LoRA Rank (r)",
|
88 |
-
info="Rank of the LoRA update matrices",
|
89 |
-
visible=lambda: use_lora.value
|
90 |
-
)
|
91 |
-
lora_alpha = gr.Slider(
|
92 |
-
minimum=4,
|
93 |
-
maximum=64,
|
94 |
-
step=4,
|
95 |
-
value=self.default_params["lora_alpha"],
|
96 |
-
label="LoRA Alpha",
|
97 |
-
info="Scaling factor for LoRA updates",
|
98 |
-
visible=lambda: use_lora.value
|
99 |
-
)
|
100 |
-
eval_ratio = gr.Slider(
|
101 |
-
minimum=0.05,
|
102 |
-
maximum=0.3,
|
103 |
-
step=0.05,
|
104 |
-
value=self.default_params["eval_ratio"],
|
105 |
-
label="Validation Split Ratio",
|
106 |
-
info="Portion of data to use for validation"
|
107 |
-
)
|
108 |
-
|
109 |
-
with gr.TabItem("3. Training"):
|
110 |
-
with gr.Row():
|
111 |
-
with gr.Column():
|
112 |
-
start_training_button = gr.Button("Start Fine-tuning")
|
113 |
-
stop_training_button = gr.Button("Stop Training", variant="stop")
|
114 |
-
training_status = gr.Textbox(label="Training Status", interactive=False)
|
115 |
-
|
116 |
-
with gr.Column():
|
117 |
-
progress_plot = gr.Plot(label="Training Progress")
|
118 |
-
refresh_plot_button = gr.Button("Refresh Plot")
|
119 |
-
|
120 |
-
with gr.TabItem("4. Evaluation & Export"):
|
121 |
-
with gr.Row():
|
122 |
-
with gr.Column():
|
123 |
-
test_prompt = gr.Textbox(
|
124 |
-
label="Test Prompt",
|
125 |
-
placeholder="Enter a prompt to test the model...",
|
126 |
-
lines=3
|
127 |
-
)
|
128 |
-
max_gen_length = gr.Slider(
|
129 |
-
minimum=10,
|
130 |
-
maximum=500,
|
131 |
-
step=10,
|
132 |
-
value=100,
|
133 |
-
label="Max Generation Length"
|
134 |
-
)
|
135 |
-
generate_button = gr.Button("Generate Text")
|
136 |
-
generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
|
137 |
-
|
138 |
-
with gr.Column():
|
139 |
-
export_format = gr.Radio(
|
140 |
-
["pytorch", "tensorflow", "gguf"],
|
141 |
-
label="Export Format",
|
142 |
-
value="pytorch"
|
143 |
-
)
|
144 |
-
export_button = gr.Button("Export Model")
|
145 |
-
export_status = gr.Textbox(label="Export Status", interactive=False)
|
146 |
-
|
147 |
-
# Functionality
|
148 |
-
def preprocess_data(file, format_type):
|
149 |
-
try:
|
150 |
-
if file is None:
|
151 |
-
return "Please upload a file first."
|
152 |
-
|
153 |
-
# Process the uploaded file
|
154 |
-
dataset = self.model_instance.prepare_dataset(file.name, format_type)
|
155 |
-
self.model_instance.dataset = dataset
|
156 |
-
|
157 |
-
# Create a summary of the dataset
|
158 |
-
num_samples = len(dataset["train"])
|
159 |
-
|
160 |
-
|
161 |
-
# Sample a few examples
|
162 |
-
examples = dataset["train"].select(range(min(3, num_samples)))
|
163 |
-
sample_text = []
|
164 |
-
for ex in examples:
|
165 |
-
text_key = list(ex.keys())[0] if "text" not in ex else "text"
|
166 |
-
sample = ex[text_key]
|
167 |
-
if isinstance(sample, str):
|
168 |
-
sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample)
|
169 |
-
|
170 |
-
info = f"Dataset loaded successfully!\n"
|
171 |
-
info += f"Number of training examples: {num_samples}\n"
|
172 |
-
info += f"Sample data:\n" + "\n---\n".join(sample_text)
|
173 |
-
|
174 |
-
return info
|
175 |
-
except Exception as e:
|
176 |
-
return f"Error preprocessing data: {str(e)}"
|
177 |
-
|
178 |
-
def start_training(
|
179 |
-
model_name, learning_rate, batch_size, epochs, max_length,
|
180 |
-
use_lora, lora_r, lora_alpha, eval_ratio
|
181 |
-
):
|
182 |
-
try:
|
183 |
-
if self.model_instance.dataset is None:
|
184 |
-
return "Please preprocess a dataset first."
|
185 |
-
|
186 |
-
# Validate parameters
|
187 |
-
if not model_name:
|
188 |
-
return "Please select a model."
|
189 |
-
|
190 |
-
# Prepare training parameters with proper type conversion
|
191 |
-
training_params = {
|
192 |
-
"model_name": str(model_name),
|
193 |
-
"learning_rate": float(learning_rate),
|
194 |
-
"batch_size": int(batch_size),
|
195 |
-
"epochs": int(epochs),
|
196 |
-
"max_length": int(max_length),
|
197 |
-
"use_lora": bool(use_lora),
|
198 |
-
"lora_r": int(lora_r) if use_lora else None,
|
199 |
-
"lora_alpha": int(lora_alpha) if use_lora else None,
|
200 |
-
"eval_ratio": float(eval_ratio),
|
201 |
-
"weight_decay": float(self.default_params["weight_decay"]),
|
202 |
-
"warmup_ratio": float(self.default_params["warmup_ratio"]),
|
203 |
-
"lora_dropout": float(self.default_params["lora_dropout"])
|
204 |
-
}
|
205 |
-
|
206 |
-
# Start training in a separate thread
|
207 |
-
import threading
|
208 |
-
def train_thread():
|
209 |
-
status = self.model_instance.train(training_params)
|
210 |
-
return status
|
211 |
-
|
212 |
-
thread = threading.Thread(target=train_thread)
|
213 |
-
thread.start()
|
214 |
-
|
215 |
-
return "Training started! Monitor the progress in the Training tab."
|
216 |
-
except Exception as e:
|
217 |
-
return f"Error starting training: {str(e)}"
|
218 |
-
|
219 |
-
def stop_training():
|
220 |
-
if self.model_instance.trainer is not None:
|
221 |
-
# Attempt to stop the trainer
|
222 |
-
self.model_instance.trainer.stop_training = True
|
223 |
-
return "Training stop signal sent. It may take a moment to complete the current step."
|
224 |
-
return "No active training to stop."
|
225 |
-
|
226 |
-
def update_progress_plot():
|
227 |
-
try:
|
228 |
-
return self.model_instance.plot_training_progress()
|
229 |
-
except Exception as e:
|
230 |
-
return None
|
231 |
-
|
232 |
-
def run_text_generation(prompt, max_length):
|
233 |
-
try:
|
234 |
-
if self.model_instance.model is None:
|
235 |
-
return "Please fine-tune a model first."
|
236 |
-
|
237 |
-
return self.model_instance.generate_text(prompt, int(max_length))
|
238 |
-
except Exception as e:
|
239 |
-
return f"Error generating text: {str(e)}"
|
240 |
-
|
241 |
-
def export_model_fn(format_type):
|
242 |
-
try:
|
243 |
-
if self.model_instance.model is None:
|
244 |
-
return "Please fine-tune a model first."
|
245 |
-
|
246 |
-
return self.model_instance.export_model(format_type)
|
247 |
-
except Exception as e:
|
248 |
-
return f"Error exporting model: {str(e)}"
|
249 |
-
|
250 |
-
# Connect UI components to functions
|
251 |
-
preprocess_button.click(
|
252 |
-
preprocess_data,
|
253 |
-
inputs=[file_upload, file_format],
|
254 |
-
outputs=dataset_info
|
255 |
-
)
|
256 |
-
|
257 |
-
start_training_button.click(
|
258 |
-
start_training,
|
259 |
-
inputs=[
|
260 |
-
model_name, learning_rate, batch_size, epochs, max_length,
|
261 |
-
use_lora, lora_r, lora_alpha, eval_ratio
|
262 |
-
],
|
263 |
-
outputs=training_status
|
264 |
-
)
|
265 |
-
|
266 |
-
stop_training_button.click(
|
267 |
-
stop_training,
|
268 |
-
inputs=[],
|
269 |
-
outputs=training_status
|
270 |
-
)
|
271 |
-
|
272 |
-
refresh_plot_button.click(
|
273 |
-
update_progress_plot,
|
274 |
-
inputs=[],
|
275 |
-
outputs=progress_plot
|
276 |
-
)
|
277 |
-
|
278 |
-
generate_button.click(
|
279 |
-
run_text_generation,
|
280 |
-
inputs=[test_prompt, max_gen_length],
|
281 |
-
outputs=generated_output
|
282 |
-
)
|
283 |
-
|
284 |
-
export_button.click(
|
285 |
-
export_model_fn,
|
286 |
-
inputs=[export_format],
|
287 |
-
outputs=export_status
|
288 |
-
)
|
289 |
-
|
290 |
-
return app
|
291 |
-
|
292 |
-
if __name__ == '__main__':
|
293 |
-
ui = GemmaUI()
|
294 |
-
app = ui.create_ui()
|
295 |
-
app.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|