Spaces:
Running
Running
File size: 4,076 Bytes
b1c2932 f1168c0 219ed3b b1c2932 7a079bf a6c2ca7 b1c2932 5e7a4a4 b1c2932 8dc5af0 b1c2932 38e644a 65e8066 3d40b96 38e644a b1c2932 a6c2ca7 b1c2932 182d6c8 b1c2932 84c21a9 b1c2932 219ed3b 466052e 219ed3b b1c2932 f7e1e22 b1c2932 f7e1e22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi
from label_dicts import CAP_NUM_DICT, CAP_LABEL_NAMES
from .utils import is_disk_full
HF_TOKEN = os.environ["hf_read"]
languages = [
"English",
"Multilingual"
]
domains = {
"media": "media",
"social media": "social",
"parliamentary speech": "parlspeech",
"legislative documents": "legislative",
"executive speech": "execspeech",
"executive order": "execorder",
"party programs": "party",
"judiciary": "judiciary",
"budget": "budget",
"public opinion": "publicopinion",
"local government agenda": "localgovernment"
}
def check_huggingface_path(checkpoint_path: str):
try:
hf_api = HfApi(token=HF_TOKEN)
hf_api.model_info(checkpoint_path, token=HF_TOKEN)
return True
except:
return False
def build_huggingface_path(language: str, domain: str):
language = language.lower()
base_path = "xlm-roberta-large"
if language == "english" and (domain == "media" or domain == "legislative"):
lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v4"
return lang_domain_path
else:
lang_domain_path = f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
lang_path = f"poltextlab/{base_path}-{language}-cap-v3"
path_map = {
"L": lang_path,
"L-D": lang_domain_path,
"X": lang_domain_path,
}
value = None
try:
lang_domain_table = pd.read_csv("language_domain_models.csv")
lang_domain_table["language"] = lang_domain_table["language"].str.lower()
lang_domain_table.columns = lang_domain_table.columns.str.lower()
# get the row for the language and them get the value from the domain column
row = lang_domain_table[(lang_domain_table["language"] == language)]
tmp = row.get(domain)
if not tmp.empty:
value = tmp.iloc[0]
except (AttributeError, FileNotFoundError):
value = None
if language == 'english':
model_path = lang_path
else:
model_path = "poltextlab/xlm-roberta-large-pooled-cap"
if check_huggingface_path(model_path):
return model_path
else:
return "poltextlab/xlm-roberta-large-pooled-cap"
def predict(text, model_id, tokenizer_id):
device = torch.device("cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
inputs = tokenizer(text,
max_length=256,
truncation=True,
padding="do_not_pad",
return_tensors="pt").to(device)
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
output_pred = {f"[{CAP_NUM_DICT[i]}] {CAP_LABEL_NAMES[CAP_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
return output_pred, output_info
def predict_cap(text, language, domain):
domain = domains[domain]
model_id = build_huggingface_path(language, domain)
tokenizer_id = "xlm-roberta-large"
if is_disk_full():
os.system('rm -rf /data/models*')
os.system('rm -r ~/.cache/huggingface/hub')
return predict(text, model_id, tokenizer_id)
demo = gr.Interface(
title="CAP Babel Demo",
fn=predict_cap,
inputs=[gr.Textbox(lines=6, label="Input"),
gr.Dropdown(languages, label="Language"),
gr.Dropdown(domains.keys(), label="Domain")],
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])
|