babel_machine / interfaces /manifesto.py
vickeee465
shorter max_len
027da70
raw
history blame
2.92 kB
import time
import gradio as gr
import os
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi
from label_dicts import MANIFESTO_LABEL_NAMES
class RuntimeMeasure:
def __init__(self, msg):
self.msg = msg
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
end_time = time.time()
runtime = end_time - self.start_time
gr.Info(f"{self.msg}: {runtime} seconds")
def m(msg):
return RuntimeMeasure(msg)
HF_TOKEN = os.environ["hf_read"]
languages = [
"Armenian", "Bulgarian", "Croatian", "Czech", "Danish", "Dutch", "English",
"Estonian", "Finnish", "French", "Georgian", "German", "Greek", "Hebrew",
"Hungarian", "Icelandic", "Italian", "Japanese", "Korean", "Latvian",
"Lithuanian", "Norwegian", "Polish", "Portuguese", "Romanian", "Russian",
"Serbian", "Slovak", "Slovenian", "Spanish", "Swedish", "Turkish"
]
def build_huggingface_path(language: str):
return "poltextlab/xlm-roberta-large-manifesto"
def predict(text, model_id, tokenizer_id):
gr.Info("\n".join(os.listdir("/data/")))
device = torch.device("cpu")
with m("Loading model"):
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", token=HF_TOKEN)
with m("Loading tokenizer"):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
with m("Tokenizing"):
inputs = tokenizer(text,
max_length=256,
truncation=True,
padding="do_not_pad",
return_tensors="pt").to(device)
with m("model.eval()"):
model.eval()
with m("Inference"):
with torch.no_grad():
logits = model(**inputs).logits
with m("Softmax"):
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
with m("Output formatting"):
output_pred = {f"[{model.config.id2label[i]}] {MANIFESTO_LABEL_NAMES[int(model.config.id2label[i])]}": probs[i] for i in np.argsort(probs)[::-1]}
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
return output_pred, output_info
def predict_cap(text, language):
with m("WHOLE PROCESS"):
model_id = build_huggingface_path(language)
tokenizer_id = "xlm-roberta-large"
prediction = predict(text, model_id, tokenizer_id)
return prediction
demo = gr.Interface(
fn=predict_cap,
inputs=[gr.Textbox(lines=6, label="Input"),
gr.Dropdown(languages, label="Language")],
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])