Spaces:
Runtime error
Runtime error
File size: 6,061 Bytes
5b85931 7b903ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Copyright reserved by Yang Hao, Metaverse Developers Association. All rights reserved
This module provides the conversion of Excel formatted data into standardized fine-tuning datasets
Authors: yanghao([email protected])
Date: 2023/09/12 19:23:06
"""
import openpyxl
import os
import json
import tiktoken
from collections import defaultdict
def GetTokenforStr(strText):
encoding = tiktoken.encoding_for_model('gpt-3.5-turbo-0301')
num_tokens = len(encoding.encode(strText))
return num_tokens
def CheckData(messages):
format_errors = defaultdict(int)
if isinstance(messages,dict):
messages=[messages]
for ex in messages:
if not isinstance(ex, dict):
format_errors["data_type"] += 1
continue
messages = ex.get("messages", None)
if not messages:
format_errors["missing_messages_list"] += 1
continue
for message in messages:
if "role" not in message or "content" not in message:
format_errors["message_missing_key"] += 1
if any(k not in ("role", "content", "name") for k in message):
format_errors["message_unrecognized_key"] += 1
if message.get("role", None) not in ("system", "user", "assistant"):
format_errors["unrecognized_role"] += 1
content = message.get("content", None)
if not content or not isinstance(content, str):
format_errors["missing_content"] += 1
if not any(message.get("role", None) == "assistant" for message in messages):
format_errors["example_missing_assistant_message"] += 1
if format_errors:
return False,format_errors
else:
return True,{}
# 计算 encode 返回列表的长度
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")
num_tokens = 0
if type(messages)!=type([1]):
messages=[messages]
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
# print(value)
value=str(value)
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def DataFormat(inputPath,OpenAPItype):
#一、加载含有用户输入和GPT的输出的文件
try:
inputPath=inputPath.name
except:
inputPath = inputPath
book=openpyxl.load_workbook(inputPath)
sheet=book.active
maxrow=sheet.max_row
if OpenAPItype[:7]=='gpt-3.5':
# 二、遍历准备好的数据集文件并格式化成微调需要的格式
print("训练用例条数:{}".format(maxrow-1))
messages=[]
outputPath="{}/Format_{}.jsonl".format(os.path.dirname(inputPath),os.path.splitext(os.path.basename(inputPath))[0]) #格式化后输出的地址
with open(outputPath,'w',encoding='utf-8')as w:
for i in range(2,maxrow+1):
systemJson={"role": "system", "content": sheet.cell(i,1).value}
userJson={"role": "user", "content": sheet.cell(i,2).value}
AssistantJson = {"role": "assistant", "content": sheet.cell(i, 3).value}
messagesJson={"messages": [systemJson,userJson,AssistantJson]}
messAgeTokens=num_tokens_from_messages(messagesJson, 'gpt-3.5-turbo-0301')
if messAgeTokens>4096:
print('用例{} tokens数为{},无法发送'.format(i,messAgeTokens))
else:
json.dump(messagesJson, w, ensure_ascii=False)
w.write('\n')
messages.append(messagesJson)
messagesTokens=num_tokens_from_messages(messages, 'gpt-3.5-turbo-0301')
cost=messagesTokens/1000*0.008*3
ans='整个微调数据集token总数:{}\n训练费用:经过3个epoch训练,参与训练总token数:{}。\n预计基于该jsonl微调数据的训练成本约为:{:.3f}美元'.format(messagesTokens,messagesTokens*3,cost)
print(ans)
ret,errorsItem=CheckData(messages)
if not ret:
ans+="\n\n格式检查:有格式问题!数据错误统计:"
print("格式检查:有格式问题!数据错误统计:")
for k, v in errorsItem.items():
ans+=f"\n{k}: {v}"
print(f"{k}: {v}")
else:
ans += "\n格式检查:检查完毕!该微调数据集无格式问题。"
print("格式检查:检查完毕!该微调数据集无格式问题。")
return outputPath,ans
# print(sheet.cell(i,1).value)
if __name__=="__main__":
pass
|