Spaces:
Configuration error
Configuration error
from typing import Any, Literal, TypedDict | |
from neollm import MyLLM | |
from neollm.types import Functions | |
from neollm.utils.postprocess import json2dict | |
from neollm.utils.preprocess import optimize_token | |
class ProfileExtractorInputType(TypedDict): | |
text: str | |
class ProfileExtractorOuputType(TypedDict): | |
name: str | |
birth_year: int | |
domain: str | |
lang: Literal["ENG", "JPN"] | |
class ProfileExtractor(MyLLM): | |
"""情報を抽出するMyLLM | |
Notes: | |
inputs: | |
>>> {"text": str} | |
outpus: | |
>>> {"text_translated": str | None(うまくいかなかった場合)} | |
""" | |
def _preprocess(self, inputs: ProfileExtractorInputType): | |
system_prompt = "<input>より情報を抽出する。存在しない場合nullとする" | |
user_prompt = "<input>\n" f"'''{inputs['text'].strip()}'''" | |
messages = [ | |
{"role": "system", "content": optimize_token(system_prompt)}, | |
{"role": "user", "content": optimize_token(user_prompt)}, | |
] | |
return messages | |
def _check_input( | |
self, inputs: ProfileExtractorInputType, messages | |
) -> tuple[bool, ProfileExtractorOuputType | None]: | |
# 入力がない場合の処理 | |
if inputs["text"].strip() == "": | |
# requestしない, ルールベースのoutput | |
return False, {"name": "", "birth_year": -1, "domain": "", "lang": "JPN"} | |
# 入力が多い時に16kを使う | |
if self.llm.count_tokens(messages) >= 1600: | |
self.model = "gpt-3.5-turbo-16k" | |
else: | |
self.model = "gpt-3.5-turbo" | |
# requestする, _ | |
return True, None | |
def _postprocess(self, response) -> ProfileExtractorOuputType: | |
if dict(response["choices"][0]["message"]).get("function_call"): | |
try: | |
extracted_data = json2dict(response["choices"][0]["message"]["function_call"]["arguments"]) | |
except Exception: | |
extracted_data = {} | |
else: | |
extracted_data = {} | |
lang_ = extracted_data.get("lang") | |
if lang_ in {"ENG", "JPN"}: | |
lang = lang_ | |
else: | |
lang = "JPN" | |
outputs: ProfileExtractorOuputType = { | |
"name": str(extracted_data.get("name") or ""), | |
"birth_year": int(extracted_data.get("birth_year") or -1), | |
"domain": str(extracted_data.get("domain") or ""), | |
"lang": lang, | |
} | |
return outputs | |
# Function Callingを使う場合必要 | |
def _add_functions(self, inputs: Any) -> Functions | None: | |
functions: Functions = [ | |
{ | |
"name": "extract_profile", | |
"description": "extract profile of a person", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"name": { | |
"type": "string", | |
"description": "名前", | |
}, | |
"domain": { | |
"type": "string", | |
"description": "研究ドメイン カンマ区切り", | |
}, | |
"birth_year": { | |
"type": "integer", | |
"description": "the year of the birth YYYY", | |
}, | |
"lang": { | |
"type": "string", | |
"description": "the language of the text", | |
"enum": ["ENG", "JPN"], | |
}, | |
}, | |
"required": ["name", "birth_year", "domain", "lang"], | |
}, | |
} | |
] | |
return functions | |
# 型定義のために必要 | |
def __call__(self, inputs: ProfileExtractorInputType) -> ProfileExtractorOuputType: | |
outputs: ProfileExtractorOuputType = super().__call__(inputs) | |
return outputs | |