Kpenciler's picture
Upload 53 files
88435ed verified
import json
from typing import Any
from openai.types.chat import ChatCompletionAssistantMessageParam
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall
from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from neollm.types import (
ChatCompletionMessage,
Chunk,
ClientSettings,
InputType,
LLMSettings,
Message,
Messages,
OutputType,
PriceInfo,
PrintColor,
Role,
TokenInfo,
)
from neollm.utils.postprocess import json2dict
from neollm.utils.utils import CPrintParam, cprint
TITLE_COLOR: PrintColor = "blue"
YEN_PAR_DOLLAR: float = 140.0 # 150円になってしまったぴえん(231027)
def _ChatCompletionMessage2dict(message: ChatCompletionMessage) -> Message:
message_dict = ChatCompletionAssistantMessageParam(content=message.content, role=message.role)
if message.function_call is not None:
message_dict["function_call"] = FunctionCall(
arguments=message.function_call.arguments, name=message.function_call.name
)
if message.tool_calls is not None:
message_dict["tool_calls"] = [
ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(arguments=tool_call.function.arguments, name=tool_call.function.name),
type=tool_call.type,
)
for tool_call in message.tool_calls
]
return message_dict
def _get_tool_calls(message_dict: Message) -> list[ChatCompletionMessageToolCallParam]:
tool_calls: list[ChatCompletionMessageToolCallParam] = []
if "tool_calls" in message_dict:
_tool_calls = message_dict.get("tool_calls", None)
if _tool_calls is not None and isinstance(_tool_calls, list): # isinstance(_tool_calls, list)ないと通らん,,,
for _tool_call in _tool_calls:
tool_call = ChatCompletionMessageToolCallParam(
id=_tool_call["id"],
function=Function(
arguments=_tool_call["function"]["arguments"],
name=_tool_call["function"]["name"],
),
type=_tool_call["type"],
)
tool_calls.append(tool_call)
if "function_call" in message_dict:
function_call = message_dict.get("function_call", None)
if function_call is not None and isinstance(
function_call, dict
): # isinstance(function_call, dict)ないと通らん,,,
tool_calls.append(
ChatCompletionMessageToolCallParam(
id="",
function=Function(
arguments=function_call["arguments"],
name=function_call["name"],
),
type="function",
)
)
return tool_calls
def print_metadata(time: float, token: TokenInfo, price: PriceInfo) -> None:
try:
cprint("[metadata]", color=TITLE_COLOR, kwargs={"end": " "})
print(
f"{time:.1f}s; "
f"{token.total:,}({token.input:,}+{token.output:,})tokens; "
f"${price.total:.2g}; ¥{price.total*YEN_PAR_DOLLAR:.2g}"
)
except Exception as e:
cprint(e, color="red", background=True)
def print_inputs(inputs: InputType) -> None:
try:
cprint("[inputs]", color=TITLE_COLOR)
print(json.dumps(_arange_dumpable_object(inputs), indent=2, ensure_ascii=False))
except Exception as e:
cprint(e, color="red", background=True)
def print_outputs(outputs: OutputType) -> None:
try:
cprint("[outputs]", color=TITLE_COLOR)
print(json.dumps(_arange_dumpable_object(outputs), indent=2, ensure_ascii=False))
except Exception as e:
cprint(e, color="red", background=True)
def print_messages(messages: list[ChatCompletionMessage] | Messages | None, title: bool = True) -> None:
if messages is None:
cprint("Not yet running _preprocess", color="red")
return
# try:
if title:
cprint("[messages]", color=TITLE_COLOR)
role2prarams: dict[Role, CPrintParam] = {
"system": {"color": "green"},
"user": {"color": "green"},
"assistant": {"color": "green"},
"function": {"color": "green", "background": True},
"tool": {"color": "green", "background": True},
}
for message in messages:
message_dict: Message
if isinstance(message, ChatCompletionMessage):
message_dict = _ChatCompletionMessage2dict(message)
else:
message_dict = message
# roleの出力 ----------------------------------------
print(" ", end="")
role = message_dict["role"]
cprint(role, **role2prarams[role])
# contentの出力 ----------------------------------------
content = message_dict.get("content", None)
if isinstance(content, str):
print(" " + content.replace("\n", "\n "))
elif isinstance(content, list):
for content_part in content:
if content_part["type"] == "text":
print(" " + content_part["text"].replace("\n", "\n "))
elif content_part["type"] == "image_url":
cprint(" <image_url>", color="green", kwargs={"end": " "})
print(content_part["image_url"])
# TODO: 画像出力
# TODO: Preview用、content_part["image"]: str, dict両方いけてしまう
else:
# TODO: 未対応のcontentの出力
pass
# tool_callの出力 ----------------------------------------
for tool_call in _get_tool_calls(message_dict):
print(" ", end="")
cprint(tool_call["function"]["name"], color="green", background=True)
print(" " + str(json2dict(tool_call["function"]["arguments"], error_key=None)).replace("\n", "\n "))
# except Exception as e:
# cprint(e, color="red", background=True)
def print_delta(chunk: Chunk) -> None:
if len(chunk.choices) == 0:
return
choice = chunk.choices[0] # TODO: n>2の対応
if choice.delta.role is not None:
print(" ", end="")
cprint(choice.delta.role, color="green")
print(" ", end="")
if choice.delta.content is not None:
print(choice.delta.content.replace("\n", "\n "), end="")
if choice.delta.function_call is not None:
if choice.delta.function_call.name is not None:
cprint(choice.delta.function_call.name, color="green", background=True)
print(" ", end="")
if choice.delta.function_call.arguments is not None:
print(choice.delta.function_call.arguments.replace("\n", "\n "), end="")
if choice.delta.tool_calls is not None:
for tool_call in choice.delta.tool_calls:
if tool_call.function is not None:
if tool_call.function.name is not None:
if tool_call.index != 0:
print("\n ", end="")
cprint(tool_call.function.name, color="green", background=True)
print(" ", end="")
if tool_call.function.arguments is not None:
print(tool_call.function.arguments.replace("\n", "\n "), end="")
if choice.finish_reason is not None:
print()
def print_llm_settings(llm_settings: LLMSettings, model: str, engine: str | None, platform: str) -> None:
try:
cprint("[llm_settings]", color=TITLE_COLOR, kwargs={"end": " "})
llm_settings_copy = dict(platform=platform, **llm_settings)
llm_settings_copy["model"] = model
# Azureの場合
if platform == "azure":
llm_settings_copy["engine"] = engine # engineを追加
print(llm_settings_copy or "-")
except Exception as e:
cprint(e, color="red", background=True)
def print_client_settings(client_settings: ClientSettings) -> None:
try:
cprint("[client_settings]", color=TITLE_COLOR, kwargs={"end": " "})
print(client_settings or "-")
except Exception as e:
cprint(e, color="red", background=True)
# -------
_DumplableEntity = int | float | str | bool | None | list[Any] | dict[Any, Any]
DumplableType = _DumplableEntity | list["DumplableType"] | dict["DumplableType", "DumplableType"]
def _arange_dumpable_object(obj: Any) -> DumplableType:
# 基本データ型の場合、そのまま返す
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
# リストの場合、再帰的に各要素を変換
if isinstance(obj, list):
return [_arange_dumpable_object(item) for item in obj]
# 辞書の場合、再帰的に各キーと値を変換
if isinstance(obj, dict):
return {_arange_dumpable_object(key): _arange_dumpable_object(value) for key, value in obj.items()}
# それ以外の型の場合、型情報を含めて文字列に変換
return f"<{type(obj).__name__}>{str(obj)}"