menouar commited on
Commit
592b663
·
1 Parent(s): bc9310c

Update the generated Notebook to login properly to HF

Browse files
Files changed (2) hide show
  1. app.py +13 -5
  2. utils/notebook_generator.py +14 -7
app.py CHANGED
@@ -2,7 +2,6 @@ from typing import Any
2
 
3
  from nbconvert import HTMLExporter
4
 
5
-
6
  from utils.notebook_generator import *
7
  from utils.components_creator import *
8
 
@@ -44,6 +43,10 @@ def centered_column():
44
  return gr.Column(elem_classes=["container"])
45
 
46
 
 
 
 
 
47
  def change_model_selection(model_id):
48
  if model_id == gemma.name:
49
  gr.Warning("""
@@ -107,10 +110,6 @@ def generate_code(components: dict[Component, Any]):
107
  if flash_attention_value:
108
  create_install_flash_attention(notebook['cells'])
109
 
110
- push_to_hub = get_value(components, PUSH_TO_HUB_ID)
111
- if push_to_hub:
112
- create_login_hf_cells(notebook['cells'])
113
-
114
  dataset_value = get_value(components, DATASET_SELECTION_ID)
115
  seed_value = get_value(components, DATASET_SHUFFLING_SEED)
116
  if not check_valid_input(dataset_value):
@@ -119,6 +118,8 @@ def generate_code(components: dict[Component, Any]):
119
  create_datasets_cells(notebook['cells'], get_dataset(dataset_value), seed_value)
120
 
121
  model_value = get_value(components, MODEL_SELECTION_ID)
 
 
122
  if not check_valid_input(model_value):
123
  gr.Warning("No model is selected!")
124
  else:
@@ -126,6 +127,9 @@ def generate_code(components: dict[Component, Any]):
126
  if not check_valid_input(version_value):
127
  gr.Warning("No version of the model is selected")
128
  else:
 
 
 
129
  load_in_4bit = get_value(components, LOAD_IN_4_BIT_ID)
130
  bnb_4bit_use_double_quant = get_value(components, BNB_4BIT_USE_DOUBLE_QUANT)
131
  bnb_4bit_quant_type = get_value(components, BNB_4BIT_QUANT_TYPE)
@@ -174,6 +178,8 @@ def generate_code(components: dict[Component, Any]):
174
  packing = get_value(components, PACKING_ID)
175
  create_sft_trainer_cells(notebook['cells'], max_seq_length, packing)
176
 
 
 
177
  create_start_training_cells(notebook['cells'], epochs, max_steps, push_to_hub, output_dir)
178
 
179
  create_free_gpu_cells(notebook['cells'])
@@ -181,6 +187,8 @@ def generate_code(components: dict[Component, Any]):
181
  create_merge_lora_cells(notebook['cells'], output_dir)
182
 
183
  if push_to_hub:
 
 
184
  push_merged_model_cells(notebook['cells'], output_dir)
185
 
186
  file_name = f"{finetuning_notebook}.ipynb"
 
2
 
3
  from nbconvert import HTMLExporter
4
 
 
5
  from utils.notebook_generator import *
6
  from utils.components_creator import *
7
 
 
43
  return gr.Column(elem_classes=["container"])
44
 
45
 
46
+ def should_login_to_hf_model(model_id: str):
47
+ return model_id == gemma.name or model_id == llama.name
48
+
49
+
50
  def change_model_selection(model_id):
51
  if model_id == gemma.name:
52
  gr.Warning("""
 
110
  if flash_attention_value:
111
  create_install_flash_attention(notebook['cells'])
112
 
 
 
 
 
113
  dataset_value = get_value(components, DATASET_SELECTION_ID)
114
  seed_value = get_value(components, DATASET_SHUFFLING_SEED)
115
  if not check_valid_input(dataset_value):
 
118
  create_datasets_cells(notebook['cells'], get_dataset(dataset_value), seed_value)
119
 
120
  model_value = get_value(components, MODEL_SELECTION_ID)
121
+ should_login = should_login_to_hf_model(model_value)
122
+
123
  if not check_valid_input(model_value):
124
  gr.Warning("No model is selected!")
125
  else:
 
127
  if not check_valid_input(version_value):
128
  gr.Warning("No version of the model is selected")
129
  else:
130
+ if should_login:
131
+ create_login_hf_cells(notebook['cells'], should_login=True, model_name=model_value)
132
+
133
  load_in_4bit = get_value(components, LOAD_IN_4_BIT_ID)
134
  bnb_4bit_use_double_quant = get_value(components, BNB_4BIT_USE_DOUBLE_QUANT)
135
  bnb_4bit_quant_type = get_value(components, BNB_4BIT_QUANT_TYPE)
 
178
  packing = get_value(components, PACKING_ID)
179
  create_sft_trainer_cells(notebook['cells'], max_seq_length, packing)
180
 
181
+ push_to_hub = get_value(components, PUSH_TO_HUB_ID)
182
+
183
  create_start_training_cells(notebook['cells'], epochs, max_steps, push_to_hub, output_dir)
184
 
185
  create_free_gpu_cells(notebook['cells'])
 
187
  create_merge_lora_cells(notebook['cells'], output_dir)
188
 
189
  if push_to_hub:
190
+ if not should_login:
191
+ create_login_hf_cells(notebook['cells'])
192
  push_merged_model_cells(notebook['cells'], output_dir)
193
 
194
  file_name = f"{finetuning_notebook}.ipynb"
utils/notebook_generator.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import nbformat as nbf
2
 
3
  from utils import FTDataSet
@@ -47,7 +49,7 @@ def create_install_flash_attention(cells: list):
47
  import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'
48
 
49
  !pip install ninja packaging
50
- !MAX_JOBS=4 pip install flash-attn --no-build-isolation --upgrade
51
  """
52
  code_cell = nbf.v4.new_code_cell(code)
53
  cells.append(text_cell)
@@ -55,13 +57,16 @@ import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not s
55
  cells.append(code_cell)
56
 
57
 
58
- def create_login_hf_cells(cells: list):
59
  text_cell = nbf.v4.new_markdown_cell(
60
  "### Login to HF")
61
- text_cell1 = nbf.v4.new_markdown_cell("Installing **huggingface_hub** to use as a remote "
62
- "model versioning service. This means that your model, logs, and information "
63
- "will be automatically pushed to the Hub during training. You should have "
64
- "'HF_TOKEN'")
 
 
 
65
  code = """
66
  # Install huggingface_hub
67
  !pip install -q huggingface_hub
@@ -229,6 +234,8 @@ def create_training_args_cells(cells: list, epochs, max_steps, logging_steps, pe
229
  elif report_to != "none":
230
  to_install = report_to
231
 
 
 
232
  code_report = f"""
233
  # Installing {to_install} to report the metrics
234
 
@@ -244,12 +251,12 @@ args = TrainingArguments(
244
  per_device_train_batch_size={per_device_train_batch_size},
245
  gradient_accumulation_steps={gradient_accumulation_steps},
246
  gradient_checkpointing={gradient_checkpointing},
 
247
  optim="adamw_torch_fused",
248
  logging_steps={logging_steps},
249
  save_strategy='{save_strategy}',
250
  learning_rate={learning_rate},
251
  bf16=True,
252
- tf32=True,
253
  max_grad_norm={max_grad_norm},
254
  warmup_ratio={warmup_ratio},
255
  lr_scheduler_type='{lr_scheduler_type}',
 
1
+ from typing import Optional
2
+
3
  import nbformat as nbf
4
 
5
  from utils import FTDataSet
 
49
  import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'
50
 
51
  !pip install ninja packaging
52
+ !MAX_JOBS=4 pip install -q flash-attn --no-build-isolation --upgrade
53
  """
54
  code_cell = nbf.v4.new_code_cell(code)
55
  cells.append(text_cell)
 
57
  cells.append(code_cell)
58
 
59
 
60
+ def create_login_hf_cells(cells: list, should_login: bool = False, model_name: Optional[str] = None):
61
  text_cell = nbf.v4.new_markdown_cell(
62
  "### Login to HF")
63
+
64
+ text_1 = "Login with your `HF_TOKEN` in order to push the finetuned model to `huggingface_hub`."
65
+
66
+ if should_login:
67
+ text_1 = f"Login with your `HF_TOKEN` in order to load **{model_name}** from `huggingface_hub`."
68
+
69
+ text_cell1 = nbf.v4.new_markdown_cell(text_1)
70
  code = """
71
  # Install huggingface_hub
72
  !pip install -q huggingface_hub
 
234
  elif report_to != "none":
235
  to_install = report_to
236
 
237
+ gradient_checkpointing_kwargs = {"use_reentrant": False}
238
+
239
  code_report = f"""
240
  # Installing {to_install} to report the metrics
241
 
 
251
  per_device_train_batch_size={per_device_train_batch_size},
252
  gradient_accumulation_steps={gradient_accumulation_steps},
253
  gradient_checkpointing={gradient_checkpointing},
254
+ gradient_checkpointing_kwargs={gradient_checkpointing_kwargs},
255
  optim="adamw_torch_fused",
256
  logging_steps={logging_steps},
257
  save_strategy='{save_strategy}',
258
  learning_rate={learning_rate},
259
  bf16=True,
 
260
  max_grad_norm={max_grad_norm},
261
  warmup_ratio={warmup_ratio},
262
  lr_scheduler_type='{lr_scheduler_type}',