Update README.md
Browse files
README.md
CHANGED
@@ -172,12 +172,41 @@ To be anounced...
|
|
172 |
| Classification Accuracy (Test) | 0.7161 |
|
173 |
| Weighted F1 Score | [More Information Needed] |
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
### Single Prediction Example
|
176 |
|
177 |
```python
|
178 |
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
def predict_group(sentence):
|
183 |
classifier_model.eval()
|
@@ -192,17 +221,17 @@ def predict_group(sentence):
|
|
192 |
attention_mask = encoding["attention_mask"].to(device)
|
193 |
|
194 |
with torch.no_grad():
|
195 |
-
|
|
|
196 |
predicted_label = torch.argmax(logits, dim=1).cpu().item()
|
197 |
|
198 |
-
|
199 |
-
# Explicitly convert numeric label to original GroupID
|
200 |
-
predicted_groupid = label_to_groupid[predicted_label]
|
201 |
return predicted_groupid
|
202 |
|
|
|
203 |
sentence = "APT38 has used phishing emails with malicious links to distribute malware."
|
204 |
predicted_class = predict_group(sentence)
|
205 |
-
print(f"Predicted GroupID: {predicted_class}")
|
206 |
```
|
207 |
|
208 |
## Environmental Impact
|
|
|
172 |
| Classification Accuracy (Test) | 0.7161 |
|
173 |
| Weighted F1 Score | [More Information Needed] |
|
174 |
|
175 |
+
|
176 |
+
Embedding Variability Accuracy
|
177 |
+
Original MPNet 0.092721 0.998611
|
178 |
+
MLM Fine-tuned MPNet 0.034983 0.653611
|
179 |
+
Classification Fine-tuned MPNet 0.193065 0.950833
|
180 |
+
SecBERT 0.591303 0.988611
|
181 |
+
ATTACK-BERT 0.096108 0.967778
|
182 |
+
|
183 |
### Single Prediction Example
|
184 |
|
185 |
```python
|
186 |
|
187 |
+
import torch
|
188 |
+
import torch.nn as nn
|
189 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
190 |
+
import torch.optim as optim
|
191 |
+
import numpy as np
|
192 |
+
from huggingface_hub import hf_hub_download
|
193 |
+
import json
|
194 |
+
|
195 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
196 |
+
# Load explicitly your fine-tuned MPNet model
|
197 |
+
classifier_model = AutoModelForSequenceClassification.from_pretrained("selfconstruct3d/AttackGroup-MPNET").to(device)
|
198 |
+
|
199 |
+
# Load explicitly your tokenizer
|
200 |
+
tokenizer = AutoTokenizer.from_pretrained("selfconstruct3d/AttackGroup-MPNET")
|
201 |
+
|
202 |
+
|
203 |
+
label_to_groupid_file = hf_hub_download(
|
204 |
+
repo_id="selfconstruct3d/AttackGroup-MPNET",
|
205 |
+
filename="label_to_groupid.json"
|
206 |
+
)
|
207 |
+
|
208 |
+
with open(label_to_groupid_file, "r") as f:
|
209 |
+
label_to_groupid = json.load(f)
|
210 |
|
211 |
def predict_group(sentence):
|
212 |
classifier_model.eval()
|
|
|
221 |
attention_mask = encoding["attention_mask"].to(device)
|
222 |
|
223 |
with torch.no_grad():
|
224 |
+
outputs = classifier_model(input_ids=input_ids, attention_mask=attention_mask)
|
225 |
+
logits = outputs.logits
|
226 |
predicted_label = torch.argmax(logits, dim=1).cpu().item()
|
227 |
|
228 |
+
predicted_groupid = label_to_groupid[str(predicted_label)]
|
|
|
|
|
229 |
return predicted_groupid
|
230 |
|
231 |
+
# Example usage explicitly:
|
232 |
sentence = "APT38 has used phishing emails with malicious links to distribute malware."
|
233 |
predicted_class = predict_group(sentence)
|
234 |
+
print(f"Predicted GroupID: {predicted_class}")
|
235 |
```
|
236 |
|
237 |
## Environmental Impact
|