File size: 4,361 Bytes
88435ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generator, Generic, Optional, TypeAlias, cast

from neollm.myllm.print_utils import print_inputs, print_metadata, print_outputs
from neollm.types import (
    InputType,
    OutputType,
    PriceInfo,
    StreamOutputType,
    TimeInfo,
    TokenInfo,
)
from neollm.utils.utils import cprint

if TYPE_CHECKING:
    from typing import Any

    from neollm.myllm.myl3m2 import MyL3M2

    _MyL3M2: TypeAlias = MyL3M2[Any, Any]


class AbstractMyLLM(ABC, Generic[InputType, OutputType]):
    """MyLLM, MyL3M2の抽象クラス"""

    inputs: InputType | None
    outputs: OutputType | None
    silent_set: set[str]
    verbose: bool
    time: float = 0.0
    time_detail: TimeInfo = TimeInfo()
    parent: Optional["_MyL3M2"] = None
    do_stream: bool

    @property
    @abstractmethod
    def token(self) -> TokenInfo:
        """LLMの利用トークン数

        Returns:
            TokenInfo: トークン数 (入力, 出力, 合計)
            >>> TokenInfo(input=1588, output=128, total=1716)
        """

    @property
    def custom_token(self) -> TokenInfo | None:
        """料金計算用トークン(Gemini用)"""
        return None

    @property
    @abstractmethod
    def price(self) -> PriceInfo:
        """LLMの利用料金 (USD)

        Returns:
            PriceInfo: 利用料金 (USD) (入力, 出力, 合計)
            >>> PriceInfo(input=0.002382, output=0.000256, total=0.002638)
        """

    @abstractmethod
    def _call(self, inputs: InputType, stream: bool = False) -> Generator[StreamOutputType, None, OutputType]:
        """MyLLMの子クラスのメインロジック

        streamとnon-streamの両方のコードを書く必要がある

        Args:
            inputs (InputType): LLMへの入力
            stream (bool, optional): streamの有無. Defaults to False.

        Yields:
            Generator[StreamOutputType, None, OutputType]: LLMのstream出力

        Returns:
            OutputType: LLMの出力
        """

    def __call__(self, inputs: InputType) -> OutputType:
        """MyLLMのメインロジック

        Args:
            inputs (InputType): LLMへの入力

        Returns:
            OutputType: LLMの出力
        """
        it: Generator[StreamOutputType, None, OutputType] = self._call(inputs, stream=self.do_stream)
        while True:
            try:
                next(it)
            except StopIteration as e:
                outputs = cast(OutputType, e.value)
                return outputs
            except Exception as e:
                raise e

    def call_stream(self, inputs: InputType) -> Generator[StreamOutputType, None, OutputType]:
        """MyLLMのメインロジック(stream処理)

        Args:
            inputs (InputType): LLMへの入力

        Yields:
            Generator[StreamOutputType, None, OutputType]: LLMのstream出力

        Returns:
            LLMの出力
        """
        it: Generator[StreamOutputType, None, OutputType] = self._call(inputs, stream=True)
        while True:
            try:
                delta_content = next(it)
                yield delta_content
            except StopIteration as e:
                outputs = cast(OutputType, e.value)
                return outputs
            except Exception as e:
                raise e

    def _print_inputs(self) -> None:
        if self.inputs is None:
            return
        if not ("inputs" not in self.silent_set and self.verbose):
            return
        print_inputs(self.inputs)

    def _print_outputs(self) -> None:
        if self.outputs is None:
            return
        if not ("outputs" not in self.silent_set and self.verbose):
            return
        print_outputs(self.outputs)

    def _print_metadata(self) -> None:
        if not ("metadata" not in self.silent_set and self.verbose):
            return
        print_metadata(self.time, self.token, self.price)

    def _print_start(self, sep: str = "-") -> None:
        if not self.verbose:
            return
        if self.parent is None:
            cprint("PARENT", color="red", background=True)
        print(self, sep * (99 - len(str(self))))

    def _print_end(self, sep: str = "-") -> None:
        if not self.verbose:
            return
        print(sep * 100)