File size: 13,713 Bytes
57d7ed3
 
 
 
 
 
0f2d9f6
57d7ed3
0f2d9f6
57d7ed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f2d9f6
57d7ed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f2d9f6
 
57d7ed3
0f2d9f6
 
 
 
57d7ed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f2d9f6
57d7ed3
 
 
 
0f2d9f6
57d7ed3
 
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57d7ed3
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57d7ed3
 
 
 
 
0f2d9f6
57d7ed3
 
0f2d9f6
 
 
 
 
 
 
 
 
 
57d7ed3
 
 
 
0f2d9f6
57d7ed3
 
 
 
 
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57d7ed3
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57d7ed3
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
57d7ed3
0f2d9f6
 
 
 
 
 
57d7ed3
0f2d9f6
57d7ed3
0f2d9f6
 
57d7ed3
0f2d9f6
 
 
 
 
57d7ed3
 
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57d7ed3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
# takn from: https://huggingface.co/spaces/frgfm/torch-cam/blob/main/app.py

# streamlit run app.py
from io import BytesIO
import os
import sys
import cv2
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import torch
from PIL import Image
from torchvision import models
from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor
from torchvision import transforms

from torchcam.methods import CAM
from torchcam import methods as torchcam_methods
from torchcam.utils import overlay_mask
import os.path as osp

root_path = osp.abspath(osp.join(__file__, osp.pardir))
sys.path.append(root_path)

from preprocessing.dataset_creation import EyeDentityDatasetCreation
from utils import get_model
from registry_utils import import_registered_modules

import_registered_modules()
# from torchcam.methods._utils import locate_candidate_layer

CAM_METHODS = [
    "CAM",
    # "GradCAM",
    # "GradCAMpp",
    # "SmoothGradCAMpp",
    # "ScoreCAM",
    # "SSCAM",
    # "ISCAM",
    # "XGradCAM",
    # "LayerCAM",
]
TV_MODELS = [
    "ResNet18",
    "ResNet50",
]
SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
UPSCALE = [2, 4]
UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
LABEL_MAP = ["left_pupil", "right_pupil"]


@torch.no_grad()
def _load_model(model_configs, device="cpu"):
    model_path = os.path.join(root_path, model_configs["model_path"])
    model_configs.pop("model_path")
    model_dict = torch.load(model_path, map_location=device)
    model = get_model(model_configs=model_configs)
    model.load_state_dict(model_dict)
    model = model.to(device)
    model = model.eval()
    return model


