multimodalart's picture
Upload 2025 files
22a452a verified

A newer version of the Gradio SDK is available: 5.34.2

Upgrade

Token Merging (ํ† ํฐ ๋ณ‘ํ•ฉ)

Token Merging (introduced in Token Merging: Your ViT But Faster)์€ ํŠธ๋žœ์Šคํฌ๋จธ ๊ธฐ๋ฐ˜ ๋„คํŠธ์›Œํฌ์˜ forward pass์—์„œ ์ค‘๋ณต ํ† ํฐ์ด๋‚˜ ํŒจ์น˜๋ฅผ ์ ์ง„์ ์œผ๋กœ ๋ณ‘ํ•ฉํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๊ธฐ๋ฐ˜ ๋„คํŠธ์›Œํฌ์˜ ์ถ”๋ก  ์ง€์—ฐ ์‹œ๊ฐ„์„ ๋‹จ์ถ•ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Token Merging(ToMe)์ด ์ถœ์‹œ๋œ ํ›„, ์ €์ž๋“ค์€ Fast Stable Diffusion์„ ์œ„ํ•œ ํ† ํฐ ๋ณ‘ํ•ฉ์„ ๋ฐœํ‘œํ•˜์—ฌ Stable Diffusion๊ณผ ๋” ์ž˜ ํ˜ธํ™˜๋˜๋Š” ToMe ๋ฒ„์ „์„ ์†Œ๊ฐœํ–ˆ์Šต๋‹ˆ๋‹ค. ToMe๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด [DiffusionPipeline]์˜ ์ถ”๋ก  ์ง€์—ฐ ์‹œ๊ฐ„์„ ๋ถ€๋“œ๋Ÿฝ๊ฒŒ ๋‹จ์ถ•ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ฌธ์„œ์—์„œ๋Š” ToMe๋ฅผ [StableDiffusionPipeline]์— ์ ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•, ์˜ˆ์ƒ๋˜๋Š” ์†๋„ ํ–ฅ์ƒ, [StableDiffusionPipeline]์—์„œ ToMe๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ์˜ ์งˆ์  ์ธก๋ฉด์— ๋Œ€ํ•ด ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

ToMe ์‚ฌ์šฉํ•˜๊ธฐ

ToMe์˜ ์ €์ž๋“ค์€ tomesd๋ผ๋Š” ํŽธ๋ฆฌํ•œ Python ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๊ณต๊ฐœํ–ˆ๋Š”๋ฐ, ์ด ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ด์šฉํ•˜๋ฉด [DiffusionPipeline]์— ToMe๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

from diffusers import StableDiffusionPipeline
import tomesd

