update readme
Browse files
README.md
CHANGED
@@ -21,23 +21,162 @@ This repository provides a base-sized Japanese RoBERTa model. The model is provi
|
|
21 |
|
22 |
# How to use the model
|
23 |
|
24 |
-
Since this is a private repo, first login your huggingface account from the command line:
|
25 |
-
|
26 |
-
~~~
|
27 |
-
transformer-cli login
|
28 |
-
~~~
|
29 |
-
|
30 |
*NOTE:* Use `T5Tokenizer` to initiate the tokenizer.
|
31 |
|
32 |
~~~~
|
33 |
from transformers import T5Tokenizer, RobertaForMaskedLM
|
34 |
|
35 |
-
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base"
|
36 |
tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
|
37 |
|
38 |
-
model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
~~~~
|
40 |
|
|
|
41 |
# Model architecture
|
42 |
A 12-layer, 768-hidden-size transformer-based masked language model.
|
43 |
|
|
|
21 |
|
22 |
# How to use the model
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
*NOTE:* Use `T5Tokenizer` to initiate the tokenizer.
|
25 |
|
26 |
~~~~
|
27 |
from transformers import T5Tokenizer, RobertaForMaskedLM
|
28 |
|
29 |
+
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
|
30 |
tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
|
31 |
|
32 |
+
model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
|
33 |
+
~~~~
|
34 |
+
|
35 |
+
# How to use the model for masked token prediction
|
36 |
+
|
37 |
+
*NOTE:* To predict a masked token, be sure to add a `[CLS]` token before the sentence for the model to correctly encode it, as it is used during the model training.
|
38 |
+
|
39 |
+
Here we adopt the example by [kenta1984](https://qiita.com/kenta1984/items/7f3a5d859a15b20657f3#%E6%97%A5%E6%9C%AC%E8%AA%9Epre-trained-models) to illustrate how our model works as a masked language model.
|
40 |
+
|
41 |
+
~~~~
|
42 |
+
# original text
|
43 |
+
text = "テレビでサッカーの試合を見る。"
|
44 |
+
|
45 |
+
# prepend [CLS]
|
46 |
+
text = "[CLS]" + text
|
47 |
+
|
48 |
+
# tokenize
|
49 |
+
tokens = tokenizer.tokenize(text)
|
50 |
+
print(tokens) # output: ['[CLS]', '▁', 'テレビ', 'で', 'サッカー', 'の試合', 'を見る', '。']
|
51 |
+
|
52 |
+
# mask a token
|
53 |
+
masked_idx = 4
|
54 |
+
tokens[masked_idx] = tokenizer.mask_token
|
55 |
+
print(tokens) # output: ['[CLS]', '▁', 'テレビ', 'で', '[MASK]', 'の試合', 'を見る', '。']
|
56 |
+
|
57 |
+
# convert to ids
|
58 |
+
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
59 |
+
print(token_ids) # output: [4, 9, 480, 19, 6, 8466, 6518, 8]
|
60 |
+
|
61 |
+
# convert to tensor
|
62 |
+
import torch
|
63 |
+
token_tensor = torch.tensor([token_ids])
|
64 |
+
|
65 |
+
# get the top 50 predictions of the masked token
|
66 |
+
model = model.eval()
|
67 |
+
with torch.no_grad():
|
68 |
+
outputs = model(token_tensor)
|
69 |
+
predictions = outputs[0][0, masked_idx].topk(100)
|
70 |
+
for i, index_t in enumerate(predictions.indices):
|
71 |
+
index = index_t.item()
|
72 |
+
token = tokenizer.convert_ids_to_tokens([index])[0]
|
73 |
+
print(i, token)
|
74 |
+
|
75 |
+
"""
|
76 |
+
0 サッカー
|
77 |
+
1 メジャーリーグ
|
78 |
+
2 フィギュアスケート
|
79 |
+
3 バレーボール
|
80 |
+
4 ボクシング
|
81 |
+
5 ワールドカップ
|
82 |
+
6 バスケットボール
|
83 |
+
7 阪神タイガース
|
84 |
+
8 プロ野球
|
85 |
+
9 アメリカンフットボール
|
86 |
+
10 日本代表
|
87 |
+
11 高校野球
|
88 |
+
12 福岡ソフトバンクホークス
|
89 |
+
13 プレミアリーグ
|
90 |
+
14 ファイターズ
|
91 |
+
15 ラグビー
|
92 |
+
16 東北楽天ゴールデンイーグルス
|
93 |
+
17 中日ドラゴンズ
|
94 |
+
18 アイスホッケー
|
95 |
+
19 フットサル
|
96 |
+
20 サッカー選手
|
97 |
+
21 スポーツ
|
98 |
+
22 チャンピオンズリーグ
|
99 |
+
23 ジャイアンツ
|
100 |
+
24 ソフトボール
|
101 |
+
25 バスケット
|
102 |
+
26 フットボール
|
103 |
+
27 新日本プロレス
|
104 |
+
28 バドミントン
|
105 |
+
29 千葉ロッテマリーンズ
|
106 |
+
30 <unk>
|
107 |
+
31 北京オリンピック
|
108 |
+
32 広島東洋カープ
|
109 |
+
33 キックボクシング
|
110 |
+
34 オリンピック
|
111 |
+
35 ロンドンオリンピック
|
112 |
+
36 読売ジャイアンツ
|
113 |
+
37 テニス
|
114 |
+
38 東京オリンピック
|
115 |
+
39 日本シリーズ
|
116 |
+
40 ヤクルトスワローズ
|
117 |
+
41 タイガース
|
118 |
+
42 サッカークラブ
|
119 |
+
43 ハンドボール
|
120 |
+
44 野球
|
121 |
+
45 バルセロナ
|
122 |
+
46 ホッケー
|
123 |
+
47 格闘技
|
124 |
+
48 大相撲
|
125 |
+
49 ブンデスリーガ
|
126 |
+
50 スキージャンプ
|
127 |
+
51 プロサッカー選手
|
128 |
+
52 ヤンキース
|
129 |
+
53 社会人野球
|
130 |
+
54 クライマックスシリーズ
|
131 |
+
55 クリケット
|
132 |
+
56 トップリーグ
|
133 |
+
57 パラリンピック
|
134 |
+
58 クラブチーム
|
135 |
+
59 ニュージーランド
|
136 |
+
60 総合格闘技
|
137 |
+
61 ウィンブルドン
|
138 |
+
62 ドラゴンボール
|
139 |
+
63 レスリング
|
140 |
+
64 ドラゴンズ
|
141 |
+
65 プロ野球選手
|
142 |
+
66 リオデジャネイロオリンピック
|
143 |
+
67 ホークス
|
144 |
+
68 全日本プロレス
|
145 |
+
69 プロレス
|
146 |
+
70 ヴェルディ
|
147 |
+
71 都市対抗野球
|
148 |
+
72 ライオンズ
|
149 |
+
73 グランプリシリーズ
|
150 |
+
74 日本プロ野球
|
151 |
+
75 アテネオリンピック
|
152 |
+
76 ヤクルト
|
153 |
+
77 イーグルス
|
154 |
+
78 巨人
|
155 |
+
79 ワールドシリーズ
|
156 |
+
80 アーセナル
|
157 |
+
81 マスターズ
|
158 |
+
82 ソフトバンク
|
159 |
+
83 日本ハム
|
160 |
+
84 クロアチア
|
161 |
+
85 マリナーズ
|
162 |
+
86 サッカーリーグ
|
163 |
+
87 アトランタオリンピック
|
164 |
+
88 ゴルフ
|
165 |
+
89 ジャニーズ
|
166 |
+
90 甲子園
|
167 |
+
91 夏の甲子園
|
168 |
+
92 陸上競技
|
169 |
+
93 ベースボール
|
170 |
+
94 卓球
|
171 |
+
95 プロ
|
172 |
+
96 南アフリカ
|
173 |
+
97 レッズ
|
174 |
+
98 ウルグアイ
|
175 |
+
99 オールスターゲーム
|
176 |
+
"""
|
177 |
~~~~
|
178 |
|
179 |
+
|
180 |
# Model architecture
|
181 |
A 12-layer, 768-hidden-size transformer-based masked language model.
|
182 |
|