toshas's picture
initial commit
a45988a

A newer version of the Gradio SDK is available: 5.23.3

Upgrade

DiffEdit

[[open-in-colab]]

์ด๋ฏธ์ง€ ํŽธ์ง‘์„ ํ•˜๋ ค๋ฉด ์ผ๋ฐ˜์ ์œผ๋กœ ํŽธ์ง‘ํ•  ์˜์—ญ์˜ ๋งˆ์Šคํฌ๋ฅผ ์ œ๊ณตํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. DiffEdit๋Š” ํ…์ŠคํŠธ ์ฟผ๋ฆฌ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋งˆ์Šคํฌ๋ฅผ ์ž๋™์œผ๋กœ ์ƒ์„ฑํ•˜๋ฏ€๋กœ ์ด๋ฏธ์ง€ ํŽธ์ง‘ ์†Œํ”„ํŠธ์›จ์–ด ์—†์ด๋„ ๋งˆ์Šคํฌ๋ฅผ ๋งŒ๋“ค๊ธฐ๊ฐ€ ์ „๋ฐ˜์ ์œผ๋กœ ๋” ์‰ฌ์›Œ์ง‘๋‹ˆ๋‹ค. DiffEdit ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ์„ธ ๋‹จ๊ณ„๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค:

  1. Diffusion ๋ชจ๋ธ์ด ์ผ๋ถ€ ์ฟผ๋ฆฌ ํ…์ŠคํŠธ์™€ ์ฐธ์กฐ ํ…์ŠคํŠธ๋ฅผ ์กฐ๊ฑด๋ถ€๋กœ ์ด๋ฏธ์ง€์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜์—ฌ ์ด๋ฏธ์ง€์˜ ์—ฌ๋Ÿฌ ์˜์—ญ์— ๋Œ€ํ•ด ์„œ๋กœ ๋‹ค๋ฅธ ๋…ธ์ด์ฆˆ ์ถ”์ •์น˜๋ฅผ ์ƒ์„ฑํ•˜๊ณ , ๊ทธ ์ฐจ์ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ฟผ๋ฆฌ ํ…์ŠคํŠธ์™€ ์ผ์น˜ํ•˜๋„๋ก ์ด๋ฏธ์ง€์˜ ์–ด๋Š ์˜์—ญ์„ ๋ณ€๊ฒฝํ•ด์•ผ ํ•˜๋Š”์ง€ ์‹๋ณ„ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์Šคํฌ๋ฅผ ์ถ”๋ก ํ•ฉ๋‹ˆ๋‹ค.
  2. ์ž…๋ ฅ ์ด๋ฏธ์ง€๊ฐ€ DDIM์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž ์žฌ ๊ณต๊ฐ„์œผ๋กœ ์ธ์ฝ”๋”ฉ๋ฉ๋‹ˆ๋‹ค.
  3. ๋งˆ์Šคํฌ ์™ธ๋ถ€์˜ ํ”ฝ์…€์ด ์ž…๋ ฅ ์ด๋ฏธ์ง€์™€ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€๋˜๋„๋ก ๋งˆ์Šคํฌ๋ฅผ ๊ฐ€์ด๋“œ๋กœ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ ์ฟผ๋ฆฌ์— ์กฐ๊ฑด์ด ์ง€์ •๋œ diffusion ๋ชจ๋ธ๋กœ latents๋ฅผ ๋””์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋งˆ์Šคํฌ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋งŒ๋“ค์ง€ ์•Š๊ณ  DiffEdit๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ํŽธ์ง‘ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ๋‹ค์Œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

# Colab์—์„œ ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ•˜๊ธฐ ์œ„ํ•ด ์ฃผ์„์„ ์ œ์™ธํ•˜์„ธ์š”
#!pip install -q diffusers transformers accelerate

[StableDiffusionDiffEditPipeline]์—๋Š” ์ด๋ฏธ์ง€ ๋งˆ์Šคํฌ์™€ ๋ถ€๋ถ„์ ์œผ๋กœ ๋ฐ˜์ „๋œ latents ์ง‘ํ•ฉ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ๋งˆ์Šคํฌ๋Š” [~StableDiffusionDiffEditPipeline.generate_mask] ํ•จ์ˆ˜์—์„œ ์ƒ์„ฑ๋˜๋ฉฐ, ๋‘ ๊ฐœ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ์ธ source_prompt์™€ target_prompt๊ฐ€ ํฌํ•จ๋ฉ๋‹ˆ๋‹ค. ์ด ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ์ด๋ฏธ์ง€์—์„œ ๋ฌด์—‡์„ ํŽธ์ง‘ํ• ์ง€ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๊ณผ์ผ ํ•œ ๊ทธ๋ฆ‡์„ ๋ฐฐ ํ•œ ๊ทธ๋ฆ‡์œผ๋กœ ๋ณ€๊ฒฝํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•˜์„ธ์š”:

