File size: 13,563 Bytes
c5982f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import gradio as gr
from utils.check_dataset import validate_dataset, generate_dataset_report
from utils.sample_dataset import generate_sample_datasets
from utils.model import GemmaFineTuning

class GemmaUI:
    def __init__(self):
        self.model_instance = GemmaFineTuning()
        self.default_params = self.model_instance.default_params

    def create_ui(self):
        """Create the Gradio interface"""
        with gr.Blocks(title="Gemma Fine-tuning UI") as app:
            gr.Markdown("# Gemma Model Fine-tuning Interface")
            gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model")

            with gr.Tabs():
                with gr.TabItem("1. Data Upload & Preprocessing"):
                    with gr.Row():
                        with gr.Column():
                            file_upload = gr.File(label="Upload Dataset")
                            file_format = gr.Radio(
                                ["csv", "jsonl", "text"],
                                label="File Format",
                                value="csv"
                            )
                            preprocess_button = gr.Button("Preprocess Dataset")
                            dataset_info = gr.TextArea(label="Dataset Information", interactive=False)

                with gr.TabItem("2. Model & Hyperparameters"):
                    with gr.Row():
                        with gr.Column():
                            model_name = gr.Dropdown(
                                choices=[
                                    "google/gemma-2b",
                                    "google/gemma-7b",
                                    "google/gemma-2b-it",
                                    "google/gemma-7b-it"
                                ],
                                value=self.default_params["model_name"],
                                label="Model Name",
                                info="Select a Gemma model to fine-tune"
                            )
                            learning_rate = gr.Slider(
                                minimum=1e-6,
                                maximum=5e-4,
                                value=self.default_params["learning_rate"],
                                label="Learning Rate",
                                info="Learning rate for the optimizer"
                            )
                            batch_size = gr.Slider(
                                minimum=1,
                                maximum=32,
                                step=1,
                                value=self.default_params["batch_size"],
                                label="Batch Size",
                                info="Number of samples in each training batch"
                            )
                            epochs = gr.Slider(
                                minimum=1,
                                maximum=10,
                                step=1,
                                value=self.default_params["epochs"],
                                label="Epochs",
                                info="Number of training epochs"
                            )

                        with gr.Column():
                            max_length = gr.Slider(
                                minimum=128,
                                maximum=2048,
                                step=16,
                                value=self.default_params["max_length"],
                                label="Max Sequence Length",
                                info="Maximum token length for inputs"
                            )
                            use_lora = gr.Checkbox(
                                value=self.default_params["use_lora"],
                                label="Use LoRA for Parameter-Efficient Fine-tuning",
                                info="Recommended for faster training and lower memory usage"
                            )
                            lora_r = gr.Slider(
                                minimum=4,
                                maximum=64,
                                step=4,
                                value=self.default_params["lora_r"],
                                label="LoRA Rank (r)",
                                info="Rank of the LoRA update matrices",
                                visible=lambda: use_lora.value
                            )
                            lora_alpha = gr.Slider(
                                minimum=4,
                                maximum=64,
                                step=4,
                                value=self.default_params["lora_alpha"],
                                label="LoRA Alpha",
                                info="Scaling factor for LoRA updates",
                                visible=lambda: use_lora.value
                            )
                            eval_ratio = gr.Slider(
                                minimum=0.05,
                                maximum=0.3,
                                step=0.05,
                                value=self.default_params["eval_ratio"],
                                label="Validation Split Ratio",
                                info="Portion of data to use for validation"
                            )

                with gr.TabItem("3. Training"):
                    with gr.Row():
                        with gr.Column():
                            start_training_button = gr.Button("Start Fine-tuning")
                            stop_training_button = gr.Button("Stop Training", variant="stop")
                            training_status = gr.Textbox(label="Training Status", interactive=False)

                        with gr.Column():
                            progress_plot = gr.Plot(label="Training Progress")
                            refresh_plot_button = gr.Button("Refresh Plot")

                with gr.TabItem("4. Evaluation & Export"):
                    with gr.Row():
                        with gr.Column():
                            test_prompt = gr.Textbox(
                                label="Test Prompt",
                                placeholder="Enter a prompt to test the model...",
                                lines=3
                            )
                            max_gen_length = gr.Slider(
                                minimum=10,
                                maximum=500,
                                step=10,
                                value=100,
                                label="Max Generation Length"
                            )
                            generate_button = gr.Button("Generate Text")
                            generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)

                        with gr.Column():
                            export_format = gr.Radio(
                                ["pytorch", "tensorflow", "gguf"],
                                label="Export Format",
                                value="pytorch"
                            )
                            export_button = gr.Button("Export Model")
                            export_status = gr.Textbox(label="Export Status", interactive=False)

            # Functionality
            def preprocess_data(file, format_type):
                try:
                    if file is None:
                        return "Please upload a file first."

                    # Process the uploaded file
                    dataset = self.model_instance.prepare_dataset(file.name, format_type)
                    self.model_instance.dataset = dataset

                    # Create a summary of the dataset
                    num_samples = len(dataset["train"])


                    # Sample a few examples
                    examples = dataset["train"].select(range(min(3, num_samples)))
                    sample_text = []
                    for ex in examples:
                        text_key = list(ex.keys())[0] if "text" not in ex else "text"
                        sample = ex[text_key]
                        if isinstance(sample, str):
                            sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample)

                    info = f"Dataset loaded successfully!\n"
                    info += f"Number of training examples: {num_samples}\n"
                    info += f"Sample data:\n" + "\n---\n".join(sample_text)

                    return info
                except Exception as e:
                    return f"Error preprocessing data: {str(e)}"

            def start_training(

                model_name, learning_rate, batch_size, epochs, max_length,

                use_lora, lora_r, lora_alpha, eval_ratio

            ):
                try:
                    if self.model_instance.dataset is None:
                        return "Please preprocess a dataset first."

                    # Validate parameters
                    if not model_name:
                        return "Please select a model."
                    
                    # Prepare training parameters with proper type conversion
                    training_params = {
                        "model_name": str(model_name),
                        "learning_rate": float(learning_rate),
                        "batch_size": int(batch_size),
                        "epochs": int(epochs),
                        "max_length": int(max_length),
                        "use_lora": bool(use_lora),
                        "lora_r": int(lora_r) if use_lora else None,
                        "lora_alpha": int(lora_alpha) if use_lora else None,
                        "eval_ratio": float(eval_ratio),
                        "weight_decay": float(self.default_params["weight_decay"]),
                        "warmup_ratio": float(self.default_params["warmup_ratio"]),
                        "lora_dropout": float(self.default_params["lora_dropout"])
                    }

                    # Start training in a separate thread
                    import threading
                    def train_thread():
                        status = self.model_instance.train(training_params)
                        return status

                    thread = threading.Thread(target=train_thread)
                    thread.start()

                    return "Training started! Monitor the progress in the Training tab."
                except Exception as e:
                    return f"Error starting training: {str(e)}"

            def stop_training():
                if self.model_instance.trainer is not None:
                    # Attempt to stop the trainer
                    self.model_instance.trainer.stop_training = True
                    return "Training stop signal sent. It may take a moment to complete the current step."
                return "No active training to stop."

            def update_progress_plot():
                try:
                    return self.model_instance.plot_training_progress()
                except Exception as e:
                    return None

            def run_text_generation(prompt, max_length):
                try:
                    if self.model_instance.model is None:
                        return "Please fine-tune a model first."

                    return self.model_instance.generate_text(prompt, int(max_length))
                except Exception as e:
                    return f"Error generating text: {str(e)}"

            def export_model_fn(format_type):
                try:
                    if self.model_instance.model is None:
                        return "Please fine-tune a model first."

                    return self.model_instance.export_model(format_type)
                except Exception as e:
                    return f"Error exporting model: {str(e)}"

            # Connect UI components to functions
            preprocess_button.click(
                preprocess_data,
                inputs=[file_upload, file_format],
                outputs=dataset_info
            )

            start_training_button.click(
                start_training,
                inputs=[
                    model_name, learning_rate, batch_size, epochs, max_length,
                    use_lora, lora_r, lora_alpha, eval_ratio
                ],
                outputs=training_status
            )

            stop_training_button.click(
                stop_training,
                inputs=[],
                outputs=training_status
            )

            refresh_plot_button.click(
                update_progress_plot,
                inputs=[],
                outputs=progress_plot
            )

            generate_button.click(
                run_text_generation,
                inputs=[test_prompt, max_gen_length],
                outputs=generated_output
            )

            export_button.click(
                export_model_fn,
                inputs=[export_format],
                outputs=export_status
            )

        return app

if __name__ == '__main__':
    ui = GemmaUI()
    app = ui.create_ui()
    app.launch()