File size: 5,377 Bytes
59880dc
311419e
 
1cedc13
651dfe7
1cedc13
7cdacae
 
 
 
 
651dfe7
 
 
1cedc13
311419e
 
1cedc13
 
cd178d4
1cedc13
7cdacae
1cedc13
 
 
 
 
 
 
7cdacae
1cedc13
 
 
 
 
 
 
 
 
 
 
 
84063e2
 
 
 
 
 
 
 
 
 
 
 
 
7cdacae
c9620cb
7cdacae
1cedc13
 
68a7cdb
6d71d5c
84063e2
6d71d5c
 
02713a6
 
 
 
 
a41990b
7cdacae
311419e
651dfe7
311419e
 
c9620cb
311419e
 
7cdacae
311419e
 
 
 
84063e2
1cedc13
7cdacae
1cedc13
7cdacae
1cedc13
c9620cb
1cedc13
7c72554
 
 
 
1cedc13
7c72554
 
8895d65
 
 
c9620cb
 
 
 
 
 
02713a6
c9620cb
8895d65
 
6d71d5c
8895d65
 
 
 
c9620cb
6d71d5c
8895d65
 
 
 
 
 
 
 
c9620cb
84063e2
8895d65
651dfe7
1cedc13
651dfe7
7cdacae
651dfe7
1cedc13
cd178d4
c360cac
651dfe7
1cedc13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import spaces  # 必须放在最前面
import os
import numpy as np
import torch
from PIL import Image
import gradio as gr

# 延迟 CUDA 初始化
weight_dtype = torch.float32

# 加载模型组件
from DAI.pipeline_all import DAIPipeline
from DAI.controlnetvae import ControlNetVAEModel
from DAI.decoder import CustomAutoencoderKL
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, AutoTokenizer

pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载模型
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path2, subfolder="vae").to(device)
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path2, subfolder="text_encoder").to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path2, subfolder="tokenizer", use_fast=False)

# 创建推理管道
pipe = DAIPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    controlnet=controlnet,
    safety_checker=None,
    scheduler=None,
    feature_extractor=None,
    t_start=0,
).to(device)

def resize_image(image, max_size):
    """Resize the image so that the maximum side is max_size."""
    width, height = image.size
    if max(width, height) > max_size:
        if width > height:
            new_width = max_size
            new_height = int(height * (max_size / width))
        else:
            new_height = max_size
            new_width = int(width * (max_size / height))
        image = image.resize((new_width, new_height), Image.LANCZOS)
    return image

@spaces.GPU
def process_image(input_image, resolution_choice):
    # 将 Gradio 输入转换为 PIL 图像
    input_image = Image.fromarray(input_image)

    # 根据用户选择设置处理分辨率
    if resolution_choice == "768":
        input_image = resize_image(input_image, 768)
        processing_resolution = None
    else:
        if input_image.size[0] > 2560 or input_image.size[1] > 2560:
            processing_resolution = 2560  # 限制最大分辨率
            input_image = resize_image(input_image, 2560)
        else:
            processing_resolution = 0  # 使用原始分辨率

    # 处理图像
    pipe_out = pipe(
        image=input_image,
        prompt="remove glass reflection",
        vae_2=vae_2,
        processing_resolution=processing_resolution,
    )

    # 将输出转换为图像
    processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
    processed_frame = (processed_frame[0] * 255).astype(np.uint8)
    processed_frame = Image.fromarray(processed_frame)

    return input_image, processed_frame  # 返回调整后的输入图片和处理后的图片

# 创建 Gradio 界面
def create_gradio_interface():
    # 示例图像
    example_images = [
        [os.path.join("files", "image", f"{i}.png"), "768"] for i in range(1, 14)
    ]
    title = "# Dereflection Any Image"
    description = """Official demo for **Dereflection Any Image**.
    Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""

    with gr.Blocks() as demo:
        gr.Markdown(title)
        gr.Markdown(description)
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", type="numpy")
                resolution_choice = gr.Radio(
                    choices=["768", "Original Resolution"],
                    label="Processing Resolution",
                    value="768",  # 默认选择原始分辨率
                )
                gr.Markdown(
                    "Select the resolution for processing the image, 768 is recommended for faster processing and stable performance. Higher resolution may take longer to process, we restrict the maximum resolution to 2560."
                )
                submit_btn = gr.Button("Remove Reflection", variant="primary")
            with gr.Column():
                output_image = gr.Image(label="Processed Image")

        # 添加示例
        gr.Examples(
            examples=example_images,
            inputs=[input_image, resolution_choice],  # 输入组件列表
            outputs=output_image,
            fn=process_image,
            cache_examples=False,  # 缓存结果以加快加载速度
            label="Example Images",
        )

        # 绑定按钮点击事件
        submit_btn.click(
            fn=process_image,
            inputs=[input_image, resolution_choice],  # 输入组件列表
            outputs=[input_image, output_image],  # 输出组件列表
        )

    return demo

# 主函数
def main():
    demo = create_gradio_interface()
    demo.launch(server_name="0.0.0.0", server_port=7860)

if __name__ == "__main__":
    main()