Model Description

In this paper, we make the first attempt to align diffusion models for image inpainting with human aesthetic standards via a reinforcement learning framework, significantly improving the quality and visual appeal of inpainted images. Specifically, instead of directly measuring the divergence with paired images, we train a reward model with the dataset we construct, consisting of nearly 51,000 images annotated with human preferences. Then, we adopt a reinforcement learning process to fine-tune the distribution of a pre-trained diffusion model for image inpainting in the direction of higher reward. Moreover, we theoretically deduce the upper bound on the error of the reward model, which illustrates the potential confidence of reward estimation throughout the reinforcement alignment process, thereby facilitating accurate regularization. Our code and dataset are publicly available at https://prefpaint.github.io.

Usage

import os
from PIL import Image
from diffusers import AutoPipelineForInpainting

pipe = AutoPipelineForInpainting.from_pretrained(
             'kd5678/prefpaint-v1.0', cache_dir='/data/kendong/cache')
pipe = pipe.to("cuda")

color_path = 'images.png'
mask_path = 'mask.png'
save_path = './results'
os.makedirs(save_path, exist_ok=True)

image, mask = Image.open(color_path), Image.open(mask_path).convert('L')
# You can provide your prompt here.
prompt = " "
result = pipe(prompt=prompt, image=image, mask_image=mask, eta=1.0).images[0]            
result.save(os.path.join(save_path, 'results.png'))
  

How to Cite

@article{liu2024prefpaint,
  title={PrefPaint: Aligning Image Inpainting Diffusion Model with Human Preference},
  author={Liu, Kendong and Zhu, Zhiyu and Li, Chuanhao and Liu, Hui and Zeng, Huanqiang and Hou, Junhui},
  journal={arXiv preprint arXiv:2410.21966},
  year={2024}
}
Downloads last month
100
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.