sjtu-deepvision's picture
Upload app.py
7c4abde verified
raw
history blame
3.64 kB
import gradio as gr
from PIL import Image
from DAI.pipeline_all import DAIPipeline
import os
import tempfile
import numpy as np
import torch as torch
torch.backends.cuda.matmul.allow_tf32 = True
import spaces
import functools
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, AutoTokenizer
from DAI.controlnetvae import ControlNetVAEModel
from DAI.decoder import CustomAutoencoderKL
def process_image(pipe, vae_2, image):
# Save the input image to a temporary file
temp_input_path = tempfile.mktemp(suffix=".png")
image.save(temp_input_path)
name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
print(f"Processing image {name_base}{name_ext}")
path_output_dir = tempfile.mkdtemp()
path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
resolution = None
pipe_out = pipe(
image=image,
prompt="remove glass reflection",
vae_2=vae_2,
processing_resolution=resolution,
)
processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
processed_frame = (processed_frame[0] * 255).astype(np.uint8)
processed_frame = Image.fromarray(processed_frame)
processed_frame.save(path_out_png)
return processed_frame
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float32
pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
revision = None
variant = None
# Load the model
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
).to(device)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path2,
subfolder="tokenizer",
revision=revision,
use_fast=False,
)
pipe = DAIPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
scheduler=None,
feature_extractor=None,
t_start=0,
).to(device)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
# Cache example images in memory
example_images_dir = "files/image"
example_images = []
for i in range(1, 9):
image_path = os.path.join(example_images_dir, f"{i}.png")
if os.path.exists(image_path):
example_images.append([Image.open(image_path)])
# Create a Gradio interface
interface = gr.Interface(
fn=spaces.GPU(functools.partial(process_image, pipe, vae_2)),
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Dereflection Any Image",
description="Upload an image to remove glass reflections.",
examples=example_images,
)
interface.launch()