szzzzz commited on
Commit
bd9dd85
·
1 Parent(s): 854b18b

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -226
model.py DELETED
@@ -1,226 +0,0 @@
1
- import os
2
- from typing import Dict, List, Optional, Union
3
- import pickle
4
- import torch
5
- import torchvision
6
- from torch import nn
7
- import tarfile
8
- from PIL import Image
9
- from torchvision import transforms
10
-
11
-
12
- def read_im(input: Image.Image) -> Image.Image:
13
- """read im
14
-
15
- Args:
16
- input (Image.Image):
17
- img
18
-
19
- Returns:
20
- Image.Image
21
-
22
- """
23
- im = input
24
- if not isinstance(im, Image.Image):
25
- raise ValueError("""`input` should be a str or bytes or Image.Image!""")
26
- im = im.convert("RGB")
27
-
28
- return im
29
-
30
-
31
- class Classifier(nn.Module):
32
- """Toxic Classifier.
33
-
34
- Given a transformed image,`classifier` will get a toxic socre on it.
35
-
36
- Attributes:
37
- config (Optional[Dict],optional):
38
- Modeling config.
39
- Defaults to None.
40
- """
41
-
42
- def __init__(self, config: Optional[Dict] = None) -> None:
43
- super().__init__()
44
- self.config = {} if config is None else config
45
-
46
- self.resnet = torchvision.models.resnet50()
47
- self.resnet.fc = nn.Linear(
48
- in_features=self.config.get("in_features", 2048),
49
- out_features=self.config.get("tag_num", 2),)
50
-
51
- def forward(self, x: torch.Tensor) -> torch.Tensor:
52
- out = self.resnet(x)
53
- return out
54
-
55
- @torch.no_grad()
56
- def score(self, input: torch.Tensor) -> List[float]:
57
- """Scoring the input image(one input).
58
-
59
- Args:
60
- input (torch.Tensor):
61
- img input(should be transformed).
62
-
63
- Returns:
64
- List[float]:
65
- The toxic score of the input .
66
- """
67
-
68
- return (
69
- torch.softmax(self.forward(input), dim=1).detach().cpu().view(-1).tolist())
70
-
71
- class Detector():
72
- """Toxic detector .
73
-
74
- Attributes:
75
- config (Optional[Dict],optional):
76
- Modeling config.
77
- Defaults to None.
78
- """
79
-
80
- def __init__(self,*,config: Optional[Dict] = None,) -> None:
81
- super().__init__()
82
-
83
- if config is None:
84
- config = {}
85
- self._config = config
86
- self._in_features = config.get("in_features", 2048)
87
- self._tag_num = config.get("tag_num", 2)
88
- self._tags = config.get("tags", ["obscene"])
89
-
90
- self._classifier = Classifier(self.config)
91
- self._trans = transforms.Compose(
92
- [
93
- # transforms.ToPILImage()
94
- transforms.Resize(256),
95
- transforms.CenterCrop(size=(224, 224)),
96
- transforms.ToTensor(),
97
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
98
- ])
99
-
100
- @property
101
- def config(self):
102
- return self._config
103
-
104
- @config.setter
105
- def config(self, config: Dict):
106
- self._config = config
107
- self._in_features = config.get("in_features", 2048)
108
- self._tag_num = config.get("tag_num", 2)
109
- self._tags = config.get("tags", ["obscene"])
110
-
111
- @property
112
- def classifier(self):
113
- return self._classifier
114
-
115
- def _load_pkl(self, path: str) -> Dict:
116
- with open(path, "rb") as f:
117
- file = pickle.load(f)
118
- return file
119
-
120
- def _unzip2dir(self, file: str, dir: Optional[str] = None) -> None:
121
- if dir is None:
122
- dir = self._tmpdir.name
123
- if not os.path.isdir(dir):
124
- raise ValueError("""`dir` shoud be a dir!""")
125
- tar = tarfile.open(file, "r")
126
- tar.extractall(path=dir)
127
- tar.close()
128
-
129
- def load(self, model: str) -> None:
130
- """Load state dict from local model path .
131
-
132
- Args:
133
- model (str):
134
- Model file need to be loaded.
135
- A string, the path of a pretrained model.
136
-
137
- Raises:
138
- ValueError: str model should be a path!
139
- """
140
-
141
- if isinstance(model, str):
142
- if os.path.isdir(model):
143
- self._load_from_dir(model)
144
- elif os.path.isfile(model):
145
- dir = "./toxic_detection"
146
- if os.path.exists(dir):
147
- pass
148
- else:
149
- os.mkdir(dir)
150
- self._unzip2dir(model, dir)
151
- self._load_from_dir(dir)
152
- else:
153
- raise ValueError("""str model should be a path!""")
154
-
155
- else:
156
- raise ValueError("""str model should be a path!""")
157
-
158
- def _load_from_dir(self, model_dir: str) -> None:
159
- """Set model params from `model_file`.
160
-
161
- Args:
162
- model_dir (str):
163
- Dir containing model params.
164
- """
165
- config = self._load_pkl(os.path.join(model_dir, "config.pkl"))
166
- self.config = config
167
- self._classifier = Classifier(config)
168
- self._classifier.load_state_dict(
169
- torch.load(os.path.join(model_dir, "classifier.pkl"), map_location="cpu"))
170
- self._classifier.eval()
171
-
172
- def _transform(self, input: Union[str, bytes, Image.Image]) -> torch.Tensor:
173
- """Transforms image to torch tensor.
174
-
175
- Args:
176
- input (Union[str,bytes,Image.Image]):
177
- Image .
178
-
179
- Raises:
180
- ValueError:
181
- `input` should be a str or bytes!
182
-
183
- Returns:
184
- torch.Tensor:
185
- Transformed torch tensor.
186
- """
187
-
188
- im = read_im(input)
189
- out = self._trans(im).view(1, 3, 224, 224).float()
190
- return out
191
-
192
- def _score(self, input: torch.Tensor) -> List[float]:
193
- """Scoring the input image."""
194
- toxic_score = self._classifier.score(input)
195
- toxic_score = [round(s, 3) for s in toxic_score][1:]
196
- return toxic_score
197
-
198
- def detect(self, input: Union[str, bytes, Image.Image]) -> Dict:
199
- """Detects toxic contents from image `input`.
200
-
201
- Args:
202
- input (Union[str,bytes,Image.Image]):
203
- Image path of bytes.
204
-
205
- Raises:
206
- ValueError:
207
- `input` should be a str or bytes!
208
-
209
- Returns:
210
- Dict:
211
- Pattern as {
212
- "toxic_score " : Dict[str,float]
213
- }.
214
- """
215
-
216
- im = self._transform(input)
217
- toxic_score = self._score(im)
218
-
219
- out = {
220
- "toxic_score": dict(
221
- zip(
222
- self._tags,
223
- toxic_score,
224
- )
225
- ),}
226
- return out