Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,9 @@ import gradio as gr
|
|
8 |
import json
|
9 |
from huggingface_hub import HfApi
|
10 |
|
|
|
|
|
|
|
11 |
max_seq_length = 4096
|
12 |
dtype = None
|
13 |
load_in_4bit = True
|
@@ -17,7 +20,6 @@ current_num = os.getenv("NUM")
|
|
17 |
print(f"stage ${current_num}")
|
18 |
|
19 |
api = HfApi(token=hf_token)
|
20 |
-
# models = f"dad1909/cybersentinal-2.0-{current_num}"
|
21 |
|
22 |
model_base = "unsloth/llama-3-8b-Instruct-bnb-4bit"
|
23 |
|
@@ -31,13 +33,17 @@ model, tokenizer = FastLanguageModel.from_pretrained(
|
|
31 |
load_in_4bit=load_in_4bit,
|
32 |
token=hf_token
|
33 |
)
|
34 |
-
|
|
|
|
|
35 |
|
36 |
# Wrap the model in DataParallel to use all GPUs
|
37 |
if torch.cuda.device_count() > 1:
|
38 |
print(f"Using {torch.cuda.device_count()} GPUs!")
|
39 |
model = torch.nn.DataParallel(model)
|
40 |
|
|
|
|
|
41 |
print("Configuring PEFT model...")
|
42 |
model = FastLanguageModel.get_peft_model(
|
43 |
model.module if isinstance(model, torch.nn.DataParallel) else model,
|
@@ -118,7 +124,7 @@ trainer = SFTTrainer(
|
|
118 |
dataset_num_proc=2,
|
119 |
packing=False,
|
120 |
args=TrainingArguments(
|
121 |
-
per_device_train_batch_size=17,
|
122 |
gradient_accumulation_steps=17,
|
123 |
learning_rate=2e-4,
|
124 |
fp16=not is_bfloat16_supported(),
|
@@ -144,21 +150,27 @@ num += 1
|
|
144 |
uploads_models = f"cybersentinal-3.0"
|
145 |
|
146 |
print("Saving the trained model...")
|
147 |
-
|
|
|
|
|
|
|
148 |
print("Model saved successfully.")
|
149 |
|
150 |
print("Pushing the model to the hub...")
|
151 |
-
model.
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
162 |
print("Model pushed to hub successfully.")
|
163 |
|
164 |
api.delete_space_variable(repo_id="dad1909/CyberCode", key="NUM")
|
|
|
8 |
import json
|
9 |
from huggingface_hub import HfApi
|
10 |
|
11 |
+
# Ensure that all 4 GPUs are visible to PyTorch
|
12 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
13 |
+
|
14 |
max_seq_length = 4096
|
15 |
dtype = None
|
16 |
load_in_4bit = True
|
|
|
20 |
print(f"stage ${current_num}")
|
21 |
|
22 |
api = HfApi(token=hf_token)
|
|
|
23 |
|
24 |
model_base = "unsloth/llama-3-8b-Instruct-bnb-4bit"
|
25 |
|
|
|
33 |
load_in_4bit=load_in_4bit,
|
34 |
token=hf_token
|
35 |
)
|
36 |
+
|
37 |
+
# Move the model to GPU
|
38 |
+
model = model.to('cuda')
|
39 |
|
40 |
# Wrap the model in DataParallel to use all GPUs
|
41 |
if torch.cuda.device_count() > 1:
|
42 |
print(f"Using {torch.cuda.device_count()} GPUs!")
|
43 |
model = torch.nn.DataParallel(model)
|
44 |
|
45 |
+
print("Model and tokenizer loaded successfully.")
|
46 |
+
|
47 |
print("Configuring PEFT model...")
|
48 |
model = FastLanguageModel.get_peft_model(
|
49 |
model.module if isinstance(model, torch.nn.DataParallel) else model,
|
|
|
124 |
dataset_num_proc=2,
|
125 |
packing=False,
|
126 |
args=TrainingArguments(
|
127 |
+
per_device_train_batch_size=17, # Adjust this based on GPU memory
|
128 |
gradient_accumulation_steps=17,
|
129 |
learning_rate=2e-4,
|
130 |
fp16=not is_bfloat16_supported(),
|
|
|
150 |
uploads_models = f"cybersentinal-3.0"
|
151 |
|
152 |
print("Saving the trained model...")
|
153 |
+
if isinstance(model, torch.nn.DataParallel):
|
154 |
+
model.module.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
|
155 |
+
else:
|
156 |
+
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
|
157 |
print("Model saved successfully.")
|
158 |
|
159 |
print("Pushing the model to the hub...")
|
160 |
+
if isinstance(model, torch.nn.DataParallel):
|
161 |
+
model.module.push_to_hub_merged(
|
162 |
+
uploads_models,
|
163 |
+
tokenizer,
|
164 |
+
save_method="merged_16bit",
|
165 |
+
token=hf_token
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
model.push_to_hub_merged(
|
169 |
+
uploads_models,
|
170 |
+
tokenizer,
|
171 |
+
save_method="merged_16bit",
|
172 |
+
token=hf_token
|
173 |
+
)
|
174 |
print("Model pushed to hub successfully.")
|
175 |
|
176 |
api.delete_space_variable(repo_id="dad1909/CyberCode", key="NUM")
|