Spaces:
Sleeping
Sleeping
Update model_utils.py
Browse filesAdd new load_models() model
- model_utils.py +12 -34
model_utils.py
CHANGED
@@ -64,47 +64,25 @@ class BiLSTMForTokenClassification(nn.Module):
|
|
64 |
return {'loss': loss, 'logits': logits}
|
65 |
|
66 |
def load_models():
|
67 |
-
""
|
68 |
-
|
69 |
-
|
70 |
-
Returns:
|
71 |
-
bilstm: The loaded bilstm model.
|
72 |
-
"""
|
73 |
-
with open('models/bilstm-model.pkl', 'rb') as f:
|
74 |
-
bilstm_model = pickle.load(f)
|
75 |
-
bilstm_model.eval()
|
76 |
-
|
77 |
-
return bilstm_model
|
78 |
-
|
79 |
-
def load_custom_model(model_dir, tokenizer_dir, id2label):
|
80 |
-
"""
|
81 |
-
Loads a custom BiLSTM model and tokenizer from local files.
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
87 |
|
88 |
-
Returns:
|
89 |
-
model: Loaded BiLSTMForTokenClassification model.
|
90 |
-
tokenizer: Loaded AutoTokenizer.
|
91 |
-
id2label: Input id2label dictionary.
|
92 |
-
"""
|
93 |
config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
|
94 |
-
config.id2label =
|
95 |
-
config.num_labels = len(
|
96 |
|
97 |
model = BiLSTMForTokenClassification(model_name=config._name_or_path, num_labels=config.num_labels)
|
98 |
-
model.config.id2label =
|
99 |
model.load_state_dict(torch.load(os.path.join(model_dir, 'pytorch_model.bin'), map_location=torch.device('cpu')))
|
100 |
-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only
|
101 |
-
|
102 |
-
return model, tokenizer, id2label
|
103 |
|
104 |
-
|
105 |
-
tokenizer_dir = "./models/tokenizer"
|
106 |
-
id2label_ner = {0: 'O', 1: 'I-art', 2: 'B-org', 3: 'B-geo', 4: 'I-per', 5: 'B-eve', 6: 'I-geo', 7: 'B-per', 8: 'I-nat', 9: 'B-art', 10: 'B-tim', 11: 'I-gpe', 12: 'I-tim', 13: 'B-nat', 14: 'B-gpe', 15: 'I-org', 16: 'I-eve'}
|
107 |
-
ner_model, ner_tokenizer, id2label_ner = load_custom_model(ner_model_dir, tokenizer_dir, id2label_ner)
|
108 |
|
109 |
# QA model
|
110 |
qa_model = pipeline('question-answering', model='deepset/bert-base-cased-squad2')
|
|
|
64 |
return {'loss': loss, 'logits': logits}
|
65 |
|
66 |
def load_models():
|
67 |
+
model_dir = "./models/bilstm_ner"
|
68 |
+
tokenizer_dir = "./models/tokenizer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
id2label_ner = {
|
71 |
+
0: 'O', 1: 'I-art', 2: 'B-org', 3: 'B-geo', 4: 'I-per', 5: 'B-eve',
|
72 |
+
6: 'I-geo', 7: 'B-per', 8: 'I-nat', 9: 'B-art', 10: 'B-tim', 11: 'I-gpe',
|
73 |
+
12: 'I-tim', 13: 'B-nat', 14: 'B-gpe', 15: 'I-org', 16: 'I-eve'
|
74 |
+
}
|
75 |
|
|
|
|
|
|
|
|
|
|
|
76 |
config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
|
77 |
+
config.id2label = id2label_ner
|
78 |
+
config.num_labels = len(id2label_ner)
|
79 |
|
80 |
model = BiLSTMForTokenClassification(model_name=config._name_or_path, num_labels=config.num_labels)
|
81 |
+
model.config.id2label = id2label_ner
|
82 |
model.load_state_dict(torch.load(os.path.join(model_dir, 'pytorch_model.bin'), map_location=torch.device('cpu')))
|
83 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True)
|
|
|
|
|
84 |
|
85 |
+
return model, tokenizer, id2label_ner
|
|
|
|
|
|
|
86 |
|
87 |
# QA model
|
88 |
qa_model = pipeline('question-answering', model='deepset/bert-base-cased-squad2')
|