cpi-connect
commited on
Commit
·
303b1b2
1
Parent(s):
44d9bc9
Upload model
Browse files- event_arg_predict.py +6 -6
- event_nugget_predict.py +6 -6
- event_realis_predict.py +6 -6
- model.py +2 -2
event_arg_predict.py
CHANGED
@@ -37,10 +37,10 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
|
37 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
39 |
|
40 |
-
from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
|
41 |
-
model_nugget = ArgumentModel(num_classes=43)
|
42 |
-
model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/argument_model_state_dict.pth", map_location=device))
|
43 |
-
model_nugget.eval()
|
44 |
|
45 |
"""
|
46 |
Function: create_dataloader(text_input)
|
@@ -51,9 +51,9 @@ Output:
|
|
51 |
- dataloader: A DataLoader for the tokenized and batched text data.
|
52 |
- tokenized_dataset_ner: The tokenized dataset used for training.
|
53 |
"""
|
54 |
-
def create_dataloader(text_input):
|
55 |
|
56 |
-
event_nuggets = get_event_nuggets(text_input)
|
57 |
doc = nlp(text_input)
|
58 |
|
59 |
content_as_words_emdash = [tok.text for tok in doc]
|
|
|
37 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
39 |
|
40 |
+
# from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
|
41 |
+
# model_nugget = ArgumentModel(num_classes=43)
|
42 |
+
# model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/argument_model_state_dict.pth", map_location=device))
|
43 |
+
# model_nugget.eval()
|
44 |
|
45 |
"""
|
46 |
Function: create_dataloader(text_input)
|
|
|
51 |
- dataloader: A DataLoader for the tokenized and batched text data.
|
52 |
- tokenized_dataset_ner: The tokenized dataset used for training.
|
53 |
"""
|
54 |
+
def create_dataloader(model_nugget, text_input):
|
55 |
|
56 |
+
event_nuggets = get_event_nuggets(model_nugget, text_input)
|
57 |
doc = nlp(text_input)
|
58 |
|
59 |
content_as_words_emdash = [tok.text for tok in doc]
|
event_nugget_predict.py
CHANGED
@@ -34,9 +34,9 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
|
34 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
36 |
|
37 |
-
model_nugget = NuggetModel(num_classes = 11)
|
38 |
-
model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/nugget_model_state_dict.pth", map_location=device))
|
39 |
-
model_nugget.eval()
|
40 |
|
41 |
"""
|
42 |
Function: create_dataloader(text_input)
|
@@ -133,7 +133,7 @@ Inputs:
|
|
133 |
Output:
|
134 |
- predicted_label: A tensor containing the predicted labels for the input data.
|
135 |
"""
|
136 |
-
def predict(dataloader):
|
137 |
predicted_label = []
|
138 |
for batch in dataloader:
|
139 |
with torch.no_grad():
|
@@ -202,9 +202,9 @@ Output:
|
|
202 |
- predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
|
203 |
subtype, and text content.
|
204 |
"""
|
205 |
-
def get_event_nuggets(text_input):
|
206 |
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
207 |
-
predicted_label = predict(dataloader)
|
208 |
|
209 |
predicted_event_nuggets = []
|
210 |
text_length = 0
|
|
|
34 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
36 |
|
37 |
+
# model_nugget = NuggetModel(num_classes = 11)
|
38 |
+
# model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/nugget_model_state_dict.pth", map_location=device))
|
39 |
+
# model_nugget.eval()
|
40 |
|
41 |
"""
|
42 |
Function: create_dataloader(text_input)
|
|
|
133 |
Output:
|
134 |
- predicted_label: A tensor containing the predicted labels for the input data.
|
135 |
"""
|
136 |
+
def predict(model_nugget, dataloader):
|
137 |
predicted_label = []
|
138 |
for batch in dataloader:
|
139 |
with torch.no_grad():
|
|
|
202 |
- predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
|
203 |
subtype, and text content.
|
204 |
"""
|
205 |
+
def get_event_nuggets(model_nugget, text_input):
|
206 |
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
207 |
+
predicted_label = predict(model_nugget, dataloader)
|
208 |
|
209 |
predicted_event_nuggets = []
|
210 |
text_length = 0
|
event_realis_predict.py
CHANGED
@@ -49,10 +49,10 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
|
49 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
50 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
51 |
|
52 |
-
from .realis_model_utils import CustomRobertaWithPOS as RealisModel
|
53 |
-
model_realis = RealisModel(num_classes_realis=4)
|
54 |
-
model_realis.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/realis_model_state_dict.pth", map_location=device))
|
55 |
-
model_realis.eval()
|
56 |
|
57 |
"""
|
58 |
Function: create_dataloader(text_input)
|
@@ -63,9 +63,9 @@ Output:
|
|
63 |
- dataloader: A DataLoader for the tokenized and batched text data.
|
64 |
- tokenized_dataset_ner: The tokenized dataset used for training.
|
65 |
"""
|
66 |
-
def create_dataloader(text_input):
|
67 |
|
68 |
-
event_nuggets = get_event_nuggets(text_input)
|
69 |
doc = nlp(text_input)
|
70 |
|
71 |
content_as_words_emdash = [tok.text for tok in doc]
|
|
|
49 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
50 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
51 |
|
52 |
+
# from .realis_model_utils import CustomRobertaWithPOS as RealisModel
|
53 |
+
# model_realis = RealisModel(num_classes_realis=4)
|
54 |
+
# model_realis.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/realis_model_state_dict.pth", map_location=device))
|
55 |
+
# model_realis.eval()
|
56 |
|
57 |
"""
|
58 |
Function: create_dataloader(text_input)
|
|
|
63 |
- dataloader: A DataLoader for the tokenized and batched text data.
|
64 |
- tokenized_dataset_ner: The tokenized dataset used for training.
|
65 |
"""
|
66 |
+
def create_dataloader(model_nugget, text_input):
|
67 |
|
68 |
+
event_nuggets = get_event_nuggets(model_nugget, text_input)
|
69 |
doc = nlp(text_input)
|
70 |
|
71 |
content_as_words_emdash = [tok.text for tok in doc]
|
model.py
CHANGED
@@ -61,8 +61,8 @@ class CybersecurityKnowledgeGraphModel(PreTrainedModel):
|
|
61 |
|
62 |
def forward(self, text):
|
63 |
nugget_dataloader, _ = self.event_nugget_dataloader(text)
|
64 |
-
argument_dataloader, _ = self.event_argument_dataloader(text)
|
65 |
-
realis_dataloader, _ = self.event_realis_dataloader(text)
|
66 |
|
67 |
nugget_pred = self.forward_model(self.event_nugget_model, nugget_dataloader)
|
68 |
no_nuggets = torch.all(nugget_pred == 0, dim=1)
|
|
|
61 |
|
62 |
def forward(self, text):
|
63 |
nugget_dataloader, _ = self.event_nugget_dataloader(text)
|
64 |
+
argument_dataloader, _ = self.event_argument_dataloader(self.event_nugget_model, text)
|
65 |
+
realis_dataloader, _ = self.event_realis_dataloader(self.event_nugget_model, text)
|
66 |
|
67 |
nugget_pred = self.forward_model(self.event_nugget_model, nugget_dataloader)
|
68 |
no_nuggets = torch.all(nugget_pred == 0, dim=1)
|