File size: 2,786 Bytes
d222e8f
 
 
 
 
 
 
 
 
 
 
30eb8c5
 
d222e8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970b15c
d222e8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9f7672
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import torchtext
from model import xlmr_base_encoder_model
from timeit import default_timer as timer
from torchdata.datapipes.iter import IterableWrapper
from torch.utils.data import DataLoader
import torchtext.functional as F
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url

# Setup class names
class_names = ["Bad", "Good"]

### 2. Model and transforms preparation ###
model, transforms = xlmr_base_encoder_model(
  num_classes = 2
)

# load save weights
model.load_state_dict(
  torch.load(
    f = "xlmr_base_encoder.pth",
    map_location = torch.device("cpu") # Load the model to the CPU
  )
)

### 3. Predict function ###

def predict(string):

  start_time = timer()

  var = (string, -9999999)
  dp = IterableWrapper([var])
  dp = dp.sharding_filter()

  padding_idx = 1
  bos_idx = 0
  eos_idx = 2
  max_seq_len = 256
  xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
  xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"

  text_transform = T.Sequential(
      T.SentencePieceTokenizer(xlmr_spm_model_path),
      T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),
      T.Truncate(max_seq_len-2),
      T.AddToken(token = bos_idx, begin = True),
      T.AddToken(token = eos_idx, begin = False)
  )

  # Transform the raw dataset using non-batched API (i.e apply transformation line by line)
  def apply_transform(x):
    return text_transform(x[0]), x[1]

  dp = dp.map(apply_transform)
  dp = dp.batch(1)
  dp = dp.rows2columnar(["token_ids", "target"])
  dp = DataLoader(dp, batch_size=None)

  val = next(iter(dp))
  model.to('cpu')
  value = F.to_tensor(val["token_ids"], padding_value = padding_idx).to('cpu')
    # Pass transformed image through the model and turn the prediction logits into probabilities
  model.eval()
  with torch.inference_mode():
    answer = model(value)
  print(answer)
  # answer = answer.argmax(1)
  answer = torch.softmax(answer, dim=1)
  pred_labels_and_probs = {class_names[i]: float(answer[0][i]) for i in range(len(class_names))}


  # Calculate pred time
  end_time = timer()
  pred_time = round(end_time - start_time, 4)

  # Return pred dict and pred time
  return pred_labels_and_probs, pred_time

### 4. Gradio app ###
title = "Good or Bad"
description = "Using XLMR_BASE_ENCODER"

# Create the gradio demo
demo = gr.Interface(
    fn = predict, # maps inputs to outputs
    inputs = "textbox",
    outputs=[
        gr.Label(num_top_classes=2, label="Predictions"),
        gr.Number(label = "Prediction time(s) ")
    ],
    title = title,
    description = description,
    # article = article
)

# launch the demo!
demo.launch()