sunny-annie commited on
Commit
6b9e663
·
1 Parent(s): 707768a

upload bert_func.py

Browse files
Files changed (1) hide show
  1. bert_func.py +37 -0
bert_func.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import transformers
4
+ import joblib
5
+
6
+ def load_bert_lr_model(path):
7
+
8
+ model_class = transformers.BertModel
9
+ tokenizer_class = transformers.BertTokenizer
10
+ pretrained_weights = 'bert-base-uncased'
11
+ model = model_class.from_pretrained(pretrained_weights)
12
+ tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
13
+ lr = joblib.load(path)
14
+ return model, tokenizer, lr
15
+
16
+ def prediction(text, model, tokenizer, lr, max_len=256):
17
+ input = tokenizer.encode(text,
18
+ add_special_tokens=True,
19
+ padding='max_length',
20
+ truncation=True,
21
+ return_tensors='np',
22
+ max_length=max_len)
23
+
24
+ att_mask = np.where(input != 0, 1, 0)
25
+ input = torch.tensor(input)
26
+ att_mask = torch.tensor(att_mask)
27
+
28
+ last_hidden_states = model(input, attention_mask=att_mask)
29
+ vector = last_hidden_states[0][:,0,:].detach().numpy()
30
+ pred = lr.predict(vector)[0]
31
+
32
+ if pred == 1:
33
+ result = 'Positive review'
34
+ else:
35
+ result = 'Negative review'
36
+
37
+ return result