Fred Zhang commited on
Commit
eaacc51
1 Parent(s): 3826cd5
Files changed (4) hide show
  1. Config.py +36 -0
  2. Model.py +117 -0
  3. config.json +48 -0
  4. pytorch_model.bin +3 -0
Config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ classes_example = {
5
+ 0: 'nsfw_gore',
6
+ 1: 'nsfw_suggestive',
7
+ 2: 'safe'
8
+ }
9
+
10
+ class InceptionV3Config(PretrainedConfig):
11
+ model_type = "inceptionv3"
12
+ def __init__(self, model_name: str = "inception_v3", input_channels: int = 3, num_classes: int = 3, input_size: List[int] = [3, 299, 299], pool_size: List[int] = [8, 8, 2048], crop_pct: float = 0.875, interpolation: str = "bicubic", mean: List[float] = [0.5, 0.5, 0.5], std: List[float] = [0.5, 0.5, 0.5], first_conv: str = "Conv2d_1a_3x3.conv", classifier: str = "fc", has_aux: bool = True, label_offset: int = 1, classes: dict = classes_example, output_channels: int = 2048, use_jit=False, **kwargs):
13
+ self.model_name = model_name
14
+ self.input_channels = input_channels
15
+ self.num_classes = num_classes
16
+ self.input_size = input_size
17
+ self.pool_size = pool_size
18
+ self.crop_pct = crop_pct
19
+ self.interpolation = interpolation
20
+ self.mean = mean
21
+ self.std = std
22
+ self.first_conv = first_conv
23
+ self.classifier = classifier
24
+ self.has_aux = has_aux
25
+ self.label_offset = label_offset
26
+ self.classes = classes
27
+ self.output_channels = output_channels
28
+ self.use_jit = use_jit
29
+ super().__init__(**kwargs)
30
+
31
+ """
32
+
33
+ inceptionv3_config = InceptionV3Config()
34
+ inceptionv3_config.save_pretrained("inceptionv3_config")
35
+
36
+ """
Model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ import torch
3
+ import os
4
+
5
+ url_map = {
6
+ "inception_v3": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth"
7
+ }
8
+
9
+ class InceptionV3ModelForImageClassification(PreTrainedModel):
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+
13
+ model_path = f"{self.config.model_name}.bin".replace("/","_")
14
+
15
+ if self.config.model_name == "google-safesearch-mini":
16
+ self.model = torch.jit.load(model_path)
17
+ elif self.config.model_name == "inception_v3":
18
+ self.model = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
19
+ else:
20
+ if not os.path.exists(model_path):
21
+ from urllib.request import urlretrieve
22
+ urlretrieve(f"https://huggingface.co/{self.config.model_name}/resolve/main/pytorch_model.bin", model_path)
23
+ self.model = torch.jit.load(model_path) if self.config.use_jit else torch.load(model_path)
24
+
25
+ def forward(self, input_ids):
26
+ out, aux = self.model(input_ids)
27
+ return out, aux
28
+
29
+ def freeze(self):
30
+ for param in self.model.parameters():
31
+ param.requires_grad = False
32
+
33
+ def unfreeze(self):
34
+ for param in self.model.parameters():
35
+ param.requires_grad = True
36
+
37
+ def train(self, mode=True):
38
+ super().train(mode)
39
+ self.model.train(mode)
40
+
41
+ def eval(self):
42
+ return self.train(False)
43
+
44
+ def to(self, device):
45
+ self.model.to(device)
46
+ return self
47
+
48
+ def cuda(self, device=None):
49
+ return self.to("cuda")
50
+
51
+ def cpu(self):
52
+ return self.to("cpu")
53
+
54
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
55
+ return self.model.state_dict(destination, prefix, keep_vars)
56
+
57
+ def load_state_dict(self, state_dict, strict=True):
58
+ return self.model.load_state_dict(state_dict, strict)
59
+
60
+ def parameters(self, recurse=True):
61
+ return self.model.parameters(recurse)
62
+
63
+ def named_parameters(self, prefix='', recurse=True):
64
+ return self.model.named_parameters(prefix, recurse)
65
+
66
+ def children(self):
67
+ return self.model.children()
68
+
69
+ def named_children(self):
70
+ return self.model.named_children()
71
+
72
+ def modules(self):
73
+ return self.model.modules()
74
+
75
+ def named_modules(self, memo=None, prefix=''):
76
+ return self.model.named_modules(memo, prefix)
77
+
78
+ def zero_grad(self, set_to_none=False):
79
+ return self.model.zero_grad(set_to_none)
80
+
81
+ def share_memory(self):
82
+ return self.model.share_memory()
83
+
84
+ def transform(self, image):
85
+ from torchvision import transforms
86
+ transform = transforms.Compose([
87
+ transforms.Resize(299),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize(mean=self.config.mean, std=self.config.std)
90
+ ])
91
+ image = transform(image)
92
+ return image
93
+
94
+ def open_image(self, path):
95
+ from PIL import Image
96
+ path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1'
97
+ if path.startswith('http://') or path.startswith('https://'):
98
+ import requests
99
+ from io import BytesIO
100
+ response = requests.get(path)
101
+ image = Image.open(BytesIO(response.content)).convert('RGB')
102
+ else:
103
+ image = Image.open(path).convert('RGB')
104
+ return image
105
+
106
+ def predict(self, path, device="cuda"):
107
+ image = self.open_image(path)
108
+ image = self.transform(image)
109
+ image = image.unsqueeze(0)
110
+ self.eval()
111
+ if device == "cuda":
112
+ image = image.cuda()
113
+ with torch.no_grad():
114
+ out, aux = self(image)
115
+ print(out)
116
+ _, predicted = torch.max(out.data, 1)
117
+ return self.config.classes[predicted.item()]
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "InceptionV3ModelForImageClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "Config.InceptionV3Config",
7
+ "AutoModelForImageClassification": "Model.InceptionV3ModelForImageClassification"
8
+ },
9
+ "classes": {
10
+ "0": "nsfw_gore",
11
+ "1": "nsfw_suggestive",
12
+ "2": "safe"
13
+ },
14
+ "classifier": "fc",
15
+ "crop_pct": 0.875,
16
+ "first_conv": "Conv2d_1a_3x3.conv",
17
+ "has_aux": true,
18
+ "input_channels": 3,
19
+ "input_size": [
20
+ 3,
21
+ 299,
22
+ 299
23
+ ],
24
+ "interpolation": "bicubic",
25
+ "label_offset": 1,
26
+ "mean": [
27
+ 0.5,
28
+ 0.5,
29
+ 0.5
30
+ ],
31
+ "model_name": "google-safesearch-mini",
32
+ "model_type": "inceptionv3",
33
+ "num_classes": 3,
34
+ "output_channels": 2048,
35
+ "pool_size": [
36
+ 8,
37
+ 8,
38
+ 2048
39
+ ],
40
+ "std": [
41
+ 0.5,
42
+ 0.5,
43
+ 0.5
44
+ ],
45
+ "torch_dtype": "float32",
46
+ "transformers_version": "4.21.2",
47
+ "use_jit": true
48
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db510376e428b0d5f1472e4f56d31a4bfbee69b3e8a58c67a802098e00d42d12
3
+ size 100804217