import re import gradio as gr from transformers import AutoTokenizer from unidecode import unidecode from models import * tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base") def preprocess(x): """Preprocess input string x""" s = unidecode(x) s = str.lower(s) s = re.sub(r"\[[a-z]+\]","", s) s = re.sub(r"\*","", s) s = re.sub(r"[^a-zA-Z0-9]+"," ",s) s = re.sub(r" +"," ",s) s = re.sub(r"(.)\1+",r"\1",s) return s label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"] def ssl_predict(in_text, model_type): """main predict function""" preprocessed = preprocess(in_text) toks = tok( preprocessed, padding="max_length", max_length=96, truncation=True, return_tensors="tf" ) if model_type == "freematch": model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch") model.cls_head.load_weights("./checkpoints/freematch_tune") preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) probs = list(preds[0].numpy()) return {k:v for k, v in zip(label_names, probs)} with gr.Blocks() as ssl_interface: with gr.Row(): with gr.Column(): in_text = gr.Textbox(label="Input text") model_list = gr.Dropdown( choices=["fixmatch", "freematch", "mixmatch"], max_choices=1, label="Training method", allow_custom_value=False, info="Select trained model according to different SSL techniques from paper", ) with gr.Row(): clear_btn = gr.Button(value="Clear") submit_btn = gr.Button(value="Submit") with gr.Column(): out_field = gr.Label(num_top_classes=4,label="Prediction") submit_btn.click( fn=ssl_predict, inputs=[in_text, model_list], outputs=[out_field] ) clear_btn.click( fn=lambda: [None for _ in range(2)], inputs=None, outputs=[in_text, out_field] ) ssl_interface.launch(server_name="0.0.0.0", server_port=7860)