File size: 2,997 Bytes
42303c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
This script is largely copied from the Vicuna repo: https://github.com/lm-sys/FastChat/blob/main/fastchat/data/split_long_conversation.py
We fixed a bug in `split_one_sample`, which previously includes long conversations in the processed data. Now we skip these long conversations.
"""
import argparse
from concurrent.futures import ProcessPoolExecutor
import json
import transformers
from tqdm import tqdm

def shareGPT_pipeline(tokenizer, raw_datasets, overwrite_cache):
    
    def preprocess_conversation(convo):
        key_mapping = {"role" : "from", "content" : "value"}
        value_mapping = {"user" : "user", "human" : "user", "gpt" : "assistant", 'system': 'assitant', 'bing': 'assitant', 'chatgpt': 'assitant', 'bard': 'assitant'}
        # mapping = {"human" : "user", "gpt" : "assitant"}
        if value_mapping[convo[0][key_mapping['role']]] != 'user': 
            convo = convo[1:]
        preproc_convos_user = [{"role": 'user', "content": convo_elem[key_mapping['content']]} for i, convo_elem in enumerate(convo) if (i % 2 == 0 and value_mapping[convo_elem[key_mapping['role']]] == 'user')]
        preproc_convos_assistant = [{"role": 'assistant', "content": convo_elem[key_mapping['content']]} for i, convo_elem in enumerate(convo) if (i % 2 == 1 and value_mapping[convo_elem[key_mapping['role']]] == 'assistant')]
        if len(preproc_convos_user) != len(preproc_convos_assistant):
            return []
        preproc_convos = [conv_elem for pair in zip(preproc_convos_user, preproc_convos_assistant) for conv_elem in pair]
        return preproc_convos

    def filter_incorrect_conversations(examples):
        convos = examples["conversations"]
        ids_to_remove = [True if preprocess_conversation(convo) == [] else False for convo in convos]
        return { "ids_to_remove" : ids_to_remove, }

    def formatting_prompts_func(examples):
        convos = examples["conversations"]
        # preproc_convos = [convo for convo in convos if (convo[0]['from'] == 'human' or convo[0]['from'] == 'user')]
        preproc_convos = [preprocess_conversation(convo) for convo in convos]
        # preproc_convos2 = [preproc_convo for preproc_convo in preproc_convos if preproc_convo[0]['role'] == 'user']
        texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for i, convo in enumerate(preproc_convos)]
        return { "text" : texts,}

    filtered_datasets = raw_datasets.filter(lambda example: example['conversations'] != [], load_from_cache_file=not overwrite_cache,)
    dataset = filtered_datasets.map(filter_incorrect_conversations, batched = True, load_from_cache_file=not overwrite_cache,)
    filtered_datasets2 = dataset.filter(lambda example: example['ids_to_remove'] == False, load_from_cache_file=not overwrite_cache,)
    raw_datasets_preprocessed = filtered_datasets2.map(formatting_prompts_func, batched = True, load_from_cache_file=not overwrite_cache,)    
    
    return raw_datasets_preprocessed