Update README.md
Browse files
README.md
CHANGED
@@ -50,7 +50,7 @@ This model specializes in cybersecurity contexts. Predictions for unrelated cont
|
|
50 |
|
51 |
Always verify predictions with cybersecurity analysts before using in critical decision-making scenarios.
|
52 |
|
53 |
-
## How to Get Started with the Model
|
54 |
|
55 |
```python
|
56 |
import torch
|
@@ -104,6 +104,46 @@ print(f"Predicted GroupID: {predicted_class}")
|
|
104 |
```
|
105 |
Predicted GroupID: G0001
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
## Training Details
|
109 |
|
|
|
50 |
|
51 |
Always verify predictions with cybersecurity analysts before using in critical decision-making scenarios.
|
52 |
|
53 |
+
## How to Get Started with the Model (Classification)
|
54 |
|
55 |
```python
|
56 |
import torch
|
|
|
104 |
```
|
105 |
Predicted GroupID: G0001
|
106 |
|
107 |
+
## How to Get Started with the Model (Embeddings)
|
108 |
+
|
109 |
+
```python
|
110 |
+
import torch
|
111 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
112 |
+
|
113 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
114 |
+
|
115 |
+
# Load your fine-tuned classification model
|
116 |
+
model_name = "selfconstruct3d/AttackGroup-MPNET"
|
117 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
118 |
+
classifier_model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
|
119 |
+
|
120 |
+
def get_embedding(sentence):
|
121 |
+
classifier_model.eval()
|
122 |
+
|
123 |
+
encoding = tokenizer(
|
124 |
+
sentence,
|
125 |
+
truncation=True,
|
126 |
+
padding="max_length",
|
127 |
+
max_length=128,
|
128 |
+
return_tensors="pt"
|
129 |
+
)
|
130 |
+
input_ids = encoding["input_ids"].to(device)
|
131 |
+
attention_mask = encoding["attention_mask"].to(device)
|
132 |
+
|
133 |
+
with torch.no_grad():
|
134 |
+
outputs = classifier_model.mpnet(input_ids=input_ids, attention_mask=attention_mask)
|
135 |
+
cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy().flatten()
|
136 |
+
|
137 |
+
return cls_embedding
|
138 |
+
|
139 |
+
# Example explicitly:
|
140 |
+
sentence = "APT38 has used phishing emails with malicious links to distribute malware."
|
141 |
+
embedding = get_embedding(sentence)
|
142 |
+
print("Embedding shape:", embedding.shape)
|
143 |
+
print("Embedding values:", embedding)
|
144 |
+
```
|
145 |
+
|
146 |
+
|
147 |
|
148 |
## Training Details
|
149 |
|