import gradio as gr
import torch
from PIL import Image, ImageFilter, ImageOps,ImageEnhance
from scipy.ndimage import rank_filter, maximum_filter
import numpy as np
import pathlib
import glob
import os
from diffusers import StableDiffusionControlNetPipeline, DDIMScheduler, AutoencoderKL, ControlNetModel
from ip_adapter import IPAdapter
DESCRIPTION = """# [FilterPrompt](https://arxiv.org/abs/2404.13263): Guiding Imgae Transfer in Diffusion Models
"""
##################################################################################################################
# 0. Get Pre-Models' Path Ready
##################################################################################################################
vae_model_path = "stabilityai/sd-vae-ft-mse"
base_model_path = "runwayml/stable-diffusion-v1-5"
image_encoder_path = "models/image_encoder"
ip_ckpt = "models/ip-adapter_sd15.bin"
controlnet_softEdge_model_path = "lllyasviel/control_v11p_sd15_softedge"
controlnet_depth_model_path = "lllyasviel/control_v11f1p_sd15_depth"
# device = "cuda:0"
device = "cpu"
##################################################################################################################
# 1. load pipeline
##################################################################################################################
torch.cuda.empty_cache()
## 1.1 noise_scheduler
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
# 1.2 vae
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float32)
# 1.3 ControlNet
## 1.3.1 load controlnet_softEdge
controlnet_softEdge = ControlNetModel.from_pretrained(controlnet_softEdge_model_path, torch_dtype=torch.float32)
## 1.3.2 load controlnet_depth
controlnet_depth = ControlNetModel.from_pretrained(controlnet_depth_model_path, torch_dtype=torch.float32)
# 1.4 load SD pipeline
pipe_softEdge = StableDiffusionControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet_softEdge,
torch_dtype=torch.float32,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None
)
pipe_depth = StableDiffusionControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet_depth,
torch_dtype=torch.float32,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None
)
print("1 Model loading completed !")
print("##################################################################")
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
#########################################################################
## funcitions for task 1 : style transfer
#########################################################################
def gaussian_blur(image, blur_radius):
image = Image.open(image)
blurred_image = image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
return blurred_image
def task1_StyleTransfer(photo, blur_radius, sketch):
photoImage = Image.open(photo)
blurPhoto = gaussian_blur(photo, blur_radius)
Control_factor = 1.2
IP_factor = 0.6
ip_model = IPAdapter(pipe_depth, image_encoder_path, ip_ckpt, device, Control_factor=Control_factor, IP_factor=IP_factor)
depth_image= Image.open(sketch)
img_array = np.array(depth_image)
gray_img_array = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
# 反相
inverted_array = 255 - gray_img_array
gray_img_array = inverted_array.astype(np.uint8)
processed_image = Image.fromarray(gray_img_array)
contrast_factor = 2
enhancer = ImageEnhance.Contrast(processed_image)
processed_image = enhancer.enhance(contrast_factor)
#images = ip_model.generate(pil_image=photoImage, image=processed_image, num_samples=1, num_inference_steps=15, seed=52)
#original = image_grid(images, 1, 1)
images = ip_model.generate(pil_image=blurPhoto, image=processed_image, num_samples=1, num_inference_steps=10, seed=52)
result= image_grid(images, 1, 1)
#return original,result
return blurPhoto, result
def task1_test(photo, blur_radius, sketch):
original = photo
print(type(original))
#
result = sketch
return original, result
#########################################################################
## funcitions for task 2 : color transfer
#########################################################################
# todo
#############################################
# Demo
#############################################
theme = gr.themes.Monochrome(primary_hue="blue").set(
loader_color="#FF0000",
slider_color="#FF0000",
)
with gr.Blocks(theme=theme) as demo:
gr.Markdown(DESCRIPTION)
# 1. 第一个任务Style Transfer的界面代码(青铜器拓本转照片)
with gr.Group():
## 1.1 任务描述
gr.Markdown(
"""
## Case 1: Style transfer
- In this task, our main goal is to achieve the style transfer from sketch to photo.
- In the original generation result, the surface of the object has redundant pattern representation from the style image.
- Next, you can control the Gaussian kernel size of GaussianBlur to weaken the expression of redundant pattern features in the generated results.
""")
## 1.2 输入输出控件布局
#### 用Column()控制空间在列上的排列关系
with gr.Row():
# 第一列
with gr.Column():
with gr.Row():
### 1.2.1.1 输入真实照片
photo = gr.Image(label="Input photo", type="filepath")
print(photo)
print(type(photo))
with gr.Row():
### 1.2.1.2 高斯核控件
gaussianKernel = gr.Slider(minimum=0, maximum=8, step=1, value=2, label="Gaussian Blur Radius")
# 第二列
with gr.Column():
with gr.Row():
# 1.2.2.1 输入素描图
sketch = gr.Image(label="Input sketch", type="filepath")
#print(sketch)
with gr.Row():
# 1.2.2.2 按钮:开始生成图片
task1Button = gr.Button("Preprocess")
# 第三列:显示初始的生成图
with gr.Column():
with gr.Row():
original_result_task1 = gr.Image(label="Original generation result", interactive=False, type="pil")
# 第四列:显示使用高斯滤波之后的生成图
with gr.Column():
result_image_1 = gr.Image(label="Generate results after using GaussianBlur",type="pil")
## 1.3 示例图展示
with gr.Row():
paths = sorted(pathlib.Path("images/inputExample").glob("*.jpg"))
gr.Examples(examples=[[path.as_posix()] for path in paths], inputs = sketch)
with gr.Row():
gr.Image(value="images/1_gaussian_filter.png", label=" Task example Image", type="filepath")
# 1. task 1 - style transfer 的界面代码写完了,现在写控件之间交互的逻辑
task1Button.click(
fn=task1_StyleTransfer,
#fn=task1_test,
inputs=[photo, gaussianKernel, sketch],
outputs=[original_result_task1, result_image_1],
)
##################################################################################################################
# 2. run Demo on gradio
##################################################################################################################
if __name__ == "__main__":
demo.queue(max_size=5).launch()