3 / app.py
paranitik's picture
Update app.py
3479b02
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
model.load_state_dict(torch.load('model_after_train.pt', map_location=torch.device('cpu')), strict=False)
model.eval()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def preprocess_text(text):
parts = []
text_len = len(text.split(' '))
delta = 300
max_parts = 5
nb_cuts = int(text_len / delta)
nb_cuts = min(nb_cuts, max_parts)
for i in range(nb_cuts + 1):
text_part = ' '.join(text.split(' ')[i * delta: (i + 1) * delta])
parts.append(tokenizer.encode(text_part, return_tensors="pt", max_length=500).to(device))
return parts
def test(text):
text_parts = preprocess_text(text)
overall_output = torch.zeros((1,2)).to(device)
try:
for part in text_parts:
if len(part) > 0:
overall_output += model(part.reshape(1, -1))[0]
except RuntimeError:
print("GPU out of memory, skipping this entry.")
overall_output = F.softmax(overall_output[0], dim=-1)
value, result = overall_output.max(0)
term = "fake"
if result.item() == 0:
term = "real"
return term + " at " + str(int(value.item()*100)) + " %"
description = "Fake news detector trained using pre-trained model bert-base-uncased, fine-tuned on https://www.kaggle.com/clmentbisaillon/fake-and-real-news-dataset dataset"
title = "Fake News Detector"
examples = ["CNN - Two people in China are being treated for plague, authorities said Tuesday. It’s the second time the disease, the same one that caused the Black Death, one of the deadliest pandemics in human history, has been detected in the region – in May, a Mongolian couple died from bubonic plague after eating the raw kidney of a marmot, a local folk health remedy."]
iface = gr.Interface(fn=test, inputs="text", outputs="text", title=title,description=description, examples=examples)
iface.launch()