admincybers2 commited on
Commit
635b1f2
·
verified ·
1 Parent(s): eda60d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from unsloth import FastLanguageModel, is_bfloat16_supported
4
+ from trl import SFTTrainer
5
+ from transformers import TrainingArguments
6
+ from datasets import load_dataset
7
+ import gradio as gr
8
+ import json
9
+ from huggingface_hub import HfApi
10
+
11
+ max_seq_length = 4096
12
+ dtype = None
13
+ load_in_4bit = True
14
+ hf_token = os.getenv("HF_TOKEN")
15
+ current_num = os.getenv("NUM")
16
+
17
+ print(f"stage ${current_num}")
18
+
19
+ api = HfApi(token=hf_token)
20
+ models = "unsloth/Meta-Llama-3.1-70B-bnb-4bit"
21
+
22
+ print("Starting model and tokenizer loading...")
23
+
24
+ # Load the model and tokenizer
25
+ model, tokenizer = FastLanguageModel.from_pretrained(
26
+ model_name=model_base,
27
+ max_seq_length=max_seq_length,
28
+ dtype=dtype,
29
+ load_in_4bit=load_in_4bit,
30
+ token=hf_token
31
+ )
32
+
33
+ print("Model and tokenizer loaded successfully.")
34
+
35
+ print("Configuring PEFT model...")
36
+ model = FastLanguageModel.get_peft_model(
37
+ model,
38
+ r=16,
39
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
40
+ lora_alpha=16,
41
+ lora_dropout=0,
42
+ bias="none",
43
+ use_gradient_checkpointing="unsloth",
44
+ random_state=3407,
45
+ use_rslora=False,
46
+ loftq_config=None,
47
+ )
48
+ print("PEFT model configured.")
49
+
50
+ # Updated alpaca_prompt for different types
51
+ alpaca_prompt = {
52
+ "learning_from": """Below is a CVE definition.
53
+ ### CVE definition:
54
+ {}
55
+ ### detail CVE:
56
+ {}""",
57
+ "definition": """Below is a definition about software vulnerability. Explain it.
58
+ ### Definition:
59
+ {}
60
+ ### Explanation:
61
+ {}""",
62
+ "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
63
+ ### Code Snippet:
64
+ {}
65
+ ### Vulnerability solution:
66
+ {}"""
67
+ }
68
+
69
+ EOS_TOKEN = tokenizer.eos_token
70
+
71
+ def detect_prompt_type(instruction):
72
+ if instruction.startswith("what is code vulnerable of this code:"):
73
+ return "code_vulnerability"
74
+ elif instruction.startswith("Learning from"):
75
+ return "learning_from"
76
+ elif instruction.startswith("what is"):
77
+ return "definition"
78
+ else:
79
+ return "unknown"
80
+
81
+ def formatting_prompts_func(examples):
82
+ instructions = examples["instruction"]
83
+ outputs = examples["output"]
84
+ texts = []
85
+
86
+ for instruction, output in zip(instructions, outputs):
87
+ prompt_type = detect_prompt_type(instruction)
88
+ if prompt_type in alpaca_prompt:
89
+ prompt = alpaca_prompt[prompt_type].format(instruction, output)
90
+ else:
91
+ prompt = instruction + "\n\n" + output
92
+ text = prompt + EOS_TOKEN
93
+ texts.append(text)
94
+
95
+ return {"text": texts}
96
+
97
+ print("Loading dataset...")
98
+ dataset = load_dataset("admincybers2/DSV", split="train")
99
+ print("Dataset loaded successfully.")
100
+
101
+ print("Applying formatting function to the dataset...")
102
+ dataset = dataset.map(formatting_prompts_func, batched=True)
103
+ print("Formatting function applied.")
104
+
105
+ print("Initializing trainer...")
106
+ trainer = SFTTrainer(
107
+ model=model,
108
+ tokenizer=tokenizer,
109
+ train_dataset=dataset,
110
+ dataset_text_field="text",
111
+ max_seq_length=max_seq_length,
112
+ dataset_num_proc=2,
113
+ packing=False,
114
+ args=TrainingArguments(
115
+ per_device_train_batch_size=1,
116
+ gradient_accumulation_steps=1,
117
+ learning_rate=2e-4,
118
+ fp16=not is_bfloat16_supported(),
119
+ bf16=is_bfloat16_supported(),
120
+ warmup_steps=5,
121
+ logging_steps=10,
122
+ optim="adamw_8bit",
123
+ weight_decay=0.01,
124
+ lr_scheduler_type="linear",
125
+ seed=3407,
126
+ output_dir="outputs"
127
+ ),
128
+ )
129
+ print("Trainer initialized.")
130
+
131
+ print("Starting training...")
132
+ trainer_stats = trainer.train()
133
+ print("Training completed.")
134
+
135
+ num = int(current_num)
136
+ num += 1
137
+
138
+ uploads_models = f"cybersentinal-2.0-{str(num)}"
139
+
140
+ up = "sentinal-3.1-70B"
141
+
142
+ print("Saving the trained model...")
143
+ model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
144
+ print("Model saved successfully.")
145
+
146
+ print("Pushing the model to the hub...")
147
+ model.push_to_hub_merged(
148
+ up,
149
+ tokenizer,
150
+ save_method="merged_16bit",
151
+ token=hf_token
152
+ )
153
+ print("Model pushed to hub successfully.")
154
+
155
+ api.delete_space_variable(repo_id="admincybers2/CyberController", key="NUM")
156
+ api.add_space_variable(repo_id="admincybers2/CyberController", key="NUM", value=str(num))