codewithdark commited on
Commit
bc3f1ca
·
verified ·
1 Parent(s): 183d0e9

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -295
app.py DELETED
@@ -1,295 +0,0 @@
1
- import gradio as gr
2
- from utils.check_dataset import validate_dataset, generate_dataset_report
3
- from utils.sample_dataset import generate_sample_datasets
4
- from utils.model import GemmaFineTuning
5
-
6
- class GemmaUI:
7
- def __init__(self):
8
- self.model_instance = GemmaFineTuning()
9
- self.default_params = self.model_instance.default_params
10
-
11
- def create_ui(self):
12
- """Create the Gradio interface"""
13
- with gr.Blocks(title="Gemma Fine-tuning UI") as app:
14
- gr.Markdown("# Gemma Model Fine-tuning Interface")
15
- gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model")
16
-
17
- with gr.Tabs():
18
- with gr.TabItem("1. Data Upload & Preprocessing"):
19
- with gr.Row():
20
- with gr.Column():
21
- file_upload = gr.File(label="Upload Dataset")
22
- file_format = gr.Radio(
23
- ["csv", "jsonl", "text"],
24
- label="File Format",
25
- value="csv"
26
- )
27
- preprocess_button = gr.Button("Preprocess Dataset")
28
- dataset_info = gr.TextArea(label="Dataset Information", interactive=False)
29
-
30
- with gr.TabItem("2. Model & Hyperparameters"):
31
- with gr.Row():
32
- with gr.Column():
33
- model_name = gr.Dropdown(
34
- choices=[
35
- "google/gemma-2b",
36
- "google/gemma-7b",
37
- "google/gemma-2b-it",
38
- "google/gemma-7b-it"
39
- ],
40
- value=self.default_params["model_name"],
41
- label="Model Name",
42
- info="Select a Gemma model to fine-tune"
43
- )
44
- learning_rate = gr.Slider(
45
- minimum=1e-6,
46
- maximum=5e-4,
47
- value=self.default_params["learning_rate"],
48
- label="Learning Rate",
49
- info="Learning rate for the optimizer"
50
- )
51
- batch_size = gr.Slider(
52
- minimum=1,
53
- maximum=32,
54
- step=1,
55
- value=self.default_params["batch_size"],
56
- label="Batch Size",
57
- info="Number of samples in each training batch"
58
- )
59
- epochs = gr.Slider(
60
- minimum=1,
61
- maximum=10,
62
- step=1,
63
- value=self.default_params["epochs"],
64
- label="Epochs",
65
- info="Number of training epochs"
66
- )
67
-
68
- with gr.Column():
69
- max_length = gr.Slider(
70
- minimum=128,
71
- maximum=2048,
72
- step=16,
73
- value=self.default_params["max_length"],
74
- label="Max Sequence Length",
75
- info="Maximum token length for inputs"
76
- )
77
- use_lora = gr.Checkbox(
78
- value=self.default_params["use_lora"],
79
- label="Use LoRA for Parameter-Efficient Fine-tuning",
80
- info="Recommended for faster training and lower memory usage"
81
- )
82
- lora_r = gr.Slider(
83
- minimum=4,
84
- maximum=64,
85
- step=4,
86
- value=self.default_params["lora_r"],
87
- label="LoRA Rank (r)",
88
- info="Rank of the LoRA update matrices",
89
- visible=lambda: use_lora.value
90
- )
91
- lora_alpha = gr.Slider(
92
- minimum=4,
93
- maximum=64,
94
- step=4,
95
- value=self.default_params["lora_alpha"],
96
- label="LoRA Alpha",
97
- info="Scaling factor for LoRA updates",
98
- visible=lambda: use_lora.value
99
- )
100
- eval_ratio = gr.Slider(
101
- minimum=0.05,
102
- maximum=0.3,
103
- step=0.05,
104
- value=self.default_params["eval_ratio"],
105
- label="Validation Split Ratio",
106
- info="Portion of data to use for validation"
107
- )
108
-
109
- with gr.TabItem("3. Training"):
110
- with gr.Row():
111
- with gr.Column():
112
- start_training_button = gr.Button("Start Fine-tuning")
113
- stop_training_button = gr.Button("Stop Training", variant="stop")
114
- training_status = gr.Textbox(label="Training Status", interactive=False)
115
-
116
- with gr.Column():
117
- progress_plot = gr.Plot(label="Training Progress")
118
- refresh_plot_button = gr.Button("Refresh Plot")
119
-
120
- with gr.TabItem("4. Evaluation & Export"):
121
- with gr.Row():
122
- with gr.Column():
123
- test_prompt = gr.Textbox(
124
- label="Test Prompt",
125
- placeholder="Enter a prompt to test the model...",
126
- lines=3
127
- )
128
- max_gen_length = gr.Slider(
129
- minimum=10,
130
- maximum=500,
131
- step=10,
132
- value=100,
133
- label="Max Generation Length"
134
- )
135
- generate_button = gr.Button("Generate Text")
136
- generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
137
-
138
- with gr.Column():
139
- export_format = gr.Radio(
140
- ["pytorch", "tensorflow", "gguf"],
141
- label="Export Format",
142
- value="pytorch"
143
- )
144
- export_button = gr.Button("Export Model")
145
- export_status = gr.Textbox(label="Export Status", interactive=False)
146
-
147
- # Functionality
148
- def preprocess_data(file, format_type):
149
- try:
150
- if file is None:
151
- return "Please upload a file first."
152
-
153
- # Process the uploaded file
154
- dataset = self.model_instance.prepare_dataset(file.name, format_type)
155
- self.model_instance.dataset = dataset
156
-
157
- # Create a summary of the dataset
158
- num_samples = len(dataset["train"])
159
-
160
-
161
- # Sample a few examples
162
- examples = dataset["train"].select(range(min(3, num_samples)))
163
- sample_text = []
164
- for ex in examples:
165
- text_key = list(ex.keys())[0] if "text" not in ex else "text"
166
- sample = ex[text_key]
167
- if isinstance(sample, str):
168
- sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample)
169
-
170
- info = f"Dataset loaded successfully!\n"
171
- info += f"Number of training examples: {num_samples}\n"
172
- info += f"Sample data:\n" + "\n---\n".join(sample_text)
173
-
174
- return info
175
- except Exception as e:
176
- return f"Error preprocessing data: {str(e)}"
177
-
178
- def start_training(
179
- model_name, learning_rate, batch_size, epochs, max_length,
180
- use_lora, lora_r, lora_alpha, eval_ratio
181
- ):
182
- try:
183
- if self.model_instance.dataset is None:
184
- return "Please preprocess a dataset first."
185
-
186
- # Validate parameters
187
- if not model_name:
188
- return "Please select a model."
189
-
190
- # Prepare training parameters with proper type conversion
191
- training_params = {
192
- "model_name": str(model_name),
193
- "learning_rate": float(learning_rate),
194
- "batch_size": int(batch_size),
195
- "epochs": int(epochs),
196
- "max_length": int(max_length),
197
- "use_lora": bool(use_lora),
198
- "lora_r": int(lora_r) if use_lora else None,
199
- "lora_alpha": int(lora_alpha) if use_lora else None,
200
- "eval_ratio": float(eval_ratio),
201
- "weight_decay": float(self.default_params["weight_decay"]),
202
- "warmup_ratio": float(self.default_params["warmup_ratio"]),
203
- "lora_dropout": float(self.default_params["lora_dropout"])
204
- }
205
-
206
- # Start training in a separate thread
207
- import threading
208
- def train_thread():
209
- status = self.model_instance.train(training_params)
210
- return status
211
-
212
- thread = threading.Thread(target=train_thread)
213
- thread.start()
214
-
215
- return "Training started! Monitor the progress in the Training tab."
216
- except Exception as e:
217
- return f"Error starting training: {str(e)}"
218
-
219
- def stop_training():
220
- if self.model_instance.trainer is not None:
221
- # Attempt to stop the trainer
222
- self.model_instance.trainer.stop_training = True
223
- return "Training stop signal sent. It may take a moment to complete the current step."
224
- return "No active training to stop."
225
-
226
- def update_progress_plot():
227
- try:
228
- return self.model_instance.plot_training_progress()
229
- except Exception as e:
230
- return None
231
-
232
- def run_text_generation(prompt, max_length):
233
- try:
234
- if self.model_instance.model is None:
235
- return "Please fine-tune a model first."
236
-
237
- return self.model_instance.generate_text(prompt, int(max_length))
238
- except Exception as e:
239
- return f"Error generating text: {str(e)}"
240
-
241
- def export_model_fn(format_type):
242
- try:
243
- if self.model_instance.model is None:
244
- return "Please fine-tune a model first."
245
-
246
- return self.model_instance.export_model(format_type)
247
- except Exception as e:
248
- return f"Error exporting model: {str(e)}"
249
-
250
- # Connect UI components to functions
251
- preprocess_button.click(
252
- preprocess_data,
253
- inputs=[file_upload, file_format],
254
- outputs=dataset_info
255
- )
256
-
257
- start_training_button.click(
258
- start_training,
259
- inputs=[
260
- model_name, learning_rate, batch_size, epochs, max_length,
261
- use_lora, lora_r, lora_alpha, eval_ratio
262
- ],
263
- outputs=training_status
264
- )
265
-
266
- stop_training_button.click(
267
- stop_training,
268
- inputs=[],
269
- outputs=training_status
270
- )
271
-
272
- refresh_plot_button.click(
273
- update_progress_plot,
274
- inputs=[],
275
- outputs=progress_plot
276
- )
277
-
278
- generate_button.click(
279
- run_text_generation,
280
- inputs=[test_prompt, max_gen_length],
281
- outputs=generated_output
282
- )
283
-
284
- export_button.click(
285
- export_model_fn,
286
- inputs=[export_format],
287
- outputs=export_status
288
- )
289
-
290
- return app
291
-
292
- if __name__ == '__main__':
293
- ui = GemmaUI()
294
- app = ui.create_ui()
295
- app.launch(share=True)