kovacsvi
updated model selection logic
f295362
raw
history blame
3.97 kB
import gradio as gr
import os
import time
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi
from label_dicts import CAP_NUM_DICT, CAP_LABEL_NAMES
from .utils import is_disk_full, release_model
HF_TOKEN = os.environ["hf_read"]
languages = [
"English",
"Multilingual"
]
domains = {
"media": "media",
"social media": "social",
"parliamentary speech": "parlspeech",
"legislative documents": "legislative",
"executive speech": "execspeech",
"executive order": "execorder",
"party programs": "party",
"judiciary": "judiciary",
"budget": "budget",
"public opinion": "publicopinion",
"local government agenda": "localgovernment"
}
def check_huggingface_path(checkpoint_path: str):
try:
hf_api = HfApi(token=HF_TOKEN)
hf_api.model_info(checkpoint_path, token=HF_TOKEN)
return True
except:
return False
def build_huggingface_path(language: str, domain: str):
language = language.lower()
base_path = "xlm-roberta-large"
lang_path = f"poltextlab/{base_path}-{language}-cap-v3"
# some custom mapping
if language in ["english"]:
if domain in ["media", "legislative"]:
return f"poltextlab/{base_path}-{language}-{domain}-cap-v4"
elif domain in ["social"]:
return f"poltextlab/{base_path}-{language}-{domain}-cap-v3"
return lang_path
else:
if domain in ["social"]:
return f"poltextlab/{base_path}-{domain}-cap-v3"
return "poltextlab/xlm-roberta-large-pooled-cap-v3"
def predict(text, model_id, tokenizer_id):
device = torch.device("cpu")
t0 = time.perf_counter()
jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt"
model = torch.jit.load(jit_model_path).to(device)
model.eval()
print(f"Model loading: {time.perf_counter() - t0:.3f}s")
t1 = time.perf_counter()
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
print(f"Tokenizer loading: {time.perf_counter() - t1:.3f}s")
t2 = time.perf_counter()
inputs = tokenizer(
text,
max_length=64,
truncation=True,
padding=True,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
print(f"Tokenization: {time.perf_counter() - t2:.3f}s")
t3 = time.perf_counter()
with torch.no_grad():
output = model(inputs["input_ids"], inputs["attention_mask"])
logits = output["logits"]
print(f"Inference: {time.perf_counter() - t3:.3f}s")
release_model(model, model_id)
t4 = time.perf_counter()
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
output_pred = {f"[{CAP_NUM_DICT[i]}] {CAP_LABEL_NAMES[CAP_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
print(f"Post-processing: {time.perf_counter() - t4:.3f}s")
return output_pred, output_info
def predict_cap(text, language, domain):
print(domain) # debug statement
domain = domains[domain]
model_id = build_huggingface_path(language, domain)
tokenizer_id = "xlm-roberta-large"
if is_disk_full():
os.system('rm -rf /data/models*')
os.system('rm -r ~/.cache/huggingface/hub')
return predict(text, model_id, tokenizer_id)
demo = gr.Interface(
title="CAP Babel Demo",
fn=predict_cap,
inputs=[gr.Textbox(lines=6, label="Input"),
gr.Dropdown(languages, label="Language", value=languages[-1]),
gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0])],
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])