File size: 8,432 Bytes
39a23e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import torch
from PIL import Image, ImageFilter, ImageOps,ImageEnhance
from scipy.ndimage import rank_filter, maximum_filter
import numpy as np
import pathlib
import glob
import os
from diffusers import StableDiffusionControlNetPipeline, DDIMScheduler, AutoencoderKL, ControlNetModel
from ip_adapter import IPAdapter


DESCRIPTION = """# [FilterPrompt](https://arxiv.org/abs/2404.13263): Guiding Imgae Transfer in Diffusion Models
<img id="teaser" alt="teaser" src="https://raw.githubusercontent.com/Meaoxixi/FilterPrompt/gh-pages/resources/teaser.png" />
"""
##################################################################################################################
# 0. Get Pre-Models' Path Ready
##################################################################################################################
vae_model_path = "https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main"
base_model_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main"
image_encoder_path = "https://huggingface.co/h94/IP-Adapter/tree/main/models/image_encoder"
ip_ckpt = "https://huggingface.co/h94/IP-Adapter/tree/main/models/ip-adapter_sd15.bin"
controlnet_softEdge_model_path = "https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/tree/main"
controlnet_depth_model_path = "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/tree/main"
# device = "cuda:0"
##################################################################################################################
# 1. load pipeline
##################################################################################################################
torch.cuda.empty_cache()
## 1.1 noise_scheduler
noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
        steps_offset=1,
    )
#  1.2 vae
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
#  1.3 ControlNet
## 1.3.1 load controlnet_softEdge
controlnet_softEdge = ControlNetModel.from_pretrained(controlnet_softEdge_model_path, torch_dtype=torch.float16)
## 1.3.2 load controlnet_depth
controlnet_depth = ControlNetModel.from_pretrained(controlnet_depth_model_path, torch_dtype=torch.float16)
# 1.4 load SD pipeline
pipe_softEdge = StableDiffusionControlNetPipeline.from_pretrained(
        base_model_path,
        controlnet=controlnet_softEdge,
        torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )
pipe_depth = StableDiffusionControlNetPipeline.from_pretrained(
        base_model_path,
        controlnet=controlnet_depth,
        torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )
print("1 Model loading completed !")
print("##################################################################")
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid
#########################################################################
## funcitions for task 1 : style transfer
#########################################################################
def gaussian_blur(image, blur_radius):
    image = Image.open(image)
    blurred_image = image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
    return blurred_image

def task1_StyleTransfer(photo, blur_radius, sketch):
    photoImage = Image.open(photo)
    blurPhoto = gaussian_blur(photo, blur_radius)

    Control_factor = 1.2
    IP_factor = 0.6
    ip_model = IPAdapter(pipe_depth, image_encoder_path, ip_ckpt, device, Control_factor=Control_factor, IP_factor=IP_factor)

    depth_image= Image.open(sketch)
    img_array = np.array(depth_image)
    gray_img_array = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
    # 反相
    inverted_array = 255 - gray_img_array
    gray_img_array = inverted_array.astype(np.uint8)
    processed_image = Image.fromarray(gray_img_array)
    contrast_factor = 2
    enhancer = ImageEnhance.Contrast(processed_image)
    processed_image = enhancer.enhance(contrast_factor)

    images = ip_model.generate(pil_image=photoImage, image=processed_image, num_samples=1, num_inference_steps=30, seed=52)
    original = image_grid(images, 1, 1)
    images = ip_model.generate(pil_image=blurPhoto, image=processed_image, num_samples=1, num_inference_steps=30, seed=52)
    result= image_grid(images, 1, 1)

    return original,result

def task1_test(photo, blur_radius, sketch):
    original = photo
    print(type(original))
    # <class 'str'>
    result = sketch
    return original, result
#########################################################################
## funcitions for task 2 : color transfer
#########################################################################
# todo

#############################################
# Demo
#############################################
theme = gr.themes.Monochrome(primary_hue="blue").set(
        loader_color="#FF0000",
        slider_color="#FF0000",
)

with gr.Blocks(theme=theme) as demo:
    gr.Markdown(DESCRIPTION)

    # 1. 第一个任务Style Transfer的界面代码(青铜器拓本转照片)
    with gr.Group():
        ## 1.1 任务描述
        gr.Markdown(
            """
            ## Case 1: Style transfer
                - In this task, our main goal is to achieve the style transfer from sketch to photo.
                - In the original generation result, the surface of the object has redundant pattern representation from the style image.
                - Next, you can control the Gaussian kernel size of GaussianBlur to weaken the expression of redundant pattern features in the generated results.
            """)
        ## 1.2 输入输出控件布局
        #### 用Column()控制空间在列上的排列关系
        with gr.Row():
            # 第一列
            with gr.Column():
                with gr.Row():
                    ### 1.2.1.1 输入真实照片
                    photo = gr.Image(label="Input photo", type="filepath")
                    print(photo)
                    print(type(photo))
                with gr.Row():
                    ### 1.2.1.2 高斯核控件
                    gaussianKernel = gr.Slider(minimum=0, maximum=8, step=1, value=2, label="Gaussian Blur Radius")
            # 第二列
            with gr.Column():
                with gr.Row():
                    # 1.2.2.1 输入素描图
                    sketch = gr.Image(label="Input sketch", type="filepath")
                    #print(sketch)
                with gr.Row():
                    # 1.2.2.2 按钮:开始生成图片
                    task1Button = gr.Button("Preprocess")
            # 第三列:显示初始的生成图
            with gr.Column():
                with gr.Row():
                    original_result_task1 = gr.Image(label="Original generation result", interactive=False, type="pil")
            # 第四列:显示使用高斯滤波之后的生成图
            with gr.Column():
                result_image_1 = gr.Image(label="Generate results after using GaussianBlur",type="pil")

        ## 1.3 示例图展示
        with gr.Row():
            paths = sorted(pathlib.Path("images/inputExample").glob("*.jpg"))
            gr.Examples(examples=[[path.as_posix()] for path in paths], inputs = sketch)
        with gr.Row():
            gr.Image(value="images/1_gaussian_filter.png", label=" Task example Image", type="filepath")

    # 1. task 1 - style transfer 的界面代码写完了,现在写控件之间交互的逻辑
    task1Button.click(
        fn=task1_StyleTransfer,
        #fn=task1_test,
        inputs=[photo, gaussianKernel, sketch],
        outputs=[original_result_task1, result_image_1],
    )

##################################################################################################################
# 2. run Demo on gradio
##################################################################################################################

if __name__ == "__main__":
    demo.queue(max_size=5).launch()