Spaces:
Running
on
Zero
Running
on
Zero
Update inference interruption
Browse files- utils/models.py +9 -10
utils/models.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
|
3 |
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
4 |
import spaces
|
@@ -209,10 +212,8 @@ def run_inference(model_name, context, question):
|
|
209 |
result = pipe(
|
210 |
text_input,
|
211 |
max_new_tokens=512,
|
212 |
-
|
213 |
-
|
214 |
-
"stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
|
215 |
-
},
|
216 |
)[0]["generated_text"]
|
217 |
|
218 |
result = result[-1]["content"]
|
@@ -245,7 +246,7 @@ def run_inference(model_name, context, question):
|
|
245 |
max_new_tokens=512,
|
246 |
eos_token_id=tokenizer.eos_token_id,
|
247 |
pad_token_id=tokenizer.pad_token_id,
|
248 |
-
stopping_criteria=[interrupt_criteria] #
|
249 |
)
|
250 |
|
251 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
@@ -265,7 +266,7 @@ def run_inference(model_name, context, question):
|
|
265 |
# output_sequences = bitnet_model.generate(
|
266 |
# **formatted,
|
267 |
# max_new_tokens=512,
|
268 |
-
# stopping_criteria=[interrupt_criteria]
|
269 |
# )
|
270 |
|
271 |
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
@@ -282,10 +283,8 @@ def run_inference(model_name, context, question):
|
|
282 |
outputs = pipe(
|
283 |
formatted,
|
284 |
max_new_tokens=512,
|
285 |
-
|
286 |
-
|
287 |
-
"stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
|
288 |
-
},
|
289 |
)
|
290 |
# print(outputs[0]['generated_text'])
|
291 |
result = outputs[0]["generated_text"][input_length:]
|
|
|
1 |
import os
|
2 |
+
# Add Dynamo error suppression
|
3 |
+
import torch._dynamo
|
4 |
+
torch._dynamo.config.suppress_errors = True
|
5 |
|
6 |
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
7 |
import spaces
|
|
|
212 |
result = pipe(
|
213 |
text_input,
|
214 |
max_new_tokens=512,
|
215 |
+
stopping_criteria=[interrupt_criteria], # Direct parameter for pipelines
|
216 |
+
generation_kwargs={"skip_special_tokens": True}
|
|
|
|
|
217 |
)[0]["generated_text"]
|
218 |
|
219 |
result = result[-1]["content"]
|
|
|
246 |
max_new_tokens=512,
|
247 |
eos_token_id=tokenizer.eos_token_id,
|
248 |
pad_token_id=tokenizer.pad_token_id,
|
249 |
+
stopping_criteria=[interrupt_criteria] # Direct parameter for model.generate
|
250 |
)
|
251 |
|
252 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
|
|
266 |
# output_sequences = bitnet_model.generate(
|
267 |
# **formatted,
|
268 |
# max_new_tokens=512,
|
269 |
+
# stopping_criteria=[interrupt_criteria]
|
270 |
# )
|
271 |
|
272 |
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
|
|
283 |
outputs = pipe(
|
284 |
formatted,
|
285 |
max_new_tokens=512,
|
286 |
+
stopping_criteria=[interrupt_criteria], # Direct parameter for pipelines
|
287 |
+
generation_kwargs={"skip_special_tokens": True}
|
|
|
|
|
288 |
)
|
289 |
# print(outputs[0]['generated_text'])
|
290 |
result = outputs[0]["generated_text"][input_length:]
|