ahamedddd's picture
Update app.py
f9f7672 verified
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import torchtext
from model import xlmr_base_encoder_model
from timeit import default_timer as timer
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
import torchtext.functional as F
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
# Setup class names
class_names = ["Bad", "Good"]
### 2. Model and transforms preparation ###
model, transforms = xlmr_base_encoder_model(
num_classes = 2
)
# load save weights
model.load_state_dict(
torch.load(
f = "xlmr_base_encoder.pth",
map_location = torch.device("cpu") # Load the model to the CPU
)
)
### 3. Predict function ###
def predict(string):
start_time = timer()
var = (string, -9999999)
dp = IterableWrapper([var])
dp = dp.sharding_filter()
padding_idx = 1
bos_idx = 0
eos_idx = 2
max_seq_len = 256
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
text_transform = T.Sequential(
T.SentencePieceTokenizer(xlmr_spm_model_path),
T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),
T.Truncate(max_seq_len-2),
T.AddToken(token = bos_idx, begin = True),
T.AddToken(token = eos_idx, begin = False)
)
# Transform the raw dataset using non-batched API (i.e apply transformation line by line)
def apply_transform(x):
return text_transform(x[0]), x[1]
dp = dp.map(apply_transform)
dp = dp.batch(1)
dp = dp.rows2columnar(["token_ids", "target"])
dp = DataLoader(dp, batch_size=None)
val = next(iter(dp))
model.to('cpu')
value = F.to_tensor(val["token_ids"], padding_value = padding_idx).to('cpu')
# Pass transformed image through the model and turn the prediction logits into probabilities
model.eval()
with torch.inference_mode():
answer = model(value)
print(answer)
# answer = answer.argmax(1)
answer = torch.softmax(answer, dim=1)
pred_labels_and_probs = {class_names[i]: float(answer[0][i]) for i in range(len(class_names))}
# Calculate pred time
end_time = timer()
pred_time = round(end_time - start_time, 4)
# Return pred dict and pred time
return pred_labels_and_probs, pred_time
### 4. Gradio app ###
title = "Good or Bad"
description = "Using XLMR_BASE_ENCODER"
# Create the gradio demo
demo = gr.Interface(
fn = predict, # maps inputs to outputs
inputs = "textbox",
outputs=[
gr.Label(num_top_classes=2, label="Predictions"),
gr.Number(label = "Prediction time(s) ")
],
title = title,
description = description,
# article = article
)
# launch the demo!
demo.launch()