sneedium commited on
Commit
653f12c
·
1 Parent(s): cb0ed46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -30
app.py CHANGED
@@ -4,39 +4,12 @@ os.system("curl -L -o tensor.pt https://seyarabata.com/btfo_by_24mb_model")
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
- from torchvision import transforms as T
8
- from typing import Tuple
9
 
10
-
11
- def rand_augment_transform(magnitude=5, num_layers=3):
12
- # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
13
- hparams = {
14
- 'rotate_deg': 30,
15
- 'shear_x_pct': 0.9,
16
- 'shear_y_pct': 0.2,
17
- 'translate_x_pct': 0.10,
18
- 'translate_y_pct': 0.30
19
- }
20
- ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS)
21
- # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
22
- choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))]
23
- return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
24
-
25
- def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0):
26
- transforms = []
27
- if augment:
28
- transforms.append(rand_augment_transform())
29
- if rotation:
30
- transforms.append(lambda img: img.rotate(rotation, expand=True))
31
- transforms.extend([
32
- T.Resize(img_size, T.InterpolationMode.BICUBIC),
33
- T.ToTensor(),
34
- T.Normalize(0.5, 0.5)
35
- ])
36
- return T.Compose(transforms)
37
 
38
  parseq = torch.load('tensor.pt', map_location=torch.device('cpu')).eval()
39
- img_transform = get_transform(parseq.hparams.img_size)
40
 
41
  def captcha_solver(img):
42
  img = img.convert('RGB')
 
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
 
7
 
8
+ from strhub.data.module import SceneTextDataModule
9
+ from strhub.models.utils import load_from_checkpoint, parse_model_args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  parseq = torch.load('tensor.pt', map_location=torch.device('cpu')).eval()
12
+ img_transform = SceneTextDataModule.get_transform(parseq.hparams.img_size)
13
 
14
  def captcha_solver(img):
15
  img = img.convert('RGB')