mikeleske commited on
Commit
92db50e
·
verified ·
1 Parent(s): 2512b3e

Create streamer.py

Browse files
Files changed (1) hide show
  1. streamer.py +98 -0
streamer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+
3
+
4
+ class BaseStreamer:
5
+ """
6
+ Base class from which `.generate()` streamers should inherit.
7
+ """
8
+
9
+ def put(self, value):
10
+ """Function that is called by `.generate()` to push new tokens"""
11
+ raise NotImplementedError()
12
+
13
+ def end(self):
14
+ """Function that is called by `.generate()` to signal the end of generation"""
15
+ raise NotImplementedError()
16
+
17
+
18
+ class ByteStreamer(BaseStreamer):
19
+ """
20
+ Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
21
+ <Tip warning={true}>
22
+ The API for the streamer classes is still under development and may change in the future.
23
+ </Tip>
24
+ Parameters:
25
+ tokenizer (`AutoTokenizer`):
26
+ The tokenized used to decode the tokens.
27
+ skip_prompt (`bool`, *optional*, defaults to `False`):
28
+ Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
29
+ decode_kwargs (`dict`, *optional*):
30
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
31
+ Examples:
32
+ ```python
33
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
34
+ >>> tok = AutoTokenizer.from_pretrained("gpt2")
35
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
36
+ >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
37
+ >>> streamer = TextStreamer(tok)
38
+ >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
39
+ >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
40
+ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
41
+ ```
42
+ """
43
+
44
+ def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
45
+ self.tokenizer = tokenizer
46
+ self.skip_prompt = skip_prompt
47
+ self.decode_kwargs = decode_kwargs
48
+
49
+ # variables used in the streaming process
50
+ self.token_cache = []
51
+ self.print_len = 0
52
+ self.next_tokens_are_prompt = True
53
+
54
+ def put(self, value):
55
+ """
56
+ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
57
+ """
58
+ if len(value.shape) > 1 and value.shape[0] > 1:
59
+ raise ValueError("TextStreamer only supports batch size 1")
60
+ elif len(value.shape) > 1:
61
+ value = value[0]
62
+
63
+ if self.skip_prompt and self.next_tokens_are_prompt:
64
+ self.next_tokens_are_prompt = False
65
+ return
66
+
67
+ # Add the new token to the cache and decodes the entire thing.
68
+ self.token_cache.extend(value.tolist())
69
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
70
+
71
+ # After the symbol for a new line, we flush the cache.
72
+ if text.endswith("\n"):
73
+ printable_text = text[self.print_len :]
74
+ self.token_cache = []
75
+ self.print_len = 0
76
+ else:
77
+ printable_text = text[self.print_len : self.print_len + 1]
78
+ self.print_len += len(printable_text)
79
+
80
+ self.on_finalized_text(printable_text)
81
+
82
+ def end(self):
83
+ """Flushes any remaining cache and prints a newline to stdout."""
84
+ # Flush the cache, if it exists
85
+ if len(self.token_cache) > 0:
86
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
87
+ printable_text = text[self.print_len :]
88
+ self.token_cache = []
89
+ self.print_len = 0
90
+ else:
91
+ printable_text = ""
92
+
93
+ self.next_tokens_are_prompt = True
94
+ self.on_finalized_text(printable_text, stream_end=True)
95
+
96
+ def on_finalized_text(self, text: str, stream_end: bool = False):
97
+ """Prints the new text to stdout. If the stream is ending, also prints a newline."""
98
+ print(text, flush=True, end="" if not stream_end else None)