# !/usr/bin/env python3 # -*- coding: utf-8 -*- """ Copyright reserved by Yang Hao, Metaverse Developers Association. All rights reserved This module provides the conversion of Excel formatted data into standardized fine-tuning datasets Authors: yanghao(yanghao31@baidu.com) Date: 2023/09/12 19:23:06 """ import openpyxl import os import json import tiktoken from collections import defaultdict def GetTokenforStr(strText): encoding = tiktoken.encoding_for_model('gpt-3.5-turbo-0301') num_tokens = len(encoding.encode(strText)) return num_tokens def CheckData(messages): format_errors = defaultdict(int) if isinstance(messages,dict): messages=[messages] for ex in messages: if not isinstance(ex, dict): format_errors["data_type"] += 1 continue messages = ex.get("messages", None) if not messages: format_errors["missing_messages_list"] += 1 continue for message in messages: if "role" not in message or "content" not in message: format_errors["message_missing_key"] += 1 if any(k not in ("role", "content", "name") for k in message): format_errors["message_unrecognized_key"] += 1 if message.get("role", None) not in ("system", "user", "assistant"): format_errors["unrecognized_role"] += 1 content = message.get("content", None) if not content or not isinstance(content, str): format_errors["missing_content"] += 1 if not any(message.get("role", None) == "assistant" for message in messages): format_errors["example_missing_assistant_message"] += 1 if format_errors: return False,format_errors else: return True,{} # 计算 encode 返回列表的长度 def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"): """Returns the number of tokens used by a list of messages.""" try: encoding = tiktoken.encoding_for_model(model) except KeyError: print("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model == "gpt-3.5-turbo": print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.") return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") elif model == "gpt-4": print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.") return num_tokens_from_messages(messages, model="gpt-4-0314") elif model == "gpt-3.5-turbo-0301": tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n tokens_per_name = -1 # if there's a name, the role is omitted elif model == "gpt-4-0314": tokens_per_message = 3 tokens_per_name = 1 else: raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""") num_tokens = 0 if type(messages)!=type([1]): messages=[messages] for message in messages: num_tokens += tokens_per_message for key, value in message.items(): # print(value) value=str(value) num_tokens += len(encoding.encode(value)) if key == "name": num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens def DataFormat(inputPath,OpenAPItype): #一、加载含有用户输入和GPT的输出的文件 try: inputPath=inputPath.name except: inputPath = inputPath book=openpyxl.load_workbook(inputPath) sheet=book.active maxrow=sheet.max_row if OpenAPItype[:7]=='gpt-3.5': # 二、遍历准备好的数据集文件并格式化成微调需要的格式 print("训练用例条数:{}".format(maxrow-1)) messages=[] outputPath="{}/Format_{}.jsonl".format(os.path.dirname(inputPath),os.path.splitext(os.path.basename(inputPath))[0]) #格式化后输出的地址 with open(outputPath,'w',encoding='utf-8')as w: for i in range(2,maxrow+1): systemJson={"role": "system", "content": sheet.cell(i,1).value} userJson={"role": "user", "content": sheet.cell(i,2).value} AssistantJson = {"role": "assistant", "content": sheet.cell(i, 3).value} messagesJson={"messages": [systemJson,userJson,AssistantJson]} messAgeTokens=num_tokens_from_messages(messagesJson, 'gpt-3.5-turbo-0301') if messAgeTokens>4096: print('用例{} tokens数为{},无法发送'.format(i,messAgeTokens)) else: json.dump(messagesJson, w, ensure_ascii=False) w.write('\n') messages.append(messagesJson) messagesTokens=num_tokens_from_messages(messages, 'gpt-3.5-turbo-0301') cost=messagesTokens/1000*0.008*3 ans='整个微调数据集token总数:{}\n训练费用:经过3个epoch训练,参与训练总token数:{}。\n预计基于该jsonl微调数据的训练成本约为:{:.3f}美元'.format(messagesTokens,messagesTokens*3,cost) print(ans) ret,errorsItem=CheckData(messages) if not ret: ans+="\n\n格式检查:有格式问题!数据错误统计:" print("格式检查:有格式问题!数据错误统计:") for k, v in errorsItem.items(): ans+=f"\n{k}: {v}" print(f"{k}: {v}") else: ans += "\n格式检查:检查完毕!该微调数据集无格式问题。" print("格式检查:检查完毕!该微调数据集无格式问题。") return outputPath,ans # print(sheet.cell(i,1).value) if __name__=="__main__": pass