File size: 4,003 Bytes
88435ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Any, Literal, TypedDict

from neollm import MyLLM
from neollm.types import Functions
from neollm.utils.postprocess import json2dict
from neollm.utils.preprocess import optimize_token


class ProfileExtractorInputType(TypedDict):
    text: str


class ProfileExtractorOuputType(TypedDict):
    name: str
    birth_year: int
    domain: str
    lang: Literal["ENG", "JPN"]


class ProfileExtractor(MyLLM):
    """情報を抽出するMyLLM

    Notes:
        inputs:
            >>> {"text": str}
        outpus:
            >>> {"text_translated": str | None(うまくいかなかった場合)}
    """

    def _preprocess(self, inputs: ProfileExtractorInputType):
        system_prompt = "<input>より情報を抽出する。存在しない場合nullとする"
        user_prompt = "<input>\n" f"'''{inputs['text'].strip()}'''"
        messages = [
            {"role": "system", "content": optimize_token(system_prompt)},
            {"role": "user", "content": optimize_token(user_prompt)},
        ]
        return messages

    def _check_input(
        self, inputs: ProfileExtractorInputType, messages
    ) -> tuple[bool, ProfileExtractorOuputType | None]:
        # 入力がない場合の処理
        if inputs["text"].strip() == "":
            # requestしない, ルールベースのoutput
            return False, {"name": "", "birth_year": -1, "domain": "", "lang": "JPN"}
        # 入力が多い時に16kを使う
        if self.llm.count_tokens(messages) >= 1600:
            self.model = "gpt-3.5-turbo-16k"
        else:
            self.model = "gpt-3.5-turbo"
        # requestする, _
        return True, None

    def _postprocess(self, response) -> ProfileExtractorOuputType:
        if dict(response["choices"][0]["message"]).get("function_call"):
            try:
                extracted_data = json2dict(response["choices"][0]["message"]["function_call"]["arguments"])
            except Exception:
                extracted_data = {}
        else:
            extracted_data = {}

        lang_ = extracted_data.get("lang")
        if lang_ in {"ENG", "JPN"}:
            lang = lang_
        else:
            lang = "JPN"

        outputs: ProfileExtractorOuputType = {
            "name": str(extracted_data.get("name") or ""),
            "birth_year": int(extracted_data.get("birth_year") or -1),
            "domain": str(extracted_data.get("domain") or ""),
            "lang": lang,
        }
        return outputs

    # Function Callingを使う場合必要
    def _add_functions(self, inputs: Any) -> Functions | None:
        functions: Functions = [
            {
                "name": "extract_profile",
                "description": "extract profile of a person",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "name": {
                            "type": "string",
                            "description": "名前",
                        },
                        "domain": {
                            "type": "string",
                            "description": "研究ドメイン カンマ区切り",
                        },
                        "birth_year": {
                            "type": "integer",
                            "description": "the year of the birth YYYY",
                        },
                        "lang": {
                            "type": "string",
                            "description": "the language of the text",
                            "enum": ["ENG", "JPN"],
                        },
                    },
                    "required": ["name", "birth_year", "domain", "lang"],
                },
            }
        ]
        return functions

    # 型定義のために必要
    def __call__(self, inputs: ProfileExtractorInputType) -> ProfileExtractorOuputType:
        outputs: ProfileExtractorOuputType = super().__call__(inputs)
        return outputs