File size: 17,235 Bytes
9e3c734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
from __future__ import annotations

import collections
import contextlib
import sys
from collections.abc import Iterable, AsyncIterable
import dataclasses
import itertools
import textwrap
from typing import TypedDict, Union

import google.protobuf.json_format
import google.api_core.exceptions

from google.ai import generativelanguage as glm
from google.generativeai import string_utils

__all__ = [
    "AsyncGenerateContentResponse",
    "BlockedPromptException",
    "StopCandidateException",
    "IncompleteIterationError",
    "BrokenResponseError",
    "GenerationConfigDict",
    "GenerationConfigType",
    "GenerationConfig",
    "GenerateContentResponse",
]

if sys.version_info < (3, 10):

    def aiter(obj):
        return obj.__aiter__()

    async def anext(obj, default=None):
        try:
            return await obj.__anext__()
        except StopAsyncIteration:
            if default is not None:
                return default
            else:
                raise


class BlockedPromptException(Exception):
    pass


class StopCandidateException(Exception):
    pass


class IncompleteIterationError(Exception):
    pass


class BrokenResponseError(Exception):
    pass


class GenerationConfigDict(TypedDict):
    # TODO(markdaoust): Python 3.11+ use `NotRequired`, ref: https://peps.python.org/pep-0655/
    candidate_count: int
    stop_sequences: Iterable[str]
    max_output_tokens: int
    temperature: float


@dataclasses.dataclass
class GenerationConfig:
    """A simple dataclass used to configure the generation parameters of `GenerativeModel.generate_content`.

    Attributes:
        candidate_count:
            Number of generated responses to return.
        stop_sequences:
            The set of character sequences (up
            to 5) that will stop output generation. If
            specified, the API will stop at the first
            appearance of a stop sequence. The stop sequence
            will not be included as part of the response.
        max_output_tokens:
            The maximum number of tokens to include in a
            candidate.

            If unset, this will default to output_token_limit specified
            in the model's specification.
        temperature:
            Controls the randomness of the output. Note: The

            default value varies by model, see the `Model.temperature`
            attribute of the `Model` returned the `genai.get_model`
            function.

            Values can range from [0.0,1.0], inclusive. A value closer
            to 1.0 will produce responses that are more varied and
            creative, while a value closer to 0.0 will typically result
            in more straightforward responses from the model.
        top_p:
            Optional. The maximum cumulative probability of tokens to
            consider when sampling.

            The model uses combined Top-k and nucleus sampling.

            Tokens are sorted based on their assigned probabilities so
            that only the most likely tokens are considered. Top-k
            sampling directly limits the maximum number of tokens to
            consider, while Nucleus sampling limits number of tokens
            based on the cumulative probability.

            Note: The default value varies by model, see the
            `Model.top_p` attribute of the `Model` returned the
            `genai.get_model` function.

        top_k (int):
            Optional. The maximum number of tokens to consider when
            sampling.

            The model uses combined Top-k and nucleus sampling.

            Top-k sampling considers the set of `top_k` most probable
            tokens. Defaults to 40.

            Note: The default value varies by model, see the
            `Model.top_k` attribute of the `Model` returned the
            `genai.get_model` function.
    """

    candidate_count: int | None = None
    stop_sequences: Iterable[str] | None = None
    max_output_tokens: int | None = None
    temperature: float | None = None
    top_p: float | None = None
    top_k: int | None = None


GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig]


def to_generation_config_dict(generation_config: GenerationConfigType):
    if generation_config is None:
        return {}
    elif isinstance(generation_config, glm.GenerationConfig):
        return type(generation_config).to_dict(generation_config)  # pytype: disable=attribute-error
    elif isinstance(generation_config, GenerationConfig):
        generation_config = dataclasses.asdict(generation_config)
        return {key: value for key, value in generation_config.items() if value is not None}
    elif hasattr(generation_config, "keys"):
        return dict(generation_config)
    else:
        raise TypeError(
            "Did not understand `generation_config`, expected a `dict` or"
            f" `GenerationConfig`\nGot type: {type(generation_config)}\nValue:"
            f" {generation_config}"
        )


def _join_citation_metadatas(
    citation_metadatas: Iterable[glm.CitationMetadata],
):
    citation_metadatas = list(citation_metadatas)
    return citation_metadatas[-1]