source_prompt = "a bowl of fruits"
target_prompt = "a bowl of pears"

๋ถ€๋ถ„์ ์œผ๋กœ ๋ฐ˜์ „๋œ latents๋Š” [~StableDiffusionDiffEditPipeline.invert] ํ•จ์ˆ˜์—์„œ ์ƒ์„ฑ๋˜๋ฉฐ, ์ผ๋ฐ˜์ ์œผ๋กœ ์ด๋ฏธ์ง€๋ฅผ ์„ค๋ช…ํ•˜๋Š” prompt ๋˜๋Š” ์บก์…˜์„ ํฌํ•จํ•˜๋Š” ๊ฒƒ์ด inverse latent sampling ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฐ€์ด๋“œํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค. ์บก์…˜์€ ์ข…์ข… source_prompt๊ฐ€ ๋  ์ˆ˜ ์žˆ์ง€๋งŒ, ๋‹ค๋ฅธ ํ…์ŠคํŠธ ์„ค๋ช…์œผ๋กœ ์ž์œ ๋กญ๊ฒŒ ์‹คํ—˜ํ•ด ๋ณด์„ธ์š”!

ํŒŒ์ดํ”„๋ผ์ธ, ์Šค์ผ€์ค„๋Ÿฌ, ์—ญ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๊ธฐ ์œ„ํ•ด ๋ช‡ ๊ฐ€์ง€ ์ตœ์ ํ™”๋ฅผ ํ™œ์„ฑํ™”ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

import torch
from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionDiffEditPipeline

pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
    torch_dtype=torch.float16,
    safety_checker=None,
    use_safetensors=True,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()

์ˆ˜์ •ํ•˜๊ธฐ ์œ„ํ•œ ์ด๋ฏธ์ง€๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค:

from diffusers.utils import load_image, make_image_grid

img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
raw_image = load_image(img_url).resize((768, 768))
raw_image

์ด๋ฏธ์ง€ ๋งˆ์Šคํฌ๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด [~StableDiffusionDiffEditPipeline.generate_mask] ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€์—์„œ ํŽธ์ง‘ํ•  ๋‚ด์šฉ์„ ์ง€์ •ํ•˜๊ธฐ ์œ„ํ•ด source_prompt์™€ target_prompt๋ฅผ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

from PIL import Image

source_prompt = "a bowl of fruits"
target_prompt = "a basket of pears"
mask_image = pipeline.generate_mask(
    image=raw_image,
    source_prompt=source_prompt,
    target_prompt=target_prompt,
)
Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768))

๋‹ค์Œ์œผ๋กœ, ๋ฐ˜์ „๋œ latents๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์ด๋ฏธ์ง€๋ฅผ ๋ฌ˜์‚ฌํ•˜๋Š” ์บก์…˜์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค:

inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image).latents

๋งˆ์ง€๋ง‰์œผ๋กœ, ์ด๋ฏธ์ง€ ๋งˆ์Šคํฌ์™€ ๋ฐ˜์ „๋œ latents๋ฅผ ํŒŒ์ดํ”„๋ผ์ธ์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. target_prompt๋Š” ์ด์ œ prompt๊ฐ€ ๋˜๋ฉฐ, source_prompt๋Š” negative_prompt๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

output_image = pipeline(
    prompt=target_prompt,
    mask_image=mask_image,
    image_latents=inv_latents,
    negative_prompt=source_prompt,
).images[0]
mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768))
make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)
original image
edited image

Source์™€ target ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑํ•˜๊ธฐ

Source์™€ target ์ž„๋ฒ ๋”ฉ์€ ์ˆ˜๋™์œผ๋กœ ์ƒ์„ฑํ•˜๋Š” ๋Œ€์‹  Flan-T5 ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž๋™์œผ๋กœ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Flan-T5 ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๐Ÿค— Transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค:

import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto", torch_dtype=torch.float16)

๋ชจ๋ธ์— ํ”„๋กฌํ”„ํŠธํ•  source์™€ target ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ์ดˆ๊ธฐ ํ…์ŠคํŠธ๋“ค์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

source_concept = "bowl"
target_concept = "basket"

source_text = f"Provide a caption for images containing a {source_concept}. "
"The captions should be in English and should be no longer than 150 characters."

target_text = f"Provide a caption for images containing a {target_concept}. "
"The captions should be in English and should be no longer than 150 characters."

๋‹ค์Œ์œผ๋กœ, ํ”„๋กฌํ”„ํŠธ๋“ค์„ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

