import time from abc import abstractmethod from typing import Iterable, cast from google.cloud.aiplatform_v1beta1.types import CountTokensResponse from google.cloud.aiplatform_v1beta1.types.content import Candidate from vertexai.generative_models import ( Content, GenerationConfig, GenerationResponse, GenerativeModel, Part, ) from vertexai.generative_models._generative_models import ContentsType from neollm.llm.abstract_llm import AbstractLLM from neollm.types import ( ChatCompletion, CompletionUsageForCustomPriceCalculation, LLMSettings, Message, Messages, Response, StreamResponse, ) from neollm.types.openai.chat_completion import ( ChatCompletionMessage, Choice, CompletionUsage, ) from neollm.types.openai.chat_completion import FinishReason as FinishReasonVertex from neollm.types.openai.chat_completion_chunk import ( ChatCompletionChunk, ChoiceDelta, ChunkChoice, ) from neollm.utils.utils import cprint class AbstractGemini(AbstractLLM): @abstractmethod def generate_config(self, llm_settings: LLMSettings) -> GenerationConfig: ... # 使っていない def encode(self, text: str) -> list[int]: return [ord(char) for char in text] # 使っていない def decode(self, decoded: list[int]) -> str: return "".join([chr(number) for number in decoded]) def _count_tokens_vertex(self, contents: ContentsType) -> CountTokensResponse: model = GenerativeModel(model_name=self.model) return cast(CountTokensResponse, model.count_tokens(contents)) def count_tokens(self, messages: list[Message] | None = None, only_response: bool = False) -> int: """ トークン数の計測 Args: messages (Messages): messages Returns: int: トークン数 """ if messages is None: return 0 _system, _message = self._convert_to_platform_messages(messages) total_tokens = 0 if _system: total_tokens += int(self._count_tokens_vertex(_system).total_tokens) if _message: total_tokens = int(self._count_tokens_vertex(_message).total_tokens) return total_tokens def _convert_to_platform_messages(self, messages: Messages) -> tuple[str | None, list[Content]]: _system = None _message: list[Content] = [] for message in messages: if message["role"] == "system": _system = "\n" + message["content"] elif message["role"] == "user": if isinstance(message["content"], str): _message.append(Content(role="user", parts=[Part.from_text(message["content"])])) else: try: if isinstance(message["content"], list) and message["content"][1]["type"] == "image_url": encoded_image = message["content"][1]["image_url"]["url"].split(",")[-1] _message.append( Content( role="user", parts=[ Part.from_text(message["content"][0]["text"]), Part.from_data(data=encoded_image, mime_type="image/jpeg"), ], ) ) except KeyError: cprint("WARNING: 未対応です", color="yellow", background=True) except IndexError: cprint("WARNING: 未対応です", color="yellow", background=True) except Exception as e: cprint(e, color="red", background=True) elif message["role"] == "assistant": if isinstance(message["content"], str): _message.append(Content(role="model", parts=[Part.from_text(message["content"])])) else: cprint("WARNING: 未対応です", color="yellow", background=True) return _system, _message def _convert_finish_reason(self, stop_reason: Candidate.FinishReason) -> FinishReasonVertex | None: """ 参考記事 : https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason 0: FINISH_REASON_UNSPECIFIED Default value. This value is unused. 1: STOP Natural stop point of the model or provided stop sequence. 2: MAX_TOKENS The maximum number of tokens as specified in the request was reached. 3: SAFETY The candidate content was flagged for safety reasons. 4: RECITATION The candidate content was flagged for recitation reasons. 5: OTHER Unknown reason. """ if stop_reason.value in [0, 3, 4, 5]: return "stop" if stop_reason.value in [2]: return "length" return None def _convert_to_response( self, platform_response: GenerationResponse, system: str | None, message: list[Content] ) -> Response: # input 請求用文字数 input_billable_characters = 0 if system: input_billable_characters += self._count_tokens_vertex(system).total_billable_characters if message: input_billable_characters += self._count_tokens_vertex(message).total_billable_characters # output 請求用文字数 output_billable_characters = 0 if platform_response.text: output_billable_characters += self._count_tokens_vertex(platform_response.text).total_billable_characters return ChatCompletion( # type: ignore [call-arg] id="", choices=[ Choice( index=0, message=ChatCompletionMessage( content=platform_response.text, role="assistant", ), finish_reason=self._convert_finish_reason(platform_response.candidates[0].finish_reason), ) ], created=int(time.time()), model=self.model, object="messages.create", system_fingerprint=None, usage=CompletionUsage( prompt_tokens=platform_response.usage_metadata.prompt_token_count, completion_tokens=platform_response.usage_metadata.candidates_token_count, total_tokens=platform_response.usage_metadata.prompt_token_count + platform_response.usage_metadata.candidates_token_count, ), usage_for_price=CompletionUsageForCustomPriceCalculation( prompt_tokens=input_billable_characters, completion_tokens=output_billable_characters, total_tokens=input_billable_characters + output_billable_characters, ), ) def _convert_to_streamresponse(self, platform_streamresponse: Iterable[GenerationResponse]) -> StreamResponse: created = int(time.time()) content: str | None = None for chunk in platform_streamresponse: content = chunk.text yield ChatCompletionChunk( id="", choices=[ ChunkChoice( delta=ChoiceDelta( content=content, role="assistant", ), finish_reason=self._convert_finish_reason(chunk.candidates[0].finish_reason), index=0, # 0-indexedじゃないかもしれないので0に塗り替え ) ], created=created, model=self.model, object="chat.completion.chunk", ) def generate(self, messages: Messages, llm_settings: LLMSettings) -> Response: _system, _message = self._convert_to_platform_messages(messages) model = GenerativeModel( model_name=self.model, system_instruction=_system, ) response = model.generate_content( contents=_message, stream=False, generation_config=self.generate_config(llm_settings), ) return self._convert_to_response(platform_response=response, system=_system, message=_message) def generate_stream(self, messages: Messages, llm_settings: LLMSettings) -> StreamResponse: _system, _message = self._convert_to_platform_messages(messages) model = GenerativeModel( model_name=self.model, system_instruction=_system, ) response = model.generate_content( contents=_message, stream=True, generation_config=self.generate_config(llm_settings), ) return self._convert_to_streamresponse(platform_streamresponse=response)