def _join_safety_ratings_lists(
    safety_ratings_lists: Iterable[list[glm.SafetyRating]],
):
    ratings = {}
    blocked = collections.defaultdict(list)

    for safety_ratings_list in safety_ratings_lists:
        for rating in safety_ratings_list:
            ratings[rating.category] = rating.probability
            blocked[rating.category].append(rating.blocked)

    blocked = {category: any(blocked) for category, blocked in blocked.items()}

    safety_list = []
    for (category, probability), blocked in zip(ratings.items(), blocked.values()):
        safety_list.append(
            glm.SafetyRating(category=category, probability=probability, blocked=blocked)
        )

    return safety_list


def _join_contents(contents: Iterable[glm.Content]):
    contents = tuple(contents)
    roles = [c.role for c in contents if c.role]
    if roles:
        role = roles[0]
    else:
        role = ""

    parts = []
    for content in contents:
        parts.extend(content.parts)

    merged_parts = [parts.pop(0)]
    for part in parts:
        if not merged_parts[-1].text:
            merged_parts.append(part)
            continue

        if not part.text:
            merged_parts.append(part)
            continue

        merged_part = glm.Part(merged_parts[-1])
        merged_part.text += part.text
        merged_parts[-1] = merged_part

    return glm.Content(
        role=role,
        parts=merged_parts,
    )


def _join_candidates(candidates: Iterable[glm.Candidate]):
    candidates = tuple(candidates)

    index = candidates[0].index  # These should all be the same.

    return glm.Candidate(
        index=index,
        content=_join_contents([c.content for c in candidates]),
        finish_reason=candidates[-1].finish_reason,
        safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]),
        citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]),
    )


def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]):
    # Assuming that is a candidate ends, it is no longer returned in the list of
    # candidates and that's why candidates have an index
    candidates = collections.defaultdict(list)
    for candidate_list in candidate_lists:
        for candidate in candidate_list:
            candidates[candidate.index].append(candidate)

    new_candidates = []
    for index, candidate_parts in sorted(candidates.items()):
        new_candidates.append(_join_candidates(candidate_parts))

    return new_candidates


def _join_prompt_feedbacks(
    prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback],
):
    # Always return the first prompt feedback.
    return next(iter(prompt_feedbacks))


def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]):
    return glm.GenerateContentResponse(
        candidates=_join_candidate_lists(c.candidates for c in chunks),
        prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
    )


_INCOMPLETE_ITERATION_MESSAGE = """\
Please let the response complete iteration before accessing the final accumulated
attributes (or call `response.resolve()`)"""


class BaseGenerateContentResponse:
    def __init__(
        self,
        done: bool,
        iterator: (
            None
            | Iterable[glm.GenerateContentResponse]
            | AsyncIterable[glm.GenerateContentResponse]
        ),
        result: glm.GenerateContentResponse,
        chunks: Iterable[glm.GenerateContentResponse] | None = None,
    ):
        self._done = done
        self._iterator = iterator
        self._result = result
        if chunks is None:
            self._chunks = [result]
        else:
            self._chunks = list(chunks)
        if result.prompt_feedback.block_reason:
            self._error = BlockedPromptException(result)
        else:
            self._error = None

    @property
    def candidates(self):
        """The list of candidate responses.

        Raises:
            IncompleteIterationError: With `stream=True` if iteration over the stream was not completed.
        """
        if not self._done:
            raise IncompleteIterationError(_INCOMPLETE_ITERATION_MESSAGE)
        return self._result.candidates

    @property
    def parts(self):
        """A quick accessor equivalent to `self.candidates[0].parts`

        Raises:
            ValueError: If the candidate list does not contain exactly one candidate.
        """
        candidates = self.candidates
        if not candidates:
            raise ValueError(
                "The `response.parts` quick accessor only works for a single candidate, "
                "but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked."
            )
        if len(candidates) > 1:
            raise ValueError(
                "The `response.parts` quick accessor only works with a "
                "single candidate. With multiple candidates use "
                "result.candidates[index].text"
            )
        parts = candidates[0].content.parts
        return parts

    @property
    def text(self):
        """A quick accessor equivalent to `self.candidates[0].parts[0].text`

        Raises:
            ValueError: If the candidate list or parts list does not contain exactly one entry.
        """
        parts = self.parts
        if not parts:
            raise ValueError(
                "The `response.text` quick accessor only works when the response contains a valid "
                "`Part`, but none was returned. Check the `candidate.safety_ratings` to see if the "
                "response was blocked."
            )

        return parts[0].text

    @property
    def prompt_feedback(self):
        return self._result.prompt_feedback

    def __str__(self) -> str:
        if self._done:
            _iterator = "None"
        else:
            _iterator = f"<{self._iterator.__class__.__name__}>"

        _result = f"glm.GenerateContentResponse({type(self._result).to_dict(self._result)})"

        if self._error:
            _error = f",\nerror=<{self._error.__class__.__name__}> {self._error}"
        else:
            _error = ""

        return (
            textwrap.dedent(
                f"""\
                response:
                {type(self).__name__}(
                    done={self._done},
                    iterator={_iterator},
                    result={_result},
                )"""
            )
            + _error
        )

    __repr__ = __str__


