|
import spaces |
|
import os |
|
from typing import cast |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from diffusers import DDIMScheduler |
|
from load_image import load_exr_image, load_ldr_image |
|
from pipeline_x2rgb import StableDiffusionAOVDropoutPipeline |
|
|
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" |
|
|
|
current_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
_pipe = StableDiffusionAOVDropoutPipeline.from_pretrained( |
|
"zheng95z/x-to-rgb", |
|
torch_dtype=torch.float16, |
|
cache_dir=os.path.join(current_directory, "model_cache"), |
|
).to("cuda") |
|
pipe = cast(StableDiffusionAOVDropoutPipeline, _pipe) |
|
pipe.scheduler = DDIMScheduler.from_config( |
|
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" |
|
) |
|
pipe.set_progress_bar_config(disable=True) |
|
pipe.to("cuda") |
|
pipe = cast(StableDiffusionAOVDropoutPipeline, pipe) |
|
|
|
|
|
@spaces.GPU |
|
def generate( |
|
albedo, |
|
normal, |
|
roughness, |
|
metallic, |
|
irradiance, |
|
prompt: str, |
|
seed: int, |
|
inference_step: int, |
|
num_samples: int, |
|
guidance_scale: float, |
|
image_guidance_scale: float, |
|
) -> list[Image.Image]: |
|
generator = torch.Generator(device="cuda").manual_seed(seed) |
|
|
|
|
|
def process_image(file, **kwargs): |
|
if file is None: |
|
return None |
|
if file.name.endswith(".exr"): |
|
return load_exr_image(file.name, **kwargs).to("cuda") |
|
elif file.name.endswith((".png", ".jpg", ".jpeg")): |
|
return load_ldr_image(file.name, **kwargs).to("cuda") |
|
return None |
|
|
|
albedo_image = process_image(albedo, clamp=True) |
|
normal_image = process_image(normal, normalize=True) |
|
roughness_image = process_image(roughness, clamp=True) |
|
metallic_image = process_image(metallic, clamp=True) |
|
irradiance_image = process_image(irradiance, tonemaping=True, clamp=True) |
|
|
|
|
|
height, width = 768, 768 |
|
for img in [ |
|
albedo_image, |
|
normal_image, |
|
roughness_image, |
|
metallic_image, |
|
irradiance_image, |
|
]: |
|
if img is not None: |
|
height, width = img.shape[1], img.shape[2] |
|
break |
|
|
|
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] |
|
return_list = [] |
|
|
|
for i in range(num_samples): |
|
generated_image = pipe( |
|
prompt=prompt, |
|
albedo=albedo_image, |
|
normal=normal_image, |
|
roughness=roughness_image, |
|
metallic=metallic_image, |
|
irradiance=irradiance_image, |
|
num_inference_steps=inference_step, |
|
height=height, |
|
width=width, |
|
generator=generator, |
|
required_aovs=required_aovs, |
|
guidance_scale=guidance_scale, |
|
image_guidance_scale=image_guidance_scale, |
|
guidance_rescale=0.7, |
|
output_type="np", |
|
).images[0] |
|
|
|
return_list.append((generated_image, f"Generated Image {i}")) |
|
|
|
|
|
def post_process_image(img, **kwargs): |
|
if img is not None: |
|
return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image")) |
|
return np.zeros((height, width, 3)) |
|
|
|
return_list.extend( |
|
[ |
|
post_process_image(albedo_image, label="Albedo"), |
|
post_process_image(normal_image, label="Normal"), |
|
post_process_image(roughness_image, label="Roughness"), |
|
post_process_image(metallic_image, label="Metallic"), |
|
post_process_image(irradiance_image, label="Irradiance"), |
|
] |
|
) |
|
|
|
return return_list |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown("## Model X -> RGB (Intrinsic channels -> realistic image)") |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Given intrinsic channels") |
|
albedo = gr.File(label="Albedo", file_types=[".exr", ".png", ".jpg"]) |
|
normal = gr.File(label="Normal", file_types=[".exr", ".png", ".jpg"]) |
|
roughness = gr.File(label="Roughness", file_types=[".exr", ".png", ".jpg"]) |
|
metallic = gr.File(label="Metallic", file_types=[".exr", ".png", ".jpg"]) |
|
irradiance = gr.File( |
|
label="Irradiance", file_types=[".exr", ".png", ".jpg"] |
|
) |
|
|
|
gr.Markdown("### Parameters") |
|
prompt = gr.Textbox(label="Prompt") |
|
run_button = gr.Button(value="Run") |
|
with gr.Accordion("Advanced options", open=False): |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=-1, |
|
maximum=2147483647, |
|
step=1, |
|
randomize=True, |
|
) |
|
inference_step = gr.Slider( |
|
label="Inference Step", |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=50, |
|
) |
|
num_samples = gr.Slider( |
|
label="Samples", |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=1, |
|
) |
|
guidance_scale = gr.Slider( |
|
label="Guidance Scale", |
|
minimum=0.0, |
|
maximum=10.0, |
|
step=0.1, |
|
value=7.5, |
|
) |
|
image_guidance_scale = gr.Slider( |
|
label="Image Guidance Scale", |
|
minimum=0.0, |
|
maximum=10.0, |
|
step=0.1, |
|
value=1.5, |
|
) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("### Output Gallery") |
|
result_gallery = gr.Gallery( |
|
label="Output", |
|
show_label=False, |
|
elem_id="gallery", |
|
columns=2, |
|
) |
|
|
|
run_button.click( |
|
fn=generate, |
|
inputs=[ |
|
albedo, |
|
normal, |
|
roughness, |
|
metallic, |
|
irradiance, |
|
prompt, |
|
seed, |
|
inference_step, |
|
num_samples, |
|
guidance_scale, |
|
image_guidance_scale, |
|
], |
|
outputs=result_gallery, |
|
queue=True, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=False, share=False, show_api=False) |
|
|