Spaces:
Running
Running
add major code mapping
Browse files- interfaces/cap_minor.py +18 -4
interfaces/cap_minor.py
CHANGED
@@ -8,7 +8,7 @@ from transformers import AutoModelForSequenceClassification
|
|
8 |
from transformers import AutoTokenizer
|
9 |
from huggingface_hub import HfApi
|
10 |
|
11 |
-
from label_dicts import CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES
|
12 |
|
13 |
from .utils import is_disk_full
|
14 |
|
@@ -32,6 +32,19 @@ domains = {
|
|
32 |
"local government agenda": "localgovernment"
|
33 |
}
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def check_huggingface_path(checkpoint_path: str):
|
36 |
try:
|
37 |
hf_api = HfApi(token=HF_TOKEN)
|
@@ -59,9 +72,10 @@ def predict(text, model_id, tokenizer_id):
|
|
59 |
logits = model(**inputs).logits
|
60 |
|
61 |
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
|
62 |
-
|
|
|
63 |
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
|
64 |
-
return
|
65 |
|
66 |
def predict_cap(text, language, domain):
|
67 |
domain = domains[domain]
|
@@ -80,4 +94,4 @@ demo = gr.Interface(
|
|
80 |
inputs=[gr.Textbox(lines=6, label="Input"),
|
81 |
gr.Dropdown(languages, label="Language"),
|
82 |
gr.Dropdown(domains.keys(), label="Domain")],
|
83 |
-
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])
|
|
|
8 |
from transformers import AutoTokenizer
|
9 |
from huggingface_hub import HfApi
|
10 |
|
11 |
+
from label_dicts import CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES, CAP_LABEL_NAMES
|
12 |
|
13 |
from .utils import is_disk_full
|
14 |
|
|
|
32 |
"local government agenda": "localgovernment"
|
33 |
}
|
34 |
|
35 |
+
def convert_minor_to_major(results: str) -> str:
|
36 |
+
results_as_text = dict()
|
37 |
+
for i in range(results):
|
38 |
+
prob = probs[i]
|
39 |
+
major_code = CAP_MIN_NUM_DICT[i][:-2]
|
40 |
+
label = CAP_LABEL_NAMES[major_code]
|
41 |
+
|
42 |
+
key = f"[{major_code}] {label}"
|
43 |
+
results_as_text[key] = probs[i]
|
44 |
+
|
45 |
+
return results_as_text
|
46 |
+
|
47 |
+
|
48 |
def check_huggingface_path(checkpoint_path: str):
|
49 |
try:
|
50 |
hf_api = HfApi(token=HF_TOKEN)
|
|
|
72 |
logits = model(**inputs).logits
|
73 |
|
74 |
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
|
75 |
+
output_pred_minor = {f"[{CAP_MIN_NUM_DICT[i]}] {CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
|
76 |
+
output_pred_major = convert_minor_to_major(np.argsort(probs)[::-1])
|
77 |
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
|
78 |
+
return output_pred_minor, output_pred_major output_info
|
79 |
|
80 |
def predict_cap(text, language, domain):
|
81 |
domain = domains[domain]
|
|
|
94 |
inputs=[gr.Textbox(lines=6, label="Input"),
|
95 |
gr.Dropdown(languages, label="Language"),
|
96 |
gr.Dropdown(domains.keys(), label="Domain")],
|
97 |
+
outputs=[gr.Label(num_top_classes=5, label="Output minor"), gr.Label(num_top_classes=5, label="Output major" gr.Markdown()])
|