Spaces:
Configuration error
Configuration error
File size: 4,489 Bytes
88435ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from copy import deepcopy
from typing import Literal, cast, get_args
import vertexai
from vertexai.generative_models import GenerationConfig
from neollm.llm.abstract_llm import AbstractLLM
from neollm.llm.gemini.abstract_gemini import AbstractGemini
from neollm.types import APIPricing, ClientSettings, LLMSettings, StreamResponse
from neollm.types.mytypes import Messages, Response
from neollm.utils.utils import cprint
# price: https://ai.google.dev/pricing?hl=ja
# models: https://ai.google.dev/gemini-api/docs/models/gemini?hl=ja
SUPPORTED_MODELS = Literal["gemini-1.0-pro", "gemini-1.0-pro-vision", "gemini-1.5-pro-preview-0409"]
AVAILABLE_CONFIG_VARIABLES = [
"candidate_count",
"stop_sequences",
"temperature",
"max_tokens", # "max_output_tokensが設定されていない場合、max_tokensを使う
"max_output_tokens",
"top_p",
"top_k",
]
def get_gcp_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM:
vertexai.init(**client_settings)
# map to LLM
supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = {
"gemini-1.0-pro": GCPGemini10Pro(client_settings),
"gemini-1.0-pro-vision": GCPGemini10ProVision(client_settings),
"gemini-1.5-pro-preview-0409": GCPGemini15Pro0409(client_settings),
}
if model_name in supported_model_map:
model_name = cast(SUPPORTED_MODELS, model_name)
return supported_model_map[model_name]
raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.")
class GoogleLLM(AbstractGemini):
def generate_config(self, llm_settings: LLMSettings) -> GenerationConfig:
"""
参考記事 : https://ai.google.dev/api/rest/v1/GenerationConfig?hl=ja
"""
# gemini
candidate_count = llm_settings.pop("candidate_count", None)
stop_sequences = llm_settings.pop("stop_sequences", None)
temperature = llm_settings.pop("temperature", None)
max_output_tokens = llm_settings.pop("max_output_tokens", None)
top_p = llm_settings.pop("top_p", None)
top_k = llm_settings.pop("top_k", None)
# neollmの引数でも動くようにする
if max_output_tokens is None:
max_output_tokens = llm_settings.pop("max_tokens", None)
if len(llm_settings) > 0 and "max_tokens" not in llm_settings:
raise ValueError(f"llm_settings has unknown keys: {llm_settings}")
return GenerationConfig(
candidate_count=candidate_count,
stop_sequences=stop_sequences,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
)
class GCPGemini10Pro(GoogleLLM):
dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000)
model: str = "gemini-1.0-pro"
context_window: int = 32_000
class GCPGemini10ProVision(GoogleLLM):
dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000)
model: str = "gemini-1.0-pro-vision"
context_window: int = 32_000
def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response:
messages = self._preprocess_message_to_use_system(messages)
return super().generate(messages, llm_settings)
def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse:
messages = self._preprocess_message_to_use_system(messages)
return super().generate_stream(messages, llm_settings)
def _preprocess_message_to_use_system(self, message: Messages) -> Messages:
if message[0]["role"] != "system":
return message
preprocessed_message = deepcopy(message)
system = preprocessed_message[0]["content"]
del preprocessed_message[0]
if (
isinstance(system, str)
and isinstance(preprocessed_message[0]["content"], list)
and isinstance(preprocessed_message[0]["content"][0]["text"], str)
):
preprocessed_message[0]["content"][0]["text"] = system + preprocessed_message[0]["content"][0]["text"]
else:
cprint("WARNING: 入力形式が不正です", color="yellow", background=True)
return preprocessed_message
class GCPGemini15Pro0409(GoogleLLM):
dollar_per_ktoken = APIPricing(input=2.5 / 1000, output=7.5 / 1000)
model: str = "gemini-1.5-pro-preview-0409"
context_window: int = 1_000_000
|