|
import gradio as gr |
|
import torch |
|
|
|
import torch.nn.functional as F |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
model = BertForSequenceClassification.from_pretrained("bert-base-uncased") |
|
|
|
model.load_state_dict(torch.load('model_after_train.pt', map_location=torch.device('cpu')), strict=False) |
|
model.eval() |
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
|
|
def preprocess_text(text): |
|
parts = [] |
|
|
|
text_len = len(text.split(' ')) |
|
delta = 300 |
|
max_parts = 5 |
|
nb_cuts = int(text_len / delta) |
|
nb_cuts = min(nb_cuts, max_parts) |
|
|
|
|
|
for i in range(nb_cuts + 1): |
|
text_part = ' '.join(text.split(' ')[i * delta: (i + 1) * delta]) |
|
parts.append(tokenizer.encode(text_part, return_tensors="pt", max_length=500).to(device)) |
|
|
|
return parts |
|
|
|
def test(text): |
|
text_parts = preprocess_text(text) |
|
overall_output = torch.zeros((1,2)).to(device) |
|
try: |
|
for part in text_parts: |
|
if len(part) > 0: |
|
overall_output += model(part.reshape(1, -1))[0] |
|
except RuntimeError: |
|
print("GPU out of memory, skipping this entry.") |
|
|
|
overall_output = F.softmax(overall_output[0], dim=-1) |
|
|
|
value, result = overall_output.max(0) |
|
|
|
term = "fake" |
|
if result.item() == 0: |
|
term = "real" |
|
|
|
return term + " at " + str(int(value.item()*100)) + " %" |
|
|
|
|
|
description = "Fake news detector trained using pre-trained model bert-base-uncased, fine-tuned on https://www.kaggle.com/clmentbisaillon/fake-and-real-news-dataset dataset" |
|
title = "Fake News Detector" |
|
|
|
examples = ["CNN - Two people in China are being treated for plague, authorities said Tuesday. It’s the second time the disease, the same one that caused the Black Death, one of the deadliest pandemics in human history, has been detected in the region – in May, a Mongolian couple died from bubonic plague after eating the raw kidney of a marmot, a local folk health remedy."] |
|
|
|
iface = gr.Interface(fn=test, inputs="text", outputs="text", title=title,description=description, examples=examples) |
|
iface.launch() |