mjlee
commited on
Commit
·
8966d80
1
Parent(s):
297680e
0708_9
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ from models import *
|
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
import os
|
5 |
from config import *
|
6 |
-
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
7 |
|
8 |
ENTITY_REPO_ID = 'vaivTA/absa_v2_entity'
|
9 |
ENTITY_FILENAME = "entity_model.pt"
|
@@ -19,18 +18,18 @@ base_model = base_model
|
|
19 |
|
20 |
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
21 |
|
22 |
-
sen_model = Classifier(base_model, num_labels=2, device=
|
23 |
sen_model.load_state_dict(torch.load(sen_model_file))
|
24 |
|
25 |
-
entity_model = Classifier(base_model, num_labels=2, device=
|
26 |
entity_model.load_state_dict(torch.load(entity_model_file))
|
27 |
|
28 |
|
29 |
def infer(test_sentence):
|
30 |
-
entity_model.to(device)
|
31 |
-
entity_model.eval()
|
32 |
-
sen_model.to(device)
|
33 |
-
sen_model.eval()
|
34 |
|
35 |
form = test_sentence
|
36 |
annotation = []
|
|
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
import os
|
5 |
from config import *
|
|
|
6 |
|
7 |
ENTITY_REPO_ID = 'vaivTA/absa_v2_entity'
|
8 |
ENTITY_FILENAME = "entity_model.pt"
|
|
|
18 |
|
19 |
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
20 |
|
21 |
+
sen_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer)
|
22 |
sen_model.load_state_dict(torch.load(sen_model_file))
|
23 |
|
24 |
+
entity_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer)
|
25 |
entity_model.load_state_dict(torch.load(entity_model_file))
|
26 |
|
27 |
|
28 |
def infer(test_sentence):
|
29 |
+
# entity_model.to(device)
|
30 |
+
# entity_model.eval()
|
31 |
+
# sen_model.to(device)
|
32 |
+
# sen_model.eval()
|
33 |
|
34 |
form = test_sentence
|
35 |
annotation = []
|