Update README.md
Browse files
README.md
CHANGED
@@ -7,19 +7,36 @@ tags:
|
|
7 |
- text-generation
|
8 |
- lm
|
9 |
- nlp
|
|
|
10 |
license: mit
|
11 |
datasets:
|
12 |
- kunishou/databricks-dolly-15k-ja
|
|
|
13 |
widget:
|
14 |
- text: >-
|
15 |
<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n日本で一番広い湖は?\n[SEP]\n応答:\n
|
16 |
---
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# dolly-japanese-gpt-1b
|
19 |
|
20 |
-
1.3Bパラメータの日本語GPT
|
21 |
|
22 |
-
rinna社の「[japanese-gpt-1b](https://huggingface.co/rinna/japanese-gpt-1b)
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
学習データやモデルを作成および配布してくださった方々に心から感謝申し上げます。
|
25 |
|
@@ -36,29 +53,38 @@ tokenizer = AutoTokenizer.from_pretrained("inu-ai/dolly-japanese-gpt-1b", use_fa
|
|
36 |
model = AutoModelForCausalLM.from_pretrained("inu-ai/dolly-japanese-gpt-1b").to(device)
|
37 |
```
|
38 |
|
39 |
-
##
|
40 |
|
41 |
```python
|
42 |
MAX_ASSISTANT_LENGTH = 100
|
43 |
MAX_INPUT_LENGTH = 1024
|
44 |
INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n'
|
45 |
NO_INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n'
|
|
|
|
|
46 |
|
47 |
-
def prepare_input(
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
return
|
53 |
|
54 |
def format_output(output):
|
55 |
output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n")
|
56 |
return output
|
57 |
|
58 |
-
def generate_response(
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
with torch.no_grad():
|
64 |
output_ids = model.generate(
|
@@ -66,6 +92,7 @@ def generate_response(instruction, input_text):
|
|
66 |
min_length=n,
|
67 |
max_length=min(MAX_INPUT_LENGTH, n + MAX_ASSISTANT_LENGTH),
|
68 |
temperature=0.7,
|
|
|
69 |
do_sample=True,
|
70 |
pad_token_id=tokenizer.pad_token_id,
|
71 |
bos_token_id=tokenizer.bos_token_id,
|
@@ -75,105 +102,151 @@ def generate_response(instruction, input_text):
|
|
75 |
|
76 |
output = tokenizer.decode(output_ids.tolist()[0])
|
77 |
formatted_output_all = format_output(output)
|
78 |
-
|
|
|
|
|
|
|
79 |
|
80 |
return formatted_output_all, response
|
81 |
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
questions = [
|
84 |
"日本で一番高い山は?",
|
85 |
"日本で一番広い湖は?",
|
|
|
86 |
"世界で一番高い山は?",
|
87 |
"世界で一番広い湖は?",
|
88 |
-
"
|
|
|
|
|
89 |
]
|
90 |
|
91 |
# 各質問に対して応答を生成して表示
|
92 |
for question in questions:
|
93 |
-
formatted_output_all, response = generate_response(
|
94 |
-
print(response)
|
95 |
```
|
96 |
|
97 |
## 出力
|
98 |
|
99 |
```
|
|
|
100 |
Assistant:富士山
|
101 |
-
|
102 |
-
|
103 |
-
Assistant
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
```
|
106 |
|
107 |
-
##
|
108 |
|
109 |
-
|
110 |
-
まず、`prepare_input`関数でプロンプトを作成し、`generate_response`関数でモデルから応答を生成します。
|
111 |
-
生成された応答を整形し、質問ごとに結果を表示します。
|
112 |
|
113 |
# 評価
|
114 |
-
|
115 |
-
一番正答率が高い
|
116 |
|
117 |
| 入力 | 応答 | 正答率[%] |
|
118 |
|-----------------------|-------------|-------|
|
119 |
-
| 日本で一番広い湖は? | 琵琶湖 |
|
120 |
-
| 世界で一番高い山は? | エベレスト |
|
121 |
|
122 |
# 学習データのフォーマット
|
123 |
|
124 |
[alpaca](https://github.com/tatsu-lab/stanford_alpaca)と同じように、以下のようなフォーマットにしています。
|
125 |
|
126 |
```
|
127 |
-
<s>
|
128 |
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
|
129 |
-
[SEP]
|
130 |
指示:
|
131 |
-
|
132 |
-
[SEP]
|
133 |
入力:
|
134 |
User:日本で一番高い山は?
|
135 |
-
[SEP]
|
136 |
応答:
|
137 |
富士山
|
138 |
</s>
|
139 |
```
|
140 |
|
141 |
transformersのコードでtxtファイルを学習する場合、1データ1行のようなので改行コードを一旦`\n`に置き換えています。
|
142 |
-
学習データは[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
# 学習のハイパーパラメータ
|
145 |
|
146 |
学習時には以下のハイパーパラメータを使用:
|
147 |
-
|
|
|
148 |
```
|
149 |
-
python.exe transformers/examples/pytorch/language-modeling/run_clm.py ^
|
150 |
--model_name_or_path rinna/japanese-gpt-1b ^
|
151 |
-
--train_file train_data/
|
152 |
-
|
153 |
--do_train ^
|
154 |
-
|
155 |
-
|
156 |
--optim adamw_bnb_8bit ^
|
157 |
-
--num_train_epochs
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
--gradient_checkpointing ^
|
163 |
--per_device_train_batch_size 8 ^
|
164 |
-
|
165 |
-
|
166 |
```
|
167 |
|
168 |
# ライブラリのバージョン
|
169 |
|
170 |
-
- Transformers 4.28.
|
171 |
- Pytorch 2.0.0+cu117
|
|
|
172 |
- Tokenizers 0.13.3
|
173 |
- bitsandbytes 0.37.2
|
174 |
|
175 |
# ライセンス
|
176 |
MITで大丈夫そうです。
|
177 |
|
178 |
-
- japanese-gpt-1b - mit
|
179 |
-
- databricks-dolly-15k-ja - CC BY SA 3.0
|
|
|
|
|
|
|
|
7 |
- text-generation
|
8 |
- lm
|
9 |
- nlp
|
10 |
+
- conversational
|
11 |
license: mit
|
12 |
datasets:
|
13 |
- kunishou/databricks-dolly-15k-ja
|
14 |
+
- kunishou/oasst1-89k-ja
|
15 |
widget:
|
16 |
- text: >-
|
17 |
<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n日本で一番広い湖は?\n[SEP]\n応答:\n
|
18 |
---
|
19 |
|
20 |
+
# 更新履歴
|
21 |
+
- 2023年5月7日
|
22 |
+
|
23 |
+
「[oasst1-89k-ja](https://huggingface.co/datasets/kunishou/oasst1-89k-ja)」データセットを追加して**対話システム**に対応しました。1024トークンまで会話履歴を保存できます。
|
24 |
+
前回のモデルで行った質疑応答の正答率は今回のモデルで下がりました。「日本で一番広い湖は?」が91%から89%、「世界で一番高い山は?」が84%から73%に下がりました。(対話は分けた方が良かったのか、それともoasst1の質が良くないとか)
|
25 |
+
|
26 |
+
- 2023年4月13日
|
27 |
+
|
28 |
+
「[japanese-gpt-1b](https://huggingface.co/rinna/japanese-gpt-1b)」モデルを「[databricks-dolly-15k-ja](https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja)」データセットで**RLHF** (人間のフィードバックからの強化学習)しました。
|
29 |
+
|
30 |
# dolly-japanese-gpt-1b
|
31 |
|
32 |
+
1.3Bパラメータの日本語GPT-2モデルを使用した対話型のAIです。VRAM 7GB または RAM 7GB が必要で、問題なく動作すると思われます。
|
33 |
|
34 |
+
rinna社の「[japanese-gpt-1b](https://huggingface.co/rinna/japanese-gpt-1b)」を、
|
35 |
+
日本語データセット「[databricks-dolly-15k-ja](https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja)」、
|
36 |
+
「[oasst1-89k-ja](https://huggingface.co/datasets/kunishou/oasst1-89k-ja)」、
|
37 |
+
「[OjousamaTalkScriptDataset](https://github.com/matsuvr/OjousamaTalkScriptDataset)」、
|
38 |
+
「[train_data/zundamon.json](train_data/zundamon.json)」
|
39 |
+
を使用して学習させました。
|
40 |
|
41 |
学習データやモデルを作成および配布してくださった方々に心から感謝申し上げます。
|
42 |
|
|
|
53 |
model = AutoModelForCausalLM.from_pretrained("inu-ai/dolly-japanese-gpt-1b").to(device)
|
54 |
```
|
55 |
|
56 |
+
## ChatGPT/GPT-4によるサンプルコード(少し修正)
|
57 |
|
58 |
```python
|
59 |
MAX_ASSISTANT_LENGTH = 100
|
60 |
MAX_INPUT_LENGTH = 1024
|
61 |
INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n'
|
62 |
NO_INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n'
|
63 |
+
USER_NAME = "User"
|
64 |
+
ASSISTANT_NAME = "Assistant"
|
65 |
|
66 |
+
def prepare_input(role_instruction, conversation_history, new_conversation):
|
67 |
+
instruction = "".join([f"{text} " for text in role_instruction])
|
68 |
+
instruction += " ".join(conversation_history)
|
69 |
+
input_text = f"{USER_NAME}:{new_conversation}"
|
70 |
+
|
71 |
+
return INPUT_PROMPT.format(instruction=instruction, input=input_text)
|
72 |
|
73 |
def format_output(output):
|
74 |
output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n")
|
75 |
return output
|
76 |
|
77 |
+
def generate_response(role_instruction, conversation_history, new_conversation):
|
78 |
+
# 入力トークン数1024におさまるようにする
|
79 |
+
for _ in range(8):
|
80 |
+
input_text = prepare_input(role_instruction, conversation_history, new_conversation)
|
81 |
+
token_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
|
82 |
+
n = len(token_ids[0])
|
83 |
+
if n + MAX_ASSISTANT_LENGTH <= MAX_INPUT_LENGTH:
|
84 |
+
break
|
85 |
+
else:
|
86 |
+
conversation_history.pop(0)
|
87 |
+
conversation_history.pop(0)
|
88 |
|
89 |
with torch.no_grad():
|
90 |
output_ids = model.generate(
|
|
|
92 |
min_length=n,
|
93 |
max_length=min(MAX_INPUT_LENGTH, n + MAX_ASSISTANT_LENGTH),
|
94 |
temperature=0.7,
|
95 |
+
repetition_penalty=1.0, # 数値を大きくすると、文字列の繰り返しが減る
|
96 |
do_sample=True,
|
97 |
pad_token_id=tokenizer.pad_token_id,
|
98 |
bos_token_id=tokenizer.bos_token_id,
|
|
|
102 |
|
103 |
output = tokenizer.decode(output_ids.tolist()[0])
|
104 |
formatted_output_all = format_output(output)
|
105 |
+
|
106 |
+
response = f"{ASSISTANT_NAME}:{formatted_output_all.split('応答:')[-1].strip()}"
|
107 |
+
conversation_history.append(f"{USER_NAME}:{new_conversation}".replace("\n", "\\n"))
|
108 |
+
conversation_history.append(response.replace("\n", "\\n"))
|
109 |
|
110 |
return formatted_output_all, response
|
111 |
|
112 |
+
role_instruction = [
|
113 |
+
f"{USER_NAME}:きみは「ずんだもん」なのだ。東北ずん子の武器である「ずんだアロー」に変身する妖精またはマスコットなのだ。一人称は「ボク」で語尾に「なのだー」を付けてしゃべるのだ。",
|
114 |
+
f"{ASSISTANT_NAME}:了解したのだ。",
|
115 |
+
f"{USER_NAME}:きみは同じ言葉を繰り返さず、何でも正確に要約して答えられるのだ。",
|
116 |
+
f"{ASSISTANT_NAME}:了解したのだ。",
|
117 |
+
]
|
118 |
+
|
119 |
+
conversation_history = [
|
120 |
+
]
|
121 |
+
|
122 |
questions = [
|
123 |
"日本で一番高い山は?",
|
124 |
"日本で一番広い湖は?",
|
125 |
+
"冗談を言ってください。",
|
126 |
"世界で一番高い山は?",
|
127 |
"世界で一番広い湖は?",
|
128 |
+
"最初の質問は何ですか?",
|
129 |
+
"今何問目?",
|
130 |
+
"自己紹介をしてください。",
|
131 |
]
|
132 |
|
133 |
# 各質問に対して応答を生成して表示
|
134 |
for question in questions:
|
135 |
+
formatted_output_all, response = generate_response(role_instruction, conversation_history, question)
|
136 |
+
print(f"{USER_NAME}:{question}\n{response}\n---")
|
137 |
```
|
138 |
|
139 |
## 出力
|
140 |
|
141 |
```
|
142 |
+
User:日本で一番高い山は?
|
143 |
Assistant:富士山
|
144 |
+
---
|
145 |
+
User:日本で一番広い湖は?
|
146 |
+
Assistant:日本列島には、太平洋、日本海、オホーツク海、インド洋、大西洋など、さまざまな海があります。
|
147 |
+
---
|
148 |
+
User:冗談を言ってください。
|
149 |
+
Assistant:冗談を言おう!
|
150 |
+
---
|
151 |
+
User:世界で一番高い山は?
|
152 |
+
Assistant:世界で最も高い山はエベレストです。
|
153 |
+
---
|
154 |
+
User:世界で一番広い湖は?
|
155 |
+
Assistant:太平洋に浮かぶ世界最大の島、ツバル
|
156 |
+
---
|
157 |
+
User:最初の質問は何ですか?
|
158 |
+
Assistant:私が最初の質問をするとするならば、それは「ずんだもん」です。
|
159 |
+
---
|
160 |
+
User:今何問目?
|
161 |
+
Assistant:今何問目ですか?
|
162 |
+
---
|
163 |
+
User:自己紹介をしてください。
|
164 |
+
Assistant:私は「ずんだもん」というあだ名で呼ばれています。
|
165 |
+
---
|
166 |
```
|
167 |
|
168 |
+
## ChatGPT/GPT-4による説明(少し修正)
|
169 |
|
170 |
+
このコードは、質問に答えるAIアシスタントを実装しています。質問リストに対して、役割指示に従った応答を生成し、会話を表示します。
|
|
|
|
|
171 |
|
172 |
# 評価
|
173 |
+
1000回の「入力」のような質問を行い、それらに対する「応答」に正解の文字列が含まれるかで評価しています。
|
174 |
+
一番正答率が高い10エポック目のモデルを選択しました。(やり過ぎたかもしれないです。)
|
175 |
|
176 |
| 入力 | 応答 | 正答率[%] |
|
177 |
|-----------------------|-------------|-------|
|
178 |
+
| 日本で一番広い湖は? | 琵琶湖 | 89 |
|
179 |
+
| 世界で一番高い山は? | エベレスト | 73 |
|
180 |
|
181 |
# 学習データのフォーマット
|
182 |
|
183 |
[alpaca](https://github.com/tatsu-lab/stanford_alpaca)と同じように、以下のようなフォーマットにしています。
|
184 |
|
185 |
```
|
186 |
+
<s>
|
187 |
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
|
188 |
+
[SEP]
|
189 |
指示:
|
190 |
+
User:きみは「ずんだもん」なのだ。東北ずん子の武器である「ずんだアロー」に変身する妖精またはマスコットなのだ。一人称は「ボク」で語尾に「なのだー」を付けてしゃべるのだ。 Assistant:了解したのだ。 User:きみは同じ言葉を繰り返さず、何でも正確に要約して答えられるのだ。 Assistant:了解したのだ。
|
191 |
+
[SEP]
|
192 |
入力:
|
193 |
User:日本で一番高い山は?
|
194 |
+
[SEP]
|
195 |
応答:
|
196 |
富士山
|
197 |
</s>
|
198 |
```
|
199 |
|
200 |
transformersのコードでtxtファイルを学習する場合、1データ1行のようなので改行コードを一旦`\n`に置き換えています。
|
201 |
+
学習データは[dolly-oasst1-ja.txt](train_data/dolly-oasst1-ja.txt)です。
|
202 |
+
|
203 |
+
また学習データを作った過程のスクリプトとjsonファイルも[train_data](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/tree/main/train_data)に置いておきます。
|
204 |
+
|
205 |
+
手順は、
|
206 |
+
1. 各jsonファイルを作成
|
207 |
+
2. jsonファイルを一つのjsonファイルにマージ
|
208 |
+
3. マージしたjsonファイルを学習データのtxtファイルに変換
|
209 |
+
|
210 |
+
になります。
|
211 |
|
212 |
# 学習のハイパーパラメータ
|
213 |
|
214 |
学習時には以下のハイパーパラメータを使用:
|
215 |
+
|
216 |
+
※VRAMが足りない場合、optimをadafactorにするとVRAM使用量が減りました。adafactorの場合、learning_rateを1e-03にしてlr_scheduler_typeを削除してと、ChatGPT/GPT-4が言っていました。
|
217 |
```
|
218 |
+
venv/Scripts/python.exe transformers/examples/pytorch/language-modeling/run_clm.py ^
|
219 |
--model_name_or_path rinna/japanese-gpt-1b ^
|
220 |
+
--train_file train_data/dolly-oasst1-ja.txt ^
|
221 |
+
--output_dir output ^
|
222 |
--do_train ^
|
223 |
+
--bf16 True ^
|
224 |
+
--tf32 True ^
|
225 |
--optim adamw_bnb_8bit ^
|
226 |
+
--num_train_epochs 10 ^
|
227 |
+
--save_steps 721 ^
|
228 |
+
--logging_steps 72 ^
|
229 |
+
--learning_rate 1e-07 ^
|
230 |
+
--lr_scheduler_type constant ^
|
231 |
--gradient_checkpointing ^
|
232 |
--per_device_train_batch_size 8 ^
|
233 |
+
--save_safetensors True ^
|
234 |
+
--logging_dir logs
|
235 |
```
|
236 |
|
237 |
# ライブラリのバージョン
|
238 |
|
239 |
+
- Transformers 4.28.1
|
240 |
- Pytorch 2.0.0+cu117
|
241 |
+
- Datasets 2.11.0
|
242 |
- Tokenizers 0.13.3
|
243 |
- bitsandbytes 0.37.2
|
244 |
|
245 |
# ライセンス
|
246 |
MITで大丈夫そうです。
|
247 |
|
248 |
+
- [japanese-gpt-1b](rinna/japanese-gpt-1b) - mit
|
249 |
+
- [databricks-dolly-15k-ja](https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja) - CC BY SA 3.0
|
250 |
+
- [oasst1-89k-ja](https://huggingface.co/datasets/kunishou/oasst1-89k-ja) - apache-2.0
|
251 |
+
- [OjousamaTalkScriptDataset](https://github.com/matsuvr/OjousamaTalkScriptDataset) - mit
|
252 |
+
- [train_data/zundamon.json](train_data/zundamon.json) - mit
|