File size: 7,316 Bytes
270d2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33c5278
270d2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33c5278
 
 
 
270d2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33c5278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270d2eb
 
 
 
 
 
33c5278
270d2eb
 
 
 
33c5278
270d2eb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr

import torch
import matplotlib.pyplot as plt 
from PIL import Image, ImageDraw, ImageFont 
import requests 
from io import BytesIO 
import numpy as np 

# load a simple face detector 
from retinaface import RetinaFace 

device = "cuda" if torch.cuda.is_available() else "cpu"

# load Gaze-LLE model
model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout")
model.eval()
model.to(device)

def main(image_input, progress=gr.Progress(track_tqdm=True)):
    # load image
    image = Image.open(image_input)
    width, height = image.size

    # detect faces
    resp = RetinaFace.detect_faces(np.array(image))
    print(resp)
    bboxes = [resp[key]["facial_area"] for key in resp.keys()]
    print(bboxes)

    # prepare gazelle input
    img_tensor = transform(image).unsqueeze(0).to(device)
    norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]]

    input = {
        "images": img_tensor, # [num_images, 3, 448, 448]
        "bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...]
    }

    with torch.no_grad():
        output = model(input)

    img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap
    print(img1_person1_heatmap.shape)
    if model.inout:
        img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction)
        print(img1_person1_inout.item())

    # visualize predicted gaze heatmap for each person and gaze in/out of frame score

    def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None):
        if isinstance(heatmap, torch.Tensor):
            heatmap = heatmap.detach().cpu().numpy()
        heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
        heatmap = plt.cm.jet(np.array(heatmap) / 255.)
        heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
        heatmap = Image.fromarray(heatmap).convert("RGBA")
        heatmap.putalpha(90)
        overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)

        if bbox is not None:
            width, height = pil_image.size
            xmin, ymin, xmax, ymax = bbox
            draw = ImageDraw.Draw(overlay_image)
            draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01))

            if inout_score is not None:
                text = f"in-frame: {inout_score:.2f}"
                text_width = draw.textlength(text)
                text_height = int(height * 0.01)
                text_x = xmin * width
                text_y = ymax * height + text_height
                draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
        return overlay_image

    heatmap_results = []
    for i in range(len(bboxes)):
        overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None))
        heatmap_results.append(overlay_img)

    # combined visualization with maximal gaze points for each person

    def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
        colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
        overlay_image = pil_image.convert("RGBA")
        draw = ImageDraw.Draw(overlay_image)
        width, height = pil_image.size

        for i in range(len(bboxes)):
            bbox = bboxes[i]
            xmin, ymin, xmax, ymax = bbox
            color = colors[i % len(colors)]
            draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01))

            if inout_scores is not None:
                inout_score = inout_scores[i]
                text = f"in-frame: {inout_score:.2f}"
                text_width = draw.textlength(text)
                text_height = int(height * 0.01)
                text_x = xmin * width
                text_y = ymax * height + text_height
                draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05)))

            if inout_scores is not None and inout_score > inout_thresh:
                heatmap = heatmaps[i]
                heatmap_np = heatmap.detach().cpu().numpy()
                max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
                gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
                gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
                bbox_center_x = ((xmin + xmax) / 2) * width
                bbox_center_y = ((ymin + ymax) / 2) * height

                draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height)))
                draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height)))

        return overlay_image

    result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)

    return result_gazed, heatmap_results

css="""
div#col-container{
    margin: 0 auto;
    max-width: 982px;
}
"""

with gr.Blocks(css=css) as demo: 
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders")
        gr.Markdown("A transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!")
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href="https://github.com/fkryan/gazelle">
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a> 
            <a href="https://arxiv.org/abs/2412.09586">
                <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
            </a>
            <a href="https://huggingface.co/spaces/fffiloni/Gaze-LLE?duplicate=true">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
            </a>
            <a href="https://huggingface.co/fffiloni">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
            </a>
        </div>
        """)
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Image Input", type="filepath")
                submit_button = gr.Button("Submit")
            with gr.Column():
                result = gr.Image(label="Result")
                heatmaps = gr.Gallery(label="Heatmap")

    submit_button.click(
        fn = main,
        inputs = [input_image],
        outputs = [result, heatmaps]
    )
demo.queue().launch(show_api=False, show_error=True)