aizip-dev commited on
Commit
217c4d4
·
verified ·
1 Parent(s): c1f1ebf

Update inference interruption

Browse files
Files changed (1) hide show
  1. utils/models.py +9 -10
utils/models.py CHANGED
@@ -1,4 +1,7 @@
1
  import os
 
 
 
2
 
3
  os.environ["MKL_THREADING_LAYER"] = "GNU"
4
  import spaces
@@ -209,10 +212,8 @@ def run_inference(model_name, context, question):
209
  result = pipe(
210
  text_input,
211
  max_new_tokens=512,
212
- generation_kwargs={
213
- "skip_special_tokens": True,
214
- "stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
215
- },
216
  )[0]["generated_text"]
217
 
218
  result = result[-1]["content"]
@@ -245,7 +246,7 @@ def run_inference(model_name, context, question):
245
  max_new_tokens=512,
246
  eos_token_id=tokenizer.eos_token_id,
247
  pad_token_id=tokenizer.pad_token_id,
248
- stopping_criteria=[interrupt_criteria] # ADD INTERRUPT SUPPORT
249
  )
250
 
251
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
@@ -265,7 +266,7 @@ def run_inference(model_name, context, question):
265
  # output_sequences = bitnet_model.generate(
266
  # **formatted,
267
  # max_new_tokens=512,
268
- # stopping_criteria=[interrupt_criteria] # ADD INTERRUPT SUPPORT
269
  # )
270
 
271
  # result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
@@ -282,10 +283,8 @@ def run_inference(model_name, context, question):
282
  outputs = pipe(
283
  formatted,
284
  max_new_tokens=512,
285
- generation_kwargs={
286
- "skip_special_tokens": True,
287
- "stopping_criteria": [interrupt_criteria] # ADD INTERRUPT SUPPORT
288
- },
289
  )
290
  # print(outputs[0]['generated_text'])
291
  result = outputs[0]["generated_text"][input_length:]
 
1
  import os
2
+ # Add Dynamo error suppression
3
+ import torch._dynamo
4
+ torch._dynamo.config.suppress_errors = True
5
 
6
  os.environ["MKL_THREADING_LAYER"] = "GNU"
7
  import spaces
 
212
  result = pipe(
213
  text_input,
214
  max_new_tokens=512,
215
+ stopping_criteria=[interrupt_criteria], # Direct parameter for pipelines
216
+ generation_kwargs={"skip_special_tokens": True}
 
 
217
  )[0]["generated_text"]
218
 
219
  result = result[-1]["content"]
 
246
  max_new_tokens=512,
247
  eos_token_id=tokenizer.eos_token_id,
248
  pad_token_id=tokenizer.pad_token_id,
249
+ stopping_criteria=[interrupt_criteria] # Direct parameter for model.generate
250
  )
251
 
252
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
 
266
  # output_sequences = bitnet_model.generate(
267
  # **formatted,
268
  # max_new_tokens=512,
269
+ # stopping_criteria=[interrupt_criteria]
270
  # )
271
 
272
  # result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
 
283
  outputs = pipe(
284
  formatted,
285
  max_new_tokens=512,
286
+ stopping_criteria=[interrupt_criteria], # Direct parameter for pipelines
287
+ generation_kwargs={"skip_special_tokens": True}
 
 
288
  )
289
  # print(outputs[0]['generated_text'])
290
  result = outputs[0]["generated_text"][input_length:]