File size: 6,061 Bytes
5b85931
 
 
 
 
 
 
 
 
 
7b903ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Copyright reserved by Yang Hao, Metaverse Developers Association. All rights reserved

This module provides the conversion of Excel formatted data into standardized fine-tuning datasets

Authors: yanghao([email protected])
Date: 2023/09/12 19:23:06
"""
import openpyxl
import os
import json
import tiktoken
from collections import defaultdict
def GetTokenforStr(strText):
    encoding = tiktoken.encoding_for_model('gpt-3.5-turbo-0301')
    num_tokens = len(encoding.encode(strText))
    return num_tokens
def CheckData(messages):

    format_errors = defaultdict(int)
    if isinstance(messages,dict):
        messages=[messages]
    for ex in messages:
        if not isinstance(ex, dict):
            format_errors["data_type"] += 1
            continue

        messages = ex.get("messages", None)
        if not messages:
            format_errors["missing_messages_list"] += 1
            continue

        for message in messages:
            if "role" not in message or "content" not in message:
                format_errors["message_missing_key"] += 1

            if any(k not in ("role", "content", "name") for k in message):
                format_errors["message_unrecognized_key"] += 1

            if message.get("role", None) not in ("system", "user", "assistant"):
                format_errors["unrecognized_role"] += 1

            content = message.get("content", None)
            if not content or not isinstance(content, str):
                format_errors["missing_content"] += 1

        if not any(message.get("role", None) == "assistant" for message in messages):
            format_errors["example_missing_assistant_message"] += 1
    if format_errors:

        return False,format_errors
    else:

        return True,{}
# 计算 encode 返回列表的长度
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
    """Returns the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        print("Warning: model not found. Using cl100k_base encoding.")
        encoding = tiktoken.get_encoding("cl100k_base")
    if model == "gpt-3.5-turbo":
        print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.")
        return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
    elif model == "gpt-4":
        print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
        return num_tokens_from_messages(messages, model="gpt-4-0314")
    elif model == "gpt-3.5-turbo-0301":
        tokens_per_message = 4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
        tokens_per_name = -1  # if there's a name, the role is omitted
    elif model == "gpt-4-0314":
        tokens_per_message = 3
        tokens_per_name = 1
    else:
        raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")
    num_tokens = 0
    if type(messages)!=type([1]):
        messages=[messages]
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
          #  print(value)
            value=str(value)
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
    return num_tokens

def DataFormat(inputPath,OpenAPItype):
    #一、加载含有用户输入和GPT的输出的文件
    try:
        inputPath=inputPath.name
    except:
        inputPath = inputPath
    book=openpyxl.load_workbook(inputPath)
    sheet=book.active
    maxrow=sheet.max_row
    if OpenAPItype[:7]=='gpt-3.5':
        # 二、遍历准备好的数据集文件并格式化成微调需要的格式
        print("训练用例条数:{}".format(maxrow-1))
        messages=[]
        outputPath="{}/Format_{}.jsonl".format(os.path.dirname(inputPath),os.path.splitext(os.path.basename(inputPath))[0]) #格式化后输出的地址
        with open(outputPath,'w',encoding='utf-8')as w:
            for i in range(2,maxrow+1):
                systemJson={"role": "system", "content": sheet.cell(i,1).value}
                userJson={"role": "user", "content": sheet.cell(i,2).value}
                AssistantJson = {"role": "assistant", "content": sheet.cell(i, 3).value}
                messagesJson={"messages": [systemJson,userJson,AssistantJson]}


                messAgeTokens=num_tokens_from_messages(messagesJson, 'gpt-3.5-turbo-0301')
                if messAgeTokens>4096:
                    print('用例{} tokens数为{},无法发送'.format(i,messAgeTokens))
                else:
                    json.dump(messagesJson, w, ensure_ascii=False)
                    w.write('\n')
                    messages.append(messagesJson)
        messagesTokens=num_tokens_from_messages(messages, 'gpt-3.5-turbo-0301')
        cost=messagesTokens/1000*0.008*3

        ans='整个微调数据集token总数:{}\n训练费用:经过3个epoch训练,参与训练总token数:{}。\n预计基于该jsonl微调数据的训练成本约为:{:.3f}美元'.format(messagesTokens,messagesTokens*3,cost)
        print(ans)
        ret,errorsItem=CheckData(messages)
        if not ret:
            ans+="\n\n格式检查:有格式问题!数据错误统计:"
            print("格式检查:有格式问题!数据错误统计:")
            for k, v in errorsItem.items():
                ans+=f"\n{k}: {v}"
                print(f"{k}: {v}")
        else:
            ans += "\n格式检查:检查完毕!该微调数据集无格式问题。"
            print("格式检查:检查完毕!该微调数据集无格式问题。")
        return outputPath,ans


 #   print(sheet.cell(i,1).value)
if __name__=="__main__":
    pass