|
from typing import Any
|
|
|
|
from transformers import pipeline
|
|
|
|
from constants import SAFETY_CHECKER_MODEL
|
|
|
|
|
|
class SafetyChecker:
|
|
"""A class to check if an image is NSFW or not."""
|
|
|
|
def __init__(
|
|
self,
|
|
mode_id: str = SAFETY_CHECKER_MODEL,
|
|
):
|
|
self.classifier = pipeline(
|
|
"image-classification",
|
|
model=mode_id,
|
|
)
|
|
|
|
def is_safe(
|
|
self,
|
|
image: Any,
|
|
) -> bool:
|
|
pred = self.classifier(image)
|
|
scores = {label["label"]: label["score"] for label in pred}
|
|
nsfw_score = scores.get("nsfw", 0)
|
|
normal_score = scores.get("normal", 0)
|
|
print(f"NSFW score: {nsfw_score}, Normal score: {normal_score}")
|
|
return normal_score > nsfw_score
|
|
|