pipeline = StableDiffusionPipeline.from_pretrained(
      "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
+ tomesd.apply_patch(pipeline, ratio=0.5)

image = pipeline("a photo of an astronaut riding a horse on mars").images[0]

์ด๊ฒƒ์ด ๋‹ค์ž…๋‹ˆ๋‹ค!

tomesd.apply_patch()๋Š” ํŒŒ์ดํ”„๋ผ์ธ ์ถ”๋ก  ์†๋„์™€ ์ƒ์„ฑ๋œ ํ† ํฐ์˜ ํ’ˆ์งˆ ์‚ฌ์ด์˜ ๊ท ํ˜•์„ ๋งž์ถœ ์ˆ˜ ์žˆ๋„๋ก ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์ธ์ž๋ฅผ ๋…ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ธ์ˆ˜ ์ค‘ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๊ฒƒ์€ ratio(๋น„์œจ)์ž…๋‹ˆ๋‹ค. ratio์€ forward pass ์ค‘์— ๋ณ‘ํ•ฉ๋  ํ† ํฐ์˜ ์ˆ˜๋ฅผ ์ œ์–ดํ•ฉ๋‹ˆ๋‹ค. tomesd์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ํ•ด๋‹น ๋ฆฌํฌ์ง€ํ† ๋ฆฌ(https://github.com/dbolya/tomesd) ๋ฐ ๋…ผ๋ฌธ์„ ์ฐธ๊ณ ํ•˜์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

StableDiffusionPipeline์œผ๋กœ tomesd ๋ฒค์น˜๋งˆํ‚นํ•˜๊ธฐ

We benchmarked the impact of using tomesd on [StableDiffusionPipeline] along with xformers across different image resolutions. We used A100 and V100 as our test GPU devices with the following development environment (with Python 3.8.5): ๋‹ค์–‘ํ•œ ์ด๋ฏธ์ง€ ํ•ด์ƒ๋„์—์„œ xformers๋ฅผ ์ ์šฉํ•œ ์ƒํƒœ์—์„œ, [StableDiffusionPipeline]์— tomesd๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ์˜ ์˜ํ–ฅ์„ ๋ฒค์น˜๋งˆํ‚นํ–ˆ์Šต๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ GPU ์žฅ์น˜๋กœ A100๊ณผ V100์„ ์‚ฌ์šฉํ–ˆ์œผ๋ฉฐ ๊ฐœ๋ฐœ ํ™˜๊ฒฝ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค(Python 3.8.5 ์‚ฌ์šฉ):

- `diffusers` version: 0.15.1
- Python version: 3.8.16
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.13.2
- Transformers version: 4.27.2
- Accelerate version: 0.18.0
- xFormers version: 0.0.16
- tomesd version: 0.1.2

๋ฒค์น˜๋งˆํ‚น์—๋Š” ๋‹ค์Œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค: https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335. ๊ฒฐ๊ณผ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

A100

ํ•ด์ƒ๋„ ๋ฐฐ์น˜ ํฌ๊ธฐ Vanilla ToMe ToMe + xFormers ToMe ์†๋„ ํ–ฅ์ƒ (%) ToMe + xFormers ์†๋„ ํ–ฅ์ƒ (%)
512 10 6.88 5.26 4.69 23.54651163 31.83139535
768 10 OOM 14.71 11
8 OOM 11.56 8.84
4 OOM 5.98 4.66
2 4.99 3.24 3.1 35.07014028 37.8757515
1 3.29 2.24 2.03 31.91489362 38.29787234
1024 10 OOM OOM OOM
8 OOM OOM OOM
4 OOM 12.51 9.09
2 OOM 6.52 4.96
1 6.4 3.61 2.81 43.59375 56.09375

๊ฒฐ๊ณผ๋Š” ์ดˆ ๋‹จ์œ„์ž…๋‹ˆ๋‹ค. ์†๋„ ํ–ฅ์ƒ์€ Vanilla๊ณผ ๋น„๊ตํ•ด ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค.

V100

ํ•ด์ƒ๋„ ๋ฐฐ์น˜ ํฌ๊ธฐ Vanilla ToMe ToMe + xFormers ToMe ์†๋„ ํ–ฅ์ƒ (%) ToMe + xFormers ์†๋„ ํ–ฅ์ƒ (%)
512 10 OOM 10.03 9.29
8 OOM 8.05 7.47
4 5.7 4.3 3.98 24.56140351 30.1754386
2 3.14 2.43 2.27 22.61146497 27.70700637
1 1.88 1.57 1.57 16.4893617 16.4893617
768 10 OOM OOM 23.67
8 OOM OOM 18.81
4 OOM 11.81 9.7
2 OOM 6.27 5.2
1 5.43 3.38 2.82 37.75322284 48.06629834
1024 10 OOM OOM OOM
8 OOM OOM OOM
4 OOM OOM 19.35
2 OOM 13 10.78
1 OOM 6.66 5.54

์œ„์˜ ํ‘œ์—์„œ ๋ณผ ์ˆ˜ ์žˆ๋“ฏ์ด, ์ด๋ฏธ์ง€ ํ•ด์ƒ๋„๊ฐ€ ๋†’์„์ˆ˜๋ก tomesd๋ฅผ ์‚ฌ์šฉํ•œ ์†๋„ ํ–ฅ์ƒ์ด ๋”์šฑ ๋‘๋“œ๋Ÿฌ์ง‘๋‹ˆ๋‹ค. ๋˜ํ•œ tomesd๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด 1024x1024์™€ ๊ฐ™์€ ๋” ๋†’์€ ํ•ด์ƒ๋„์—์„œ ํŒŒ์ดํ”„๋ผ์ธ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ๋„ ํฅ๋ฏธ๋กญ์Šต๋‹ˆ๋‹ค.

torch.compile()์„ ์‚ฌ์šฉํ•˜๋ฉด ์ถ”๋ก  ์†๋„๋ฅผ ๋”์šฑ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ’ˆ์งˆ

As reported in the paper, ToMe can preserve the quality of the generated images to a great extent while speeding up inference. By increasing the ratio, it is possible to further speed up inference, but that might come at the cost of a deterioration in the image quality.

To test the quality of the generated samples using our setup, we sampled a few prompts from the โ€œParti Promptsโ€ (introduced in Parti) and performed inference with the [StableDiffusionPipeline] in the following settings:

๋…ผ๋ฌธ์— ๋ณด๊ณ ๋œ ๋ฐ”์™€ ๊ฐ™์ด, ToMe๋Š” ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€์˜ ํ’ˆ์งˆ์„ ์ƒ๋‹น ๋ถ€๋ถ„ ๋ณด์กดํ•˜๋ฉด์„œ ์ถ”๋ก  ์†๋„๋ฅผ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ratio์„ ๋†’์ด๋ฉด ์ถ”๋ก  ์†๋„๋ฅผ ๋” ๋†’์ผ ์ˆ˜ ์žˆ์ง€๋งŒ, ์ด๋ฏธ์ง€ ํ’ˆ์งˆ์ด ์ €ํ•˜๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ•ด๋‹น ์„ค์ •์„ ์‚ฌ์šฉํ•˜์—ฌ ์ƒ์„ฑ๋œ ์ƒ˜ํ”Œ์˜ ํ’ˆ์งˆ์„ ํ…Œ์ŠคํŠธํ•˜๊ธฐ ์œ„ํ•ด, "Parti ํ”„๋กฌํ”„ํŠธ"(Parti์—์„œ ์†Œ๊ฐœ)์—์„œ ๋ช‡ ๊ฐ€์ง€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ˜ํ”Œ๋งํ•˜๊ณ  ๋‹ค์Œ ์„ค์ •์—์„œ [StableDiffusionPipeline]์„ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ–ˆ์Šต๋‹ˆ๋‹ค:

  • Vanilla [StableDiffusionPipeline]
  • [StableDiffusionPipeline] + ToMe
  • [StableDiffusionPipeline] + ToMe + xformers

์ƒ์„ฑ๋œ ์ƒ˜ํ”Œ์˜ ํ’ˆ์งˆ์ด ํฌ๊ฒŒ ์ €ํ•˜๋˜๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ์€ ์ƒ˜ํ”Œ์ž…๋‹ˆ๋‹ค:

tome-samples

์ƒ์„ฑ๋œ ์ƒ˜ํ”Œ์€ ์—ฌ๊ธฐ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์‹คํ—˜์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ์ด ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.