joyinning commited on
Commit
85b50eb
·
verified ·
1 Parent(s): 6eb0b14

Update model_utils.py

Browse files

Add load_model() function without bert_model

Files changed (1) hide show
  1. model_utils.py +26 -0
model_utils.py CHANGED
@@ -63,7 +63,33 @@ class BiLSTMForTokenClassification(nn.Module):
63
 
64
  return {'loss': loss, 'logits': logits}
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def load_custom_model(model_dir, tokenizer_dir, id2label):
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
68
  config.id2label = id2label
69
  config.num_labels = len(id2label)
 
63
 
64
  return {'loss': loss, 'logits': logits}
65
 
66
+ def load_models():
67
+ """
68
+ Loads the custom BiLSTM model
69
+
70
+ Returns:
71
+ bilstm: The loaded bilstm model.
72
+ """
73
+ with open('models/bilstm-model.pkl', 'rb') as f:
74
+ bilstm_model = pickle.load(f)
75
+ bilstm_model.eval()
76
+
77
+ return bilstm_model
78
+
79
  def load_custom_model(model_dir, tokenizer_dir, id2label):
80
+ """
81
+ Loads a custom BiLSTM model and tokenizer from local files.
82
+
83
+ Args:
84
+ model_dir (str): Path to the directory containing the model's files.
85
+ tokenizer_dir (str): Path to the directory containing the tokenizer's files.
86
+ id2label (dict): Dictionary mapping label IDs to their names.
87
+
88
+ Returns:
89
+ model: Loaded BiLSTMForTokenClassification model.
90
+ tokenizer: Loaded AutoTokenizer.
91
+ id2label: Input id2label dictionary.
92
+ """
93
  config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
94
  config.id2label = id2label
95
  config.num_labels = len(id2label)