@contextlib.contextmanager
def rewrite_stream_error():
    try:
        yield
    except (google.protobuf.json_format.ParseError, AttributeError) as e:
        raise google.api_core.exceptions.BadRequest(
            "Unknown error trying to retrieve streaming response. "
            "Please retry with `stream=False` for more details."
        )


GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method.

    These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`.
    This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback`
    and `candidates` attributes. This class adds several quick accessors for common use cases.

    The same object type is returned for both `stream=True/False`.

    ### Streaming

    When you pass `stream=True` to `GenerativeModel.generate_content` or `ChatSession.send_message`,
    iterate over this object to receive chunks of the response:

    ```
    response = model.generate_content(..., stream=True):
    for chunk in response:
      print(chunk.text)
    ```

    `GenerateContentResponse.prompt_feedback` is available immediately but
    `GenerateContentResponse.candidates`, and all the attributes derived from them (`.text`, `.parts`),
    are only available after the iteration is complete.
    """

ASYNC_GENERATE_CONTENT_RESPONSE_DOC = (
    """This is the async version of `genai.GenerateContentResponse`."""
)


@string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC)
class GenerateContentResponse(BaseGenerateContentResponse):
    @classmethod
    def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]):
        iterator = iter(iterator)
        with rewrite_stream_error():
            response = next(iterator)

        return cls(
            done=False,
            iterator=iterator,
            result=response,
        )

    @classmethod
    def from_response(cls, response: glm.GenerateContentResponse):
        return cls(
            done=True,
            iterator=None,
            result=response,
        )

    def __iter__(self):
        # This is not thread safe.
        if self._done:
            for chunk in self._chunks:
                yield GenerateContentResponse.from_response(chunk)
            return

        # Always have the next chunk available.
        if len(self._chunks) == 0:
            self._chunks.append(next(self._iterator))

        for n in itertools.count():
            if self._error:
                raise self._error

            if n >= len(self._chunks) - 1:
                # Look ahead for a new item, so that you know the stream is done
                # when you yield the last item.
                if self._done:
                    return

                try:
                    item = next(self._iterator)
                except StopIteration:
                    self._done = True
                except Exception as e:
                    self._error = e
                    self._done = True
                else:
                    self._chunks.append(item)
                    self._result = _join_chunks([self._result, item])

            item = self._chunks[n]

            item = GenerateContentResponse.from_response(item)
            yield item

    def resolve(self):
        if self._done:
            return

        for _ in self:
            pass


@string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC)
class AsyncGenerateContentResponse(BaseGenerateContentResponse):
    @classmethod
    async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]):
        iterator = aiter(iterator)  # type: ignore
        with rewrite_stream_error():
            response = await anext(iterator)  # type: ignore

        return cls(
            done=False,
            iterator=iterator,
            result=response,
        )

    @classmethod
    def from_response(cls, response: glm.GenerateContentResponse):
        return cls(
            done=True,
            iterator=None,
            result=response,
        )

    async def __aiter__(self):
        # This is not thread safe.
        if self._done:
            for chunk in self._chunks:
                yield GenerateContentResponse.from_response(chunk)
            return

        # Always have the next chunk available.
        if len(self._chunks) == 0:
            self._chunks.append(await anext(self._iterator))  # type: ignore

        for n in itertools.count():
            if self._error:
                raise self._error

            if n >= len(self._chunks) - 1:
                # Look ahead for a new item, so that you know the stream is done
                # when you yield the last item.
                if self._done:
                    return

                try:
                    item = await anext(self._iterator)  # type: ignore
                except StopAsyncIteration:
                    self._done = True
                except Exception as e:
                    self._error = e
                    self._done = True
                else:
                    self._chunks.append(item)
                    self._result = _join_chunks([self._result, item])

            item = self._chunks[n]

            item = GenerateContentResponse.from_response(item)
            yield item

    async def resolve(self):
        if self._done:
            return

        async for _ in self:
            pass