toxic_detection / model.py
szzzzz's picture
Create model.py
f51b502
raw
history blame
6.33 kB
import os
from typing import Dict, List, Optional, Union
import pickle
import torch
import torchvision
from torch import nn
import tarfile
from PIL import Image
from torchvision import transforms
def read_im(input: Image.Image) -> Image.Image:
"""read im
Args:
input (Image.Image):
img
Returns:
Image.Image
"""
im = input
if not isinstance(im, Image.Image):
raise ValueError("""`input` should be a str or bytes or Image.Image!""")
im = im.convert("RGB")
return im
class Classifier(nn.Module):
"""Toxic Classifier.
Given a transformed image,`classifier` will get a toxic socre on it.
Attributes:
config (Optional[Dict],optional):
Modeling config.
Defaults to None.
"""
def __init__(self, config: Optional[Dict] = None) -> None:
super().__init__()
self.config = {} if config is None else config
self.resnet = torchvision.models.resnet50()
self.resnet.fc = nn.Linear(
in_features=self.config.get("in_features", 2048),
out_features=self.config.get("tag_num", 2),)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.resnet(x)
return out
@torch.no_grad()
def score(self, input: torch.Tensor) -> List[float]:
"""Scoring the input image(one input).
Args:
input (torch.Tensor):
img input(should be transformed).
Returns:
List[float]:
The toxic score of the input .
"""
return (
torch.softmax(self.forward(input), dim=1).detach().cpu().view(-1).tolist())
class Detector():
"""Toxic detector .
Attributes:
config (Optional[Dict],optional):
Modeling config.
Defaults to None.
"""
def __init__(self,*,config: Optional[Dict] = None,) -> None:
super().__init__()
if config is None:
config = {}
self._config = config
self._in_features = config.get("in_features", 2048)
self._tag_num = config.get("tag_num", 2)
self._tags = config.get("tags", ["obscene"])
self._classifier = Classifier(self.config)
self._trans = transforms.Compose(
[
# transforms.ToPILImage()
transforms.Resize(256),
transforms.CenterCrop(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
@property
def config(self):
return self._config
@config.setter
def config(self, config: Dict):
self._config = config
self._in_features = config.get("in_features", 2048)
self._tag_num = config.get("tag_num", 2)
self._tags = config.get("tags", ["obscene"])
@property
def classifier(self):
return self._classifier
def _load_pkl(self, path: str) -> Dict:
with open(path, "rb") as f:
file = pickle.load(f)
return file
def _unzip2dir(self, file: str, dir: Optional[str] = None) -> None:
if dir is None:
dir = self._tmpdir.name
if not os.path.isdir(dir):
raise ValueError("""`dir` shoud be a dir!""")
tar = tarfile.open(file, "r")
tar.extractall(path=dir)
tar.close()
def load(self, model: str) -> None:
"""Load state dict from local model path .
Args:
model (str):
Model file need to be loaded.
A string, the path of a pretrained model.
Raises:
ValueError: str model should be a path!
"""
if isinstance(model, str):
if os.path.isdir(model):
self._load_from_dir(model)
elif os.path.isfile(model):
dir = "./toxic_detection"
if os.path.exists(dir):
pass
else:
os.mkdir(dir)
self._unzip2dir(model, dir)
self._load_from_dir(dir)
else:
raise ValueError("""str model should be a path!""")
else:
raise ValueError("""str model should be a path!""")
def _load_from_dir(self, model_dir: str) -> None:
"""Set model params from `model_file`.
Args:
model_dir (str):
Dir containing model params.
"""
config = self._load_pkl(os.path.join(model_dir, "config.pkl"))
self.config = config
self._classifier = Classifier(config)
self._classifier.load_state_dict(
torch.load(os.path.join(model_dir, "classifier.pkl"), map_location="cpu"))
self._classifier.eval()
def _transform(self, input: Union[str, bytes, Image.Image]) -> torch.Tensor:
"""Transforms image to torch tensor.
Args:
input (Union[str,bytes,Image.Image]):
Image .
Raises:
ValueError:
`input` should be a str or bytes!
Returns:
torch.Tensor:
Transformed torch tensor.
"""
im = read_im(input)
out = self._trans(im).view(1, 3, 224, 224).float()
return out
def _score(self, input: torch.Tensor) -> List[float]:
"""Scoring the input image."""
toxic_score = self._classifier.score(input)
toxic_score = [round(s, 3) for s in toxic_score][1:]
return toxic_score
def detect(self, input: Union[str, bytes, Image.Image]) -> Dict:
"""Detects toxic contents from image `input`.
Args:
input (Union[str,bytes,Image.Image]):
Image path of bytes.
Raises:
ValueError:
`input` should be a str or bytes!
Returns:
Dict:
Pattern as {
"toxic_score " : Dict[str,float]
}.
"""
im = self._transform(input)
toxic_score = self._score(im)
out = {
"toxic_score": dict(
zip(
self._tags,
toxic_score,
)
),}
return out