File size: 5,903 Bytes
11fa0f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import os
import sys
import tqdm
import pandas as pd
import numpy as np
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer


def get_statistics_for_messages_data(data_path):
    # load dataset
    dataset = load_dataset("json", data_files={"train": data_path})
    # tokenize dataset
    tokenizer = AutoTokenizer.from_pretrained("/net/nfs.cirrascale/allennlp/yizhongw/hf_llama_models/7B", use_fast=False)
    # get statistics
    num_instances = len(dataset["train"])
    num_of_turns = [len(instance["messages"]) for instance in dataset["train"]]
    user_prompt_lengths = []
    assistant_response_lengths = []
    instance_lengths = []
    for instance in tqdm.tqdm(dataset["train"], desc="Processing instances"):
        instance_length = 0
        for message in instance["messages"]:
            if message["role"] == "user":
                user_prompt_lengths.append(len(tokenizer(message["content"], truncation=False, add_special_tokens=False)["input_ids"]))
                instance_length += user_prompt_lengths[-1]
            elif message["role"] == "assistant":
                assistant_response_lengths.append(len(tokenizer(message["content"], truncation=False, add_special_tokens=False)["input_ids"]))
                instance_length += assistant_response_lengths[-1]
        instance_lengths.append(instance_length)

    top_100_longest_instances = np.argsort(instance_lengths)[-100:][::-1].tolist()
    top_100_longest_instances = [dataset["train"][i]["id"] for i in top_100_longest_instances]

    result = {
        "num_instances": num_instances,
        "turns_summary": pd.Series(num_of_turns).describe(),
        "user_prompt_lengths_summary": pd.Series(user_prompt_lengths).describe(),
        "assistant_response_lengths_summary": pd.Series(assistant_response_lengths).describe(),
        "total_lengths_summary": pd.Series(instance_lengths).describe(),
        "num_instances_with_total_length_gt_512": np.sum(np.array(instance_lengths) > 512),
        "num_instances_with_total_length_gt_768": np.sum(np.array(instance_lengths) > 768),
        "num_instances_with_total_length_gt_1024": np.sum(np.array(instance_lengths) > 1024),
        "num_instances_with_total_length_gt_1536": np.sum(np.array(instance_lengths) > 1536),
        "num_instances_with_total_length_gt_2048": np.sum(np.array(instance_lengths) > 2048),
        "num_instances_with_total_length_gt_4096": np.sum(np.array(instance_lengths) > 4096),
        "top_100_longest_instances": top_100_longest_instances,
    }
          
    # convert everything to dict or scalar
    for key, value in result.items():
        if isinstance(value, pd.Series):
            result[key] = value.to_dict()
        elif isinstance(value, np.ndarray):
            result[key] = value.tolist()
        elif isinstance(value, np.int64):
            result[key] = int(value)

    return result

def get_statistics_for_prompt_completion_data(data_path):
    # load dataset
    dataset = load_dataset("json", data_files={"train": data_path})
    prompts = [instance["prompt"] for instance in dataset["train"]]
    completions = [instance["completion"] for instance in dataset["train"]]
    # tokenize dataset
    tokenizer = AutoTokenizer.from_pretrained("/net/nfs.cirrascale/allennlp/yizhongw/hf_llama_models/7B")
    tokenized_prompts = tokenizer(prompts, truncation=False, add_special_tokens=False)
    tokenized_completions = tokenizer(completions, truncation=False, add_special_tokens=False)
    # get statistics
    num_instances = len(dataset["train"])
    prompt_lengths = [len(tokenized_prompts["input_ids"][i]) for i in range(num_instances)]
    completion_lengths = [len(tokenized_completions["input_ids"][i]) for i in range(num_instances)]
    prompt_completion_lengths = [prompt_lengths[i] + completion_lengths[i] for i in range(num_instances)]

    result = {
        "num_instances": num_instances,
        "prompt_lengths_summary": pd.Series(prompt_lengths).describe(),
        "completion_lengths_summary": pd.Series(completion_lengths).describe(),
        "prompt_completion_lengths_summary": pd.Series(prompt_completion_lengths).describe(),
        "num_instances_with_prompt_length_gt_512": np.sum(np.array(prompt_lengths) > 512),
        "num_instances_with_completion_length_gt_512": np.sum(np.array(completion_lengths) > 512),
        "num_instances_with_prompt_completion_length_gt_512": np.sum(np.array(prompt_completion_lengths) > 512),
        "num_instances_with_completion_length_gt_768": np.sum(np.array(completion_lengths) > 768),
        "num_instances_with_prompt_completion_length_gt_1024": np.sum(np.array(prompt_completion_lengths) > 1024),
    }

    # convert everything to dict or scalar
    for key, value in result.items():
        if isinstance(value, pd.Series):
            result[key] = value.to_dict()
        elif isinstance(value, np.ndarray):
            result[key] = value.tolist()
        elif isinstance(value, np.int64):
            result[key] = int(value)
    
    return result


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--save_path", type=str, help="Path to save the statistics.")
    args = parser.parse_args()
    
    with open(args.data_path, "r") as f:
        sample = json.loads(f.readline())
    if "prompt" in sample:
        statistics = get_statistics_for_prompt_completion_data(args.data_path)
    elif "messages" in sample:
        statistics = get_statistics_for_messages_data(args.data_path)
    else:
        raise ValueError("Invalid data format - the data should be either prompt completion data or messages data.")

    print(json.dumps(statistics, indent=4))

    if args.save_path is not None:
        with open(args.save_path, "w") as f:
            json.dump(statistics, f, indent=4)