from copy import deepcopy from typing import Literal, cast, get_args import vertexai from vertexai.generative_models import GenerationConfig from neollm.llm.abstract_llm import AbstractLLM from neollm.llm.gemini.abstract_gemini import AbstractGemini from neollm.types import APIPricing, ClientSettings, LLMSettings, StreamResponse from neollm.types.mytypes import Messages, Response from neollm.utils.utils import cprint # price: https://ai.google.dev/pricing?hl=ja # models: https://ai.google.dev/gemini-api/docs/models/gemini?hl=ja SUPPORTED_MODELS = Literal["gemini-1.0-pro", "gemini-1.0-pro-vision", "gemini-1.5-pro-preview-0409"] AVAILABLE_CONFIG_VARIABLES = [ "candidate_count", "stop_sequences", "temperature", "max_tokens", # "max_output_tokensが設定されていない場合、max_tokensを使う "max_output_tokens", "top_p", "top_k", ] def get_gcp_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM: vertexai.init(**client_settings) # map to LLM supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = { "gemini-1.0-pro": GCPGemini10Pro(client_settings), "gemini-1.0-pro-vision": GCPGemini10ProVision(client_settings), "gemini-1.5-pro-preview-0409": GCPGemini15Pro0409(client_settings), } if model_name in supported_model_map: model_name = cast(SUPPORTED_MODELS, model_name) return supported_model_map[model_name] raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.") class GoogleLLM(AbstractGemini): def generate_config(self, llm_settings: LLMSettings) -> GenerationConfig: """ 参考記事 : https://ai.google.dev/api/rest/v1/GenerationConfig?hl=ja """ # gemini candidate_count = llm_settings.pop("candidate_count", None) stop_sequences = llm_settings.pop("stop_sequences", None) temperature = llm_settings.pop("temperature", None) max_output_tokens = llm_settings.pop("max_output_tokens", None) top_p = llm_settings.pop("top_p", None) top_k = llm_settings.pop("top_k", None) # neollmの引数でも動くようにする if max_output_tokens is None: max_output_tokens = llm_settings.pop("max_tokens", None) if len(llm_settings) > 0 and "max_tokens" not in llm_settings: raise ValueError(f"llm_settings has unknown keys: {llm_settings}") return GenerationConfig( candidate_count=candidate_count, stop_sequences=stop_sequences, temperature=temperature, max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, ) class GCPGemini10Pro(GoogleLLM): dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000) model: str = "gemini-1.0-pro" context_window: int = 32_000 class GCPGemini10ProVision(GoogleLLM): dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000) model: str = "gemini-1.0-pro-vision" context_window: int = 32_000 def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response: messages = self._preprocess_message_to_use_system(messages) return super().generate(messages, llm_settings) def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: messages = self._preprocess_message_to_use_system(messages) return super().generate_stream(messages, llm_settings) def _preprocess_message_to_use_system(self, message: Messages) -> Messages: if message[0]["role"] != "system": return message preprocessed_message = deepcopy(message) system = preprocessed_message[0]["content"] del preprocessed_message[0] if ( isinstance(system, str) and isinstance(preprocessed_message[0]["content"], list) and isinstance(preprocessed_message[0]["content"][0]["text"], str) ): preprocessed_message[0]["content"][0]["text"] = system + preprocessed_message[0]["content"][0]["text"] else: cprint("WARNING: 入力形式が不正です", color="yellow", background=True) return preprocessed_message class GCPGemini15Pro0409(GoogleLLM): dollar_per_ktoken = APIPricing(input=2.5 / 1000, output=7.5 / 1000) model: str = "gemini-1.5-pro-preview-0409" context_window: int = 1_000_000