Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import collections | |
import contextlib | |
import sys | |
from collections.abc import Iterable, AsyncIterable | |
import dataclasses | |
import itertools | |
import textwrap | |
from typing import TypedDict, Union | |
import google.protobuf.json_format | |
import google.api_core.exceptions | |
from google.ai import generativelanguage as glm | |
from google.generativeai import string_utils | |
__all__ = [ | |
"AsyncGenerateContentResponse", | |
"BlockedPromptException", | |
"StopCandidateException", | |
"IncompleteIterationError", | |
"BrokenResponseError", | |
"GenerationConfigDict", | |
"GenerationConfigType", | |
"GenerationConfig", | |
"GenerateContentResponse", | |
] | |
if sys.version_info < (3, 10): | |
def aiter(obj): | |
return obj.__aiter__() | |
async def anext(obj, default=None): | |
try: | |
return await obj.__anext__() | |
except StopAsyncIteration: | |
if default is not None: | |
return default | |
else: | |
raise | |
class BlockedPromptException(Exception): | |
pass | |
class StopCandidateException(Exception): | |
pass | |
class IncompleteIterationError(Exception): | |
pass | |
class BrokenResponseError(Exception): | |
pass | |
class GenerationConfigDict(TypedDict): | |
# TODO(markdaoust): Python 3.11+ use `NotRequired`, ref: https://peps.python.org/pep-0655/ | |
candidate_count: int | |
stop_sequences: Iterable[str] | |
max_output_tokens: int | |
temperature: float | |
class GenerationConfig: | |
"""A simple dataclass used to configure the generation parameters of `GenerativeModel.generate_content`. | |
Attributes: | |
candidate_count: | |
Number of generated responses to return. | |
stop_sequences: | |
The set of character sequences (up | |
to 5) that will stop output generation. If | |
specified, the API will stop at the first | |
appearance of a stop sequence. The stop sequence | |
will not be included as part of the response. | |
max_output_tokens: | |
The maximum number of tokens to include in a | |
candidate. | |
If unset, this will default to output_token_limit specified | |
in the model's specification. | |
temperature: | |
Controls the randomness of the output. Note: The | |
default value varies by model, see the `Model.temperature` | |
attribute of the `Model` returned the `genai.get_model` | |
function. | |
Values can range from [0.0,1.0], inclusive. A value closer | |
to 1.0 will produce responses that are more varied and | |
creative, while a value closer to 0.0 will typically result | |
in more straightforward responses from the model. | |
top_p: | |
Optional. The maximum cumulative probability of tokens to | |
consider when sampling. | |
The model uses combined Top-k and nucleus sampling. | |
Tokens are sorted based on their assigned probabilities so | |
that only the most likely tokens are considered. Top-k | |
sampling directly limits the maximum number of tokens to | |
consider, while Nucleus sampling limits number of tokens | |
based on the cumulative probability. | |
Note: The default value varies by model, see the | |
`Model.top_p` attribute of the `Model` returned the | |
`genai.get_model` function. | |
top_k (int): | |
Optional. The maximum number of tokens to consider when | |
sampling. | |
The model uses combined Top-k and nucleus sampling. | |
Top-k sampling considers the set of `top_k` most probable | |
tokens. Defaults to 40. | |
Note: The default value varies by model, see the | |
`Model.top_k` attribute of the `Model` returned the | |
`genai.get_model` function. | |
""" | |
candidate_count: int | None = None | |
stop_sequences: Iterable[str] | None = None | |
max_output_tokens: int | None = None | |
temperature: float | None = None | |
top_p: float | None = None | |
top_k: int | None = None | |
GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig] | |
def to_generation_config_dict(generation_config: GenerationConfigType): | |
if generation_config is None: | |
return {} | |
elif isinstance(generation_config, glm.GenerationConfig): | |
return type(generation_config).to_dict(generation_config) # pytype: disable=attribute-error | |
elif isinstance(generation_config, GenerationConfig): | |
generation_config = dataclasses.asdict(generation_config) | |
return {key: value for key, value in generation_config.items() if value is not None} | |
elif hasattr(generation_config, "keys"): | |
return dict(generation_config) | |
else: | |
raise TypeError( | |
"Did not understand `generation_config`, expected a `dict` or" | |
f" `GenerationConfig`\nGot type: {type(generation_config)}\nValue:" | |
f" {generation_config}" | |
) | |
def _join_citation_metadatas( | |
citation_metadatas: Iterable[glm.CitationMetadata], | |
): | |
citation_metadatas = list(citation_metadatas) | |
return citation_metadatas[-1] | |
def _join_safety_ratings_lists( | |
safety_ratings_lists: Iterable[list[glm.SafetyRating]], | |
): | |
ratings = {} | |
blocked = collections.defaultdict(list) | |
for safety_ratings_list in safety_ratings_lists: | |
for rating in safety_ratings_list: | |
ratings[rating.category] = rating.probability | |
blocked[rating.category].append(rating.blocked) | |
blocked = {category: any(blocked) for category, blocked in blocked.items()} | |
safety_list = [] | |
for (category, probability), blocked in zip(ratings.items(), blocked.values()): | |
safety_list.append( | |
glm.SafetyRating(category=category, probability=probability, blocked=blocked) | |
) | |
return safety_list | |
def _join_contents(contents: Iterable[glm.Content]): | |
contents = tuple(contents) | |
roles = [c.role for c in contents if c.role] | |
if roles: | |
role = roles[0] | |
else: | |
role = "" | |
parts = [] | |
for content in contents: | |
parts.extend(content.parts) | |
merged_parts = [parts.pop(0)] | |
for part in parts: | |
if not merged_parts[-1].text: | |
merged_parts.append(part) | |
continue | |
if not part.text: | |
merged_parts.append(part) | |
continue | |
merged_part = glm.Part(merged_parts[-1]) | |
merged_part.text += part.text | |
merged_parts[-1] = merged_part | |
return glm.Content( | |
role=role, | |
parts=merged_parts, | |
) | |
def _join_candidates(candidates: Iterable[glm.Candidate]): | |
candidates = tuple(candidates) | |
index = candidates[0].index # These should all be the same. | |
return glm.Candidate( | |
index=index, | |
content=_join_contents([c.content for c in candidates]), | |
finish_reason=candidates[-1].finish_reason, | |
safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]), | |
citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]), | |
) | |
def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): | |
# Assuming that is a candidate ends, it is no longer returned in the list of | |
# candidates and that's why candidates have an index | |
candidates = collections.defaultdict(list) | |
for candidate_list in candidate_lists: | |
for candidate in candidate_list: | |
candidates[candidate.index].append(candidate) | |
new_candidates = [] | |
for index, candidate_parts in sorted(candidates.items()): | |
new_candidates.append(_join_candidates(candidate_parts)) | |
return new_candidates | |
def _join_prompt_feedbacks( | |
prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback], | |
): | |
# Always return the first prompt feedback. | |
return next(iter(prompt_feedbacks)) | |
def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]): | |
return glm.GenerateContentResponse( | |
candidates=_join_candidate_lists(c.candidates for c in chunks), | |
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), | |
) | |
_INCOMPLETE_ITERATION_MESSAGE = """\ | |
Please let the response complete iteration before accessing the final accumulated | |
attributes (or call `response.resolve()`)""" | |
class BaseGenerateContentResponse: | |
def __init__( | |
self, | |
done: bool, | |
iterator: ( | |
None | |
| Iterable[glm.GenerateContentResponse] | |
| AsyncIterable[glm.GenerateContentResponse] | |
), | |
result: glm.GenerateContentResponse, | |
chunks: Iterable[glm.GenerateContentResponse] | None = None, | |
): | |
self._done = done | |
self._iterator = iterator | |
self._result = result | |
if chunks is None: | |
self._chunks = [result] | |
else: | |
self._chunks = list(chunks) | |
if result.prompt_feedback.block_reason: | |
self._error = BlockedPromptException(result) | |
else: | |
self._error = None | |
def candidates(self): | |
"""The list of candidate responses. | |
Raises: | |
IncompleteIterationError: With `stream=True` if iteration over the stream was not completed. | |
""" | |
if not self._done: | |
raise IncompleteIterationError(_INCOMPLETE_ITERATION_MESSAGE) | |
return self._result.candidates | |
def parts(self): | |
"""A quick accessor equivalent to `self.candidates[0].parts` | |
Raises: | |
ValueError: If the candidate list does not contain exactly one candidate. | |
""" | |
candidates = self.candidates | |
if not candidates: | |
raise ValueError( | |
"The `response.parts` quick accessor only works for a single candidate, " | |
"but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked." | |
) | |
if len(candidates) > 1: | |
raise ValueError( | |
"The `response.parts` quick accessor only works with a " | |
"single candidate. With multiple candidates use " | |
"result.candidates[index].text" | |
) | |
parts = candidates[0].content.parts | |
return parts | |
def text(self): | |
"""A quick accessor equivalent to `self.candidates[0].parts[0].text` | |
Raises: | |
ValueError: If the candidate list or parts list does not contain exactly one entry. | |
""" | |
parts = self.parts | |
if not parts: | |
raise ValueError( | |
"The `response.text` quick accessor only works when the response contains a valid " | |
"`Part`, but none was returned. Check the `candidate.safety_ratings` to see if the " | |
"response was blocked." | |
) | |
return parts[0].text | |
def prompt_feedback(self): | |
return self._result.prompt_feedback | |
def __str__(self) -> str: | |
if self._done: | |
_iterator = "None" | |
else: | |
_iterator = f"<{self._iterator.__class__.__name__}>" | |
_result = f"glm.GenerateContentResponse({type(self._result).to_dict(self._result)})" | |
if self._error: | |
_error = f",\nerror=<{self._error.__class__.__name__}> {self._error}" | |
else: | |
_error = "" | |
return ( | |
textwrap.dedent( | |
f"""\ | |
response: | |
{type(self).__name__}( | |
done={self._done}, | |
iterator={_iterator}, | |
result={_result}, | |
)""" | |
) | |
+ _error | |
) | |
__repr__ = __str__ | |
def rewrite_stream_error(): | |
try: | |
yield | |
except (google.protobuf.json_format.ParseError, AttributeError) as e: | |
raise google.api_core.exceptions.BadRequest( | |
"Unknown error trying to retrieve streaming response. " | |
"Please retry with `stream=False` for more details." | |
) | |
GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method. | |
These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`. | |
This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback` | |
and `candidates` attributes. This class adds several quick accessors for common use cases. | |
The same object type is returned for both `stream=True/False`. | |
### Streaming | |
When you pass `stream=True` to `GenerativeModel.generate_content` or `ChatSession.send_message`, | |
iterate over this object to receive chunks of the response: | |
``` | |
response = model.generate_content(..., stream=True): | |
for chunk in response: | |
print(chunk.text) | |
``` | |
`GenerateContentResponse.prompt_feedback` is available immediately but | |
`GenerateContentResponse.candidates`, and all the attributes derived from them (`.text`, `.parts`), | |
are only available after the iteration is complete. | |
""" | |
ASYNC_GENERATE_CONTENT_RESPONSE_DOC = ( | |
"""This is the async version of `genai.GenerateContentResponse`.""" | |
) | |
class GenerateContentResponse(BaseGenerateContentResponse): | |
def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): | |
iterator = iter(iterator) | |
with rewrite_stream_error(): | |
response = next(iterator) | |
return cls( | |
done=False, | |
iterator=iterator, | |
result=response, | |
) | |
def from_response(cls, response: glm.GenerateContentResponse): | |
return cls( | |
done=True, | |
iterator=None, | |
result=response, | |
) | |
def __iter__(self): | |
# This is not thread safe. | |
if self._done: | |
for chunk in self._chunks: | |
yield GenerateContentResponse.from_response(chunk) | |
return | |
# Always have the next chunk available. | |
if len(self._chunks) == 0: | |
self._chunks.append(next(self._iterator)) | |
for n in itertools.count(): | |
if self._error: | |
raise self._error | |
if n >= len(self._chunks) - 1: | |
# Look ahead for a new item, so that you know the stream is done | |
# when you yield the last item. | |
if self._done: | |
return | |
try: | |
item = next(self._iterator) | |
except StopIteration: | |
self._done = True | |
except Exception as e: | |
self._error = e | |
self._done = True | |
else: | |
self._chunks.append(item) | |
self._result = _join_chunks([self._result, item]) | |
item = self._chunks[n] | |
item = GenerateContentResponse.from_response(item) | |
yield item | |
def resolve(self): | |
if self._done: | |
return | |
for _ in self: | |
pass | |
class AsyncGenerateContentResponse(BaseGenerateContentResponse): | |
async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]): | |
iterator = aiter(iterator) # type: ignore | |
with rewrite_stream_error(): | |
response = await anext(iterator) # type: ignore | |
return cls( | |
done=False, | |
iterator=iterator, | |
result=response, | |
) | |
def from_response(cls, response: glm.GenerateContentResponse): | |
return cls( | |
done=True, | |
iterator=None, | |
result=response, | |
) | |
async def __aiter__(self): | |
# This is not thread safe. | |
if self._done: | |
for chunk in self._chunks: | |
yield GenerateContentResponse.from_response(chunk) | |
return | |
# Always have the next chunk available. | |
if len(self._chunks) == 0: | |
self._chunks.append(await anext(self._iterator)) # type: ignore | |
for n in itertools.count(): | |
if self._error: | |
raise self._error | |
if n >= len(self._chunks) - 1: | |
# Look ahead for a new item, so that you know the stream is done | |
# when you yield the last item. | |
if self._done: | |
return | |
try: | |
item = await anext(self._iterator) # type: ignore | |
except StopAsyncIteration: | |
self._done = True | |
except Exception as e: | |
self._error = e | |
self._done = True | |
else: | |
self._chunks.append(item) | |
self._result = _join_chunks([self._result, item]) | |
item = self._chunks[n] | |
item = GenerateContentResponse.from_response(item) | |
yield item | |
async def resolve(self): | |
if self._done: | |
return | |
async for _ in self: | |
pass | |