Plus: Addes the .csv from huawei
Browse files- src/__pycache__/predict.cpython-310.pyc +0 -0
- src/predict.py +20 -0
src/__pycache__/predict.cpython-310.pyc
CHANGED
Binary files a/src/__pycache__/predict.cpython-310.pyc and b/src/__pycache__/predict.cpython-310.pyc differ
|
|
src/predict.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from pathlib import Path
|
2 |
import torch
|
|
|
3 |
|
4 |
from .tokenizer import load_tokenizer, preprocessing_text
|
5 |
from .model import load_model
|
@@ -17,6 +18,24 @@ tokenizer = load_tokenizer(model_name)
|
|
17 |
model = load_model(checkpoint_path, model_name, num_labels, divice)
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def get_predict(text):
|
21 |
inputs = preprocessing_text(text, tokenizer)
|
22 |
input_ids = inputs["input_ids"].to(divice)
|
@@ -24,4 +43,5 @@ def get_predict(text):
|
|
24 |
token_type_ids = inputs["token_type_ids"].to(divice)
|
25 |
outputs = model(input_ids, attention_mask, token_type_ids)
|
26 |
preds = torch.sigmoid(outputs).detach().cpu().numpy()
|
|
|
27 |
return preds
|
|
|
1 |
from pathlib import Path
|
2 |
import torch
|
3 |
+
import numpy as np
|
4 |
|
5 |
from .tokenizer import load_tokenizer, preprocessing_text
|
6 |
from .model import load_model
|
|
|
18 |
model = load_model(checkpoint_path, model_name, num_labels, divice)
|
19 |
|
20 |
|
21 |
+
RETURN_VALUES =[
|
22 |
+
"target_sentiment_negative",
|
23 |
+
"target_sentiment_neutral",
|
24 |
+
"target_sentiment_positive",
|
25 |
+
"companies_sentiment_negative",
|
26 |
+
"companies_sentiment_neutral",
|
27 |
+
"companies_sentiment_positive",
|
28 |
+
"consumers_sentiment_negative",
|
29 |
+
"consumers_sentiment_neutral",
|
30 |
+
"consumers_sentiment_positive"
|
31 |
+
]
|
32 |
+
|
33 |
+
def filter(preds, threshold=0.5):
|
34 |
+
bool = preds > threshold
|
35 |
+
indices = np.where(bool)[0]
|
36 |
+
filtered_values = {RETURN_VALUES[index]: preds[index] for index in indices}
|
37 |
+
return filtered_values
|
38 |
+
|
39 |
def get_predict(text):
|
40 |
inputs = preprocessing_text(text, tokenizer)
|
41 |
input_ids = inputs["input_ids"].to(divice)
|
|
|
43 |
token_type_ids = inputs["token_type_ids"].to(divice)
|
44 |
outputs = model(input_ids, attention_mask, token_type_ids)
|
45 |
preds = torch.sigmoid(outputs).detach().cpu().numpy()
|
46 |
+
preds = filter(preds[0])
|
47 |
return preds
|