DaviLima's picture
Update app.py
0bfefe0
raw
history blame
2.19 kB
import torch
from torch import nn
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm import tqdm
import gradio as gr
model_name = 'neuralmind/bert-base-portuguese-cased'
tokenizer = BertTokenizer.from_pretrained(model_name)
def predict(model, loader):
model.eval()
predictions = []
with torch.no_grad():
for batch in loader:
input_ids, attention_mask = batch
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
batch_predictions = logits.argmax(dim=1).cpu().tolist()
predictions.extend(batch_predictions)
return predictions
def generate_predictions(text):
input_encodings = tokenizer(
text, truncation=True, padding=True, max_length=512, return_tensors='pt'
)
input_dataset = torch.utils.data.TensorDataset(
input_encodings['input_ids'], input_encodings['attention_mask']
)
input_loader = torch.utils.data.DataLoader(
input_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
)
# Make predictions
predictions = predict(loaded_model, input_loader)
return predictions[0]
# Specify the device as CPU
device = torch.device('cpu')
# Load the saved model and map it to the CPU
loaded_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
loaded_model.load_state_dict(torch.load('best_model8.pt', map_location=device))
loaded_model.to(device)
# Define the Gradio interface
iface = gr.Interface(
fn=generate_predictions,
inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
outputs=gr.outputs.Label(label="Prediction"),
examples=[
["Seu Comunista!"],
['Os imigrantes não deveriam ser impedidos de entrar no meu país'],
['Os imigrantes deveriam ser impedidos de entrar no meu país'],
['eu te amo'],
['aquele cara é um babaca'],
]
)
# Launch the interface
iface.launch()