def main():
    # Wide mode
    st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")

    # Designing the interface
    st.title("EyeDentify Playground")
    # For newline
    st.write("\n")
    # Set the columns
    cols = st.columns((1, 1))
    # cols = st.columns((1, 1, 1))
    cols[0].header("Input image")
    # cols[1].header("Raw CAM")
    cols[-1].header("Prediction")

    # Sidebar
    # File selection
    st.sidebar.title("Upload Face or Eye")
    # Disabling warning
    st.set_option("deprecation.showfileUploaderEncoding", False)
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader(
        "Upload Image", type=["png", "jpeg", "jpg"]
    )
    if uploaded_file is not None:
        input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
        # print("input_img before = ", input_img.size)
        max_size = [input_img.size[0], input_img.size[1]]
        cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}")
        if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256:
            max_size[0] = 256
            max_size[1] = 256
        else:
            if input_img.size[0] >= 640:
                max_size[0] = 640
            elif input_img.size[0] < 64:
                max_size[0] = 64
            if input_img.size[1] >= 480:
                max_size[1] = 480
            elif input_img.size[1] < 32:
                max_size[1] = 32
        input_img.thumbnail((max_size[0], max_size[1]))  # Bicubic resampling
        # print("input_img after = ", input_img.size)
        # cols[0].image(input_img)
        fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
        # Display the input image
        axs0.imshow(input_img)
        axs0.axis("off")
        axs0.set_title("Input Image")

        # Display the plot
        cols[0].pyplot(fig0)
        cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")

    st.sidebar.title("Setup")

    # Upscale selection
    upscale = "-"
    # upscale = st.sidebar.selectbox(
    #     "Upscale",
    #     ["-"] + UPSCALE,
    #     help="Upscale the uploaded image 2 or 4 times. Keep blank for no upscaling",
    # )

    # Upscale method selection
    if upscale != "-":
        upscale_method_or_model = st.sidebar.selectbox(
            "Upscale Method / Model",
            UPSCALE_METHODS + SR_METHODS,
            help="Select a method or model to upscale the uploaded image",
        )
    else:
        upscale_method_or_model = None

    # Pupil selection
    pupil_selection = st.sidebar.selectbox(
        "Pupil Selection",
        ["-"] + LABEL_MAP,
        help="Select left or right pupil OR keep blank for both pupil diameter estimation",
    )

    # Model selection
    tv_model = st.sidebar.selectbox(
        "Classification model",
        TV_MODELS,
        help="Supported Models for Pupil Diameter Estimation",
    )

    cam_method = "CAM"
    # cam_method = st.sidebar.selectbox(
    #     "CAM method",
    #     CAM_METHODS,
    #     help="The way your class activation map will be computed",
    # )
    # target_layer = st.sidebar.text_input(
    #     "Target layer",
    #     default_layer,
    #     help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
    # )

    st.sidebar.write("\n")

    if st.sidebar.button("Predict Diameter & Compute CAM"):
        if uploaded_file is None:
            st.sidebar.error("Please upload an image first")

        else:
            with st.spinner("Analyzing..."):
                if upscale == "-":
                    sr_configs = None
                else:
                    sr_configs = {
                        "method": upscale_method_or_model,
                        "params": {"upscale": upscale},
                    }
                config_file = {
                    "sr_configs": sr_configs,
                    "feature_extraction_configs": {
                        "blink_detection": False,
                        "upscale": upscale,
                        "extraction_library": "mediapipe",
                    },
                }

                img = np.array(input_img)
                # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                # if img.shape[0] > max_size or img.shape[1] > max_size:
                #     img = cv2.resize(img, (max_size, max_size))

                ds_results = EyeDentityDatasetCreation(
                    feature_extraction_configs=config_file[
                        "feature_extraction_configs"
                    ],
                    sr_configs=config_file["sr_configs"],
                )(img)
                # if ds_results is not None:
                # print("ds_results = ", ds_results.keys())

                preprocess_steps = [
                    transforms.ToTensor(),
                    transforms.Resize(
                        [32, 64],
                        # interpolation=transforms.InterpolationMode.BILINEAR,
                        interpolation=transforms.InterpolationMode.BICUBIC,
                        antialias=True,
                    ),
                ]
                preprocess_function = transforms.Compose(preprocess_steps)

                left_eye = None
                right_eye = None

                if ds_results is None:
                    # print("type of input_img = ", type(input_img))
                    input_img = preprocess_function(input_img)
                    input_img = input_img.unsqueeze(0)
                    if pupil_selection == "left_pupil":
                        left_eye = input_img
                    elif pupil_selection == "right_pupil":
                        right_eye = input_img
                    else:
                        left_eye = input_img
                        right_eye = input_img
                    # print("type of left_eye = ", type(left_eye))
                    # print("type of right_eye = ", type(right_eye))
                elif "eyes" in ds_results.keys():
                    if (
                        "left_eye" in ds_results["eyes"].keys()
                        and ds_results["eyes"]["left_eye"] is not None
                    ):
                        left_eye = ds_results["eyes"]["left_eye"]
                        # print("type of left_eye = ", type(left_eye))
                        left_eye = to_pil_image(left_eye).convert("RGB")
                        # print("type of left_eye = ", type(left_eye))

                        left_eye = preprocess_function(left_eye)
                        # print("type of left_eye = ", type(left_eye))

                        left_eye = left_eye.unsqueeze(0)
                    if (
                        "right_eye" in ds_results["eyes"].keys()
                        and ds_results["eyes"]["right_eye"] is not None
                    ):
                        right_eye = ds_results["eyes"]["right_eye"]
                        # print("type of right_eye = ", type(right_eye))
                        right_eye = to_pil_image(right_eye).convert("RGB")
                        # print("type of right_eye = ", type(right_eye))

                        right_eye = preprocess_function(right_eye)
                        # print("type of right_eye = ", type(right_eye))

                        right_eye = right_eye.unsqueeze(0)
                else:
                    # print("type of input_img = ", type(input_img))
                    input_img = preprocess_function(input_img)
                    input_img = input_img.unsqueeze(0)
                    if pupil_selection == "left_pupil":
                        left_eye = input_img
                    elif pupil_selection == "right_pupil":
                        right_eye = input_img
                    else:
                        left_eye = input_img
                        right_eye = input_img
                    # print("type of left_eye = ", type(left_eye))
                    # print("type of right_eye = ", type(right_eye))

                # print("left_eye = ", left_eye.shape)
                # print("right_eye = ", right_eye.shape)

                if pupil_selection == "-":
                    selected_eyes = ["left_eye", "right_eye"]
                elif pupil_selection == "left_pupil":
                    selected_eyes = ["left_eye"]
                elif pupil_selection == "right_pupil":
                    selected_eyes = ["right_eye"]

                for eye_type in selected_eyes:

                    model_configs = {
                        "model_path": root_path
                        + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
                        "registered_model_name": tv_model,
                        "num_classes": 1,
                    }
                    registered_model_name = model_configs["registered_model_name"]
                    model = _load_model(model_configs)

                    if registered_model_name == "ResNet18":
                        target_layer = model.resnet.layer4[-1].conv2
                    elif registered_model_name == "ResNet50":
                        target_layer = model.resnet.layer4[-1].conv3
                    else:
                        raise Exception(
                            f"No target layer available for selected model: {registered_model_name}"
                        )

                    if left_eye is not None and eye_type == "left_eye":
                        input_img = left_eye
                    elif right_eye is not None and eye_type == "right_eye":
                        input_img = right_eye
                    else:
                        raise Exception("Wrong Data")

                    if cam_method is not None:
                        cam_extractor = torchcam_methods.__dict__[cam_method](
                            model,
                            target_layer=target_layer,
                            fc_layer=model.resnet.fc,
                            input_shape=input_img.shape,
                        )

                    # with torch.no_grad():
                    out = model(input_img)
                    cols[-1].markdown(
                        f"<h3>Predicted Pupil Diameter: {out[0].item():.2f} mm</h3>",
                        unsafe_allow_html=True,
                    )
                    # cols[-1].text(f"Predicted Pupil Diameter: {out[0].item():.2f}")

                    # Retrieve the CAM
                    act_maps = cam_extractor(0, out)

                    # Fuse the CAMs if there are several
                    activation_map = (
                        act_maps[0]
                        if len(act_maps) == 1
                        else cam_extractor.fuse_cams(act_maps)
                    )

                    # Convert input image and activation map to PIL images
                    input_image_pil = to_pil_image(input_img.squeeze(0))
                    activation_map_pil = to_pil_image(activation_map, mode="F")

                    # Create the overlayed CAM result
                    result = overlay_mask(
                        input_image_pil,
                        activation_map_pil,
                        alpha=0.5,
                    )

                    # Create a subplot with 1 row and 2 columns
                    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

                    # Display the input image
                    axs[0].imshow(input_image_pil)
                    axs[0].axis("off")
                    axs[0].set_title("Input Image")

                    # Display the overlayed CAM result
                    axs[1].imshow(result)
                    axs[1].axis("off")
                    axs[1].set_title("Overlayed CAM")

                    # Display the plot
                    cols[-1].pyplot(fig)
                    cols[-1].text(
                        f"eye image size: {input_img.shape[-1]} x {input_img.shape[-2]}"
                    )


if __name__ == "__main__":
    main()