File size: 5,410 Bytes
a7a6ae1
8bbdb4e
 
 
 
 
7bcf3d8
8bbdb4e
feb0188
7bcf3d8
 
 
 
 
 
 
dd8a302
a7a6ae1
8bbdb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
3307e73
8bbdb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d715eac
 
 
 
 
 
 
 
8bbdb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bcf3d8
 
 
 
 
 
 
 
9e1027a
7bcf3d8
a54f7cf
 
8bbdb4e
7bcf3d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bbdb4e
7bcf3d8
8bbdb4e
9e1027a
 
 
7bcf3d8
acc0e22
7bcf3d8
9e1027a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from transformers import pipeline, ViTForImageClassification, ViTImageProcessor
import numpy as np
from PIL import Image
import cv2 as cv
import dlib
import warnings
import logging
from typing import Optional
from pytorch_grad_cam import run_dff_on_image, GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch
from face_grab import FaceGrabber
from gradcam import GradCam
from torchvision import transforms


logging.basicConfig(level=logging.INFO)

def grab_faces(img: np.ndarray) -> Optional[np.ndarray]:
    cascades = [
        "haarcascade_frontalface_default.xml",
        "haarcascade_frontalface_alt.xml",
        "haarcascade_frontalface_alt2.xml",
        "haarcascade_frontalface_alt_tree.xml"
    ]

    detector = dlib.get_frontal_face_detector() # load face detector
    predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks_GTX.dat") # load face predictor
    mmod = dlib.cnn_face_detection_model_v1("mmod_human_face_detector.dat") # load face detector

    paddingBy = 0.1 # padding by 10%

    gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY) # convert to grayscale

    detected = None

    if detected is None:
        faces = detector(gray) # detect faces
        if len(faces) > 0:
            detected = faces[0]
            detected = (detected.left(), detected.top(), detected.width(), detected.height())
            logging.info("Face detected by dlib")

    if detected is None:
        faces = mmod(img)
        if len(faces) > 0:
            detected = faces[0]
            detected = (detected.rect.left(), detected.rect.top(), detected.rect.width(), detected.rect.height())
            logging.info("Face detected by mmod")

    for cascade in cascades:
        cascadeClassifier = cv.CascadeClassifier(cv.data.haarcascades + cascade)
        faces = cascadeClassifier.detectMultiScale(gray, scaleFactor=1.3, minNeighbors=5) # detect faces
        if len(faces) > 0:
            detected = faces[0]
            logging.info(f"Face detected by {cascade}")
            break

    if detected is not None: # if face detected
        x, y, w, h = detected # grab first face
        padW = int(paddingBy * w) # get padding width
        padH = int(paddingBy * h) # get padding height
        imgH, imgW, _ = img.shape # get image dims
        x = max(0, x - padW)
        y = max(0, y - padH)
        w = min(imgW - x, w + 2 * padW)
        h = min(imgH - y, h + 2 * padH)
        x = max(0, x - (w - detected[2]) // 2) # center the face horizontally
        y = max(0, y - (h - detected[3]) // 2) # center the face vertically
        face = img[y:y+h, x:x+w] # crop face
        return face

    return None

model = ViTForImageClassification.from_pretrained("ongkn/attraction-classifier")
processor = ViTImageProcessor.from_pretrained("ongkn/attraction-classifier")

pipe = pipeline("image-classification", model=model, feature_extractor=processor)

faceGrabber = FaceGrabber()
gradCam = GradCam()

targetsForGradCam = [ClassifierOutputTarget(gradCam.category_name_to_index(model, "pos")),
                       ClassifierOutputTarget(gradCam.category_name_to_index(model, "neg"))]
targetLayerDff = model.vit.layernorm
targetLayerGradCam = model.vit.encoder.layer[-2].output

def classify_image(input):
    face = faceGrabber.grab_faces(np.array(input))
    if face is None:
        return "No face detected", 0, input
    face = Image.fromarray(face)
    imgTensor = transforms.ToTensor()(face)
    tensor = transforms.ToTensor()(face)
    dffImage = run_dff_on_image(model=model,
                                target_layer=targetLayerDff,
                                classifier=model.classifier,
                                img_pil=face,
                                img_tensor=tensor,
                                reshape_transform=gradCam.reshape_transform_vit_huggingface,
                                n_components=5,
                                top_k=10,
                                threshold=0,
                                )
    gradCamImage = gradCam.run_grad_cam_on_image(model=model,
                                                 target_layer=targetLayerGradCam,
                                                 classifier=model.classifier,
                                                 img_pil=face,
                                                 img_tensor=tensor,
                                                 reshape_transform=gradCam.reshape_transform_vit_huggingface,
                                                 n_components=5,
                                                 top_k=10,
                                                 threshold=0,
                                                 )
    result = pipe(face)
    return result[0]["label"], result[0]["score"], face, dffImage, gradCamImage

iface = gr.Interface(
    fn=classify_image,
    inputs="image",
    outputs=["text", "number", "image", "image", "image"],
    title="Attraction Classifier - subjective",
    description=f"Takes in a (224, 224) image and outputs an attraction class: {'pos', 'neg'}, along with a GradCam/DFF explanation. Face detection, cropping, and resizing are done internally. Uploaded images are not stored by us, but may be stored by HF. Refer to their privacy policy for details."
)
iface.launch()