from typing import Any, Literal, TypedDict from neollm import MyLLM from neollm.types import Functions from neollm.utils.postprocess import json2dict from neollm.utils.preprocess import optimize_token class ProfileExtractorInputType(TypedDict): text: str class ProfileExtractorOuputType(TypedDict): name: str birth_year: int domain: str lang: Literal["ENG", "JPN"] class ProfileExtractor(MyLLM): """情報を抽出するMyLLM Notes: inputs: >>> {"text": str} outpus: >>> {"text_translated": str | None(うまくいかなかった場合)} """ def _preprocess(self, inputs: ProfileExtractorInputType): system_prompt = "より情報を抽出する。存在しない場合nullとする" user_prompt = "\n" f"'''{inputs['text'].strip()}'''" messages = [ {"role": "system", "content": optimize_token(system_prompt)}, {"role": "user", "content": optimize_token(user_prompt)}, ] return messages def _check_input( self, inputs: ProfileExtractorInputType, messages ) -> tuple[bool, ProfileExtractorOuputType | None]: # 入力がない場合の処理 if inputs["text"].strip() == "": # requestしない, ルールベースのoutput return False, {"name": "", "birth_year": -1, "domain": "", "lang": "JPN"} # 入力が多い時に16kを使う if self.llm.count_tokens(messages) >= 1600: self.model = "gpt-3.5-turbo-16k" else: self.model = "gpt-3.5-turbo" # requestする, _ return True, None def _postprocess(self, response) -> ProfileExtractorOuputType: if dict(response["choices"][0]["message"]).get("function_call"): try: extracted_data = json2dict(response["choices"][0]["message"]["function_call"]["arguments"]) except Exception: extracted_data = {} else: extracted_data = {} lang_ = extracted_data.get("lang") if lang_ in {"ENG", "JPN"}: lang = lang_ else: lang = "JPN" outputs: ProfileExtractorOuputType = { "name": str(extracted_data.get("name") or ""), "birth_year": int(extracted_data.get("birth_year") or -1), "domain": str(extracted_data.get("domain") or ""), "lang": lang, } return outputs # Function Callingを使う場合必要 def _add_functions(self, inputs: Any) -> Functions | None: functions: Functions = [ { "name": "extract_profile", "description": "extract profile of a person", "parameters": { "type": "object", "properties": { "name": { "type": "string", "description": "名前", }, "domain": { "type": "string", "description": "研究ドメイン カンマ区切り", }, "birth_year": { "type": "integer", "description": "the year of the birth YYYY", }, "lang": { "type": "string", "description": "the language of the text", "enum": ["ENG", "JPN"], }, }, "required": ["name", "birth_year", "domain", "lang"], }, } ] return functions # 型定義のために必要 def __call__(self, inputs: ProfileExtractorInputType) -> ProfileExtractorOuputType: outputs: ProfileExtractorOuputType = super().__call__(inputs) return outputs