Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
import spaces | |
import functools | |
import os | |
import tempfile | |
import numpy as np | |
import torch as torch | |
torch.backends.cuda.matmul.allow_tf32 = True | |
from diffusers import ( | |
AutoencoderKL, | |
UNet2DConditionModel, | |
) | |
from transformers import CLIPTextModel, AutoTokenizer | |
from DAI.pipeline_all import DAIPipeline | |
from DAI.controlnetvae import ControlNetVAEModel | |
from DAI.decoder import CustomAutoencoderKL | |
def process_image(pipe, vae_2, image): | |
# Save the input image to a temporary file | |
temp_input_path = tempfile.mktemp(suffix=".png") | |
image.save(temp_input_path) | |
name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path)) | |
print(f"Processing image {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png") | |
resolution = None | |
pipe_out = pipe( | |
image=image, | |
prompt="remove glass reflection", | |
vae_2=vae_2, | |
processing_resolution=resolution, | |
) | |
processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2 | |
processed_frame = (processed_frame[0] * 255).astype(np.uint8) | |
processed_frame = Image.fromarray(processed_frame) | |
processed_frame.save(path_out_png) | |
return processed_frame | |
if __name__ == "__main__": | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
weight_dtype = torch.float32 | |
pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0" | |
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1" | |
revision = None | |
variant = None | |
# Load the model | |
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device) | |
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device) | |
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device) | |
vae = AutoencoderKL.from_pretrained( | |
pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant | |
).to(device) | |
text_encoder = CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path2, | |
subfolder="tokenizer", | |
revision=revision, | |
use_fast=False, | |
) | |
pipe = DAIPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
controlnet=controlnet, | |
safety_checker=None, | |
scheduler=None, | |
feature_extractor=None, | |
t_start=0, | |
).to(device) | |
try: | |
import xformers | |
pipe.enable_xformers_memory_efficient_attention() | |
except: | |
pass # run without xformers | |
# Cache example images in memory | |
example_images_dir = "files/image" | |
example_images = [] | |
for i in range(1, 9): | |
image_path = os.path.join(example_images_dir, f"{i}.png") | |
if os.path.exists(image_path): | |
example_images.append([Image.open(image_path)]) | |
# Create a Gradio interface | |
interface = gr.Interface( | |
fn=spaces.GPU(functools.partial(process_image, pipe, vae_2)), | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="pil"), | |
title="Dereflection Any Image", | |
description="Upload an image to remove glass reflections.", | |
examples=example_images, | |
) | |
interface.launch() | |