Spaces:
Running
Running
File size: 4,556 Bytes
b1c2932 7a079bf b1c2932 5a69359 d7f7631 5a69359 b1c2932 d7f7631 b1c2932 5a69359 b1c2932 ce3ad7e b1c2932 590fdbb 5a69359 b1c2932 5a69359 b1c2932 5a69359 b1c2932 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi
from huggingface_hub.utils._errors import RepositoryNotFoundError
from label_dicts import CAP_NUM_DICT, CAP_LABEL_NAMES
HF_TOKEN = os.environ["hf_read"]
languages = [
"Danish",
"Dutch",
"English",
"French",
"German",
"Hungarian",
"Italian",
"Polish",
"Portuguese",
"Spanish",
"Czech",
"Slovak",
"Norwegian"
]
domains = {
"media": "media",
"social media": "social",
"parliamentary speech": "parlspeech",
"legislative documents": "legislative",
"executive speech": "execspeech",
"executive order": "execorder",
"party programs": "party",
"judiciary": "judiciary",
"budget": "budget",
"public opinion": "publicopinion",
"local government agenda": "localgovernment"
}
def check_huggingface_path(checkpoint_path: str):
try:
hf_api = HfApi(token=HF_TOKEN)
hf_api.model_info(checkpoint_path, token=HF_TOKEN)
return True
except RepositoryNotFoundError:
return False
def build_huggingface_path(language: str, domain: str):
base_path = "xlm-roberta-large"
lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
lang_path = f"poltextlab/{base_path}-{language}-cap-v3"
path_map = {
"L": lang_path,
"L-D": lang_domain_path,
"X": lang_domain_path,
}
value = None
try:
lang_domain_table = pd.read_csv("language_domain_models.csv")
lang_domain_table["language"] = lang_domain_table["language"].str.lower()
lang_domain_table.columns = lang_domain_table.columns.str.lower()
# get the row for the language and them get the value from the domain column
row = lang_domain_table[(lang_domain_table["language"] == language)]
tmp = row.get(domain)
if not tmp.empty:
value = tmp.iloc[0]
except (AttributeError, FileNotFoundError):
value = None
if value and value in path_map:
model_path = path_map[value]
if check_huggingface_path(model_path):
# if the model is available on Huggingface, return the path
return model_path
else:
# if the model is not available on Huggingface, look for other models
filtered_path_map = {k: v for k, v in path_map.items() if k != value}
for k, v in filtered_path_map.items():
if check_huggingface_path(v):
return v
elif check_huggingface_path(lang_domain_path):
return lang_domain_path
elif check_huggingface_path(lang_path):
return lang_path
else:
return "poltextlab/xlm-roberta-large-pooled-cap"
def predict(text, model_id, tokenizer_id):
device = torch.device("cpu")
gr.Info("Loading model")
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", token=HF_TOKEN)
gr.Info("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
#gr.Info("Mapping model to device")
#model.to(device)
gr.Info("Tokenizing")
inputs = tokenizer(text,
max_length=4,
truncation=True,
padding="do_not_pad",
return_tensors="pt").to(device)
gr.Info("model.eval()")
model.eval()
gr.Info("Prediction")
with torch.no_grad():
logits = model(**inputs).logits
gr.Info("Softmax")
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
output_pred = {f"[{CAP_NUM_DICT[i]}] {CAP_LABEL_NAMES[CAP_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
return output_pred, output_info
def predict_cap(text, language, domain):
domain = domains[domain]
model_id = build_huggingface_path(language, domain)
tokenizer_id = "xlm-roberta-large"
return predict(text, model_id, tokenizer_id)
demo = gr.Interface(
fn=predict_cap,
inputs=[gr.Textbox(lines=6, label="Input"),
gr.Dropdown(languages, label="Language"),
gr.Dropdown(domains.keys(), label="Domain")],
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()]) |