Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,640 Bytes
a390bd6 311419e 7c4abde 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 311419e a390bd6 7c4abde a390bd6 7c4abde a390bd6 c360cac a390bd6 c360cac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
from PIL import Image
from DAI.pipeline_all import DAIPipeline
import os
import tempfile
import numpy as np
import torch as torch
torch.backends.cuda.matmul.allow_tf32 = True
import spaces
import functools
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, AutoTokenizer
from DAI.controlnetvae import ControlNetVAEModel
from DAI.decoder import CustomAutoencoderKL
def process_image(pipe, vae_2, image):
# Save the input image to a temporary file
temp_input_path = tempfile.mktemp(suffix=".png")
image.save(temp_input_path)
name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
print(f"Processing image {name_base}{name_ext}")
path_output_dir = tempfile.mkdtemp()
path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
resolution = None
pipe_out = pipe(
image=image,
prompt="remove glass reflection",
vae_2=vae_2,
processing_resolution=resolution,
)
processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
processed_frame = (processed_frame[0] * 255).astype(np.uint8)
processed_frame = Image.fromarray(processed_frame)
processed_frame.save(path_out_png)
return processed_frame
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float32
pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
revision = None
variant = None
# Load the model
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
).to(device)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path2,
subfolder="tokenizer",
revision=revision,
use_fast=False,
)
pipe = DAIPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
scheduler=None,
feature_extractor=None,
t_start=0,
).to(device)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
# Cache example images in memory
example_images_dir = "files/image"
example_images = []
for i in range(1, 9):
image_path = os.path.join(example_images_dir, f"{i}.png")
if os.path.exists(image_path):
example_images.append([Image.open(image_path)])
# Create a Gradio interface
interface = gr.Interface(
fn=spaces.GPU(functools.partial(process_image, pipe, vae_2)),
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Dereflection Any Image",
description="Upload an image to remove glass reflections.",
examples=example_images,
)
interface.launch()
|