DaviLima's picture
Update app.py
cb9ee2f
raw
history blame
2.76 kB
import torch
from torch import nn
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm import tqdm
import gradio as gr
import string
model_name = 'neuralmind/bert-base-portuguese-cased'
tokenizer = BertTokenizer.from_pretrained(model_name)
def predict(model, loader):
model.eval()
predictions = []
with torch.no_grad():
for batch in loader:
input_ids, attention_mask = batch
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
batch_predictions = logits.argmax(dim=1).cpu().tolist()
predictions.extend(batch_predictions)
return predictions
def preprocess_text(text):
# Remove pontuação
text = text.translate(str.maketrans("", "", string.punctuation))
# Converter para letras minúsculas
text = text.lower()
return text
def generate_predictions(text):
sentences = text.split(".")
sentences = [preprocess_text(sentence) for sentence in sentences]
predictions = []
for sentence in sentences:
input_encodings = tokenizer(
sentence, truncation=True, padding=True, max_length=512, return_tensors='pt'
)
input_dataset = torch.utils.data.TensorDataset(
input_encodings['input_ids'], input_encodings['attention_mask']
)
input_loader = torch.utils.data.DataLoader(
input_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
)
# Make predictions
sentence_prediction = predict(loaded_model, input_loader)[0]
predictions.append(f"{sentence}: {sentence_prediction}")
predictions_html = "<br>".join(predictions)
return predictions_html
# Specify the device as CPU
device = torch.device('cpu')
# Load the saved model and map it to the CPU
loaded_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
loaded_model.load_state_dict(torch.load('best_model8.pt', map_location=device))
loaded_model.to(device)
# Define the Gradio interface
iface = gr.Interface(
fn=generate_predictions,
inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
outputs=gr.outputs.Label(num_top_classes=2, label="Prediction"),
examples=[
["Seu Comunista!"],
['Os imigrantes não deveriam ser impedidos de entrar no meu país'],
['Os imigrantes deveriam ser impedidos de entrar no meu país'],
['eu te amo'],
['aquele cara é um babaca'],
]
)
# Launch the interface
iface.launch()