kovacsvi commited on
Commit
8869f68
·
1 Parent(s): 3149885

removed html slop

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor.py +36 -70
interfaces/cap_minor.py CHANGED
@@ -8,19 +8,17 @@ from transformers import AutoModelForSequenceClassification
8
  from transformers import AutoTokenizer
9
  from huggingface_hub import HfApi
10
 
11
- from label_dicts import CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES, CAP_LABEL_NAMES
12
 
13
- from .utils import is_disk_full, release_model
14
- from itertools import islice
 
 
 
 
15
 
16
- def take(n, iterable):
17
- """Return the first n items of the iterable as a list."""
18
- return list(islice(iterable, n))
19
 
20
- def score_to_color(prob):
21
- red = int(255 * (1 - prob))
22
- green = int(255 * prob)
23
- return f"rgb({red},{green},0)"
24
 
25
  HF_TOKEN = os.environ["hf_read"]
26
 
@@ -39,20 +37,17 @@ domains = {
39
  "judiciary": "judiciary",
40
  "budget": "budget",
41
  "public opinion": "publicopinion",
42
- "local government agenda": "localgovernment"
43
  }
44
 
45
- def convert_minor_to_major(minor_topic):
46
- if minor_topic == 999:
47
- major_code = 999
48
- else:
49
- major_code = str(minor_topic)[:-2]
50
 
51
-
52
- label = CAP_LABEL_NAMES[int(major_code)]
 
 
 
 
53
 
54
- return label
55
-
56
 
57
  def check_huggingface_path(checkpoint_path: str):
58
  try:
@@ -62,11 +57,13 @@ def check_huggingface_path(checkpoint_path: str):
62
  except:
63
  return False
64
 
 
65
  def build_huggingface_path(language: str, domain: str):
66
  if domain in ["social"]:
67
  return "poltextlab/xlm-roberta-large-twitter-cap-minor"
68
  return "poltextlab/xlm-roberta-large-pooled-cap-minor-v3"
69
 
 
70
  def predict(text, model_id, tokenizer_id):
71
  device = torch.device("cpu")
72
 
@@ -80,74 +77,43 @@ def predict(text, model_id, tokenizer_id):
80
 
81
  # Tokenize input
82
  inputs = tokenizer(
83
- text,
84
- max_length=64,
85
- truncation=True,
86
- padding=True,
87
- return_tensors="pt"
88
  )
89
  inputs = {k: v.to(device) for k, v in inputs.items()}
90
 
91
  with torch.no_grad():
92
  output = model(inputs["input_ids"], inputs["attention_mask"])
93
- print(output) # debug
94
  logits = output["logits"]
95
-
96
  release_model(model, model_id)
97
 
98
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
99
- output_pred = {f"[{'999' if str(CAP_MIN_NUM_DICT[i]) == '999' else str(CAP_MIN_NUM_DICT[i])[:-2]}]{convert_minor_to_major(CAP_MIN_NUM_DICT[i])} [{CAP_MIN_NUM_DICT[i]}]{CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
100
-
101
-
102
- output_pred = dict(sorted(output_pred.items(), key=lambda item: item[1], reverse=True))
103
- first_n_items = take(5, output_pred.items())
104
-
105
- html = ""
106
- html += '<div style="background-color: white">'
107
- first = True
108
- for label, prob in first_n_items:
109
- bar_color = "#e0d890" if first else "#ccc"
110
- text_color = "black"
111
- bar_width = int(prob * 100)
112
-
113
- bar_color = score_to_color(prob)
114
-
115
-
116
-
117
- if first:
118
- html += f"""
119
- <div style="text-align: center; font-weight: bold; font-size: 27px; margin-bottom: 10px; margin-left: 10px; margin-right: 10px;">
120
- <span style="color: {text_color};">{label}</span>
121
- </div>"""
122
-
123
- html += f"""
124
- <div style="height: 4px; background-color: green; width: {bar_width}%; margin-bottom: 8px;"></div>
125
- <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;">
126
- <span style="color: {text_color};">{label} — {int(prob * 100)}%</span>
127
- </div>
128
- """
129
- first = False
130
-
131
-
132
- html += '</div>'
133
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
134
- return html, output_info
 
135
 
136
  def predict_cap(text, language, domain):
137
  domain = domains[domain]
138
  model_id = build_huggingface_path(language, domain)
139
  tokenizer_id = "xlm-roberta-large"
140
-
141
  if is_disk_full():
142
- os.system('rm -rf /data/models*')
143
- os.system('rm -r ~/.cache/huggingface/hub')
144
-
145
  return predict(text, model_id, tokenizer_id)
146
 
 
147
  demo = gr.Interface(
148
  title="CAP Minor Topics Babel Demo",
149
  fn=predict_cap,
150
- inputs=[gr.Textbox(lines=6, label="Input"),
151
- gr.Dropdown(languages, label="Language", value=languages[0]),
152
- gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0])],
153
- outputs=[gr.HTML(label="Output"), gr.Markdown()])
 
 
 
 
8
  from transformers import AutoTokenizer
