ahassoun's picture
Upload 3018 files
ee6e328

A newer version of the Gradio SDK is available: 5.23.3

Upgrade

๋‹ค๊ตญ์–ด ๋ชจ๋ธ ์ถ”๋ก ํ•˜๊ธฐ[[multilingual-models-for-inference]]

[[open-in-colab]]

๐Ÿค— Transformers์—๋Š” ์—ฌ๋Ÿฌ ์ข…๋ฅ˜์˜ ๋‹ค๊ตญ์–ด(multilingual) ๋ชจ๋ธ์ด ์žˆ์œผ๋ฉฐ, ๋‹จ์ผ ์–ธ์–ด(monolingual) ๋ชจ๋ธ๊ณผ ์ถ”๋ก  ์‹œ ์‚ฌ์šฉ๋ฒ•์ด ๋‹ค๋ฆ…๋‹ˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๊ณ  ํ•ด์„œ ๋ชจ๋“  ๋‹ค๊ตญ์–ด ๋ชจ๋ธ์˜ ์‚ฌ์šฉ๋ฒ•์ด ๋‹ค๋ฅธ ๊ฒƒ์€ ์•„๋‹™๋‹ˆ๋‹ค.

bert-base-multilingual-uncased์™€ ๊ฐ™์€ ๋ช‡๋ช‡ ๋ชจ๋ธ์€ ๋‹จ์ผ ์–ธ์–ด ๋ชจ๋ธ์ฒ˜๋Ÿผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฒˆ ๊ฐ€์ด๋“œ์—์„œ ๋‹ค๊ตญ์–ด ๋ชจ๋ธ์˜ ์ถ”๋ก  ์‹œ ์‚ฌ์šฉ ๋ฐฉ๋ฒ•์„ ์•Œ์•„๋ณผ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

XLM[[xlm]]

XLM์—๋Š” 10๊ฐ€์ง€ ์ฒดํฌํฌ์ธํŠธ(checkpoint)๊ฐ€ ์žˆ๋Š”๋ฐ, ์ด ์ค‘ ํ•˜๋‚˜๋งŒ ๋‹จ์ผ ์–ธ์–ด์ž…๋‹ˆ๋‹ค. ๋‚˜๋จธ์ง€ ์ฒดํฌํฌ์ธํŠธ 9๊ฐœ๋Š” ์–ธ์–ด ์ž„๋ฒ ๋”ฉ์„ ์‚ฌ์šฉํ•˜๋Š” ์ฒดํฌํฌ์ธํŠธ์™€ ๊ทธ๋ ‡์ง€ ์•Š์€ ์ฒดํฌํฌ์ธํŠธ์˜ ๋‘ ๊ฐ€์ง€ ๋ฒ”์ฃผ๋กœ ๋‚˜๋ˆŒ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์–ธ์–ด ์ž„๋ฒ ๋”ฉ์„ ์‚ฌ์šฉํ•˜๋Š” XLM[[xlm-with-language-embeddings]]

๋‹ค์Œ XLM ๋ชจ๋ธ์€ ์ถ”๋ก  ์‹œ์— ์–ธ์–ด ์ž„๋ฒ ๋”ฉ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค:

  • xlm-mlm-ende-1024 (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง, ์˜์–ด-๋…์ผ์–ด)
  • xlm-mlm-enfr-1024 (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง, ์˜์–ด-ํ”„๋ž‘์Šค์–ด)
  • xlm-mlm-enro-1024 (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง, ์˜์–ด-๋ฃจ๋งˆ๋‹ˆ์•„์–ด)
  • xlm-mlm-xnli15-1024 (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง, XNLI ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ์ œ๊ณตํ•˜๋Š” 15๊ฐœ ๊ตญ์–ด)
  • xlm-mlm-tlm-xnli15-1024 (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง + ๋ฒˆ์—ญ, XNLI ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ์ œ๊ณตํ•˜๋Š” 15๊ฐœ ๊ตญ์–ด)
  • xlm-clm-enfr-1024 (Causal language modeling, ์˜์–ด-ํ”„๋ž‘์Šค์–ด)
  • xlm-clm-ende-1024 (Causal language modeling, ์˜์–ด-๋…์ผ์–ด)

