noumanjavaid commited on
Commit
bb5747b
·
verified ·
1 Parent(s): 04fc03b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +545 -0
app.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ import contextlib
5
+ import sys
6
+ from collections.abc import Iterable, AsyncIterable
7
+ import dataclasses
8
+ import itertools
9
+ import textwrap
10
+ from typing import TypedDict, Union
11
+
12
+ import google.protobuf.json_format
13
+ import google.api_core.exceptions
14
+
15
+ from google.ai import generativelanguage as glm
16
+ from google.generativeai import string_utils
17
+
18
+ __all__ = [
19
+ "AsyncGenerateContentResponse",
20
+ "BlockedPromptException",
21
+ "StopCandidateException",
22
+ "IncompleteIterationError",
23
+ "BrokenResponseError",
24
+ "GenerationConfigDict",
25
+ "GenerationConfigType",
26
+ "GenerationConfig",
27
+ "GenerateContentResponse",
28
+ ]
29
+
30
+ if sys.version_info < (3, 10):
31
+
32
+ def aiter(obj):
33
+ return obj.__aiter__()
34
+
35
+ async def anext(obj, default=None):
36
+ try:
37
+ return await obj.__anext__()
38
+ except StopAsyncIteration:
39
+ if default is not None:
40
+ return default
41
+ else:
42
+ raise
43
+
44
+
45
+ class BlockedPromptException(Exception):
46
+ pass
47
+
48
+
49
+ class StopCandidateException(Exception):
50
+ pass
51
+
52
+
53
+ class IncompleteIterationError(Exception):
54
+ pass
55
+
56
+
57
+ class BrokenResponseError(Exception):
58
+ pass
59
+
60
+
61
+ class GenerationConfigDict(TypedDict):
62
+ # TODO(markdaoust): Python 3.11+ use `NotRequired`, ref: https://peps.python.org/pep-0655/
63
+ candidate_count: int
64
+ stop_sequences: Iterable[str]
65
+ max_output_tokens: int
66
+ temperature: float
67
+
68
+
69
+ @dataclasses.dataclass
70
+ class GenerationConfig:
71
+ """A simple dataclass used to configure the generation parameters of `GenerativeModel.generate_content`.
72
+
73
+ Attributes:
74
+ candidate_count:
75
+ Number of generated responses to return.
76
+ stop_sequences:
77
+ The set of character sequences (up
78
+ to 5) that will stop output generation. If
79
+ specified, the API will stop at the first
80
+ appearance of a stop sequence. The stop sequence
81
+ will not be included as part of the response.
82
+ max_output_tokens:
83
+ The maximum number of tokens to include in a
84
+ candidate.
85
+
86
+ If unset, this will default to output_token_limit specified
87
+ in the model's specification.
88
+ temperature:
89
+ Controls the randomness of the output. Note: The
90
+
91
+ default value varies by model, see the `Model.temperature`
92
+ attribute of the `Model` returned the `genai.get_model`
93
+ function.
94
+
95
+ Values can range from [0.0,1.0], inclusive. A value closer
96
+ to 1.0 will produce responses that are more varied and
97
+ creative, while a value closer to 0.0 will typically result
98
+ in more straightforward responses from the model.
99
+ top_p:
100
+ Optional. The maximum cumulative probability of tokens to
101
+ consider when sampling.
102
+
103
+ The model uses combined Top-k and nucleus sampling.
104
+
105
+ Tokens are sorted based on their assigned probabilities so
106
+ that only the most likely tokens are considered. Top-k
107
+ sampling directly limits the maximum number of tokens to
108
+ consider, while Nucleus sampling limits number of tokens
109
+ based on the cumulative probability.
110
+
111
+ Note: The default value varies by model, see the
112
+ `Model.top_p` attribute of the `Model` returned the
113
+ `genai.get_model` function.
114
+
115
+ top_k (int):
116
+ Optional. The maximum number of tokens to consider when
117
+ sampling.
118
+
119
+ The model uses combined Top-k and nucleus sampling.
120
+
121
+ Top-k sampling considers the set of `top_k` most probable
122
+ tokens. Defaults to 40.
123
+
124
+ Note: The default value varies by model, see the
125
+ `Model.top_k` attribute of the `Model` returned the
126
+ `genai.get_model` function.
127
+ """
128
+
129
+ candidate_count: int | None = None
130
+ stop_sequences: Iterable[str] | None = None
131
+ max_output_tokens: int | None = None
132
+ temperature: float | None = None
133
+ top_p: float | None = None
134
+ top_k: int | None = None
135
+
136
+
137
+ GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig]
138
+
139
+
140
+ def to_generation_config_dict(generation_config: GenerationConfigType):
141
+ if generation_config is None:
142
+ return {}
143
+ elif isinstance(generation_config, glm.GenerationConfig):
144
+ return type(generation_config).to_dict(generation_config) # pytype: disable=attribute-error
145
+ elif isinstance(generation_config, GenerationConfig):
146
+ generation_config = dataclasses.asdict(generation_config)
147
+ return {key: value for key, value in generation_config.items() if value is not None}
148
+ elif hasattr(generation_config, "keys"):
149
+ return dict(generation_config)
150
+ else:
151
+ raise TypeError(
152
+ "Did not understand `generation_config`, expected a `dict` or"
153
+ f" `GenerationConfig`\nGot type: {type(generation_config)}\nValue:"
154
+ f" {generation_config}"
155
+ )
156
+
157
+
158
+ def _join_citation_metadatas(
159
+ citation_metadatas: Iterable[glm.CitationMetadata],
160
+ ):
161
+ citation_metadatas = list(citation_metadatas)
162
+ return citation_metadatas[-1]
163
+
164
+
165
+ def _join_safety_ratings_lists(
166
+ safety_ratings_lists: Iterable[list[glm.SafetyRating]],
167
+ ):
168
+ ratings = {}
169
+ blocked = collections.defaultdict(list)
170
+
171
+ for safety_ratings_list in safety_ratings_lists:
172
+ for rating in safety_ratings_list:
173
+ ratings[rating.category] = rating.probability
174
+ blocked[rating.category].append(rating.blocked)
175
+
176
+ blocked = {category: any(blocked) for category, blocked in blocked.items()}
177
+
178
+ safety_list = []
179
+ for (category, probability), blocked in zip(ratings.items(), blocked.values()):
180
+ safety_list.append(
181
+ glm.SafetyRating(category=category, probability=probability, blocked=blocked)
182
+ )
183
+
184
+ return safety_list
185
+
186
+
187
+ def _join_contents(contents: Iterable[glm.Content]):
188
+ contents = tuple(contents)
189
+ roles = [c.role for c in contents if c.role]
190
+ if roles:
191
+ role = roles[0]
192
+ else:
193
+ role = ""
194
+
195
+ parts = []
196
+ for content in contents:
197
+ parts.extend(content.parts)
198
+
199
+ merged_parts = [parts.pop(0)]
200
+ for part in parts:
201
+ if not merged_parts[-1].text:
202
+ merged_parts.append(part)
203
+ continue
204
+
205
+ if not part.text:
206
+ merged_parts.append(part)
207
+ continue
208
+
209
+ merged_part = glm.Part(merged_parts[-1])
210
+ merged_part.text += part.text
211
+ merged_parts[-1] = merged_part
212
+
213
+ return glm.Content(
214
+ role=role,
215
+ parts=merged_parts,
216
+ )
217
+
218
+
219
+ def _join_candidates(candidates: Iterable[glm.Candidate]):
220
+ candidates = tuple(candidates)
221
+
222
+ index = candidates[0].index # These should all be the same.
223
+
224
+ return glm.Candidate(
225
+ index=index,
226
+ content=_join_contents([c.content for c in candidates]),
227
+ finish_reason=candidates[-1].finish_reason,
228
+ safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]),
229
+ citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]),
230
+ )
231
+
232
+
233
+ def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]):
234
+ # Assuming that is a candidate ends, it is no longer returned in the list of
235
+ # candidates and that's why candidates have an index
236
+ candidates = collections.defaultdict(list)
237
+ for candidate_list in candidate_lists:
238
+ for candidate in candidate_list:
239
+ candidates[candidate.index].append(candidate)
240
+
241
+ new_candidates = []
242
+ for index, candidate_parts in sorted(candidates.items()):
243
+ new_candidates.append(_join_candidates(candidate_parts))
244
+
245
+ return new_candidates
246
+
247
+
248
+ def _join_prompt_feedbacks(
249
+ prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback],
250
+ ):
251
+ # Always return the first prompt feedback.
252
+ return next(iter(prompt_feedbacks))
253
+
254
+
255
+ def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]):
256
+ return glm.GenerateContentResponse(
257
+ candidates=_join_candidate_lists(c.candidates for c in chunks),
258
+ prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
259
+ )
260
+
261
+
262
+ _INCOMPLETE_ITERATION_MESSAGE = """\
263
+ Please let the response complete iteration before accessing the final accumulated
264
+ attributes (or call `response.resolve()`)"""
265
+
266
+
267
+ class BaseGenerateContentResponse:
268
+ def __init__(
269
+ self,
270
+ done: bool,
271
+ iterator: (
272
+ None
273
+ | Iterable[glm.GenerateContentResponse]
274
+ | AsyncIterable[glm.GenerateContentResponse]
275
+ ),
276
+ result: glm.GenerateContentResponse,
277
+ chunks: Iterable[glm.GenerateContentResponse] | None = None,
278
+ ):
279
+ self._done = done
280
+ self._iterator = iterator
281
+ self._result = result
282
+ if chunks is None:
283
+ self._chunks = [result]
284
+ else:
285
+ self._chunks = list(chunks)
286
+ if result.prompt_feedback.block_reason:
287
+ self._error = BlockedPromptException(result)
288
+ else:
289
+ self._error = None
290
+
291
+ @property
292
+ def candidates(self):
293
+ """The list of candidate responses.
294
+
295
+ Raises:
296
+ IncompleteIterationError: With `stream=True` if iteration over the stream was not completed.
297
+ """
298
+ if not self._done:
299
+ raise IncompleteIterationError(_INCOMPLETE_ITERATION_MESSAGE)
300
+ return self._result.candidates
301
+
302
+ @property
303
+ def parts(self):
304
+ """A quick accessor equivalent to `self.candidates[0].parts`
305
+
306
+ Raises:
307
+ ValueError: If the candidate list does not contain exactly one candidate.
308
+ """
309
+ candidates = self.candidates
310
+ if not candidates:
311
+ raise ValueError(
312
+ "The `response.parts` quick accessor only works for a single candidate, "
313
+ "but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked."
314
+ )
315
+ if len(candidates) > 1:
316
+ raise ValueError(
317
+ "The `response.parts` quick accessor only works with a "
318
+ "single candidate. With multiple candidates use "
319
+ "result.candidates[index].text"
320
+ )
321
+ parts = candidates[0].content.parts
322
+ return parts
323
+
324
+ @property
325
+ def text(self):
326
+ """A quick accessor equivalent to `self.candidates[0].parts[0].text`
327
+
328
+ Raises:
329
+ ValueError: If the candidate list or parts list does not contain exactly one entry.
330
+ """
331
+ parts = self.parts
332
+ if not parts:
333
+ raise ValueError(
334
+ "The `response.text` quick accessor only works when the response contains a valid "
335
+ "`Part`, but none was returned. Check the `candidate.safety_ratings` to see if the "
336
+ "response was blocked."
337
+ )
338
+
339
+ return parts[0].text
340
+
341
+ @property
342
+ def prompt_feedback(self):
343
+ return self._result.prompt_feedback
344
+
345
+ def __str__(self) -> str:
346
+ if self._done:
347
+ _iterator = "None"
348
+ else:
349
+ _iterator = f"<{self._iterator.__class__.__name__}>"
350
+
351
+ _result = f"glm.GenerateContentResponse({type(self._result).to_dict(self._result)})"
352
+
353
+ if self._error:
354
+ _error = f",\nerror=<{self._error.__class__.__name__}> {self._error}"
355
+ else:
356
+ _error = ""
357
+
358
+ return (
359
+ textwrap.dedent(
360
+ f"""\
361
+ response:
362
+ {type(self).__name__}(
363
+ done={self._done},
364
+ iterator={_iterator},
365
+ result={_result},
366
+ )"""
367
+ )
368
+ + _error
369
+ )
370
+
371
+ __repr__ = __str__
372
+
373
+
374
+ @contextlib.contextmanager
375
+ def rewrite_stream_error():
376
+ try:
377
+ yield
378
+ except (google.protobuf.json_format.ParseError, AttributeError) as e:
379
+ raise google.api_core.exceptions.BadRequest(
380
+ "Unknown error trying to retrieve streaming response. "
381
+ "Please retry with `stream=False` for more details."
382
+ )
383
+
384
+
385
+ GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method.
386
+
387
+ These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`.
388
+ This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback`
389
+ and `candidates` attributes. This class adds several quick accessors for common use cases.
390
+
391
+ The same object type is returned for both `stream=True/False`.
392
+
393
+ ### Streaming
394
+
395
+ When you pass `stream=True` to `GenerativeModel.generate_content` or `ChatSession.send_message`,
396
+ iterate over this object to receive chunks of the response:
397
+
398
+ ```
399
+ response = model.generate_content(..., stream=True):
400
+ for chunk in response:
401
+ print(chunk.text)
402
+ ```
403
+
404
+ `GenerateContentResponse.prompt_feedback` is available immediately but
405
+ `GenerateContentResponse.candidates`, and all the attributes derived from them (`.text`, `.parts`),
406
+ are only available after the iteration is complete.
407
+ """
408
+
409
+ ASYNC_GENERATE_CONTENT_RESPONSE_DOC = (
410
+ """This is the async version of `genai.GenerateContentResponse`."""
411
+ )
412
+
413
+
414
+ @string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC)
415
+ class GenerateContentResponse(BaseGenerateContentResponse):
416
+ @classmethod
417
+ def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]):
418
+ iterator = iter(iterator)
419
+ with rewrite_stream_error():
420
+ response = next(iterator)
421
+
422
+ return cls(
423
+ done=False,
424
+ iterator=iterator,
425
+ result=response,
426
+ )
427
+
428
+ @classmethod
429
+ def from_response(cls, response: glm.GenerateContentResponse):
430
+ return cls(
431
+ done=True,
432
+ iterator=None,
433
+ result=response,
434
+ )
435
+
436
+ def __iter__(self):
437
+ # This is not thread safe.
438
+ if self._done:
439
+ for chunk in self._chunks:
440
+ yield GenerateContentResponse.from_response(chunk)
441
+ return
442
+
443
+ # Always have the next chunk available.
444
+ if len(self._chunks) == 0:
445
+ self._chunks.append(next(self._iterator))
446
+
447
+ for n in itertools.count():
448
+ if self._error:
449
+ raise self._error
450
+
451
+ if n >= len(self._chunks) - 1:
452
+ # Look ahead for a new item, so that you know the stream is done
453
+ # when you yield the last item.
454
+ if self._done:
455
+ return
456
+
457
+ try:
458
+ item = next(self._iterator)
459
+ except StopIteration:
460
+ self._done = True
461
+ except Exception as e:
462
+ self._error = e
463
+ self._done = True
464
+ else:
465
+ self._chunks.append(item)
466
+ self._result = _join_chunks([self._result, item])
467
+
468
+ item = self._chunks[n]
469
+
470
+ item = GenerateContentResponse.from_response(item)
471
+ yield item
472
+
473
+ def resolve(self):
474
+ if self._done:
475
+ return
476
+
477
+ for _ in self:
478
+ pass
479
+
480
+
481
+ @string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC)
482
+ class AsyncGenerateContentResponse(BaseGenerateContentResponse):
483
+ @classmethod
484
+ async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]):
485
+ iterator = aiter(iterator) # type: ignore
486
+ with rewrite_stream_error():
487
+ response = await anext(iterator) # type: ignore
488
+
489
+ return cls(
490
+ done=False,
491
+ iterator=iterator,
492
+ result=response,
493
+ )
494
+
495
+ @classmethod
496
+ def from_response(cls, response: glm.GenerateContentResponse):
497
+ return cls(
498
+ done=True,
499
+ iterator=None,
500
+ result=response,
501
+ )
502
+
503
+ async def __aiter__(self):
504
+ # This is not thread safe.
505
+ if self._done:
506
+ for chunk in self._chunks:
507
+ yield GenerateContentResponse.from_response(chunk)
508
+ return
509
+
510
+ # Always have the next chunk available.
511
+ if len(self._chunks) == 0:
512
+ self._chunks.append(await anext(self._iterator)) # type: ignore
513
+
514
+ for n in itertools.count():
515
+ if self._error:
516
+ raise self._error
517
+
518
+ if n >= len(self._chunks) - 1:
519
+ # Look ahead for a new item, so that you know the stream is done
520
+ # when you yield the last item.
521
+ if self._done:
522
+ return
523
+
524
+ try:
525
+ item = await anext(self._iterator) # type: ignore
526
+ except StopAsyncIteration:
527
+ self._done = True
528
+ except Exception as e:
529
+ self._error = e
530
+ self._done = True
531
+ else:
532
+ self._chunks.append(item)
533
+ self._result = _join_chunks([self._result, item])
534
+
535
+ item = self._chunks[n]
536
+
537
+ item = GenerateContentResponse.from_response(item)
538
+ yield item
539
+
540
+ async def resolve(self):
541
+ if self._done:
542
+ return
543
+
544
+ async for _ in self:
545
+ pass