@torch.no_grad()
def generate_prompts(input_prompt):
    input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")

    outputs = model.generate(
        input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10
    )
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

source_prompts = generate_prompts(source_text)
target_prompts = generate_prompts(target_text)
print(source_prompts)
print(target_prompts)

๋‹ค์–‘ํ•œ ํ’ˆ์งˆ์˜ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ์ „๋žต์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด ์ƒ์„ฑ ์ „๋žต ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ์„ ์œ„ํ•ด [StableDiffusionDiffEditPipeline]์—์„œ ์‚ฌ์šฉํ•˜๋Š” ํ…์ŠคํŠธ ์ธ์ฝ”๋” ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. ํ…์ŠคํŠธ ์ธ์ฝ”๋”๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค:

import torch
from diffusers import StableDiffusionDiffEditPipeline

pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, use_safetensors=True
)
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()

@torch.no_grad()
def embed_prompts(sentences, tokenizer, text_encoder, device="cuda"):
    embeddings = []
    for sent in sentences:
        text_inputs = tokenizer(
            sent,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
        embeddings.append(prompt_embeds)
    return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)

source_embeds = embed_prompts(source_prompts, pipeline.tokenizer, pipeline.text_encoder)
target_embeds = embed_prompts(target_prompts, pipeline.tokenizer, pipeline.text_encoder)

๋งˆ์ง€๋ง‰์œผ๋กœ, ์ž„๋ฒ ๋”ฉ์„ [~StableDiffusionDiffEditPipeline.generate_mask] ๋ฐ [~StableDiffusionDiffEditPipeline.invert] ํ•จ์ˆ˜์™€ ํŒŒ์ดํ”„๋ผ์ธ์— ์ „๋‹ฌํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:

  from diffusers import DDIMInverseScheduler, DDIMScheduler
  from diffusers.utils import load_image, make_image_grid
  from PIL import Image

  pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
  pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)

  img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
  raw_image = load_image(img_url).resize((768, 768))

  mask_image = pipeline.generate_mask(
      image=raw_image,
-     source_prompt=source_prompt,
-     target_prompt=target_prompt,
+     source_prompt_embeds=source_embeds,
+     target_prompt_embeds=target_embeds,
  )

  inv_latents = pipeline.invert(
-     prompt=source_prompt,
+     prompt_embeds=source_embeds,
      image=raw_image,
  ).latents

  output_image = pipeline(
      mask_image=mask_image,
      image_latents=inv_latents,
-     prompt=target_prompt,
-     negative_prompt=source_prompt,
+     prompt_embeds=target_embeds,
+     negative_prompt_embeds=source_embeds,
  ).images[0]
  mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L")
  make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)

๋ฐ˜์ „์„ ์œ„ํ•œ ์บก์…˜ ์ƒ์„ฑํ•˜๊ธฐ

source_prompt๋ฅผ ์บก์…˜์œผ๋กœ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ€๋ถ„์ ์œผ๋กœ ๋ฐ˜์ „๋œ latents๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, BLIP ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์บก์…˜์„ ์ž๋™์œผ๋กœ ์ƒ์„ฑํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

๐Ÿค— Transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ BLIP ๋ชจ๋ธ๊ณผ ํ”„๋กœ์„ธ์„œ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค:

import torch
from transformers import BlipForConditionalGeneration, BlipProcessor

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16, low_cpu_mem_usage=True)

์ž…๋ ฅ ์ด๋ฏธ์ง€์—์„œ ์บก์…˜์„ ์ƒ์„ฑํ•˜๋Š” ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค:

@torch.no_grad()
def generate_caption(images, caption_generator, caption_processor):
    text = "a photograph of"

    inputs = caption_processor(images, text, return_tensors="pt").to(device="cuda", dtype=caption_generator.dtype)
    caption_generator.to("cuda")
    outputs = caption_generator.generate(**inputs, max_new_tokens=128)

    # ์บก์…˜ generator ์˜คํ”„๋กœ๋“œ
    caption_generator.to("cpu")

    caption = caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]
    return caption

์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  generate_caption ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•ด๋‹น ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์บก์…˜์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:

from diffusers.utils import load_image

img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
raw_image = load_image(img_url).resize((768, 768))
caption = generate_caption(raw_image, model, processor)
generated caption: "a photograph of a bowl of fruit on a table"

์ด์ œ ์บก์…˜์„ [~StableDiffusionDiffEditPipeline.invert] ํ•จ์ˆ˜์— ๋†“์•„ ๋ถ€๋ถ„์ ์œผ๋กœ ๋ฐ˜์ „๋œ latents๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!