9
  from huggingface_hub import HfApi
10
 
11
+ from collections import defaultdict
12
 
13
+ from label_dicts import (
14
+ CAP_NUM_DICT,
15
+ CAP_LABEL_NAMES,
16
+ CAP_MIN_NUM_DICT,
17
+ CAP_MIN_LABEL_NAMES,
18
+ )
19
 
20
+ from .utils import is_disk_full, release_model
 
 
21
 
 
 
 
 
22
 
23
  HF_TOKEN = os.environ["hf_read"]
24
 
 
37
  "judiciary": "judiciary",
38
  "budget": "budget",
39
  "public opinion": "publicopinion",
40
+ "local government agenda": "localgovernment",
41
  }
42
 
 
 
 
 
 
43
 
44
+ def get_label_name(idx):
45
+ minor_code = CAP_MIN_NUM_DICT[idx]
46
+ minor_label_name = CAP_MIN_LABEL_NAMES[minor_code]
47
+ major_code = minor_code // 100 if minor_code not in [99, 999, 9999] else 999
48
+ major_label_name = CAP_LABEL_NAMES[major_code]
49
+ return f"[{major_code}] {major_label_name} [{minor_code}] {minor_label_name}"
50
 
 
 
51
 
52
  def check_huggingface_path(checkpoint_path: str):
53
  try:
 
57
  except:
58
  return False
59
 
60
+
61
  def build_huggingface_path(language: str, domain: str):
62
  if domain in ["social"]:
63
  return "poltextlab/xlm-roberta-large-twitter-cap-minor"
64
  return "poltextlab/xlm-roberta-large-pooled-cap-minor-v3"
65
 
66
+
67
  def predict(text, model_id, tokenizer_id):
68
  device = torch.device("cpu")
69
 
 
77
 
78
  # Tokenize input
79
  inputs = tokenizer(
80
+ text, max_length=64, truncation=True, padding=True, return_tensors="pt"
 
 
 
 
81
  )
82
  inputs = {k: v.to(device) for k, v in inputs.items()}
83
 
84
  with torch.no_grad():
85
  output = model(inputs["input_ids"], inputs["attention_mask"])
86
+ print(output) # debug
87
  logits = output["logits"]
88
+
89
  release_model(model, model_id)
90
 
91
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
92
+
93
+ output_pred = {get_label_name(i): probs[i] for i in np.argsort(probs)[::-1]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
95
+ return output_pred, output_info
96
+
97
 
98
  def predict_cap(text, language, domain):
99
  domain = domains[domain]
100
  model_id = build_huggingface_path(language, domain)
101
  tokenizer_id = "xlm-roberta-large"
102
+
103
  if is_disk_full():
104
+ os.system("rm -rf /data/models*")
105
+ os.system("rm -r ~/.cache/huggingface/hub")
106
+
107
  return predict(text, model_id, tokenizer_id)
108
 
109
+
110
  demo = gr.Interface(
111
  title="CAP Minor Topics Babel Demo",
112
  fn=predict_cap,
113
+ inputs=[
114
+ gr.Textbox(lines=6, label="Input"),
115
+ gr.Dropdown(languages, label="Language", value=languages[0]),
116
+ gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0]),
117
+ ],
118
+ outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()],
119
+ )