|
import os |
|
import gradio as gr |
|
import cv2 |
|
import torch |
|
import urllib.request |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import subprocess |
|
|
|
def calculate_depth(model_type, gan_type, dim, slider, img): |
|
|
|
if not os.path.exists('temp'): |
|
os.system('mkdir temp') |
|
|
|
filename = "Images/Input-Test/1.png" |
|
|
|
img.save(filename, "PNG") |
|
|
|
midas = torch.hub.load("intel-isl/MiDaS", model_type) |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
midas.to(device) |
|
midas.eval() |
|
|
|
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") |
|
|
|
if model_type == "DPT_Large" or model_type == "DPT_Hybrid": |
|
transform = midas_transforms.dpt_transform |
|
else: |
|
transform = midas_transforms.small_transform |
|
|
|
img = cv2.imread(filename) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
input_batch = transform(img).to(device) |
|
|
|
with torch.no_grad(): |
|
prediction = midas(input_batch) |
|
|
|
prediction = torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=img.shape[:2], |
|
mode="bicubic", |
|
align_corners=False, |
|
).squeeze() |
|
|
|
output = prediction.cpu().numpy() |
|
|
|
formatted = (output * 255.0 / np.max(output)).astype('uint8') |
|
out_im = Image.fromarray(formatted) |
|
out_im.save("Images/Input-Test/1_d.png", "PNG") |
|
|
|
c_images = '1' |
|
name_output = 'out' |
|
|
|
dict_saved_gans = {'Cycle': '74962_110', 'Cycle(half)': '66942','noCycle': '31219_110', 'noCycle-noCr': '92332_110', 'noCycle-noCr-noL1': '82122_110', 'OnlyGen': '70944_110' } |
|
|
|
subprocess.run(["python", "main.py", "--gan_type", 'WiggleGAN', "--expandGen", "4", "--expandDis", "4", "--batch_size", c_images, "--cIm", c_images, |
|
"--visdom", "false", "--wiggleDepth", str(slider), "--seedLoad", dict_saved_gans[gan_type], "--gpu_mode", "false", "--imageDim", dim, "--name_wiggle", name_output |
|
]) |
|
subprocess.run(["python", "WiggleResults/split.py", "--dim", dim]) |
|
|
|
return [out_im,f'WiggleResults/' + name_output + '_0.gif', f'WiggleResults/' + name_output + '_0.mp4', f'WiggleResults/'+ name_output + '.jpg'] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Start typing below and then click **Run** to see the output.") |
|
|
|
|
|
midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"] |
|
gan_models = ["Cycle","Cycle(half)","noCycle","noCycle-noCr","noCycle-noCr-noL1","OnlyGen"] |
|
dim = ['256','512','1024'] |
|
|
|
with gr.Row(): |
|
inp = [gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type")] |
|
inp.append(gr.inputs.Dropdown(gan_models, default="Cycle", label="Different GAN trainings")) |
|
inp.append(gr.inputs.Dropdown(dim, default="256", label="Wiggle dimension result")) |
|
inp.append(gr.Slider(1,15, default = 2, label='StepCycles',step= 1)) |
|
with gr.Row(): |
|
inp.append(gr.Image(type="pil", label="Input")) |
|
out = [gr.Image(type="pil", label="depth_estimation")] |
|
with gr.Row(): |
|
out.append(gr.Image(type="file", label="Output_wiggle_gif")) |
|
out.append(gr.Video(type="file", label="Output_wiggle_video")) |
|
out.append(gr.Image(type="file", label="Output_images")) |
|
btn = gr.Button("Calculate depth + Wiggle") |
|
btn.click(fn=calculate_depth, inputs=inp, outputs=out) |
|
|
|
|
|
demo.launch() |