Intrinsic / app.py
ccareaga's picture
loading models at start up
a3aeb1a
raw
history blame
4.07 kB
import spaces
import gradio as gr
import numpy as np
import torch
from chrislib.general import uninvert, invert, view, view_scale
from intrinsic.pipeline import load_models, run_pipeline
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Instead of loading models at startup, we'll create a cache for models
model_cache = {}
model_cache['v2'] = load_models('v2', device=DEVICE)
model_cache['v2.1'] = load_models('v2.1', device=DEVICE)
def generate_pipeline(models):
def pipeline_func(image, **kwargs):
return run_pipeline(models, image, **kwargs)
return pipeline_func
@spaces.GPU
def process_image(image, model_version):
# Check if image is provided
if image is None:
return [None, None, None]
print(f"Processing with model version: {model_version}")
print(image.shape)
image = image.astype(np.single) / 255.
# Get or load the selected model
models = model_cache[model_version]
pipeline_func = generate_pipeline(models)
result = pipeline_func(image, device=DEVICE, resize_conf=1024)
return [view(result['hr_alb']), 1 - invert(result['dif_shd']), view_scale(result['pos_res'])]
with gr.Blocks(
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
.tabs button.selected {
font-size: 20px !important;
color: crimson !important;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
.md_feedback li {
margin-bottom: 0px !important;
}
.image-gallery {
display: flex;
flex-wrap: wrap;
gap: 10px;
justify-content: center;
}
.image-gallery > * {
flex: 1;
min-width: 200px;
}
""",
) as demo:
gr.Markdown(
"""
# Colorful Diffuse Intrinsic Image Decomposition in the Wild
"""
)
# Model version selector with information panel
with gr.Row():
model_version = gr.Dropdown(
choices=["v2", "v2.1"],
value="v2",
label="Model Version",
info="Select which model weights to use",
scale=1
)
gr.Markdown("""
<p align="center">
<a title="Website" href="https://yaksoy.github.io/ColorfulShading/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
</a>
<a title="Github" href="https://github.com/compphoto/Intrinsic" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/compphoto/Intrinsic?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
**V2**: Original weights from the paper.
**V2.1**: More albedo detail and improved diffuse shading estimation.
""")
# Gallery-style layout for all images
with gr.Row(elem_classes="image-gallery"):
input_img = gr.Image(label="Input Image")
alb_img = gr.Image(label="Albedo")
shd_img = gr.Image(label="Diffuse Shading")
dif_img = gr.Image(label="Diffuse Image")
# Update to pass model_version to process_image
input_img.change(
process_image,
inputs=[input_img, model_version],
outputs=[alb_img, shd_img, dif_img]
)
# Add event handler for when model_version changes
model_version.change(
process_image,
inputs=[input_img, model_version],
outputs=[alb_img, shd_img, dif_img]
)
demo.launch(show_error=True)