Spaces:
Sleeping
Sleeping
File size: 2,786 Bytes
d222e8f 30eb8c5 d222e8f 970b15c d222e8f f9f7672 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import torchtext
from model import xlmr_base_encoder_model
from timeit import default_timer as timer
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
import torchtext.functional as F
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
# Setup class names
class_names = ["Bad", "Good"]
### 2. Model and transforms preparation ###
model, transforms = xlmr_base_encoder_model(
num_classes = 2
)
# load save weights
model.load_state_dict(
torch.load(
f = "xlmr_base_encoder.pth",
map_location = torch.device("cpu") # Load the model to the CPU
)
)
### 3. Predict function ###
def predict(string):
start_time = timer()
var = (string, -9999999)
dp = IterableWrapper([var])
dp = dp.sharding_filter()
padding_idx = 1
bos_idx = 0
eos_idx = 2
max_seq_len = 256
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
text_transform = T.Sequential(
T.SentencePieceTokenizer(xlmr_spm_model_path),
T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),
T.Truncate(max_seq_len-2),
T.AddToken(token = bos_idx, begin = True),
T.AddToken(token = eos_idx, begin = False)
)
# Transform the raw dataset using non-batched API (i.e apply transformation line by line)
def apply_transform(x):
return text_transform(x[0]), x[1]
dp = dp.map(apply_transform)
dp = dp.batch(1)
dp = dp.rows2columnar(["token_ids", "target"])
dp = DataLoader(dp, batch_size=None)
val = next(iter(dp))
model.to('cpu')
value = F.to_tensor(val["token_ids"], padding_value = padding_idx).to('cpu')
# Pass transformed image through the model and turn the prediction logits into probabilities
model.eval()
with torch.inference_mode():
answer = model(value)
print(answer)
# answer = answer.argmax(1)
answer = torch.softmax(answer, dim=1)
pred_labels_and_probs = {class_names[i]: float(answer[0][i]) for i in range(len(class_names))}
# Calculate pred time
end_time = timer()
pred_time = round(end_time - start_time, 4)
# Return pred dict and pred time
return pred_labels_and_probs, pred_time
### 4. Gradio app ###
title = "Good or Bad"
description = "Using XLMR_BASE_ENCODER"
# Create the gradio demo
demo = gr.Interface(
fn = predict, # maps inputs to outputs
inputs = "textbox",
outputs=[
gr.Label(num_top_classes=2, label="Predictions"),
gr.Number(label = "Prediction time(s) ")
],
title = title,
description = description,
# article = article
)
# launch the demo!
demo.launch()
|