Kpenciler's picture
Upload 53 files
88435ed verified
import os
import time
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Final, Generator, Literal, Optional
from neollm.exceptions import ContentFilterError
from neollm.llm import AbstractLLM, get_llm
from neollm.llm.gpt.azure_llm import AzureLLM
from neollm.myllm.abstract_myllm import AbstractMyLLM
from neollm.myllm.print_utils import (
print_client_settings,
print_delta,
print_llm_settings,
print_messages,
)
from neollm.types import (
Chunk,
ClientSettings,
Functions,
InputType,
LLMSettings,
Message,
Messages,
OutputType,
PriceInfo,
Response,
StreamOutputType,
TimeInfo,
TokenInfo,
Tools,
)
from neollm.types.openai.chat_completion import CompletionUsageForCustomPriceCalculation
from neollm.utils.preprocess import dict2json
from neollm.utils.utils import cprint
if TYPE_CHECKING:
from neollm.myllm.myl3m2 import MyL3M2
_MyL3M2 = MyL3M2[Any, Any]
_State = dict[Any, Any]
DEFAULT_LLM_SETTINGS: LLMSettings = {"temperature": 0}
DEFAULT_PLATFORM: Final[str] = "azure"
class MyLLM(AbstractMyLLM[InputType, OutputType]):
"""LLMの単一リクエストをまとめるクラス"""
def __init__(
self,
model: str,
parent: Optional["_MyL3M2"] = None,
llm_settings: LLMSettings | None = None,
client_settings: ClientSettings | None = None,
platform: str | None = None,
verbose: bool = False,
stream_verbose: bool = False,
silent_list: list[Literal["llm_settings", "inputs", "outputs", "messages", "metadata"]] | None = None,
log_dir: str | None = None,
) -> None:
"""
MyLLMクラスの初期化
Args:
model (Optional[str]): LLMモデル名
parent (Optional[MyL3M2]): 親のMyL3M2のインスタンス (self or None)
llm_settings (LLMSettings): LLMの設定パラメータ
client_settings (ClientSettings): llmのclientの設定パラメータ
platform (Optional[str]): LLMのプラットフォーム名 (デフォルト: os.environ["PLATFORM"] or "azure")
(enum: openai, azure)
verbose (bool): 出力をするかどうかのフラグ
stream_verbose (bool): assitantをstreamで出力するか(verbose=False, message in "messages"の時、無効)
silent_list (list[Literal["llm_settings", "inputs", "outputs", "messages", "metadata"]]):
verbose=True時, 出力を抑制する要素のリスト
log_dir (Optional[str]): ログを保存するディレクトリのパス Noneの時、保存しない
"""
self.parent: _MyL3M2 | None = parent
self.llm_settings = llm_settings or DEFAULT_LLM_SETTINGS
self.client_settings = client_settings or {}
self.model: str = model
self.platform: str = platform or os.environ.get("LLM_PLATFORM", DEFAULT_PLATFORM) or DEFAULT_PLATFORM
self.verbose: bool = verbose & (True if self.parent is None else self.parent.verbose) # 親に合わせる
self.silent_set = set(silent_list or [])
self.stream_verbose: bool = stream_verbose if verbose and ("messages" not in self.silent_set) else False
self.log_dir: str | None = log_dir
self.inputs: InputType | None = None
self.outputs: OutputType | None = None
self.messages: Messages | None = None
self.functions: Functions | None = None
self.tools: Tools | None = None
self.response: Response | None = None
self.called: bool = False
self.do_stream: bool = self.stream_verbose
self.llm: AbstractLLM = get_llm(
model_name=self.model, platform=self.platform, client_settings=self.client_settings
)
@abstractmethod
def _preprocess(self, inputs: InputType) -> Messages:
"""
inputs を API入力 の messages に前処理する
Args:
inputs (InputType): 入力
Returns:
Messages: API入力 の messages
>>> [{"role": "system", "content": "system_prompt"}, {"role": "user", "content": "user_prompt"}]
"""
@abstractmethod
def _postprocess(self, response: Response) -> OutputType:
"""
API の response を outputs に後処理する
Args:
response (Response): API の response
>>> {"choices": [{"message": {"role": "assistant",
>>> "content": "This is a test!"}}]}
>>> {"choices": [{"message": {"role": "assistant",
>>> "function_call": {"name": "func", "arguments": "{a: 1}"}}]}
Returns:
OutputType: 出力
"""
def _ruleprocess(self, inputs: InputType) -> OutputType | None:
"""
ルールベース処理 or APIリクエスト の判断
Args:
inputs (InputType): MyLLMの入力
Returns:
RuleOutputs:
ルールベース処理の時、MyLLMの出力を返す
APIリクエストの時、Noneを返す
"""
return None
def _update_settings(self) -> None:
"""
APIの設定の更新
Note:
messageのトークン数
>>> self.llm.count_tokens(self.messsage)
モデル変更
>>> self.model = "gpt-3.5-turbo-16k"
パラメータ変更
>>> self.llm_settings = {"temperature": 0.2}
"""
return None
def _add_tools(self, inputs: InputType) -> Tools | None:
return None
def _add_functions(self, inputs: InputType) -> Functions | None:
"""
functions の追加
Args:
inputs (InputType): 入力
Returns:
Functions | None: functions。追加しない場合None
https://json-schema.org/understanding-json-schema/reference/index.html
>>> {
>>> "name": "関数名",
>>> "description": "関数の動作の説明。GPTは説明を見て利用するか選ぶ",
>>> "parameters": {
>>> "type": "object", "properties": {"city_name": {"type": "string", "description": "都市名"}},
>>> json-schema[https://json-schema.org/understanding-json-schema/reference/index.html]
>>> }
>>> }
"""
return None
def _stream_postprocess(
self,
new_chunk: Chunk,
state: "_State",
) -> StreamOutputType:
"""call_streamのGeneratorのpostprocess
Args:
new_chunk (OpenAIChunkResponse): 新しいchunk
state (dict[Any, Any]): 状態を持てるdict. 初めは、default {}. 状態が消えてしまうのでoverwriteしない。
Returns:
StreamOutputType: 一時的なoutput
"""
if len(new_chunk.choices) == 0:
return ""
return new_chunk.choices[0].delta.content
def _generate(self, stream: bool) -> Generator[StreamOutputType, None, None]:
"""
LLMの出力を得て、`self.response`に格納する
Args:
messages (list[dict[str, str]]): LLMの入力メッセージ
"""
# 例外処理 -----------------------------------------------------------
if self.messages is None:
raise ValueError("MessagesがNoneです。")
# kwargs -----------------------------------------------------------
generate_kwargs = dict(**self.llm_settings)
if self.functions is not None:
generate_kwargs["functions"] = self.functions
if self.functions is not None:
generate_kwargs["tools"] = self.tools
# generate ----------------------------------------------------------
self._print_messages() # verbose
self.llm = get_llm(model_name=self.model, platform=self.platform, client_settings=self.client_settings)
# [stream]
if stream or self.stream_verbose:
it = self.llm.generate_stream(messages=self.messages, llm_settings=generate_kwargs)
chunk_list: list[Chunk] = []
state: "_State" = {}
for chunk in it:
chunk_list.append(chunk)
self._print_delta(chunk=chunk) # verbose: stop→改行、conent, TODO: fc→出力
yield self._stream_postprocess(new_chunk=chunk, state=state)
self.response = self.llm.convert_nonstream_response(chunk_list, self.messages, self.functions)
# [non-stream]
else:
try:
self.response = self.llm.generate(messages=self.messages, llm_settings=generate_kwargs)
self._print_message_assistant()
except Exception as e:
raise e
# ContentFilterError -------------------------------------------------
if len(self.response.choices) == 0:
cprint(self.response, color="red", background=True)
raise ContentFilterError("入力のコンテンツフィルターに引っかかりました。")
if self.response.choices[0].finish_reason == "content_filter":
cprint(self.response, color="red", background=True)
raise ContentFilterError("出力のコンテンツフィルターに引っかかりました。")
def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]:
"""
LLMの処理を行う (preprocess, check_input, generate, postprocess)
Args:
inputs (InputType): 入力データを保持する辞書
Returns:
OutputType: 処理結果の出力データ
Raises:
RuntimeError: 既に呼び出されている場合に発生
"""
if self.called:
raise RuntimeError("MyLLMは1回しか呼び出せない")
self._print_start(sep="-")
# main -----------------------------------------------------------
t_start = time.time()
self.inputs = inputs
self._print_inputs()
rulebase_output = self._ruleprocess(inputs)
if rulebase_output is None: # API リクエストを送る場合
self._update_settings()
self.messages = self._preprocess(inputs)
self.functions = self._add_functions(inputs)
self.tools = self._add_tools(inputs)
t_preprocessed = time.time()
# [generate]
it = self._generate(stream=stream)
for delta_content in it: # stream=Falseの時、空のGenerator
yield delta_content
if self.response is None:
raise ValueError("responseがNoneです。")
t_generated = time.time()
# [postprocess]
self.outputs = self._postprocess(self.response)
t_postprocessed = time.time()
else: # ルールベースの場合
self.outputs = rulebase_output
t_preprocessed = t_generated = t_postprocessed = time.time()
self.time_detail = TimeInfo(
total=t_postprocessed - t_start,
preprocess=t_preprocessed - t_start,
main=t_generated - t_preprocessed,
postprocess=t_postprocessed - t_generated,
)
self.time = t_postprocessed - t_start
# print -----------------------------------------------------------
self._print_outputs()
self._print_client_settings()
self._print_llm_settings()
self._print_metadata()
self._print_end(sep="-")
# 親MyL3M2にAppend -----------------------------------------------------------
if self.parent is not None:
self.parent.myllm_list.append(self)
self.called = True
# log -----------------------------------------------------------
self._save_log()
return self.outputs
@property
def log(self) -> dict[str, Any]:
return {
"inputs": self.inputs,
"outputs": self.outputs,
"resposnse": self.response.model_dump() if self.response is not None else None,
"input_token": self.token.input,
"output_token": self.token.output,
"total_token": self.token.total,
"input_price": self.price.input,
"output_price": self.price.output,
"total_price": self.price.total,
"time": self.time,
"time_stamp": time.time(),
"llm_settings": self.llm_settings,
"client_settings": self.client_settings,
"model": self.model,
"platform": self.platform,
"verbose": self.verbose,
"messages": self.messages,
"assistant_message": self.assistant_message,
"functions": self.functions,
"tools": self.tools,
}
def _save_log(self) -> None:
if self.log_dir is None:
return
try:
log = self.log
json_string = dict2json(log)
save_log_path = os.path.join(self.log_dir, f"{log['time_stamp']}.json")
os.makedirs(self.log_dir, exist_ok=True)
with open(save_log_path, mode="w") as f:
f.write(json_string)
except Exception as e:
cprint(e, color="red", background=True)
@property
def token(self) -> TokenInfo:
if self.response is None or self.response.usage is None:
return TokenInfo(input=0, output=0, total=0)
return TokenInfo(
input=self.response.usage.prompt_tokens,
output=self.response.usage.completion_tokens,
total=self.response.usage.total_tokens,
)
@property
def custom_token(self) -> TokenInfo | None:
if not self.llm._custom_price_calculation:
return None
if self.response is None:
return TokenInfo(input=0, output=0, total=0)
usage_for_price = getattr(self.response, "usage_for_price", None)
if not isinstance(usage_for_price, CompletionUsageForCustomPriceCalculation):
cprint("usage_for_priceがNoneです。正しくトークン計算できません", color="red", background=True)
return TokenInfo(input=0, output=0, total=0)
return TokenInfo(
input=usage_for_price.prompt_tokens,
output=usage_for_price.completion_tokens,
total=usage_for_price.total_tokens,
)
@property
def price(self) -> PriceInfo:
if self.response is None:
return PriceInfo(input=0.0, output=0.0, total=0.0)
if self.llm._custom_price_calculation:
# Geniniの時は必ずcustom_tokenがある想定
if self.custom_token is None:
cprint("custom_tokenがNoneです。正しくトークン計算できません", color="red", background=True)
else:
return PriceInfo(
input=self.llm.calculate_price(num_input_tokens=self.custom_token.input),
output=self.llm.calculate_price(num_output_tokens=self.custom_token.output),
total=self.llm.calculate_price(
num_input_tokens=self.custom_token.input, num_output_tokens=self.custom_token.output
),
)
return PriceInfo(
input=self.llm.calculate_price(num_input_tokens=self.token.input),
output=self.llm.calculate_price(num_output_tokens=self.token.output),
total=self.llm.calculate_price(num_input_tokens=self.token.input, num_output_tokens=self.token.output),
)
@property
def assistant_message(self) -> Message | None:
if self.response is None or len(self.response.choices) == 0:
return None
return self.response.choices[0].message.to_typeddict_message()
@property
def chat_history(self) -> Messages:
chat_history: Messages = []
if self.messages:
chat_history += self.messages
if self.assistant_message is not None:
chat_history.append(self.assistant_message)
return chat_history
def _print_llm_settings(self) -> None:
if not ("llm_settings" not in self.silent_set and self.verbose):
return
print_llm_settings(
llm_settings=self.llm_settings,
model=self.model,
platform=self.platform,
engine=self.llm.engine if isinstance(self.llm, AzureLLM) else None,
)
def _print_messages(self) -> None:
if not ("messages" not in self.silent_set and self.verbose):
return
print_messages(self.messages, title=True)
def _print_message_assistant(self) -> None:
if self.response is None or len(self.response.choices) == 0:
return
if not ("messages" not in self.silent_set and self.verbose):
return
print_messages(messages=[self.response.choices[0].message], title=False)
def _print_delta(self, chunk: Chunk) -> None:
if not ("messages" not in self.silent_set and self.verbose):
return
print_delta(chunk)
def _print_client_settings(self) -> None:
if not ("client_settings" not in self.silent_set and self.verbose):
return
print_client_settings(self.llm.client_settings)
def __repr__(self) -> str:
return f"MyLLM({self.__class__.__name__})"