File size: 3,050 Bytes
88435ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import tiktoken

from neollm.llm.abstract_llm import AbstractLLM
from neollm.types import (
    ChatCompletion,
    ChatCompletionChunk,
    Message,
    Messages,
    OpenAIMessages,
    OpenAIResponse,
    OpenAIStreamResponse,
    Response,
    StreamResponse,
)


class AbstractGPT(AbstractLLM):
    def encode(self, text: str) -> list[int]:
        tokenizer = tiktoken.encoding_for_model(self.model or "gpt-3.5-turbo")
        return tokenizer.encode(text)

    def decode(self, encoded: list[int]) -> str:
        tokenizer = tiktoken.encoding_for_model(self.model or "gpt-3.5-turbo")
        return tokenizer.decode(encoded)

    def count_tokens(self, messages: list[Message] | None = None, only_response: bool = False) -> int:
        """
        トークン数の計測

        Args:
            messages (Messages): messages

        Returns:
            int: トークン数
        """
        if messages is None:
            return 0

        # count tokens
        num_tokens: int = 0
        # messages ---------------------------------------------------------------------------v
        for message in messages:
            # per message -------------------------------------------
            num_tokens += 4
            # content -----------------------------------------------
            content = message.get("content", None)
            if content is None:
                num_tokens += 0
            elif isinstance(content, str):
                num_tokens += len(self.encode(content))
                continue
            elif isinstance(content, list):
                for content_params in content:
                    if content_params["type"] == "text":
                        num_tokens += len(self.encode(content_params["text"]))
            # TODO: ChatCompletionFunctionMessageParam.name
            # tokens_per_name = 1
            # tool calls ------------------------------------------------
            # TODO: ChatCompletionAssistantMessageParam.function_call
            # TODO: ChatCompletionAssistantMessageParam.tool_calls

        if only_response:
            if len(messages) != 1:
                raise ValueError("only_response=Trueの場合、messagesは1つのみにしてください。")
            num_tokens -= 4  # per message分を消す
        else:
            num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>

        return num_tokens

    def _convert_to_response(self, platform_response: OpenAIResponse) -> Response:
        return ChatCompletion(**platform_response.model_dump())

    def _convert_to_platform_messages(self, messages: Messages) -> OpenAIMessages:
        # OpenAIのMessagesをデフォルトに置いているため、変換は不要
        platform_messages: OpenAIMessages = messages
        return platform_messages

    def _convert_to_streamresponse(self, platform_streamresponse: OpenAIStreamResponse) -> StreamResponse:
        for chunk in platform_streamresponse:
            yield ChatCompletionChunk(**chunk.model_dump())