import gradio as gr import torch from models.pretrained_decv2 import enc_dec_model from models.densenet_v2 import Densenet from models.unet_resnet18 import ResNet18UNet from models.unet_resnet50 import UNetWithResnet50Encoder import numpy as np import cv2 # kb cropping def cropping(img): h_im, w_im = img.shape[:2] margin_top = int(h_im - 352) margin_left = int((w_im - 1216) / 2) img = img[margin_top: margin_top + 352, margin_left: margin_left + 1216] return img def load_model(ckpt, model, optimizer=None): ckpt_dict = torch.load(ckpt, map_location='cpu') # keep backward compatibility if 'model' not in ckpt_dict and 'optimizer' not in ckpt_dict: state_dict = ckpt_dict else: state_dict = ckpt_dict['model'] weights = {} for key, value in state_dict.items(): if key.startswith('module.'): weights[key[len('module.'):]] = value else: weights[key] = value model.load_state_dict(weights) if optimizer is not None: optimizer_state = ckpt_dict['optimizer'] optimizer.load_state_dict(optimizer_state) DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' print(DEVICE) CWD = "." CKPT_FILE_NAMES = { 'Indoor':{ 'Resnet_enc':'resnet_nyu_best.ckpt', 'Unet':'resnet18_unet_epoch_08_model_kitti_and_nyu.ckpt', 'Densenet_enc':'densenet_epoch_15_model.ckpt' }, 'Outdoor':{ 'Resnet_enc':'resnet_encdecmodel_epoch_05_model_nyu_and_kitti.ckpt', 'Unet':'resnet50_unet_epoch_02_model_nyuandkitti.ckpt', 'Densenet_enc':'densenet_nyu_then_kitti_epoch_10_model.ckpt' } } MODEL_CLASSES = { 'Indoor': { 'Resnet_enc':enc_dec_model(max_depth = 10), 'Unet':ResNet18UNet(max_depth = 10), 'Densenet_enc':Densenet(max_depth = 10) }, 'Outdoor': { 'Resnet_enc':enc_dec_model(max_depth = 80), 'Unet':UNetWithResnet50Encoder(max_depth = 80), 'Densenet_enc':Densenet(max_depth = 80) }, } location_types = ['Indoor', 'Outdoor'] Models = ['Resnet_enc','Unet','Densenet_enc'] for location in location_types: for model in Models: ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model]}" load_model(ckpt_dir, MODEL_CLASSES[location][model]) def predict(location, model_name, img): # ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model_name]}" # if location == 'nyu': # max_depth = 10 # else: # max_depth = 80 # model = MODEL_CLASSES[location][model_name](max_depth).to(DEVICE) model = MODEL_CLASSES[location][model_name].to(DEVICE) # load_model(ckpt_dir,model) # print(img.shape) # assert False if img.shape == (375,1242,3): img = cropping(img) img = torch.tensor(img).permute(2, 0, 1).float().to(DEVICE) input_RGB = img.unsqueeze(0) print(input_RGB.shape) with torch.no_grad(): pred = model(input_RGB) pred_d = pred['pred_d'] pred_d_numpy = pred_d.squeeze().cpu().numpy() # pred_d_numpy = (pred_d_numpy - pred_d_numpy.mean())/pred_d_numpy.std() pred_d_numpy = np.clip((pred_d_numpy / pred_d_numpy[15:,:].max()) * 255, 0,255) # pred_d_numpy = (pred_d_numpy / pred_d_numpy.max()) * 255 pred_d_numpy = pred_d_numpy.astype(np.uint8) pred_d_color = cv2.applyColorMap(pred_d_numpy, cv2.COLORMAP_RAINBOW) pred_d_color = cv2.cvtColor(pred_d_color, cv2.COLOR_BGR2RGB) # del model return pred_d_color with gr.Blocks() as demo: gr.Markdown("# Monocular Depth Estimation") with gr.Row(): location = gr.Radio(choices=['Indoor', 'Outdoor'],value='Indoor', label = "Select Location Type") model_name = gr.Radio(['Unet', 'Resnet_enc', 'Densenet_enc'],value="Densenet_enc" ,label="Select model") with gr.Row(): with gr.Column(): input_image = gr.Image(label = "Input Image for Depth Estimation") with gr.Column(): output_depth_map = gr.Image(label = "Depth prediction Heatmap") with gr.Row(): predict_btn = gr.Button("Generate Depthmap") predict_btn.click(fn=predict, inputs=[location, model_name, input_image], outputs=output_depth_map) with gr.Row(): gr.Examples(['./demo_data/Bathroom.jpg', './demo_data/Bedroom.jpg', './demo_data/Bookstore.jpg', './demo_data/Classroom.jpg', './demo_data/Computerlab.jpg', './demo_data/kitti_1.png'], inputs=input_image) demo.launch()