Intrinsic / app.py
ccareaga's picture
small formatting change
de3bb71
raw
history blame
4.15 kB
import spaces
import gradio as gr
import numpy as np
import torch
from chrislib.general import uninvert, invert, view, view_scale
from intrinsic.pipeline import load_models, run_pipeline
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Instead of loading models at startup, we'll create a cache for models
model_cache = {}
def get_model(model_version):
if model_version not in model_cache:
model_cache[model_version] = load_models(model_version, device=DEVICE)
return model_cache[model_version]
def generate_pipeline(models):
def pipeline_func(image, **kwargs):
return run_pipeline(models, image, **kwargs)
return pipeline_func
@spaces.GPU
def process_image(image, model_version):
# Check if image is provided
if image is None:
return [None, None, None]
print(f"Processing with model version: {model_version}")
print(image.shape)
image = image.astype(np.single) / 255.
# Get or load the selected model
models = get_model(model_version)
pipeline_func = generate_pipeline(models)
result = pipeline_func(image, device=DEVICE, resize_conf=1024)
return [view(result['hr_alb']), 1 - invert(result['dif_shd']), view_scale(result['pos_res'])]
with gr.Blocks(
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
.tabs button.selected {
font-size: 20px !important;
color: crimson !important;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
.md_feedback li {
margin-bottom: 0px !important;
}
.image-gallery {
display: flex;
flex-wrap: wrap;
gap: 10px;
justify-content: center;
}
.image-gallery > * {
flex: 1;
min-width: 200px;
}
""",
) as demo:
gr.Markdown(
"""
# Colorful Diffuse Intrinsic Image Decomposition in the Wild
<p align="center">
<a title="Website" href="https://yaksoy.github.io/ColorfulShading/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
</a>
<a title="Github" href="https://github.com/compphoto/Intrinsic" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/compphoto/Intrinsic?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
"""
)
# Model version selector with information panel
with gr.Row():
model_version = gr.Dropdown(
choices=["v2", "v2.1"],
value="v2",
label="Model Version",
info="Select which model weights to use",
scale=1
)
gr.Markdown("""
The model may take a few seconds to load the first time you use it.
Subsequent decompositions should be faster after the model is loaded.
""")
# Gallery-style layout for all images
with gr.Row(elem_classes="image-gallery"):
input_img = gr.Image(label="Input Image")
alb_img = gr.Image(label="Albedo")
shd_img = gr.Image(label="Diffuse Shading")
dif_img = gr.Image(label="Diffuse Image")
# Update to pass model_version to process_image
input_img.change(
process_image,
inputs=[input_img, model_version],
outputs=[alb_img, shd_img, dif_img]
)
# Add event handler for when model_version changes
model_version.change(
process_image,
inputs=[input_img, model_version],
outputs=[alb_img, shd_img, dif_img]
)
demo.launch(show_error=True)