์–ธ์–ด ์ž„๋ฒ ๋”ฉ์€ ๋ชจ๋ธ์— ์ „๋‹ฌ๋œ input_ids์™€ ๋™์ผํ•œ shape์˜ ํ…์„œ๋กœ ํ‘œํ˜„๋ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ํ…์„œ์˜ ๊ฐ’์€ ์‚ฌ์šฉ๋œ ์–ธ์–ด์— ๋”ฐ๋ผ ๋‹ค๋ฅด๋ฉฐ ํ† ํฌ๋‚˜์ด์ €์˜ lang2id ๋ฐ id2lang ์†์„ฑ์— ์˜ํ•ด ์‹๋ณ„๋ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ ์˜ˆ์ œ์—์„œ๋Š” xlm-clm-enfr-1024 ์ฒดํฌํฌ์ธํŠธ(์ฝ”์ž˜ ์–ธ์–ด ๋ชจ๋ธ๋ง(causal language modeling), ์˜์–ด-ํ”„๋ž‘์Šค์–ด)๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค:

>>> import torch
>>> from transformers import XLMTokenizer, XLMWithLMHeadModel

>>> tokenizer = XLMTokenizer.from_pretrained("xlm-clm-enfr-1024")
>>> model = XLMWithLMHeadModel.from_pretrained("xlm-clm-enfr-1024")

ํ† ํฌ๋‚˜์ด์ €์˜ lang2id ์†์„ฑ์€ ๋ชจ๋ธ์˜ ์–ธ์–ด์™€ ํ•ด๋‹น ID๋ฅผ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค:

>>> print(tokenizer.lang2id)
{'en': 0, 'fr': 1}

๋‹ค์Œ์œผ๋กœ, ์˜ˆ์ œ ์ž…๋ ฅ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค:

>>> input_ids = torch.tensor([tokenizer.encode("Wikipedia was used to")])  # ๋ฐฐ์น˜ ํฌ๊ธฐ๋Š” 1์ž…๋‹ˆ๋‹ค

์–ธ์–ด ID๋ฅผ "en"์œผ๋กœ ์„ค์ •ํ•ด ์–ธ์–ด ์ž„๋ฒ ๋”ฉ์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ์–ธ์–ด ์ž„๋ฒ ๋”ฉ์€ ์˜์–ด์˜ ์–ธ์–ด ID์ธ 0์œผ๋กœ ์ฑ„์›Œ์ง„ ํ…์„œ์ž…๋‹ˆ๋‹ค. ์ด ํ…์„œ๋Š” input_ids์™€ ๊ฐ™์€ ํฌ๊ธฐ์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค.

>>> language_id = tokenizer.lang2id["en"]  # 0
>>> langs = torch.tensor([language_id] * input_ids.shape[1])  # torch.tensor([0, 0, 0, ..., 0])

>>> # (batch_size, sequence_length) shape์˜ ํ…์„œ๊ฐ€ ๋˜๋„๋ก ๋งŒ๋“ญ๋‹ˆ๋‹ค.
>>> langs = langs.view(1, -1)  # ์ด์ œ [1, sequence_length] shape์ด ๋˜์—ˆ์Šต๋‹ˆ๋‹ค(๋ฐฐ์น˜ ํฌ๊ธฐ๋Š” 1์ž…๋‹ˆ๋‹ค)

์ด์ œ input_ids์™€ ์–ธ์–ด ์ž„๋ฒ ๋”ฉ์„ ๋ชจ๋ธ๋กœ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค:

>>> outputs = model(input_ids, langs=langs)

run_generation.py ์Šคํฌ๋ฆฝํŠธ๋กœ xlm-clm ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ฌ์šฉํ•ด ํ…์ŠคํŠธ์™€ ์–ธ์–ด ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์–ธ์–ด ์ž„๋ฒ ๋”ฉ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” XLM[[xlm-without-language-embeddings]]

๋‹ค์Œ XLM ๋ชจ๋ธ์€ ์ถ”๋ก  ์‹œ์— ์–ธ์–ด ์ž„๋ฒ ๋”ฉ์ด ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค:

  • xlm-mlm-17-1280 (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง, 17๊ฐœ ๊ตญ์–ด)
  • xlm-mlm-100-1280 (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง, 100๊ฐœ ๊ตญ์–ด)

