cpi-connect commited on
Commit
303b1b2
·
1 Parent(s): 44d9bc9

Upload model

Browse files
event_arg_predict.py CHANGED
@@ -37,10 +37,10 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
37
  model_checkpoint = "ehsanaghaei/SecureBERT"
38
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
39
 
40
- from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
41
- model_nugget = ArgumentModel(num_classes=43)
42
- model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/argument_model_state_dict.pth", map_location=device))
43
- model_nugget.eval()
44
 
45
  """
46
  Function: create_dataloader(text_input)
@@ -51,9 +51,9 @@ Output:
51
  - dataloader: A DataLoader for the tokenized and batched text data.
52
  - tokenized_dataset_ner: The tokenized dataset used for training.
53
  """
54
- def create_dataloader(text_input):
55
 
56
- event_nuggets = get_event_nuggets(text_input)
57
  doc = nlp(text_input)
58
 
59
  content_as_words_emdash = [tok.text for tok in doc]
 
37
  model_checkpoint = "ehsanaghaei/SecureBERT"
38
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
39
 
40
+ # from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
41
+ # model_nugget = ArgumentModel(num_classes=43)
42
+ # model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/argument_model_state_dict.pth", map_location=device))
43
+ # model_nugget.eval()
44
 
45
  """
46
  Function: create_dataloader(text_input)
 
51
  - dataloader: A DataLoader for the tokenized and batched text data.
52
  - tokenized_dataset_ner: The tokenized dataset used for training.
53
  """
54
+ def create_dataloader(model_nugget, text_input):
55
 
56
+ event_nuggets = get_event_nuggets(model_nugget, text_input)
57
  doc = nlp(text_input)
58
 
59
  content_as_words_emdash = [tok.text for tok in doc]
event_nugget_predict.py CHANGED
@@ -34,9 +34,9 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
34
  model_checkpoint = "ehsanaghaei/SecureBERT"
35
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
36
 
37
- model_nugget = NuggetModel(num_classes = 11)
38
- model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/nugget_model_state_dict.pth", map_location=device))
39
- model_nugget.eval()
40
 
41
  """
42
  Function: create_dataloader(text_input)
@@ -133,7 +133,7 @@ Inputs:
133
  Output:
134
  - predicted_label: A tensor containing the predicted labels for the input data.
135
  """
136
- def predict(dataloader):
137
  predicted_label = []
138
  for batch in dataloader:
139
  with torch.no_grad():
@@ -202,9 +202,9 @@ Output:
202
  - predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
203
  subtype, and text content.
204
  """
205
- def get_event_nuggets(text_input):
206
  dataloader, tokenized_dataset_ner = create_dataloader(text_input)
207
- predicted_label = predict(dataloader)
208
 
209
  predicted_event_nuggets = []
210
  text_length = 0
 
34
  model_checkpoint = "ehsanaghaei/SecureBERT"
35
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
36
 
37
+ # model_nugget = NuggetModel(num_classes = 11)
38
+ # model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/nugget_model_state_dict.pth", map_location=device))
39
+ # model_nugget.eval()
40
 
41
  """
42
  Function: create_dataloader(text_input)
 
133
  Output:
134
  - predicted_label: A tensor containing the predicted labels for the input data.
135
  """
136
+ def predict(model_nugget, dataloader):
137
  predicted_label = []
138
  for batch in dataloader:
139
  with torch.no_grad():
 
202
  - predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
203
  subtype, and text content.
204
  """
205
+ def get_event_nuggets(model_nugget, text_input):
206
  dataloader, tokenized_dataset_ner = create_dataloader(text_input)
207
+ predicted_label = predict(model_nugget, dataloader)
208
 
209
  predicted_event_nuggets = []
210
  text_length = 0
event_realis_predict.py CHANGED
@@ -49,10 +49,10 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
49
  model_checkpoint = "ehsanaghaei/SecureBERT"
50
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
51
 
52
- from .realis_model_utils import CustomRobertaWithPOS as RealisModel
53
- model_realis = RealisModel(num_classes_realis=4)
54
- model_realis.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/realis_model_state_dict.pth", map_location=device))
55
- model_realis.eval()
56
 
57
  """
58
  Function: create_dataloader(text_input)
@@ -63,9 +63,9 @@ Output:
63
  - dataloader: A DataLoader for the tokenized and batched text data.
64
  - tokenized_dataset_ner: The tokenized dataset used for training.
65
  """
66
- def create_dataloader(text_input):
67
 
68
- event_nuggets = get_event_nuggets(text_input)
69
  doc = nlp(text_input)
70
 
71
  content_as_words_emdash = [tok.text for tok in doc]
 
49
  model_checkpoint = "ehsanaghaei/SecureBERT"
50
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
51
 
52
+ # from .realis_model_utils import CustomRobertaWithPOS as RealisModel
53
+ # model_realis = RealisModel(num_classes_realis=4)
54
+ # model_realis.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/realis_model_state_dict.pth", map_location=device))
55
+ # model_realis.eval()
56
 
57
  """
58
  Function: create_dataloader(text_input)
 
63
  - dataloader: A DataLoader for the tokenized and batched text data.
64
  - tokenized_dataset_ner: The tokenized dataset used for training.
65
  """
66
+ def create_dataloader(model_nugget, text_input):
67
 
68
+ event_nuggets = get_event_nuggets(model_nugget, text_input)
69
  doc = nlp(text_input)
70
 
71
  content_as_words_emdash = [tok.text for tok in doc]
model.py CHANGED
@@ -61,8 +61,8 @@ class CybersecurityKnowledgeGraphModel(PreTrainedModel):
61
 
62
  def forward(self, text):
63
  nugget_dataloader, _ = self.event_nugget_dataloader(text)
64
- argument_dataloader, _ = self.event_argument_dataloader(text)
65
- realis_dataloader, _ = self.event_realis_dataloader(text)
66
 
67
  nugget_pred = self.forward_model(self.event_nugget_model, nugget_dataloader)
68
  no_nuggets = torch.all(nugget_pred == 0, dim=1)
 
61
 
62
  def forward(self, text):
63
  nugget_dataloader, _ = self.event_nugget_dataloader(text)
64
+ argument_dataloader, _ = self.event_argument_dataloader(self.event_nugget_model, text)
65
+ realis_dataloader, _ = self.event_realis_dataloader(self.event_nugget_model, text)
66
 
67
  nugget_pred = self.forward_model(self.event_nugget_model, nugget_dataloader)
68
  no_nuggets = torch.all(nugget_pred == 0, dim=1)