neo-llm-module-v1.3.5 / project /ex_module /ex_profile_extractor.py
Kpenciler's picture
Upload 53 files
88435ed verified
from typing import Any, Literal, TypedDict
from neollm import MyLLM
from neollm.types import Functions
from neollm.utils.postprocess import json2dict
from neollm.utils.preprocess import optimize_token
class ProfileExtractorInputType(TypedDict):
text: str
class ProfileExtractorOuputType(TypedDict):
name: str
birth_year: int
domain: str
lang: Literal["ENG", "JPN"]
class ProfileExtractor(MyLLM):
"""情報を抽出するMyLLM
Notes:
inputs:
>>> {"text": str}
outpus:
>>> {"text_translated": str | None(うまくいかなかった場合)}
"""
def _preprocess(self, inputs: ProfileExtractorInputType):
system_prompt = "<input>より情報を抽出する。存在しない場合nullとする"
user_prompt = "<input>\n" f"'''{inputs['text'].strip()}'''"
messages = [
{"role": "system", "content": optimize_token(system_prompt)},
{"role": "user", "content": optimize_token(user_prompt)},
]
return messages
def _check_input(
self, inputs: ProfileExtractorInputType, messages
) -> tuple[bool, ProfileExtractorOuputType | None]:
# 入力がない場合の処理
if inputs["text"].strip() == "":
# requestしない, ルールベースのoutput
return False, {"name": "", "birth_year": -1, "domain": "", "lang": "JPN"}
# 入力が多い時に16kを使う
if self.llm.count_tokens(messages) >= 1600:
self.model = "gpt-3.5-turbo-16k"
else:
self.model = "gpt-3.5-turbo"
# requestする, _
return True, None
def _postprocess(self, response) -> ProfileExtractorOuputType:
if dict(response["choices"][0]["message"]).get("function_call"):
try:
extracted_data = json2dict(response["choices"][0]["message"]["function_call"]["arguments"])
except Exception:
extracted_data = {}
else:
extracted_data = {}
lang_ = extracted_data.get("lang")
if lang_ in {"ENG", "JPN"}:
lang = lang_
else:
lang = "JPN"
outputs: ProfileExtractorOuputType = {
"name": str(extracted_data.get("name") or ""),
"birth_year": int(extracted_data.get("birth_year") or -1),
"domain": str(extracted_data.get("domain") or ""),
"lang": lang,
}
return outputs
# Function Callingを使う場合必要
def _add_functions(self, inputs: Any) -> Functions | None:
functions: Functions = [
{
"name": "extract_profile",
"description": "extract profile of a person",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "名前",
},
"domain": {
"type": "string",
"description": "研究ドメイン カンマ区切り",
},
"birth_year": {
"type": "integer",
"description": "the year of the birth YYYY",
},
"lang": {
"type": "string",
"description": "the language of the text",
"enum": ["ENG", "JPN"],
},
},
"required": ["name", "birth_year", "domain", "lang"],
},
}
]
return functions
# 型定義のために必要
def __call__(self, inputs: ProfileExtractorInputType) -> ProfileExtractorOuputType:
outputs: ProfileExtractorOuputType = super().__call__(inputs)
return outputs