Spaces:
Configuration error
Configuration error
import json | |
import textwrap | |
from typing import Any | |
import tiktoken | |
from neollm.types import Function # , Functions, Messages | |
def normalize_model_name(model_name: str) -> str: | |
"""model_nameのトークン数計測のための標準化 | |
Args: | |
model_name (str): model_name | |
OpenAI: gpt-3.5-turbo-0613, gpt-3.5-turbo-16k-0613, gpt-4-0613, gpt-4-32k-0613 | |
OpenAIFT: ft:gpt-3.5-turbo:org_id | |
Azure: gpt-35-turbo-0613, gpt-35-turbo-16k-0613, gpt-4-0613, gpt-4-32k-0613 | |
Returns: | |
str: 標準化されたmodel_name | |
Raises: | |
ValueError: model_nameが不適切 | |
""" | |
# 参考: https://platform.openai.com/docs/models/gpt-3-5 | |
NEWEST_MAP = [ | |
("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613"), | |
("gpt-3.5-turbo", "gpt-3.5-turbo-0613"), | |
("gpt-4-32k", "gpt-4-32k-0613"), | |
("gpt-4", "gpt-4-0613"), | |
] | |
ALL_VERSION_MODELS = [ | |
# gpt-3.5-turbo | |
"gpt-3.5-turbo-0613", | |
"gpt-3.5-turbo-16k-0613", | |
"gpt-3.5-turbo-0301", # Legacy | |
# gpt-4 | |
"gpt-4-0613", | |
"gpt-4-32k-0613", | |
"gpt-4-0314", # Legacy | |
"gpt-4-32k-0314", # Legacy | |
] | |
# Azure表記 → OpenAI表記に統一 | |
model_name = model_name.replace("gpt-35", "gpt-3.5") | |
# 最新モデルを正式名称に & 新モデル, FTモデルをキャッチ | |
if model_name not in ALL_VERSION_MODELS: | |
for key, model_name_version in NEWEST_MAP: | |
if key in model_name: | |
model_name = model_name_version | |
break | |
# Return | |
if model_name in ALL_VERSION_MODELS: | |
return model_name | |
raise ValueError("model_name は以下から選んで.\n" + ",".join(ALL_VERSION_MODELS)) | |
def count_tokens(messages: Any | None = None, model_name: str | None = None, functions: Any | None = None) -> int: | |
"""トークン数計測 | |
Args: | |
messages (Messages): GPTAPIの入力のmessages | |
model_name (str | None, optional): モデル名. Defaults to None. | |
functions (Functions | None, optional): GPTAPIの入力のfunctions. Defaults to None. | |
Returns: | |
int: トークン数 | |
""" | |
model_name = normalize_model_name(model_name or "cl100k_base") | |
num_tokens = _count_messages_tokens(messages, model_name) | |
if functions is not None: | |
num_tokens += _count_functions_tokens(functions, model_name) | |
return num_tokens | |
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |
def _count_messages_tokens(messages: Any | None, model_name: str) -> int: | |
"""メッセージのトークン数を計算 | |
Args: | |
messages (Messages): ChatGPT等APIに入力するmessages | |
model_name (str, optional): 使用するモデルの名前 | |
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314" | |
"gpt-4-0613", "gpt-4-32k-0613", "gpt-3.5-turbo", "gpt-4" | |
Returns: | |
int: トークン数の合計 | |
""" | |
if messages is None: | |
return 0 | |
# setting model | |
encoding_model = tiktoken.encoding_for_model(model_name) # "cl100k_base" | |
# config | |
if model_name == "gpt-3.5-turbo-0301": | |
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |
tokens_per_name = -1 # if there's a name, the role is omitted | |
else: | |
tokens_per_message = 3 | |
tokens_per_name = 1 | |
# count tokens | |
num_tokens = 3 # every reply is primed with <|start|>assistant<|message|> | |
for message in messages: | |
num_tokens += tokens_per_message | |
for key, value in message.items(): | |
if isinstance(value, str): | |
num_tokens += len(encoding_model.encode(value)) | |
if key == "name": | |
num_tokens += tokens_per_name | |
return num_tokens | |
# https://gist.github.com/CGamesPlay/dd4f108f27e2eec145eedf5c717318f5 | |
def _count_functions_tokens(functions: Any, model_name: str | None = None) -> int: | |
""" | |
functionsのトークン数計測 | |
Args: | |
functions (Functions): GPTAPIの入力のfunctions | |
model_name (str | None, optional): モデル名. Defaults to None. | |
Returns: | |
_type_: トークン数 | |
""" | |
encoding_model = tiktoken.encoding_for_model(model_name or "cl100k_base") # "cl100k_base" | |
num_tokens = 3 + len(encoding_model.encode(__functions2string(functions))) | |
return num_tokens | |
# functionsのstring化、補助関数 --------------------------------------------------------------------------- | |
def __functions2string(functions: Any) -> str: | |
"""functionsの文字列化 | |
Args: | |
functions (Functions): GPTAPIの入力のfunctions | |
Returns: | |
str: functionsの文字列 | |
""" | |
prefix = "# Tools\n\n## functions\n\nnamespace functions {\n\n} // namespace functions\n" | |
functions_string = prefix + "".join(__function2string(function) for function in functions) | |
return functions_string | |
def __function2string(function: Function) -> str: | |
"""functionの文字列化 | |
Args: | |
function (Function): GPTAPIのfunctionの要素 | |
Returns: | |
str: functionの文字列 | |
""" | |
object_string = __format_object(function["parameters"]) | |
if object_string is not None: | |
object_string = "_: " + object_string | |
else: | |
object_string = "" | |
functions_string: str = ( | |
f"// {function['description']}\ntype {function['name']} = (" + object_string + ") => any;\n\n" | |
) | |
return functions_string | |
def __format_object(schema: dict[str, Any], indent: int = 0) -> str | None: | |
if "properties" not in schema or len(schema["properties"]) == 0: | |
if schema.get("additionalProperties", False): | |
return "object" | |
return None | |
result = "{\n" | |
for key, value in dict(schema["properties"]).items(): | |
# value <- resolve_ref(value) | |
value_rendered = __format_schema(value, indent + 1) | |
if value_rendered is None: | |
continue | |
# description | |
if "description" in value: | |
description = "".join( | |
" " * indent + f"// {description_i}\n" | |
for description_i in textwrap.dedent(value["description"]).strip().split("\n") | |
) | |
# optional | |
optional = "" if key in schema.get("required", {}) else "?" | |
# default | |
default_comment = "" if "default" not in value else f" // default: {__format_default(value)}" | |
# add string | |
result += description + " " * indent + f"{key}{optional}: {value_rendered},{default_comment}\n" | |
result += (" " * (indent - 1)) + "}" | |
return result | |
# よくわからん | |
# def resolve_ref(schema): | |
# if schema.get("$ref") is not None: | |
# ref = schema["$ref"][14:] | |
# schema = json_schema["definitions"][ref] | |
# return schema | |
def __format_schema(schema: dict[str, Any], indent: int) -> str | None: | |
# schema <- resolve_ref(schema) | |
if "enum" in schema: | |
return __format_enum(schema) | |
elif schema["type"] == "object": | |
return __format_object(schema, indent) | |
elif schema["type"] in {"integer", "number"}: | |
return "number" | |
elif schema["type"] in {"string"}: | |
return "string" | |
elif schema["type"] == "array": | |
return str(__format_schema(schema["items"], indent)) + "[]" | |
else: | |
raise ValueError("unknown schema type " + schema["type"]) | |
def __format_enum(schema: dict[str, Any]) -> str: | |
# "A" | "B" | "C" | |
return " | ".join(json.dumps(element, ensure_ascii=False) for element in schema["enum"]) | |
def __format_default(schema: dict[str, Any]) -> str: | |
default = schema["default"] | |
if schema["type"] == "number" and float(default).is_integer(): | |
# numberの時、0 → 0.0 | |
return f"{default:.1f}" | |
else: | |
return str(default) | |