sjtu-deepvision's picture
Upload app.py
2fb6499 verified
raw
history blame
3.64 kB
import gradio as gr
from PIL import Image
import spaces
import functools
import os
import tempfile
import numpy as np
import torch as torch
torch.backends.cuda.matmul.allow_tf32 = True
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, AutoTokenizer
from DAI.pipeline_all import DAIPipeline
from DAI.controlnetvae import ControlNetVAEModel
from DAI.decoder import CustomAutoencoderKL
def process_image(pipe, vae_2, image):
# Save the input image to a temporary file
temp_input_path = tempfile.mktemp(suffix=".png")
image.save(temp_input_path)
name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
print(f"Processing image {name_base}{name_ext}")
path_output_dir = tempfile.mkdtemp()
path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
resolution = None
pipe_out = pipe(
image=image,
prompt="remove glass reflection",
vae_2=vae_2,
processing_resolution=resolution,
)
processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
processed_frame = (processed_frame[0] * 255).astype(np.uint8)
processed_frame = Image.fromarray(processed_frame)
processed_frame.save(path_out_png)
return processed_frame
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float32
pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
revision = None
variant = None
# Load the model
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
).to(device)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path2,
subfolder="tokenizer",
revision=revision,
use_fast=False,
)
pipe = DAIPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
scheduler=None,
feature_extractor=None,
t_start=0,
).to(device)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
# Cache example images in memory
example_images_dir = "files/image"
example_images = []
for i in range(1, 9):
image_path = os.path.join(example_images_dir, f"{i}.png")
if os.path.exists(image_path):
example_images.append([Image.open(image_path)])
# Create a Gradio interface
interface = gr.Interface(
fn=spaces.GPU(functools.partial(process_image, pipe, vae_2)),
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Dereflection Any Image",
description="Upload an image to remove glass reflections.",
examples=example_images,
)
interface.launch()