codewithdark commited on
Commit
c5982f9
·
verified ·
1 Parent(s): bc3f1ca

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +295 -0
main.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.check_dataset import validate_dataset, generate_dataset_report
3
+ from utils.sample_dataset import generate_sample_datasets
4
+ from utils.model import GemmaFineTuning
5
+
6
+ class GemmaUI:
7
+ def __init__(self):
8
+ self.model_instance = GemmaFineTuning()
9
+ self.default_params = self.model_instance.default_params
10
+
11
+ def create_ui(self):
12
+ """Create the Gradio interface"""
13
+ with gr.Blocks(title="Gemma Fine-tuning UI") as app:
14
+ gr.Markdown("# Gemma Model Fine-tuning Interface")
15
+ gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model")
16
+
17
+ with gr.Tabs():
18
+ with gr.TabItem("1. Data Upload & Preprocessing"):
19
+ with gr.Row():
20
+ with gr.Column():
21
+ file_upload = gr.File(label="Upload Dataset")
22
+ file_format = gr.Radio(
23
+ ["csv", "jsonl", "text"],
24
+ label="File Format",
25
+ value="csv"
26
+ )
27
+ preprocess_button = gr.Button("Preprocess Dataset")
28
+ dataset_info = gr.TextArea(label="Dataset Information", interactive=False)
29
+
30
+ with gr.TabItem("2. Model & Hyperparameters"):
31
+ with gr.Row():
32
+ with gr.Column():
33
+ model_name = gr.Dropdown(
34
+ choices=[
35
+ "google/gemma-2b",
36
+ "google/gemma-7b",
37
+ "google/gemma-2b-it",
38
+ "google/gemma-7b-it"
39
+ ],
40
+ value=self.default_params["model_name"],
41
+ label="Model Name",
42
+ info="Select a Gemma model to fine-tune"
43
+ )
44
+ learning_rate = gr.Slider(
45
+ minimum=1e-6,
46
+ maximum=5e-4,
47
+ value=self.default_params["learning_rate"],
48
+ label="Learning Rate",
49
+ info="Learning rate for the optimizer"
50
+ )
51
+ batch_size = gr.Slider(
52
+ minimum=1,
53
+ maximum=32,
54
+ step=1,
55
+ value=self.default_params["batch_size"],
56
+ label="Batch Size",
57
+ info="Number of samples in each training batch"
58
+ )
59
+ epochs = gr.Slider(
60
+ minimum=1,
61
+ maximum=10,
62
+ step=1,
63
+ value=self.default_params["epochs"],
64
+ label="Epochs",
65
+ info="Number of training epochs"
66
+ )
67
+
68
+ with gr.Column():
69
+ max_length = gr.Slider(
70
+ minimum=128,
71
+ maximum=2048,
72
+ step=16,
73
+ value=self.default_params["max_length"],
74
+ label="Max Sequence Length",
75
+ info="Maximum token length for inputs"
76
+ )
77
+ use_lora = gr.Checkbox(
78
+ value=self.default_params["use_lora"],
79
+ label="Use LoRA for Parameter-Efficient Fine-tuning",
80
+ info="Recommended for faster training and lower memory usage"
81
+ )
82
+ lora_r = gr.Slider(
83
+ minimum=4,
84
+ maximum=64,
85
+ step=4,
86
+ value=self.default_params["lora_r"],
87
+ label="LoRA Rank (r)",
88
+ info="Rank of the LoRA update matrices",
89
+ visible=lambda: use_lora.value
90
+ )
91
+ lora_alpha = gr.Slider(
92
+ minimum=4,
93
+ maximum=64,
94
+ step=4,
95
+ value=self.default_params["lora_alpha"],
96
+ label="LoRA Alpha",
97
+ info="Scaling factor for LoRA updates",
98
+ visible=lambda: use_lora.value
99
+ )
100
+ eval_ratio = gr.Slider(
101
+ minimum=0.05,
102
+ maximum=0.3,
103
+ step=0.05,
104
+ value=self.default_params["eval_ratio"],
105
+ label="Validation Split Ratio",
106
+ info="Portion of data to use for validation"
107
+ )
108
+
109
+ with gr.TabItem("3. Training"):
110
+ with gr.Row():
111
+ with gr.Column():
112
+ start_training_button = gr.Button("Start Fine-tuning")
113
+ stop_training_button = gr.Button("Stop Training", variant="stop")
114
+ training_status = gr.Textbox(label="Training Status", interactive=False)
115
+
116
+ with gr.Column():
117
+ progress_plot = gr.Plot(label="Training Progress")
118
+ refresh_plot_button = gr.Button("Refresh Plot")
119
+
120
+ with gr.TabItem("4. Evaluation & Export"):
121
+ with gr.Row():
122
+ with gr.Column():
123
+ test_prompt = gr.Textbox(
124
+ label="Test Prompt",
125
+ placeholder="Enter a prompt to test the model...",
126
+ lines=3
127
+ )
128
+ max_gen_length = gr.Slider(
129
+ minimum=10,
130
+ maximum=500,
131
+ step=10,
132
+ value=100,
133
+ label="Max Generation Length"
134
+ )
135
+ generate_button = gr.Button("Generate Text")
136
+ generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
137
+
138
+ with gr.Column():
139
+ export_format = gr.Radio(
140
+ ["pytorch", "tensorflow", "gguf"],
141
+ label="Export Format",
142
+ value="pytorch"
143
+ )
144
+ export_button = gr.Button("Export Model")
145
+ export_status = gr.Textbox(label="Export Status", interactive=False)
146
+
147
+ # Functionality
148
+ def preprocess_data(file, format_type):
149
+ try:
150
+ if file is None:
151
+ return "Please upload a file first."
152
+
153
+ # Process the uploaded file
154
+ dataset = self.model_instance.prepare_dataset(file.name, format_type)
155
+ self.model_instance.dataset = dataset
156
+
157
+ # Create a summary of the dataset
158
+ num_samples = len(dataset["train"])
159
+
160
+
161
+ # Sample a few examples
162
+ examples = dataset["train"].select(range(min(3, num_samples)))
163
+ sample_text = []
164
+ for ex in examples:
165
+ text_key = list(ex.keys())[0] if "text" not in ex else "text"
166
+ sample = ex[text_key]
167
+ if isinstance(sample, str):
168
+ sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample)
169
+
170
+ info = f"Dataset loaded successfully!\n"
171
+ info += f"Number of training examples: {num_samples}\n"
172
+ info += f"Sample data:\n" + "\n---\n".join(sample_text)
173
+
174
+ return info
175
+ except Exception as e:
176
+ return f"Error preprocessing data: {str(e)}"
177
+
178
+ def start_training(
179
+ model_name, learning_rate, batch_size, epochs, max_length,
180
+ use_lora, lora_r, lora_alpha, eval_ratio
181
+ ):
182
+ try:
183
+ if self.model_instance.dataset is None:
184
+ return "Please preprocess a dataset first."
185
+
186
+ # Validate parameters
187
+ if not model_name:
188
+ return "Please select a model."
189
+
190
+ # Prepare training parameters with proper type conversion
191
+ training_params = {
192
+ "model_name": str(model_name),
193
+ "learning_rate": float(learning_rate),
194
+ "batch_size": int(batch_size),
195
+ "epochs": int(epochs),
196
+ "max_length": int(max_length),
197
+ "use_lora": bool(use_lora),
198
+ "lora_r": int(lora_r) if use_lora else None,
199
+ "lora_alpha": int(lora_alpha) if use_lora else None,
200
+ "eval_ratio": float(eval_ratio),
201
+ "weight_decay": float(self.default_params["weight_decay"]),
202
+ "warmup_ratio": float(self.default_params["warmup_ratio"]),
203
+ "lora_dropout": float(self.default_params["lora_dropout"])
204
+ }
205
+
206
+ # Start training in a separate thread
207
+ import threading
208
+ def train_thread():
209
+ status = self.model_instance.train(training_params)
210
+ return status
211
+
212
+ thread = threading.Thread(target=train_thread)
213
+ thread.start()
214
+
215
+ return "Training started! Monitor the progress in the Training tab."
216
+ except Exception as e:
217
+ return f"Error starting training: {str(e)}"
218
+
219
+ def stop_training():
220
+ if self.model_instance.trainer is not None:
221
+ # Attempt to stop the trainer
222
+ self.model_instance.trainer.stop_training = True
223
+ return "Training stop signal sent. It may take a moment to complete the current step."
224
+ return "No active training to stop."
225
+
226
+ def update_progress_plot():
227
+ try:
228
+ return self.model_instance.plot_training_progress()
229
+ except Exception as e:
230
+ return None
231
+
232
+ def run_text_generation(prompt, max_length):
233
+ try:
234
+ if self.model_instance.model is None:
235
+ return "Please fine-tune a model first."
236
+
237
+ return self.model_instance.generate_text(prompt, int(max_length))
238
+ except Exception as e:
239
+ return f"Error generating text: {str(e)}"
240
+
241
+ def export_model_fn(format_type):
242
+ try:
243
+ if self.model_instance.model is None:
244
+ return "Please fine-tune a model first."
245
+
246
+ return self.model_instance.export_model(format_type)
247
+ except Exception as e:
248
+ return f"Error exporting model: {str(e)}"
249
+
250
+ # Connect UI components to functions
251
+ preprocess_button.click(
252
+ preprocess_data,
253
+ inputs=[file_upload, file_format],
254
+ outputs=dataset_info
255
+ )
256
+
257
+ start_training_button.click(
258
+ start_training,
259
+ inputs=[
260
+ model_name, learning_rate, batch_size, epochs, max_length,
261
+ use_lora, lora_r, lora_alpha, eval_ratio
262
+ ],
263
+ outputs=training_status
264
+ )
265
+
266
+ stop_training_button.click(
267
+ stop_training,
268
+ inputs=[],
269
+ outputs=training_status
270
+ )
271
+
272
+ refresh_plot_button.click(
273
+ update_progress_plot,
274
+ inputs=[],
275
+ outputs=progress_plot
276
+ )
277
+
278
+ generate_button.click(
279
+ run_text_generation,
280
+ inputs=[test_prompt, max_gen_length],
281
+ outputs=generated_output
282
+ )
283
+
284
+ export_button.click(
285
+ export_model_fn,
286
+ inputs=[export_format],
287
+ outputs=export_status
288
+ )
289
+
290
+ return app
291
+
292
+ if __name__ == '__main__':
293
+ ui = GemmaUI()
294
+ app = ui.create_ui()
295
+ app.launch()