Spaces:
Paused
A newer version of the Gradio SDK is available:
5.23.3
Hugging Face Transformers๋ฅผ ์ถ๊ฐํ๋ ๋ฐฉ๋ฒ์ ๋ฌด์์ธ๊ฐ์? [[how-to-add-a-model-to-transformers]]
Hugging Face Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ์ปค๋ฎค๋ํฐ ๊ธฐ์ฌ์๋ค ๋๋ถ์ ์๋ก์ด ๋ชจ๋ธ์ ์ ๊ณตํ ์ ์๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ํ์ง๋ง ์ด๋ ๋์ ์ ์ธ ํ๋ก์ ํธ์ด๋ฉฐ Hugging Face Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๊ตฌํํ ๋ชจ๋ธ์ ๋ํ ๊น์ ์ดํด๊ฐ ํ์ํฉ๋๋ค. Hugging Face์์๋ ๋ ๋ง์ ์ปค๋ฎค๋ํฐ ๋ฉค๋ฒ๊ฐ ๋ชจ๋ธ์ ์ ๊ทน์ ์ผ๋ก ์ถ๊ฐํ ์ ์๋๋ก ์ง์ํ๊ณ ์ ํ๋ฉฐ, ์ด ๊ฐ์ด๋๋ฅผ ํตํด PyTorch ๋ชจ๋ธ์ ์ถ๊ฐํ๋ ๊ณผ์ ์ ์๋ดํ๊ณ ์์ต๋๋ค (PyTorch๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํด์ฃผ์ธ์).
TensorFlow ๋ชจ๋ธ์ ๊ตฌํํ๊ณ ์ ํ๋ ๊ฒฝ์ฐ ๐ค Transformers ๋ชจ๋ธ์ TensorFlow๋ก ๋ณํํ๋ ๋ฐฉ๋ฒ ๊ฐ์ด๋๋ฅผ ์ดํด๋ณด์ธ์!
์ด ๊ณผ์ ์ ์งํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ ๋ด์ฉ์ ์ดํดํ๊ฒ ๋ฉ๋๋ค:
- ์คํ ์์ค์ ๋ชจ๋ฒ ์ฌ๋ก์ ๋ํ ํต์ฐฐ๋ ฅ์ ์ป์ต๋๋ค.
- ๊ฐ์ฅ ์ธ๊ธฐ ์๋ ๋ฅ๋ฌ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ค๊ณ ์์น์ ์ดํดํฉ๋๋ค.
- ๋๊ท๋ชจ ๋ชจ๋ธ์ ํจ์จ์ ์ผ๋ก ํ ์คํธํ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์๋๋ค.
black
,ruff
,make fix-copies
์ ๊ฐ์ Python ์ ํธ๋ฆฌํฐ๋ฅผ ํตํฉํ์ฌ ๊น๋ํ๊ณ ๊ฐ๋ ์ฑ ์๋ ์ฝ๋๋ฅผ ์์ฑํ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์๋๋ค.
Hugging Face ํ์ ํญ์ ๋์์ ์ค ์ค๋น๊ฐ ๋์ด ์์ผ๋ฏ๋ก ํผ์๊ฐ ์๋๋ผ๋ ์ ์ ๊ธฐ์ตํ์ธ์. ๐ค โค๏ธ
์์์ ์์ ๐ค Transformers์ ์ํ๋ ๋ชจ๋ธ์ ์ถ๊ฐํ๊ธฐ ์ํด New model addition ์ด์๋ฅผ ์ด์ด์ผ ํฉ๋๋ค. ํน์ ๋ชจ๋ธ์ ๊ธฐ์ฌํ๋ ๋ฐ ํน๋ณํ ๊น๋ค๋ก์ด ๊ธฐ์ค์ ๊ฐ์ง์ง ์๋ ๊ฒฝ์ฐ New model label์ ํํฐ๋งํ์ฌ ์์ฒญ๋์ง ์์ ๋ชจ๋ธ์ด ์๋์ง ํ์ธํ๊ณ ์์ ํ ์ ์์ต๋๋ค.
์๋ก์ด ๋ชจ๋ธ ์์ฒญ์ ์ด์๋ค๋ฉด ์ฒซ ๋ฒ์งธ ๋จ๊ณ๋ ๐ค Transformers์ ์ต์ํด์ง๋ ๊ฒ์ ๋๋ค!
๐ค Transformers์ ์ ๋ฐ์ ์ธ ๊ฐ์ [[general-overview-of-transformers]]
๋จผ์ ๐ค Transformers์ ๋ํ ์ ๋ฐ์ ์ธ ๊ฐ์๋ฅผ ํ์ ํด์ผ ํฉ๋๋ค. ๐ค Transformers๋ ๋งค์ฐ ์ฃผ๊ด์ ์ธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ด๊ธฐ ๋๋ฌธ์ ํด๋น ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ฒ ํ์ด๋ ์ค๊ณ ์ ํ ์ฌํญ์ ๋์ํ์ง ์์ ์๋ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ฐ๋ฆฌ์ ๊ฒฝํ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๊ธฐ๋ณธ์ ์ธ ์ค๊ณ ์ ํ๊ณผ ์ฒ ํ์ ๐ค Transformers์ ๊ท๋ชจ๋ฅผ ํจ์จ์ ์ผ๋ก ํ์ฅํ๋ฉด์ ์ ์ง ๋ณด์ ๋น์ฉ์ ํฉ๋ฆฌ์ ์ธ ์์ค์ผ๋ก ์ ์งํ๋ ๊ฒ์ ๋๋ค.
๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ฒ ํ์ ๋ํ ๋ฌธ์๋ฅผ ์ฝ๋ ๊ฒ์ด ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๋ ์ ์ดํดํ๋ ์ข์ ์์์ ์ ๋๋ค. ๋ชจ๋ ๋ชจ๋ธ์ ์ ์ฉํ๋ ค๋ ๋ช ๊ฐ์ง ์์ ๋ฐฉ์์ ๋ํ ์ ํ ์ฌํญ์ด ์์ต๋๋ค:
- ์ผ๋ฐ์ ์ผ๋ก ์ถ์ํ๋ณด๋ค๋ ๊ตฌ์ฑ์ ์ ํธํฉ๋๋ค.
- ์ฝ๋๋ฅผ ๋ณต์ ํ๋ ๊ฒ์ด ํญ์ ๋์ ๊ฒ์ ์๋๋๋ค. ์ฝ๋์ ๊ฐ๋ ์ฑ์ด๋ ์ ๊ทผ์ฑ์ ํฌ๊ฒ ํฅ์์ํจ๋ค๋ฉด ๋ณต์ ํ๋ ๊ฒ์ ์ข์ต๋๋ค.
- ๋ชจ๋ธ ํ์ผ์ ๊ฐ๋ฅํ ํ ๋
๋ฆฝ์ ์ผ๋ก ์ ์ง๋์ด์ผ ํฉ๋๋ค. ๋ฐ๋ผ์ ํน์ ๋ชจ๋ธ์ ์ฝ๋๋ฅผ ์ฝ์ ๋ ํด๋น
modeling_....py
ํ์ผ๋ง ํ์ธํ๋ฉด ๋ฉ๋๋ค.
์ฐ๋ฆฌ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ฝ๋๊ฐ ์ ํ์ ์ ๊ณตํ๋ ์๋จ๋ฟ๋ง ์๋๋ผ ๊ฐ์ ํ๊ณ ์ ํ๋ ์ ํ์ด๋ผ๊ณ ๋ ์๊ฐํฉ๋๋ค. ๋ฐ๋ผ์ ๋ชจ๋ธ์ ์ถ๊ฐํ ๋, ์ฌ์ฉ์๋ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ฌ๋๋ฟ๋ง ์๋๋ผ ์ฝ๋๋ฅผ ์ฝ๊ณ ์ดํดํ๊ณ ํ์ํ ๊ฒฝ์ฐ ์กฐ์ ํ ์ ์๋ ๋ชจ๋ ์ฌ๋๊น์ง๋ ํฌํจํ๋ค๋ ์ ์ ๊ธฐ์ตํด์ผ ํฉ๋๋ค.
์ด๋ฅผ ์ผ๋์ ๋๊ณ ์ผ๋ฐ์ ์ธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค๊ณ์ ๋ํด ์กฐ๊ธ ๋ ์์ธํ ์์๋ณด๊ฒ ์ต๋๋ค.
๋ชจ๋ธ ๊ฐ์ [[overview-of-models]]
๋ชจ๋ธ์ ์ฑ๊ณต์ ์ผ๋ก ์ถ๊ฐํ๋ ค๋ฉด ๋ชจ๋ธ๊ณผ ํด๋น ๊ตฌ์ฑ์ธ [PreTrainedModel
] ๋ฐ [PretrainedConfig
] ๊ฐ์ ์ํธ์์ฉ์ ์ดํดํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๐ค Transformers์ ์ถ๊ฐํ๋ ค๋ ๋ชจ๋ธ์ BrandNewBert
๋ผ๊ณ ๋ถ๋ฅด๊ฒ ์ต๋๋ค.
๋ค์์ ์ดํด๋ณด๊ฒ ์ต๋๋ค:

