|
import torch |
|
from torch import nn |
|
from transformers import BertTokenizer, BertForSequenceClassification, AdamW |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import classification_report |
|
from tqdm import tqdm |
|
import gradio as gr |
|
|
|
model_name = 'neuralmind/bert-base-portuguese-cased' |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
|
|
def predict(model, loader): |
|
model.eval() |
|
predictions = [] |
|
with torch.no_grad(): |
|
for batch in loader: |
|
input_ids, attention_mask = batch |
|
input_ids = input_ids.to(device) |
|
attention_mask = attention_mask.to(device) |
|
|
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
batch_predictions = logits.argmax(dim=1).cpu().tolist() |
|
predictions.extend(batch_predictions) |
|
|
|
return predictions |
|
|
|
def generate_predictions(text): |
|
input_encodings = tokenizer( |
|
text, truncation=True, padding=True, max_length=512, return_tensors='pt' |
|
) |
|
input_dataset = torch.utils.data.TensorDataset( |
|
input_encodings['input_ids'], input_encodings['attention_mask'] |
|
) |
|
input_loader = torch.utils.data.DataLoader( |
|
input_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True |
|
) |
|
|
|
|
|
predictions = predict(loaded_model, input_loader) |
|
|
|
return predictions[0] |
|
|
|
|
|
device = torch.device('cpu') |
|
|
|
|
|
loaded_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) |
|
loaded_model.load_state_dict(torch.load('best_model8.pt', map_location=device)) |
|
loaded_model.to(device) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_predictions, |
|
inputs=gr.inputs.Textbox(lines=5, label="Input Text"), |
|
outputs=gr.outputs.Label(label="Prediction"), |
|
examples=[ |
|
["Seu Comunista!"], |
|
['Os imigrantes não deveriam ser impedidos de entrar no meu país'], |
|
['Os imigrantes deveriam ser impedidos de entrar no meu país'], |
|
['eu te amo'], |
|
['aquele cara é um babaca'], |
|
] |
|
) |
|
|
|
|
|
iface.launch() |
|
|