File size: 1,043 Bytes
5ba7af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dab9b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from PIL import Image
# from strhub.data.module import SceneTextDataModule
from torchvision import transforms as T
import gradio as gr

# Load model and image transforms
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
# img_transform = SceneTextDataModule.get_transform(parseq.hparams.img_size)

transform = T.Compose([
            T.Resize(parseq.hparams.img_size, T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])


def infer(inps):
    
    img = inps.convert('RGB')
    # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
    img = transform(img).unsqueeze(0)

    logits = parseq(img)
    pred = logits.softmax(-1)
    label, confidence = parseq.tokenizer.decode(pred)
    # print('Decoded label = {}'.format(label[0]))
    return label[0]


demo = gr.Interface(fn=infer, 
             inputs=[gr.inputs.Image(type="pil")],
             outputs=[gr.outputs.Textbox(label="Output Text")]
             )
             
demo.launch()