### 1. Imports and class names setup ### import gradio as gr import os import torch import torchtext from model import xlmr_base_encoder_model from timeit import default_timer as timer from torchdata.datapipes.iter import IterableWrapper from torch.utils.data import DataLoader import torchtext.functional as F import torchtext.transforms as T from torch.hub import load_state_dict_from_url # Setup class names class_names = ["Bad", "Good"] ### 2. Model and transforms preparation ### model, transforms = xlmr_base_encoder_model( num_classes = 2 ) # load save weights model.load_state_dict( torch.load( f = "xlmr_base_encoder.pth", map_location = torch.device("cpu") # Load the model to the CPU ) ) ### 3. Predict function ### def predict(string): start_time = timer() var = (string, -9999999) dp = IterableWrapper([var]) dp = dp.sharding_filter() padding_idx = 1 bos_idx = 0 eos_idx = 2 max_seq_len = 256 xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" text_transform = T.Sequential( T.SentencePieceTokenizer(xlmr_spm_model_path), T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)), T.Truncate(max_seq_len-2), T.AddToken(token = bos_idx, begin = True), T.AddToken(token = eos_idx, begin = False) ) # Transform the raw dataset using non-batched API (i.e apply transformation line by line) def apply_transform(x): return text_transform(x[0]), x[1] dp = dp.map(apply_transform) dp = dp.batch(1) dp = dp.rows2columnar(["token_ids", "target"]) dp = DataLoader(dp, batch_size=None) val = next(iter(dp)) model.to('cpu') value = F.to_tensor(val["token_ids"], padding_value = padding_idx).to('cpu') # Pass transformed image through the model and turn the prediction logits into probabilities model.eval() with torch.inference_mode(): answer = model(value) print(answer) # answer = answer.argmax(1) answer = torch.softmax(answer, dim=1) pred_labels_and_probs = {class_names[i]: float(answer[0][i]) for i in range(len(class_names))} # Calculate pred time end_time = timer() pred_time = round(end_time - start_time, 4) # Return pred dict and pred time return pred_labels_and_probs, pred_time ### 4. Gradio app ### title = "Good or Bad" description = "Using XLMR_BASE_ENCODER" # Create the gradio demo demo = gr.Interface( fn = predict, # maps inputs to outputs inputs = "textbox", outputs=[ gr.Label(num_top_classes=2, label="Predictions"), gr.Number(label = "Prediction time(s) ") ], title = title, description = description, # article = article ) # launch the demo! demo.launch()