hHoai commited on
Commit
5ca192d
·
verified ·
1 Parent(s): 68c0ef8

Update bartpho/utils.py

Browse files
Files changed (1) hide show
  1. bartpho/utils.py +9 -10
bartpho/utils.py CHANGED
@@ -51,28 +51,27 @@ def predict(model, text, tokenizer, model_tokenize=None, processed=True, printou
51
  for i in range(no_tag):
52
  tag = tags[i]
53
  score_list = []
54
- input_ids = tokenizer([text] * no_polarity, return_tensors='pt')['input_ids']
 
55
  target_list = ["Nhận_xét " + tag.lower() + " " + polarity.lower() + " ." for polarity in polarity_list]
56
- output_ids = tokenizer(target_list, return_tensors='pt', padding=True, truncation=True)['input_ids']
57
 
58
  with torch.no_grad():
59
  output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids.to(device))[0]
60
  logits = output.softmax(dim=-1).to('cpu').numpy()
61
  for m in range(no_polarity):
62
- score = 1
63
- for n in range(logits[m].shape[0] - 2):
64
- score *= logits[m][n][output_ids[m][n+1]]
65
  score_list.append(score)
66
- predict = np.argmax(score_list)
67
  predicts.append(predict)
68
 
69
  if printout:
70
  result = {}
71
  for i in range(no_tag):
72
- if predicts[i] != 0:
73
- result[eng_tags[i]] = eng_polarity[predicts[i]]
74
- print(result)
75
- return predicts
76
 
77
  def predict_df(model, df, tokenizer=None, model_tokenize=None, tokenizer_name='vinai/bartpho-word-base', processed=True, printout=True):
78
  model.eval()
 
51
  for i in range(no_tag):
52
  tag = tags[i]
53
  score_list = []
54
+
55
+ input_ids = tokenizer([text] * no_polarity, return_tensors='pt')['input_ids'].to(device)
56
  target_list = ["Nhận_xét " + tag.lower() + " " + polarity.lower() + " ." for polarity in polarity_list]
57
+ output_ids = tokenizer(target_list, return_tensors='pt', padding=True, truncation=True)['input_ids'].to(device)
58
 
59
  with torch.no_grad():
60
  output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids.to(device))[0]
61
  logits = output.softmax(dim=-1).to('cpu').numpy()
62
  for m in range(no_polarity):
63
+ score = np.sum(np.log(logits[m][range(len(output_ids[m]) - 2), output_ids[m][1:-1]]))
 
 
64
  score_list.append(score)
65
+ predict = int(np.argmax(score_list)) # Ép kiểu sang int
66
  predicts.append(predict)
67
 
68
  if printout:
69
  result = {}
70
  for i in range(no_tag):
71
+ if predicts[i] != 0: # Bỏ qua các nhãn không có cảm xúc (mặc định 0)
72
+ result[tags[i]] = polarity_list[predicts[i]] # Ánh xạ nhãn
73
+ # print(result)
74
+ return result
75
 
76
  def predict_df(model, df, tokenizer=None, model_tokenize=None, tokenizer_name='vinai/bartpho-word-base', processed=True, printout=True):
77
  model.eval()