์ด์ „์˜ XLM ์ฒดํฌํฌ์ธํŠธ์™€ ๋‹ฌ๋ฆฌ ์ด ๋ชจ๋ธ์€ ์ผ๋ฐ˜ ๋ฌธ์žฅ ํ‘œํ˜„์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

BERT[[bert]]

๋‹ค์Œ BERT ๋ชจ๋ธ์€ ๋‹ค๊ตญ์–ด ํƒœ์Šคํฌ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

  • bert-base-multilingual-uncased (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง + ๋‹ค์Œ ๋ฌธ์žฅ ์˜ˆ์ธก, 102๊ฐœ ๊ตญ์–ด)
  • bert-base-multilingual-cased (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง + ๋‹ค์Œ ๋ฌธ์žฅ ์˜ˆ์ธก, 104๊ฐœ ๊ตญ์–ด)

์ด๋Ÿฌํ•œ ๋ชจ๋ธ์€ ์ถ”๋ก  ์‹œ์— ์–ธ์–ด ์ž„๋ฒ ๋”ฉ์ด ํ•„์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋ฌธ๋งฅ์—์„œ ์–ธ์–ด๋ฅผ ์‹๋ณ„ํ•˜๊ณ , ์‹๋ณ„๋œ ์–ธ์–ด๋กœ ์ถ”๋ก ํ•ฉ๋‹ˆ๋‹ค.

XLM-RoBERTa[[xlmroberta]]

๋‹ค์Œ XLM-RoBERTa ๋˜ํ•œ ๋‹ค๊ตญ์–ด ๋‹ค๊ตญ์–ด ํƒœ์Šคํฌ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

  • xlm-roberta-base (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง, 100๊ฐœ ๊ตญ์–ด)
  • xlm-roberta-large (๋งˆ์Šคํ‚น๋œ ์–ธ์–ด ๋ชจ๋ธ๋ง, 100๊ฐœ ๊ตญ์–ด)

XLM-RoBERTa๋Š” 100๊ฐœ ๊ตญ์–ด์— ๋Œ€ํ•ด ์ƒˆ๋กœ ์ƒ์„ฑ๋˜๊ณ  ์ •์ œ๋œ 2.5TB ๊ทœ๋ชจ์˜ CommonCrawl ๋ฐ์ดํ„ฐ๋กœ ํ•™์Šต๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด์ „์— ๊ณต๊ฐœ๋œ mBERT๋‚˜ XLM๊ณผ ๊ฐ™์€ ๋‹ค๊ตญ์–ด ๋ชจ๋ธ์— ๋น„ํ•ด ๋ถ„๋ฅ˜, ์‹œํ€€์Šค ๋ผ๋ฒจ๋ง, ์งˆ์˜ ์‘๋‹ต๊ณผ ๊ฐ™์€ ๋‹ค์šด์ŠคํŠธ๋ฆผ(downstream) ์ž‘์—…์—์„œ ์ด์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

M2M100[[m2m100]]

๋‹ค์Œ M2M100 ๋ชจ๋ธ ๋˜ํ•œ ๋‹ค๊ตญ์–ด ๋‹ค๊ตญ์–ด ํƒœ์Šคํฌ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

  • facebook/m2m100_418M (๋ฒˆ์—ญ)
  • facebook/m2m100_1.2B (๋ฒˆ์—ญ)

์ด ์˜ˆ์ œ์—์„œ๋Š” facebook/m2m100_418M ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๊ฐ€์ ธ์™€์„œ ์ค‘๊ตญ์–ด๋ฅผ ์˜์–ด๋กœ ๋ฒˆ์—ญํ•ฉ๋‹ˆ๋‹ค. ํ† ํฌ๋‚˜์ด์ €์—์„œ ๋ฒˆ์—ญ ๋Œ€์ƒ ์–ธ์–ด(source language)๋ฅผ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

>>> en_text = "Do not meddle in the affairs of wizards, for they are subtle and quick to anger."
>>> chinese_text = "ไธ่ฆๆ’ๆ‰‹ๅทซๅธซ็š„ไบ‹ๅ‹™, ๅ› ็‚บไป–ๅ€‘ๆ˜ฏๅพฎๅฆ™็š„, ๅพˆๅฟซๅฐฑๆœƒ็™ผๆ€’."

>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="zh")
>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")

๋ฌธ์žฅ์„ ํ† ํฐํ™”ํ•ฉ๋‹ˆ๋‹ค:

