Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,377 Bytes
59880dc 311419e 1cedc13 651dfe7 1cedc13 7cdacae 651dfe7 1cedc13 311419e 1cedc13 cd178d4 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 84063e2 7cdacae c9620cb 7cdacae 1cedc13 68a7cdb 6d71d5c 84063e2 6d71d5c 02713a6 a41990b 7cdacae 311419e 651dfe7 311419e c9620cb 311419e 7cdacae 311419e 84063e2 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 c9620cb 1cedc13 7c72554 1cedc13 7c72554 8895d65 c9620cb 02713a6 c9620cb 8895d65 6d71d5c 8895d65 c9620cb 6d71d5c 8895d65 c9620cb 84063e2 8895d65 651dfe7 1cedc13 651dfe7 7cdacae 651dfe7 1cedc13 cd178d4 c360cac 651dfe7 1cedc13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import spaces # 必须放在最前面
import os
import numpy as np
import torch
from PIL import Image
import gradio as gr
# 延迟 CUDA 初始化
weight_dtype = torch.float32
# 加载模型组件
from DAI.pipeline_all import DAIPipeline
from DAI.controlnetvae import ControlNetVAEModel
from DAI.decoder import CustomAutoencoderKL
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, AutoTokenizer
pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path2, subfolder="vae").to(device)
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path2, subfolder="text_encoder").to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path2, subfolder="tokenizer", use_fast=False)
# 创建推理管道
pipe = DAIPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
scheduler=None,
feature_extractor=None,
t_start=0,
).to(device)
def resize_image(image, max_size):
"""Resize the image so that the maximum side is max_size."""
width, height = image.size
if max(width, height) > max_size:
if width > height:
new_width = max_size
new_height = int(height * (max_size / width))
else:
new_height = max_size
new_width = int(width * (max_size / height))
image = image.resize((new_width, new_height), Image.LANCZOS)
return image
@spaces.GPU
def process_image(input_image, resolution_choice):
# 将 Gradio 输入转换为 PIL 图像
input_image = Image.fromarray(input_image)
# 根据用户选择设置处理分辨率
if resolution_choice == "768":
input_image = resize_image(input_image, 768)
processing_resolution = None
else:
if input_image.size[0] > 2560 or input_image.size[1] > 2560:
processing_resolution = 2560 # 限制最大分辨率
input_image = resize_image(input_image, 2560)
else:
processing_resolution = 0 # 使用原始分辨率
# 处理图像
pipe_out = pipe(
image=input_image,
prompt="remove glass reflection",
vae_2=vae_2,
processing_resolution=processing_resolution,
)
# 将输出转换为图像
processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
processed_frame = (processed_frame[0] * 255).astype(np.uint8)
processed_frame = Image.fromarray(processed_frame)
return input_image, processed_frame # 返回调整后的输入图片和处理后的图片
# 创建 Gradio 界面
def create_gradio_interface():
# 示例图像
example_images = [
[os.path.join("files", "image", f"{i}.png"), "768"] for i in range(1, 14)
]
title = "# Dereflection Any Image"
description = """Official demo for **Dereflection Any Image**.
Please refer to our [paper](), [project page](https://abuuu122.github.io/DAI.github.io/), and [github](https://github.com/Abuuu122/Dereflection-Any-Image) for more details."""
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
resolution_choice = gr.Radio(
choices=["768", "Original Resolution"],
label="Processing Resolution",
value="768", # 默认选择原始分辨率
)
gr.Markdown(
"Select the resolution for processing the image, 768 is recommended for faster processing and stable performance. Higher resolution may take longer to process, we restrict the maximum resolution to 2560."
)
submit_btn = gr.Button("Remove Reflection", variant="primary")
with gr.Column():
output_image = gr.Image(label="Processed Image")
# 添加示例
gr.Examples(
examples=example_images,
inputs=[input_image, resolution_choice], # 输入组件列表
outputs=output_image,
fn=process_image,
cache_examples=False, # 缓存结果以加快加载速度
label="Example Images",
)
# 绑定按钮点击事件
submit_btn.click(
fn=process_image,
inputs=[input_image, resolution_choice], # 输入组件列表
outputs=[input_image, output_image], # 输出组件列表
)
return demo
# 主函数
def main():
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main() |