Spaces:
Sleeping
Sleeping
### 1. Imports and class names setup ### | |
import gradio as gr | |
import os | |
import torch | |
from transformers import BertTokenizer, BertModel, BertConfig | |
from transformers import BertForSequenceClassification | |
# from model import create_effnetb2_model | |
from timeit import default_timer as timer | |
# from typing import Tuple, Dict | |
# Setup class names | |
# class_names = ["pizza", "steak", "sushi"] | |
### 2. Model and transforms preparation ### | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', | |
do_lower_case=True) | |
# Create BERT model | |
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", | |
num_labels=2, | |
output_attentions=False, | |
output_hidden_states=False) | |
model.load_state_dict(torch.load(f='finetuned_BERT_epoch_10.model', map_location=torch.device('cpu'))) | |
### 3. Predict function ### | |
# Create predict function | |
def predict(text) : | |
"""Transforms and performs a prediction on Text. | |
""" | |
# Start the timer | |
start_time = timer() | |
encoding = tokenizer.encode_plus( | |
text, | |
None, | |
add_special_tokens=True, | |
max_length=256, | |
pad_to_max_length=True, | |
return_token_type_ids=True, | |
return_tensors='pt' | |
) | |
model.eval() | |
loss_val_total = 0 | |
predictions = [] | |
# batch = tuple(prediction) | |
inputs = {'input_ids': encoding["input_ids"], | |
'attention_mask': encoding["attention_mask"], | |
} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
print(outputs) | |
# loss = outputs[0] | |
logits = outputs[0] | |
# loss_val_total += loss.item() | |
logits = logits.detach().cpu().numpy() | |
# print(logits) | |
# label_ids = inputs['labels'].cpu().numpy() | |
predictions.append(logits) | |
# true_vals.append(label_ids) | |
# loss_val_avg = loss_val_total/len(dataloader_val) | |
predictions = np.concatenate(predictions, axis=0) | |
preds_flat = np.argmax(predictions, axis=1).flatten() | |
if preds_flat==0: | |
prediction = "positive" | |
else: | |
prediction = "negative" | |
# Calculate the prediction time | |
pred_time = round(timer() - start_time, 5) | |
# Return the prediction dictionary and prediction time | |
return prediction, pred_time | |
### 4. Gradio app ### | |
# Create title, description and article strings | |
title = "Sentiment Analysis" | |
description = "An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi." | |
# Create examples list from "examples/" directory | |
# example_list = [["examples/" + example] for example in os.listdir("examples")] | |
# Create the Gradio demo | |
demo = gr.Interface(fn=predict, # mapping function from input to output | |
inputs=["text", "checkbox"], | |
outputs=["text", | |
gr.Number(label="Prediction time (s)")], | |
# Create examples list from "examples/" directory | |
# examples=example_list, | |
title=title, | |
description=description) | |
# Launch the demo! | |
demo.launch() | |