File size: 3,952 Bytes
a7a6ae1
8bbdb4e
 
 
7bcf3d8
8bbdb4e
7bcf3d8
 
 
 
 
 
 
dd8a302
8bbdb4e
 
 
 
 
 
7bcf3d8
 
 
 
 
 
 
 
9e1027a
7bcf3d8
a54f7cf
 
8bbdb4e
4534e51
 
7bcf3d8
 
 
4534e51
 
7bcf3d8
6f885f8
 
7bcf3d8
8c172f5
 
b792534
8c172f5
 
7bcf3d8
c7a063d
8c172f5
c7a063d
 
 
320cce2
 
 
 
092d7e9
320cce2
8cc08e3
fd64a1e
219cf1f
7755c4c
8bbdb4e
9e1027a
 
 
219cf1f
acc0e22
6f885f8
9e1027a
6b67e23
9e1027a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from transformers import pipeline, ViTForImageClassification, ViTImageProcessor
import numpy as np
from PIL import Image
import warnings
import logging
from pytorch_grad_cam import run_dff_on_image, GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch
from face_grab import FaceGrabber
from gradcam import GradCam
from torchvision import transforms

logging.basicConfig(level=logging.INFO)


model = ViTForImageClassification.from_pretrained("ongkn/attraction-classifier")
processor = ViTImageProcessor.from_pretrained("ongkn/attraction-classifier")

faceGrabber = FaceGrabber()
gradCam = GradCam()

targetsForGradCam = [ClassifierOutputTarget(gradCam.category_name_to_index(model, "pos")),
                       ClassifierOutputTarget(gradCam.category_name_to_index(model, "neg"))]
targetLayerDff = model.vit.layernorm
targetLayerGradCam = model.vit.encoder.layer[-2].output

def classify_image(input):
    face = faceGrabber.grab_faces(np.array(input))
    if face is None:
        return "No face detected", 0, input
    face = Image.fromarray(face)
    faceResized = face.resize((224, 224))
    tensorResized = transforms.ToTensor()(faceResized)
    dffImage = run_dff_on_image(model=model,
                                target_layer=targetLayerDff,
                                classifier=model.classifier,
                                img_pil=faceResized,
                                img_tensor=tensorResized,
                                reshape_transform=gradCam.reshape_transform_vit_huggingface,
                                n_components=6,
                                top_k=15
                                )
    result = gradCam.get_top_category(model, tensorResized)
    cls = result[0]["label"]
    result[0]["score"] = round(result[0]["score"], 2)
    clsIdx = gradCam.category_name_to_index(model, cls)
    clsTarget = ClassifierOutputTarget(clsIdx)
    gradCamImage = gradCam.run_grad_cam_on_image(model=model,
                                        target_layer=targetLayerGradCam,
                                        targets_for_gradcam=[clsTarget],
                                        input_tensor=tensorResized,
                                        input_image=faceResized,
                                        reshape_transform=gradCam.reshape_transform_vit_huggingface)
    if result[0]["label"] == "pos" and result[0]["score"] > 0.85 and result[0]["score"] <= 0.9:
        return result[0]["label"], result[0]["score"], "Nice!", face, dffImage, gradCamImage
    elif result[0]["label"] == "pos" and result[0]["score"] > 0.9 and result[0]["score"] <= 0.95:
        return result[0]["label"], result[0]["score"], "Pretty!", face, dffImage, gradCamImage
    elif result[0]["label"] == "pos" and result[0]["score"] > 0.95 and result[0]["score"] <= 0.98:
        return result[0]["label"], result[0]["score"], "WHOA!!!!", face, dffImage, gradCamImage
    elif result[0]["label"] == "pos" and result[0]["score"] > 0.98:
        return result[0]["label"], result[0]["score"], "** ABSOLUTELY MINDBLOWING **", face, dffImage, gradCamImage
    else:
        return cls, result[0]["score"], "Indifferent", face, dffImage, gradCamImage

iface = gr.Interface(
    fn=classify_image,
    inputs="image",
    outputs=["text", "number", "text", "image", "image", "image"],
    title="Attraction Classifier - subjective",
    description=f"Takes in a (224, 224, 3) (RGB) image and outputs an attraction class: {'pos', 'neg'}, along with a GradCam/DFF explanation. Face detection, cropping, and resizing are done internally. Uploaded images are not stored by us, but may be stored by HF. Refer to their [privacy policy](https://huggingface.co/privacy) for details.\nAssociated post: https://simtoon.ongakken.com/Projects/Personal/Girl+classifier/desc+-+girl+classifier"
)

iface.launch()