selfconstruct3d commited on
Commit
f2c6a2d
·
verified ·
1 Parent(s): 9af3e83

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -1
README.md CHANGED
@@ -50,7 +50,7 @@ This model specializes in cybersecurity contexts. Predictions for unrelated cont
50
 
51
  Always verify predictions with cybersecurity analysts before using in critical decision-making scenarios.
52
 
53
- ## How to Get Started with the Model
54
 
55
  ```python
56
  import torch
@@ -104,6 +104,46 @@ print(f"Predicted GroupID: {predicted_class}")
104
  ```
105
  Predicted GroupID: G0001
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  ## Training Details
109
 
 
50
 
51
  Always verify predictions with cybersecurity analysts before using in critical decision-making scenarios.
52
 
53
+ ## How to Get Started with the Model (Classification)
54
 
55
  ```python
56
  import torch
 
104
  ```
105
  Predicted GroupID: G0001
106
 
107
+ ## How to Get Started with the Model (Embeddings)
108
+
109
+ ```python
110
+ import torch
111
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
112
+
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+
115
+ # Load your fine-tuned classification model
116
+ model_name = "selfconstruct3d/AttackGroup-MPNET"
117
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
118
+ classifier_model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
119
+
120
+ def get_embedding(sentence):
121
+ classifier_model.eval()
122
+
123
+ encoding = tokenizer(
124
+ sentence,
125
+ truncation=True,
126
+ padding="max_length",
127
+ max_length=128,
128
+ return_tensors="pt"
129
+ )
130
+ input_ids = encoding["input_ids"].to(device)
131
+ attention_mask = encoding["attention_mask"].to(device)
132
+
133
+ with torch.no_grad():
134
+ outputs = classifier_model.mpnet(input_ids=input_ids, attention_mask=attention_mask)
135
+ cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy().flatten()
136
+
137
+ return cls_embedding
138
+
139
+ # Example explicitly:
140
+ sentence = "APT38 has used phishing emails with malicious links to distribute malware."
141
+ embedding = get_embedding(sentence)
142
+ print("Embedding shape:", embedding.shape)
143
+ print("Embedding values:", embedding)
144
+ ```
145
+
146
+
147
 
148
  ## Training Details
149