๋ณด๋ค์ํผ, ๐ค Transformers์์๋ ์์์ ์ฌ์ฉํ์ง๋ง ์ถ์ํ ์์ค์ ์ต์ํ์ผ๋ก ์ ์งํฉ๋๋ค. ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ด๋ค ๋ชจ๋ธ์์๋ ๋ ์์ค ์ด์์ ์ถ์ํ๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค. BrandNewBertModel
์ BrandNewBertPreTrainedModel
์์ ์์๋ฐ๊ณ , ์ด ํด๋์ค๋ [PreTrainedModel
]์์ ์์๋ฐ์ต๋๋ค. ์ด๋ก์จ ์๋ก์ด ๋ชจ๋ธ์ [PreTrainedModel
]์๋ง ์์กดํ๋๋ก ํ๋ ค๊ณ ํฉ๋๋ค. ๋ชจ๋ ์๋ก์ด ๋ชจ๋ธ์ ์๋์ผ๋ก ์ ๊ณต๋๋ ์ค์ํ ๊ธฐ๋ฅ์ [~PreTrainedModel.from_pretrained
] ๋ฐ [~PreTrainedModel.save_pretrained
]์
๋๋ค. ์ด๋ฌํ ๊ธฐ๋ฅ ์ธ์๋ BrandNewBertModel.forward
์ ๊ฐ์ ๋ค๋ฅธ ์ค์ํ ๊ธฐ๋ฅ์ ์๋ก์ด modeling_brand_new_bert.py
์คํฌ๋ฆฝํธ์์ ์์ ํ ์ ์๋์ด์ผ ํฉ๋๋ค. ๋ํ BrandNewBertForMaskedLM
๊ณผ ๊ฐ์ ํน์ ํค๋ ๋ ์ด์ด๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ BrandNewBertModel
์ ์์๋ฐ์ง ์๊ณ forward pass์์ ํธ์ถํ ์ ์๋ BrandNewBertModel
์ ์ฌ์ฉํ์ฌ ์ถ์ํ ์์ค์ ๋ฎ๊ฒ ์ ์งํฉ๋๋ค. ๋ชจ๋ ์๋ก์ด ๋ชจ๋ธ์ BrandNewBertConfig
๋ผ๋ ๊ตฌ์ฑ ํด๋์ค๋ฅผ ํ์๋ก ํฉ๋๋ค. ์ด ๊ตฌ์ฑ์ ํญ์ [PreTrainedModel
]์ ์์ฑ์ผ๋ก ์ ์ฅ๋๋ฉฐ, ๋ฐ๋ผ์ BrandNewBertPreTrainedModel
์ ์์๋ฐ๋ ๋ชจ๋ ํด๋์ค์์ config
์์ฑ์ ํตํด ์ก์ธ์คํ ์ ์์ต๋๋ค:
model = BrandNewBertModel.from_pretrained("brandy/brand_new_bert")
model.config # model has access to its config
๋ชจ๋ธ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ๊ตฌ์ฑ์ [PretrainedConfig
]์์ ๊ธฐ๋ณธ ์ง๋ ฌํ ๋ฐ ์ญ์ง๋ ฌํ ๊ธฐ๋ฅ์ ์์๋ฐ์ต๋๋ค. ๊ตฌ์ฑ๊ณผ ๋ชจ๋ธ์ ํญ์ pytorch_model.bin ํ์ผ๊ณผ config.json ํ์ผ๋ก ๊ฐ๊ฐ ๋ณ๋๋ก ์ง๋ ฌํ๋ฉ๋๋ค. [~PreTrainedModel.save_pretrained
]๋ฅผ ํธ์ถํ๋ฉด ์๋์ผ๋ก [~PretrainedConfig.save_pretrained
]๋ ํธ์ถ๋๋ฏ๋ก ๋ชจ๋ธ๊ณผ ๊ตฌ์ฑ์ด ๋ชจ๋ ์ ์ฅ๋ฉ๋๋ค.
์ฝ๋ ์คํ์ผ [[code-style]]
์๋ก์ด ๋ชจ๋ธ์ ์์ฑํ ๋, Transformers๋ ์ฃผ๊ด์ ์ธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ด๋ฉฐ ๋ช ๊ฐ์ง ๋ ํนํ ์ฝ๋ฉ ์คํ์ผ์ด ์์ต๋๋ค:
- ๋ชจ๋ธ์ forward pass๋ ๋ชจ๋ธ ํ์ผ์ ์์ ํ ์์ฑ๋์ด์ผ ํฉ๋๋ค. ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ค๋ฅธ ๋ชจ๋ธ์์ ๋ธ๋ก์ ์ฌ์ฌ์ฉํ๋ ค๋ฉด ์ฝ๋๋ฅผ ๋ณต์ฌํ์ฌ ์์
# Copied from
์ฃผ์๊ณผ ํจ๊ป ๋ถ์ฌ๋ฃ์ผ๋ฉด ๋ฉ๋๋ค (์: ์ฌ๊ธฐ๋ฅผ ์ฐธ์กฐํ์ธ์). - ์ฝ๋๋ ์์ ํ ์ดํดํ๊ธฐ ์ฌ์์ผ ํฉ๋๋ค. ๋ณ์ ์ด๋ฆ์ ๋ช
ํํ๊ฒ ์ง์ ํ๊ณ ์ฝ์ด๋ฅผ ์ฌ์ฉํ์ง ์๋ ๊ฒ์ด ์ข์ต๋๋ค. ์๋ฅผ ๋ค์ด,
act
๋ณด๋ค๋activation
์ ์ ํธํฉ๋๋ค. ํ ๊ธ์ ๋ณ์ ์ด๋ฆ์ ๋ฃจํ์ ์ธ๋ฑ์ค์ธ ๊ฒฝ์ฐ๋ฅผ ์ ์ธํ๊ณ ๊ถ์ฅ๋์ง ์์ต๋๋ค. - ๋ ์ผ๋ฐ์ ์ผ๋ก, ์งง์ ๋ง๋ฒ ๊ฐ์ ์ฝ๋๋ณด๋ค๋ ๊ธธ๊ณ ๋ช ์์ ์ธ ์ฝ๋๋ฅผ ์ ํธํฉ๋๋ค.
- PyTorch์์
nn.Sequential
์ ํ์ ํด๋์ค๋ก ๋ง๋ค์ง ๋ง๊ณnn.Module
์ ํ์ ํด๋์ค๋ก ๋ง๋ค๊ณ forward pass๋ฅผ ์์ฑํ์ฌ ๋ค๋ฅธ ์ฌ๋์ด ์ฝ๋๋ฅผ ๋น ๋ฅด๊ฒ ๋๋ฒ๊ทธํ ์ ์๋๋ก ํฉ๋๋ค. print ๋ฌธ์ด๋ ์ค๋จ์ ์ ์ถ๊ฐํ ์ ์์ต๋๋ค. - ํจ์ ์๊ทธ๋์ฒ์๋ ํ์ ์ฃผ์์ ์ฌ์ฉํด์ผ ํฉ๋๋ค. ๊ทธ ์ธ์๋ ํ์ ์ฃผ์๋ณด๋ค ๋ณ์ ์ด๋ฆ์ด ํจ์ฌ ์ฝ๊ธฐ ์ฝ๊ณ ์ดํดํ๊ธฐ ์ฝ์ต๋๋ค.
ํ ํฌ๋์ด์ ๊ฐ์ [[overview-of-tokenizers]]
์์ง ์ค๋น๋์ง ์์์ต๋๋ค :-( ์ด ์น์ ์ ๊ณง ์ถ๊ฐ๋ ์์ ์ ๋๋ค!
๐ค Transformers์ ๋ชจ๋ธ ์ถ๊ฐํ๋ ๋จ๊ณ๋ณ ๋ฐฉ๋ฒ [[stepbystep-recipe-to-add-a-model-to-transformers]]
๊ฐ์ ๋ชจ๋ธ์ ์ด์ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์ ํธ๊ฐ ๋ค๋ฅด๊ธฐ ๋๋ฌธ์ ๋ค๋ฅธ ๊ธฐ์ฌ์๋ค์ด Hugging Face์ ๋ชจ๋ธ์ ์ด์ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์์ฝ์ ์ดํด๋ณด๋ ๊ฒ์ด ๋งค์ฐ ์ ์ฉํ ์ ์์ต๋๋ค. ๋ค์์ ๋ชจ๋ธ์ ์ด์ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ์ปค๋ฎค๋ํฐ ๋ธ๋ก๊ทธ ๊ฒ์๋ฌผ ๋ชฉ๋ก์ ๋๋ค:
๊ฒฝํ์ ๋ชจ๋ธ์ ์ถ๊ฐํ ๋ ์ฃผ์ํด์ผ ํ ๊ฐ์ฅ ์ค์ํ ์ฌํญ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ๊ฐ์ ์ผ์ ๋ฐ๋ณตํ์ง ๋ง์ธ์! ์๋ก์ด ๐ค Transformers ๋ชจ๋ธ์ ์ํด ์ถ๊ฐํ ์ฝ๋์ ๋๋ถ๋ถ์ ์ด๋ฏธ ๐ค Transformers ์ด๋๊ฐ์ ์กด์ฌํฉ๋๋ค. ์ด๋ฏธ ์กด์ฌํ๋ ๋ณต์ฌํ ์ ์๋ ์ ์ฌํ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ์ฐพ๋๋ฐ ์๊ฐ์ ํฌ์ํ์ธ์. grep์ rg๋ฅผ ์ฐธ๊ณ ํ์ธ์. ๋ชจ๋ธ์ ํ ํฌ๋์ด์ ๊ฐ ํ ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก ํ๊ณ ๋ชจ๋ธ๋ง ์ฝ๋๊ฐ ๋ค๋ฅธ ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ ๊ฒฝ์ฐ๊ฐ ์กด์ฌํ ์๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด FSMT์ ๋ชจ๋ธ๋ง ์ฝ๋๋ BART๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๊ณ FSMT์ ํ ํฌ๋์ด์ ์ฝ๋๋ XLM์ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค.
- ์ด๊ฒ์ ๊ณผํ์ ์ธ ๋์ ๋ณด๋ค๋ ๊ณตํ์ ์ธ ๋์ ์ ๋๋ค. ๋ ผ๋ฌธ์ ๋ชจ๋ธ์ ๋ชจ๋ ์ด๋ก ์ ์ธก๋ฉด์ ์ดํดํ๋ ค๋ ๊ฒ๋ณด๋ค ํจ์จ์ ์ธ ๋๋ฒ๊น ํ๊ฒฝ์ ๋ง๋๋ ๋ฐ ๋ ๋ง์ ์๊ฐ์ ์๋นํด์ผ ํฉ๋๋ค.
- ๋งํ ๋ ๋์์ ์์ฒญํ์ธ์! ๋ชจ๋ธ์ ๐ค Transformers์ ํต์ฌ ๊ตฌ์ฑ ์์์ด๋ฏ๋ก Hugging Face์ ์ฐ๋ฆฌ๋ ๋น์ ์ด ๋ชจ๋ธ์ ์ถ๊ฐํ๋ ๊ฐ ๋จ๊ณ์์ ๊ธฐ๊บผ์ด ๋์์ ์ค ์ค๋น๊ฐ ๋์ด ์์ต๋๋ค. ์ง์ ์ด ์๋ค๊ณ ๋๋ผ๋ฉด ์ฃผ์ ํ์ง ๋ง๊ณ ๋์์ ์์ฒญํ์ธ์.
๋ค์์์๋ ๋ชจ๋ธ์ ๐ค Transformers๋ก ์ด์ํ๋ ๋ฐ ๊ฐ์ฅ ์ ์ฉํ ์ผ๋ฐ์ ์ธ ์ ์ฐจ๋ฅผ ์ ๊ณตํ๋ ค๊ณ ๋ ธ๋ ฅํฉ๋๋ค.
๋ค์ ๋ชฉ๋ก์ ๋ชจ๋ธ์ ์ถ๊ฐํ๋ ๋ฐ ์ํํด์ผ ํ ๋ชจ๋ ์์ ์ ์์ฝ์ด๋ฉฐ To-Do ๋ชฉ๋ก์ผ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค:
โ (์ ํ ์ฌํญ) BrandNewBert์ ์ด๋ก ์ ์ธก๋ฉด ์ดํด
โ Hugging Face ๊ฐ๋ฐ ํ๊ฒฝ ์ค๋น
โ ์๋ณธ ๋ฆฌํฌ์งํ ๋ฆฌ์ ๋๋ฒ๊น
ํ๊ฒฝ ์ค์
โ ์๋ณธ ๋ฆฌํฌ์งํ ๋ฆฌ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ์ฌ forward()
pass๊ฐ ์ฑ๊ณต์ ์ผ๋ก ์คํ๋๋ ์คํฌ๋ฆฝํธ ์์ฑ
โ ๐ค Transformers์ ๋ชจ๋ธ ์ค์ผ๋ ํค ์ฑ๊ณต์ ์ผ๋ก ์ถ๊ฐ
โ ์๋ณธ ์ฒดํฌํฌ์ธํธ๋ฅผ ๐ค Transformers ์ฒดํฌํฌ์ธํธ๋ก ์ฑ๊ณต์ ์ผ๋ก ๋ณํ
โ ๐ค Transformers์์ ์๋ณธ ์ฒดํฌํฌ์ธํธ์ ๋์ผํ ์ถ๋ ฅ์ ๋ด์ฃผ๋ forward()
pass ์ฑ๊ณต์ ์ผ๋ก ์คํ
โ ๐ค Transformers์์ ๋ชจ๋ธ ํ
์คํธ ์๋ฃ
โ ๐ค Transformers์ ํ ํฌ๋์ด์ ์ฑ๊ณต์ ์ผ๋ก ์ถ๊ฐ
โ ์ข
๋จ ๊ฐ ํตํฉ ํ
์คํธ ์คํ
โ ๋ฌธ์ ์์ฑ ์๋ฃ
โ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ํ๋ธ์ ์
๋ก๋
โ Pull request ์ ์ถ
โ (์ ํ ์ฌํญ) ๋ฐ๋ชจ ๋
ธํธ๋ถ ์ถ๊ฐ
์ฐ์ , ์ผ๋ฐ์ ์ผ๋ก๋ BrandNewBert
์ ์ด๋ก ์ ์ธ ์ดํด๋ก ์์ํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ก ์ ์ธก๋ฉด์ ์ง์ ์ดํดํ๋ ๋์ ์ง์ ํด๋ณด๋ฉด์ ๋ชจ๋ธ์ ์ด๋ก ์ ์ธก๋ฉด์ ์ดํดํ๋ ๊ฒ์ ์ ํธํ๋ ๊ฒฝ์ฐ ๋ฐ๋ก BrandNewBert
์ฝ๋ ๋ฒ ์ด์ค๋ก ๋น ์ ธ๋๋ ๊ฒ๋ ๊ด์ฐฎ์ต๋๋ค. ์ด ์ต์
์ ์์ง๋์ด๋ง ๊ธฐ์ ์ด ์ด๋ก ์ ๊ธฐ์ ๋ณด๋ค ๋ ๋ฐ์ด๋ ๊ฒฝ์ฐ, BrandNewBert
์ ๋
ผ๋ฌธ์ ์ดํดํ๋ ๋ฐ ์ด๋ ค์์ด ์๋ ๊ฒฝ์ฐ, ๋๋ ๊ณผํ์ ์ธ ๋
ผ๋ฌธ์ ์ฝ๋ ๊ฒ๋ณด๋ค ํ๋ก๊ทธ๋๋ฐ์ ํจ์ฌ ๋ ํฅ๋ฏธ ์๋ ๊ฒฝ์ฐ์ ๋ ์ ํฉํ ์ ์์ต๋๋ค.
1. (์ ํ ์ฌํญ) BrandNewBert์ ์ด๋ก ์ ์ธก๋ฉด [[1-optional-theoretical-aspects-of-brandnewbert]]
๋ง์ฝ ๊ทธ๋ฐ ์์ ์ ์ธ ์์ ์ด ์กด์ฌํ๋ค๋ฉด, BrandNewBert์ ๋ ผ๋ฌธ์ ์ฝ์ด๋ณด๋ ์๊ฐ์ ๊ฐ์ ธ์ผ ํฉ๋๋ค. ์ดํดํ๊ธฐ ์ด๋ ค์ด ์น์ ์ด ๋ง์ ์ ์์ต๋๋ค. ๊ทธ๋ ๋๋ผ๋ ๊ฑฑ์ ํ์ง ๋ง์ธ์! ๋ชฉํ๋ ๋ ผ๋ฌธ์ ๊น์ ์ด๋ก ์ ์ดํด๊ฐ ์๋๋ผ BrandNewBert๋ฅผ ๐ค Transformers์์ ํจ๊ณผ์ ์ผ๋ก ์ฌ๊ตฌํํ๊ธฐ ์ํด ํ์ํ ์ ๋ณด๋ฅผ ์ถ์ถํ๋ ๊ฒ์ ๋๋ค. ์ด๋ฅผ ์ํด ์ด๋ก ์ ์ธก๋ฉด์ ๋๋ฌด ๋ง์ ์๊ฐ์ ํฌ์ํ ํ์๋ ์์ง๋ง ๋ค์๊ณผ ๊ฐ์ ์ค์ ์ ์ธ ์ธก๋ฉด์ ์ง์คํด์ผ ํฉ๋๋ค:
- BrandNewBert๋ ์ด๋ค ์ ํ์ ๋ชจ๋ธ์ธ๊ฐ์? BERT์ ์ ์ฌํ ์ธ์ฝ๋ ๋ชจ๋ธ์ธ๊ฐ์? GPT2์ ์ ์ฌํ ๋์ฝ๋ ๋ชจ๋ธ์ธ๊ฐ์? BART์ ์ ์ฌํ ์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ์ธ๊ฐ์? ์ด๋ค ๊ฐ์ ์ฐจ์ด์ ์ ์ต์ํ์ง ์์ ๊ฒฝ์ฐmodel_summary๋ฅผ ์ฐธ์กฐํ์ธ์.
- BrandNewBert์ ์์ฉ ๋ถ์ผ๋ ๋ฌด์์ธ๊ฐ์? ํ ์คํธ ๋ถ๋ฅ์ธ๊ฐ์? ํ ์คํธ ์์ฑ์ธ๊ฐ์? ์์ฝ๊ณผ ๊ฐ์ Seq2Seq ์์ ์ธ๊ฐ์?
- brand_new_bert์ BERT/GPT-2/BART์ ์ฐจ์ด์ ์ ๋ฌด์์ธ๊ฐ์?
- brand_new_bert์ ๊ฐ์ฅ ์ ์ฌํ ๐ค Transformers ๋ชจ๋ธ์ ๋ฌด์์ธ๊ฐ์?
- ์ด๋ค ์ข ๋ฅ์ ํ ํฌ๋์ด์ ๊ฐ ์ฌ์ฉ๋๋์? Sentencepiece ํ ํฌ๋์ด์ ์ธ๊ฐ์? Word piece ํ ํฌ๋์ด์ ์ธ๊ฐ์? BERT ๋๋ BART์ ์ฌ์ฉ๋๋ ๋์ผํ ํ ํฌ๋์ด์ ์ธ๊ฐ์?
๋ชจ๋ธ์ ์ํคํ ์ฒ์ ๋ํด ์ถฉ๋ถํ ์ดํดํ๋ค๋ ์๊ฐ์ด ๋ ํ, ๊ถ๊ธํ ์ฌํญ์ด ์์ผ๋ฉด Hugging Face ํ์ ๋ฌธ์ํ์ญ์์ค. ์ด๋ ๋ชจ๋ธ์ ์ํคํ ์ฒ, ์ดํ ์ ๋ ์ด์ด ๋ฑ์ ๊ดํ ์ง๋ฌธ์ ํฌํจํ ์ ์์ต๋๋ค. Hugging Face์ ์ ์ง ๊ด๋ฆฌ์๋ค์ ๋ณดํต ์ฝ๋๋ฅผ ๊ฒํ ํ๋ ๊ฒ์ ๋ํด ๋งค์ฐ ๊ธฐ๋ปํ๋ฏ๋ก ๋น์ ์ ๋๋ ์ผ์ ๋งค์ฐ ํ์ํ ๊ฒ์ ๋๋ค!
2. ๊ฐ๋ฐ ํ๊ฒฝ ์ค์ [[2-next-prepare-your-environment]]
์ ์ฅ์ ํ์ด์ง์์ "Fork" ๋ฒํผ์ ํด๋ฆญํ์ฌ ์ ์ฅ์์ ์ฌ๋ณธ์ GitHub ์ฌ์ฉ์ ๊ณ์ ์ผ๋ก ๋ง๋ญ๋๋ค.
transformers
fork๋ฅผ ๋ก์ปฌ ๋์คํฌ์ ํด๋ก ํ๊ณ ๋ฒ ์ด์ค ์ ์ฅ์๋ฅผ ์๊ฒฉ ์ ์ฅ์๋ก ์ถ๊ฐํฉ๋๋ค:
git clone https://github.com/[your Github handle]/transformers.git
cd transformers
git remote add upstream https://github.com/huggingface/transformers.git
- ๊ฐ๋ฐ ํ๊ฒฝ์ ์ค์ ํฉ๋๋ค. ๋ค์ ๋ช ๋ น์ ์คํํ์ฌ ๊ฐ๋ฐ ํ๊ฒฝ์ ์ค์ ํ ์ ์์ต๋๋ค:
python -m venv .env
source .env/bin/activate
pip install -e ".[dev]"
๊ฐ ์ด์ ์ฒด์ ์ ๋ฐ๋ผ Transformers์ ์ ํ์ ์์กด์ฑ์ด ๊ฐ์๊ฐ ์ฆ๊ฐํ๋ฉด ์ด ๋ช ๋ น์ด ์คํจํ ์ ์์ต๋๋ค. ๊ทธ๋ฐ ๊ฒฝ์ฐ์๋ ์์ ์ค์ธ ๋ฅ ๋ฌ๋ ํ๋ ์์ํฌ (PyTorch, TensorFlow ๋ฐ/๋๋ Flax)์ ์ค์นํ ํ, ๋ค์ ๋ช ๋ น์ ์ํํ๋ฉด ๋ฉ๋๋ค:
pip install -e ".[quality]"
๋๋ถ๋ถ์ ๊ฒฝ์ฐ์๋ ์ด๊ฒ์ผ๋ก ์ถฉ๋ถํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ ์์ ๋๋ ํ ๋ฆฌ๋ก ๋์๊ฐ๋๋ค.
cd ..
- Transformers์ brand_new_bert์ PyTorch ๋ฒ์ ์ ์ถ๊ฐํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. PyTorch๋ฅผ ์ค์นํ๋ ค๋ฉด ๋ค์ ๋งํฌ์ ์ง์นจ์ ๋ฐ๋ฅด์ญ์์ค: https://pytorch.org/get-started/locally/.
์ฐธ๊ณ : CUDA๋ฅผ ์ค์นํ ํ์๋ ์์ต๋๋ค. ์๋ก์ด ๋ชจ๋ธ์ด CPU์์ ์๋ํ๋๋ก ๋ง๋๋ ๊ฒ์ผ๋ก ์ถฉ๋ถํฉ๋๋ค.
- brand_new_bert๋ฅผ ์ด์ํ๊ธฐ ์ํด์๋ ํด๋น ์๋ณธ ์ ์ฅ์์ ์ ๊ทผํ ์ ์์ด์ผ ํฉ๋๋ค:
git clone https://github.com/org_that_created_brand_new_bert_org/brand_new_bert.git
cd brand_new_bert
pip install -e .
์ด์ brand_new_bert๋ฅผ ๐ค Transformers๋ก ์ด์ํ๊ธฐ ์ํ ๊ฐ๋ฐ ํ๊ฒฝ์ ์ค์ ํ์์ต๋๋ค.
3.-4. ์๋ณธ ์ ์ฅ์์์ ์ฌ์ ํ๋ จ๋ ์ฒดํฌํฌ์ธํธ ์คํํ๊ธฐ [[3.-4.-run-a-pretrained-checkpoint-using-the-original-repository]]
๋จผ์ , ์๋ณธ brand_new_bert ์ ์ฅ์์์ ์์ ์ ์์ํฉ๋๋ค. ์๋ณธ ๊ตฌํ์ ๋ณดํต "์ฐ๊ตฌ์ฉ"์ผ๋ก ๋ง์ด ์ฌ์ฉ๋ฉ๋๋ค. ์ฆ, ๋ฌธ์ํ๊ฐ ๋ถ์กฑํ๊ณ ์ฝ๋๊ฐ ์ดํดํ๊ธฐ ์ด๋ ค์ธ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ด๊ฒ์ด ๋ฐ๋ก brand_new_bert๋ฅผ ๋ค์ ๊ตฌํํ๋ ค๋ ๋๊ธฐ๊ฐ ๋์ด์ผ ํฉ๋๋ค. Hugging Face์์์ ์ฃผ์ ๋ชฉํ ์ค ํ๋๋ ๊ฑฐ์ธ์ ์ด๊นจ ์์ ์๋ ๊ฒ์ด๋ฉฐ, ์ด๋ ์ฌ๊ธฐ์์ ์ฝ๊ฒ ํด์๋์ด ๋์ํ๋ ๋ชจ๋ธ์ ๊ฐ์ ธ์์ ๊ฐ๋ฅํ ํ ์ ๊ทผ ๊ฐ๋ฅํ๊ณ ์ฌ์ฉ์ ์นํ์ ์ด๋ฉฐ ์๋ฆ๋ต๊ฒ ๋ง๋๋ ๊ฒ์ ๋๋ค. ์ด๊ฒ์ ๐ค Transformers์์ ๋ชจ๋ธ์ ๋ค์ ๊ตฌํํ๋ ๊ฐ์ฅ ์ค์ํ ๋๊ธฐ์ ๋๋ค - ์๋ก์ด ๋ณต์กํ NLP ๊ธฐ์ ์ ๋ชจ๋์๊ฒ ์ ๊ทผ ๊ฐ๋ฅํ๊ฒ ๋ง๋๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค.
๋ฐ๋ผ์ ์๋ณธ ์ ์ฅ์์ ๋ํด ์์ธํ ์ดํด๋ณด๋ ๊ฒ์ผ๋ก ์์ํด์ผ ํฉ๋๋ค.
์๋ณธ ์ ์ฅ์์์ ๊ณต์ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ฑ๊ณต์ ์ผ๋ก ์คํํ๋ ๊ฒ์ ์ข ์ข ๊ฐ์ฅ ์ด๋ ค์ด ๋จ๊ณ์ ๋๋ค. ์ฐ๋ฆฌ์ ๊ฒฝํ์ ๋ฐ๋ฅด๋ฉด, ์๋ณธ ์ฝ๋ ๋ฒ ์ด์ค์ ์ต์ํด์ง๋ ๋ฐ ์๊ฐ์ ํฌ์ํ๋ ๊ฒ์ด ๋งค์ฐ ์ค์ํฉ๋๋ค. ๋ค์์ ํ์ ํด์ผ ํฉ๋๋ค:
- ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ์ด๋์ ์ฐพ์ ์ ์๋์ง?
- ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ํด๋น ๋ชจ๋ธ์๋ก๋ํ๋ ๋ฐฉ๋ฒ์?
- ๋ชจ๋ธ๊ณผ ๋ ๋ฆฝ์ ์ผ๋ก ํ ํฌ๋์ด์ ๋ฅผ ์คํํ๋ ๋ฐฉ๋ฒ์?
- ๊ฐ๋จํ forward pass์ ํ์ํ ํด๋์ค์ ํจ์๋ฅผ ํ์ ํ๊ธฐ ์ํด forward pass๋ฅผ ํ ๋ฒ ์ถ์ ํด ๋ณด์ธ์. ์ผ๋ฐ์ ์ผ๋ก ํด๋น ํจ์๋ค๋ง ๋ค์ ๊ตฌํํ๋ฉด ๋ฉ๋๋ค.
- ๋ชจ๋ธ์ ์ค์ํ ๊ตฌ์ฑ ์์๋ฅผ ์ฐพ์ ์ ์์ด์ผ ํฉ๋๋ค. ๋ชจ๋ธ ํด๋์ค๋ ์ด๋์ ์๋์? ๋ชจ๋ธ ํ์ ํด๋์ค(EncoderModel, DecoderModel ๋ฑ)๊ฐ ์๋์? self-attention ๋ ์ด์ด๋ ์ด๋์ ์๋์? self-attention, cross-attention ๋ฑ ์ฌ๋ฌ ๊ฐ์ง ๋ค๋ฅธ ์ดํ ์ ๋ ์ด์ด๊ฐ ์๋์?
- ์๋ณธ ํ๊ฒฝ์์ ๋ชจ๋ธ์ ๋๋ฒ๊ทธํ ์ ์๋ ๋ฐฉ๋ฒ์ ๋ฌด์์ธ๊ฐ์? print ๋ฌธ์ ์ถ๊ฐํด์ผ ํ๋์? ipdb์ ๊ฐ์ ๋ํ์ ๋๋ฒ๊ฑฐ๋ฅผ ์ฌ์ฉํ ์ ์๋์? PyCharm๊ณผ ๊ฐ์ ํจ์จ์ ์ธ IDE๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ๋๋ฒ๊ทธํ ์ ์๋์?
์๋ณธ ์ ์ฅ์์์ ์ฝ๋๋ฅผ ์ด์ํ๋ ์์ ์ ์์ํ๊ธฐ ์ ์ ์๋ณธ ์ ์ฅ์์์ ์ฝ๋๋ฅผ ํจ์จ์ ์ผ๋ก ๋๋ฒ๊ทธํ ์ ์์ด์ผ ํฉ๋๋ค! ๋ํ, ์คํ ์์ค ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์์ ํ๊ณ ์๋ค๋ ๊ฒ์ ๊ธฐ์ตํด์ผ ํฉ๋๋ค. ๋ฐ๋ผ์ ์๋ณธ ์ ์ฅ์์์ issue๋ฅผ ์ด๊ฑฐ๋ pull request๋ฅผ ์ด๊ธฐ๋ฅผ ์ฃผ์ ํ์ง ๋ง์ญ์์ค. ์ด ์ ์ฅ์์ ์ ์ง ๊ด๋ฆฌ์๋ค์ ๋๊ตฐ๊ฐ๊ฐ ์์ ๋ค์ ์ฝ๋๋ฅผ ์ดํด๋ณธ๋ค๋ ๊ฒ์ ๋ํด ๋งค์ฐ ๊ธฐ๋ปํ ๊ฒ์ ๋๋ค!
ํ์ฌ ์์ ์์, ์๋ ๋ชจ๋ธ์ ๋๋ฒ๊น ํ๊ธฐ ์ํด ์ด๋ค ๋๋ฒ๊น ํ๊ฒฝ๊ณผ ์ ๋ต์ ์ ํธํ๋์ง๋ ๋น์ ์๊ฒ ๋ฌ๋ ธ์ต๋๋ค. ์ฐ๋ฆฌ๋ ๊ณ ๊ฐ์ GPU ํ๊ฒฝ์ ๊ตฌ์ถํ๋ ๊ฒ์ ๋น์ถ์ฒํฉ๋๋ค. ๋์ , ์๋ ์ ์ฅ์๋ก ๋ค์ด๊ฐ์ ์์ ์ ์์ํ ๋์ ๐ค Transformers ๋ชจ๋ธ์ ๊ตฌํ์ ์์ํ ๋์๋ CPU์์ ์์ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๋ชจ๋ธ์ด ์ด๋ฏธ ๐ค Transformers๋ก ์ฑ๊ณต์ ์ผ๋ก ์ด์๋์์ ๋์๋ง ๋ชจ๋ธ์ด GPU์์๋ ์์๋๋ก ์๋ํ๋์ง ํ์ธํด์ผํฉ๋๋ค.
์ผ๋ฐ์ ์ผ๋ก, ์๋ ๋ชจ๋ธ์ ์คํํ๊ธฐ ์ํ ๋ ๊ฐ์ง ๊ฐ๋ฅํ ๋๋ฒ๊น ํ๊ฒฝ์ด ์์ต๋๋ค.
- Jupyter ๋ ธํธ๋ถ / Google Colab
- ๋ก์ปฌ Python ์คํฌ๋ฆฝํธ
Jupyter ๋ ธํธ๋ถ์ ์ฅ์ ์ ์ ๋จ์๋ก ์คํํ ์ ์๋ค๋ ๊ฒ์ ๋๋ค. ์ด๋ ๋ ผ๋ฆฌ์ ์ธ ๊ตฌ์ฑ ์์๋ฅผ ๋ ์ ๋ถ๋ฆฌํ๊ณ ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ ์ ์์ผ๋ฏ๋ก ๋๋ฒ๊น ์ฌ์ดํด์ด ๋ ๋นจ๋ผ์ง ์ ์์ต๋๋ค. ๋ํ, ๋ ธํธ๋ถ์ ๋ค๋ฅธ ๊ธฐ์ฌ์์ ์ฝ๊ฒ ๊ณต์ ํ ์ ์์ผ๋ฏ๋ก Hugging Face ํ์ ๋์์ ์์ฒญํ๋ ค๋ ๊ฒฝ์ฐ ๋งค์ฐ ์ ์ฉํ ์ ์์ต๋๋ค. Jupyter ๋ ธํธ๋ถ์ ์ต์ํ๋ค๋ฉด ์ด๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๊ฐ๋ ฅํ ์ถ์ฒํฉ๋๋ค.
Jupyter ๋
ธํธ๋ถ์ ๋จ์ ์ ์ฌ์ฉ์ ์ต์ํ์ง ์์ ๊ฒฝ์ฐ ์๋ก์ด ํ๋ก๊ทธ๋๋ฐ ํ๊ฒฝ์ ์ ์ํ๋ ๋ฐ ์๊ฐ์ ํ ์ ํด์ผ ํ๋ฉฐ, ipdb
์ ๊ฐ์ ์๋ ค์ง ๋๋ฒ๊น
๋๊ตฌ๋ฅผ ๋ ์ด์ ์ฌ์ฉํ ์ ์์ ์๋ ์๋ค๋ ๊ฒ์
๋๋ค.
๊ฐ ์ฝ๋ ๋ฒ ์ด์ค์ ๋ํด ์ข์ ์ฒซ ๋ฒ์งธ ๋จ๊ณ๋ ํญ์ ์์ ์ฌ์ ํ๋ จ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ก๋ํ๊ณ ๋๋ฏธ ์ ์ ๋ฒกํฐ ์ ๋ ฅ์ ์ฌ์ฉํ์ฌ ๋จ์ผ forward pass๋ฅผ ์ฌํํ๋ ๊ฒ์ ๋๋ค. ์ด์ ๊ฐ์ ์คํฌ๋ฆฝํธ๋ ๋ค์๊ณผ ๊ฐ์ ์ ์์ต๋๋ค(์์ฌ ์ฝ๋๋ก ์์ฑ):
model = BrandNewBertModel.load_pretrained_checkpoint("/path/to/checkpoint/")
input_ids = [0, 4, 5, 2, 3, 7, 9] # vector of input ids
original_output = model.predict(input_ids)
๋ค์์ผ๋ก, ๋๋ฒ๊น ์ ๋ต์ ๋ํด ์ผ๋ฐ์ ์ผ๋ก ๋ค์๊ณผ ๊ฐ์ ๋ช ๊ฐ์ง ์ ํ์ง๊ฐ ์์ต๋๋ค:
- ์๋ณธ ๋ชจ๋ธ์ ๋ง์ ์์ ํ ์คํธ ๊ฐ๋ฅํ ๊ตฌ์ฑ ์์๋ก ๋ถํดํ๊ณ ๊ฐ๊ฐ์ ๋ํด forward pass๋ฅผ ์คํํ์ฌ ๊ฒ์ฆํฉ๋๋ค.
- ์๋ณธ ๋ชจ๋ธ์ ์๋ณธ tokenizer๊ณผ ์๋ณธ model๋ก๋ง ๋ถํดํ๊ณ ํด๋น ๋ถ๋ถ์ ๋ํด forward pass๋ฅผ ์คํํ ํ ๊ฒ์ฆ์ ์ํด ์ค๊ฐ ์ถ๋ ฅ(print ๋ฌธ ๋๋ ์ค๋จ์ )์ ์ฌ์ฉํฉ๋๋ค.
๋ค์ ๋งํ์ง๋ง, ์ด๋ค ์ ๋ต์ ์ ํํ ์ง๋ ๋น์ ์๊ฒ ๋ฌ๋ ค ์์ต๋๋ค. ์๋ณธ ์ฝ๋ ๋ฒ ์ด์ค์ ๋ฐ๋ผ ํ๋ ๋๋ ๋ค๋ฅธ ์ ๋ต์ด ์ ๋ฆฌํ ์ ์์ต๋๋ค.
์๋ณธ ์ฝ๋ ๋ฒ ์ด์ค๋ฅผ ๋ชจ๋ธ์ ์์ ํ์ ๊ตฌ์ฑ ์์๋ก ๋ถํดํ ์ ์๋์ง ์ฌ๋ถ, ์๋ฅผ ๋ค์ด ์๋ณธ ์ฝ๋ ๋ฒ ์ด์ค๊ฐ ์ฆ์ ์คํ ๋ชจ๋์์ ๊ฐ๋จํ ์คํ๋ ์ ์๋ ๊ฒฝ์ฐ, ๊ทธ๋ฐ ๊ฒฝ์ฐ์๋ ๊ทธ ๋ ธ๋ ฅ์ด ๊ฐ์น๊ฐ ์๋ค๋ ๊ฒ์ด ์ผ๋ฐ์ ์ ๋๋ค. ์ด๊ธฐ์ ๋ ์ด๋ ค์ด ๋ฐฉ๋ฒ์ ์ ํํ๋ ๊ฒ์๋ ๋ช ๊ฐ์ง ์ค์ํ ์ฅ์ ์ด ์์ต๋๋ค.
- ์๋ณธ ๋ชจ๋ธ์ ๐ค Transformers ๊ตฌํ๊ณผ ๋น๊ตํ ๋ ๊ฐ ๊ตฌ์ฑ ์์๊ฐ ์ผ์นํ๋์ง ์๋์ผ๋ก ํ์ธํ ์ ์์ต๋๋ค. ์ฆ, ์๊ฐ์ ์ธ ๋น๊ต(print ๋ฌธ์ ํตํ ๋น๊ต๊ฐ ์๋) ๋์ ๐ค Transformers ๊ตฌํ๊ณผ ๊ทธ์ ๋์ํ๋ ์๋ณธ ๊ตฌ์ฑ ์์๊ฐ ์ผ์นํ๋์ง ํ์ธํ ์ ์์ต๋๋ค.
- ์ ์ฒด ๋ชจ๋ธ์ ๋ชจ๋๋ณ๋ก, ์ฆ ์์ ๊ตฌ์ฑ ์์๋ก ๋ถํดํจ์ผ๋ก์จ ๋ชจ๋ธ์ ์ด์ํ๋ ํฐ ๋ฌธ์ ๋ฅผ ๋จ์ํ ๊ฐ๋ณ ๊ตฌ์ฑ ์์๋ฅผ ์ด์ํ๋ ์์ ๋ฌธ์ ๋ก ๋ถํดํ ์ ์์ผ๋ฏ๋ก ์์ ์ ๋ ์ ๊ตฌ์กฐํํ ์ ์์ต๋๋ค.
- ๋ชจ๋ธ์ ๋ ผ๋ฆฌ์ ์ผ๋ก ์๋ฏธ ์๋ ๊ตฌ์ฑ ์์๋ก ๋ถ๋ฆฌํ๋ ๊ฒ์ ๋ชจ๋ธ์ ์ค๊ณ์ ๋ํ ๋ ๋์ ๊ฐ์๋ฅผ ์ป๊ณ ๋ชจ๋ธ์ ๋ ์ ์ดํดํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค.
- ์ด๋ฌํ ๊ตฌ์ฑ ์์๋ณ ํ ์คํธ๋ฅผ ํตํด ์ฝ๋๋ฅผ ๋ณ๊ฒฝํ๋ฉด์ ํ๊ท๊ฐ ๋ฐ์ํ์ง ์๋๋ก ๋ณด์ฅํ ์ ์์ต๋๋ค.
Lysandre์ ELECTRA ํตํฉ ๊ฒ์ฌ๋ ์ด๋ฅผ ์ํํ๋ ์ข์ ์์ ์ ๋๋ค.
๊ทธ๋ฌ๋ ์๋ณธ ์ฝ๋ ๋ฒ ์ด์ค๊ฐ ๋งค์ฐ ๋ณต์กํ๊ฑฐ๋ ์ค๊ฐ ๊ตฌ์ฑ ์์๋ฅผ ์ปดํ์ผ๋ ๋ชจ๋์์ ์คํํ๋ ๊ฒ๋ง ํ์ฉํ๋ ๊ฒฝ์ฐ, ๋ชจ๋ธ์ ํ ์คํธ ๊ฐ๋ฅํ ์์ ํ์ ๊ตฌ์ฑ ์์๋ก ๋ถํดํ๋ ๊ฒ์ด ์๊ฐ์ด ๋ง์ด ์์๋๊ฑฐ๋ ๋ถ๊ฐ๋ฅํ ์๋ ์์ต๋๋ค. T5์ MeshTensorFlow ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ๋งค์ฐ ๋ณต์กํ๋ฉฐ ๋ชจ๋ธ์ ํ์ ๊ตฌ์ฑ ์์๋ก ๋ถํดํ๋ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ ์ ๊ณตํ์ง ์์ต๋๋ค. ์ด๋ฌํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๊ฒฝ์ฐ, ๋ณดํต print ๋ฌธ์ ํตํด ํ์ธํฉ๋๋ค.
์ด๋ค ์ ๋ต์ ์ ํํ๋๋ผ๋ ๊ถ์ฅ๋๋ ์ ์ฐจ๋ ๋์ผํฉ๋๋ค. ๋จผ์ ์์ ๋ ์ด์ด๋ฅผ ๋๋ฒ๊ทธํ๊ณ ๋ง์ง๋ง ๋ ์ด์ด๋ฅผ ๋ง์ง๋ง์ ๋๋ฒ๊ทธํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๋ค์ ์์๋ก ๊ฐ ๋ ์ด์ด์ ์ถ๋ ฅ์ ๊ฒ์ํ๋ ๊ฒ์ด ์ข์ต๋๋ค:
- ๋ชจ๋ธ์ ์ ๋ฌ๋ ์ ๋ ฅ ID ๊ฐ์ ธ์ค๊ธฐ
- ์๋ ์๋ฒ ๋ฉ ๊ฐ์ ธ์ค๊ธฐ
- ์ฒซ ๋ฒ์งธ Transformer ๋ ์ด์ด์ ์ ๋ ฅ ๊ฐ์ ธ์ค๊ธฐ
- ์ฒซ ๋ฒ์งธ Transformer ๋ ์ด์ด์ ์ถ๋ ฅ ๊ฐ์ ธ์ค๊ธฐ
- ๋ค์ n-1๊ฐ์ Transformer ๋ ์ด์ด์ ์ถ๋ ฅ ๊ฐ์ ธ์ค๊ธฐ
- BrandNewBert ๋ชจ๋ธ์ ์ถ๋ ฅ ๊ฐ์ ธ์ค๊ธฐ
์
๋ ฅ ID๋ ์ ์ ๋ฐฐ์ด๋ก ๊ตฌ์ฑ๋๋ฉฐ, ์๋ฅผ ๋ค์ด input_ids = [0, 4, 4, 3, 2, 4, 1, 7, 19]
์ ๊ฐ์ ์ ์์ต๋๋ค.
๋ค์ ๋ ์ด์ด์ ์ถ๋ ฅ์ ์ข ์ข ๋ค์ฐจ์ ์ค์ ๋ฐฐ์ด๋ก ๊ตฌ์ฑ๋๋ฉฐ, ๋ค์๊ณผ ๊ฐ์ด ๋ํ๋ผ ์ ์์ต๋๋ค:
[[
[-0.1465, -0.6501, 0.1993, ..., 0.1451, 0.3430, 0.6024],
[-0.4417, -0.5920, 0.3450, ..., -0.3062, 0.6182, 0.7132],
[-0.5009, -0.7122, 0.4548, ..., -0.3662, 0.6091, 0.7648],
...,
[-0.5613, -0.6332, 0.4324, ..., -0.3792, 0.7372, 0.9288],
[-0.5416, -0.6345, 0.4180, ..., -0.3564, 0.6992, 0.9191],
[-0.5334, -0.6403, 0.4271, ..., -0.3339, 0.6533, 0.8694]]],
๐ค Transformers์ ์ถ๊ฐ๋๋ ๋ชจ๋ ๋ชจ๋ธ์ ํตํฉ ํ ์คํธ๋ฅผ ํต๊ณผํด์ผ ํฉ๋๋ค. ์ฆ, ์๋ณธ ๋ชจ๋ธ๊ณผ ๐ค Transformers์ ์ฌ๊ตฌํ ๋ฒ์ ์ด 0.001์ ์ ๋ฐ๋๋ก ์ ํํ ๋์ผํ ์ถ๋ ฅ์ ๋ด์ผ ํฉ๋๋ค! ๋์ผํ ๋ชจ๋ธ์ด ๋ค๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ์์ฑ๋์์ ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ํ๋ ์์ํฌ์ ๋ฐ๋ผ ์ฝ๊ฐ ๋ค๋ฅธ ์ถ๋ ฅ์ ์ป๋ ๊ฒ์ ์ ์์ด๋ฏ๋ก 1e-3(0.001)์ ์ค์ฐจ๋ ํ์ฉํฉ๋๋ค. ๊ฑฐ์ ๋์ผํ ์ถ๋ ฅ์ ๋ด๋ ๊ฒ๋ง์ผ๋ก๋ ์ถฉ๋ถํ์ง ์์ผ๋ฉฐ, ์๋ฒฝํ ์ผ์นํ๋ ์์ค์ด์ด์ผ ํฉ๋๋ค. ๋ฐ๋ผ์ ๐ค Transformers ๋ฒ์ ์ ์ค๊ฐ ์ถ๋ ฅ์ brand_new_bert์ ์๋ ๊ตฌํ์ ์ค๊ฐ ์ถ๋ ฅ๊ณผ ์ฌ๋ฌ ๋ฒ ๋น๊ตํด์ผ ํฉ๋๋ค. ์ด ๊ฒฝ์ฐ ์๋ณธ ์ ์ฅ์์ ํจ์จ์ ์ธ ๋๋ฒ๊น ํ๊ฒฝ์ด ์ ๋์ ์ผ๋ก ์ค์ํฉ๋๋ค. ๋๋ฒ๊น ํ๊ฒฝ์ ๊ฐ๋ฅํ ํ ํจ์จ์ ์ผ๋ก ๋ง๋๋ ๋ช ๊ฐ์ง ์กฐ์ธ์ ์ ์ํฉ๋๋ค.
- ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ๋๋ฒ๊ทธํ๋ ๊ฐ์ฅ ์ข์ ๋ฐฉ๋ฒ์ ์ฐพ์ผ์ธ์. ์๋ณธ ์ ์ฅ์๊ฐ PyTorch๋ก ์์ฑ๋์๋ค๋ฉด ์๋ณธ ๋ชจ๋ธ์ ๋ ์์ ํ์ ๊ตฌ์ฑ ์์๋ก ๋ถํดํ์ฌ ์ค๊ฐ ๊ฐ์ ๊ฒ์ํ๋ ๊ธด ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ๋ ๊ฒ์ ์๊ฐ์ ํฌ์ํ ๊ฐ์น๊ฐ ์์ต๋๋ค. ์๋ณธ ์ ์ฅ์๊ฐ Tensorflow 1๋ก ์์ฑ๋์๋ค๋ฉด tf.print์ ๊ฐ์ Tensorflow ์ถ๋ ฅ ์์ ์ ์ฌ์ฉํ์ฌ ์ค๊ฐ ๊ฐ์ ์ถ๋ ฅํด์ผ ํ ์๋ ์์ต๋๋ค. ์๋ณธ ์ ์ฅ์๊ฐ Jax๋ก ์์ฑ๋์๋ค๋ฉด forward pass๋ฅผ ์คํํ ๋ ๋ชจ๋ธ์ด jit ๋์ง ์๋๋ก ํด์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด ์ด ๋งํฌ๋ฅผ ํ์ธํด ๋ณด์ธ์.
- ์ฌ์ฉ ๊ฐ๋ฅํ ๊ฐ์ฅ ์์ ์ฌ์ ํ๋ จ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ์ธ์. ์ฒดํฌํฌ์ธํธ๊ฐ ์์์๋ก ๋๋ฒ๊ทธ ์ฌ์ดํด์ด ๋ ๋นจ๋ผ์ง๋๋ค. ์ ๋ฐ์ ์ผ๋ก forward pass์ 10์ด ์ด์์ด ๊ฑธ๋ฆฌ๋ ๊ฒฝ์ฐ ํจ์จ์ ์ด์ง ์์ต๋๋ค. ๋งค์ฐ ํฐ ์ฒดํฌํฌ์ธํธ๋ง ์ฌ์ฉํ ์ ์๋ ๊ฒฝ์ฐ, ์ ํ๊ฒฝ์์ ์์๋ก ์ด๊ธฐํ๋ ๊ฐ์ค์น๋ก ๋๋ฏธ ๋ชจ๋ธ์ ๋ง๋ค๊ณ ํด๋น ๊ฐ์ค์น๋ฅผ ๐ค Transformers ๋ฒ์ ๊ณผ ๋น๊ตํ๊ธฐ ์ํด ์ ์ฅํ๋ ๊ฒ์ด ๋ ์๋ฏธ๊ฐ ์์ ์ ์์ต๋๋ค.
- ๋๋ฒ๊น
์ค์ ์์ ๊ฐ์ฅ ์ฝ๊ฒ forward pass๋ฅผ ํธ์ถํ๋ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ธ์. ์๋ณธ ์ ์ฅ์์์ ๋จ์ผ forward pass๋ง ํธ์ถํ๋ ํจ์๋ฅผ ์ฐพ๋ ๊ฒ์ด ์ด์์ ์
๋๋ค. ์ด ํจ์๋ ์ผ๋ฐ์ ์ผ๋ก
predict
,evaluate
,forward
,__call__
๊ณผ ๊ฐ์ด ํธ์ถ๋ฉ๋๋ค.autoregressive_sample
๊ณผ ๊ฐ์ ํ ์คํธ ์์ฑ์์forward
๋ฅผ ์ฌ๋ฌ ๋ฒ ํธ์ถํ์ฌ ํ ์คํธ๋ฅผ ์์ฑํ๋ ๋ฑ์ ์์ ์ ์ํํ๋ ํจ์๋ฅผ ๋๋ฒ๊ทธํ๊ณ ์ถ์ง ์์ ๊ฒ์ ๋๋ค. - ํ ํฐํ ๊ณผ์ ์ ๋ชจ๋ธ์ forward pass์ ๋ถ๋ฆฌํ๋ ค๊ณ ๋ ธ๋ ฅํ์ธ์. ์๋ณธ ์ ์ฅ์์์ ์ ๋ ฅ ๋ฌธ์์ด์ ์ ๋ ฅํด์ผ ํ๋ ์์ ๊ฐ ์๋ ๊ฒฝ์ฐ, ์ ๋ ฅ ๋ฌธ์์ด์ด ์ ๋ ฅ ID๋ก ๋ณ๊ฒฝ๋๋ ์๊ฐ์ ์ฐพ์์ ์์ํ์ธ์. ์ด ๊ฒฝ์ฐ ์ง์ ID๋ฅผ ์ ๋ ฅํ ์ ์๋๋ก ์์ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ๊ฑฐ๋ ์๋ณธ ์ฝ๋๋ฅผ ์์ ํด์ผ ํ ์๋ ์์ต๋๋ค.
- ๋๋ฒ๊น ์ค์ ์์ ๋ชจ๋ธ์ด ํ๋ จ ๋ชจ๋๊ฐ ์๋๋ผ๋ ๊ฒ์ ํ์ธํ์ธ์. ํ๋ จ ๋ชจ๋์์๋ ๋ชจ๋ธ์ ์ฌ๋ฌ ๋๋กญ์์ ๋ ์ด์ด ๋๋ฌธ์ ๋ฌด์์ ์ถ๋ ฅ์ด ์์ฑ๋ ์ ์์ต๋๋ค. ๋๋ฒ๊น ํ๊ฒฝ์์ forward pass๊ฐ ๊ฒฐ์ ๋ก ์ ์ด๋๋ก ํด์ผ ํฉ๋๋ค. ๋๋ ๋์ผํ ํ๋ ์์ํฌ์ ์๋ ๊ฒฝ์ฐ transformers.utils.set_seed๋ฅผ ์ฌ์ฉํ์ธ์.
๋ค์ ์น์ ์์๋ brand_new_bert์ ๋ํด ์ด ์์ ์ ์ํํ๋ ๋ฐ ๋ ๊ตฌ์ฒด์ ์ธ ์ธ๋ถ ์ฌํญ/ํ์ ์ ๊ณตํฉ๋๋ค.
5.-14. ๐ค Transformers์ BrandNewBert๋ฅผ ์ด์ํ๊ธฐ [[5.-14.-port-brandnewbert-to-transformers]]
์ด์ , ๋ง์นจ๋ด ๐ค Transformers์ ์๋ก์ด ์ฝ๋๋ฅผ ์ถ๊ฐํ ์ ์์ต๋๋ค. ๐ค Transformers ํฌํฌ์ ํด๋ก ์ผ๋ก ์ด๋ํ์ธ์:
cd transformers
๋ค์๊ณผ ๊ฐ์ด ์ด๋ฏธ ์กด์ฌํ๋ ๋ชจ๋ธ์ ๋ชจ๋ธ ์ํคํ ์ฒ์ ์ ํํ ์ผ์นํ๋ ๋ชจ๋ธ์ ์ถ๊ฐํ๋ ํน๋ณํ ๊ฒฝ์ฐ์๋ ์ด ์น์ ์ ์ค๋ช ๋๋๋ก ๋ณํ ์คํฌ๋ฆฝํธ๋ง ์ถ๊ฐํ๋ฉด ๋ฉ๋๋ค. ์ด ๊ฒฝ์ฐ์๋ ์ด๋ฏธ ์กด์ฌํ๋ ๋ชจ๋ธ์ ์ ์ฒด ๋ชจ๋ธ ์ํคํ ์ฒ๋ฅผ ๊ทธ๋๋ก ์ฌ์ฌ์ฉํ ์ ์์ต๋๋ค.
๊ทธ๋ ์ง ์์ผ๋ฉด ์๋ก์ด ๋ชจ๋ธ ์์ฑ์ ์์ํฉ์๋ค. ์ฌ๊ธฐ์์ ๋ ๊ฐ์ง ์ ํ์ง๊ฐ ์์ต๋๋ค:
transformers-cli add-new-model-like
๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ์กด ๋ชจ๋ธ๊ณผ ์ ์ฌํ ์๋ก์ด ๋ชจ๋ธ ์ถ๊ฐํ๊ธฐtransformers-cli add-new-model
์ ์ฌ์ฉํ์ฌ ํ ํ๋ฆฟ์ ๊ธฐ๋ฐ์ผ๋ก ํ ์๋ก์ด ๋ชจ๋ธ ์ถ๊ฐํ๊ธฐ (์ ํํ ๋ชจ๋ธ ์ ํ์ ๋ฐ๋ผ BERT ๋๋ Bart์ ์ ์ฌํ ๋ชจ์ต์ผ ๊ฒ์ ๋๋ค)
๋ ๊ฒฝ์ฐ ๋ชจ๋, ๋ชจ๋ธ์ ๊ธฐ๋ณธ ์ ๋ณด๋ฅผ ์
๋ ฅํ๋ ์ค๋ฌธ์กฐ์ฌ๊ฐ ์ ์๋ฉ๋๋ค. ๋ ๋ฒ์งธ ๋ช
๋ น์ด๋ cookiecutter
๋ฅผ ์ค์นํด์ผ ํฉ๋๋ค. ์์ธํ ์ ๋ณด๋ ์ฌ๊ธฐ์์ ํ์ธํ ์ ์์ต๋๋ค.
huggingface/transformers ๋ฉ์ธ ์ ์ฅ์์ Pull Request ์ด๊ธฐ
์๋์ผ๋ก ์์ฑ๋ ์ฝ๋๋ฅผ ์์ ํ๊ธฐ ์ ์, ์ง๊ธ์ "์์ ์งํ ์ค (WIP)" ํ ๋ฆฌํ์คํธ๋ฅผ ์ด๊ธฐ ์ํ ์๊ธฐ์ ๋๋ค. ์๋ฅผ ๋ค์ด, ๐ค Transformers์ "brand_new_bert ์ถ๊ฐ"๋ผ๋ ์ ๋ชฉ์ "[WIP] Add brand_new_bert" ํ ๋ฆฌํ์คํธ๋ฅผ ์ฝ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋น์ ๊ณผ Hugging Face ํ์ด ๐ค Transformers์ ๋ชจ๋ธ์ ํตํฉํ๋ ์์ ์ ํจ๊ปํ ์ ์์ต๋๋ค.
๋ค์์ ์ํํด์ผ ํฉ๋๋ค:
- ๋ฉ์ธ ๋ธ๋์น์์ ์์ ์ ์ ์ค๋ช ํ๋ ์ด๋ฆ์ผ๋ก ๋ธ๋์น ์์ฑ
git checkout -b add_brand_new_bert
- ์๋์ผ๋ก ์์ฑ๋ ์ฝ๋ ์ปค๋ฐ
git add .
git commit
- ํ์ฌ ๋ฉ์ธ์ ๊ฐ์ ธ์ค๊ณ ๋ฆฌ๋ฒ ์ด์ค
git fetch upstream
git rebase upstream/main
- ๋ณ๊ฒฝ ์ฌํญ์ ๊ณ์ ์ ํธ์
git push -u origin a-descriptive-name-for-my-changes
๋ง์กฑ์ค๋ฝ๋ค๋ฉด, GitHub์์ ์์ ์ ํฌํฌํ ์น ํ์ด์ง๋ก ์ด๋ํฉ๋๋ค. "Pull request"๋ฅผ ํด๋ฆญํฉ๋๋ค. Hugging Face ํ์ ์ผ๋ถ ๋ฉค๋ฒ์ GitHub ํธ๋ค์ ๋ฆฌ๋ทฐ์ด๋ก ์ถ๊ฐํ์ฌ Hugging Face ํ์ด ์์ผ๋ก์ ๋ณ๊ฒฝ ์ฌํญ์ ๋ํด ์๋ฆผ์ ๋ฐ์ ์ ์๋๋ก ํฉ๋๋ค.
GitHub ํ ๋ฆฌํ์คํธ ์น ํ์ด์ง ์ค๋ฅธ์ชฝ์ ์๋ "Convert to draft"๋ฅผ ํด๋ฆญํ์ฌ PR์ ์ด์์ผ๋ก ๋ณ๊ฒฝํฉ๋๋ค.
๋ค์์ผ๋ก, ์ด๋ค ์ง์ ์ ์ด๋ฃจ์๋ค๋ฉด ์์ ์ ์ปค๋ฐํ๊ณ ๊ณ์ ์ ํธ์ํ์ฌ ํ ๋ฆฌํ์คํธ์ ํ์๋๋๋ก ํด์ผ ํฉ๋๋ค. ๋ํ, ๋ค์๊ณผ ๊ฐ์ด ํ์ฌ ๋ฉ์ธ๊ณผ ์์ ์ ์ ๋ฐ์ดํธํด์ผ ํฉ๋๋ค:
git fetch upstream
git merge upstream/main
์ผ๋ฐ์ ์ผ๋ก, ๋ชจ๋ธ ๋๋ ๊ตฌํ์ ๊ดํ ๋ชจ๋ ์ง๋ฌธ์ ์์ ์ PR์์ ํด์ผ ํ๋ฉฐ, PR์์ ํ ๋ก ๋๊ณ ํด๊ฒฐ๋์ด์ผ ํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด Hugging Face ํ์ด ์๋ก์ด ์ฝ๋๋ฅผ ์ปค๋ฐํ๊ฑฐ๋ ์ง๋ฌธ์ ํ ๋ ํญ์ ์๋ฆผ์ ๋ฐ์ ์ ์์ต๋๋ค. Hugging Face ํ์๊ฒ ๋ฌธ์ ๋๋ ์ง๋ฌธ์ ํจ์จ์ ์ผ๋ก ์ดํดํ ์ ์๋๋ก ์ถ๊ฐํ ์ฝ๋๋ฅผ ๋ช ์ํ๋ ๊ฒ์ด ๋์์ด ๋ ๋๊ฐ ๋ง์ต๋๋ค.
์ด๋ฅผ ์ํด, ๋ณ๊ฒฝ ์ฌํญ์ ๋ชจ๋ ๋ณผ ์ ์๋ "Files changed" ํญ์ผ๋ก ์ด๋ํ์ฌ ์ง๋ฌธํ๊ณ ์ ํ๋ ์ค๋ก ์ด๋ํ ๋ค์ "+" ๊ธฐํธ๋ฅผ ํด๋ฆญํ์ฌ ์ฝ๋ฉํธ๋ฅผ ์ถ๊ฐํ ์ ์์ต๋๋ค. ์ง๋ฌธ์ด๋ ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋๋ฉด, ์์ฑ๋ ์ฝ๋ฉํธ์ "Resolve" ๋ฒํผ์ ํด๋ฆญํ ์ ์์ต๋๋ค.
๋ง์ฐฌ๊ฐ์ง๋ก, Hugging Face ํ์ ์ฝ๋๋ฅผ ๋ฆฌ๋ทฐํ ๋ ์ฝ๋ฉํธ๋ฅผ ๋จ๊ธธ ๊ฒ์ ๋๋ค. ์ฐ๋ฆฌ๋ PR์์ ๋๋ถ๋ถ์ ์ง๋ฌธ์ GitHub์์ ๋ฌป๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. ๊ณต๊ฐ์ ํฌ๊ฒ ๋์์ด ๋์ง ์๋ ๋งค์ฐ ์ผ๋ฐ์ ์ธ ์ง๋ฌธ์ ๊ฒฝ์ฐ, Slack์ด๋ ์ด๋ฉ์ผ์ ํตํด Hugging Face ํ์๊ฒ ๋ฌธ์ํ ์ ์์ต๋๋ค.
5. brand_new_bert์ ๋ํด ์์ฑ๋ ๋ชจ๋ธ ์ฝ๋๋ฅผ ์ ์ฉํ๊ธฐ
๋จผ์ , ์ฐ๋ฆฌ๋ ๋ชจ๋ธ ์์ฒด์๋ง ์ด์ ์ ๋ง์ถ๊ณ ํ ํฌ๋์ด์ ์ ๋ํด์๋ ์ ๊ฒฝ ์ฐ์ง ์์ ๊ฒ์
๋๋ค. ๋ชจ๋ ๊ด๋ จ ์ฝ๋๋ ๋ค์์ ์์ฑ๋ ํ์ผ์์ ์ฐพ์ ์ ์์ต๋๋ค: src/transformers/models/brand_new_bert/modeling_brand_new_bert.py
๋ฐ src/transformers/models/brand_new_bert/configuration_brand_new_bert.py
.
์ด์ ๋ง์นจ๋ด ์ฝ๋ฉ์ ์์ํ ์ ์์ต๋๋ค :). src/transformers/models/brand_new_bert/modeling_brand_new_bert.py
์ ์์ฑ๋ ์ฝ๋๋ ์ธ์ฝ๋ ์ ์ฉ ๋ชจ๋ธ์ธ ๊ฒฝ์ฐ BERT์ ๋์ผํ ์ํคํ
์ฒ๋ฅผ ๊ฐ์ง๊ฑฐ๋, ์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ์ธ ๊ฒฝ์ฐ BART์ ๋์ผํ ์ํคํ
์ฒ๋ฅผ ๊ฐ์ง ๊ฒ์
๋๋ค. ์ด ์์ ์์, ๋ชจ๋ธ์ ์ด๋ก ์ ์ธก๋ฉด์ ๋ํด ๋ฐฐ์ด ๋ด์ฉ์ ๋ค์ ์๊ธฐํด์ผ ํฉ๋๋ค: ๋ชจ๋ธ์ด BERT ๋๋ BART์ ์ด๋ป๊ฒ ๋ค๋ฅธ๊ฐ์?. ์์ฃผ ๋ณ๊ฒฝํด์ผ ํ๋ ๊ฒ์ self-attention ๋ ์ด์ด, ์ ๊ทํ ๋ ์ด์ด์ ์์ ๋ฑ์ ๋ณ๊ฒฝํ๋ ๊ฒ์
๋๋ค. ๋ค์ ๋งํ์ง๋ง, ์์ ์ ๋ชจ๋ธ์ ๊ตฌํํ๋ ๋ฐ ๋์์ด ๋๋๋ก Transformers์์ ์ด๋ฏธ ์กด์ฌํ๋ ๋ชจ๋ธ์ ์ ์ฌํ ์ํคํ
์ฒ๋ฅผ ์ดํด๋ณด๋ ๊ฒ์ด ์ ์ฉํ ์ ์์ต๋๋ค.
์ฐธ๊ณ ๋ก ์ด ์์ ์์, ์ฝ๋๊ฐ ์์ ํ ์ ํํ๊ฑฐ๋ ๊นจ๋ํ๋ค๊ณ ํ์ ํ ํ์๋ ์์ต๋๋ค. ์คํ๋ ค ์ฒ์์๋ ์๋ณธ ์ฝ๋์ ์ฒซ ๋ฒ์งธ ๋ถ์์ ํ๊ณ ๋ณต์ฌ๋ ๋ฒ์ ์ src/transformers/models/brand_new_bert/modeling_brand_new_bert.py
์ ์ถ๊ฐํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ํ์ํ ๋ชจ๋ ์ฝ๋๊ฐ ์ถ๊ฐ๋ ๋๊น์ง ์ด๋ฌํ ์์
์ ์งํํ ํ, ๋ค์ ์น์
์์ ์ค๋ช
ํ ๋ณํ ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํ์ฌ ์ฝ๋๋ฅผ ์ ์ง์ ์ผ๋ก ๊ฐ์ ํ๊ณ ์์ ํ๋ ๊ฒ์ด ํจ์ฌ ํจ์จ์ ์
๋๋ค. ์ด ์์ ์์ ์๋ํด์ผ ํ๋ ์ ์ผํ ๊ฒ์ ๋ค์ ๋ช
๋ น์ด ์๋ํ๋ ๊ฒ์
๋๋ค:
from transformers import BrandNewBertModel, BrandNewBertConfig
model = BrandNewBertModel(BrandNewBertConfig())
์์ ๋ช
๋ น์ BrandNewBertConfig()
์ ์ ์๋ ๊ธฐ๋ณธ ๋งค๊ฐ๋ณ์์ ๋ฐ๋ผ ๋ฌด์์ ๊ฐ์ค์น๋ก ๋ชจ๋ธ์ ์์ฑํ๋ฉฐ, ์ด๋ก์จ ๋ชจ๋ ๊ตฌ์ฑ ์์์ init()
๋ฉ์๋๊ฐ ์๋ํจ์ ๋ณด์ฅํฉ๋๋ค.
๋ชจ๋ ๋ฌด์์ ์ด๊ธฐํ๋ BrandnewBertPreTrainedModel
ํด๋์ค์ _init_weights
๋ฉ์๋์์ ์ํ๋์ด์ผ ํฉ๋๋ค. ์ด ๋ฉ์๋๋ ๊ตฌ์ฑ ์ค์ ๋ณ์์ ๋ฐ๋ผ ๋ชจ๋ ๋ฆฌํ ๋ชจ๋์ ์ด๊ธฐํํด์ผ ํฉ๋๋ค. BERT์ _init_weights
๋ฉ์๋ ์์ ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
๋ช ๊ฐ์ง ๋ชจ๋์ ๋ํด ํน๋ณํ ์ด๊ธฐํ๊ฐ ํ์ํ ๊ฒฝ์ฐ ์ฌ์ฉ์ ์ ์ ๋ฐฉ์์ ์ฌ์ฉํ ์๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, Wav2Vec2ForPreTraining
์์ ๋ง์ง๋ง ๋ ๊ฐ์ ์ ํ ๋ ์ด์ด๋ ์ผ๋ฐ์ ์ธ PyTorch nn.Linear
์ ์ด๊ธฐํ๋ฅผ ๊ฐ์ ธ์ผ ํ์ง๋ง, ๋ค๋ฅธ ๋ชจ๋ ๋ ์ด์ด๋ ์์ ๊ฐ์ ์ด๊ธฐํ๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. ์ด๋ ๋ค์๊ณผ ๊ฐ์ด ์ฝ๋ํ๋ฉ๋๋ค:
def _init_weights(self, module):
"""Initialize the weights"""
if isinstnace(module, Wav2Vec2ForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
_is_hf_initialized
ํ๋๊ทธ๋ ์๋ธ๋ชจ๋์ ํ ๋ฒ๋ง ์ด๊ธฐํํ๋๋ก ๋ด๋ถ์ ์ผ๋ก ์ฌ์ฉ๋ฉ๋๋ค. module.project_q
๋ฐ module.project_hid
์ ๋ํด True
๋ก ์ค์ ํจ์ผ๋ก์จ, ์ฐ๋ฆฌ๊ฐ ์ํํ ์ฌ์ฉ์ ์ ์ ์ด๊ธฐํ๊ฐ ์ดํ์ ๋ฎ์ด์ฐ์ด์ง ์๋๋ก ํฉ๋๋ค. ์ฆ, _init_weights
ํจ์๊ฐ ์ด๋ค์๊ฒ ์ ์ฉ๋์ง ์์ต๋๋ค.
6. ๋ณํ ์คํฌ๋ฆฝํธ ์์ฑํ๊ธฐ
๋ค์์ผ๋ก, ๋๋ฒ๊ทธ์ ์ฌ์ฉํ ์ฒดํฌํฌ์ธํธ๋ฅผ ๊ธฐ์กด ์ ์ฅ์์์ ๋ง๋ ๐ค Transformers ๊ตฌํ๊ณผ ํธํ๋๋ ์ฒดํฌํฌ์ธํธ๋ก ๋ณํํ ์ ์๋ ๋ณํ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํด์ผ ํฉ๋๋ค. ๋ณํ ์คํฌ๋ฆฝํธ๋ฅผ ์ฒ์๋ถํฐ ์์ฑํ๋ ๊ฒ๋ณด๋ค๋ brand_new_bert์ ๋์ผํ ํ๋ ์์ํฌ๋ก ์์ฑ๋ ์ ์ฌํ ๋ชจ๋ธ์ ๋ณํํ ๊ธฐ์กด ๋ณํ ์คํฌ๋ฆฝํธ๋ฅผ ์ฐพ์๋ณด๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ๊ธฐ์กด ๋ณํ ์คํฌ๋ฆฝํธ๋ฅผ ๋ณต์ฌํ์ฌ ์ฌ์ฉ ์ฌ๋ก์ ๋ง๊ฒ ์ฝ๊ฐ ์์ ํ๋ ๊ฒ์ผ๋ก ์ถฉ๋ถํฉ๋๋ค. ๋ชจ๋ธ์ ๋ํด ์ ์ฌํ ๊ธฐ์กด ๋ณํ ์คํฌ๋ฆฝํธ๋ฅผ ์ด๋์์ ์ฐพ์ ์ ์๋์ง Hugging Face ํ์๊ฒ ๋ฌธ์ํ๋ ๊ฒ์ ๋ง์ค์ด์ง ๋ง์ธ์.
- TensorFlow์์ PyTorch๋ก ๋ชจ๋ธ์ ์ด์ ํ๋ ๊ฒฝ์ฐ, ์ข์ ์ฐธ๊ณ ์๋ฃ๋ก BERT์ ๋ณํ ์คํฌ๋ฆฝํธ ์ฌ๊ธฐ๋ฅผ ์ฐธ์กฐํ ์ ์์ต๋๋ค.
- PyTorch์์ PyTorch๋ก ๋ชจ๋ธ์ ์ด์ ํ๋ ๊ฒฝ์ฐ, ์ข์ ์ฐธ๊ณ ์๋ฃ๋ก BART์ ๋ณํ ์คํฌ๋ฆฝํธ ์ฌ๊ธฐ๋ฅผ ์ฐธ์กฐํ ์ ์์ต๋๋ค.
๋ค์์์๋ PyTorch ๋ชจ๋ธ์ด ๋ ์ด์ด ๊ฐ์ค์น๋ฅผ ์ ์ฅํ๊ณ ๋ ์ด์ด ์ด๋ฆ์ ์ ์ํ๋ ๋ฐฉ๋ฒ์ ๋ํด ๊ฐ๋จํ ์ค๋ช
ํ๊ฒ ์ต๋๋ค. PyTorch์์ ๋ ์ด์ด์ ์ด๋ฆ์ ๋ ์ด์ด์ ์ง์ ํ ํด๋์ค ์์ฑ์ ์ด๋ฆ์ผ๋ก ์ ์๋ฉ๋๋ค. ๋ค์๊ณผ ๊ฐ์ด PyTorch์์ SimpleModel
์ด๋ผ๋ ๋๋ฏธ ๋ชจ๋ธ์ ์ ์ํด ๋ด
์๋ค:
from torch import nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.dense = nn.Linear(10, 10)
self.intermediate = nn.Linear(10, 10)
self.layer_norm = nn.LayerNorm(10)
์ด์ ์ด ๋ชจ๋ธ ์ ์์ ์ธ์คํด์ค๋ฅผ ์์ฑํ ์ ์์ผ๋ฉฐ dense
, intermediate
, layer_norm
๋ฑ์ ๊ฐ์ค์น๊ฐ ๋๋คํ๊ฒ ํ ๋น๋ฉ๋๋ค. ๋ชจ๋ธ์ ์ถ๋ ฅํ์ฌ ์ํคํ
์ฒ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
model = SimpleModel()
print(model)
์ด๋ ๋ค์๊ณผ ๊ฐ์ด ์ถ๋ ฅ๋ฉ๋๋ค:
SimpleModel(
(dense): Linear(in_features=10, out_features=10, bias=True)
(intermediate): Linear(in_features=10, out_features=10, bias=True)
(layer_norm): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
)
์ฐ๋ฆฌ๋ ๋ ์ด์ด์ ์ด๋ฆ์ด PyTorch์์ ํด๋์ค ์์ฑ์ ์ด๋ฆ์ผ๋ก ์ ์๋์ด ์๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ํน์ ๋ ์ด์ด์ ๊ฐ์ค์น ๊ฐ์ ์ถ๋ ฅํ์ฌ ํ์ธํ ์ ์์ต๋๋ค:
print(model.dense.weight.data)
๊ฐ์ค์น๊ฐ ๋ฌด์์๋ก ์ด๊ธฐํ๋์์์ ํ์ธํ ์ ์์ต๋๋ค.
tensor([[-0.0818, 0.2207, -0.0749, -0.0030, 0.0045, -0.1569, -0.1598, 0.0212,
-0.2077, 0.2157],
[ 0.1044, 0.0201, 0.0990, 0.2482, 0.3116, 0.2509, 0.2866, -0.2190,
0.2166, -0.0212],
[-0.2000, 0.1107, -0.1999, -0.3119, 0.1559, 0.0993, 0.1776, -0.1950,
-0.1023, -0.0447],
[-0.0888, -0.1092, 0.2281, 0.0336, 0.1817, -0.0115, 0.2096, 0.1415,
-0.1876, -0.2467],
[ 0.2208, -0.2352, -0.1426, -0.2636, -0.2889, -0.2061, -0.2849, -0.0465,
0.2577, 0.0402],
[ 0.1502, 0.2465, 0.2566, 0.0693, 0.2352, -0.0530, 0.1859, -0.0604,
0.2132, 0.1680],
[ 0.1733, -0.2407, -0.1721, 0.1484, 0.0358, -0.0633, -0.0721, -0.0090,
0.2707, -0.2509],
[-0.1173, 0.1561, 0.2945, 0.0595, -0.1996, 0.2988, -0.0802, 0.0407,
0.1829, -0.1568],
[-0.1164, -0.2228, -0.0403, 0.0428, 0.1339, 0.0047, 0.1967, 0.2923,
0.0333, -0.0536],
[-0.1492, -0.1616, 0.1057, 0.1950, -0.2807, -0.2710, -0.1586, 0.0739,
0.2220, 0.2358]]).
๋ณํ ์คํฌ๋ฆฝํธ์์๋ ์ด๋ฌํ ๋ฌด์์๋ก ์ด๊ธฐํ๋ ๊ฐ์ค์น๋ฅผ ์ฒดํฌํฌ์ธํธ์ ํด๋น ๋ ์ด์ด์ ์ ํํ ๊ฐ์ค์น๋ก ์ฑ์์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
# retrieve matching layer weights, e.g. by
# recursive algorithm
layer_name = "dense"
pretrained_weight = array_of_dense_layer
model_pointer = getattr(model, "dense")
model_pointer.weight.data = torch.from_numpy(pretrained_weight)
์ด๋ ๊ฒ ํ๋ฉด PyTorch ๋ชจ๋ธ์ ๋ฌด์์๋ก ์ด๊ธฐํ๋ ๊ฐ ๊ฐ์ค์น์ ํด๋น ์ฒดํฌํฌ์ธํธ ๊ฐ์ค์น๊ฐ ๋ชจ์๊ณผ ์ด๋ฆ ๋ชจ๋์์ ์ ํํ ์ผ์นํ๋์ง ํ์ธํด์ผ ํฉ๋๋ค. ์ด๋ฅผ ์ํด ๋ชจ์์ ๋ํ assert ๋ฌธ์ ์ถ๊ฐํ๊ณ ์ฒดํฌํฌ์ธํธ ๊ฐ์ค์น์ ์ด๋ฆ์ ์ถ๋ ฅํด์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด ๋ค์๊ณผ ๊ฐ์ ๋ฌธ์ฅ์ ์ถ๊ฐํด์ผ ํฉ๋๋ค:
assert (
model_pointer.weight.shape == pretrained_weight.shape
), f"Pointer shape of random weight {model_pointer.shape} and array shape of checkpoint weight {pretrained_weight.shape} mismatched"
๋ํ ๋ ๊ฐ์ค์น์ ์ด๋ฆ์ ์ถ๋ ฅํ์ฌ ์ผ์นํ๋์ง ํ์ธํด์ผ ํฉ๋๋ค. ์์:
logger.info(f"Initialize PyTorch weight {layer_name} from {pretrained_weight.name}")
๋ชจ์ ๋๋ ์ด๋ฆ์ด ์ผ์นํ์ง ์๋ ๊ฒฝ์ฐ, ๋๋ค์ผ๋ก ์ด๊ธฐํ๋ ๋ ์ด์ด์ ์๋ชป๋ ์ฒดํฌํฌ์ธํธ ๊ฐ์ค์น๋ฅผ ํ ๋นํ ๊ฒ์ผ๋ก ์ถ์ธก๋ฉ๋๋ค.
์๋ชป๋ ๋ชจ์์ BrandNewBertConfig()
์ ๊ตฌ์ฑ ๋งค๊ฐ๋ณ์ ์ค์ ์ด ๋ณํํ๋ ค๋ ์ฒดํฌํฌ์ธํธ์ ์ฌ์ฉ๋ ์ค์ ๊ณผ ์ ํํ ์ผ์นํ์ง ์๊ธฐ ๋๋ฌธ์ผ ๊ฐ๋ฅ์ฑ์ด ๊ฐ์ฅ ํฝ๋๋ค. ๊ทธ๋ฌ๋ PyTorch์ ๋ ์ด์ด ๊ตฌํ ์์ฒด์์ ๊ฐ์ค์น๋ฅผ ์ ์นํด์ผ ํ ์๋ ์์ต๋๋ค.
๋ง์ง๋ง์ผ๋ก, ๋ชจ๋ ํ์ํ ๊ฐ์ค์น๊ฐ ์ด๊ธฐํ๋์๋์ง ํ์ธํ๊ณ ์ด๊ธฐํ์ ์ฌ์ฉ๋์ง ์์ ๋ชจ๋ ์ฒดํฌํฌ์ธํธ ๊ฐ์ค์น๋ฅผ ์ถ๋ ฅํ์ฌ ๋ชจ๋ธ์ด ์ฌ๋ฐ๋ฅด๊ฒ ๋ณํ๋์๋์ง ํ์ธํด์ผ ํฉ๋๋ค. ์๋ชป๋ ๋ชจ์ ๋ฌธ์ฅ์ด๋ ์๋ชป๋ ์ด๋ฆ ํ ๋น์ผ๋ก ์ธํด ๋ณํ ์๋๊ฐ ์คํจํ๋ ๊ฒ์ ์์ ํ ์ ์์
๋๋ค. ์ด๋ BrandNewBertConfig()
์์ ์๋ชป๋ ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ๊ฑฐ๋ ๐ค Transformers ๊ตฌํ์์ ์๋ชป๋ ์ํคํ
์ฒ, ๐ค Transformers ๊ตฌํ์ ๊ตฌ์ฑ ์์ ์ค ํ๋์ init()
ํจ์์ ๋ฒ๊ทธ๊ฐ ์๋ ๊ฒฝ์ฐ์ด๊ฑฐ๋ ์ฒดํฌํฌ์ธํธ ๊ฐ์ค์น ์ค ํ๋๋ฅผ ์ ์นํด์ผ ํ๋ ๊ฒฝ์ฐ์ผ ๊ฐ๋ฅ์ฑ์ด ๊ฐ์ฅ ๋์ต๋๋ค.
์ด ๋จ๊ณ๋ ์ด์ ๋จ๊ณ์ ํจ๊ป ๋ฐ๋ณต๋์ด์ผ ํ๋ฉฐ ๋ชจ๋ ์ฒดํฌํฌ์ธํธ์ ๊ฐ์ค์น๊ฐ Transformers ๋ชจ๋ธ์ ์ฌ๋ฐ๋ฅด๊ฒ ๋ก๋๋์์ ๋๊น์ง ๊ณ์๋์ด์ผ ํฉ๋๋ค. ๐ค Transformers ๊ตฌํ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ๋ฐ๋ฅด๊ฒ ๋ก๋ํ ํ์๋ /path/to/converted/checkpoint/folder
์ ๊ฐ์ ์ํ๋ ํด๋์ ๋ชจ๋ธ์ ์ ์ฅํ ์ ์์ด์ผ ํฉ๋๋ค. ํด๋น ํด๋์๋ pytorch_model.bin
ํ์ผ๊ณผ config.json
ํ์ผ์ด ๋ชจ๋ ํฌํจ๋์ด์ผ ํฉ๋๋ค.
model.save_pretrained("/path/to/converted/checkpoint/folder")
7. ์๋ฐฉํฅ ํจ์ค ๊ตฌํํ๊ธฐ
๐ค Transformers ๊ตฌํ์ ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ์ ํํ๊ฒ ๋ก๋ํ ํ์๋ ์๋ฐฉํฅ ํจ์ค๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ๊ตฌํ๋์๋์ง ํ์ธํด์ผ ํฉ๋๋ค. ์๋ณธ ์ ์ฅ์์ ์ต์ํด์ง๊ธฐ์์ ์ด๋ฏธ ์๋ณธ ์ ์ฅ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์๋ฐฉํฅ ํจ์ค๋ฅผ ์คํํ๋ ์คํฌ๋ฆฝํธ๋ฅผ ๋ง๋ค์์ต๋๋ค. ์ด์ ์๋ณธ ๋์ ๐ค Transformers ๊ตฌํ์ ์ฌ์ฉํ๋ ์ ์ฌํ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํด์ผ ํฉ๋๋ค. ๋ค์๊ณผ ๊ฐ์ด ์์ฑ๋์ด์ผ ํฉ๋๋ค:
model = BrandNewBertModel.from_pretrained("/path/to/converted/checkpoint/folder")
input_ids = [0, 4, 4, 3, 2, 4, 1, 7, 19]
output = model(input_ids).last_hidden_states
๐ค Transformers ๊ตฌํ๊ณผ ์๋ณธ ๋ชจ๋ธ ๊ตฌํ์ด ์ฒ์๋ถํฐ ์ ํํ ๋์ผํ ์ถ๋ ฅ์ ์ ๊ณตํ์ง ์๊ฑฐ๋ ์๋ฐฉํฅ ํจ์ค์์ ์ค๋ฅ๊ฐ ๋ฐ์ํ ๊ฐ๋ฅ์ฑ์ด ๋งค์ฐ ๋์ต๋๋ค. ์ค๋งํ์ง ๋ง์ธ์. ์์๋ ์ผ์
๋๋ค! ๋จผ์ , ์๋ฐฉํฅ ํจ์ค์์ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ง ์๋๋ก ํด์ผ ํฉ๋๋ค. ์ข
์ข
์๋ชป๋ ์ฐจ์์ด ์ฌ์ฉ๋์ด ์ฐจ์ ๋ถ์ผ์น ์ค๋ฅ๊ฐ ๋ฐ์ํ๊ฑฐ๋ ์๋ชป๋ ๋ฐ์ดํฐ ์ ํ ๊ฐ์ฒด๊ฐ ์ฌ์ฉ๋๋ ๊ฒฝ์ฐ๊ฐ ์์ต๋๋ค. ์๋ฅผ ๋ค๋ฉด torch.long
๋์ ์ torch.float32
๊ฐ ์ฌ์ฉ๋ ๊ฒฝ์ฐ์
๋๋ค. ํด๊ฒฐํ ์ ์๋ ์ค๋ฅ๊ฐ ๋ฐ์ํ๋ฉด Hugging Face ํ์ ๋์์ ์์ฒญํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๐ค Transformers ๊ตฌํ์ด ์ฌ๋ฐ๋ฅด๊ฒ ์๋ํ๋์ง ํ์ธํ๋ ๋ง์ง๋ง ๋จ๊ณ๋ ์ถ๋ ฅ์ด 1e-3
์ ์ ๋ฐ๋๋ก ๋์ผํ์ง ํ์ธํ๋ ๊ฒ์
๋๋ค. ๋จผ์ , ์ถ๋ ฅ ๋ชจ์์ด ๋์ผํ๋๋ก ๋ณด์ฅํด์ผ ํฉ๋๋ค. ์ฆ, ๐ค Transformers ๊ตฌํ ์คํฌ๋ฆฝํธ์ ์๋ณธ ๊ตฌํ ์ฌ์ด์์ outputs.shape
๋ ๋์ผํ ๊ฐ์ ๋ฐํํด์ผ ํฉ๋๋ค. ๊ทธ ๋ค์์ผ๋ก, ์ถ๋ ฅ ๊ฐ์ด ๋์ผํ๋๋ก ํด์ผ ํฉ๋๋ค. ์ด๋ ์๋ก์ด ๋ชจ๋ธ์ ์ถ๊ฐํ ๋ ๊ฐ์ฅ ์ด๋ ค์ด ๋ถ๋ถ ์ค ํ๋์
๋๋ค. ์ถ๋ ฅ์ด ๋์ผํ์ง ์์ ์ผ๋ฐ์ ์ธ ์ค์ ์ฌ๋ก๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ์ผ๋ถ ๋ ์ด์ด๊ฐ ์ถ๊ฐ๋์ง ์์์ต๋๋ค. ์ฆ, ํ์ฑํ ๋ ์ด์ด๊ฐ ์ถ๊ฐ๋์ง ์์๊ฑฐ๋ ์์ฐจ ์ฐ๊ฒฐ์ด ๋น ์ก์ต๋๋ค.
- ๋จ์ด ์๋ฒ ๋ฉ ํ๋ ฌ์ด ์ฐ๊ฒฐ๋์ง ์์์ต๋๋ค.
- ์๋ชป๋ ์์น ์๋ฒ ๋ฉ์ด ์ฌ์ฉ๋์์ต๋๋ค. ์๋ณธ ๊ตฌํ์์๋ ์คํ์ ์ ์ฌ์ฉํฉ๋๋ค.
- ์๋ฐฉํฅ ํจ์ค ์ค์ Dropout์ด ์ ์ฉ๋์์ต๋๋ค. ์ด๋ฅผ ์์ ํ๋ ค๋ฉด model.training์ด False์ธ์ง ํ์ธํ๊ณ ์๋ฐฉํฅ ํจ์ค ์ค์ Dropout ๋ ์ด์ด๊ฐ ์๋ชป ํ์ฑํ๋์ง ์๋๋ก ํ์ธ์. ์ฆ, PyTorch์ ๊ธฐ๋ฅ์ Dropout์ self.training์ ์ ๋ฌํ์ธ์.
๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๊ฐ์ฅ ์ข์ ๋ฐฉ๋ฒ์ ์ผ๋ฐ์ ์ผ๋ก ์๋ณธ ๊ตฌํ๊ณผ ๐ค Transformers ๊ตฌํ์ ์๋ฐฉํฅ ํจ์ค๋ฅผ ๋๋ํ ๋๊ณ ์ฐจ์ด์ ์ด ์๋์ง ํ์ธํ๋ ๊ฒ์
๋๋ค. ์ด์์ ์ผ๋ก๋ ์๋ฐฉํฅ ํจ์ค์ ์ค๊ฐ ์ถ๋ ฅ์ ๋๋ฒ๊ทธ/์ถ๋ ฅํ์ฌ ์๋ณธ ๊ตฌํ๊ณผ ๐ค Transformers ๊ตฌํ์ ์ ํํ ์์น๋ฅผ ์ฐพ์ ์ ์์ด์ผ ํฉ๋๋ค. ๋จผ์ , ๋ ์คํฌ๋ฆฝํธ์ ํ๋์ฝ๋ฉ๋ input_ids
๊ฐ ๋์ผํ์ง ํ์ธํ์ธ์. ๋ค์์ผ๋ก, input_ids
์ ์ฒซ ๋ฒ์งธ ๋ณํ์ ์ถ๋ ฅ(์ผ๋ฐ์ ์ผ๋ก ๋จ์ด ์๋ฒ ๋ฉ)์ด ๋์ผํ์ง ํ์ธํ์ธ์. ๊ทธ๋ฐ ๋ค์ ๋คํธ์ํฌ์ ๊ฐ์ฅ ๋ง์ง๋ง ๋ ์ด์ด๊น์ง ์งํํด๋ณด์ธ์. ์ด๋ ์์ ์์ ๋ ๊ตฌํ ์ฌ์ด์ ์ฐจ์ด๊ฐ ์๋ ๊ฒ์ ์๊ฒ ๋๋๋ฐ, ์ด๋ ๐ค Transformers ๊ตฌํ์ ๋ฒ๊ทธ ์์น๋ฅผ ๊ฐ๋ฆฌํฌ ๊ฒ์
๋๋ค. ์ ํฌ ๊ฒฝํ์์ผ๋ก๋ ์๋ณธ ๊ตฌํ๊ณผ ๐ค Transformers ๊ตฌํ ๋ชจ๋์์ ๋์ผํ ์์น์ ๋ง์ ์ถ๋ ฅ ๋ฌธ์ ์ถ๊ฐํ๊ณ ์ด๋ค์ ์ค๊ฐ ํํ์ ๋ํด ๋์ผํ ๊ฐ์ ๋ณด์ด๋ ์ถ๋ ฅ ๋ฌธ์ ์ฐ์์ ์ผ๋ก ์ ๊ฑฐํ๋ ๊ฒ์ด ๊ฐ๋จํ๊ณ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์
๋๋ค.
torch.allclose(original_output, output, atol=1e-3)
๋ก ์ถ๋ ฅ์ ํ์ธํ์ฌ ๋ ๊ตฌํ์ด ๋์ผํ ์ถ๋ ฅ์ ํ๋ ๊ฒ์ ํ์ ํ๋ค๋ฉด, ๊ฐ์ฅ ์ด๋ ค์ด ๋ถ๋ถ์ ๋๋ฌ์ต๋๋ค! ์ถํ๋๋ฆฝ๋๋ค. ๋จ์ ์์
์ ์ฌ์ด ์ผ์ด ๋ ๊ฒ์
๋๋ค ๐.
8. ํ์ํ ๋ชจ๋ ๋ชจ๋ธ ํ ์คํธ ์ถ๊ฐํ๊ธฐ
์ด ์์ ์์ ์๋ก์ด ๋ชจ๋ธ์ ์ฑ๊ณต์ ์ผ๋ก ์ถ๊ฐํ์ต๋๋ค. ๊ทธ๋ฌ๋ ํด๋น ๋ชจ๋ธ์ด ์๊ตฌ๋๋ ๋์์ธ์ ์์ ํ ๋ถํฉํ์ง ์์ ์๋ ์์ต๋๋ค. ๐ค Transformers์ ์๋ฒฝํ๊ฒ ํธํ๋๋ ๊ตฌํ์ธ์ง ํ์ธํ๊ธฐ ์ํด ๋ชจ๋ ์ผ๋ฐ ํ
์คํธ๋ฅผ ํต๊ณผํด์ผ ํฉ๋๋ค. Cookiecutter๋ ์๋ง๋ ๋ชจ๋ธ์ ์ํ ํ
์คํธ ํ์ผ์ ์๋์ผ๋ก ์ถ๊ฐํ์ ๊ฒ์
๋๋ค. ์๋ง๋ tests/models/brand_new_bert/test_modeling_brand_new_bert.py
์ ๊ฐ์ ๊ฒฝ๋ก์ ์์นํ ๊ฒ์
๋๋ค. ์ด ํ
์คํธ ํ์ผ์ ์คํํ์ฌ ์ผ๋ฐ ํ
์คํธ๊ฐ ๋ชจ๋ ํต๊ณผํ๋์ง ํ์ธํ์ธ์.
pytest tests/models/brand_new_bert/test_modeling_brand_new_bert.py
๋ชจ๋ ์ผ๋ฐ ํ ์คํธ๋ฅผ ์์ ํ ํ, ์ด์ ์ํํ ์์ ์ ์ถฉ๋ถํ ํ ์คํธํ์ฌ ๋ค์ ์ฌํญ์ ๋ณด์ฅํด์ผ ํฉ๋๋ค.
- a) ์ปค๋ฎค๋ํฐ๊ฐ brand_new_bert์ ํน์ ํ ์คํธ๋ฅผ ์ดํด๋ด์ผ๋ก์จ ์์ ์ ์ฝ๊ฒ ์ดํดํ ์ ์๋๋ก ํจ
- b) ๋ชจ๋ธ์ ๋ํ ํฅํ ๋ณ๊ฒฝ ์ฌํญ์ด ๋ชจ๋ธ์ ์ค์ํ ๊ธฐ๋ฅ์ ์์์ํค์ง ์๋๋ก ํจ
๋จผ์ ํตํฉ ํ
์คํธ๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค. ์ด๋ฌํ ํตํฉ ํ
์คํธ๋ ์ด์ ์ ๋ชจ๋ธ์ ๐ค Transformers๋ก ๊ตฌํํ๊ธฐ ์ํด ์ฌ์ฉํ ๋๋ฒ๊น
์คํฌ๋ฆฝํธ์ ๋์ผํ ์์
์ ์ํํฉ๋๋ค. Cookiecutter์ ์ด๋ฏธ ์ด๋ฌํ ๋ชจ๋ธ ํ
์คํธ์ ํ
ํ๋ฆฟ์ธ BrandNewBertModelIntegrationTests
๊ฐ ์ถ๊ฐ๋์ด ์์ผ๋ฉฐ, ์ฌ๋ฌ๋ถ์ด ์์ฑํด์ผ ํ ๋ด์ฉ์ผ๋ก๋ง ์ฑ์ ๋ฃ์ผ๋ฉด ๋ฉ๋๋ค. ์ด๋ฌํ ํ
์คํธ๊ฐ ํต๊ณผํ๋์ง ํ์ธํ๋ ค๋ฉด ๋ค์์ ์คํํ์ธ์.
RUN_SLOW=1 pytest -sv tests/models/brand_new_bert/test_modeling_brand_new_bert.py::BrandNewBertModelIntegrationTests
Windows๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ RUN_SLOW=1
์ SET RUN_SLOW=1
๋ก ๋ฐ๊ฟ์ผ ํฉ๋๋ค.
๋์งธ๋ก, brand_new_bert์ ํนํ๋ ๋ชจ๋ ๊ธฐ๋ฅ๋ ๋ณ๋์ ํ ์คํธ์์ ์ถ๊ฐ๋ก ํ ์คํธํด์ผ ํฉ๋๋ค. ์ด ๋ถ๋ถ์ ์ข ์ข ์ํ๋๋ฐ, ๋ ๊ฐ์ง ์ธก๋ฉด์์ ๊ต์ฅํ ์ ์ฉํฉ๋๋ค.
- brand_new_bert์ ํน์ ๊ธฐ๋ฅ์ด ์ด๋ป๊ฒ ์๋ํด์ผ ํ๋์ง ๋ณด์ฌ์ค์ผ๋ก์จ ์ปค๋ฎค๋ํฐ์๊ฒ ๋ชจ๋ธ ์ถ๊ฐ ๊ณผ์ ์์ ์ต๋ํ ์ง์์ ์ ๋ฌํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค.
- ํฅํ ๊ธฐ์ฌ์๋ ์ด๋ฌํ ํน์ ํ ์คํธ๋ฅผ ์คํํ์ฌ ๋ชจ๋ธ์ ๋ํ ๋ณ๊ฒฝ ์ฌํญ์ ๋น ๋ฅด๊ฒ ํ ์คํธํ ์ ์์ต๋๋ค.
9. ํ ํฌ๋์ด์ ๊ตฌํํ๊ธฐ
๋ค์์ผ๋ก, brand_new_bert์ ํ ํฌ๋์ด์ ๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค. ๋ณดํต ํ ํฌ๋์ด์ ๋ ๐ค Transformers์ ๊ธฐ์กด ํ ํฌ๋์ด์ ์ ๋์ผํ๊ฑฐ๋ ๋งค์ฐ ์ ์ฌํฉ๋๋ค.
ํ ํฌ๋์ด์ ๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ์๋ํ๋์ง ํ์ธํ๊ธฐ ์ํด ๋จผ์ ์๋ณธ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ๋ฌธ์์ด์ ์
๋ ฅํ๊ณ input_ids
๋ฅผ ๋ฐํํ๋ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๋ค์๊ณผ ๊ฐ์ ์ ์ฌํ ์คํฌ๋ฆฝํธ์ผ ์ ์์ต๋๋ค (์์ฌ ์ฝ๋๋ก ์์ฑ):
input_str = "This is a long example input string containing special characters .$?-, numbers 2872 234 12 and words."
model = BrandNewBertModel.load_pretrained_checkpoint("/path/to/checkpoint/")
input_ids = model.tokenize(input_str)
์๋ณธ ๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ ์์ธํ ์ดํด๋ณด๊ณ ์ฌ๋ฐ๋ฅธ ํ ํฌ๋์ด์ ํจ์๋ฅผ ์ฐพ๊ฑฐ๋, ๋ณต์ ๋ณธ์์ ๋ณ๊ฒฝ ์ฌํญ์ ์ ์ฉํ์ฌ input_ids
๋ง ์ถ๋ ฅํ๋๋ก ํด์ผ ํฉ๋๋ค. ์๋ณธ ๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ ๊ธฐ๋ฅ์ ์ธ ํ ํฐํ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ ํ, ๐ค Transformers์ ์ ์ฌํ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํด์ผ ํฉ๋๋ค. ๋ค์๊ณผ ๊ฐ์ด ์์ฑ๋์ด์ผ ํฉ๋๋ค:
from transformers import BrandNewBertTokenizer
input_str = "This is a long example input string containing special characters .$?-, numbers 2872 234 12 and words."
tokenizer = BrandNewBertTokenizer.from_pretrained("/path/to/tokenizer/folder/")
input_ids = tokenizer(input_str).input_ids
๋ ๊ฐ์ input_ids
๊ฐ ๋์ผํ ๊ฐ์ ๋ฐํํ ๋, ๋ง์ง๋ง ๋จ๊ณ๋ก ํ ํฌ๋์ด์ ํ
์คํธ ํ์ผ๋ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
brand_new_bert์ ๋ชจ๋ธ๋ง ํ ์คํธ ํ์ผ๊ณผ ์ ์ฌํ๊ฒ, brand_new_bert์ ํ ํฌ๋์ด์ ์ด์ ํ ์คํธ ํ์ผ์๋ ๋ช ๊ฐ์ง ํ๋์ฝ๋ฉ๋ ํตํฉ ํ ์คํธ๊ฐ ํฌํจ๋์ด์ผ ํฉ๋๋ค.
10. ์ข ๋จ ๊ฐ ํตํฉ ํ ์คํธ ์คํ
ํ ํฌ๋์ด์ ๋ฅผ ์ถ๊ฐํ ํ์๋ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํ์ฌ ๋ช ๊ฐ์ง ์ข
๋จ ๊ฐ ํตํฉ ํ
์คํธ๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค. tests/models/brand_new_bert/test_modeling_brand_new_bert.py
์ ์ถ๊ฐํด์ฃผ์ธ์. ์ด๋ฌํ ํ
์คํธ๋ ๐ค Transformers ๊ตฌํ์ด ์์๋๋ก ์๋ํ๋์ง๋ฅผ ์๋ฏธ ์๋ text-to-text ์์๋ก ๋ณด์ฌ์ค์ผ ํฉ๋๋ค. ๊ทธ ์์๋ก๋ ์๋ฅผ ๋ค์ด source-to-target ๋ฒ์ญ ์, article-to-summary ์, question-to-answer ์ ๋ฑ์ด ํฌํจ๋ ์ ์์ต๋๋ค. ๋ถ๋ฌ์จ ์ฒดํฌํฌ์ธํธ ์ค ์ด๋ ๊ฒ๋ ๋ค์ด์คํธ๋ฆผ ์์
์์ ๋ฏธ์ธ ์กฐ์ ๋์ง ์์๋ค๋ฉด, ๋ชจ๋ธ ํ
์คํธ๋ง์ผ๋ก ์ถฉ๋ถํฉ๋๋ค. ๋ชจ๋ธ์ด ์์ ํ ๊ธฐ๋ฅ์ ๊ฐ์ถ์๋์ง ํ์ธํ๊ธฐ ์ํด ๋ง์ง๋ง ๋จ๊ณ๋ก GPU์์ ๋ชจ๋ ํ
์คํธ๋ฅผ ์คํํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ๋ชจ๋ธ์ ๋ด๋ถ ํ
์์ ์ผ๋ถ์ .to(self.device)
๋ฌธ์ ์ถ๊ฐํ๋ ๊ฒ์ ์์์ ์ ์์ผ๋ฉฐ, ์ด ๊ฒฝ์ฐ ํ
์คํธ์์ ์ค๋ฅ๋ก ํ์๋ฉ๋๋ค. GPU์ ์ก์ธ์คํ ์ ์๋ ๊ฒฝ์ฐ, Hugging Face ํ์ด ํ
์คํธ๋ฅผ ๋์ ์คํํ ์ ์์ต๋๋ค.
11. ๊ธฐ์ ๋ฌธ์ ์ถ๊ฐ
์ด์ brand_new_bert์ ํ์ํ ๋ชจ๋ ๊ธฐ๋ฅ์ด ์ถ๊ฐ๋์์ต๋๋ค. ๊ฑฐ์ ๋๋ฌ์ต๋๋ค! ์ถ๊ฐํด์ผ ํ ๊ฒ์ ๋ฉ์ง ๊ธฐ์ ๋ฌธ์๊ณผ ๊ธฐ์ ๋ฌธ์ ํ์ด์ง์
๋๋ค. Cookiecutter๊ฐ docs/source/model_doc/brand_new_bert.md
๋ผ๋ ํ
ํ๋ฆฟ ํ์ผ์ ์ถ๊ฐํด์คฌ์ ๊ฒ์
๋๋ค. ์ด ํ์ด์ง๋ฅผ ์ฌ์ฉํ๊ธฐ ์ ์ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ์ฌ์ฉ์๋ค์ ์ผ๋ฐ์ ์ผ๋ก ์ด ํ์ด์ง๋ฅผ ๋จผ์ ํ์ธํฉ๋๋ค. ๋ฐ๋ผ์ ๋ฌธ์๋ ์ดํดํ๊ธฐ ์ฝ๊ณ ๊ฐ๊ฒฐํด์ผ ํฉ๋๋ค. ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ฃผ๊ธฐ ์ํด ํ์ ์ถ๊ฐํ๋ ๊ฒ์ด ์ปค๋ฎค๋ํฐ์ ๋งค์ฐ ์ ์ฉํฉ๋๋ค. ๋
์คํธ๋ง์ ๊ด๋ จํ์ฌ Hugging Face ํ์ ๋ฌธ์ํ๋ ๊ฒ์ ์ฃผ์ ํ์ง ๋ง์ธ์.
๋ค์์ผ๋ก, src/transformers/models/brand_new_bert/modeling_brand_new_bert.py
์ ์ถ๊ฐ๋ ๋
์คํธ๋ง์ด ์ฌ๋ฐ๋ฅด๋ฉฐ ํ์ํ ๋ชจ๋ ์
๋ ฅ ๋ฐ ์ถ๋ ฅ์ ํฌํจํ๋๋ก ํ์ธํ์ธ์. ์ฌ๊ธฐ์์ ์ฐ๋ฆฌ์ ๋ฌธ์ ์์ฑ ๊ฐ์ด๋์ ๋
์คํธ๋ง ํ์์ ๋ํ ์์ธ ๊ฐ์ด๋๊ฐ ์์ต๋๋ค. ๋ฌธ์๋ ์ผ๋ฐ์ ์ผ๋ก ์ปค๋ฎค๋ํฐ์ ๋ชจ๋ธ์ ์ฒซ ๋ฒ์งธ ์ ์ ์ด๊ธฐ ๋๋ฌธ์, ๋ฌธ์๋ ์ ์ด๋ ์ฝ๋๋งํผ์ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ฌ์ผ ํฉ๋๋ค.
์ฝ๋ ๋ฆฌํฉํ ๋ง
์ข์์, ์ด์ brand_new_bert๋ฅผ ์ํ ๋ชจ๋ ํ์ํ ์ฝ๋๋ฅผ ์ถ๊ฐํ์ต๋๋ค. ์ด ์์ ์์ ๋ค์์ ์คํํ์ฌ ์ ์ฌ์ ์ผ๋ก ์๋ชป๋ ์ฝ๋ ์คํ์ผ์ ์์ ํด์ผ ํฉ๋๋ค:
๊ทธ๋ฆฌ๊ณ ์ฝ๋ฉ ์คํ์ผ์ด ํ์ง ์ ๊ฒ์ ํต๊ณผํ๋์ง ํ์ธํ๊ธฐ ์ํด ๋ค์์ ์คํํ๊ณ ํ์ธํด์ผ ํฉ๋๋ค:
make style
๐ค Transformers์๋ ์ฌ์ ํ ์คํจํ ์ ์๋ ๋ช ๊ฐ์ง ๋งค์ฐ ์๊ฒฉํ ๋์์ธ ํ ์คํธ๊ฐ ์์ต๋๋ค. ์ด๋ ๋ ์คํธ๋ง์ ๋๋ฝ๋ ์ ๋ณด๋ ์๋ชป๋ ๋ช ๋ช ๋๋ฌธ์ ์ข ์ข ๋ฐ์ํฉ๋๋ค. ์ฌ๊ธฐ์ ๋งํ๋ฉด Hugging Face ํ์ด ๋์์ ์ค ๊ฒ์ ๋๋ค.
make quality
๋ง์ง๋ง์ผ๋ก, ์ฝ๋๊ฐ ์ ํํ ์๋ํ๋ ๊ฒ์ ํ์ธํ ํ์๋ ํญ์ ์ฝ๋๋ฅผ ๋ฆฌํฉํ ๋งํ๋ ๊ฒ์ด ์ข์ ์๊ฐ์ ๋๋ค. ๋ชจ๋ ํ ์คํธ๊ฐ ํต๊ณผ๋ ์ง๊ธ์ ์ถ๊ฐํ ์ฝ๋๋ฅผ ๋ค์ ๊ฒํ ํ๊ณ ๋ฆฌํฉํ ๋งํ๋ ์ข์ ์๊ธฐ์ ๋๋ค.
์ด์ ์ฝ๋ฉ ๋ถ๋ถ์ ์๋ฃํ์ต๋๋ค. ์ถํํฉ๋๋ค! ๐ ๋ฉ์ ธ์! ๐
12. ๋ชจ๋ธ์ ๋ชจ๋ธ ํ๋ธ์ ์ ๋ก๋ํ์ธ์
์ด ๋ง์ง๋ง ํํธ์์๋ ๋ชจ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณํํ์ฌ ๋ชจ๋ธ ํ๋ธ์ ์
๋ก๋ํ๊ณ ๊ฐ ์
๋ก๋๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ์ ๋ํ ๋ชจ๋ธ ์นด๋๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค. Model sharing and uploading Page๋ฅผ ์ฝ๊ณ ํ๋ธ ๊ธฐ๋ฅ์ ์ต์ํด์ง์ธ์. brand_new_bert์ ์ ์ ์กฐ์ง ์๋์ ๋ชจ๋ธ์ ์
๋ก๋ํ ์ ์๋ ํ์ํ ์ก์ธ์ค ๊ถํ์ ์ป๊ธฐ ์ํด Hugging Face ํ๊ณผ ํ์
ํด์ผ ํฉ๋๋ค. transformers
์ ๋ชจ๋ ๋ชจ๋ธ์ ์๋ push_to_hub
๋ฉ์๋๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ํ๋ธ์ ๋น ๋ฅด๊ณ ํจ์จ์ ์ผ๋ก ์
๋ก๋ํ๋ ๋ฐฉ๋ฒ์
๋๋ค. ์๋์ ์์ ์ฝ๋ ์กฐ๊ฐ์ด ๋ถ์ฌ์ ธ ์์ต๋๋ค:
๊ฐ ์ฒดํฌํฌ์ธํธ์ ์ ํฉํ ๋ชจ๋ธ ์นด๋๋ฅผ ๋ง๋๋ ๋ฐ ์๊ฐ์ ํ ์ ํ๋ ๊ฒ์ ๊ฐ์น๊ฐ ์์ต๋๋ค. ๋ชจ๋ธ ์นด๋๋ ์ฒดํฌํฌ์ธํธ์ ํน์ฑ์ ๊ฐ์กฐํด์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด ์ด ์ฒดํฌํฌ์ธํธ๋ ์ด๋ค ๋ฐ์ดํฐ์ ์์ ์ฌ์ ํ๋ จ/์ธ๋ถ ํ๋ จ๋์๋์ง? ์ด ๋ชจ๋ธ์ ์ด๋ค ํ์ ์์ ์์ ์ฌ์ฉํด์ผ ํ๋์ง? ๊ทธ๋ฆฌ๊ณ ๋ชจ๋ธ์ ์ฌ๋ฐ๋ฅด๊ฒ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ช ๊ฐ์ง ์ฝ๋๋ ํฌํจํด์ผ ํฉ๋๋ค.
brand_new_bert.push_to_hub("brand_new_bert")
# Uncomment the following line to push to an organization.
# brand_new_bert.push_to_hub("<organization>/brand_new_bert")
13. (์ ํ ์ฌํญ) ๋ ธํธ๋ถ ์ถ๊ฐ
brand_new_bert๋ฅผ ๋ค์ด์คํธ๋ฆผ ์์ ์์ ์ถ๋ก ๋๋ ๋ฏธ์ธ ์กฐ์ ์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ์์ธํ ๋ณด์ฌ์ฃผ๋ ๋ ธํธ๋ถ์ ์ถ๊ฐํ๋ ๊ฒ์ด ๋งค์ฐ ์ ์ฉํฉ๋๋ค. ์ด๊ฒ์ PR์ ๋ณํฉํ๋ ๋ฐ ํ์์ ์ด์ง๋ ์์ง๋ง ์ปค๋ฎค๋ํฐ์ ๋งค์ฐ ์ ์ฉํฉ๋๋ค.
14. ์๋ฃ๋ PR ์ ์ถ
์ด์ ํ๋ก๊ทธ๋๋ฐ์ ๋ง์ณค์ผ๋ฉฐ, ๋ง์ง๋ง ๋จ๊ณ๋ก PR์ ๋ฉ์ธ ๋ธ๋์น์ ๋ณํฉํด์ผ ํฉ๋๋ค. ๋ณดํต Hugging Face ํ์ ์ด๋ฏธ ์ฌ๊ธฐ๊น์ง ๋์์ ์ฃผ์์ ๊ฒ์ ๋๋ค. ๊ทธ๋ฌ๋ PR์ ๋ฉ์ง ์ค๋ช ์ ์ถ๊ฐํ๊ณ ๋ฆฌ๋ทฐ์ด์๊ฒ ํน์ ๋์์ธ ์ ํ ์ฌํญ์ ๊ฐ์กฐํ๋ ค๋ฉด ์๋ฃ๋ PR์ ์ฝ๊ฐ์ ์ค๋ช ์ ์ถ๊ฐํ๋ ์๊ฐ์ ํ ์ ํ๋ ๊ฒ์ด ๊ฐ์น๊ฐ ์์ต๋๋ค.
์์ ๋ฌผ์ ๊ณต์ ํ์ธ์!! [[share-your-work]]
์ด์ ์ปค๋ฎค๋ํฐ์์ ์์ ๋ฌผ์ ์ธ์ ๋ฐ์ ์๊ฐ์ ๋๋ค! ๋ชจ๋ธ ์ถ๊ฐ ์์ ์ ์๋ฃํ๋ ๊ฒ์ Transformers์ ์ ์ฒด NLP ์ปค๋ฎค๋ํฐ์ ํฐ ๊ธฐ์ฌ์ ๋๋ค. ๋น์ ์ ์ฝ๋์ ์ด์๋ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ์๋ฐฑ, ์ฌ์ง์ด ์์ฒ ๋ช ์ ๊ฐ๋ฐ์์ ์ฐ๊ตฌ์์ ์ํด ํ์คํ ์ฌ์ฉ๋ ๊ฒ์ ๋๋ค. ๋น์ ์ ์์ ์ ์๋์ค๋ฌ์ํด์ผ ํ๋ฉฐ ์ด๋ฅผ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํด์ผ ํฉ๋๋ค.
๋น์ ์ ์ปค๋ฎค๋ํฐ ๋ด ๋ชจ๋ ์ฌ๋๋ค์๊ฒ ๋งค์ฐ ์ฝ๊ฒ ์ ๊ทผ ๊ฐ๋ฅํ ๋ ๋ค๋ฅธ ๋ชจ๋ธ์ ๋ง๋ค์์ต๋๋ค! ๐คฏ