Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from newsclassifier.config.config import Cfg, logger | |
from newsclassifier.data import prepare_input | |
from newsclassifier.models import CustomModel | |
from transformers import RobertaTokenizer | |
labels = list(Cfg.index_to_class.values()) | |
# load and compile the model | |
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") | |
model = CustomModel(num_classes=7) | |
model.load_state_dict(torch.load(os.path.join(Cfg.artifacts_path, "model.pt"), map_location=torch.device("cpu"))) | |
def prediction(text): | |
sample_input = prepare_input(tokenizer, text) | |
input_ids = torch.unsqueeze(sample_input["input_ids"], 0).to("cpu") | |
attention_masks = torch.unsqueeze(sample_input["attention_mask"], 0).to("cpu") | |
test_sample = dict(input_ids=input_ids, attention_mask=attention_masks) | |
with torch.no_grad(): | |
y_pred_test_sample = model.predict_proba(test_sample) | |
pred_probs = y_pred_test_sample[0] | |
return {labels[i]: float(pred_probs[i]) for i in range(len(labels))} | |
title = "NewsClassifier" | |
description = "Enter a news headline, and this app will classify it into one of the categories." | |
instructions = "Type or paste a news headline in the textbox and press Enter." | |
iface = gr.Interface( | |
fn=prediction, | |
inputs=gr.Textbox(), | |
outputs=gr.Label(num_top_classes=7), | |
title=title, | |
description=description, | |
examples=[ | |
["Global Smartphone Shipments Will Hit Lowest Point in a Decade, IDC Says"], | |
["John Wick's First Spinoff is the Rare Prequel That Justifies Its Existence"], | |
["Research provides a better understanding of how light stimulates the brain"], | |
["Lionel Messi scores free kick golazo for Argentina in World Cup qualifiers"], | |
], | |
article=instructions, | |
) | |
iface.launch(share=True) | |