File size: 4,058 Bytes
54d726d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
from models.pretrained_decv2 import enc_dec_model
from models.densenet_v2 import Densenet
from models.unet_resnet18 import ResNet18UNet
from models.unet_resnet50 import UNetWithResnet50Encoder
import numpy as np
import cv2

# kb cropping
def cropping(img):
    h_im, w_im = img.shape[:2]

    margin_top = int(h_im - 352)
    margin_left = int((w_im - 1216) / 2)

    img = img[margin_top: margin_top + 352,
                margin_left: margin_left + 1216]

    return img

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)
CWD = "."
CKPT_FILE_NAMES = {
    'Indoor':{
        'Resnet_enc':'resnet_nyu_best.ckpt',
        'Unet':'resnet18_unet_epoch_08_model_kitti_and_nyu.ckpt',
        'Densenet_enc':'densenet_epoch_15_model.ckpt'
    },
    'Outdoor':{
        'Resnet_enc':'resnet_encdecmodel_epoch_05_model_nyu_and_kitti.ckpt',
        'Unet':'resnet50_unet_epoch_02_model_nyuandkitti.ckpt',
        'Densenet_enc':'densenet_nyu_then_kitti_epoch_10_model.ckpt'
    }
}
MODEL_CLASSES = {
    'Indoor': {
        'Resnet_enc':enc_dec_model,
        'Unet':ResNet18UNet,
        'Densenet_enc':Densenet
    },

    'Outdoor': {
        'Resnet_enc':enc_dec_model,
        'Unet':UNetWithResnet50Encoder,
        'Densenet_enc':Densenet
    },

}

def load_model(ckpt, model, optimizer=None):
    ckpt_dict = torch.load(ckpt, map_location='cpu')
    # keep backward compatibility
    if 'model' not in ckpt_dict and 'optimizer' not in ckpt_dict:
        state_dict = ckpt_dict
    else:
        state_dict = ckpt_dict['model']
    weights = {}
    for key, value in state_dict.items():
        if key.startswith('module.'):
            weights[key[len('module.'):]] = value
        else:
            weights[key] = value

    model.load_state_dict(weights)

    if optimizer is not None:
        optimizer_state = ckpt_dict['optimizer']
        optimizer.load_state_dict(optimizer_state)


def predict(location, model_name, img):
    ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model_name]}"
    if location == 'nyu':
        max_depth = 10
    else:
        max_depth = 80
    model = MODEL_CLASSES[location][model_name](max_depth).to(DEVICE)
    load_model(ckpt_dir,model)
    # print(img.shape)
    # assert False 
    if img.shape ==  (375,1242,3):
        img = cropping(img)
    img = torch.tensor(img).permute(2, 0, 1).float().to(DEVICE)
    input_RGB = img.unsqueeze(0)
    print(input_RGB.shape)
    with torch.no_grad():
        pred = model(input_RGB)
        pred_d = pred['pred_d']
        pred_d_numpy = pred_d.squeeze().cpu().numpy()
        # pred_d_numpy = (pred_d_numpy - pred_d_numpy.mean())/pred_d_numpy.std()
        pred_d_numpy = np.clip((pred_d_numpy / pred_d_numpy[15:,:].max()) * 255, 0,255)
        # pred_d_numpy = (pred_d_numpy / pred_d_numpy.max()) * 255
        pred_d_numpy = pred_d_numpy.astype(np.uint8)
        pred_d_color = cv2.applyColorMap(pred_d_numpy, cv2.COLORMAP_RAINBOW)
        pred_d_color = cv2.cvtColor(pred_d_color, cv2.COLOR_BGR2RGB)
        # del model
    return pred_d_color

with gr.Blocks() as demo:
    gr.Markdown("# Monocular Depth Estimation")
    with gr.Row():
        location = gr.Radio(choices=['Indoor', 'Outdoor'],value='Indoor', label = "Select Location Type")
        model_name = gr.Radio(['Unet', 'Resnet_enc', 'Densenet_enc'],value="Densenet_enc" ,label="Select model")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label = "Input Image for Depth Estimation")
        with gr.Column():
            output_depth_map = gr.Image(label = "Depth prediction Heatmap")
    with gr.Row():
        predict_btn = gr.Button("Generate Depthmap")
        predict_btn.click(fn=predict, inputs=[location, model_name, input_image], outputs=output_depth_map)
    with gr.Row():
        gr.Examples(['./demo_data/Bathroom.jpg', './demo_data/Bedroom.jpg', './demo_data/Bookstore.jpg', './demo_data/Classroom.jpg', './demo_data/Computerlab.jpg', './demo_data/kitti_1.png'], inputs=input_image)    
demo.launch()