dad1909 commited on
Commit
493a1a4
·
verified ·
1 Parent(s): e427808

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -147
app.py CHANGED
@@ -1,158 +1,90 @@
1
  import os
2
- import torch
3
- from unsloth import FastLanguageModel, is_bfloat16_supported
4
- from trl import SFTTrainer
5
- from transformers import TrainingArguments
6
- from datasets import load_dataset
7
  import gradio as gr
8
- import json
9
- from huggingface_hub import HfApi
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- max_seq_length = 4096
12
- dtype = None
13
- load_in_4bit = True
14
  hf_token = os.getenv("HF_TOKEN")
15
- current_num = os.getenv("NUM")
16
-
17
- print(f"stage ${current_num}")
18
-
19
- api = HfApi(token=hf_token)
20
- # models = f"dad1909/cybersentinal-2.0-{current_num}"
21
- model_base = "unsloth/gemma-2-27b-bnb-4bit"
22
-
23
- print("Starting model and tokenizer loading...")
24
-
25
- # Load the model and tokenizer
26
- model, tokenizer = FastLanguageModel.from_pretrained(
27
- model_name=model_base,
28
- max_seq_length=max_seq_length,
29
- dtype=dtype,
30
- load_in_4bit=load_in_4bit,
31
- token=hf_token
32
- )
33
-
34
- print("Model and tokenizer loaded successfully.")
35
-
36
- print("Configuring PEFT model...")
37
- model = FastLanguageModel.get_peft_model(
38
- model,
39
- r=16,
40
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
41
- lora_alpha=16,
42
- lora_dropout=0,
43
- bias="none",
44
- use_gradient_checkpointing="unsloth",
45
- random_state=3407,
46
- use_rslora=False,
47
- loftq_config=None,
48
- )
49
- print("PEFT model configured.")
50
 
51
- # Updated alpaca_prompt for different types
52
- alpaca_prompt = {
53
- "learning_from": """Below is a CVE definition.
54
- ### CVE definition:
55
- {}
56
- ### detail CVE:
57
- {}""",
58
- "definition": """Below is a definition about software vulnerability. Explain it.
59
- ### Definition:
60
- {}
61
- ### Explanation:
62
- {}""",
63
- "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
 
 
 
 
 
 
 
 
 
64
  ### Code Snippet:
65
  {}
66
- ### Vulnerability solution:
67
- {}"""
68
  }
69
 
70
- EOS_TOKEN = tokenizer.eos_token
71
-
72
- def detect_prompt_type(instruction):
73
- if instruction.startswith("what is code vulnerable of this code:"):
74
- return "code_vulnerability"
75
- elif instruction.startswith("Learning from"):
76
- return "learning_from"
77
- elif instruction.startswith("what is"):
78
- return "definition"
79
- else:
80
- return "unknown"
81
-
82
- def formatting_prompts_func(examples):
83
- instructions = examples["instruction"]
84
- outputs = examples["output"]
85
- texts = []
86
-
87
- for instruction, output in zip(instructions, outputs):
88
- prompt_type = detect_prompt_type(instruction)
89
- if prompt_type in alpaca_prompt:
90
- prompt = alpaca_prompt[prompt_type].format(instruction, output)
91
- else:
92
- prompt = instruction + "\n\n" + output
93
- text = prompt + EOS_TOKEN
94
- texts.append(text)
95
-
96
- return {"text": texts}
97
-
98
- print("Loading dataset...")
99
- dataset = load_dataset("dad1909/DCSV", split="train")
100
- print("Dataset loaded successfully.")
101
-
102
- print("Applying formatting function to the dataset...")
103
- dataset = dataset.map(formatting_prompts_func, batched=True)
104
- print("Formatting function applied.")
105
-
106
- print("Initializing trainer...")
107
- trainer = SFTTrainer(
108
- model=model,
109
- tokenizer=tokenizer,
110
- train_dataset=dataset,
111
- dataset_text_field="text",
112
- max_seq_length=max_seq_length,
113
- dataset_num_proc=2,
114
- packing=False,
115
- args=TrainingArguments(
116
- per_device_train_batch_size=1,
117
- gradient_accumulation_steps=1,
118
- learning_rate=2e-4,
119
- fp16=not is_bfloat16_supported(),
120
- bf16=is_bfloat16_supported(),
121
- warmup_steps=5,
122
- logging_steps=10,
123
- max_steps=100,
124
- optim="adamw_8bit",
125
- weight_decay=0.01,
126
- lr_scheduler_type="linear",
127
- seed=3407,
128
- output_dir="outputs"
129
- ),
130
- )
131
- print("Trainer initialized.")
132
-
133
- print("Starting training...")
134
- trainer_stats = trainer.train()
135
- print("Training completed.")
136
-
137
- num = int(current_num)
138
- num += 1
139
-
140
- uploads_models = f"cybersentinal-2.0-{str(num)}"
141
-
142
- up = "sentinal-3.1-70B"
143
-
144
- print("Saving the trained model...")
145
- model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
146
- print("Model saved successfully.")
147
-
148
- print("Pushing the model to the hub...")
149
- model.push_to_hub_merged(
150
- up,
151
- tokenizer,
152
- save_method="merged_16bit",
153
- token=hf_token
154
  )
155
- print("Model pushed to hub successfully.")
156
 
157
- api.delete_space_variable(repo_id="dad1909/CyberCode", key="NUM")
158
- api.add_space_variable(repo_id="dad1909/CyberCode", key="NUM", value=str(num))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
2
  import gradio as gr
3
+ import torch
4
+ from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer
5
+ import spaces
6
+
7
+ # Define the model configurations
8
+ model_configs = {
9
+ "CyberSentinel": {
10
+ "model_name": "dad1909/cybersentinal-2.0",
11
+ "max_seq_length": 1028,
12
+ "dtype": torch.float16,
13
+ "load_in_4bit": True
14
+ }
15
+ }
16
 
17
+ # Hugging Face token
 
 
18
  hf_token = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Load the model when the application starts
21
+ loaded_models = {}
22
+
23
+ def load_model(selected_model):
24
+ if selected_model not in loaded_models:
25
+ config = model_configs[selected_model]
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ config["model_name"],
28
+ torch_dtype=config["dtype"],
29
+ device_map="auto",
30
+ use_auth_token=hf_token
31
+ )
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ config["model_name"],
34
+ use_auth_token=hf_token
35
+ )
36
+ loaded_models[selected_model] = (model, tokenizer)
37
+ return loaded_models[selected_model]
38
+
39
+ alpaca_prompts = {
40
+ "information": "Give me information about the following topic: {}",
41
+ "vulnerable": """Identify the line of code that is vulnerable and describe the type of software vulnerability.
42
  ### Code Snippet:
43
  {}
44
+ ### Vulnerability Description:""",
45
+ "Chat": "{}"
46
  }
