poltextlab commited on
Commit
89d4ec8
·
verified ·
1 Parent(s): ea732a8

add major code mapping

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor.py +18 -4
interfaces/cap_minor.py CHANGED
@@ -8,7 +8,7 @@ from transformers import AutoModelForSequenceClassification
8
  from transformers import AutoTokenizer
9
  from huggingface_hub import HfApi
10
 
11
- from label_dicts import CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES
12
 
13
  from .utils import is_disk_full
14
 
@@ -32,6 +32,19 @@ domains = {
32
  "local government agenda": "localgovernment"
33
  }
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def check_huggingface_path(checkpoint_path: str):
36
  try:
37
  hf_api = HfApi(token=HF_TOKEN)
@@ -59,9 +72,10 @@ def predict(text, model_id, tokenizer_id):
59
  logits = model(**inputs).logits
60
 
61
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
62
- output_pred = {f"[{CAP_MIN_NUM_DICT[i]}] {CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
 
63
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
64
- return output_pred, output_info
65
 
66
  def predict_cap(text, language, domain):
67
  domain = domains[domain]
@@ -80,4 +94,4 @@ demo = gr.Interface(
80
  inputs=[gr.Textbox(lines=6, label="Input"),
81
  gr.Dropdown(languages, label="Language"),
82
  gr.Dropdown(domains.keys(), label="Domain")],
83
- outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])
 
8
  from transformers import AutoTokenizer
9
  from huggingface_hub import HfApi
10
 
11
+ from label_dicts import CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES, CAP_LABEL_NAMES
12
 
13
  from .utils import is_disk_full
14
 
 
32
  "local government agenda": "localgovernment"
33
  }
34
 
35
+ def convert_minor_to_major(results: str) -> str:
36
+ results_as_text = dict()
37
+ for i in range(results):
38
+ prob = probs[i]
39
+ major_code = CAP_MIN_NUM_DICT[i][:-2]
40
+ label = CAP_LABEL_NAMES[major_code]
41
+
42
+ key = f"[{major_code}] {label}"
43
+ results_as_text[key] = probs[i]
44
+
45
+ return results_as_text
46
+
47
+
48
  def check_huggingface_path(checkpoint_path: str):
49
  try:
50
  hf_api = HfApi(token=HF_TOKEN)
 
72
  logits = model(**inputs).logits
73
 
74
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
75
+ output_pred_minor = {f"[{CAP_MIN_NUM_DICT[i]}] {CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
76
+ output_pred_major = convert_minor_to_major(np.argsort(probs)[::-1])
77
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
78
+ return output_pred_minor, output_pred_major output_info
79
 
80
  def predict_cap(text, language, domain):
81
  domain = domains[domain]
 
94
  inputs=[gr.Textbox(lines=6, label="Input"),
95
  gr.Dropdown(languages, label="Language"),
96
  gr.Dropdown(domains.keys(), label="Domain")],
97
+ outputs=[gr.Label(num_top_classes=5, label="Output minor"), gr.Label(num_top_classes=5, label="Output major" gr.Markdown()])