>>> encoded_zh = tokenizer(chinese_text, return_tensors="pt")

M2M100์€ ๋ฒˆ์—ญ์„ ์ง„ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ์ฒซ ๋ฒˆ์งธ๋กœ ์ƒ์„ฑ๋˜๋Š” ํ† ํฐ์€ ๋ฒˆ์—ญํ•  ์–ธ์–ด(target language) ID๋กœ ๊ฐ•์ œ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค. ์˜์–ด๋กœ ๋ฒˆ์—ญํ•˜๊ธฐ ์œ„ํ•ด generate ๋ฉ”์†Œ๋“œ์—์„œ forced_bos_token_id๋ฅผ en์œผ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค:

>>> generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id("en"))
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
'Do not interfere with the matters of the witches, because they are delicate and will soon be angry.'

MBart[[mbart]]

๋‹ค์Œ MBart ๋ชจ๋ธ ๋˜ํ•œ ๋‹ค๊ตญ์–ด ํƒœ์Šคํฌ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

  • facebook/mbart-large-50-one-to-many-mmt (์ผ๋Œ€๋‹ค ๋‹ค๊ตญ์–ด ๋ฒˆ์—ญ, 50๊ฐœ ๊ตญ์–ด)
  • facebook/mbart-large-50-many-to-many-mmt (๋‹ค๋Œ€๋‹ค ๋‹ค๊ตญ์–ด ๋ฒˆ์—ญ, 50๊ฐœ ๊ตญ์–ด)
  • facebook/mbart-large-50-many-to-one-mmt (๋‹ค๋Œ€์ผ ๋‹ค๊ตญ์–ด ๋ฒˆ์—ญ, 50๊ฐœ ๊ตญ์–ด)
  • facebook/mbart-large-50 (๋‹ค๊ตญ์–ด ๋ฒˆ์—ญ, 50๊ฐœ ๊ตญ์–ด)
  • facebook/mbart-large-cc25

์ด ์˜ˆ์ œ์—์„œ๋Š” ํ•€๋ž€๋“œ์–ด๋ฅผ ์˜์–ด๋กœ ๋ฒˆ์—ญํ•˜๊ธฐ ์œ„ํ•ด facebook/mbart-large-50-many-to-many-mmt ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. ํ† ํฌ๋‚˜์ด์ €์—์„œ ๋ฒˆ์—ญ ๋Œ€์ƒ ์–ธ์–ด(source language)๋ฅผ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

>>> en_text = "Do not meddle in the affairs of wizards, for they are subtle and quick to anger."
>>> fi_text = "ร„lรค sekaannu velhojen asioihin, sillรค ne ovat hienovaraisia ja nopeasti vihaisia."

>>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", src_lang="fi_FI")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

๋ฌธ์žฅ์„ ํ† ํฐํ™”ํ•ฉ๋‹ˆ๋‹ค:

>>> encoded_en = tokenizer(en_text, return_tensors="pt")

MBart๋Š” ๋ฒˆ์—ญ์„ ์ง„ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ์ฒซ ๋ฒˆ์งธ๋กœ ์ƒ์„ฑ๋˜๋Š” ํ† ํฐ์€ ๋ฒˆ์—ญํ•  ์–ธ์–ด(target language) ID๋กœ ๊ฐ•์ œ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค. ์˜์–ด๋กœ ๋ฒˆ์—ญํ•˜๊ธฐ ์œ„ํ•ด generate ๋ฉ”์†Œ๋“œ์—์„œ forced_bos_token_id๋ฅผ en์œผ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค:

>>> generated_tokens = model.generate(**encoded_en, forced_bos_token_id=tokenizer.lang_code_to_id("en_XX"))
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
"Don't interfere with the wizard's affairs, because they are subtle, will soon get angry."

facebook/mbart-large-50-many-to-one-mmt ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋‹ค๋ฉด, ์ฒซ ๋ฒˆ์งธ๋กœ ์ƒ์„ฑ๋˜๋Š” ํ† ํฐ์„ ๋ฒˆ์—ญํ•  ์–ธ์–ด(target language) ID๋กœ ๊ฐ•์ œ ์ง€์ •ํ•  ํ•„์š”๋Š” ์—†์Šต๋‹ˆ๋‹ค.