muellerzr HF staff commited on
Commit
3655067
1 Parent(s): 56584f9

No fastai init

Browse files
Files changed (6) hide show
  1. exported_model.pth +3 -0
  2. requirements.txt +5 -0
  3. src/__init__.py +0 -0
  4. src/app.py +53 -0
  5. src/model.py +63 -0
  6. src/transform.py +106 -0
exported_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09649d02c644e6f4f0f61afb9ac97e46e00a8c91aca7511cf46d6b37be21da80
3
+ size 67880401
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio=="3.18.0"
2
+ pillow=="9.4.0"
3
+ timm=="0.6.12"
4
+ torch=="1.13.1"
5
+ torchvision=="0.14.1"
src/__init__.py ADDED
File without changes
src/app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from model import get_model, apply_weights, copy_weight
5
+ from transform import crop, pad, gpu_crop
6
+ from torchvision.transforms import Normalize, ToTensor
7
+
8
+ # Vocab
9
+ vocab = [
10
+ 'Abyssinian', 'Bengal', 'Birman',
11
+ 'Bombay', 'British_Shorthair',
12
+ 'Egyptian_Mau', 'Maine_Coon',
13
+ 'Persian', 'Ragdoll', 'Russian_Blue',
14
+ 'Siamese', 'Sphynx', 'american_bulldog',
15
+ 'american_pit_bull_terrier', 'basset_hound',
16
+ 'beagle', 'boxer', 'chihuahua', 'english_cocker_spaniel',
17
+ 'english_setter', 'german_shorthaired',
18
+ 'great_pyrenees', 'havanese',
19
+ 'japanese_chin', 'keeshond',
20
+ 'leonberger', 'miniature_pinscher', 'newfoundland',
21
+ 'pomeranian', 'pug', 'saint_bernard', 'samoyed',
22
+ 'scottish_terrier', 'shiba_inu', 'staffordshire_bull_terrier',
23
+ 'wheaten_terrier', 'yorkshire_terrier'
24
+ ]
25
+
26
+
27
+ model = get_model()
28
+ state = torch.load('../exported_model.pth')["model"]
29
+ apply_weights(model, state, copy_weight)
30
+ model.cuda()
31
+
32
+ to_tensor = ToTensor()
33
+ norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
34
+
35
+ def classify_image(inp):
36
+ inp = Image.fromarray(inp)
37
+ transformed_input = pad(crop(inp, (460, 460)), (460, 460))
38
+ transformed_input = to_tensor(transformed_input).unsqueeze(0)
39
+ transformed_input = gpu_crop(transformed_input, (224, 224))
40
+ transformed_input = norm(transformed_input).cuda()
41
+ model.eval()
42
+ with torch.no_grad():
43
+ pred = model(transformed_input)
44
+ pred = torch.argmax(pred, dim=1)
45
+ return vocab[pred]
46
+
47
+ iface = gr.Interface(
48
+ fn=classify_image,
49
+ inputs=gr.inputs.Image(),
50
+ outputs="text",
51
+ title="NO Fastai Classifier",
52
+ description="An example of not using Fastai in Gradio.",
53
+ ).launch()
src/model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torchvision.transforms.functional as tvf
4
+ import torchvision.transforms as tvtfms
5
+ import operator as op
6
+ from PIL import Image
7
+ from torch import nn
8
+ from timm import create_model
9
+
10
+ # For type hinting later on
11
+ import collections
12
+ import typing
13
+
14
+ def get_model():
15
+ net = create_model("vit_tiny_patch16_224", pretrained=False, num_classes=0, in_chans=3)
16
+ head = nn.Sequential(
17
+ nn.BatchNorm1d(192),
18
+ nn.Dropout(0.25),
19
+ nn.Linear(192, 512, bias=False),
20
+ nn.ReLU(inplace=True),
21
+ nn.BatchNorm1d(512),
22
+ nn.Dropout(0.5),
23
+ nn.Linear(512, 37, bias=False)
24
+ )
25
+ model = nn.Sequential(net, head)
26
+ return model
27
+
28
+ def copy_weight(name, parameter, state_dict):
29
+ """
30
+ Takes in a layer `name`, model `parameter`, and `state_dict`
31
+ and loads the weights from `state_dict` into `parameter`
32
+ if it exists.
33
+ """
34
+ # Part of the body
35
+ if name[0] == "0":
36
+ name = name[:2] + "model." + name[2:]
37
+ if name in state_dict.keys():
38
+ input_parameter = state_dict[name]
39
+ if input_parameter.shape == parameter.shape:
40
+ parameter.copy_(input_parameter)
41
+ else:
42
+ print(f'Shape mismatch at layer: {name}, skipping')
43
+ else:
44
+ print(f'{name} is not in the state_dict, skipping.')
45
+
46
+ def apply_weights(input_model:nn.Module, input_weights:collections.OrderedDict, application_function:callable):
47
+ """
48
+ Takes an input state_dict and applies those weights to the `input_model`, potentially
49
+ with a modifier function.
50
+
51
+ Args:
52
+ input_model (`nn.Module`):
53
+ The model that weights should be applied to
54
+ input_weights (`collections.OrderedDict`):
55
+ A dictionary of weights, the trained model's `state_dict()`
56
+ application_function (`callable`):
57
+ A function that takes in one parameter and layer name from `input_model`
58
+ and the `input_weights`. Should apply the weights from the state dict into `input_model`.
59
+ """
60
+ model_dict = input_model.state_dict()
61
+ for name, parameter in model_dict.items():
62
+ application_function(name, parameter, input_weights)
63
+ input_model.load_state_dict(model_dict)
src/transform.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import torch
3
+ from PIL import Image
4
+ import torchvision.transforms.functional as tvf
5
+ import torch.nn.functional as F
6
+
7
+ def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
8
+ """
9
+ Takes a `PIL.Image` and crops it `size` unless one
10
+ dimension is larger than the actual image. Padding
11
+ must be performed afterwards if so.
12
+
13
+ Args:
14
+ image (`PIL.Image`):
15
+ An image to perform cropping on
16
+ size (`tuple` of integers):
17
+ A size to crop to, should be in the form
18
+ of (width, height)
19
+
20
+ Returns:
21
+ An augmented `PIL.Image`
22
+ """
23
+ top = (image.size[-2] - size[0]) // 2
24
+ left = (image.size[-1] - size[1]) // 2
25
+
26
+ top = max(top, 0)
27
+ left = max(left, 0)
28
+
29
+ height = min(top + size[0], image.size[-2])
30
+ width = min(left + size[1], image.size[-1])
31
+ return image.crop((top, left, height, width))
32
+
33
+ def pad(image, size:typing.Tuple[int,int]) -> Image:
34
+ """
35
+ Takes a `PIL.Image` and pads it to `size` with
36
+ zeros.
37
+
38
+ Args:
39
+ image (`PIL.Image`):
40
+ An image to perform padding on
41
+ size (`tuple` of integers):
42
+ A size to pad to, should be in the form
43
+ of (width, height)
44
+
45
+ Returns:
46
+ An augmented `PIL.Image`
47
+ """
48
+ top = (image.size[-2] - size[0]) // 2
49
+ left = (image.size[-1] - size[1]) // 2
50
+
51
+ pad_top = max(-top, 0)
52
+ pad_left = max(-left, 0)
53
+
54
+ height, width = (
55
+ max(size[1] - image.size[-2] + top, 0),
56
+ max(size[0] - image.size[-1] + left, 0)
57
+ )
58
+ return tvf.pad(
59
+ image,
60
+ [pad_top, pad_left, height, width],
61
+ padding_mode="constant"
62
+ )
63
+
64
+ def gpu_crop(
65
+ batch:torch.tensor,
66
+ size:typing.Tuple[int,int]
67
+ ):
68
+ """
69
+ Crops each image in `batch` to a particular `size`.
70
+
71
+ Args:
72
+ batch (array of `torch.Tensor`):
73
+ A batch of images, should be of shape `NxCxWxH`
74
+ size (`tuple` of integers):
75
+ A size to pad to, should be in the form
76
+ of (width, height)
77
+
78
+ Returns:
79
+ A batch of cropped images
80
+ """
81
+ # Split into multiple lines for clarity
82
+ affine_matrix = torch.eye(3, device=batch.device).float()
83
+ affine_matrix = affine_matrix.unsqueeze(0)
84
+ affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
85
+ affine_matrix = affine_matrix.contiguous()[:,:2]
86
+
87
+ coords = F.affine_grid(
88
+ affine_matrix, batch.shape[:2] + size, align_corners=True
89
+ )
90
+
91
+ top_range, bottom_range = coords.min(), coords.max()
92
+ zoom = 1/(bottom_range - top_range).item()*2
93
+
94
+ resizing_limit = min(
95
+ batch.shape[-2]/coords.shape[-2],
96
+ batch.shape[-1]/coords.shape[-1]
97
+ )/2
98
+
99
+ if resizing_limit > 1 and resizing_limit > zoom:
100
+ batch = F.interpolate(
101
+ batch,
102
+ scale_factor=1/resizing_limit,
103
+ mode='area',
104
+ recompute_scale_factor=True
105
+ )
106
+ return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)