File size: 1,812 Bytes
49c9603
 
 
 
 
 
 
 
 
ce8dba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49c9603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os 
os.system("curl -L -o tensor.pt https://seyarabata.com/btfo_by_24mb_model")

import torch
from PIL import Image
import gradio as gr
from torchvision import transforms as T
from typing import Tuple


def rand_augment_transform(magnitude=5, num_layers=3):
    # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
    hparams = {
        'rotate_deg': 30,
        'shear_x_pct': 0.9,
        'shear_y_pct': 0.2,
        'translate_x_pct': 0.10,
        'translate_y_pct': 0.30
    }
    ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS)
    # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
    choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))]
    return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)

def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0):
    transforms = []
    if augment:
        transforms.append(rand_augment_transform())
    if rotation:
        transforms.append(lambda img: img.rotate(rotation, expand=True))
    transforms.extend([
        T.Resize(img_size, T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(0.5, 0.5)
    ])
    return T.Compose(transforms)

parseq = torch.load('tensor.pt', map_location=torch.device('cpu')).eval()
img_transform = get_transform(parseq.hparams.img_size)

def captcha_solver(img):
  img = img.convert('RGB')
  img = img_transform(img).unsqueeze(0)

  logits = parseq(img)
  logits.shape
  
  # # Greedy decoding
  pred = logits.softmax(-1)
  label, confidence = parseq.tokenizer.decode(pred)
  return label[0]

demo = gr.Interface(fn=captcha_solver, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Textbox())
demo.launch()