import gradio as gr import torch import torch.nn.functional as F import numpy as np from corpy.morphodita import Tokenizer import transformers from transformers import AutoTokenizer, AutoModelForSequenceClassification model_checkpoint = 'ufal/robeczech-base' device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") transformers.logging.set_verbosity(transformers.logging.ERROR) def classify_sentence(sent:str): toksentence = tokenizer(sent,truncation=True,return_tensors="pt") model.eval() with torch.no_grad(): toksentence.to(device) output = model(**toksentence) return F.softmax(output.logits,dim=1).argmax(dim=1) def classify_text(text:str): tokenizer_morphodita = Tokenizer("czech") all = [] for sentence in tokenizer_morphodita.tokenize(text, sents=True): all.append(sentence) sentences = np.array([' '.join(x) for x in all]) annotations = np.array(list(map(classify_sentence,sentences))) return annotations def classify_text_wrapper(text:str): result = classify_text(text) n = len(result) non_biased = np.where(result==0)[0].shape[0] biased = np.where(result==1)[0].shape[0] return {'Non-biased':non_biased/n,'Biased':biased/n} def interpret_bias(text:str): result = classify_text(text) tokenizer_morphodita = Tokenizer("czech") interpretation = [] all = [] for sentence in tokenizer_morphodita.tokenize(text, sents=True): all.append(sentence) sentences = np.array([' '.join(x) for x in all]) for idx,sentence in enumerate(sentences): score = 0 #non biased if result[idx] == 0: score = -1 #biased if result[idx] == 1: score = 1 interpretation.append((sentence, score)) return interpretation tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForSequenceClassification.from_pretrained("sagittariusA/media_bias_classifier_cs") model.eval() label = gr.outputs.Label(num_top_classes=2) inputs = gr.inputs.Textbox(placeholder=None, default="", label=None) app = gr.Interface(fn=classify_text_wrapper,title='Bias classifier',theme='default', inputs="textbox",layout='unaligned', outputs=label, capture_session=True ,interpretation=interpret_bias) app.launch(inbrowser=True)