fffiloni commited on
Commit
270d2eb
·
verified ·
1 Parent(s): c7901e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import requests
7
+ from io import BytesIO
8
+ import numpy as np
9
+
10
+ # load a simple face detector
11
+ from retinaface import RetinaFace
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # load Gaze-LLE model
16
+ model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout")
17
+ model.eval()
18
+ model.to(device)
19
+
20
+ def main(image_input):
21
+ # load image
22
+ image = Image.open(image_input)
23
+ width, height = image.size
24
+
25
+ # detect faces
26
+ resp = RetinaFace.detect_faces(np.array(image))
27
+ print(resp)
28
+ bboxes = [resp[key]["facial_area"] for key in resp.keys()]
29
+ print(bboxes)
30
+
31
+ # prepare gazelle input
32
+ img_tensor = transform(image).unsqueeze(0).to(device)
33
+ norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]]
34
+
35
+ input = {
36
+ "images": img_tensor, # [num_images, 3, 448, 448]
37
+ "bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...]
38
+ }
39
+
40
+ with torch.no_grad():
41
+ output = model(input)
42
+
43
+ img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap
44
+ print(img1_person1_heatmap.shape)
45
+ if model.inout:
46
+ img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction)
47
+ print(img1_person1_inout.item())
48
+
49
+ # visualize predicted gaze heatmap for each person and gaze in/out of frame score
50
+
51
+ def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None):
52
+ if isinstance(heatmap, torch.Tensor):
53
+ heatmap = heatmap.detach().cpu().numpy()
54
+ heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
55
+ heatmap = plt.cm.jet(np.array(heatmap) / 255.)
56
+ heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
57
+ heatmap = Image.fromarray(heatmap).convert("RGBA")
58
+ heatmap.putalpha(90)
59
+ overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)
60
+
61
+ if bbox is not None:
62
+ width, height = pil_image.size
63
+ xmin, ymin, xmax, ymax = bbox
64
+ draw = ImageDraw.Draw(overlay_image)
65
+ draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01))
66
+
67
+ if inout_score is not None:
68
+ text = f"in-frame: {inout_score:.2f}"
69
+ text_width = draw.textlength(text)
70
+ text_height = int(height * 0.01)
71
+ text_x = xmin * width
72
+ text_y = ymax * height + text_height
73
+ draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
74
+ return overlay_image
75
+
76
+
77
+ # combined visualization with maximal gaze points for each person
78
+
79
+ def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
80
+ colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
81
+ overlay_image = pil_image.convert("RGBA")
82
+ draw = ImageDraw.Draw(overlay_image)
83
+ width, height = pil_image.size
84
+
85
+ for i in range(len(bboxes)):
86
+ bbox = bboxes[i]
87
+ xmin, ymin, xmax, ymax = bbox
88
+ color = colors[i % len(colors)]
89
+ draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01))
90
+
91
+ if inout_scores is not None:
92
+ inout_score = inout_scores[i]
93
+ text = f"in-frame: {inout_score:.2f}"
94
+ text_width = draw.textlength(text)
95
+ text_height = int(height * 0.01)
96
+ text_x = xmin * width
97
+ text_y = ymax * height + text_height
98
+ draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
99
+
100
+ if inout_scores is not None and inout_score > inout_thresh:
101
+ heatmap = heatmaps[i]
102
+ heatmap_np = heatmap.detach().cpu().numpy()
103
+ max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
104
+ gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
105
+ gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
106
+ bbox_center_x = ((xmin + xmax) / 2) * width
107
+ bbox_center_y = ((ymin + ymax) / 2) * height
108
+
109
+ draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height)))
110
+ draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height)))
111
+
112
+ return overlay_image
113
+
114
+ result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)
115
+
116
+ return result_gazed
117
+
118
+
119
+ with gr.Blocks() as demo:
120
+ with gr.Column():
121
+ with gr.Row():
122
+ with gr.Column():
123
+ input_image = gr.Image(label="Image Input", type="filepath")
124
+ submit_button = gr.Button("Submit")
125
+ with gr.Column():
126
+ result = gr.Image(label="Result")
127
+
128
+ submit_button.click(
129
+ fn = main,
130
+ inputs = [input_image],
131
+ outputs = [result]
132
+ )
133
+ demo.queue().launch(show_api=False, show_error=True)