Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from transformers import BertTokenizer, BertForSequenceClassification, AdamW | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import classification_report | |
from tqdm import tqdm | |
import gradio as gr | |
import string | |
model_name = 'neuralmind/bert-base-portuguese-cased' | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
def predict(model, loader): | |
model.eval() | |
predictions = [] | |
with torch.no_grad(): | |
for batch in loader: | |
input_ids, attention_mask = batch | |
input_ids = input_ids.to(device) | |
attention_mask = attention_mask.to(device) | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits | |
batch_predictions = logits.argmax(dim=1).cpu().tolist() | |
predictions.extend(batch_predictions) | |
return predictions | |
def preprocess_text(text): | |
# Remove pontuação | |
text = text.translate(str.maketrans("", "", string.punctuation)) | |
# Converter para letras minúsculas | |
text = text.lower() | |
return text | |
def generate_predictions(text): | |
sentences = text.split(".") | |
sentences = [preprocess_text(sentence) for sentence in sentences] | |
predictions = [] | |
for sentence in sentences: | |
input_encodings = tokenizer( | |
sentence, truncation=True, padding=True, max_length=512, return_tensors='pt' | |
) | |
input_dataset = torch.utils.data.TensorDataset( | |
input_encodings['input_ids'], input_encodings['attention_mask'] | |
) | |
input_loader = torch.utils.data.DataLoader( | |
input_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True | |
) | |
# Make predictions | |
sentence_prediction = predict(loaded_model, input_loader)[0] | |
predictions.append(sentence_prediction) | |
return predictions | |
# Specify the device as CPU | |
device = torch.device('cpu') | |
# Load the saved model and map it to the CPU | |
loaded_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) | |
loaded_model.load_state_dict(torch.load('best_model8.pt', map_location=device)) | |
loaded_model.to(device) | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=generate_predictions, | |
inputs=gr.inputs.Textbox(lines=5, label="Input Text"), | |
outputs=gr.outputs.Label(num_top_classes=2, label="Prediction"), | |
examples=[ | |
["Seu Comunista!"], | |
['Os imigrantes não deveriam ser impedidos de entrar no meu país'], | |
['Os imigrantes deveriam ser impedidos de entrar no meu país'], | |
['eu te amo'], | |
['aquele cara é um babaca'], | |
] | |
) | |
# Launch the interface | |
iface.launch() | |