47
 
48
+ @spaces.GPU(duration=100)
49
+ def predict(selected_model, prompt, prompt_type, max_length=128):
50
+ model, tokenizer = load_model(selected_model)
51
+ selected_prompt = alpaca_prompts[prompt_type]
52
+ formatted_prompt = selected_prompt.format(prompt)
53
+ inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
54
+ text_streamer = TextStreamer(tokenizer)
55
+ output = model.generate(**inputs, streamer=text_streamer, max_new_tokens=max_length)
56
+ return tokenizer.decode(output[0], skip_special_tokens=True)
57
+
58
+ theme = gr.themes.Default(
59
+ primary_hue=gr.themes.colors.rose,
60
+ secondary_hue=gr.themes.colors.blue,
61
+ font=gr.themes.GoogleFont("Source Sans Pro")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
 
63
 
64
+ load_model("CyberSentinel")
65
+
66
+ with gr.Blocks(theme=theme) as demo:
67
+ selected_model = gr.Dropdown(choices=list(model_configs.keys()), value="CyberSentinel", label="Model")
68
+ prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt")
69
+ prompt_type = gr.Dropdown(choices=list(alpaca_prompts.keys()), value="Chat", label="Prompt Type")
70
+ max_length = gr.Slider(minimum=128, maximum=512, step=128, value=128, label="Max Length")
71
+ generated_text = gr.Textbox(label="Generated Text")
72
+
73
+ generate_button = gr.Button("Generate")
74
+
75
+ generate_button.click(predict, inputs=[selected_model, prompt, prompt_type, max_length], outputs=generated_text)
76
+
77
+ gr.Examples(
78
+ examples=[
79
+ ["CyberSentinel", "What is SQL injection?", "information", 128],
80
+ ["CyberSentinel", "$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);", "vulnerable", 128],
81
+ ["CyberSentinel", "Can you tell me a joke?", "Chat", 128]
82
+ ],
83
+ inputs=[selected_model, prompt, prompt_type, max_length]
84
+ )
85
+
86
+ demo.queue(default_concurrency_limit=20).launch(
87
+ server_name="0.0.0.0",
88
+ allowed_paths=["/"],
89
+ share=True
90
+ )