YangHao520's picture
Update DataFormat.py
5b85931
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Copyright reserved by Yang Hao, Metaverse Developers Association. All rights reserved
This module provides the conversion of Excel formatted data into standardized fine-tuning datasets
Authors: yanghao([email protected])
Date: 2023/09/12 19:23:06
"""
import openpyxl
import os
import json
import tiktoken
from collections import defaultdict
def GetTokenforStr(strText):
encoding = tiktoken.encoding_for_model('gpt-3.5-turbo-0301')
num_tokens = len(encoding.encode(strText))
return num_tokens
def CheckData(messages):
format_errors = defaultdict(int)
if isinstance(messages,dict):
messages=[messages]
for ex in messages:
if not isinstance(ex, dict):
format_errors["data_type"] += 1
continue
messages = ex.get("messages", None)
if not messages:
format_errors["missing_messages_list"] += 1
continue
for message in messages:
if "role" not in message or "content" not in message:
format_errors["message_missing_key"] += 1
if any(k not in ("role", "content", "name") for k in message):
format_errors["message_unrecognized_key"] += 1
if message.get("role", None) not in ("system", "user", "assistant"):
format_errors["unrecognized_role"] += 1
content = message.get("content", None)
if not content or not isinstance(content, str):
format_errors["missing_content"] += 1
if not any(message.get("role", None) == "assistant" for message in messages):
format_errors["example_missing_assistant_message"] += 1
if format_errors:
return False,format_errors
else:
return True,{}
# 计算 encode 返回列表的长度
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")
num_tokens = 0
if type(messages)!=type([1]):
messages=[messages]
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
# print(value)
value=str(value)
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def DataFormat(inputPath,OpenAPItype):
#一、加载含有用户输入和GPT的输出的文件
try:
inputPath=inputPath.name
except:
inputPath = inputPath
book=openpyxl.load_workbook(inputPath)
sheet=book.active
maxrow=sheet.max_row
if OpenAPItype[:7]=='gpt-3.5':
# 二、遍历准备好的数据集文件并格式化成微调需要的格式
print("训练用例条数:{}".format(maxrow-1))
messages=[]
outputPath="{}/Format_{}.jsonl".format(os.path.dirname(inputPath),os.path.splitext(os.path.basename(inputPath))[0]) #格式化后输出的地址
with open(outputPath,'w',encoding='utf-8')as w:
for i in range(2,maxrow+1):
systemJson={"role": "system", "content": sheet.cell(i,1).value}
userJson={"role": "user", "content": sheet.cell(i,2).value}
AssistantJson = {"role": "assistant", "content": sheet.cell(i, 3).value}
messagesJson={"messages": [systemJson,userJson,AssistantJson]}
messAgeTokens=num_tokens_from_messages(messagesJson, 'gpt-3.5-turbo-0301')
if messAgeTokens>4096:
print('用例{} tokens数为{},无法发送'.format(i,messAgeTokens))
else:
json.dump(messagesJson, w, ensure_ascii=False)
w.write('\n')
messages.append(messagesJson)
messagesTokens=num_tokens_from_messages(messages, 'gpt-3.5-turbo-0301')
cost=messagesTokens/1000*0.008*3
ans='整个微调数据集token总数:{}\n训练费用:经过3个epoch训练,参与训练总token数:{}。\n预计基于该jsonl微调数据的训练成本约为:{:.3f}美元'.format(messagesTokens,messagesTokens*3,cost)
print(ans)
ret,errorsItem=CheckData(messages)
if not ret:
ans+="\n\n格式检查:有格式问题!数据错误统计:"
print("格式检查:有格式问题!数据错误统计:")
for k, v in errorsItem.items():
ans+=f"\n{k}: {v}"
print(f"{k}: {v}")
else:
ans += "\n格式检查:检查完毕!该微调数据集无格式问题。"
print("格式检查:检查完毕!该微调数据集无格式问题。")
return outputPath,ans
# print(sheet.cell(i,1).value)
if __name__=="__main__":
pass