Sebastien De Greef commited on
Commit
909b9b6
·
1 Parent(s): 6baccb3

handle push_to_hub_gguf and inference

Browse files
Files changed (1) hide show
  1. app.py +57 -21
app.py CHANGED
@@ -92,7 +92,20 @@ def load_data(dataset_name, data_template_style, data_template):
92
  dataset = dataset.map(lambda examples: formatting_prompts_func(examples, data_template), batched=True)
93
  return f"Data loaded {len(dataset)} records loaded.", gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True)
94
 
95
-
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  async def train_model(model_name: str, lora_r: int, lora_alpha: int, lora_dropout: float, per_device_train_batch_size: int, warmup_steps: int, max_steps: int,
@@ -143,9 +156,35 @@ async def train_model(model_name: str, lora_r: int, lora_alpha: int, lora_dropou
143
  trainer.train()
144
  return "Model training",gr.update(visible=True, interactive=False), gr.update(visible=True, interactive=True), gr.update(interactive=True)
145
 
146
- def save_model():
147
- return "Model saved", gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=False), gr.update(interactive=False)
148
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  # Create the Gradio interface
151
  with gr.Blocks() as demo:
@@ -171,7 +210,7 @@ with gr.Blocks() as demo:
171
  dataset_name = gr.Textbox(label="Dataset Name", value="yahma/alpaca-cleaned")
172
  data_template_style = gr.Dropdown(label="Template", choices=["alpaca","custom"], value="alpaca", allow_custom_value=True)
173
  with gr.Row():
174
- data_tempalte = gr.TextArea(label="Data Template", value="""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
175
 
176
  ### Instruction:
177
  {}
@@ -184,7 +223,7 @@ with gr.Blocks() as demo:
184
  gr.Markdown("---")
185
  output_load_data = gr.Textbox(label="Data Load Status", value="Data not loaded", interactive=False)
186
  load_data_btn = gr.Button("Load Dataset", interactive=True)
187
- load_data_btn.click(load_data, inputs=[dataset_name, data_template_style, data_tempalte], outputs=[output_load_data, load_data_btn])
188
 
189
  with gr.Tab("Fine-Tuning"):
190
  gr.Markdown("""### Fine-Tuned Model Parameters""")
@@ -238,18 +277,18 @@ with gr.Blocks() as demo:
238
  with gr.Column():
239
  merge_16bit = gr.Checkbox(label="Merge to 16bit", value=False, interactive=True)
240
  merge_4bit = gr.Checkbox(label="Merge to 4bit", value=False, interactive=True)
241
- just_lora = gr.Checkbox(label="Just LoRA Adapter", value=False, interactive=True)
242
  gr.Markdown("---")
243
 
244
  with gr.Row():
245
  gr.Markdown("### GGUF Options")
246
  with gr.Column():
247
- merge_16bit = gr.Checkbox(label="Quantize to f16", value=False, interactive=True)
248
- merge_16bit = gr.Checkbox(label="Quantize to 8bit (Q8_0)", value=False, interactive=True)
249
- merge_16bit = gr.Checkbox(label="Quantize to 4bit (q4_k_m)", value=False, interactive=True)
250
  with gr.Column():
251
- merge_custom = gr.Checkbox(label="Custom", value=False, interactive=True)
252
- merge_custom_value = gr.Textbox(label="", value="Q5_K", interactive=True)
253
  gr.Markdown("---")
254
 
255
  with gr.Row():
@@ -258,7 +297,6 @@ with gr.Blocks() as demo:
258
  with gr.Column():
259
  hub_model_name = gr.Textbox(label="Hub Model Name", value=f"username/model_name", interactive=True)
260
  hub_token = gr.Textbox(label="Hub Token", interactive=True, type="password")
261
- ollama_pub_key = gr.Button("HuggingFace Access Token")
262
  gr.Markdown("---")
263
 
264
  with gr.Row():
@@ -270,23 +308,21 @@ with gr.Blocks() as demo:
270
  ollama_model_name = gr.Textbox(label="Ollama Model Name", value="user/model_name")
271
  ollama_pub_key = gr.Button("Ollama Pub Key")
272
  gr.Markdown("---")
273
-
 
274
 
275
  with gr.Tab("Inference"):
276
  with gr.Row():
277
- gr.Textbox(label="Input Text", lines=4, value="""\
278
  Continue the fibonnaci sequence.
279
  # instruction
280
  1, 1, 2, 3, 5, 8
281
  # input
282
  """, interactive=True)
283
- gr.Textbox(label="Output Text", lines=4, value="""\
284
- """, interactive=False)
285
-
286
- inference_button = gr.Button("Inference", visible=False, interactive=False)
287
- # Output
288
 
289
- # Button click events
 
290
  load_btn.click(load_model, inputs=[initial_model_name, load_in_4bit, max_sequence_length], outputs=[output, load_btn, train_btn, initial_model_name, load_in_4bit, max_sequence_length])
291
 
292
  demo.launch()
 
92
  dataset = dataset.map(lambda examples: formatting_prompts_func(examples, data_template), batched=True)
93
  return f"Data loaded {len(dataset)} records loaded.", gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True)
94
 
95
+ def inference(prompt, input_text):
96
+ FastLanguageModel.for_inference(model) # Enable native 2x faster inference
97
+ inputs = tokenizer(
98
+ [
99
+ prompt.format(
100
+ "Continue the fibonnaci sequence.", # instruction
101
+ "1, 1, 2, 3, 5, 8", # input
102
+ "", # output - leave this blank for generation!
103
+ )
104
+ ], return_tensors = "pt").to("cuda")
105
+
106
+ outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
107
+ result = tokenizer.batch_decode(outputs)
108
+ return result[0], gr.update(visible=True, interactive=True)
109
 
110
 
111
  async def train_model(model_name: str, lora_r: int, lora_alpha: int, lora_dropout: float, per_device_train_batch_size: int, warmup_steps: int, max_steps: int,
 
156
  trainer.train()
157
  return "Model training",gr.update(visible=True, interactive=False), gr.update(visible=True, interactive=True), gr.update(interactive=True)
158
 
159
+ def save_model(model_name, hub_model_name, hub_token, gguf_16bit, gguf_8bit, gguf_4bit, gguf_custom, gguf_custom_value, merge_16bit, merge_4bit, just_lora, push_to_hub):
160
+ global model, tokenizer
161
+ if gguf_custom:
162
+ gguf_custom_value = gguf_custom_value
163
+ else:
164
+ gguf_custom_value = None
165
+
166
+ if gguf_16bit:
167
+ gguf = "f16"
168
+ elif gguf_8bit:
169
+ gguf = "Q8_0"
170
+ elif gguf_4bit:
171
+ gguf = "q4_k_m"
172
+ else:
173
+ gguf = None
174
+
175
+ if merge_16bit:
176
+ merge = "16bit"
177
+ elif merge_4bit:
178
+ merge = "4bit"
179
+ elif just_lora:
180
+ merge = "lora"
181
+ else:
182
+ merge = None
183
+
184
+ #model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")
185
+ if push_to_hub:
186
+ model.push_to_hub_gguf(hub_model_name, tokenizer, quantization_method=gguf, token=hub_token)
187
+ return "Model saved", gr.update(visible=True, interactive=True)
188
 
189
  # Create the Gradio interface
190
  with gr.Blocks() as demo:
 
210
  dataset_name = gr.Textbox(label="Dataset Name", value="yahma/alpaca-cleaned")
211
  data_template_style = gr.Dropdown(label="Template", choices=["alpaca","custom"], value="alpaca", allow_custom_value=True)
212
  with gr.Row():
213
+ data_template = gr.TextArea(label="Data Template", value="""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
214
 
215
  ### Instruction:
216
  {}
 
223
  gr.Markdown("---")
224
  output_load_data = gr.Textbox(label="Data Load Status", value="Data not loaded", interactive=False)
225
  load_data_btn = gr.Button("Load Dataset", interactive=True)
226
+ load_data_btn.click(load_data, inputs=[dataset_name, data_template_style, data_template], outputs=[output_load_data, load_data_btn])
227
 
228
  with gr.Tab("Fine-Tuning"):
229
  gr.Markdown("""### Fine-Tuned Model Parameters""")
 
277
  with gr.Column():
278
  merge_16bit = gr.Checkbox(label="Merge to 16bit", value=False, interactive=True)
279
  merge_4bit = gr.Checkbox(label="Merge to 4bit", value=False, interactive=True)
280
+ just_lora = gr.Checkbox(label="Just LoRA Adapter", value=False, interactive=True)
281
  gr.Markdown("---")
282
 
283
  with gr.Row():
284
  gr.Markdown("### GGUF Options")
285
  with gr.Column():
286
+ gguf_16bit = gr.Checkbox(label="Quantize to f16", value=False, interactive=True)
287
+ gguf_8bit = gr.Checkbox(label="Quantize to 8bit (Q8_0)", value=False, interactive=True)
288
+ gguf_4bit = gr.Checkbox(label="Quantize to 4bit (q4_k_m)", value=False, interactive=True)
289
  with gr.Column():
290
+ gguf_custom = gr.Checkbox(label="Custom", value=False, interactive=True)
291
+ gguf_custom_value = gr.Textbox(label="", value="Q5_K", interactive=True)
292
  gr.Markdown("---")
293
 
294
  with gr.Row():
 
297
  with gr.Column():
298
  hub_model_name = gr.Textbox(label="Hub Model Name", value=f"username/model_name", interactive=True)
299
  hub_token = gr.Textbox(label="Hub Token", interactive=True, type="password")
 
300
  gr.Markdown("---")
301
 
302
  with gr.Row():
 
308
  ollama_model_name = gr.Textbox(label="Ollama Model Name", value="user/model_name")
309
  ollama_pub_key = gr.Button("Ollama Pub Key")
310
  gr.Markdown("---")
311
+ save_button = gr.Button("Save Model", visible=True, interactive=True)
312
+ save_button.click(save_model, inputs=[model_name, hub_model_name, hub_token, gguf_16bit, gguf_8bit, gguf_4bit, gguf_custom, gguf_custom_value, merge_16bit, merge_4bit, just_lora, push_to_hub], outputs=[save_button])
313
 
314
  with gr.Tab("Inference"):
315
  with gr.Row():
316
+ input_text = gr.Textbox(label="Input Text", lines=4, value="""\
317
  Continue the fibonnaci sequence.
318
  # instruction
319
  1, 1, 2, 3, 5, 8
320
  # input
321
  """, interactive=True)
322
+ output_text = gr.Textbox(label="Output Text", lines=4, value="", interactive=False)
 
 
 
 
323
 
324
+ inference_button = gr.Button("Inference", visible=True, interactive=True)
325
+ inference_button.click(inference, inputs=[data_template, input_text], outputs=[output_text, inference_button])
326
  load_btn.click(load_model, inputs=[initial_model_name, load_in_4bit, max_sequence_length], outputs=[output, load_btn, train_btn, initial_model_name, load_in_4bit, max_sequence_length])
327
 
328
  demo.launch()