from __future__ import annotations import collections import contextlib import sys from collections.abc import Iterable, AsyncIterable import dataclasses import itertools import textwrap from typing import TypedDict, Union import google.protobuf.json_format import google.api_core.exceptions from google.ai import generativelanguage as glm from google.generativeai import string_utils __all__ = [ "AsyncGenerateContentResponse", "BlockedPromptException", "StopCandidateException", "IncompleteIterationError", "BrokenResponseError", "GenerationConfigDict", "GenerationConfigType", "GenerationConfig", "GenerateContentResponse", ] if sys.version_info < (3, 10): def aiter(obj): return obj.__aiter__() async def anext(obj, default=None): try: return await obj.__anext__() except StopAsyncIteration: if default is not None: return default else: raise class BlockedPromptException(Exception): pass class StopCandidateException(Exception): pass class IncompleteIterationError(Exception): pass class BrokenResponseError(Exception): pass class GenerationConfigDict(TypedDict): # TODO(markdaoust): Python 3.11+ use `NotRequired`, ref: https://peps.python.org/pep-0655/ candidate_count: int stop_sequences: Iterable[str] max_output_tokens: int temperature: float @dataclasses.dataclass class GenerationConfig: """A simple dataclass used to configure the generation parameters of `GenerativeModel.generate_content`. Attributes: candidate_count: Number of generated responses to return. stop_sequences: The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. max_output_tokens: The maximum number of tokens to include in a candidate. If unset, this will default to output_token_limit specified in the model's specification. temperature: Controls the randomness of the output. Note: The default value varies by model, see the `Model.temperature` attribute of the `Model` returned the `genai.get_model` function. Values can range from [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied and creative, while a value closer to 0.0 will typically result in more straightforward responses from the model. top_p: Optional. The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and nucleus sampling. Tokens are sorted based on their assigned probabilities so that only the most likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while Nucleus sampling limits number of tokens based on the cumulative probability. Note: The default value varies by model, see the `Model.top_p` attribute of the `Model` returned the `genai.get_model` function. top_k (int): Optional. The maximum number of tokens to consider when sampling. The model uses combined Top-k and nucleus sampling. Top-k sampling considers the set of `top_k` most probable tokens. Defaults to 40. Note: The default value varies by model, see the `Model.top_k` attribute of the `Model` returned the `genai.get_model` function. """ candidate_count: int | None = None stop_sequences: Iterable[str] | None = None max_output_tokens: int | None = None temperature: float | None = None top_p: float | None = None top_k: int | None = None GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig] def to_generation_config_dict(generation_config: GenerationConfigType): if generation_config is None: return {} elif isinstance(generation_config, glm.GenerationConfig): return type(generation_config).to_dict(generation_config) # pytype: disable=attribute-error elif isinstance(generation_config, GenerationConfig): generation_config = dataclasses.asdict(generation_config) return {key: value for key, value in generation_config.items() if value is not None} elif hasattr(generation_config, "keys"): return dict(generation_config) else: raise TypeError( "Did not understand `generation_config`, expected a `dict` or" f" `GenerationConfig`\nGot type: {type(generation_config)}\nValue:" f" {generation_config}" ) def _join_citation_metadatas( citation_metadatas: Iterable[glm.CitationMetadata], ): citation_metadatas = list(citation_metadatas) return citation_metadatas[-1] def _join_safety_ratings_lists( safety_ratings_lists: Iterable[list[glm.SafetyRating]], ): ratings = {} blocked = collections.defaultdict(list) for safety_ratings_list in safety_ratings_lists: for rating in safety_ratings_list: ratings[rating.category] = rating.probability blocked[rating.category].append(rating.blocked) blocked = {category: any(blocked) for category, blocked in blocked.items()} safety_list = [] for (category, probability), blocked in zip(ratings.items(), blocked.values()): safety_list.append( glm.SafetyRating(category=category, probability=probability, blocked=blocked) ) return safety_list def _join_contents(contents: Iterable[glm.Content]): contents = tuple(contents) roles = [c.role for c in contents if c.role] if roles: role = roles[0] else: role = "" parts = [] for content in contents: parts.extend(content.parts) merged_parts = [parts.pop(0)] for part in parts: if not merged_parts[-1].text: merged_parts.append(part) continue if not part.text: merged_parts.append(part) continue merged_part = glm.Part(merged_parts[-1]) merged_part.text += part.text merged_parts[-1] = merged_part return glm.Content( role=role, parts=merged_parts, ) def _join_candidates(candidates: Iterable[glm.Candidate]): candidates = tuple(candidates) index = candidates[0].index # These should all be the same. return glm.Candidate( index=index, content=_join_contents([c.content for c in candidates]), finish_reason=candidates[-1].finish_reason, safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]), citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]), ) def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): # Assuming that is a candidate ends, it is no longer returned in the list of # candidates and that's why candidates have an index candidates = collections.defaultdict(list) for candidate_list in candidate_lists: for candidate in candidate_list: candidates[candidate.index].append(candidate) new_candidates = [] for index, candidate_parts in sorted(candidates.items()): new_candidates.append(_join_candidates(candidate_parts)) return new_candidates def _join_prompt_feedbacks( prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback], ): # Always return the first prompt feedback. return next(iter(prompt_feedbacks)) def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]): return glm.GenerateContentResponse( candidates=_join_candidate_lists(c.candidates for c in chunks), prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), ) _INCOMPLETE_ITERATION_MESSAGE = """\ Please let the response complete iteration before accessing the final accumulated attributes (or call `response.resolve()`)""" class BaseGenerateContentResponse: def __init__( self, done: bool, iterator: ( None | Iterable[glm.GenerateContentResponse] | AsyncIterable[glm.GenerateContentResponse] ), result: glm.GenerateContentResponse, chunks: Iterable[glm.GenerateContentResponse] | None = None, ): self._done = done self._iterator = iterator self._result = result if chunks is None: self._chunks = [result] else: self._chunks = list(chunks) if result.prompt_feedback.block_reason: self._error = BlockedPromptException(result) else: self._error = None @property def candidates(self): """The list of candidate responses. Raises: IncompleteIterationError: With `stream=True` if iteration over the stream was not completed. """ if not self._done: raise IncompleteIterationError(_INCOMPLETE_ITERATION_MESSAGE) return self._result.candidates @property def parts(self): """A quick accessor equivalent to `self.candidates[0].parts` Raises: ValueError: If the candidate list does not contain exactly one candidate. """ candidates = self.candidates if not candidates: raise ValueError( "The `response.parts` quick accessor only works for a single candidate, " "but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked." ) if len(candidates) > 1: raise ValueError( "The `response.parts` quick accessor only works with a " "single candidate. With multiple candidates use " "result.candidates[index].text" ) parts = candidates[0].content.parts return parts @property def text(self): """A quick accessor equivalent to `self.candidates[0].parts[0].text` Raises: ValueError: If the candidate list or parts list does not contain exactly one entry. """ parts = self.parts if not parts: raise ValueError( "The `response.text` quick accessor only works when the response contains a valid " "`Part`, but none was returned. Check the `candidate.safety_ratings` to see if the " "response was blocked." ) return parts[0].text @property def prompt_feedback(self): return self._result.prompt_feedback def __str__(self) -> str: if self._done: _iterator = "None" else: _iterator = f"<{self._iterator.__class__.__name__}>" _result = f"glm.GenerateContentResponse({type(self._result).to_dict(self._result)})" if self._error: _error = f",\nerror=<{self._error.__class__.__name__}> {self._error}" else: _error = "" return ( textwrap.dedent( f"""\ response: {type(self).__name__}( done={self._done}, iterator={_iterator}, result={_result}, )""" ) + _error ) __repr__ = __str__ @contextlib.contextmanager def rewrite_stream_error(): try: yield except (google.protobuf.json_format.ParseError, AttributeError) as e: raise google.api_core.exceptions.BadRequest( "Unknown error trying to retrieve streaming response. " "Please retry with `stream=False` for more details." ) GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method. These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`. This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback` and `candidates` attributes. This class adds several quick accessors for common use cases. The same object type is returned for both `stream=True/False`. ### Streaming When you pass `stream=True` to `GenerativeModel.generate_content` or `ChatSession.send_message`, iterate over this object to receive chunks of the response: ``` response = model.generate_content(..., stream=True): for chunk in response: print(chunk.text) ``` `GenerateContentResponse.prompt_feedback` is available immediately but `GenerateContentResponse.candidates`, and all the attributes derived from them (`.text`, `.parts`), are only available after the iteration is complete. """ ASYNC_GENERATE_CONTENT_RESPONSE_DOC = ( """This is the async version of `genai.GenerateContentResponse`.""" ) @string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC) class GenerateContentResponse(BaseGenerateContentResponse): @classmethod def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): iterator = iter(iterator) with rewrite_stream_error(): response = next(iterator) return cls( done=False, iterator=iterator, result=response, ) @classmethod def from_response(cls, response: glm.GenerateContentResponse): return cls( done=True, iterator=None, result=response, ) def __iter__(self): # This is not thread safe. if self._done: for chunk in self._chunks: yield GenerateContentResponse.from_response(chunk) return # Always have the next chunk available. if len(self._chunks) == 0: self._chunks.append(next(self._iterator)) for n in itertools.count(): if self._error: raise self._error if n >= len(self._chunks) - 1: # Look ahead for a new item, so that you know the stream is done # when you yield the last item. if self._done: return try: item = next(self._iterator) except StopIteration: self._done = True except Exception as e: self._error = e self._done = True else: self._chunks.append(item) self._result = _join_chunks([self._result, item]) item = self._chunks[n] item = GenerateContentResponse.from_response(item) yield item def resolve(self): if self._done: return for _ in self: pass @string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC) class AsyncGenerateContentResponse(BaseGenerateContentResponse): @classmethod async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]): iterator = aiter(iterator) # type: ignore with rewrite_stream_error(): response = await anext(iterator) # type: ignore return cls( done=False, iterator=iterator, result=response, ) @classmethod def from_response(cls, response: glm.GenerateContentResponse): return cls( done=True, iterator=None, result=response, ) async def __aiter__(self): # This is not thread safe. if self._done: for chunk in self._chunks: yield GenerateContentResponse.from_response(chunk) return # Always have the next chunk available. if len(self._chunks) == 0: self._chunks.append(await anext(self._iterator)) # type: ignore for n in itertools.count(): if self._error: raise self._error if n >= len(self._chunks) - 1: # Look ahead for a new item, so that you know the stream is done # when you yield the last item. if self._done: return try: item = await anext(self._iterator) # type: ignore except StopAsyncIteration: self._done = True except Exception as e: self._error = e self._done = True else: self._chunks.append(item) self._result = _join_chunks([self._result, item]) item = self._chunks[n] item = GenerateContentResponse.from_response(item) yield item async def resolve(self): if self._done: return async for _ in self: pass