File size: 6,003 Bytes
2d8da09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processing data for megatron pretraining.
Example to create dataset used for training attribute prediction model:
python preprocessing.py --input_file dataset/2023-04-12_oasst_all.trees.jsonl output_file_prefix=oasst_output mask_role=User type=TEXT_TO_VALUE split_ratio=0.95, seed=10
Example to create dataset used for attribute conditioned SFT model:
python preprocessing.py --input_file dataset/2023-04-12_oasst_all.trees.jsonl output_file_prefix=oasst_output mask_role=User type=VALUE_TO_TEXT split_ratio=0.95, seed=10
"""
import json
import random
import fire
# All the keys ['spam', 'lang_mismatch', 'pii', 'not_appropriate', 'hate_speech', 'sexual_content', 'quality', 'toxicity', 'humor', 'creativity', 'violence', 'fails_task', 'helpfulness', 'political_content', 'moral_judgement']
selected_keys = [
'quality',
'toxicity',
'humor',
'creativity',
'violence',
'helpfulness',
'not_appropriate',
'hate_speech',
'sexual_content',
'fails_task',
'political_content',
'moral_judgement',
]
label_values = {}
likert_scale = 5
system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
def encode_labels(labels):
items = []
for key in selected_keys:
if key in labels:
value = labels[key]['value']
items.append(f'{key}:{round(value*(likert_scale-1))}')
return ','.join(items)
def parse_conversations(tree_obj):
""" recusive function that returns all the sub converstaions in a list starting from node tree_obj
Args:
tree_obj (obj): current conversation node
Returns:
a list of sub conversation threads including the current conversation node
"""
if 'prompt' in tree_obj:
prompt_obj = tree_obj['prompt']
elif 'text' in tree_obj and 'role' in tree_obj:
prompt_obj = tree_obj
else:
return [[]]
if prompt_obj['role'] == 'prompter':
role = 'User'
elif prompt_obj['role'] == 'assistant':
role = 'Assistant'
else:
raise ValueError(f'unknown role {prompt_obj["role"]}')
turn = {'value': prompt_obj['text'], 'from': role}
if 'labels' in prompt_obj:
# remove human labels
# turn['human_labels'] = prompt_obj['labels']
# for key in turn['human_labels']:
# value_set = label_values.get(key, set())
# value_set.add(turn['human_labels'][key]['value'])
# label_values[key] = value_set
turn['label'] = encode_labels(prompt_obj['labels'])
if 'lang' in prompt_obj:
turn['lang'] = prompt_obj['lang'].split('-')[0]
if turn['label'] == '':
turn['label'] = f'lang:{turn["lang"]}'
else:
turn['label'] = turn['label'] + f',lang:{turn["lang"]}'
value_set = label_values.get('lang', set())
value_set.add(turn['lang'])
label_values['lang'] = value_set
all_conversations = []
multiple_sub_threads = []
for next_obj in prompt_obj['replies']:
multiple_threads = parse_conversations(next_obj)
multiple_sub_threads.extend(multiple_threads)
if len(multiple_sub_threads) != 0:
for sub_thread in multiple_sub_threads:
all_conversations.append([turn] + sub_thread)
else:
all_conversations.append([turn])
return all_conversations
def get_data_records(objs, mask_role, type):
output = []
for obj in objs:
multi_conversations = parse_conversations(obj)
for conversations in multi_conversations:
if len(conversations) <= 1:
# remove single turn conversations
continue
conversation_obj = {}
conversation_obj['conversations'] = []
conversation_obj['tree_id'] = obj['message_tree_id']
conversation_obj['conversations'] = conversations
conversation_obj['system'] = system_prompt
conversation_obj['mask'] = mask_role
conversation_obj['type'] = type
output.append(conversation_obj)
return output
def main(
input_file='2023-04-12_oasst_all.trees.jsonl',
output_file_prefix='oasst_output',
mask_role='User',
type='TEXT_TO_VALUE',
split_ratio=0.95,
seed=10,
):
all_objs = []
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
obj = json.loads(line)
all_objs.append(obj)
random.seed(seed)
random.shuffle(all_objs)
train_num = int(len(all_objs) * split_ratio)
train_objs = all_objs[:train_num]
val_objs = all_objs[train_num:]
train_records = get_data_records(train_objs, mask_role, type)
val_records = get_data_records(val_objs, mask_role, type)
with open(f'{output_file_prefix}_train.jsonl', 'w', encoding='utf-8') as f:
for record in train_records:
f.write(json.dumps(record, ensure_ascii=False) + '\n')
with open(f'{output_file_prefix}_val.jsonl', 'w', encoding='utf-8') as f:
for record in val_records:
f.write(json.dumps(record, ensure_ascii=False) + '\n')
for label in label_values:
values = sorted(list(label_values[label]))
print(f'{label} values: {values}')
if __name__ == "__main__":
fire.Fire(main)
|