cpi-connect commited on
Commit
4fa0a53
·
1 Parent(s): 74f26df

Upload model

Browse files
Files changed (3) hide show
  1. config.json +3 -3
  2. configuration.py +1 -1
  3. model.py +9 -9
config.json CHANGED
@@ -146,7 +146,7 @@
146
  "B-GPE",
147
  "I-GPE"
148
  ],
149
- "event_argument_model_path": "cybersecurity_knowledge_graph/argument_model_state_dict.pth",
150
  "event_nugget_list": [
151
  "O",
152
  "B-Ransom",
@@ -160,8 +160,8 @@
160
  "B-Phishing",
161
  "I-Phishing"
162
  ],
163
- "event_nugget_model_path": "cybersecurity_knowledge_graph/nugget_model_state_dict.pth",
164
- "event_realis_model_path": "cybersecurity_knowledge_graph/realis_model_state_dict.pth",
165
  "model_type": "cybersecurity_knowledge_graph",
166
  "realis_list": [
167
  "O",
 
146
  "B-GPE",
147
  "I-GPE"
148
  ],
149
+ "event_argument_model_path": "argument_model_state_dict.pth",
150
  "event_nugget_list": [
151
  "O",
152
  "B-Ransom",
 
160
  "B-Phishing",
161
  "I-Phishing"
162
  ],
163
+ "event_nugget_model_path": "nugget_model_state_dict.pth",
164
+ "event_realis_model_path": "realis_model_state_dict.pth",
165
  "model_type": "cybersecurity_knowledge_graph",
166
  "realis_list": [
167
  "O",
configuration.py CHANGED
@@ -1,7 +1,7 @@
1
  from transformers import PretrainedConfig
2
  import torch
3
 
4
- from cybersecurity_knowledge_graph.utils import event_args_list, event_nugget_list, realis_list, arg_2_role
5
 
6
 
7
  class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
 
1
  from transformers import PretrainedConfig
2
  import torch
3
 
4
+ from utils import event_args_list, event_nugget_list, realis_list, arg_2_role
5
 
6
 
7
  class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
model.py CHANGED
@@ -6,15 +6,15 @@ from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer
7
 
8
 
9
- from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS as NuggetModel
10
- from cybersecurity_knowledge_graph.args_model_utils import CustomRobertaWithPOS as ArgumentModel
11
- from cybersecurity_knowledge_graph.realis_model_utils import CustomRobertaWithPOS as RealisModel
12
 
13
- from .configuration import CybersecurityKnowledgeGraphConfig
14
 
15
- from cybersecurity_knowledge_graph.event_nugget_predict import create_dataloader as event_nugget_dataloader
16
- from cybersecurity_knowledge_graph.event_realis_predict import create_dataloader as event_realis_dataloader
17
- from cybersecurity_knowledge_graph.event_arg_predict import create_dataloader as event_argument_dataloader
18
 
19
  class CybersecurityKnowledgeGraphModel(PreTrainedModel):
20
  config_class = CybersecurityKnowledgeGraphConfig
@@ -40,7 +40,7 @@ class CybersecurityKnowledgeGraphModel(PreTrainedModel):
40
  self.event_argument_model.load_state_dict(torch.load(self.event_argument_model_path))
41
 
42
  role_classifiers = {}
43
- folder_path = '/cybersecurity_knowledge_graph/arg_role_models'
44
 
45
  for filename in os.listdir(os.getcwd() + folder_path):
46
  if filename.endswith('.joblib'):
@@ -50,7 +50,7 @@ class CybersecurityKnowledgeGraphModel(PreTrainedModel):
50
  role_classifiers[arg] = clf
51
 
52
  self.role_classifiers = role_classifiers
53
- self.embed_model = SentenceTransformer('sentence_transformer')
54
 
55
 
56
  self.event_nugget_list = config.event_nugget_list
 
6
  from transformers import AutoTokenizer
7
 
8
 
9
+ from nugget_model_utils import CustomRobertaWithPOS as NuggetModel
10
+ from args_model_utils import CustomRobertaWithPOS as ArgumentModel
11
+ from realis_model_utils import CustomRobertaWithPOS as RealisModel
12
 
13
+ from configuration import CybersecurityKnowledgeGraphConfig
14
 
15
+ from event_nugget_predict import create_dataloader as event_nugget_dataloader
16
+ from event_realis_predict import create_dataloader as event_realis_dataloader
17
+ from event_arg_predict import create_dataloader as event_argument_dataloader
18
 
19
  class CybersecurityKnowledgeGraphModel(PreTrainedModel):
20
  config_class = CybersecurityKnowledgeGraphConfig
 
40
  self.event_argument_model.load_state_dict(torch.load(self.event_argument_model_path))
41
 
42
  role_classifiers = {}
43
+ folder_path = '/arg_role_models'
44
 
45
  for filename in os.listdir(os.getcwd() + folder_path):
46
  if filename.endswith('.joblib'):
 
50
  role_classifiers[arg] = clf
51
 
52
  self.role_classifiers = role_classifiers
53
+ self.embed_model = SentenceTransformer('all-MiniLM-L6-v2')
54
 
55
 
56
  self.event_nugget_list = config.event_nugget_list