Spaces:
Configuration error
Configuration error
from copy import deepcopy | |
from typing import Literal, cast, get_args | |
import vertexai | |
from vertexai.generative_models import GenerationConfig | |
from neollm.llm.abstract_llm import AbstractLLM | |
from neollm.llm.gemini.abstract_gemini import AbstractGemini | |
from neollm.types import APIPricing, ClientSettings, LLMSettings, StreamResponse | |
from neollm.types.mytypes import Messages, Response | |
from neollm.utils.utils import cprint | |
# price: https://ai.google.dev/pricing?hl=ja | |
# models: https://ai.google.dev/gemini-api/docs/models/gemini?hl=ja | |
SUPPORTED_MODELS = Literal["gemini-1.0-pro", "gemini-1.0-pro-vision", "gemini-1.5-pro-preview-0409"] | |
AVAILABLE_CONFIG_VARIABLES = [ | |
"candidate_count", | |
"stop_sequences", | |
"temperature", | |
"max_tokens", # "max_output_tokensが設定されていない場合、max_tokensを使う | |
"max_output_tokens", | |
"top_p", | |
"top_k", | |
] | |
def get_gcp_llm(model_name: SUPPORTED_MODELS | str, client_settings: ClientSettings) -> AbstractLLM: | |
vertexai.init(**client_settings) | |
# map to LLM | |
supported_model_map: dict[SUPPORTED_MODELS, AbstractLLM] = { | |
"gemini-1.0-pro": GCPGemini10Pro(client_settings), | |
"gemini-1.0-pro-vision": GCPGemini10ProVision(client_settings), | |
"gemini-1.5-pro-preview-0409": GCPGemini15Pro0409(client_settings), | |
} | |
if model_name in supported_model_map: | |
model_name = cast(SUPPORTED_MODELS, model_name) | |
return supported_model_map[model_name] | |
raise ValueError(f"model_name must be {get_args(SUPPORTED_MODELS)}, but got {model_name}.") | |
class GoogleLLM(AbstractGemini): | |
def generate_config(self, llm_settings: LLMSettings) -> GenerationConfig: | |
""" | |
参考記事 : https://ai.google.dev/api/rest/v1/GenerationConfig?hl=ja | |
""" | |
# gemini | |
candidate_count = llm_settings.pop("candidate_count", None) | |
stop_sequences = llm_settings.pop("stop_sequences", None) | |
temperature = llm_settings.pop("temperature", None) | |
max_output_tokens = llm_settings.pop("max_output_tokens", None) | |
top_p = llm_settings.pop("top_p", None) | |
top_k = llm_settings.pop("top_k", None) | |
# neollmの引数でも動くようにする | |
if max_output_tokens is None: | |
max_output_tokens = llm_settings.pop("max_tokens", None) | |
if len(llm_settings) > 0 and "max_tokens" not in llm_settings: | |
raise ValueError(f"llm_settings has unknown keys: {llm_settings}") | |
return GenerationConfig( | |
candidate_count=candidate_count, | |
stop_sequences=stop_sequences, | |
temperature=temperature, | |
max_output_tokens=max_output_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
) | |
class GCPGemini10Pro(GoogleLLM): | |
dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000) | |
model: str = "gemini-1.0-pro" | |
context_window: int = 32_000 | |
class GCPGemini10ProVision(GoogleLLM): | |
dollar_per_ktoken = APIPricing(input=0.125 / 1000, output=0.375 / 1000) | |
model: str = "gemini-1.0-pro-vision" | |
context_window: int = 32_000 | |
def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response: | |
messages = self._preprocess_message_to_use_system(messages) | |
return super().generate(messages, llm_settings) | |
def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: | |
messages = self._preprocess_message_to_use_system(messages) | |
return super().generate_stream(messages, llm_settings) | |
def _preprocess_message_to_use_system(self, message: Messages) -> Messages: | |
if message[0]["role"] != "system": | |
return message | |
preprocessed_message = deepcopy(message) | |
system = preprocessed_message[0]["content"] | |
del preprocessed_message[0] | |
if ( | |
isinstance(system, str) | |
and isinstance(preprocessed_message[0]["content"], list) | |
and isinstance(preprocessed_message[0]["content"][0]["text"], str) | |
): | |
preprocessed_message[0]["content"][0]["text"] = system + preprocessed_message[0]["content"][0]["text"] | |
else: | |
cprint("WARNING: 入力形式が不正です", color="yellow", background=True) | |
return preprocessed_message | |
class GCPGemini15Pro0409(GoogleLLM): | |
dollar_per_ktoken = APIPricing(input=2.5 / 1000, output=7.5 / 1000) | |
model: str = "gemini-1.5-pro-preview-0409" | |
context_window: int = 1_000_000 | |