Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,639 Bytes
a390bd6 2fb6499 311419e 2fb6499 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 7c4abde a390bd6 7c4abde a390bd6 c360cac a390bd6 c360cac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
from PIL import Image
import spaces
import functools
import os
import tempfile
import numpy as np
import torch as torch
torch.backends.cuda.matmul.allow_tf32 = True
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, AutoTokenizer
from DAI.pipeline_all import DAIPipeline
from DAI.controlnetvae import ControlNetVAEModel
from DAI.decoder import CustomAutoencoderKL
def process_image(pipe, vae_2, image):
# Save the input image to a temporary file
temp_input_path = tempfile.mktemp(suffix=".png")
image.save(temp_input_path)
name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
print(f"Processing image {name_base}{name_ext}")
path_output_dir = tempfile.mkdtemp()
path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
resolution = None
pipe_out = pipe(
image=image,
prompt="remove glass reflection",
vae_2=vae_2,
processing_resolution=resolution,
)
processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
processed_frame = (processed_frame[0] * 255).astype(np.uint8)
processed_frame = Image.fromarray(processed_frame)
processed_frame.save(path_out_png)
return processed_frame
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float32
pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
revision = None
variant = None
# Load the model
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
).to(device)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path2,
subfolder="tokenizer",
revision=revision,
use_fast=False,
)
pipe = DAIPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
scheduler=None,
feature_extractor=None,
t_start=0,
).to(device)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
# Cache example images in memory
example_images_dir = "files/image"
example_images = []
for i in range(1, 9):
image_path = os.path.join(example_images_dir, f"{i}.png")
if os.path.exists(image_path):
example_images.append([Image.open(image_path)])
# Create a Gradio interface
interface = gr.Interface(
fn=spaces.GPU(functools.partial(process_image, pipe, vae_2)),
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Dereflection Any Image",
description="Upload an image to remove glass reflections.",
examples=example_images,
)
interface.launch()
|