Upload train_data
Browse files- .gitattributes +2 -0
- train_data/databricks-dolly-15k-ja.json +3 -0
- train_data/make_json_from_oasst1_ja.py +107 -0
- train_data/oasst1_ja.json +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
train_data/databricks-dolly-15k-ja.json filter=lfs diff=lfs merge=lfs -text
|
36 |
+
train_data/oasst1_ja.json filter=lfs diff=lfs merge=lfs -text
|
train_data/databricks-dolly-15k-ja.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7053bd9081719ea68765e0e743c6e222fd78de65a41de3411d11362b631815e
|
3 |
+
size 17061804
|
train_data/make_json_from_oasst1_ja.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from datasets import load_dataset
|
3 |
+
import json
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
ds = load_dataset("OpenAssistant/oasst1")
|
7 |
+
train = ds['train']
|
8 |
+
val = ds['validation']
|
9 |
+
|
10 |
+
# データフレームを連結
|
11 |
+
df = pd.concat([pd.DataFrame(train), pd.DataFrame(val)])
|
12 |
+
|
13 |
+
ds_ja = load_dataset("kunishou/oasst1-89k-ja")
|
14 |
+
|
15 |
+
# データフレーム
|
16 |
+
df_ja = pd.DataFrame(ds_ja['train'])
|
17 |
+
|
18 |
+
# 'message_id' をキーにして df_ja と df を結合し、df_ja の列名が優先されるようにします。
|
19 |
+
merged_df = df_ja.merge(df, on='message_id', how='left', suffixes=('', '_y'))
|
20 |
+
|
21 |
+
# 重複した列を削除します。
|
22 |
+
merged_df = merged_df.drop(columns=[col for col in merged_df.columns if col.endswith('_y')])
|
23 |
+
|
24 |
+
# 同じmessage_tree_idでデータをグループ化
|
25 |
+
grouped = merged_df.groupby('message_tree_id')
|
26 |
+
|
27 |
+
def find_longest_chain(group, root_message_id):
|
28 |
+
max_length = 0 # 最長のチェーンの長さを初期化
|
29 |
+
min_toxicity = 2.0 # 最小の毒性を初期化
|
30 |
+
leaf_id = None # 最長のチェーンの末端のメッセージIDを初期化
|
31 |
+
|
32 |
+
# グループ内の各行に対して処理を行う
|
33 |
+
for _, row in group.iterrows():
|
34 |
+
current_id = row['message_id']
|
35 |
+
if current_id == root_message_id:
|
36 |
+
continue # ルートメッセージを処理しない
|
37 |
+
|
38 |
+
chain_length = 0 # チェーンの長さを初期化
|
39 |
+
toxicity = 1.0 # 毒性を初期化
|
40 |
+
|
41 |
+
# ルートメッセージにたどり着くまでチェーンを辿る
|
42 |
+
while current_id != 'nan':
|
43 |
+
chain_length += 1
|
44 |
+
detoxify_data = group.loc[group['message_id'] == current_id, 'detoxify'].iloc[0]
|
45 |
+
toxicity = detoxify_data['toxicity'] if detoxify_data is not None else 1.0 # 毒性がない場合は1.0を代入
|
46 |
+
current_id = group.loc[group['message_id'] == current_id, 'parent_id'].values[0]
|
47 |
+
|
48 |
+
# チェーンが現在の最長のチェーンと同じか長く、毒性が現在の最小の毒性以下の場合
|
49 |
+
if chain_length >= max_length and toxicity <= min_toxicity:
|
50 |
+
max_length = chain_length
|
51 |
+
min_toxicity = toxicity
|
52 |
+
leaf_id = row['message_id'] # 末端のメッセージIDを更新
|
53 |
+
|
54 |
+
return leaf_id # 最長のチェーンの末端のメッセージIDを返す
|
55 |
+
|
56 |
+
leafs = [] # 最長チェーンの末端のメッセージIDを格納するリストを初期化
|
57 |
+
|
58 |
+
|
59 |
+
for _, group in tqdm(grouped):
|
60 |
+
# parent_idがnullのメッセージを見つける(ルートメッセージ)
|
61 |
+
root_message = group[group['parent_id'] == 'nan'].iloc[0]
|
62 |
+
root_message_id = root_message['message_id']
|
63 |
+
|
64 |
+
# 英語かスペイン語か日本語
|
65 |
+
if root_message['lang'] in ['en', 'es', 'ja']:
|
66 |
+
leaf_id = find_longest_chain(group, root_message_id)
|
67 |
+
leafs.append(leaf_id)
|
68 |
+
|
69 |
+
# 最も深いメッセージから辿ってメッセージを作成する関数
|
70 |
+
def create_message_path(message):
|
71 |
+
role = "User" if message['role'] == "prompter" else "Assistant" # メッセージの役割に応じて、UserかAssistantを選択
|
72 |
+
formatted_message = f"{role}:{message['text_ja']}" # 役割とメッセージを連結
|
73 |
+
if pd.isnull(message['parent_id']): # 親メッセージがない場合
|
74 |
+
return [formatted_message]
|
75 |
+
else:
|
76 |
+
parent_messages = merged_df[merged_df['message_id'] == message['parent_id']] # 親メッセージを検索
|
77 |
+
if parent_messages.empty: # 親メッセージが見つからない場合
|
78 |
+
return [formatted_message]
|
79 |
+
parent_message = parent_messages.iloc[0] # 親メッセージを取得
|
80 |
+
# 親メッセージから再帰的にメッセージを作成し、現在のメッセージを追加
|
81 |
+
return create_message_path(parent_message) + [formatted_message]
|
82 |
+
|
83 |
+
result = [] # 結果を格納するリストを初期化
|
84 |
+
for leaf_id in tqdm(leafs): # 進捗状況を表示するためにtqdmを使用
|
85 |
+
leaf_message = merged_df[merged_df['message_id'] == leaf_id].iloc[0] # 末端のメッセージを取得
|
86 |
+
leaf_text = create_message_path(leaf_message) # 末端のメッセージからメッセージのチェーンを作成
|
87 |
+
leaf_json = {}
|
88 |
+
odd = len(leaf_text) % 2
|
89 |
+
if len(leaf_text) <= 3: # メッセージのチェーンが3つ以下の場合
|
90 |
+
leaf_json['instruction'] = leaf_text[0].replace("User:", "", 1)
|
91 |
+
leaf_json['input'] = ""
|
92 |
+
leaf_json['output'] = leaf_text[1].replace("Assistant:", "", 1)
|
93 |
+
else: # メッセージのチェーンが4つ以上の場合
|
94 |
+
instruction = ""
|
95 |
+
for t in leaf_text[0:-2-odd]: # 最後の2つのメッセージを除いて、指示文を作成
|
96 |
+
instruction += t + " "
|
97 |
+
leaf_json['instruction'] = instruction
|
98 |
+
leaf_json['input'] = leaf_text[-2-odd] # 入力メッセージを設定
|
99 |
+
leaf_json['output'] = leaf_text[-1-odd].replace("Assistant:", "", 1) # 出力メッセージを設定
|
100 |
+
result.append(leaf_json) # 結果リスト��JSONを追加
|
101 |
+
|
102 |
+
# JSON データを作成
|
103 |
+
json_data = json.dumps(result, ensure_ascii=False, indent=4)
|
104 |
+
|
105 |
+
# JSON をファイルに保存
|
106 |
+
with open("oasst1_ja.json", "w", encoding="utf-8") as json_file:
|
107 |
+
json_file.write(json_data)
|
train_data/oasst1_ja.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f726ad7113b60c4bfe322884d385c9d9c7da1c99b0e7deeac27dc6934ac3b3c1
|
3 |
+
size 14463337
|