Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,731 Bytes
7d1a82d 311419e 1cedc13 651dfe7 1cedc13 7cdacae 651dfe7 1cedc13 311419e 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 7cdacae 311419e 651dfe7 311419e 1cedc13 311419e 7cdacae 311419e 7cdacae 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 7cdacae 1cedc13 7d1a82d 1cedc13 651dfe7 7cdacae 1cedc13 7cdacae 651dfe7 1cedc13 651dfe7 7cdacae 651dfe7 1cedc13 c360cac 651dfe7 1cedc13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import spaces # 必须放在最前面
import os
import numpy as np
import torch
from PIL import Image
import gradio as gr
from gradio_imageslider import ImageSlider
# 延迟 CUDA 初始化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float32
# 加载模型组件
from DAI.pipeline_all import DAIPipeline
from DAI.controlnetvae import ControlNetVAEModel
from DAI.decoder import CustomAutoencoderKL
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, AutoTokenizer
pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
# 加载模型
controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path2, subfolder="vae").to(device)
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path2, subfolder="text_encoder").to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path2, subfolder="tokenizer", use_fast=False)
# 创建推理管道
pipe = DAIPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
scheduler=None,
feature_extractor=None,
t_start=0,
).to(device)
# 使用 spaces.GPU 包装推理函数
@spaces.GPU
def process_image(input_image):
# 将 Gradio 输入转换为 PIL 图像
input_image = Image.fromarray(input_image)
# 处理图像
pipe_out = pipe(
image=input_image,
prompt="remove glass reflection",
vae_2=vae_2,
processing_resolution=None,
)
# 将输出转换为图像
processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
processed_frame = (processed_frame[0] * 255).astype(np.uint8)
processed_frame = Image.fromarray(processed_frame)
# 返回输入图像和处理后的图像
return input_image, processed_frame
# 创建 Gradio 界面
def create_gradio_interface():
# 示例图像
example_images = [
os.path.join("files", "image", f"{i}.png") for i in range(1, 9)
]
with gr.Blocks() as demo:
gr.Markdown("# Dereflection Any Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
submit_btn = gr.Button("Remove Reflection", variant="primary")
with gr.Column():
# 使用 ImageSlider 显示前后对比
output_slider = ImageSlider(
label="Before & After",
show_download_button=True,
show_share_button=True,
)
# 添加示例
gr.Examples(
examples=example_images,
inputs=input_image,
outputs=output_slider,
fn=process_image,
cache_examples=True, # 缓存结果以加快加载速度
label="Example Images",
)
# 绑定按钮点击事件
submit_btn.click(
fn=process_image,
inputs=input_image,
outputs=output_slider,
)
return demo
# 主函数
def main():
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main() |