Spaces:
Configuration error
Configuration error
import os | |
import time | |
from abc import abstractmethod | |
from typing import TYPE_CHECKING, Any, Final, Generator, Literal, Optional | |
from neollm.exceptions import ContentFilterError | |
from neollm.llm import AbstractLLM, get_llm | |
from neollm.llm.gpt.azure_llm import AzureLLM | |
from neollm.myllm.abstract_myllm import AbstractMyLLM | |
from neollm.myllm.print_utils import ( | |
print_client_settings, | |
print_delta, | |
print_llm_settings, | |
print_messages, | |
) | |
from neollm.types import ( | |
Chunk, | |
ClientSettings, | |
Functions, | |
InputType, | |
LLMSettings, | |
Message, | |
Messages, | |
OutputType, | |
PriceInfo, | |
Response, | |
StreamOutputType, | |
TimeInfo, | |
TokenInfo, | |
Tools, | |
) | |
from neollm.types.openai.chat_completion import CompletionUsageForCustomPriceCalculation | |
from neollm.utils.preprocess import dict2json | |
from neollm.utils.utils import cprint | |
if TYPE_CHECKING: | |
from neollm.myllm.myl3m2 import MyL3M2 | |
_MyL3M2 = MyL3M2[Any, Any] | |
_State = dict[Any, Any] | |
DEFAULT_LLM_SETTINGS: LLMSettings = {"temperature": 0} | |
DEFAULT_PLATFORM: Final[str] = "azure" | |
class MyLLM(AbstractMyLLM[InputType, OutputType]): | |
"""LLMの単一リクエストをまとめるクラス""" | |
def __init__( | |
self, | |
model: str, | |
parent: Optional["_MyL3M2"] = None, | |
llm_settings: LLMSettings | None = None, | |
client_settings: ClientSettings | None = None, | |
platform: str | None = None, | |
verbose: bool = False, | |
stream_verbose: bool = False, | |
silent_list: list[Literal["llm_settings", "inputs", "outputs", "messages", "metadata"]] | None = None, | |
log_dir: str | None = None, | |
) -> None: | |
""" | |
MyLLMクラスの初期化 | |
Args: | |
model (Optional[str]): LLMモデル名 | |
parent (Optional[MyL3M2]): 親のMyL3M2のインスタンス (self or None) | |
llm_settings (LLMSettings): LLMの設定パラメータ | |
client_settings (ClientSettings): llmのclientの設定パラメータ | |
platform (Optional[str]): LLMのプラットフォーム名 (デフォルト: os.environ["PLATFORM"] or "azure") | |
(enum: openai, azure) | |
verbose (bool): 出力をするかどうかのフラグ | |
stream_verbose (bool): assitantをstreamで出力するか(verbose=False, message in "messages"の時、無効) | |
silent_list (list[Literal["llm_settings", "inputs", "outputs", "messages", "metadata"]]): | |
verbose=True時, 出力を抑制する要素のリスト | |
log_dir (Optional[str]): ログを保存するディレクトリのパス Noneの時、保存しない | |
""" | |
self.parent: _MyL3M2 | None = parent | |
self.llm_settings = llm_settings or DEFAULT_LLM_SETTINGS | |
self.client_settings = client_settings or {} | |
self.model: str = model | |
self.platform: str = platform or os.environ.get("LLM_PLATFORM", DEFAULT_PLATFORM) or DEFAULT_PLATFORM | |
self.verbose: bool = verbose & (True if self.parent is None else self.parent.verbose) # 親に合わせる | |
self.silent_set = set(silent_list or []) | |
self.stream_verbose: bool = stream_verbose if verbose and ("messages" not in self.silent_set) else False | |
self.log_dir: str | None = log_dir | |
self.inputs: InputType | None = None | |
self.outputs: OutputType | None = None | |
self.messages: Messages | None = None | |
self.functions: Functions | None = None | |
self.tools: Tools | None = None | |
self.response: Response | None = None | |
self.called: bool = False | |
self.do_stream: bool = self.stream_verbose | |
self.llm: AbstractLLM = get_llm( | |
model_name=self.model, platform=self.platform, client_settings=self.client_settings | |
) | |
def _preprocess(self, inputs: InputType) -> Messages: | |
""" | |
inputs を API入力 の messages に前処理する | |
Args: | |
inputs (InputType): 入力 | |
Returns: | |
Messages: API入力 の messages | |
>>> [{"role": "system", "content": "system_prompt"}, {"role": "user", "content": "user_prompt"}] | |
""" | |
def _postprocess(self, response: Response) -> OutputType: | |
""" | |
API の response を outputs に後処理する | |
Args: | |
response (Response): API の response | |
>>> {"choices": [{"message": {"role": "assistant", | |
>>> "content": "This is a test!"}}]} | |
>>> {"choices": [{"message": {"role": "assistant", | |
>>> "function_call": {"name": "func", "arguments": "{a: 1}"}}]} | |
Returns: | |
OutputType: 出力 | |
""" | |
def _ruleprocess(self, inputs: InputType) -> OutputType | None: | |
""" | |
ルールベース処理 or APIリクエスト の判断 | |
Args: | |
inputs (InputType): MyLLMの入力 | |
Returns: | |
RuleOutputs: | |
ルールベース処理の時、MyLLMの出力を返す | |
APIリクエストの時、Noneを返す | |
""" | |
return None | |
def _update_settings(self) -> None: | |
""" | |
APIの設定の更新 | |
Note: | |
messageのトークン数 | |
>>> self.llm.count_tokens(self.messsage) | |
モデル変更 | |
>>> self.model = "gpt-3.5-turbo-16k" | |
パラメータ変更 | |
>>> self.llm_settings = {"temperature": 0.2} | |
""" | |
return None | |
def _add_tools(self, inputs: InputType) -> Tools | None: | |
return None | |
def _add_functions(self, inputs: InputType) -> Functions | None: | |
""" | |
functions の追加 | |
Args: | |
inputs (InputType): 入力 | |
Returns: | |
Functions | None: functions。追加しない場合None | |
https://json-schema.org/understanding-json-schema/reference/index.html | |
>>> { | |
>>> "name": "関数名", | |
>>> "description": "関数の動作の説明。GPTは説明を見て利用するか選ぶ", | |
>>> "parameters": { | |
>>> "type": "object", "properties": {"city_name": {"type": "string", "description": "都市名"}}, | |
>>> json-schema[https://json-schema.org/understanding-json-schema/reference/index.html] | |
>>> } | |
>>> } | |
""" | |
return None | |
def _stream_postprocess( | |
self, | |
new_chunk: Chunk, | |
state: "_State", | |
) -> StreamOutputType: | |
"""call_streamのGeneratorのpostprocess | |
Args: | |
new_chunk (OpenAIChunkResponse): 新しいchunk | |
state (dict[Any, Any]): 状態を持てるdict. 初めは、default {}. 状態が消えてしまうのでoverwriteしない。 | |
Returns: | |
StreamOutputType: 一時的なoutput | |
""" | |
if len(new_chunk.choices) == 0: | |
return "" | |
return new_chunk.choices[0].delta.content | |
def _generate(self, stream: bool) -> Generator[StreamOutputType, None, None]: | |
""" | |
LLMの出力を得て、`self.response`に格納する | |
Args: | |
messages (list[dict[str, str]]): LLMの入力メッセージ | |
""" | |
# 例外処理 ----------------------------------------------------------- | |
if self.messages is None: | |
raise ValueError("MessagesがNoneです。") | |
# kwargs ----------------------------------------------------------- | |
generate_kwargs = dict(**self.llm_settings) | |
if self.functions is not None: | |
generate_kwargs["functions"] = self.functions | |
if self.functions is not None: | |
generate_kwargs["tools"] = self.tools | |
# generate ---------------------------------------------------------- | |
self._print_messages() # verbose | |
self.llm = get_llm(model_name=self.model, platform=self.platform, client_settings=self.client_settings) | |
# [stream] | |
if stream or self.stream_verbose: | |
it = self.llm.generate_stream(messages=self.messages, llm_settings=generate_kwargs) | |
chunk_list: list[Chunk] = [] | |
state: "_State" = {} | |
for chunk in it: | |
chunk_list.append(chunk) | |
self._print_delta(chunk=chunk) # verbose: stop→改行、conent, TODO: fc→出力 | |
yield self._stream_postprocess(new_chunk=chunk, state=state) | |
self.response = self.llm.convert_nonstream_response(chunk_list, self.messages, self.functions) | |
# [non-stream] | |
else: | |
try: | |
self.response = self.llm.generate(messages=self.messages, llm_settings=generate_kwargs) | |
self._print_message_assistant() | |
except Exception as e: | |
raise e | |
# ContentFilterError ------------------------------------------------- | |
if len(self.response.choices) == 0: | |
cprint(self.response, color="red", background=True) | |
raise ContentFilterError("入力のコンテンツフィルターに引っかかりました。") | |
if self.response.choices[0].finish_reason == "content_filter": | |
cprint(self.response, color="red", background=True) | |
raise ContentFilterError("出力のコンテンツフィルターに引っかかりました。") | |
def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]: | |
""" | |
LLMの処理を行う (preprocess, check_input, generate, postprocess) | |
Args: | |
inputs (InputType): 入力データを保持する辞書 | |
Returns: | |
OutputType: 処理結果の出力データ | |
Raises: | |
RuntimeError: 既に呼び出されている場合に発生 | |
""" | |
if self.called: | |
raise RuntimeError("MyLLMは1回しか呼び出せない") | |
self._print_start(sep="-") | |
# main ----------------------------------------------------------- | |
t_start = time.time() | |
self.inputs = inputs | |
self._print_inputs() | |
rulebase_output = self._ruleprocess(inputs) | |
if rulebase_output is None: # API リクエストを送る場合 | |
self._update_settings() | |
self.messages = self._preprocess(inputs) | |
self.functions = self._add_functions(inputs) | |
self.tools = self._add_tools(inputs) | |
t_preprocessed = time.time() | |
# [generate] | |
it = self._generate(stream=stream) | |
for delta_content in it: # stream=Falseの時、空のGenerator | |
yield delta_content | |
if self.response is None: | |
raise ValueError("responseがNoneです。") | |
t_generated = time.time() | |
# [postprocess] | |
self.outputs = self._postprocess(self.response) | |
t_postprocessed = time.time() | |
else: # ルールベースの場合 | |
self.outputs = rulebase_output | |
t_preprocessed = t_generated = t_postprocessed = time.time() | |
self.time_detail = TimeInfo( | |
total=t_postprocessed - t_start, | |
preprocess=t_preprocessed - t_start, | |
main=t_generated - t_preprocessed, | |
postprocess=t_postprocessed - t_generated, | |
) | |
self.time = t_postprocessed - t_start | |
# print ----------------------------------------------------------- | |
self._print_outputs() | |
self._print_client_settings() | |
self._print_llm_settings() | |
self._print_metadata() | |
self._print_end(sep="-") | |
# 親MyL3M2にAppend ----------------------------------------------------------- | |
if self.parent is not None: | |
self.parent.myllm_list.append(self) | |
self.called = True | |
# log ----------------------------------------------------------- | |
self._save_log() | |
return self.outputs | |
def log(self) -> dict[str, Any]: | |
return { | |
"inputs": self.inputs, | |
"outputs": self.outputs, | |
"resposnse": self.response.model_dump() if self.response is not None else None, | |
"input_token": self.token.input, | |
"output_token": self.token.output, | |
"total_token": self.token.total, | |
"input_price": self.price.input, | |
"output_price": self.price.output, | |
"total_price": self.price.total, | |
"time": self.time, | |
"time_stamp": time.time(), | |
"llm_settings": self.llm_settings, | |
"client_settings": self.client_settings, | |
"model": self.model, | |
"platform": self.platform, | |
"verbose": self.verbose, | |
"messages": self.messages, | |
"assistant_message": self.assistant_message, | |
"functions": self.functions, | |
"tools": self.tools, | |
} | |
def _save_log(self) -> None: | |
if self.log_dir is None: | |
return | |
try: | |
log = self.log | |
json_string = dict2json(log) | |
save_log_path = os.path.join(self.log_dir, f"{log['time_stamp']}.json") | |
os.makedirs(self.log_dir, exist_ok=True) | |
with open(save_log_path, mode="w") as f: | |
f.write(json_string) | |
except Exception as e: | |
cprint(e, color="red", background=True) | |
def token(self) -> TokenInfo: | |
if self.response is None or self.response.usage is None: | |
return TokenInfo(input=0, output=0, total=0) | |
return TokenInfo( | |
input=self.response.usage.prompt_tokens, | |
output=self.response.usage.completion_tokens, | |
total=self.response.usage.total_tokens, | |
) | |
def custom_token(self) -> TokenInfo | None: | |
if not self.llm._custom_price_calculation: | |
return None | |
if self.response is None: | |
return TokenInfo(input=0, output=0, total=0) | |
usage_for_price = getattr(self.response, "usage_for_price", None) | |
if not isinstance(usage_for_price, CompletionUsageForCustomPriceCalculation): | |
cprint("usage_for_priceがNoneです。正しくトークン計算できません", color="red", background=True) | |
return TokenInfo(input=0, output=0, total=0) | |
return TokenInfo( | |
input=usage_for_price.prompt_tokens, | |
output=usage_for_price.completion_tokens, | |
total=usage_for_price.total_tokens, | |
) | |
def price(self) -> PriceInfo: | |
if self.response is None: | |
return PriceInfo(input=0.0, output=0.0, total=0.0) | |
if self.llm._custom_price_calculation: | |
# Geniniの時は必ずcustom_tokenがある想定 | |
if self.custom_token is None: | |
cprint("custom_tokenがNoneです。正しくトークン計算できません", color="red", background=True) | |
else: | |
return PriceInfo( | |
input=self.llm.calculate_price(num_input_tokens=self.custom_token.input), | |
output=self.llm.calculate_price(num_output_tokens=self.custom_token.output), | |
total=self.llm.calculate_price( | |
num_input_tokens=self.custom_token.input, num_output_tokens=self.custom_token.output | |
), | |
) | |
return PriceInfo( | |
input=self.llm.calculate_price(num_input_tokens=self.token.input), | |
output=self.llm.calculate_price(num_output_tokens=self.token.output), | |
total=self.llm.calculate_price(num_input_tokens=self.token.input, num_output_tokens=self.token.output), | |
) | |
def assistant_message(self) -> Message | None: | |
if self.response is None or len(self.response.choices) == 0: | |
return None | |
return self.response.choices[0].message.to_typeddict_message() | |
def chat_history(self) -> Messages: | |
chat_history: Messages = [] | |
if self.messages: | |
chat_history += self.messages | |
if self.assistant_message is not None: | |
chat_history.append(self.assistant_message) | |
return chat_history | |
def _print_llm_settings(self) -> None: | |
if not ("llm_settings" not in self.silent_set and self.verbose): | |
return | |
print_llm_settings( | |
llm_settings=self.llm_settings, | |
model=self.model, | |
platform=self.platform, | |
engine=self.llm.engine if isinstance(self.llm, AzureLLM) else None, | |
) | |
def _print_messages(self) -> None: | |
if not ("messages" not in self.silent_set and self.verbose): | |
return | |
print_messages(self.messages, title=True) | |
def _print_message_assistant(self) -> None: | |
if self.response is None or len(self.response.choices) == 0: | |
return | |
if not ("messages" not in self.silent_set and self.verbose): | |
return | |
print_messages(messages=[self.response.choices[0].message], title=False) | |
def _print_delta(self, chunk: Chunk) -> None: | |
if not ("messages" not in self.silent_set and self.verbose): | |
return | |
print_delta(chunk) | |
def _print_client_settings(self) -> None: | |
if not ("client_settings" not in self.silent_set and self.verbose): | |
return | |
print_client_settings(self.llm.client_settings) | |
def __repr__(self) -> str: | |
return f"MyLLM({self.__class__.__name__})" | |