File size: 3,639 Bytes
a390bd6
 
2fb6499
 
311419e
 
 
 
 
 
 
 
 
 
 
 
 
2fb6499
 
311419e
 
 
 
a390bd6
 
 
 
311419e
a390bd6
311419e
 
 
 
 
 
 
a390bd6
311419e
 
 
 
 
 
 
 
 
 
a390bd6
311419e
a390bd6
311419e
 
 
 
 
 
a390bd6
311419e
 
 
 
 
 
a390bd6
 
311419e
 
a390bd6
 
311419e
a390bd6
 
 
 
 
311419e
a390bd6
 
 
 
 
 
 
 
 
 
 
7c4abde
 
 
 
 
 
a390bd6
 
 
 
 
 
 
 
 
 
7c4abde
a390bd6
 
 
 
 
 
c360cac
a390bd6
c360cac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
from PIL import Image
import spaces
import functools
import os
import tempfile
import numpy as np
import torch as torch
torch.backends.cuda.matmul.allow_tf32 = True

from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
)

from transformers import CLIPTextModel, AutoTokenizer

from DAI.pipeline_all import DAIPipeline

from DAI.controlnetvae import ControlNetVAEModel

from DAI.decoder import CustomAutoencoderKL

def process_image(pipe, vae_2, image):
    # Save the input image to a temporary file
    temp_input_path = tempfile.mktemp(suffix=".png")
    image.save(temp_input_path)

    name_base, name_ext = os.path.splitext(os.path.basename(temp_input_path))
    print(f"Processing image {name_base}{name_ext}")

    path_output_dir = tempfile.mkdtemp()
    path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
    resolution = None

    pipe_out = pipe(
        image=image,
        prompt="remove glass reflection",
        vae_2=vae_2,
        processing_resolution=resolution,
    )

    processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
    processed_frame = (processed_frame[0] * 255).astype(np.uint8)
    processed_frame = Image.fromarray(processed_frame)
    processed_frame.save(path_out_png)

    return processed_frame

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weight_dtype = torch.float32
    pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
    pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
    revision = None
    variant = None

    # Load the model
    controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
    unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
    vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)

    vae = AutoencoderKL.from_pretrained(
        pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
    ).to(device)

    text_encoder = CLIPTextModel.from_pretrained(
        pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path2,
        subfolder="tokenizer",
        revision=revision,
        use_fast=False,
    )
    pipe = DAIPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        controlnet=controlnet,
        safety_checker=None,
        scheduler=None,
        feature_extractor=None,
        t_start=0,
    ).to(device)

    try:
        import xformers
        pipe.enable_xformers_memory_efficient_attention()
    except:
        pass  # run without xformers

    # Cache example images in memory
    example_images_dir = "files/image"
    example_images = []
    for i in range(1, 9):
        image_path = os.path.join(example_images_dir, f"{i}.png")
        if os.path.exists(image_path):
            example_images.append([Image.open(image_path)])

    # Create a Gradio interface
    interface = gr.Interface(
        fn=spaces.GPU(functools.partial(process_image, pipe, vae_2)),
        inputs=gr.Image(type="pil"),
        outputs=gr.Image(type="pil"),
        title="Dereflection Any Image",
        description="Upload an image to remove glass reflections.",
        examples=example_images,
    )

    interface.launch()