Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
00263ef
1
Parent(s):
1a203ff
fix inference output
Browse files- llama_lora/lib/inference.py +25 -4
- llama_lora/ui/inference_ui.py +0 -78
- llama_lora/utils/prompter.py +6 -1
llama_lora/lib/inference.py
CHANGED
@@ -4,7 +4,6 @@ import transformers
|
|
4 |
from .get_device import get_device
|
5 |
from .streaming_generation_utils import Iteratorize, Stream
|
6 |
|
7 |
-
|
8 |
def generate(
|
9 |
# model
|
10 |
model,
|
@@ -30,18 +29,34 @@ def generate(
|
|
30 |
"stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
|
31 |
}
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
if stream_output:
|
34 |
# Stream the reply 1 token at a time.
|
35 |
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
36 |
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
|
|
37 |
|
38 |
def generate_with_callback(callback=None, **kwargs):
|
|
|
39 |
kwargs["stopping_criteria"].insert(
|
40 |
0,
|
41 |
Stream(callback_func=callback)
|
42 |
)
|
43 |
with torch.no_grad():
|
44 |
-
model.generate(**kwargs)
|
45 |
|
46 |
def generate_with_streaming(**kwargs):
|
47 |
return Iteratorize(
|
@@ -50,16 +65,22 @@ def generate(
|
|
50 |
|
51 |
with generate_with_streaming(**generate_params) as generator:
|
52 |
for output in generator:
|
53 |
-
decoded_output = tokenizer.decode(output, skip_special_tokens=
|
54 |
yield decoded_output, output
|
55 |
if output[-1] in [tokenizer.eos_token_id]:
|
56 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
return # early return for stream_output
|
58 |
|
59 |
# Without streaming
|
60 |
with torch.no_grad():
|
61 |
generation_output = model.generate(**generate_params)
|
62 |
output = generation_output.sequences[0]
|
63 |
-
decoded_output = tokenizer.decode(output, skip_special_tokens=
|
64 |
yield decoded_output, output
|
65 |
return
|
|
|
4 |
from .get_device import get_device
|
5 |
from .streaming_generation_utils import Iteratorize, Stream
|
6 |
|
|
|
7 |
def generate(
|
8 |
# model
|
9 |
model,
|
|
|
29 |
"stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
|
30 |
}
|
31 |
|
32 |
+
skip_special_tokens = True
|
33 |
+
|
34 |
+
if '/dolly' in tokenizer.name_or_path:
|
35 |
+
# dolly has additional_special_tokens as ['### End', '### Instruction:', '### Response:'], skipping them will break the prompter's reply extraction.
|
36 |
+
skip_special_tokens = False
|
37 |
+
# Ensure generation stops once it generates "### End"
|
38 |
+
end_key_token_id = tokenizer.encode("### End")
|
39 |
+
end_key_token_id = end_key_token_id[0] # 50277
|
40 |
+
if isinstance(generate_params['generation_config'].eos_token_id, str):
|
41 |
+
generate_params['generation_config'].eos_token_id = [generate_params['generation_config'].eos_token_id]
|
42 |
+
elif not generate_params['generation_config'].eos_token_id:
|
43 |
+
generate_params['generation_config'].eos_token_id = []
|
44 |
+
generate_params['generation_config'].eos_token_id.append(end_key_token_id)
|
45 |
+
|
46 |
if stream_output:
|
47 |
# Stream the reply 1 token at a time.
|
48 |
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
49 |
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
50 |
+
generation_output = None
|
51 |
|
52 |
def generate_with_callback(callback=None, **kwargs):
|
53 |
+
nonlocal generation_output
|
54 |
kwargs["stopping_criteria"].insert(
|
55 |
0,
|
56 |
Stream(callback_func=callback)
|
57 |
)
|
58 |
with torch.no_grad():
|
59 |
+
generation_output = model.generate(**kwargs)
|
60 |
|
61 |
def generate_with_streaming(**kwargs):
|
62 |
return Iteratorize(
|
|
|
65 |
|
66 |
with generate_with_streaming(**generate_params) as generator:
|
67 |
for output in generator:
|
68 |
+
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
69 |
yield decoded_output, output
|
70 |
if output[-1] in [tokenizer.eos_token_id]:
|
71 |
break
|
72 |
+
|
73 |
+
if generation_output:
|
74 |
+
output = generation_output.sequences[0]
|
75 |
+
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
76 |
+
yield decoded_output, output
|
77 |
+
|
78 |
return # early return for stream_output
|
79 |
|
80 |
# Without streaming
|
81 |
with torch.no_grad():
|
82 |
generation_output = model.generate(**generate_params)
|
83 |
output = generation_output.sequences[0]
|
84 |
+
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
85 |
yield decoded_output, output
|
86 |
return
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -160,84 +160,6 @@ def do_inference(
|
|
160 |
None)
|
161 |
|
162 |
return
|
163 |
-
|
164 |
-
|
165 |
-
if stream_output:
|
166 |
-
# Stream the reply 1 token at a time.
|
167 |
-
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
168 |
-
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
169 |
-
|
170 |
-
def generate_with_callback(callback=None, **kwargs):
|
171 |
-
kwargs.setdefault(
|
172 |
-
"stopping_criteria", transformers.StoppingCriteriaList()
|
173 |
-
)
|
174 |
-
kwargs["stopping_criteria"].append(
|
175 |
-
Stream(callback_func=callback)
|
176 |
-
)
|
177 |
-
with torch.no_grad():
|
178 |
-
model.generate(**kwargs)
|
179 |
-
|
180 |
-
def generate_with_streaming(**kwargs):
|
181 |
-
return Iteratorize(
|
182 |
-
generate_with_callback, kwargs, callback=None
|
183 |
-
)
|
184 |
-
|
185 |
-
with generate_with_streaming(**generate_params) as generator:
|
186 |
-
for output in generator:
|
187 |
-
# new_tokens = len(output) - len(input_ids[0])
|
188 |
-
decoded_output = tokenizer.decode(output)
|
189 |
-
|
190 |
-
if output[-1] in [tokenizer.eos_token_id]:
|
191 |
-
break
|
192 |
-
|
193 |
-
raw_output = None
|
194 |
-
if show_raw:
|
195 |
-
raw_output = str(output)
|
196 |
-
response = prompter.get_response(decoded_output)
|
197 |
-
|
198 |
-
if Global.should_stop_generating:
|
199 |
-
return
|
200 |
-
|
201 |
-
yield (
|
202 |
-
gr.Textbox.update(
|
203 |
-
value=response, lines=inference_output_lines),
|
204 |
-
raw_output)
|
205 |
-
|
206 |
-
if Global.should_stop_generating:
|
207 |
-
# If the user stops the generation, and then clicks the
|
208 |
-
# generation button again, they may mysteriously landed
|
209 |
-
# here, in the previous, should-be-stopped generation
|
210 |
-
# function call, with the new generation function not be
|
211 |
-
# called at all. To workaround this, we yield a message
|
212 |
-
# and setting lines=1, and if the front-end JS detects
|
213 |
-
# that lines has been set to 1 (rows="1" in HTML),
|
214 |
-
# it will automatically click the generate button again
|
215 |
-
# (gr.Textbox.update() does not support updating
|
216 |
-
# elem_classes or elem_id).
|
217 |
-
# [WORKAROUND-UI01]
|
218 |
-
yield (
|
219 |
-
gr.Textbox.update(
|
220 |
-
value="Please retry", lines=1),
|
221 |
-
None)
|
222 |
-
return # early return for stream_output
|
223 |
-
|
224 |
-
# Without streaming
|
225 |
-
with torch.no_grad():
|
226 |
-
generation_output = model.generate(**generate_params)
|
227 |
-
s = generation_output.sequences[0]
|
228 |
-
output = tokenizer.decode(s)
|
229 |
-
raw_output = None
|
230 |
-
if show_raw:
|
231 |
-
raw_output = str(s)
|
232 |
-
|
233 |
-
response = prompter.get_response(output)
|
234 |
-
if Global.should_stop_generating:
|
235 |
-
return
|
236 |
-
|
237 |
-
yield (
|
238 |
-
gr.Textbox.update(value=response, lines=inference_output_lines),
|
239 |
-
raw_output)
|
240 |
-
|
241 |
except Exception as e:
|
242 |
raise gr.Error(e)
|
243 |
|
|
|
160 |
None)
|
161 |
|
162 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
except Exception as e:
|
164 |
raise gr.Error(e)
|
165 |
|
llama_lora/utils/prompter.py
CHANGED
@@ -131,8 +131,13 @@ class Prompter(object):
|
|
131 |
def get_response(self, output: str) -> str:
|
132 |
if self.template_name == "None":
|
133 |
return output
|
|
|
|
|
|
|
|
|
|
|
134 |
return self.template["response_split"].join(
|
135 |
-
|
136 |
).strip()
|
137 |
|
138 |
def get_variable_names(self) -> List[str]:
|
|
|
131 |
def get_response(self, output: str) -> str:
|
132 |
if self.template_name == "None":
|
133 |
return output
|
134 |
+
|
135 |
+
splitted_output = output.split(self.template["response_split"])
|
136 |
+
# if len(splitted_output) <= 1:
|
137 |
+
# return output.strip()
|
138 |
+
|
139 |
return self.template["response_split"].join(
|
140 |
+
splitted_output[1:]
|
141 |
).strip()
|
142 |
|
143 |
def get_variable_names(self) -> List[str]:
|