Spaces:
Runtime error
Runtime error
Update bartpho/utils.py
Browse files- bartpho/utils.py +9 -10
bartpho/utils.py
CHANGED
@@ -51,28 +51,27 @@ def predict(model, text, tokenizer, model_tokenize=None, processed=True, printou
|
|
51 |
for i in range(no_tag):
|
52 |
tag = tags[i]
|
53 |
score_list = []
|
54 |
-
|
|
|
55 |
target_list = ["Nhận_xét " + tag.lower() + " " + polarity.lower() + " ." for polarity in polarity_list]
|
56 |
-
output_ids = tokenizer(target_list, return_tensors='pt', padding=True, truncation=True)['input_ids']
|
57 |
|
58 |
with torch.no_grad():
|
59 |
output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids.to(device))[0]
|
60 |
logits = output.softmax(dim=-1).to('cpu').numpy()
|
61 |
for m in range(no_polarity):
|
62 |
-
score = 1
|
63 |
-
for n in range(logits[m].shape[0] - 2):
|
64 |
-
score *= logits[m][n][output_ids[m][n+1]]
|
65 |
score_list.append(score)
|
66 |
-
predict = np.argmax(score_list)
|
67 |
predicts.append(predict)
|
68 |
|
69 |
if printout:
|
70 |
result = {}
|
71 |
for i in range(no_tag):
|
72 |
-
if predicts[i] != 0:
|
73 |
-
result[
|
74 |
-
print(result)
|
75 |
-
return
|
76 |
|
77 |
def predict_df(model, df, tokenizer=None, model_tokenize=None, tokenizer_name='vinai/bartpho-word-base', processed=True, printout=True):
|
78 |
model.eval()
|
|
|
51 |
for i in range(no_tag):
|
52 |
tag = tags[i]
|
53 |
score_list = []
|
54 |
+
|
55 |
+
input_ids = tokenizer([text] * no_polarity, return_tensors='pt')['input_ids'].to(device)
|
56 |
target_list = ["Nhận_xét " + tag.lower() + " " + polarity.lower() + " ." for polarity in polarity_list]
|
57 |
+
output_ids = tokenizer(target_list, return_tensors='pt', padding=True, truncation=True)['input_ids'].to(device)
|
58 |
|
59 |
with torch.no_grad():
|
60 |
output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids.to(device))[0]
|
61 |
logits = output.softmax(dim=-1).to('cpu').numpy()
|
62 |
for m in range(no_polarity):
|
63 |
+
score = np.sum(np.log(logits[m][range(len(output_ids[m]) - 2), output_ids[m][1:-1]]))
|
|
|
|
|
64 |
score_list.append(score)
|
65 |
+
predict = int(np.argmax(score_list)) # Ép kiểu sang int
|
66 |
predicts.append(predict)
|
67 |
|
68 |
if printout:
|
69 |
result = {}
|
70 |
for i in range(no_tag):
|
71 |
+
if predicts[i] != 0: # Bỏ qua các nhãn không có cảm xúc (mặc định 0)
|
72 |
+
result[tags[i]] = polarity_list[predicts[i]] # Ánh xạ nhãn
|
73 |
+
# print(result)
|
74 |
+
return result
|
75 |
|
76 |
def predict_df(model, df, tokenizer=None, model_tokenize=None, tokenizer_name='vinai/bartpho-word-base', processed=True, printout=True):
|
77 |
model.eval()
|