File size: 4,489 Bytes
88435ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from copy import deepcopy
from typing import Literal, cast, get_args

import vertexai
from vertexai.generative_models import GenerationConfig

from neollm.llm.abstract_llm import AbstractLLM
from neollm.llm.gemini.abstract_gemini import AbstractGemini
from neollm.types import APIPricing, ClientSettings, LLMSettings, StreamResponse
from neollm.types.mytypes import Messages, Response
from neollm.utils.utils import cprint

# price: https://ai.google.dev/pricing?hl=ja
# models: https://ai.google.dev/gemini-api/docs/models/gemini?hl=ja

SUPPORTED_MODELS = Literal["gemini-1.0-pro", "gemini-1.0-pro-vision", "gemini-1.5-pro-preview-0409"]
AVAILABLE_CONFIG_VARIABLES = [
    "candidate_count",
    "stop_sequences",
    "temperature",
    "max_tokens",  # "max_output_tokensが設定されていない場合、max_tokensを使う
    "max_output_tokens",
    "top_p",
    "top_k",
]


def get_gcp_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:

    vertexai.init(**client_settings)

    # map to LLM
    supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
        "gemini-1.0-pro": GCPGemini10Pro(client_settings),
        "gemini-1.0-pro-vision": GCPGemini10ProVision(client_settings),
        "gemini-1.5-pro-preview-0409": GCPGemini15Pro0409(client_settings),
    }
    if model_name in supported_model_map:
        model_name = cast(SUPPORTED_MODELS, model_name)
        return supported_model_map[model_name]
    raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.")


class GoogleLLM(AbstractGemini):

    def generate_config(self, llm_settings: LLMSettings) -> GenerationConfig:
        """
        参考記事 : https://ai.google.dev/api/rest/v1/GenerationConfig?hl=ja
        """
        # gemini
        candidate_count = llm_settings.pop("candidate_count", None)
        stop_sequences = llm_settings.pop("stop_sequences", None)
        temperature = llm_settings.pop("temperature", None)
        max_output_tokens = llm_settings.pop("max_output_tokens", None)
        top_p = llm_settings.pop("top_p", None)
        top_k = llm_settings.pop("top_k", None)

        # neollmの引数でも動くようにする
        if max_output_tokens is None:
            max_output_tokens = llm_settings.pop("max_tokens", None)

        if len(llm_settings) > 0 and "max_tokens" not in llm_settings:
            raise ValueError(f"llm_settings has unknown keys: {llm_settings}")

        return GenerationConfig(
            candidate_count=candidate_count,
            stop_sequences=stop_sequences,
            temperature=temperature,
            max_output_tokens=max_output_tokens,
            top_p=top_p,
            top_k=top_k,
        )


class GCPGemini10Pro(GoogleLLM):
    dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000)
    model: str = "gemini-1.0-pro"
    context_window: int = 32_000


class GCPGemini10ProVision(GoogleLLM):
    dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000)
    model: str = "gemini-1.0-pro-vision"
    context_window: int = 32_000

    def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
        messages = self._preprocess_message_to_use_system(messages)
        return super().generate(messages, llm_settings)

    def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse:
        messages = self._preprocess_message_to_use_system(messages)
        return super().generate_stream(messages, llm_settings)

    def _preprocess_message_to_use_system(self, message: Messages) -> Messages:
        if message[0]["role"] != "system":
            return message
        preprocessed_message = deepcopy(message)
        system = preprocessed_message[0]["content"]
        del preprocessed_message[0]
        if (
            isinstance(system, str)
            and isinstance(preprocessed_message[0]["content"], list)
            and isinstance(preprocessed_message[0]["content"][0]["text"], str)
        ):
            preprocessed_message[0]["content"][0]["text"] = system + preprocessed_message[0]["content"][0]["text"]
        else:
            cprint("WARNING: 入力形式が不正です", color="yellow", background=True)
        return preprocessed_message


class GCPGemini15Pro0409(GoogleLLM):
    dollar_per_ktoken = APIPricing(input=2.5 / 1000, output=7.5 / 1000)
    model: str = "gemini-1.5-pro-preview-0409"
    context_window: int = 1_000_000