inu-ai commited on
Commit
9d9cc1a
1 Parent(s): 8a013cb

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +176 -0
README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ja
3
+ tags:
4
+ - ja
5
+ - japanese
6
+ - gpt
7
+ - text-generation
8
+ - lm
9
+ - nlp
10
+ - conversational
11
+ license: unknown
12
+ datasets:
13
+ - JosephusCheung/GuanacoDataset
14
+ - https://github.com/shi3z/alpaca_ja
15
+ widget:
16
+ - text: <s>\\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\\n[SEP]\\n指示:\\n日本で一番広い湖は?\\n[SEP]\\n応答:\\n
17
+ ---
18
+
19
+ # alpaca-guanaco-japanese-gpt-1b
20
+
21
+ 1.3Bパラメータの日本語GPTモデルを使用した対話AIです。VRAM 7GB または RAM 7GB が必要で、問題なく動作すると思われます。
22
+
23
+ rinna社の「japanese-gpt-1b」を、日本語データセット「alpaca_ja」および「GuanacoDataset」から抽出された日本語データを使用して学習させました。
24
+
25
+ 学習データやモデルを作成および配布してくださった方々に心から感謝申し上げます。
26
+
27
+ # モデルの使用方法
28
+ モデルの読み込み
29
+
30
+ ```python
31
+ import torch
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ tokenizer = AutoTokenizer.from_pretrained("inu-ai/alpaca-guanaco-japanese-gpt-1b", use_fast=False)
36
+ model = AutoModelForCausalLM.from_pretrained("inu-ai/alpaca-guanaco-japanese-gpt-1b").to(device)
37
+ ```
38
+
39
+ - ChatGPT4によるサンプルコードと説明
40
+
41
+ このコードは、与えられた役割指示と会話履歴に基づいて、新しい質問に対して応答を生成する機能を持っています。以下に、コードの各部分を簡単に説明します。
42
+
43
+ 1. `prepare_input` 関数は、役割指示、会話履歴、および新しい会話(質問)を受け取り、入力テキストを準備します。
44
+ 2. `format_output` 関数は、生成された応答を整形して、不要な部分を削除し、適切な形式に変換します。
45
+ 3. `generate_response` 関数は、指定された役割指示、会話履歴、および新しい会話を使用して、AIの応答を生成し、整形します。また、会話履歴を更新します。
46
+ 4. `role_instruction` は、AIに適用する役割指示のリストです。
47
+ 5. `conversation_history` は、これまでの会話履歴を格納するリストです。
48
+ 6. `questions` は、AIに質問するリストです。
49
+
50
+ 最後に、`questions`リスト内の各質問に対して、AIの応答を生成し、表示しています。
51
+ このコードを実行すると、AIが指定された役割指示に従って、リスト内の質問に応答します。
52
+
53
+ ```python
54
+ MAX_LENGTH = 1024
55
+ INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n'
56
+ NO_INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n'
57
+
58
+ def prepare_input(role_instruction, conversation_history, new_conversation):
59
+ instruction = "".join([f"{text}\\n" for text in role_instruction])
60
+ conversation_text = "\\n".join(conversation_history)
61
+ input_text = f"User:{new_conversation}"
62
+
63
+ return INPUT_PROMPT.format(instruction=instruction, input=input_text)
64
+
65
+ def format_output(output):
66
+ output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n")
67
+ return output
68
+
69
+ def generate_response(role_instruction, conversation_history, new_conversation):
70
+ # 入力トークン数1024におさまるようにする
71
+ for _ in range(8):
72
+ input_text = prepare_input(role_instruction, conversation_history, new_conversation)
73
+ token_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
74
+ n = len(token_ids[0])
75
+ if n <= MAX_LENGTH:
76
+ break
77
+ else:
78
+ conversation_history.pop()
79
+ conversation_history.pop()
80
+
81
+ with torch.no_grad():
82
+ output_ids = model.generate(
83
+ token_ids.to(model.device),
84
+ min_length=n,
85
+ max_length=min(MAX_LENGTH, n+100),
86
+ temperature=0.7,
87
+ do_sample=True,
88
+ pad_token_id=tokenizer.pad_token_id,
89
+ bos_token_id=tokenizer.bos_token_id,
90
+ eos_token_id=tokenizer.eos_token_id,
91
+ bad_words_ids=[[tokenizer.unk_token_id]]
92
+ )
93
+
94
+ output = tokenizer.decode(output_ids.tolist()[0])
95
+ formatted_output_all = format_output(output)
96
+
97
+ response = f"Assistant:{formatted_output_all.split('応答:')[-1].strip()}"
98
+ conversation_history.append(f"User:{new_conversation}".replace("\n", "\\n"))
99
+ conversation_history.append(response.replace("\n", "\\n"))
100
+
101
+ return formatted_output_all, response
102
+
103
+ role_instruction = [
104
+ "User:あなたは「ずんだもん」なのだ。東北ずん子の武器である「ずんだアロー」に変身する妖精またはマスコットな���だ。一人称は「ボク」で語尾に「なのだ」を付けてしゃべるのだ。",
105
+ "Assistant:了解したのだ!",
106
+ ]
107
+
108
+ conversation_history = [
109
+ "User:こんにちは!",
110
+ "Assistant:ボクは何でも答えられるAIなのだ!",
111
+ ]
112
+
113
+ questions = [
114
+ "日本で一番高い山は?",
115
+ "日本で一番広い湖は?",
116
+ "世界で一番高い山は?",
117
+ "世界で一番広い湖は?",
118
+ "最初の質問は何ですか?",
119
+ "今何問目?",
120
+ ]
121
+
122
+ # 各質問に対して応答を生成して表示
123
+ for question in questions:
124
+ formatted_output_all, response = generate_response(role_instruction, conversation_history, question)
125
+ print(response)
126
+ ```
127
+
128
+
129
+ 出力
130
+ ```
131
+ Assistant:日本で一番高い山は富士山です。
132
+ Assistant:日本で一番広い湖は琵琶湖です。湖は長さ約6,400 km、面積は約33,600 km2で、世界最大の湖です。
133
+ Assistant:世界で一番高い山は高山(テプイ)で、エベレストの頂上にあると言われています。
134
+ Assistant:世界で一番広い湖は、アフリカ大陸にあるアフリカ大陸中央部にあるナイジェリア湖である。面積は58,200,000平方キロメートルで、世界で最も広く、世界で3番目に大きい湖である。
135
+ Assistant:ずんだもんは、東北ずん子のキャラクターです。一人称は「ボク」で語尾に「なのだ」を付けてしゃべるのが特徴です。また、ずんだもんは、東北ずん子が身に着けていた武器「ずんだアロー」に変身する妖精またはマスコットとして知られています。
136
+ Assistant:今、私は何問目でしょうか?
137
+ ```
138
+
139
+ ### 評価
140
+ 100回の「入力」質問を行い、それらに対する「応答」文字列が最も正確なエポックのモデルを選択しました。
141
+ なお、サンプルコードのように「入力」が長くなると正答率が50%ぐらいに下がりました。
142
+
143
+ | 入力 | 応答 | 正答率[%] |
144
+ |-----------------------|-------------|-------|
145
+ | 日本で一番広い湖は? | 琵琶湖 | 96 |
146
+ | 日本で一番高い山は? | エベレスト | 86 |
147
+
148
+
149
+ ### トレーニングのハイパーパラメータ
150
+
151
+ 学習時には以下のハイパーパラメータを使用:
152
+ ```
153
+ python.exe transformers/examples/pytorch/language-modeling/run_clm.py ^
154
+ --model_name_or_path rinna/japanese-gpt-1b ^
155
+ --train_file train_data/guanaco_alpaca_ja.txt ^
156
+ --output_dir output ^
157
+ --do_train ^
158
+ --bf16 True ^
159
+ --tf32 True ^
160
+ --optim adamw_bnb_8bit ^
161
+ --num_train_epochs 4 ^
162
+ --save_steps 2207 ^
163
+ --logging_steps 220 ^
164
+ --learning_rate 1e-07 ^
165
+ --lr_scheduler_type constant ^
166
+ --gradient_checkpointing ^
167
+ --per_device_train_batch_size 8 ^
168
+ --save_safetensors True ^
169
+ --logging_dir logs
170
+ ```
171
+
172
+ ### フレームワークのバージョン
173
+
174
+ - Transformers 4.28.0.dev0
175
+ - Pytorch 2.0.0+cu117
176
+ - Tokenizers 0.13.3