|
import os |
|
import shutil |
|
|
|
import gradio as gr |
|
|
|
|
|
desc = """ |
|
<p align="center"> |
|
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://www.obukhov.ai/img/badges/badge-website.svg"> |
|
</a> |
|
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> |
|
</a> |
|
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> |
|
</a> |
|
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> |
|
</a> |
|
</p> |
|
<p align="justify"> |
|
Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom. |
|
</p> |
|
""" |
|
|
|
|
|
def init_persistence(purge=False): |
|
if not os.path.exists('/data'): |
|
return |
|
os.environ['ckpt_dir'] = "/data/Marigold_ckpt" |
|
os.environ['TRANSFORMERS_CACHE'] = "/data/hfcache" |
|
os.environ['HF_DATASETS_CACHE'] = "/data/hfcache" |
|
os.environ['HF_HOME'] = "/data/hfcache" |
|
if purge: |
|
os.system("rm -rf /data/Marigold_ckpt/*") |
|
|
|
|
|
def download_code_weights(): |
|
os.system('git clone https://github.com/prs-eth/Marigold.git') |
|
os.system('cd Marigold && bash script/download_weights.sh') |
|
os.system('echo /data && ls -la /data') |
|
os.system('echo /data/Marigold_ckpt && ls -la /data/Marigold_ckpt') |
|
os.system('echo /data/Marigold_ckpt/Marigold_v1_merged && ls -la /data/Marigold_ckpt/Marigold_v1_merged') |
|
|
|
|
|
def find_first_png(directory): |
|
for file in os.listdir(directory): |
|
if file.lower().endswith(".png"): |
|
return os.path.join(directory, file) |
|
return None |
|
|
|
|
|
def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None): |
|
if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None: |
|
return path_out_png, path_out_obj, path_out_2_png |
|
|
|
path_input_dir = path_input + ".input" |
|
path_output_dir = path_input + ".output" |
|
os.makedirs(path_input_dir, exist_ok=True) |
|
os.makedirs(path_output_dir, exist_ok=True) |
|
shutil.copy(path_input, path_input_dir) |
|
|
|
persistence_args = "" |
|
if os.path.exists('/data'): |
|
persistence_args = "--checkpoint /data/Marigold_ckpt/Marigold_v1_merged" |
|
|
|
os.system( |
|
f"cd Marigold && python3 run.py " |
|
f"{persistence_args} " |
|
f"--input_rgb_dir \"{path_input_dir}\" " |
|
f"--output_dir \"{path_output_dir}\" " |
|
f"--n_infer 10 " |
|
f"--denoise_steps 10 " |
|
) |
|
|
|
|
|
path_out_colored = find_first_png(path_output_dir + "/depth_colored") |
|
assert path_out_colored is not None, "Processing failed" |
|
path_out_bw = find_first_png(path_output_dir + "/depth_bw") |
|
assert path_out_bw is not None, "Processing failed" |
|
|
|
return path_out_colored, path_out_bw |
|
|
|
|
|
iface = gr.Interface( |
|
title="Marigold Depth Estimation", |
|
description=desc, |
|
thumbnail="marigold_logo_square.jpg", |
|
fn=marigold_process, |
|
inputs=[ |
|
gr.Image( |
|
label="Input Image", |
|
type="filepath", |
|
), |
|
gr.File( |
|
label="Predicted depth (red-near, blue-far)", |
|
visible=False, |
|
), |
|
gr.File( |
|
label="Predicted depth (16-bit PNG)", |
|
visible=False, |
|
), |
|
], |
|
outputs=[ |
|
gr.Image( |
|
label="Predicted depth (red-near, blue-far)", |
|
type="pil", |
|
), |
|
gr.Image( |
|
label="Predicted depth (16-bit PNG)", |
|
type="pil", |
|
elem_classes="imgdownload", |
|
), |
|
], |
|
allow_flagging="never", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css=""" |
|
.viewport { |
|
aspect-ratio: 4/3; |
|
} |
|
.imgdownload { |
|
height: 32px; |
|
} |
|
""", |
|
cache_examples=True, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
init_persistence() |
|
download_code_weights() |
|
iface.queue().launch(server_name="0.0.0.0", server_port=7860) |
|
|