Kpenciler's picture
Upload 53 files
88435ed verified
import json
import textwrap
from typing import Any
import tiktoken
from neollm.types import Function # , Functions, Messages
def normalize_model_name(model_name: str) -> str:
"""model_nameのトークン数計測のための標準化
Args:
model_name (str): model_name
OpenAI: gpt-3.5-turbo-0613, gpt-3.5-turbo-16k-0613, gpt-4-0613, gpt-4-32k-0613
OpenAIFT: ft:gpt-3.5-turbo:org_id
Azure: gpt-35-turbo-0613, gpt-35-turbo-16k-0613, gpt-4-0613, gpt-4-32k-0613
Returns:
str: 標準化されたmodel_name
Raises:
ValueError: model_nameが不適切
"""
# 参考: https://platform.openai.com/docs/models/gpt-3-5
NEWEST_MAP = [
("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613"),
("gpt-3.5-turbo", "gpt-3.5-turbo-0613"),
("gpt-4-32k", "gpt-4-32k-0613"),
("gpt-4", "gpt-4-0613"),
]
ALL_VERSION_MODELS = [
# gpt-3.5-turbo
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301", # Legacy
# gpt-4
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-0314", # Legacy
"gpt-4-32k-0314", # Legacy
]
# Azure表記 → OpenAI表記に統一
model_name = model_name.replace("gpt-35", "gpt-3.5")
# 最新モデルを正式名称に & 新モデル, FTモデルをキャッチ
if model_name not in ALL_VERSION_MODELS:
for key, model_name_version in NEWEST_MAP:
if key in model_name:
model_name = model_name_version
break
# Return
if model_name in ALL_VERSION_MODELS:
return model_name
raise ValueError("model_name は以下から選んで.\n" + ",".join(ALL_VERSION_MODELS))
def count_tokens(messages: Any | None = None, model_name: str | None = None, functions: Any | None = None) -> int:
"""トークン数計測
Args:
messages (Messages): GPTAPIの入力のmessages
model_name (str | None, optional): モデル名. Defaults to None.
functions (Functions | None, optional): GPTAPIの入力のfunctions. Defaults to None.
Returns:
int: トークン数
"""
model_name = normalize_model_name(model_name or "cl100k_base")
num_tokens = _count_messages_tokens(messages, model_name)
if functions is not None:
num_tokens += _count_functions_tokens(functions, model_name)
return num_tokens
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def _count_messages_tokens(messages: Any | None, model_name: str) -> int:
"""メッセージのトークン数を計算
Args:
messages (Messages): ChatGPT等APIに入力するmessages
model_name (str, optional): 使用するモデルの名前
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314"
"gpt-4-0613", "gpt-4-32k-0613", "gpt-3.5-turbo", "gpt-4"
Returns:
int: トークン数の合計
"""
if messages is None:
return 0
# setting model
encoding_model = tiktoken.encoding_for_model(model_name) # "cl100k_base"
# config
if model_name == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
else:
tokens_per_message = 3
tokens_per_name = 1
# count tokens
num_tokens = 3 # every reply is primed with <|start|>assistant<|message|>
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, str):
num_tokens += len(encoding_model.encode(value))
if key == "name":
num_tokens += tokens_per_name
return num_tokens
# https://gist.github.com/CGamesPlay/dd4f108f27e2eec145eedf5c717318f5
def _count_functions_tokens(functions: Any, model_name: str | None = None) -> int:
"""
functionsのトークン数計測
Args:
functions (Functions): GPTAPIの入力のfunctions
model_name (str | None, optional): モデル名. Defaults to None.
Returns:
_type_: トークン数
"""
encoding_model = tiktoken.encoding_for_model(model_name or "cl100k_base") # "cl100k_base"
num_tokens = 3 + len(encoding_model.encode(__functions2string(functions)))
return num_tokens
# functionsのstring化、補助関数 ---------------------------------------------------------------------------
def __functions2string(functions: Any) -> str:
"""functionsの文字列化
Args:
functions (Functions): GPTAPIの入力のfunctions
Returns:
str: functionsの文字列
"""
prefix = "# Tools\n\n## functions\n\nnamespace functions {\n\n} // namespace functions\n"
functions_string = prefix + "".join(__function2string(function) for function in functions)
return functions_string
def __function2string(function: Function) -> str:
"""functionの文字列化
Args:
function (Function): GPTAPIのfunctionの要素
Returns:
str: functionの文字列
"""
object_string = __format_object(function["parameters"])
if object_string is not None:
object_string = "_: " + object_string
else:
object_string = ""
functions_string: str = (
f"// {function['description']}\ntype {function['name']} = (" + object_string + ") => any;\n\n"
)
return functions_string
def __format_object(schema: dict[str, Any], indent: int = 0) -> str | None:
if "properties" not in schema or len(schema["properties"]) == 0:
if schema.get("additionalProperties", False):
return "object"
return None
result = "{\n"
for key, value in dict(schema["properties"]).items():
# value <- resolve_ref(value)
value_rendered = __format_schema(value, indent + 1)
if value_rendered is None:
continue
# description
if "description" in value:
description = "".join(
" " * indent + f"// {description_i}\n"
for description_i in textwrap.dedent(value["description"]).strip().split("\n")
)
# optional
optional = "" if key in schema.get("required", {}) else "?"
# default
default_comment = "" if "default" not in value else f" // default: {__format_default(value)}"
# add string
result += description + " " * indent + f"{key}{optional}: {value_rendered},{default_comment}\n"
result += (" " * (indent - 1)) + "}"
return result
# よくわからん
# def resolve_ref(schema):
# if schema.get("$ref") is not None:
# ref = schema["$ref"][14:]
# schema = json_schema["definitions"][ref]
# return schema
def __format_schema(schema: dict[str, Any], indent: int) -> str | None:
# schema <- resolve_ref(schema)
if "enum" in schema:
return __format_enum(schema)
elif schema["type"] == "object":
return __format_object(schema, indent)
elif schema["type"] in {"integer", "number"}:
return "number"
elif schema["type"] in {"string"}:
return "string"
elif schema["type"] == "array":
return str(__format_schema(schema["items"], indent)) + "[]"
else:
raise ValueError("unknown schema type " + schema["type"])
def __format_enum(schema: dict[str, Any]) -> str:
# "A" | "B" | "C"
return " | ".join(json.dumps(element, ensure_ascii=False) for element in schema["enum"])
def __format_default(schema: dict[str, Any]) -> str:
default = schema["default"]
if schema["type"] == "number" and float(default).is_integer():
# numberの時、0 → 0.0
return f"{default:.1f}"
else:
return str(default)