Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import torchvision | |
from torchvision import transforms | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from models.modelNetA import Generator as GA | |
from models.modelNetB import Generator as GB | |
from models.modelNetC import Generator as GC | |
# load model | |
modeltype2path = { | |
'ModelA': 'DTM_exp_train10%_model_a/g-best.pth', | |
'ModelB': 'DTM_exp_train10%_model_b/g-best.pth', | |
'ModelC': 'DTM_exp_train10%_model_c/g-best.pth', | |
} | |
DEVICE='cpu' | |
MODELS_TYPE = list(modeltype2path.keys()) | |
generators = [GA(), GB(), GC()] | |
for i in range(len(generators)): | |
generators[i] = torch.nn.DataParallel(generators[i]) | |
state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu')) | |
generators[i].load_state_dict(state_dict) | |
generators[i] = generators[i].module.to(DEVICE) | |
generators[i].eval() | |
preprocess = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.ToTensor() | |
]) | |
def predict(input_image, model_name): | |
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB') | |
# transform image to torch and do preprocessing | |
torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE) | |
# model predict | |
with torch.no_grad(): | |
output = generators[MODELS_TYPE.index(model_name)](torch_img) | |
sr, sr_dem_selected = output[0], output[1] | |
# transform torch to image | |
sr = sr.squeeze(0).cpu() | |
torchvision.utils.save_image(sr, 'sr_pred.png') | |
sr = np.array(Image.open('sr_pred.png')) | |
sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy() | |
fig, ax = plt.subplots() | |
im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected)) | |
plt.colorbar(im, ax=ax) | |
fig.canvas.draw() | |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
# return correct image | |
return sr, data | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(shape=(512,512)), | |
gr.inputs.Radio(MODELS_TYPE) | |
], | |
outputs=[ | |
gr.Image(), | |
gr.Image() | |
], | |
examples=[ | |
["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image | |
], | |
title="DTM Estimation", | |
description="This demo predict a DTM..." | |
) | |
iface.launch() |