### 1. Imports and class names setup ### import gradio as gr import os import torch from transformers import BertTokenizer, BertModel, BertConfig from transformers import BertForSequenceClassification # from model import create_effnetb2_model from timeit import default_timer as timer # from typing import Tuple, Dict # Setup class names # class_names = ["pizza", "steak", "sushi"] ### 2. Model and transforms preparation ### tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # Create BERT model model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, output_attentions=False, output_hidden_states=False) model.load_state_dict(torch.load(f='finetuned_BERT_epoch_10.model', map_location=torch.device('cpu'))) ### 3. Predict function ### # Create predict function def predict(text) : """Transforms and performs a prediction on Text. """ # Start the timer start_time = timer() encoding = tokenizer.encode_plus( text, None, add_special_tokens=True, max_length=256, pad_to_max_length=True, return_token_type_ids=True, return_tensors='pt' ) model.eval() loss_val_total = 0 predictions = [] # batch = tuple(prediction) inputs = {'input_ids': encoding["input_ids"], 'attention_mask': encoding["attention_mask"], } with torch.no_grad(): outputs = model(**inputs) print(outputs) # loss = outputs[0] logits = outputs[0] # loss_val_total += loss.item() logits = logits.detach().cpu().numpy() # print(logits) # label_ids = inputs['labels'].cpu().numpy() predictions.append(logits) # true_vals.append(label_ids) # loss_val_avg = loss_val_total/len(dataloader_val) predictions = np.concatenate(predictions, axis=0) preds_flat = np.argmax(predictions, axis=1).flatten() if preds_flat==0: prediction = "positive" else: prediction = "negative" # Calculate the prediction time pred_time = round(timer() - start_time, 5) # Return the prediction dictionary and prediction time return prediction, pred_time ### 4. Gradio app ### # Create title, description and article strings title = "Sentiment Analysis" description = "An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi." # Create examples list from "examples/" directory # example_list = [["examples/" + example] for example in os.listdir("examples")] # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs=["text", "checkbox"], outputs=["text", gr.Number(label="Prediction time (s)")], # Create examples list from "examples/" directory # examples=example_list, title=title, description=description) # Launch the demo! demo.launch()