Artples commited on
Commit
f99d488
·
verified ·
1 Parent(s): ee2b52a

Upload Finetuning_NoteBook(2).ipynb

Browse files
Files changed (1) hide show
  1. Finetuning_NoteBook(2).ipynb +624 -0
Finetuning_NoteBook(2).ipynb ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "25db6bb6",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Installing Required Libraries!"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "0378e956",
14
+ "metadata": {},
15
+ "source": [
16
+ "Installing required libraries, including trl, transformers, accelerate, peft, datasets, and bitsandbytes."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "bfdba870",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "\n",
27
+ "# Checks if PyTorch is installed and installs it if not.\n",
28
+ "try:\n",
29
+ " import torch\n",
30
+ " print(\"PyTorch is installed!\")\n",
31
+ "except ImportError:\n",
32
+ " print(\"PyTorch is not installed.\")\n",
33
+ " !pip install -q torch\n"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "538a911b",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "\n",
44
+ "!pip install -q --upgrade \"transformers==4.38.2\"\n",
45
+ "!pip install -q --upgrade \"datasets==2.16.1\"\n",
46
+ "!pip install -q --upgrade \"accelerate==0.26.1\"\n",
47
+ "!pip install -q --upgrade \"evaluate==0.4.1\"\n",
48
+ "!pip install -q --upgrade \"bitsandbytes==0.42.0\"\n",
49
+ "!pip install -q --upgrade \"trl==0.7.11\"\n",
50
+ "!pip install -q --upgrade \"peft==0.8.2\"\n",
51
+ " "
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "markdown",
56
+ "id": "cb6eeaf2",
57
+ "metadata": {},
58
+ "source": [
59
+ "## Installing Flash Attention"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "id": "cdd64478",
65
+ "metadata": {},
66
+ "source": [
67
+ "Installing Flash Attention to reduce the memory and runtime cost of the attention layer, and improve the performance of the model training. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main). Installing flash attention from source can take quite a bit of time (~ minutes)."
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "id": "9d59ace4",
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "\n",
78
+ "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'\n",
79
+ "\n",
80
+ "!pip install ninja packaging\n",
81
+ "!MAX_JOBS=4 pip install -q flash-attn --no-build-isolation --upgrade\n",
82
+ " "
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "id": "f9c1ff52",
88
+ "metadata": {},
89
+ "source": [
90
+ "# Load and Prepare the Dataset"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "id": "100e0966",
96
+ "metadata": {},
97
+ "source": [
98
+ "The dataset is already formatted in a conversational format, which is supported by [trl](https://huggingface.co/docs/trl/index/), and ready for supervised finetuning."
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "id": "ca04a539",
104
+ "metadata": {},
105
+ "source": [
106
+ "\n",
107
+ "**Conversational format:**\n",
108
+ "\n",
109
+ "\n",
110
+ "```python {\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
111
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
112
+ "{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
113
+ "```\n"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "id": "ec40616b",
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "\n",
124
+ "from datasets import load_dataset\n",
125
+ " \n",
126
+ "# Load dataset from the hub\n",
127
+ "dataset = load_dataset(\"HuggingFaceH4/ultrachat_200k\", split=\"train_sft\")\n",
128
+ " \n",
129
+ "dataset = dataset.shuffle(seed=42)\n",
130
+ " "
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "id": "805c2975",
136
+ "metadata": {},
137
+ "source": [
138
+ "# Load **mistralai/Mistral-7B-v0.1** for Finetuning"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "id": "8248708e",
144
+ "metadata": {},
145
+ "source": [
146
+ "\n",
147
+ "This process involves two key steps:\n",
148
+ "\n",
149
+ "1. **LLM Quantization:**\n",
150
+ " - We first load the selected large language model (LLM).\n",
151
+ " - We then use the `bitsandbytes` library to quantize the model, which can significantly reduce its memory footprint.\n",
152
+ "\n",
153
+ "> **Note:** The memory requirements of the model scale with its size. For instance, a 7B parameter model may require \n",
154
+ "a 24GB GPU for fine-tuning. \n",
155
+ "\n",
156
+ "2. **Chat Model Preparation:**\n",
157
+ " - To train a model for chat/conversational tasks, we need to prepare both the model and its tokenizer.\n",
158
+ " \n",
159
+ " - This involves adding special tokens to the tokenizer and the model itself. These tokens help the model \n",
160
+ " understand the different roles within a conversation. \n",
161
+ " \n",
162
+ " - The **trl** provides a convenient method called `setup_chat_format` for this purpose. This method performs the \n",
163
+ " following actions: \n",
164
+ " \n",
165
+ " * Adds special tokens to the tokenizer, such as `<|im_start|>` and `<|im_end|>`, to mark the beginning and \n",
166
+ " ending of a conversation. \n",
167
+ " \n",
168
+ " * Resizes the model's embedding layer to accommodate the new tokens.\n",
169
+ " \n",
170
+ " * Sets the tokenizer's chat template, which defines the format used to convert input data into a chat-like \n",
171
+ " structure. The default template is `chatml` from OpenAI.\n",
172
+ "\n",
173
+ "\n"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "id": "5612b641",
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "\n",
184
+ "import torch\n",
185
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
186
+ "from trl import setup_chat_format\n",
187
+ "\n",
188
+ "# Hugging Face model id\n",
189
+ "model_id = \"mistralai/Mistral-7B-v0.1\"\n",
190
+ "\n",
191
+ "# BitsAndBytesConfig\n",
192
+ "bnb_config = BitsAndBytesConfig(\n",
193
+ " load_in_8bit=True, bnb_4bit_use_double_quant=True, \n",
194
+ " bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16 \n",
195
+ ")\n",
196
+ "\n",
197
+ "# Load model and tokenizer\n",
198
+ "model = AutoModelForCausalLM.from_pretrained(\n",
199
+ " model_id,\n",
200
+ " device_map=\"auto\",\n",
201
+ " trust_remote_code=True,\n",
202
+ " attn_implementation='flash_attention_2',\n",
203
+ " torch_dtype=torch.bfloat16,\n",
204
+ " quantization_config=bnb_config\n",
205
+ ")\n",
206
+ "\n",
207
+ "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-v0.1\")\n",
208
+ "tokenizer.padding_side = \"left\"\n",
209
+ "\n",
210
+ "\n",
211
+ "# Set chat template to OAI chatML\n",
212
+ "model, tokenizer = setup_chat_format(model, tokenizer)\n",
213
+ "\n",
214
+ " "
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "id": "25713c3a",
220
+ "metadata": {},
221
+ "source": [
222
+ "## Setting LoRA Config"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "markdown",
227
+ "id": "5a990077",
228
+ "metadata": {},
229
+ "source": [
230
+ "The `SFTTrainer` provides native integration with `peft`, simplifying the process of efficiently tuning \n",
231
+ " Language Models (LLMs) using techniques such as [LoRA](\n",
232
+ " https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms). The only requirement is to create \n",
233
+ " the `LoraConfig` and pass it to the `SFTTrainer`. \n",
234
+ " "
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "id": "3aef033e",
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "\n",
245
+ "from peft import LoraConfig\n",
246
+ "\n",
247
+ "peft_config = LoraConfig(\n",
248
+ " lora_alpha=8,\n",
249
+ " lora_dropout=0.05,\n",
250
+ " r=6,\n",
251
+ " bias=\"none\",\n",
252
+ " target_modules=\"all-linear\",\n",
253
+ " task_type=\"CAUSAL_LM\"\n",
254
+ ")\n",
255
+ " "
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "id": "78dc9315",
261
+ "metadata": {},
262
+ "source": [
263
+ "## Setting the TrainingArguments"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
+ "id": "02e9452a",
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "\n",
274
+ "# Installing tensorboard to report the metrics\n",
275
+ "!pip install -q tensorboard\n",
276
+ " "
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": null,
282
+ "id": "4cb748d1",
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "\n",
287
+ "from transformers import TrainingArguments\n",
288
+ "\n",
289
+ "args = TrainingArguments(\n",
290
+ " output_dir=\"temp_/tmp/model\",\n",
291
+ " num_train_epochs=100,\n",
292
+ " per_device_train_batch_size=3,\n",
293
+ " gradient_accumulation_steps=2,\n",
294
+ " gradient_checkpointing=True,\n",
295
+ " gradient_checkpointing_kwargs={'use_reentrant': False},\n",
296
+ " optim=\"adamw_torch_fused\",\n",
297
+ " logging_steps=10,\n",
298
+ " save_strategy='epoch',\n",
299
+ " learning_rate=0.075,\n",
300
+ " bf16=True,\n",
301
+ " max_grad_norm=0.3,\n",
302
+ " warmup_ratio=0.1,\n",
303
+ " lr_scheduler_type='cosine',\n",
304
+ " report_to='tensorboard', \n",
305
+ " max_steps=-1,\n",
306
+ " seed=42,\n",
307
+ " overwrite_output_dir=True,\n",
308
+ " remove_unused_columns=True\n",
309
+ ")\n",
310
+ " "
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "markdown",
315
+ "id": "afad0f24",
316
+ "metadata": {},
317
+ "source": [
318
+ "## Setting the Supervised Finetuning Trainer (`SFTTrainer`)\n",
319
+ " \n",
320
+ "This `SFTTrainer` is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods.\n",
321
+ "The trainer takes care of properly initializing the `PeftModel`. \n",
322
+ " "
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "id": "4786995f",
329
+ "metadata": {},
330
+ "outputs": [],
331
+ "source": [
332
+ "\n",
333
+ "from trl import SFTTrainer\n",
334
+ "\n",
335
+ "trainer = SFTTrainer(\n",
336
+ " model=model,\n",
337
+ " args=args,\n",
338
+ " train_dataset=dataset,\n",
339
+ " peft_config=peft_config,\n",
340
+ " max_seq_length=2048,\n",
341
+ " tokenizer=tokenizer,\n",
342
+ " packing=True,\n",
343
+ " dataset_kwargs={'add_special_tokens': False, 'append_concat_token': False}\n",
344
+ ")\n"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "markdown",
349
+ "id": "5a32f64b",
350
+ "metadata": {},
351
+ "source": [
352
+ "### Starting Training and Saving Model/Tokenizer\n",
353
+ "\n",
354
+ "We start training the model by calling the `train()` method on the trainer instance. This will start the training \n",
355
+ "loop and train the model for `100 epochs`. The model will be automatically saved to the output directory (**'temp_/tmp/model'**)\n",
356
+ "and to the hub in **'User//tmp/model'**. \n",
357
+ " \n",
358
+ " "
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "id": "1a722966",
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": [
368
+ "\n",
369
+ "\n",
370
+ "model.config.use_cache = False\n",
371
+ "\n",
372
+ "# start training\n",
373
+ "trainer.train()\n",
374
+ "\n",
375
+ "# save the peft model\n",
376
+ "trainer.save_model()\n"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "markdown",
381
+ "id": "5d72635c",
382
+ "metadata": {},
383
+ "source": [
384
+ "### Free the GPU Memory to Prepare Merging `LoRA` Adapters with the Base Model\n"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": null,
390
+ "id": "131b1b16",
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": [
394
+ "\n",
395
+ "\n",
396
+ "# Free the GPU memory\n",
397
+ "del model\n",
398
+ "del trainer\n",
399
+ "torch.cuda.empty_cache()\n"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "markdown",
404
+ "id": "2ea238ed",
405
+ "metadata": {},
406
+ "source": [
407
+ "## Merging LoRA Adapters into the Original Model\n",
408
+ "\n",
409
+ "While utilizing `LoRA`, we focus on training the adapters rather than the entire model. Consequently, during the \n",
410
+ "model saving process, only the `adapter weights` are preserved, not the complete model. If we wish to save the \n",
411
+ "entire model for easier usage with Text Generation Inference, we can incorporate the adapter weights into the model \n",
412
+ "weights. This can be achieved using the `merge_and_unload` method. Following this, the model can be saved using the \n",
413
+ "`save_pretrained` method. The result is a default model that is ready for inference.\n"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": null,
419
+ "id": "0f1dc2a9",
420
+ "metadata": {},
421
+ "outputs": [],
422
+ "source": [
423
+ "\n",
424
+ "import torch\n",
425
+ "from peft import AutoPeftModelForCausalLM\n",
426
+ "\n",
427
+ "# Load Peft model on CPU\n",
428
+ "model = AutoPeftModelForCausalLM.from_pretrained(\n",
429
+ " \"temp_/tmp/model\",\n",
430
+ " torch_dtype=torch.float16,\n",
431
+ " low_cpu_mem_usage=True\n",
432
+ ")\n",
433
+ " \n",
434
+ "# Merge LoRA with the base model and save\n",
435
+ "merged_model = model.merge_and_unload()\n",
436
+ "merged_model.save_pretrained(\"/tmp/model\", safe_serialization=True, max_shard_size=\"2GB\")\n",
437
+ "tokenizer.save_pretrained(\"/tmp/model\")\n"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "markdown",
442
+ "id": "41cfdd2c",
443
+ "metadata": {},
444
+ "source": [
445
+ "### Copy all result folders from 'temp_/tmp/model' to '/tmp/model'"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": null,
451
+ "id": "a115b861",
452
+ "metadata": {},
453
+ "outputs": [],
454
+ "source": [
455
+ "\n",
456
+ "import os\n",
457
+ "import shutil\n",
458
+ "\n",
459
+ "source_folder = \"temp_/tmp/model\"\n",
460
+ "destination_folder = \"/tmp/model\"\n",
461
+ "os.makedirs(destination_folder, exist_ok=True)\n",
462
+ "for item in os.listdir(source_folder):\n",
463
+ " item_path = os.path.join(source_folder, item)\n",
464
+ " if os.path.isdir(item_path):\n",
465
+ " destination_path = os.path.join(destination_folder, item)\n",
466
+ " shutil.copytree(item_path, destination_path)\n"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "markdown",
471
+ "id": "427b8a54",
472
+ "metadata": {},
473
+ "source": [
474
+ "### Generating a model card (README.md)"
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "code",
479
+ "execution_count": null,
480
+ "id": "bb89c11b",
481
+ "metadata": {},
482
+ "outputs": [],
483
+ "source": [
484
+ "\n",
485
+ "card = '''\n",
486
+ "---\n",
487
+ "license: apache-2.0\n",
488
+ "tags:\n",
489
+ "- generated_from_trainer\n",
490
+ "- mistralai/Mistral\n",
491
+ "- PyTorch\n",
492
+ "- transformers\n",
493
+ "- trl\n",
494
+ "- peft\n",
495
+ "- tensorboard\n",
496
+ "base_model: mistralai/Mistral-7B-v0.1\n",
497
+ "widget:\n",
498
+ " - example_title: Pirate!\n",
499
+ " messages:\n",
500
+ " - role: system\n",
501
+ " content: You are a pirate chatbot who always responds with Arr!\n",
502
+ " - role: user\n",
503
+ " content: \"There's a llama on my lawn, how can I get rid of him?\"\n",
504
+ " output:\n",
505
+ " text: >-\n",
506
+ " Arr! 'Tis a puzzlin' matter, me hearty! A llama on yer lawn be a rare\n",
507
+ " sight, but I've got a plan that might help ye get rid of 'im. Ye'll need\n",
508
+ " to gather some carrots and hay, and then lure the llama away with the\n",
509
+ " promise of a tasty treat. Once he's gone, ye can clean up yer lawn and\n",
510
+ " enjoy the peace and quiet once again. But beware, me hearty, for there\n",
511
+ " may be more llamas where that one came from! Arr!\n",
512
+ "model-index:\n",
513
+ "- name: /tmp/model\n",
514
+ " results: []\n",
515
+ "datasets:\n",
516
+ "- HuggingFaceH4/ultrachat_200k\n",
517
+ "language:\n",
518
+ "- en\n",
519
+ "pipeline_tag: text-generation\n",
520
+ "---\n",
521
+ "\n",
522
+ "# Model Card for /tmp/model:\n",
523
+ "\n",
524
+ "**/tmp/model** is a language model that is trained to act as helpful assistant. It is a finetuned version of [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) that was trained using `SFTTrainer` on publicly available dataset [\n",
525
+ "HuggingFaceH4/ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k).\n",
526
+ "\n",
527
+ "## Training Procedure:\n",
528
+ "\n",
529
+ "The training code used to create this model was generated by [Menouar/LLM-FineTuning-Notebook-Generator](https://huggingface.co/spaces/Menouar/LLM-FineTuning-Notebook-Generator).\n",
530
+ "\n",
531
+ "\n",
532
+ "\n",
533
+ "## Training hyperparameters\n",
534
+ "\n",
535
+ "The following hyperparameters were used during the training:\n",
536
+ "\n",
537
+ "\n",
538
+ "'''\n",
539
+ "\n",
540
+ "with open(\"/tmp/model/README.md\", \"w\") as f:\n",
541
+ " f.write(card)\n",
542
+ "\n",
543
+ "args_dict = vars(args)\n",
544
+ "\n",
545
+ "with open(\"/tmp/model/README.md\", \"a\") as f:\n",
546
+ " for k, v in args_dict.items():\n",
547
+ " f.write(f\"- {k}: {v}\")\n",
548
+ " f.write(\"\\n \\n\")\n"
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "markdown",
553
+ "id": "12c5ab30",
554
+ "metadata": {},
555
+ "source": [
556
+ "## Login to HF"
557
+ ]
558
+ },
559
+ {
560
+ "cell_type": "markdown",
561
+ "id": "10117bb9",
562
+ "metadata": {},
563
+ "source": [
564
+ "Replace `HF_TOKEN` with a valid token in order to push **'/tmp/model'** to `huggingface_hub`."
565
+ ]
566
+ },
567
+ {
568
+ "cell_type": "code",
569
+ "execution_count": null,
570
+ "id": "8e0697a8",
571
+ "metadata": {},
572
+ "outputs": [],
573
+ "source": [
574
+ "\n",
575
+ "# Install huggingface_hub\n",
576
+ "!pip install -q huggingface_hub\n",
577
+ " \n",
578
+ "from huggingface_hub import login\n",
579
+ " \n",
580
+ "login(\n",
581
+ " token='HF_TOKEN',\n",
582
+ " add_to_git_credential=True\n",
583
+ ")\n",
584
+ " "
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "markdown",
589
+ "id": "f176ddac",
590
+ "metadata": {},
591
+ "source": [
592
+ "## Pushing '/tmp/model' to the Hugging Face account."
593
+ ]
594
+ },
595
+ {
596
+ "cell_type": "code",
597
+ "execution_count": null,
598
+ "id": "7a6b3c9f",
599
+ "metadata": {},
600
+ "outputs": [],
601
+ "source": [
602
+ "\n",
603
+ "from huggingface_hub import HfApi, HfFolder, Repository\n",
604
+ "\n",
605
+ "# Instantiate the HfApi class\n",
606
+ "api = HfApi()\n",
607
+ "\n",
608
+ "# Our Hugging Face repository\n",
609
+ "repo_name = \"/tmp/model\"\n",
610
+ "\n",
611
+ "# Create a repository on the Hugging Face Hub\n",
612
+ "repo = api.create_repo(token=HfFolder.get_token(), repo_type=\"model\", repo_id=repo_name)\n",
613
+ "\n",
614
+ "api.upload_folder(\n",
615
+ " folder_path=\"/tmp/model\",\n",
616
+ " repo_id=repo.repo_id\n",
617
+ ")\n"
618
+ ]
619
+ }
620
+ ],
621
+ "metadata": {},
622
+ "nbformat": 4,
623
+ "nbformat_minor": 5
624
+ }