mjlee commited on
Commit
8966d80
·
1 Parent(s): 297680e
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -3,7 +3,6 @@ from models import *
3
  from huggingface_hub import hf_hub_download
4
  import os
5
  from config import *
6
- device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
7
 
8
  ENTITY_REPO_ID = 'vaivTA/absa_v2_entity'
9
  ENTITY_FILENAME = "entity_model.pt"
@@ -19,18 +18,18 @@ base_model = base_model
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(base_model)
21
 
22
- sen_model = Classifier(base_model, num_labels=2, device=device, tokenizer=tokenizer)
23
  sen_model.load_state_dict(torch.load(sen_model_file))
24
 
25
- entity_model = Classifier(base_model, num_labels=2, device=device, tokenizer=tokenizer)
26
  entity_model.load_state_dict(torch.load(entity_model_file))
27
 
28
 
29
  def infer(test_sentence):
30
- entity_model.to(device)
31
- entity_model.eval()
32
- sen_model.to(device)
33
- sen_model.eval()
34
 
35
  form = test_sentence
36
  annotation = []
 
3
  from huggingface_hub import hf_hub_download
4
  import os
5
  from config import *
 
6
 
7
  ENTITY_REPO_ID = 'vaivTA/absa_v2_entity'
8
  ENTITY_FILENAME = "entity_model.pt"
 
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(base_model)
20
 
21
+ sen_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer)
22
  sen_model.load_state_dict(torch.load(sen_model_file))
23
 
24
+ entity_model = Classifier(base_model, num_labels=2, device='cpu', tokenizer=tokenizer)
25
  entity_model.load_state_dict(torch.load(entity_model_file))
26
 
27
 
28
  def infer(test_sentence):
29
+ # entity_model.to(device)
30
+ # entity_model.eval()
31
+ # sen_model.to(device)
32
+ # sen_model.eval()
33
 
34
  form = test_sentence
35
  annotation = []