Spaces:
Configuration error
Configuration error
import json | |
from typing import Any | |
from openai.types.chat import ChatCompletionAssistantMessageParam | |
from openai.types.chat.chat_completion_assistant_message_param import FunctionCall | |
from openai.types.chat.chat_completion_message_tool_call_param import ( | |
ChatCompletionMessageToolCallParam, | |
Function, | |
) | |
from neollm.types import ( | |
ChatCompletionMessage, | |
Chunk, | |
ClientSettings, | |
InputType, | |
LLMSettings, | |
Message, | |
Messages, | |
OutputType, | |
PriceInfo, | |
PrintColor, | |
Role, | |
TokenInfo, | |
) | |
from neollm.utils.postprocess import json2dict | |
from neollm.utils.utils import CPrintParam, cprint | |
TITLE_COLOR: PrintColor = "blue" | |
YEN_PAR_DOLLAR: float = 140.0 # 150円になってしまったぴえん(231027) | |
def _ChatCompletionMessage2dict(message: ChatCompletionMessage) -> Message: | |
message_dict = ChatCompletionAssistantMessageParam(content=message.content, role=message.role) | |
if message.function_call is not None: | |
message_dict["function_call"] = FunctionCall( | |
arguments=message.function_call.arguments, name=message.function_call.name | |
) | |
if message.tool_calls is not None: | |
message_dict["tool_calls"] = [ | |
ChatCompletionMessageToolCallParam( | |
id=tool_call.id, | |
function=Function(arguments=tool_call.function.arguments, name=tool_call.function.name), | |
type=tool_call.type, | |
) | |
for tool_call in message.tool_calls | |
] | |
return message_dict | |
def _get_tool_calls(message_dict: Message) -> list[ChatCompletionMessageToolCallParam]: | |
tool_calls: list[ChatCompletionMessageToolCallParam] = [] | |
if "tool_calls" in message_dict: | |
_tool_calls = message_dict.get("tool_calls", None) | |
if _tool_calls is not None and isinstance(_tool_calls, list): # isinstance(_tool_calls, list)ないと通らん,,, | |
for _tool_call in _tool_calls: | |
tool_call = ChatCompletionMessageToolCallParam( | |
id=_tool_call["id"], | |
function=Function( | |
arguments=_tool_call["function"]["arguments"], | |
name=_tool_call["function"]["name"], | |
), | |
type=_tool_call["type"], | |
) | |
tool_calls.append(tool_call) | |
if "function_call" in message_dict: | |
function_call = message_dict.get("function_call", None) | |
if function_call is not None and isinstance( | |
function_call, dict | |
): # isinstance(function_call, dict)ないと通らん,,, | |
tool_calls.append( | |
ChatCompletionMessageToolCallParam( | |
id="", | |
function=Function( | |
arguments=function_call["arguments"], | |
name=function_call["name"], | |
), | |
type="function", | |
) | |
) | |
return tool_calls | |
def print_metadata(time: float, token: TokenInfo, price: PriceInfo) -> None: | |
try: | |
cprint("[metadata]", color=TITLE_COLOR, kwargs={"end": " "}) | |
print( | |
f"{time:.1f}s; " | |
f"{token.total:,}({token.input:,}+{token.output:,})tokens; " | |
f"${price.total:.2g}; ¥{price.total*YEN_PAR_DOLLAR:.2g}" | |
) | |
except Exception as e: | |
cprint(e, color="red", background=True) | |
def print_inputs(inputs: InputType) -> None: | |
try: | |
cprint("[inputs]", color=TITLE_COLOR) | |
print(json.dumps(_arange_dumpable_object(inputs), indent=2, ensure_ascii=False)) | |
except Exception as e: | |
cprint(e, color="red", background=True) | |
def print_outputs(outputs: OutputType) -> None: | |
try: | |
cprint("[outputs]", color=TITLE_COLOR) | |
print(json.dumps(_arange_dumpable_object(outputs), indent=2, ensure_ascii=False)) | |
except Exception as e: | |
cprint(e, color="red", background=True) | |
def print_messages(messages: list[ChatCompletionMessage] | Messages | None, title: bool = True) -> None: | |
if messages is None: | |
cprint("Not yet running _preprocess", color="red") | |
return | |
# try: | |
if title: | |
cprint("[messages]", color=TITLE_COLOR) | |
role2prarams: dict[Role, CPrintParam] = { | |
"system": {"color": "green"}, | |
"user": {"color": "green"}, | |
"assistant": {"color": "green"}, | |
"function": {"color": "green", "background": True}, | |
"tool": {"color": "green", "background": True}, | |
} | |
for message in messages: | |
message_dict: Message | |
if isinstance(message, ChatCompletionMessage): | |
message_dict = _ChatCompletionMessage2dict(message) | |
else: | |
message_dict = message | |
# roleの出力 ---------------------------------------- | |
print(" ", end="") | |
role = message_dict["role"] | |
cprint(role, **role2prarams[role]) | |
# contentの出力 ---------------------------------------- | |
content = message_dict.get("content", None) | |
if isinstance(content, str): | |
print(" " + content.replace("\n", "\n ")) | |
elif isinstance(content, list): | |
for content_part in content: | |
if content_part["type"] == "text": | |
print(" " + content_part["text"].replace("\n", "\n ")) | |
elif content_part["type"] == "image_url": | |
cprint(" <image_url>", color="green", kwargs={"end": " "}) | |
print(content_part["image_url"]) | |
# TODO: 画像出力 | |
# TODO: Preview用、content_part["image"]: str, dict両方いけてしまう | |
else: | |
# TODO: 未対応のcontentの出力 | |
pass | |
# tool_callの出力 ---------------------------------------- | |
for tool_call in _get_tool_calls(message_dict): | |
print(" ", end="") | |
cprint(tool_call["function"]["name"], color="green", background=True) | |
print(" " + str(json2dict(tool_call["function"]["arguments"], error_key=None)).replace("\n", "\n ")) | |
# except Exception as e: | |
# cprint(e, color="red", background=True) | |
def print_delta(chunk: Chunk) -> None: | |
if len(chunk.choices) == 0: | |
return | |
choice = chunk.choices[0] # TODO: n>2の対応 | |
if choice.delta.role is not None: | |
print(" ", end="") | |
cprint(choice.delta.role, color="green") | |
print(" ", end="") | |
if choice.delta.content is not None: | |
print(choice.delta.content.replace("\n", "\n "), end="") | |
if choice.delta.function_call is not None: | |
if choice.delta.function_call.name is not None: | |
cprint(choice.delta.function_call.name, color="green", background=True) | |
print(" ", end="") | |
if choice.delta.function_call.arguments is not None: | |
print(choice.delta.function_call.arguments.replace("\n", "\n "), end="") | |
if choice.delta.tool_calls is not None: | |
for tool_call in choice.delta.tool_calls: | |
if tool_call.function is not None: | |
if tool_call.function.name is not None: | |
if tool_call.index != 0: | |
print("\n ", end="") | |
cprint(tool_call.function.name, color="green", background=True) | |
print(" ", end="") | |
if tool_call.function.arguments is not None: | |
print(tool_call.function.arguments.replace("\n", "\n "), end="") | |
if choice.finish_reason is not None: | |
print() | |
def print_llm_settings(llm_settings: LLMSettings, model: str, engine: str | None, platform: str) -> None: | |
try: | |
cprint("[llm_settings]", color=TITLE_COLOR, kwargs={"end": " "}) | |
llm_settings_copy = dict(platform=platform, **llm_settings) | |
llm_settings_copy["model"] = model | |
# Azureの場合 | |
if platform == "azure": | |
llm_settings_copy["engine"] = engine # engineを追加 | |
print(llm_settings_copy or "-") | |
except Exception as e: | |
cprint(e, color="red", background=True) | |
def print_client_settings(client_settings: ClientSettings) -> None: | |
try: | |
cprint("[client_settings]", color=TITLE_COLOR, kwargs={"end": " "}) | |
print(client_settings or "-") | |
except Exception as e: | |
cprint(e, color="red", background=True) | |
# ------- | |
_DumplableEntity = int | float | str | bool | None | list[Any] | dict[Any, Any] | |
DumplableType = _DumplableEntity | list["DumplableType"] | dict["DumplableType", "DumplableType"] | |
def _arange_dumpable_object(obj: Any) -> DumplableType: | |
# 基本データ型の場合、そのまま返す | |
if isinstance(obj, (int, float, str, bool, type(None))): | |
return obj | |
# リストの場合、再帰的に各要素を変換 | |
if isinstance(obj, list): | |
return [_arange_dumpable_object(item) for item in obj] | |
# 辞書の場合、再帰的に各キーと値を変換 | |
if isinstance(obj, dict): | |
return {_arange_dumpable_object(key): _arange_dumpable_object(value) for key, value in obj.items()} | |
# それ以外の型の場合、型情報を含めて文字列に変換 | |
return f"<{type(obj).__name__}>{str(obj)}" | |