Create streamer.py
Browse files- streamer.py +98 -0
streamer.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer
|
2 |
+
|
3 |
+
|
4 |
+
class BaseStreamer:
|
5 |
+
"""
|
6 |
+
Base class from which `.generate()` streamers should inherit.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def put(self, value):
|
10 |
+
"""Function that is called by `.generate()` to push new tokens"""
|
11 |
+
raise NotImplementedError()
|
12 |
+
|
13 |
+
def end(self):
|
14 |
+
"""Function that is called by `.generate()` to signal the end of generation"""
|
15 |
+
raise NotImplementedError()
|
16 |
+
|
17 |
+
|
18 |
+
class ByteStreamer(BaseStreamer):
|
19 |
+
"""
|
20 |
+
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
21 |
+
<Tip warning={true}>
|
22 |
+
The API for the streamer classes is still under development and may change in the future.
|
23 |
+
</Tip>
|
24 |
+
Parameters:
|
25 |
+
tokenizer (`AutoTokenizer`):
|
26 |
+
The tokenized used to decode the tokens.
|
27 |
+
skip_prompt (`bool`, *optional*, defaults to `False`):
|
28 |
+
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
29 |
+
decode_kwargs (`dict`, *optional*):
|
30 |
+
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
31 |
+
Examples:
|
32 |
+
```python
|
33 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
34 |
+
>>> tok = AutoTokenizer.from_pretrained("gpt2")
|
35 |
+
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
36 |
+
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
37 |
+
>>> streamer = TextStreamer(tok)
|
38 |
+
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
39 |
+
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
40 |
+
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
41 |
+
```
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
45 |
+
self.tokenizer = tokenizer
|
46 |
+
self.skip_prompt = skip_prompt
|
47 |
+
self.decode_kwargs = decode_kwargs
|
48 |
+
|
49 |
+
# variables used in the streaming process
|
50 |
+
self.token_cache = []
|
51 |
+
self.print_len = 0
|
52 |
+
self.next_tokens_are_prompt = True
|
53 |
+
|
54 |
+
def put(self, value):
|
55 |
+
"""
|
56 |
+
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
57 |
+
"""
|
58 |
+
if len(value.shape) > 1 and value.shape[0] > 1:
|
59 |
+
raise ValueError("TextStreamer only supports batch size 1")
|
60 |
+
elif len(value.shape) > 1:
|
61 |
+
value = value[0]
|
62 |
+
|
63 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
64 |
+
self.next_tokens_are_prompt = False
|
65 |
+
return
|
66 |
+
|
67 |
+
# Add the new token to the cache and decodes the entire thing.
|
68 |
+
self.token_cache.extend(value.tolist())
|
69 |
+
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
70 |
+
|
71 |
+
# After the symbol for a new line, we flush the cache.
|
72 |
+
if text.endswith("\n"):
|
73 |
+
printable_text = text[self.print_len :]
|
74 |
+
self.token_cache = []
|
75 |
+
self.print_len = 0
|
76 |
+
else:
|
77 |
+
printable_text = text[self.print_len : self.print_len + 1]
|
78 |
+
self.print_len += len(printable_text)
|
79 |
+
|
80 |
+
self.on_finalized_text(printable_text)
|
81 |
+
|
82 |
+
def end(self):
|
83 |
+
"""Flushes any remaining cache and prints a newline to stdout."""
|
84 |
+
# Flush the cache, if it exists
|
85 |
+
if len(self.token_cache) > 0:
|
86 |
+
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
87 |
+
printable_text = text[self.print_len :]
|
88 |
+
self.token_cache = []
|
89 |
+
self.print_len = 0
|
90 |
+
else:
|
91 |
+
printable_text = ""
|
92 |
+
|
93 |
+
self.next_tokens_are_prompt = True
|
94 |
+
self.on_finalized_text(printable_text, stream_end=True)
|
95 |
+
|
96 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
97 |
+
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
|
98 |
+
print(text, flush=True, end="" if not stream_end else None)
|