Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from chrislib.general import uninvert, invert, view, view_scale | |
from intrinsic.pipeline import load_models, run_pipeline | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Instead of loading models at startup, we'll create a cache for models | |
model_cache = {} | |
def get_model(model_version): | |
if model_version not in model_cache: | |
model_cache[model_version] = load_models(model_version, device=DEVICE) | |
return model_cache[model_version] | |
def generate_pipeline(models): | |
def pipeline_func(image, **kwargs): | |
return run_pipeline(models, image, **kwargs) | |
return pipeline_func | |
def process_image(image, model_version): | |
# Check if image is provided | |
if image is None: | |
return [None, None, None] | |
print(f"Processing with model version: {model_version}") | |
print(image.shape) | |
image = image.astype(np.single) / 255. | |
# Get or load the selected model | |
models = get_model(model_version) | |
pipeline_func = generate_pipeline(models) | |
result = pipeline_func(image, device=DEVICE, resize_conf=1024) | |
return [view(result['hr_alb']), 1 - invert(result['dif_shd']), view_scale(result['pos_res'])] | |
with gr.Blocks( | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
.image-gallery { | |
display: flex; | |
flex-wrap: wrap; | |
gap: 10px; | |
justify-content: center; | |
} | |
.image-gallery > * { | |
flex: 1; | |
min-width: 200px; | |
} | |
""", | |
) as demo: | |
gr.Markdown( | |
""" | |
# Colorful Diffuse Intrinsic Image Decomposition in the Wild | |
<p align="center"> | |
<a title="Website" href="https://yaksoy.github.io/ColorfulShading/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
</a> | |
<a title="Github" href="https://github.com/compphoto/Intrinsic" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/compphoto/Intrinsic?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
""" | |
) | |
# Model version selector with information panel | |
with gr.Row(): | |
model_version = gr.Dropdown( | |
choices=["v2", "v2.1"], | |
value="v2", | |
label="Model Version", | |
info="Select which model weights to use", | |
scale=1 | |
) | |
gr.Markdown(""" | |
The model may take a few seconds to load the first time you use it. | |
Subsequent decompositions should be faster after the model is loaded. | |
""") | |
# Gallery-style layout for all images | |
with gr.Row(elem_classes="image-gallery"): | |
input_img = gr.Image(label="Input Image") | |
alb_img = gr.Image(label="Albedo") | |
shd_img = gr.Image(label="Diffuse Shading") | |
dif_img = gr.Image(label="Diffuse Image") | |
# Update to pass model_version to process_image | |
input_img.change( | |
process_image, | |
inputs=[input_img, model_version], | |
outputs=[alb_img, shd_img, dif_img] | |
) | |
# Add event handler for when model_version changes | |
model_version.change( | |
process_image, | |
inputs=[input_img, model_version], | |
outputs=[alb_img, shd_img, dif_img] | |
) | |
demo.launch(show_error=True) | |