Spaces:
Running
on
Zero
A newer version of the Gradio SDK is available:
5.34.2
Token Merging (ํ ํฐ ๋ณํฉ)
Token Merging (introduced in Token Merging: Your ViT But Faster)์ ํธ๋์คํฌ๋จธ ๊ธฐ๋ฐ ๋คํธ์ํฌ์ forward pass์์ ์ค๋ณต ํ ํฐ์ด๋ ํจ์น๋ฅผ ์ ์ง์ ์ผ๋ก ๋ณํฉํ๋ ๋ฐฉ์์ผ๋ก ์๋ํฉ๋๋ค. ์ด๋ฅผ ํตํด ๊ธฐ๋ฐ ๋คํธ์ํฌ์ ์ถ๋ก ์ง์ฐ ์๊ฐ์ ๋จ์ถํ ์ ์์ต๋๋ค.
Token Merging(ToMe)์ด ์ถ์๋ ํ, ์ ์๋ค์ Fast Stable Diffusion์ ์ํ ํ ํฐ ๋ณํฉ์ ๋ฐํํ์ฌ Stable Diffusion๊ณผ ๋ ์ ํธํ๋๋ ToMe ๋ฒ์ ์ ์๊ฐํ์ต๋๋ค. ToMe๋ฅผ ์ฌ์ฉํ๋ฉด [DiffusionPipeline
]์ ์ถ๋ก ์ง์ฐ ์๊ฐ์ ๋ถ๋๋ฝ๊ฒ ๋จ์ถํ ์ ์์ต๋๋ค. ์ด ๋ฌธ์์์๋ ToMe๋ฅผ [StableDiffusionPipeline
]์ ์ ์ฉํ๋ ๋ฐฉ๋ฒ, ์์๋๋ ์๋ ํฅ์, [StableDiffusionPipeline
]์์ ToMe๋ฅผ ์ฌ์ฉํ ๋์ ์ง์ ์ธก๋ฉด์ ๋ํด ์ค๋ช
ํฉ๋๋ค.
ToMe ์ฌ์ฉํ๊ธฐ
ToMe์ ์ ์๋ค์ tomesd
๋ผ๋ ํธ๋ฆฌํ Python ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๊ณต๊ฐํ๋๋ฐ, ์ด ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ด์ฉํ๋ฉด [DiffusionPipeline
]์ ToMe๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์ ์ฉํ ์ ์์ต๋๋ค:
from diffusers import StableDiffusionPipeline
import tomesd
pipeline = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
+ tomesd.apply_patch(pipeline, ratio=0.5)
image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
์ด๊ฒ์ด ๋ค์ ๋๋ค!
tomesd.apply_patch()
๋ ํ์ดํ๋ผ์ธ ์ถ๋ก ์๋์ ์์ฑ๋ ํ ํฐ์ ํ์ง ์ฌ์ด์ ๊ท ํ์ ๋ง์ถ ์ ์๋๋ก ์ฌ๋ฌ ๊ฐ์ ์ธ์๋ฅผ ๋
ธ์ถํฉ๋๋ค. ์ด๋ฌํ ์ธ์ ์ค ๊ฐ์ฅ ์ค์ํ ๊ฒ์ ratio(๋น์จ)
์
๋๋ค. ratio
์ forward pass ์ค์ ๋ณํฉ๋ ํ ํฐ์ ์๋ฅผ ์ ์ดํฉ๋๋ค. tomesd
์ ๋ํ ์์ธํ ๋ด์ฉ์ ํด๋น ๋ฆฌํฌ์งํ ๋ฆฌ(https://github.com/dbolya/tomesd) ๋ฐ ๋
ผ๋ฌธ์ ์ฐธ๊ณ ํ์๊ธฐ ๋ฐ๋๋๋ค.
StableDiffusionPipeline
์ผ๋ก tomesd
๋ฒค์น๋งํนํ๊ธฐ
We benchmarked the impact of using tomesd
on [StableDiffusionPipeline
] along with xformers across different image resolutions. We used A100 and V100 as our test GPU devices with the following development environment (with Python 3.8.5):
๋ค์ํ ์ด๋ฏธ์ง ํด์๋์์ xformers๋ฅผ ์ ์ฉํ ์ํ์์, [StableDiffusionPipeline
]์ tomesd
๋ฅผ ์ฌ์ฉํ์ ๋์ ์ํฅ์ ๋ฒค์น๋งํนํ์ต๋๋ค. ํ
์คํธ GPU ์ฅ์น๋ก A100๊ณผ V100์ ์ฌ์ฉํ์ผ๋ฉฐ ๊ฐ๋ฐ ํ๊ฒฝ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค(Python 3.8.5 ์ฌ์ฉ):
- `diffusers` version: 0.15.1
- Python version: 3.8.16
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.13.2
- Transformers version: 4.27.2
- Accelerate version: 0.18.0
- xFormers version: 0.0.16
- tomesd version: 0.1.2
๋ฒค์น๋งํน์๋ ๋ค์ ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํ์ต๋๋ค: https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335. ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
A100
ํด์๋ | ๋ฐฐ์น ํฌ๊ธฐ | Vanilla | ToMe | ToMe + xFormers | ToMe ์๋ ํฅ์ (%) | ToMe + xFormers ์๋ ํฅ์ (%) |
---|---|---|---|---|---|---|
512 | 10 | 6.88 | 5.26 | 4.69 | 23.54651163 | 31.83139535 |
768 | 10 | OOM | 14.71 | 11 | ||
8 | OOM | 11.56 | 8.84 | |||
4 | OOM | 5.98 | 4.66 | |||
2 | 4.99 | 3.24 | 3.1 | 35.07014028 | 37.8757515 | |
1 | 3.29 | 2.24 | 2.03 | 31.91489362 | 38.29787234 | |
1024 | 10 | OOM | OOM | OOM | ||
8 | OOM | OOM | OOM | |||
4 | OOM | 12.51 | 9.09 | |||
2 | OOM | 6.52 | 4.96 | |||
1 | 6.4 | 3.61 | 2.81 | 43.59375 | 56.09375 |
๊ฒฐ๊ณผ๋ ์ด ๋จ์์
๋๋ค. ์๋ ํฅ์์ Vanilla
๊ณผ ๋น๊ตํด ๊ณ์ฐ๋ฉ๋๋ค.
V100
ํด์๋ | ๋ฐฐ์น ํฌ๊ธฐ | Vanilla | ToMe | ToMe + xFormers | ToMe ์๋ ํฅ์ (%) | ToMe + xFormers ์๋ ํฅ์ (%) |
---|---|---|---|---|---|---|
512 | 10 | OOM | 10.03 | 9.29 | ||
8 | OOM | 8.05 | 7.47 | |||
4 | 5.7 | 4.3 | 3.98 | 24.56140351 | 30.1754386 | |
2 | 3.14 | 2.43 | 2.27 | 22.61146497 | 27.70700637 | |
1 | 1.88 | 1.57 | 1.57 | 16.4893617 | 16.4893617 | |
768 | 10 | OOM | OOM | 23.67 | ||
8 | OOM | OOM | 18.81 | |||
4 | OOM | 11.81 | 9.7 | |||
2 | OOM | 6.27 | 5.2 | |||
1 | 5.43 | 3.38 | 2.82 | 37.75322284 | 48.06629834 | |
1024 | 10 | OOM | OOM | OOM | ||
8 | OOM | OOM | OOM | |||
4 | OOM | OOM | 19.35 | |||
2 | OOM | 13 | 10.78 | |||
1 | OOM | 6.66 | 5.54 |
์์ ํ์์ ๋ณผ ์ ์๋ฏ์ด, ์ด๋ฏธ์ง ํด์๋๊ฐ ๋์์๋ก tomesd
๋ฅผ ์ฌ์ฉํ ์๋ ํฅ์์ด ๋์ฑ ๋๋๋ฌ์ง๋๋ค. ๋ํ tomesd
๋ฅผ ์ฌ์ฉํ๋ฉด 1024x1024์ ๊ฐ์ ๋ ๋์ ํด์๋์์ ํ์ดํ๋ผ์ธ์ ์คํํ ์ ์๋ค๋ ์ ๋ ํฅ๋ฏธ๋กญ์ต๋๋ค.
torch.compile()
์ ์ฌ์ฉํ๋ฉด ์ถ๋ก ์๋๋ฅผ ๋์ฑ ๋์ผ ์ ์์ต๋๋ค.
ํ์ง
As reported in the paper, ToMe can preserve the quality of the generated images to a great extent while speeding up inference. By increasing the ratio
, it is possible to further speed up inference, but that might come at the cost of a deterioration in the image quality.
To test the quality of the generated samples using our setup, we sampled a few prompts from the โParti Promptsโ (introduced in Parti) and performed inference with the [StableDiffusionPipeline
] in the following settings:
๋
ผ๋ฌธ์ ๋ณด๊ณ ๋ ๋ฐ์ ๊ฐ์ด, ToMe๋ ์์ฑ๋ ์ด๋ฏธ์ง์ ํ์ง์ ์๋น ๋ถ๋ถ ๋ณด์กดํ๋ฉด์ ์ถ๋ก ์๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค. ratio
์ ๋์ด๋ฉด ์ถ๋ก ์๋๋ฅผ ๋ ๋์ผ ์ ์์ง๋ง, ์ด๋ฏธ์ง ํ์ง์ด ์ ํ๋ ์ ์์ต๋๋ค.
ํด๋น ์ค์ ์ ์ฌ์ฉํ์ฌ ์์ฑ๋ ์ํ์ ํ์ง์ ํ
์คํธํ๊ธฐ ์ํด, "Parti ํ๋กฌํํธ"(Parti์์ ์๊ฐ)์์ ๋ช ๊ฐ์ง ํ๋กฌํํธ๋ฅผ ์ํ๋งํ๊ณ ๋ค์ ์ค์ ์์ [StableDiffusionPipeline
]์ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์ํํ์ต๋๋ค:
- Vanilla [
StableDiffusionPipeline
] - [
StableDiffusionPipeline
] + ToMe - [
StableDiffusionPipeline
] + ToMe + xformers
์์ฑ๋ ์ํ์ ํ์ง์ด ํฌ๊ฒ ์ ํ๋๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ง ๋ชปํ์ต๋๋ค. ๋ค์์ ์ํ์ ๋๋ค:
์์ฑ๋ ์ํ์ ์ฌ๊ธฐ์์ ํ์ธํ ์ ์์ต๋๋ค. ์ด ์คํ์ ์ํํ๊ธฐ ์ํด ์ด ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํ์ต๋๋ค.