Meaowangxi commited on
Commit
a9967fb
·
verified ·
1 Parent(s): 1cae769

initial test

Browse files
Files changed (1) hide show
  1. app.py +380 -0
app.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image, ImageFilter, ImageOps,ImageEnhance
4
+ from scipy.ndimage import rank_filter, maximum_filter
5
+ from skimage.filters import gabor
6
+ import skimage.color
7
+ import numpy as np
8
+ import pathlib
9
+ import glob
10
+ import os
11
+ from diffusers import StableDiffusionControlNetPipeline, DDIMScheduler, AutoencoderKL, ControlNetModel
12
+ from ip_adapter import IPAdapter
13
+
14
+
15
+ DESCRIPTION = """# [FilterPrompt](https://arxiv.org/abs/2404.13263): Guiding Imgae Transfer in Diffusion Models
16
+ <img id="teaser" alt="teaser" src="https://raw.githubusercontent.com/Meaoxixi/FilterPrompt/gh-pages/resources/teaser.png" />
17
+ """
18
+ # <img id="overview" alt="overview" src="https://github.com/Meaoxixi/FilterPrompt/blob/gh-pages/resources/teaser.png" />
19
+ # 在你提供的链接中,你需要将 GitHub 页面链接中的 github.com 替换为 raw.githubusercontent.com,并将 blob 分支名称从链接中删除,以便获取原始文件内容。
20
+ ##################################################################################################################
21
+ # 0. Get Pre-Models' Path Ready
22
+ ##################################################################################################################
23
+ base_model_path = "models/stable-diffusion-v1-5"
24
+ vae_model_path = "models/sd-vae-ft-mse"
25
+ image_encoder_path = "models/IP-Adapter/image_encoder"
26
+ ip_ckpt = "models/IP-Adapter/ip-adapter_sd15.bin"
27
+ controlnet_softEdge_model_path = "models/ControlNet/ControlNet_depth"
28
+ controlnet_depth_model_path = "models/ControlNet/ControlNet_softEdge"
29
+ device = "cuda:0"
30
+ ##################################################################################################################
31
+ # 1. load pipeline
32
+ ##################################################################################################################
33
+ torch.cuda.empty_cache()
34
+ ## 1.1 noise_scheduler
35
+ noise_scheduler = DDIMScheduler(
36
+ num_train_timesteps=1000,
37
+ beta_start=0.00085,
38
+ beta_end=0.012,
39
+ beta_schedule="scaled_linear",
40
+ clip_sample=False,
41
+ set_alpha_to_one=False,
42
+ steps_offset=1,
43
+ )
44
+ # 1.2 vae
45
+ vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
46
+ # 1.3 ControlNet
47
+ ## 1.3.1 load controlnet_softEdge
48
+ controlnet_softEdge = ControlNetModel.from_pretrained(controlnet_softEdge_model_path, torch_dtype=torch.float16)
49
+ ## 1.3.2 load controlnet_depth
50
+ controlnet_depth = ControlNetModel.from_pretrained(controlnet_depth_model_path, torch_dtype=torch.float16)
51
+ # 1.4 load SD pipeline
52
+ pipe_softEdge = StableDiffusionControlNetPipeline.from_pretrained(
53
+ base_model_path,
54
+ controlnet=controlnet_softEdge,
55
+ torch_dtype=torch.float16,
56
+ scheduler=noise_scheduler,
57
+ vae=vae,
58
+ feature_extractor=None,
59
+ safety_checker=None
60
+ )
61
+ pipe_depth = StableDiffusionControlNetPipeline.from_pretrained(
62
+ base_model_path,
63
+ controlnet=controlnet_depth,
64
+ torch_dtype=torch.float16,
65
+ scheduler=noise_scheduler,
66
+ vae=vae,
67
+ feature_extractor=None,
68
+ safety_checker=None
69
+ )
70
+ print("1 Model loading completed !")
71
+ print("##################################################################")
72
+ def image_grid(imgs, rows, cols):
73
+ assert len(imgs) == rows * cols
74
+ w, h = imgs[0].size
75
+ grid = Image.new('RGB', size=(cols * w, rows * h))
76
+ for i, img in enumerate(imgs):
77
+ grid.paste(img, box=(i % cols * w, i // cols * h))
78
+ return grid
79
+ #########################################################################
80
+ # 接下来是有关demo有关的函数定义
81
+ ## funcitions for task 1 : style transfer
82
+ #########################################################################
83
+ def gaussian_blur(image, blur_radius):
84
+ image = Image.open(image)
85
+ blurred_image = image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
86
+ return blurred_image
87
+
88
+ def task1_StyleTransfer(photo, blur_radius, sketch):
89
+ photoImage = Image.open(photo)
90
+ blurPhoto = gaussian_blur(photo, blur_radius)
91
+
92
+ Control_factor = 1.2
93
+ IP_factor = 0.6
94
+ ip_model = IPAdapter(pipe_depth, image_encoder_path, ip_ckpt, device, Control_factor=Control_factor, IP_factor=IP_factor)
95
+
96
+ depth_image= Image.open(sketch)
97
+ img_array = np.array(depth_image)
98
+ gray_img_array = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
99
+ # 反相
100
+ inverted_array = 255 - gray_img_array
101
+ gray_img_array = inverted_array.astype(np.uint8)
102
+ processed_image = Image.fromarray(gray_img_array)
103
+ contrast_factor = 2
104
+ enhancer = ImageEnhance.Contrast(processed_image)
105
+ processed_image = enhancer.enhance(contrast_factor)
106
+
107
+ images = ip_model.generate(pil_image=photoImage, image=processed_image, num_samples=1, num_inference_steps=30, seed=52)
108
+ original = image_grid(images, 1, 1)
109
+ images = ip_model.generate(pil_image=blurPhoto, image=processed_image, num_samples=1, num_inference_steps=30, seed=52)
110
+ result= image_grid(images, 1, 1)
111
+
112
+ return original,result
113
+
114
+ def task1_test(photo, blur_radius, sketch):
115
+ original = photo
116
+ print(type(original))
117
+ # <class 'str'>
118
+ result = sketch
119
+ return original, result
120
+ #########################################################################
121
+ ## funcitions for task 2 : color transfer
122
+ #########################################################################
123
+ # 定义滤波器函数
124
+ def desaturate_filter(image):
125
+ image = Image.open(image)
126
+ return ImageOps.grayscale(image)
127
+
128
+ def gabor_filter(image):
129
+ image = Image.open(image)
130
+ image_array = np.array(image.convert('L')) # 转换为灰度图像
131
+ filtered_real, filtered_imag = gabor(image_array, frequency=0.6)
132
+ filtered_image = np.sqrt(filtered_real**2 + filtered_imag**2)
133
+ return Image.fromarray(np.uint8(filtered_image))
134
+
135
+ def rank_filter_func(image):
136
+ image = Image.open(image)
137
+ image_array = np.array(image.convert('L'))
138
+ filtered_image = rank_filter(image_array, rank=5, size=5)
139
+ return Image.fromarray(np.uint8(filtered_image))
140
+
141
+ def max_filter_func(image):
142
+ image = Image.open(image)
143
+ image_array = np.array(image.convert('L'))
144
+ filtered_image = maximum_filter(image_array, size=20)
145
+ return Image.fromarray(np.uint8(filtered_image))
146
+ # 定义处理函数
147
+ def fun2(image,image2, filter_name):
148
+ if filter_name == "Desaturate Filter":
149
+ return desaturate_filter(image),desaturate_filter(image2)
150
+ elif filter_name == "Gabor Filter":
151
+ return gabor_filter(image),gabor_filter(image2)
152
+ elif filter_name == "Rank Filter":
153
+ return rank_filter_func(image),rank_filter_func(image2)
154
+ elif filter_name == "Max Filter":
155
+ return max_filter_func(image),max_filter_func(image2)
156
+ else:
157
+ return image,image2
158
+
159
+
160
+ #############################################
161
+ # Demo页面
162
+ #############################################
163
+ # 通过使用自定义 CSS 来控制风格
164
+ #with gr.Blocks(theme=gr.themes.Soft()) as demo:
165
+ with gr.Blocks(css="style.css") as demo:
166
+ # 0. 首先是标题和teaser
167
+ gr.Markdown(DESCRIPTION)
168
+
169
+ # 1. 第一个任务Style Transfer的界面代码(青铜器拓本转照片)
170
+ with gr.Group():
171
+ ## 1.1 任务描述
172
+ gr.Markdown(
173
+ """
174
+ ## Case 1: Style transfer
175
+ - In this task, our main goal is to achieve the style transfer from sketch to photo.
176
+ - In the original generation result, the surface of the object has redundant pattern representation from the style image.
177
+ - Next, you can control the Gaussian kernel size of GaussianBlur to weaken the expression of redundant pattern features in the generated results.
178
+ """)
179
+ ## 1.2 输入输出控件布局
180
+ #### 用Column()控制空间在列上的排列关系
181
+ with gr.Row():
182
+ # 第一列
183
+ with gr.Column():
184
+ with gr.Row():
185
+ ### 1.2.1.1 输入真实照片
186
+ photo = gr.Image(label="Input photo", type="filepath")
187
+ print(photo)
188
+ print(type(photo))
189
+ with gr.Row():
190
+ ### 1.2.1.2 高斯核控件
191
+ gaussianKernel = gr.Slider(minimum=0, maximum=8, step=1, value=2, label="Gaussian Blur Radius")
192
+ # 第二列
193
+ with gr.Column():
194
+ with gr.Row():
195
+ # 1.2.2.1 输入素描图
196
+ sketch = gr.Image(label="Input sketch", type="filepath")
197
+ #print(sketch)
198
+ with gr.Row():
199
+ # 1.2.2.2 按钮:开始生成图片
200
+ task1Button = gr.Button("Preprocess")
201
+ # 第三列:显示初始的生成图
202
+ with gr.Column():
203
+ with gr.Row():
204
+ original_result_task1 = gr.Image(label="Original generation result", interactive=False, type="pil")
205
+ # 第四列:显示使用高斯滤波之后的生成图
206
+ with gr.Column():
207
+ result_image_1 = gr.Image(label="Generate results after using GaussianBlur",type="pil")
208
+
209
+ ## 1.3 示例图展示
210
+ with gr.Row():
211
+ paths = sorted(pathlib.Path("images/inputExample").glob("*.jpg"))
212
+ gr.Examples(examples=[[path.as_posix()] for path in paths], inputs = sketch)
213
+ with gr.Row():
214
+ gr.Image(value="images/1_gaussian_filter.png", label=" Task example Image", type="filepath")
215
+
216
+ # 1. task 1 - style transfer 的界面代码写完了,现在写控件之间交互的逻辑
217
+ task1Button.click(
218
+ fn=task1_StyleTransfer,
219
+ #fn=task1_test,
220
+ inputs=[photo, gaussianKernel, sketch],
221
+ outputs=[original_result_task1, result_image_1],
222
+ )
223
+ #
224
+ # # 2. 第二个任务增强几何属性保护-Color transfer
225
+ # with gr.Group():
226
+ # ## 2.1 任务描述
227
+ # gr.Markdown(
228
+ # """
229
+ # ## Case 2: Color transferr
230
+ # - In this task, our main goal is to transfer color from imageA to imageB. We can feel the effect of the filter on the protection of geometric properties.
231
+ # - In the standard Controlnet-depth mode, the ideal input is the depth map.
232
+ # - Here, we choose to input the result processed by some filters into the network instead of the original depth map.
233
+ # - You can feel from the use of different filters that "decolorization+inversion+enhancement of contrast" can maximize the retention of detailed geometric information in the original image.
234
+ # """)
235
+ # ## 2.1 输入输出控件布局
236
+ # with gr.Row():
237
+ # with gr.Column():
238
+ # with gr.Row():
239
+ # input_appearImage = gr.Image(label="Input Appearance Image", type="filepath")
240
+ # with gr.Row():
241
+ # filter_dropdown = gr.Dropdown(
242
+ # choices=["Desaturate Filter", "Gabor Filter", "Rank Filter", "Max Filter"],
243
+ # label="Select Filter",
244
+ # value="Desaturate Filter"
245
+ # )
246
+ # with gr.Column():
247
+ # with gr.Row():
248
+ # input_strucImage = gr.Image(label="Input Structure Image", type="filepath")
249
+ # with gr.Row():
250
+ # geometry_button = gr.Button("Preprocess")
251
+ # with gr.Column():
252
+ # with gr.Row():
253
+ # afterFilterImage = gr.Image(label="Appearance image after filter choosed", interactive=False)
254
+ # with gr.Column():
255
+ # result_task2 = gr.Image(label="Generate results")
256
+ # # instyle = gr.State()
257
+ #
258
+ # ## 2.3 示例图展示
259
+ # with gr.Row():
260
+ # gr.Image(value="task/color_transfer.png", label="example Image", type="filepath")
261
+ #
262
+ # # 3. 第3个任务是光照效果的改善
263
+ # with gr.Group():
264
+ # ## 3.1 任务描述
265
+ # gr.Markdown(
266
+ # """
267
+ # ## Case 3: Image-to-Image translation
268
+ # - In this example, our goal is to turn a simple outline drawing/sketch into a detailed and realistic photo.
269
+ # - Here, we provide the original mask generation results, and provide the generation results after superimposing the image on the mask and passing the decolorization filter.
270
+ # - From this, you can feel that the mask obtained by the decolorization operation can retain a certain amount of original lighting information and improve the texture of the generated results.
271
+ # """)
272
+ # ## 3.2 输入输出控件布局
273
+ # with gr.Row():
274
+ # with gr.Column():
275
+ # with gr.Row():
276
+ # input_appearImage = gr.Image(label="Input Appearance Image", type="filepath")
277
+ # with gr.Row():
278
+ # filter_dropdown = gr.Dropdown(
279
+ # choices=["Desaturate Filter", "Gabor Filter", "Rank Filter", "Max Filter"],
280
+ # label="Select Filter",
281
+ # value="Desaturate Filter"
282
+ # )
283
+ # with gr.Column():
284
+ # with gr.Row():
285
+ # input_strucImage = gr.Image(label="Input Structure Image", type="filepath")
286
+ # with gr.Row():
287
+ # geometry_button = gr.Button("Preprocess")
288
+ # with gr.Column():
289
+ # with gr.Row():
290
+ # afterFilterImage = gr.Image(label="Appearance image after filter choosed", interactive=False)
291
+ # with gr.Column():
292
+ # result_task2 = gr.Image(label="Generate results")
293
+ # # instyle = gr.State()
294
+ #
295
+ # ## 3.3 示例图展示
296
+ # with gr.Row():
297
+ # gr.Image(value="task/4_light.jpg", label="example Image", type="filepath")
298
+ #
299
+ # # 4. 第4个任务是materials transfer
300
+ # with gr.Group():
301
+ # ## 4.1 任务描述
302
+ # gr.Markdown(
303
+ # """
304
+ # ## Case 4: Materials Transfer
305
+ # - In this example, our goal is to transfer the material appearance of one object image to another object image. The process involves changing the surface properties of objects in the image so that they appear to be made of another material.
306
+ # - Here, we provide the original generation results and provide a variety of edited filters.
307
+ # - You can specify any filtering operation and intuitively feel the impact of the filtering on the rendering properties in the generated results.
308
+ # - For example, a sharpen filter can sharpen the texture of a stone, a Gaussian blur can smooth the texture of a stone, and a custom filter can change the style of a stone. These all show that filterPrompt is simple and intuitive.
309
+ # """)
310
+ # ## 4.2 输入输出控件布局
311
+ # with gr.Row():
312
+ # with gr.Column():
313
+ # with gr.Row():
314
+ # input_appearImage = gr.Image(label="Input Appearance Image", type="filepath")
315
+ # with gr.Row():
316
+ # filter_dropdown = gr.Dropdown(
317
+ # choices=["Desaturate Filter", "Gabor Filter", "Rank Filter", "Max Filter"],
318
+ # label="Select Filter",
319
+ # value="Desaturate Filter"
320
+ # )
321
+ # with gr.Column():
322
+ # with gr.Row():
323
+ # input_strucImage = gr.Image(label="Input Structure Image", type="filepath")
324
+ # with gr.Row():
325
+ # geometry_button = gr.Button("Preprocess")
326
+ # with gr.Column():
327
+ # with gr.Row():
328
+ # afterFilterImage = gr.Image(label="Appearance image after filter choosed", interactive=False)
329
+ # with gr.Column():
330
+ # result_task2 = gr.Image(label="Generate results")
331
+ # # instyle = gr.State()
332
+ #
333
+ # ## 3.3 示例图展示
334
+ # with gr.Row():
335
+ # gr.Image(value="task/3mateialsTransfer.jpg", label="example Image", type="filepath")
336
+ #
337
+ #
338
+ # geometry_button.click(
339
+ # fn=fun2,
340
+ # inputs=[input_strucImage, input_appearImage, filter_dropdown],
341
+ # outputs=[afterFilterImage, result_task2],
342
+ # )
343
+ # aligned_face.change(
344
+ # fn=model.reconstruct_face,
345
+ # inputs=[aligned_face, encoder_type],
346
+ # outputs=[
347
+ # reconstructed_face,
348
+ # instyle,
349
+ # ],
350
+ # )
351
+ # style_type.change(
352
+ # fn=update_slider,
353
+ # inputs=style_type,
354
+ # outputs=style_index,
355
+ # )
356
+ # style_type.change(
357
+ # fn=update_style_image,
358
+ # inputs=style_type,
359
+ # outputs=style_image,
360
+ # )
361
+ # generate_button.click(
362
+ # fn=model.generate,
363
+ # inputs=[
364
+ # style_type,
365
+ # style_index,
366
+ # structure_weight,
367
+ # color_weight,
368
+ # structure_only,
369
+ # instyle,
370
+ # ],
371
+ # outputs=result,
372
+ #)
373
+ ##################################################################################################################
374
+ # 2. run Demo on gradio
375
+ ##################################################################################################################
376
+
377
+ if __name__ == "__main__":
378
+ demo.queue(max_size=5).launch()
379
+ #demo.queue(max_size=5).launch(server_port=12345)
380
+ #demo.queue(max_size=5).launch(server_port=12345, share=True)