Spaces:
Sleeping
Sleeping
### 1. Imports and class names setup ### | |
import gradio as gr | |
import os | |
import torch | |
import torchtext | |
from model import xlmr_base_encoder_model | |
from timeit import default_timer as timer | |
from torchdata.datapipes.iter import IterableWrapper | |
from torch.utils.data import DataLoader | |
import torchtext.functional as F | |
import torchtext.transforms as T | |
from torch.hub import load_state_dict_from_url | |
# Setup class names | |
class_names = ["Bad", "Good"] | |
### 2. Model and transforms preparation ### | |
model, transforms = xlmr_base_encoder_model( | |
num_classes = 2 | |
) | |
# load save weights | |
model.load_state_dict( | |
torch.load( | |
f = "xlmr_base_encoder.pth", | |
map_location = torch.device("cpu") # Load the model to the CPU | |
) | |
) | |
### 3. Predict function ### | |
def predict(string): | |
start_time = timer() | |
var = (string, -9999999) | |
dp = IterableWrapper([var]) | |
dp = dp.sharding_filter() | |
padding_idx = 1 | |
bos_idx = 0 | |
eos_idx = 2 | |
max_seq_len = 256 | |
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" | |
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" | |
text_transform = T.Sequential( | |
T.SentencePieceTokenizer(xlmr_spm_model_path), | |
T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)), | |
T.Truncate(max_seq_len-2), | |
T.AddToken(token = bos_idx, begin = True), | |
T.AddToken(token = eos_idx, begin = False) | |
) | |
# Transform the raw dataset using non-batched API (i.e apply transformation line by line) | |
def apply_transform(x): | |
return text_transform(x[0]), x[1] | |
dp = dp.map(apply_transform) | |
dp = dp.batch(1) | |
dp = dp.rows2columnar(["token_ids", "target"]) | |
dp = DataLoader(dp, batch_size=None) | |
val = next(iter(dp)) | |
model.to('cpu') | |
value = F.to_tensor(val["token_ids"], padding_value = padding_idx).to('cpu') | |
# Pass transformed image through the model and turn the prediction logits into probabilities | |
model.eval() | |
with torch.inference_mode(): | |
answer = model(value) | |
print(answer) | |
# answer = answer.argmax(1) | |
answer = torch.softmax(answer, dim=1) | |
pred_labels_and_probs = {class_names[i]: float(answer[0][i]) for i in range(len(class_names))} | |
# Calculate pred time | |
end_time = timer() | |
pred_time = round(end_time - start_time, 4) | |
# Return pred dict and pred time | |
return pred_labels_and_probs, pred_time | |
### 4. Gradio app ### | |
title = "Good or Bad" | |
description = "Using XLMR_BASE_ENCODER" | |
# Create the gradio demo | |
demo = gr.Interface( | |
fn = predict, # maps inputs to outputs | |
inputs = "textbox", | |
outputs=[ | |
gr.Label(num_top_classes=2, label="Predictions"), | |
gr.Number(label = "Prediction time(s) ") | |
], | |
title = title, | |
description = description, | |
# article = article | |
) | |
# launch the demo! | |
demo.launch() | |