joyinning commited on
Commit
c9f64e3
·
verified ·
1 Parent(s): 85b50eb

Update model_utils.py

Browse files

Add new load_models() model

Files changed (1) hide show
  1. model_utils.py +12 -34
model_utils.py CHANGED
@@ -64,47 +64,25 @@ class BiLSTMForTokenClassification(nn.Module):
64
  return {'loss': loss, 'logits': logits}
65
 
66
  def load_models():
67
- """
68
- Loads the custom BiLSTM model
69
-
70
- Returns:
71
- bilstm: The loaded bilstm model.
72
- """
73
- with open('models/bilstm-model.pkl', 'rb') as f:
74
- bilstm_model = pickle.load(f)
75
- bilstm_model.eval()
76
-
77
- return bilstm_model
78
-
79
- def load_custom_model(model_dir, tokenizer_dir, id2label):
80
- """
81
- Loads a custom BiLSTM model and tokenizer from local files.
82
 
83
- Args:
84
- model_dir (str): Path to the directory containing the model's files.
85
- tokenizer_dir (str): Path to the directory containing the tokenizer's files.
86
- id2label (dict): Dictionary mapping label IDs to their names.
 
87
 
88
- Returns:
89
- model: Loaded BiLSTMForTokenClassification model.
90
- tokenizer: Loaded AutoTokenizer.
91
- id2label: Input id2label dictionary.
92
- """
93
  config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
94
- config.id2label = id2label
95
- config.num_labels = len(id2label)
96
 
97
  model = BiLSTMForTokenClassification(model_name=config._name_or_path, num_labels=config.num_labels)
98
- model.config.id2label = id2label
99
  model.load_state_dict(torch.load(os.path.join(model_dir, 'pytorch_model.bin'), map_location=torch.device('cpu')))
100
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only = True)
101
-
102
- return model, tokenizer, id2label
103
 
104
- ner_model_dir = "./models/bilstm_ner"
105
- tokenizer_dir = "./models/tokenizer"
106
- id2label_ner = {0: 'O', 1: 'I-art', 2: 'B-org', 3: 'B-geo', 4: 'I-per', 5: 'B-eve', 6: 'I-geo', 7: 'B-per', 8: 'I-nat', 9: 'B-art', 10: 'B-tim', 11: 'I-gpe', 12: 'I-tim', 13: 'B-nat', 14: 'B-gpe', 15: 'I-org', 16: 'I-eve'}
107
- ner_model, ner_tokenizer, id2label_ner = load_custom_model(ner_model_dir, tokenizer_dir, id2label_ner)
108
 
109
  # QA model
110
  qa_model = pipeline('question-answering', model='deepset/bert-base-cased-squad2')
 
64
  return {'loss': loss, 'logits': logits}
65
 
66
  def load_models():
67
+ model_dir = "./models/bilstm_ner"
68
+ tokenizer_dir = "./models/tokenizer"
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ id2label_ner = {
71
+ 0: 'O', 1: 'I-art', 2: 'B-org', 3: 'B-geo', 4: 'I-per', 5: 'B-eve',
72
+ 6: 'I-geo', 7: 'B-per', 8: 'I-nat', 9: 'B-art', 10: 'B-tim', 11: 'I-gpe',
73
+ 12: 'I-tim', 13: 'B-nat', 14: 'B-gpe', 15: 'I-org', 16: 'I-eve'
74
+ }
75
 
 
 
 
 
 
76
  config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
77
+ config.id2label = id2label_ner
78
+ config.num_labels = len(id2label_ner)
79
 
80
  model = BiLSTMForTokenClassification(model_name=config._name_or_path, num_labels=config.num_labels)
81
+ model.config.id2label = id2label_ner
82
  model.load_state_dict(torch.load(os.path.join(model_dir, 'pytorch_model.bin'), map_location=torch.device('cpu')))
83
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True)
 
 
84
 
85
+ return model, tokenizer, id2label_ner
 
 
 
86
 
87
  # QA model
88
  qa_model = pipeline('question-answering', model='deepset/bert-base-cased-squad2')