Spaces:
Runtime error
Runtime error
# !/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Copyright reserved by Yang Hao, Metaverse Developers Association. All rights reserved | |
This module provides the conversion of Excel formatted data into standardized fine-tuning datasets | |
Authors: yanghao([email protected]) | |
Date: 2023/09/12 19:23:06 | |
""" | |
import openpyxl | |
import os | |
import json | |
import tiktoken | |
from collections import defaultdict | |
def GetTokenforStr(strText): | |
encoding = tiktoken.encoding_for_model('gpt-3.5-turbo-0301') | |
num_tokens = len(encoding.encode(strText)) | |
return num_tokens | |
def CheckData(messages): | |
format_errors = defaultdict(int) | |
if isinstance(messages,dict): | |
messages=[messages] | |
for ex in messages: | |
if not isinstance(ex, dict): | |
format_errors["data_type"] += 1 | |
continue | |
messages = ex.get("messages", None) | |
if not messages: | |
format_errors["missing_messages_list"] += 1 | |
continue | |
for message in messages: | |
if "role" not in message or "content" not in message: | |
format_errors["message_missing_key"] += 1 | |
if any(k not in ("role", "content", "name") for k in message): | |
format_errors["message_unrecognized_key"] += 1 | |
if message.get("role", None) not in ("system", "user", "assistant"): | |
format_errors["unrecognized_role"] += 1 | |
content = message.get("content", None) | |
if not content or not isinstance(content, str): | |
format_errors["missing_content"] += 1 | |
if not any(message.get("role", None) == "assistant" for message in messages): | |
format_errors["example_missing_assistant_message"] += 1 | |
if format_errors: | |
return False,format_errors | |
else: | |
return True,{} | |
# 计算 encode 返回列表的长度 | |
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"): | |
"""Returns the number of tokens used by a list of messages.""" | |
try: | |
encoding = tiktoken.encoding_for_model(model) | |
except KeyError: | |
print("Warning: model not found. Using cl100k_base encoding.") | |
encoding = tiktoken.get_encoding("cl100k_base") | |
if model == "gpt-3.5-turbo": | |
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.") | |
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") | |
elif model == "gpt-4": | |
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.") | |
return num_tokens_from_messages(messages, model="gpt-4-0314") | |
elif model == "gpt-3.5-turbo-0301": | |
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |
tokens_per_name = -1 # if there's a name, the role is omitted | |
elif model == "gpt-4-0314": | |
tokens_per_message = 3 | |
tokens_per_name = 1 | |
else: | |
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""") | |
num_tokens = 0 | |
if type(messages)!=type([1]): | |
messages=[messages] | |
for message in messages: | |
num_tokens += tokens_per_message | |
for key, value in message.items(): | |
# print(value) | |
value=str(value) | |
num_tokens += len(encoding.encode(value)) | |
if key == "name": | |
num_tokens += tokens_per_name | |
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> | |
return num_tokens | |
def DataFormat(inputPath,OpenAPItype): | |
#一、加载含有用户输入和GPT的输出的文件 | |
try: | |
inputPath=inputPath.name | |
except: | |
inputPath = inputPath | |
book=openpyxl.load_workbook(inputPath) | |
sheet=book.active | |
maxrow=sheet.max_row | |
if OpenAPItype[:7]=='gpt-3.5': | |
# 二、遍历准备好的数据集文件并格式化成微调需要的格式 | |
print("训练用例条数:{}".format(maxrow-1)) | |
messages=[] | |
outputPath="{}/Format_{}.jsonl".format(os.path.dirname(inputPath),os.path.splitext(os.path.basename(inputPath))[0]) #格式化后输出的地址 | |
with open(outputPath,'w',encoding='utf-8')as w: | |
for i in range(2,maxrow+1): | |
systemJson={"role": "system", "content": sheet.cell(i,1).value} | |
userJson={"role": "user", "content": sheet.cell(i,2).value} | |
AssistantJson = {"role": "assistant", "content": sheet.cell(i, 3).value} | |
messagesJson={"messages": [systemJson,userJson,AssistantJson]} | |
messAgeTokens=num_tokens_from_messages(messagesJson, 'gpt-3.5-turbo-0301') | |
if messAgeTokens>4096: | |
print('用例{} tokens数为{},无法发送'.format(i,messAgeTokens)) | |
else: | |
json.dump(messagesJson, w, ensure_ascii=False) | |
w.write('\n') | |
messages.append(messagesJson) | |
messagesTokens=num_tokens_from_messages(messages, 'gpt-3.5-turbo-0301') | |
cost=messagesTokens/1000*0.008*3 | |
ans='整个微调数据集token总数:{}\n训练费用:经过3个epoch训练,参与训练总token数:{}。\n预计基于该jsonl微调数据的训练成本约为:{:.3f}美元'.format(messagesTokens,messagesTokens*3,cost) | |
print(ans) | |
ret,errorsItem=CheckData(messages) | |
if not ret: | |
ans+="\n\n格式检查:有格式问题!数据错误统计:" | |
print("格式检查:有格式问题!数据错误统计:") | |
for k, v in errorsItem.items(): | |
ans+=f"\n{k}: {v}" | |
print(f"{k}: {v}") | |
else: | |
ans += "\n格式检查:检查完毕!该微调数据集无格式问题。" | |
print("格式检查:检查完毕!该微调数据集无格式问题。") | |
return outputPath,ans | |
# print(sheet.cell(i,1).value) | |
if __name__=="__main__": | |
pass | |