szzzzz commited on
Commit
f51b502
·
1 Parent(s): 841c792

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +226 -0
model.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Union
3
+ import pickle
4
+ import torch
5
+ import torchvision
6
+ from torch import nn
7
+ import tarfile
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+
11
+
12
+ def read_im(input: Image.Image) -> Image.Image:
13
+ """read im
14
+
15
+ Args:
16
+ input (Image.Image):
17
+ img
18
+
19
+ Returns:
20
+ Image.Image
21
+
22
+ """
23
+ im = input
24
+ if not isinstance(im, Image.Image):
25
+ raise ValueError("""`input` should be a str or bytes or Image.Image!""")
26
+ im = im.convert("RGB")
27
+
28
+ return im
29
+
30
+
31
+ class Classifier(nn.Module):
32
+ """Toxic Classifier.
33
+
34
+ Given a transformed image,`classifier` will get a toxic socre on it.
35
+
36
+ Attributes:
37
+ config (Optional[Dict],optional):
38
+ Modeling config.
39
+ Defaults to None.
40
+ """
41
+
42
+ def __init__(self, config: Optional[Dict] = None) -> None:
43
+ super().__init__()
44
+ self.config = {} if config is None else config
45
+
46
+ self.resnet = torchvision.models.resnet50()
47
+ self.resnet.fc = nn.Linear(
48
+ in_features=self.config.get("in_features", 2048),
49
+ out_features=self.config.get("tag_num", 2),)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ out = self.resnet(x)
53
+ return out
54
+
55
+ @torch.no_grad()
56
+ def score(self, input: torch.Tensor) -> List[float]:
57
+ """Scoring the input image(one input).
58
+
59
+ Args:
60
+ input (torch.Tensor):
61
+ img input(should be transformed).
62
+
63
+ Returns:
64
+ List[float]:
65
+ The toxic score of the input .
66
+ """
67
+
68
+ return (
69
+ torch.softmax(self.forward(input), dim=1).detach().cpu().view(-1).tolist())
70
+
71
+ class Detector():
72
+ """Toxic detector .
73
+
74
+ Attributes:
75
+ config (Optional[Dict],optional):
76
+ Modeling config.
77
+ Defaults to None.
78
+ """
79
+
80
+ def __init__(self,*,config: Optional[Dict] = None,) -> None:
81
+ super().__init__()
82
+
83
+ if config is None:
84
+ config = {}
85
+ self._config = config
86
+ self._in_features = config.get("in_features", 2048)
87
+ self._tag_num = config.get("tag_num", 2)
88
+ self._tags = config.get("tags", ["obscene"])
89
+
90
+ self._classifier = Classifier(self.config)
91
+ self._trans = transforms.Compose(
92
+ [
93
+ # transforms.ToPILImage()
94
+ transforms.Resize(256),
95
+ transforms.CenterCrop(size=(224, 224)),
96
+ transforms.ToTensor(),
97
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
98
+ ])
99
+
100
+ @property
101
+ def config(self):
102
+ return self._config
103
+
104
+ @config.setter
105
+ def config(self, config: Dict):
106
+ self._config = config
107
+ self._in_features = config.get("in_features", 2048)
108
+ self._tag_num = config.get("tag_num", 2)
109
+ self._tags = config.get("tags", ["obscene"])
110
+
111
+ @property
112
+ def classifier(self):
113
+ return self._classifier
114
+
115
+ def _load_pkl(self, path: str) -> Dict:
116
+ with open(path, "rb") as f:
117
+ file = pickle.load(f)
118
+ return file
119
+
120
+ def _unzip2dir(self, file: str, dir: Optional[str] = None) -> None:
121
+ if dir is None:
122
+ dir = self._tmpdir.name
123
+ if not os.path.isdir(dir):
124
+ raise ValueError("""`dir` shoud be a dir!""")
125
+ tar = tarfile.open(file, "r")
126
+ tar.extractall(path=dir)
127
+ tar.close()
128
+
129
+ def load(self, model: str) -> None:
130
+ """Load state dict from local model path .
131
+
132
+ Args:
133
+ model (str):
134
+ Model file need to be loaded.
135
+ A string, the path of a pretrained model.
136
+
137
+ Raises:
138
+ ValueError: str model should be a path!
139
+ """
140
+
141
+ if isinstance(model, str):
142
+ if os.path.isdir(model):
143
+ self._load_from_dir(model)
144
+ elif os.path.isfile(model):
145
+ dir = "./toxic_detection"
146
+ if os.path.exists(dir):
147
+ pass
148
+ else:
149
+ os.mkdir(dir)
150
+ self._unzip2dir(model, dir)
151
+ self._load_from_dir(dir)
152
+ else:
153
+ raise ValueError("""str model should be a path!""")
154
+
155
+ else:
156
+ raise ValueError("""str model should be a path!""")
157
+
158
+ def _load_from_dir(self, model_dir: str) -> None:
159
+ """Set model params from `model_file`.
160
+
161
+ Args:
162
+ model_dir (str):
163
+ Dir containing model params.
164
+ """
165
+ config = self._load_pkl(os.path.join(model_dir, "config.pkl"))
166
+ self.config = config
167
+ self._classifier = Classifier(config)
168
+ self._classifier.load_state_dict(
169
+ torch.load(os.path.join(model_dir, "classifier.pkl"), map_location="cpu"))
170
+ self._classifier.eval()
171
+
172
+ def _transform(self, input: Union[str, bytes, Image.Image]) -> torch.Tensor:
173
+ """Transforms image to torch tensor.
174
+
175
+ Args:
176
+ input (Union[str,bytes,Image.Image]):
177
+ Image .
178
+
179
+ Raises:
180
+ ValueError:
181
+ `input` should be a str or bytes!
182
+
183
+ Returns:
184
+ torch.Tensor:
185
+ Transformed torch tensor.
186
+ """
187
+
188
+ im = read_im(input)
189
+ out = self._trans(im).view(1, 3, 224, 224).float()
190
+ return out
191
+
192
+ def _score(self, input: torch.Tensor) -> List[float]:
193
+ """Scoring the input image."""
194
+ toxic_score = self._classifier.score(input)
195
+ toxic_score = [round(s, 3) for s in toxic_score][1:]
196
+ return toxic_score
197
+
198
+ def detect(self, input: Union[str, bytes, Image.Image]) -> Dict:
199
+ """Detects toxic contents from image `input`.
200
+
201
+ Args:
202
+ input (Union[str,bytes,Image.Image]):
203
+ Image path of bytes.
204
+
205
+ Raises:
206
+ ValueError:
207
+ `input` should be a str or bytes!
208
+
209
+ Returns:
210
+ Dict:
211
+ Pattern as {
212
+ "toxic_score " : Dict[str,float]
213
+ }.
214
+ """
215
+
216
+ im = self._transform(input)
217
+ toxic_score = self._score(im)
218
+
219
+ out = {
220
+ "toxic_score": dict(
221
+ zip(
222
+ self._tags,
223
+ toxic_score,
224
+ )
225
+ ),}
226
+ return out