File size: 2,108 Bytes
8c753d1
 
b4eade4
8c753d1
b4eade4
 
 
 
 
e3bb30a
8c753d1
 
b4eade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c753d1
 
 
 
b4eade4
8c753d1
b4eade4
 
8c753d1
b4eade4
8c753d1
b4eade4
 
 
 
 
 
 
 
a05fab5
8c753d1
 
 
 
f5085fa
 
a05fab5
f5085fa
8c753d1
e3bb30a
8c753d1
 
f5085fa
8c753d1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
from PIL import Image
from collections import OrderedDict
import torch
from models.model import GLPDepth
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import os

# load model
DEVICE='cpu'
def load_mde_model(path):
    model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE)
    model_weight = torch.load(path, map_location=torch.device('cpu'))
    model_weight = model_weight['model_state_dict']
    if 'module' in next(iter(model_weight.items()))[0]:
        model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
    model.load_state_dict(model_weight)
    model.eval()
    return model

model = load_mde_model('best_model.ckpt')
preprocess = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
]) 

def predict(input_image):
    pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
    # transform image to torch and do preprocessing
    torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0)
    # model predict
    with torch.no_grad():
        output_patch = model(torch_img)
    # transform torch to image
    predicted_image = output_patch['pred_d'].squeeze().cpu().detach().numpy()
    # return correct image
    fig, ax = plt.subplots()
    im = ax.imshow(predicted_image, cmap='jet', vmin=0, vmax=np.max(predicted_image))
    plt.colorbar(im, ax=ax)

    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    return data #, str(predicted_image.tolist())

iface = gr.Interface(
    fn=predict, 
    inputs=gr.Image(shape=(512,512)), 
    outputs=[
        gr.Image(shape=(512,512)),
        # gr.outputs.Textbox(label='Raw output')
    ],
    examples=[
        [f"demo_imgs/{name}"] for name in os.listdir('demo_imgs')
    ],
    title="DTM Estimation",
    description="This demo predict a DTM using GLP Depth model. It will scale input image to 512x512 and at the end it will apply a colormap to better visualize the output."
)
iface.launch()