update readme
Browse files- README.md +33 -112
- config.json +1 -1
README.md
CHANGED
@@ -11,15 +11,18 @@ license: mit
|
|
11 |
datasets:
|
12 |
- cc100
|
13 |
- wikipedia
|
|
|
|
|
|
|
14 |
---
|
15 |
|
16 |
# japanese-roberta-base
|
17 |
|
18 |
![rinna-icon](./rinna.png)
|
19 |
|
20 |
-
This repository provides a base-sized Japanese RoBERTa model. The model
|
21 |
|
22 |
-
# How to
|
23 |
|
24 |
*NOTE:* Use `T5Tokenizer` to initiate the tokenizer.
|
25 |
|
@@ -34,149 +37,67 @@ model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
|
|
34 |
|
35 |
# How to use the model for masked token prediction
|
36 |
|
37 |
-
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
~~~~
|
42 |
# original text
|
43 |
-
text = "
|
44 |
|
45 |
# prepend [CLS]
|
46 |
text = "[CLS]" + text
|
47 |
|
48 |
# tokenize
|
49 |
tokens = tokenizer.tokenize(text)
|
50 |
-
print(tokens) # output: ['[CLS]', '▁', '
|
51 |
|
52 |
# mask a token
|
53 |
-
masked_idx =
|
54 |
tokens[masked_idx] = tokenizer.mask_token
|
55 |
-
print(tokens) # output: ['[CLS]', '▁', '
|
56 |
|
57 |
# convert to ids
|
58 |
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
59 |
-
print(token_ids) # output: [4,
|
60 |
|
61 |
# convert to tensor
|
62 |
import torch
|
63 |
token_tensor = torch.tensor([token_ids])
|
64 |
|
65 |
-
# get the top
|
66 |
model = model.eval()
|
67 |
with torch.no_grad():
|
68 |
outputs = model(token_tensor)
|
69 |
-
predictions = outputs[0][0, masked_idx].topk(
|
|
|
70 |
for i, index_t in enumerate(predictions.indices):
|
71 |
index = index_t.item()
|
72 |
token = tokenizer.convert_ids_to_tokens([index])[0]
|
73 |
print(i, token)
|
74 |
|
75 |
"""
|
76 |
-
0
|
77 |
-
1
|
78 |
-
2
|
79 |
-
3
|
80 |
-
4
|
81 |
-
5
|
82 |
-
6
|
83 |
-
7
|
84 |
-
8
|
85 |
-
9
|
86 |
-
10 日本代表
|
87 |
-
11 高校野球
|
88 |
-
12 福岡ソフトバンクホークス
|
89 |
-
13 プレミアリーグ
|
90 |
-
14 ファイターズ
|
91 |
-
15 ラグビー
|
92 |
-
16 東北楽天ゴールデンイーグルス
|
93 |
-
17 中日ドラゴンズ
|
94 |
-
18 アイスホッケー
|
95 |
-
19 フットサル
|
96 |
-
20 サッカー選手
|
97 |
-
21 スポーツ
|
98 |
-
22 チャンピオンズリーグ
|
99 |
-
23 ジャイアンツ
|
100 |
-
24 ソフトボール
|
101 |
-
25 バスケット
|
102 |
-
26 フットボール
|
103 |
-
27 新日本プロレス
|
104 |
-
28 バドミントン
|
105 |
-
29 千葉ロッテマリーンズ
|
106 |
-
30 <unk>
|
107 |
-
31 北京オリンピック
|
108 |
-
32 広島東洋カープ
|
109 |
-
33 キックボクシング
|
110 |
-
34 オリンピック
|
111 |
-
35 ロンドンオリンピック
|
112 |
-
36 読売ジャイアンツ
|
113 |
-
37 テニス
|
114 |
-
38 東京オリンピック
|
115 |
-
39 日本シリーズ
|
116 |
-
40 ヤクルトスワローズ
|
117 |
-
41 タイガース
|
118 |
-
42 サッカークラブ
|
119 |
-
43 ハンドボール
|
120 |
-
44 野球
|
121 |
-
45 バルセロナ
|
122 |
-
46 ホッケー
|
123 |
-
47 格闘技
|
124 |
-
48 大相撲
|
125 |
-
49 ブンデスリーガ
|
126 |
-
50 スキージャンプ
|
127 |
-
51 プロサッカー選手
|
128 |
-
52 ヤンキース
|
129 |
-
53 社会人野球
|
130 |
-
54 クライマックスシリーズ
|
131 |
-
55 クリケット
|
132 |
-
56 トップリーグ
|
133 |
-
57 パラリンピック
|
134 |
-
58 クラブチーム
|
135 |
-
59 ニュージーランド
|
136 |
-
60 総合格闘技
|
137 |
-
61 ウィンブルドン
|
138 |
-
62 ドラゴン���ール
|
139 |
-
63 レスリング
|
140 |
-
64 ドラゴンズ
|
141 |
-
65 プロ野球選手
|
142 |
-
66 リオデジャネイロオリンピック
|
143 |
-
67 ホークス
|
144 |
-
68 全日本プロレス
|
145 |
-
69 プロレス
|
146 |
-
70 ヴェルディ
|
147 |
-
71 都市対抗野球
|
148 |
-
72 ライオンズ
|
149 |
-
73 グランプリシリーズ
|
150 |
-
74 日本プロ野球
|
151 |
-
75 アテネオリンピック
|
152 |
-
76 ヤクルト
|
153 |
-
77 イーグルス
|
154 |
-
78 巨人
|
155 |
-
79 ワールドシリーズ
|
156 |
-
80 アーセナル
|
157 |
-
81 マスターズ
|
158 |
-
82 ソフトバンク
|
159 |
-
83 日本ハム
|
160 |
-
84 クロアチア
|
161 |
-
85 マリナーズ
|
162 |
-
86 サッカーリーグ
|
163 |
-
87 アトランタオリンピック
|
164 |
-
88 ゴルフ
|
165 |
-
89 ジャニーズ
|
166 |
-
90 甲子園
|
167 |
-
91 夏の甲子園
|
168 |
-
92 陸上競技
|
169 |
-
93 ベースボール
|
170 |
-
94 卓球
|
171 |
-
95 プロ
|
172 |
-
96 南アフリカ
|
173 |
-
97 レッズ
|
174 |
-
98 ウルグアイ
|
175 |
-
99 オールスターゲーム
|
176 |
"""
|
177 |
~~~~
|
178 |
|
179 |
-
|
180 |
# Model architecture
|
181 |
A 12-layer, 768-hidden-size transformer-based masked language model.
|
182 |
|
|
|
11 |
datasets:
|
12 |
- cc100
|
13 |
- wikipedia
|
14 |
+
widget:
|
15 |
+
- text: "[CLS]4年に1度[MASK]は開かれる。"
|
16 |
+
mask_token: "[MASK]"
|
17 |
---
|
18 |
|
19 |
# japanese-roberta-base
|
20 |
|
21 |
![rinna-icon](./rinna.png)
|
22 |
|
23 |
+
This repository provides a base-sized Japanese RoBERTa model. The model was trained using code from Github repository [rinnakk/japanese-pretrained-models](https://github.com/rinnakk/japanese-pretrained-models) by [rinna Co., Ltd.](https://corp.rinna.co.jp/)
|
24 |
|
25 |
+
# How to load the model
|
26 |
|
27 |
*NOTE:* Use `T5Tokenizer` to initiate the tokenizer.
|
28 |
|
|
|
37 |
|
38 |
# How to use the model for masked token prediction
|
39 |
|
40 |
+
## Note 1: Use `[CLS]`
|
41 |
|
42 |
+
To predict a masked token, be sure to add a `[CLS]` token before the sentence for the model to correctly encode it, as it is used during the model training.
|
43 |
+
|
44 |
+
## Note 2: Use `[MASK]` after tokenization
|
45 |
+
|
46 |
+
A) Directly typing `[MASK]` in an input string and B) replacing a token with `[MASK]` after tokenization will yield different token sequences, and thus different prediction results. It is more appropriate to use `[MASK]` after tokenization (as it is consistent with how the model was pretrained). However, the Huggingface Inference API only supports typing `[MASK]` in the input string and produces less robust predictions.
|
47 |
+
|
48 |
+
## Example
|
49 |
+
|
50 |
+
Here is an example by to illustrate how our model works as a masked language model. Notice the difference between running the following code example and running the Huggingface Inference API.
|
51 |
|
52 |
~~~~
|
53 |
# original text
|
54 |
+
text = "4年に1度オリンピックは開かれる。"
|
55 |
|
56 |
# prepend [CLS]
|
57 |
text = "[CLS]" + text
|
58 |
|
59 |
# tokenize
|
60 |
tokens = tokenizer.tokenize(text)
|
61 |
+
print(tokens) # output: ['[CLS]', '▁4', '年に', '1', '度', 'オリンピック', 'は', '開かれる', '。']
|
62 |
|
63 |
# mask a token
|
64 |
+
masked_idx = 6
|
65 |
tokens[masked_idx] = tokenizer.mask_token
|
66 |
+
print(tokens) # output: ['[CLS]', '▁4', '年に', '1', '度', '[MASK]', 'は', '開かれる', '。']
|
67 |
|
68 |
# convert to ids
|
69 |
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
70 |
+
print(token_ids) # output: [4, 1602, 44, 24, 368, 6, 11, 21583, 8]
|
71 |
|
72 |
# convert to tensor
|
73 |
import torch
|
74 |
token_tensor = torch.tensor([token_ids])
|
75 |
|
76 |
+
# get the top 10 predictions of the masked token
|
77 |
model = model.eval()
|
78 |
with torch.no_grad():
|
79 |
outputs = model(token_tensor)
|
80 |
+
predictions = outputs[0][0, masked_idx].topk(10)
|
81 |
+
|
82 |
for i, index_t in enumerate(predictions.indices):
|
83 |
index = index_t.item()
|
84 |
token = tokenizer.convert_ids_to_tokens([index])[0]
|
85 |
print(i, token)
|
86 |
|
87 |
"""
|
88 |
+
0 ワールドカップ
|
89 |
+
1 フェスティバル
|
90 |
+
2 オリンピック
|
91 |
+
3 サミット
|
92 |
+
4 東京オリンピック
|
93 |
+
5 総会
|
94 |
+
6 全国大会
|
95 |
+
7 イベント
|
96 |
+
8 世界選手権
|
97 |
+
9 パーティー
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
"""
|
99 |
~~~~
|
100 |
|
|
|
101 |
# Model architecture
|
102 |
A 12-layer, 768-hidden-size transformer-based masked language model.
|
103 |
|
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"RobertaForMaskedLM"
|
5 |
],
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "rinna/japanese-roberta-base",
|
3 |
"architectures": [
|
4 |
"RobertaForMaskedLM"
|
5 |
],
|