File size: 3,200 Bytes
6ef31de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Split long conversations based on certain max length.

Usage: python3 -m fastchat.data.split_long_conversation \
    --in sharegpt_clean.json \
    --out sharegpt_split.json \
    --model-name-or-path $<model-name>
"""
import argparse
import json
from typing import Dict, Sequence, Optional

import transformers
import tqdm

from fastchat import conversation as conversation_lib


def split_sample(sample, start_idx, end_idx):
    assert (end_idx - start_idx) % 2 == 0
    return {
        "id": sample["id"] + "_" + str(start_idx),
        "conversations": sample["conversations"][start_idx:end_idx],
    }


def split_contents(content, begin, end, tokenizer, max_length):
    """
    Keep the maximum round of conversations within the max token length constraint
    """
    content = content[begin:end]
    new_content = []

    for sample in tqdm.tqdm(content):
        tokenized_lens = []
        conversations = sample["conversations"]
        conversations = conversations[: len(conversations) // 2 * 2]
        for c in conversations:
            length = len(tokenizer(c["value"]).input_ids) + 5
            tokenized_lens.append(length)

        start_idx = 0
        cur_len = 0
        sample
        assert len(conversations) % 2 == 0, f"id: {sample['id']}"
        for i in range(0, len(conversations), 2):
            tmp_len = tokenized_lens[i] + tokenized_lens[i + 1]
            if cur_len + tmp_len > max_length:
                new_content.append(split_sample(sample, start_idx, i))
                start_idx = i
                cur_len = 0
            elif i == len(conversations) - 2:
                new_content.append(split_sample(sample, start_idx, i + 2))

            cur_len += tmp_len

    return new_content


def filter_invalid_roles(content):
    new_content = []
    for i, c in enumerate(content):
        roles = ["human", "gpt"]
        if len(c["conversations"]) <= 0:
            continue

        valid = True
        for j, s in enumerate(c["conversations"]):
            if s["from"] != roles[j % 2]:
                valid = False
                break

        if valid:
            new_content.append(c)

    return new_content


def main(args):
    content = json.load(open(args.in_file, "r"))
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        model_max_length=args.max_length,
        padding_side="right",
        use_fast=False,
    )
    new_content = split_contents(
        content, args.begin, args.end, tokenizer, args.max_length
    )
    new_content = filter_invalid_roles(new_content)

    print(f"total: {len(content)}, new: {len(new_content)}")
    json.dump(new_content, open(args.out_file, "w"), indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--in-file", type=str, required=True)
    parser.add_argument("--out-file", type=str, default="sharegpt_split.json")
    parser.add_argument("--begin", type=int)
    parser.add_argument("--end", type=int)
    parser.add_argument("--model-name-or-path", type=str, required=True)
    parser.add_argument("--max-length", type=int, default=2048)
    args = parser.parse_args()
    main(args)