aizip-dev commited on
Commit
bb6bbaf
·
verified ·
1 Parent(s): 217c4d4

Update interruption method

Browse files
Files changed (1) hide show
  1. utils/models.py +37 -9
utils/models.py CHANGED
@@ -62,6 +62,7 @@ def generate_summaries(example, model_a_name, model_b_name):
62
  Generates summaries for the given example using the assigned models sequentially.
63
  """
64
  if generation_interrupt.is_set():
 
65
  return "", ""
66
 
67
  context_text = ""
@@ -69,6 +70,11 @@ def generate_summaries(example, model_a_name, model_b_name):
69
 
70
  if "full_contexts" in example and example["full_contexts"]:
71
  for i, ctx in enumerate(example["full_contexts"]):
 
 
 
 
 
72
  content = ""
73
 
74
  # Extract content from either dict or string
@@ -92,17 +98,22 @@ def generate_summaries(example, model_a_name, model_b_name):
92
  question = example.get("question", "")
93
 
94
  if generation_interrupt.is_set():
 
95
  return "", ""
96
 
 
97
  # Run model A
98
  summary_a = run_inference(models[model_a_name], context_text, question)
99
 
100
  if generation_interrupt.is_set():
 
101
  return summary_a, ""
102
 
 
103
  # Run model B
104
  summary_b = run_inference(models[model_b_name], context_text, question)
105
 
 
106
  return summary_a, summary_b
107
 
108
 
@@ -114,6 +125,7 @@ def run_inference(model_name, context, question):
114
  """
115
  # Check interrupt at the beginning
116
  if generation_interrupt.is_set():
 
117
  return ""
118
 
119
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -134,6 +146,11 @@ def run_inference(model_name, context, question):
134
  if model_name in tokenizer_cache:
135
  tokenizer = tokenizer_cache[model_name]
136
  else:
 
 
 
 
 
137
  # Common arguments for tokenizer loading
138
  tokenizer_load_args = {"padding_side": "left", "token": True}
139
 
@@ -155,6 +172,7 @@ def run_inference(model_name, context, question):
155
 
156
  # Check interrupt before loading the model
157
  if generation_interrupt.is_set():
 
158
  return ""
159
 
160
  # Create interrupt criteria for this generation
@@ -162,19 +180,21 @@ def run_inference(model_name, context, question):
162
 
163
  print("REACHED HERE BEFORE pipe")
164
  print(f"Loading model {model_name}...")
 
 
 
 
 
 
165
  if "bitnet" in model_name.lower():
166
  bitnet_model = BitNetForCausalLM.from_pretrained(
167
  model_name,
168
- #device_map="auto",
169
  torch_dtype=torch.bfloat16,
170
- #trust_remote_code=True,
171
  )
172
  pipe = pipeline(
173
  "text-generation",
174
  model=bitnet_model,
175
  tokenizer=tokenizer,
176
- #device_map="auto",
177
- #trust_remote_code=True,
178
  torch_dtype=torch.bfloat16,
179
  model_kwargs={
180
  "attn_implementation": "eager",
@@ -206,13 +226,20 @@ def run_inference(model_name, context, question):
206
  torch_dtype=torch.bfloat16,
207
  )
208
 
 
 
 
 
 
209
  text_input = format_rag_prompt(question, context, accepts_sys)
 
 
210
  if "Gemma-3".lower() in model_name.lower():
211
  print("REACHED HERE BEFORE GEN")
212
  result = pipe(
213
  text_input,
214
  max_new_tokens=512,
215
- stopping_criteria=[interrupt_criteria], # Direct parameter for pipelines
216
  generation_kwargs={"skip_special_tokens": True}
217
  )[0]["generated_text"]
218
 
@@ -238,6 +265,7 @@ def run_inference(model_name, context, question):
238
  with torch.inference_mode():
239
  # Check interrupt before generation
240
  if generation_interrupt.is_set():
 
241
  return ""
242
 
243
  output_sequences = model.generate(
@@ -246,7 +274,7 @@ def run_inference(model_name, context, question):
246
  max_new_tokens=512,
247
  eos_token_id=tokenizer.eos_token_id,
248
  pad_token_id=tokenizer.pad_token_id,
249
- stopping_criteria=[interrupt_criteria] # Direct parameter for model.generate
250
  )
251
 
252
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
@@ -278,17 +306,17 @@ def run_inference(model_name, context, question):
278
  )
279
 
280
  input_length = len(formatted)
281
- # Check interrupt before generation
282
 
283
  outputs = pipe(
284
  formatted,
285
  max_new_tokens=512,
286
- stopping_criteria=[interrupt_criteria], # Direct parameter for pipelines
287
  generation_kwargs={"skip_special_tokens": True}
288
  )
289
- # print(outputs[0]['generated_text'])
290
  result = outputs[0]["generated_text"][input_length:]
291
 
 
 
292
  except Exception as e:
293
  print(f"Error in inference for {model_name}: {e}")
294
  print(traceback.format_exc())
 
62
  Generates summaries for the given example using the assigned models sequentially.
63
  """
64
  if generation_interrupt.is_set():
65
+ print("Generation interrupted before starting")
66
  return "", ""
67
 
68
  context_text = ""
 
70
 
71
  if "full_contexts" in example and example["full_contexts"]:
72
  for i, ctx in enumerate(example["full_contexts"]):
73
+ # Check interrupt during context processing
74
+ if generation_interrupt.is_set():
75
+ print("Generation interrupted during context processing")
76
+ return "", ""
77
+
78
  content = ""
79
 
80
  # Extract content from either dict or string
 
98
  question = example.get("question", "")
99
 
100
  if generation_interrupt.is_set():
101
+ print("Generation interrupted before model A")
102
  return "", ""
103
 
104
+ print(f"Starting inference for Model A: {model_a_name}")
105
  # Run model A
106
  summary_a = run_inference(models[model_a_name], context_text, question)
107
 
108
  if generation_interrupt.is_set():
109
+ print("Generation interrupted after model A, before model B")
110
  return summary_a, ""
111
 
112
+ print(f"Starting inference for Model B: {model_b_name}")
113
  # Run model B
114
  summary_b = run_inference(models[model_b_name], context_text, question)
115
 
116
+ print("Both models completed successfully")
117
  return summary_a, summary_b
118
 
119
 
 
125
  """
126
  # Check interrupt at the beginning
127
  if generation_interrupt.is_set():
128
+ print(f"Inference interrupted before starting for {model_name}")
129
  return ""
130
 
131
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
146
  if model_name in tokenizer_cache:
147
  tokenizer = tokenizer_cache[model_name]
148
  else:
149
+ # Check interrupt before loading tokenizer
150
+ if generation_interrupt.is_set():
151
+ print(f"Inference interrupted before loading tokenizer for {model_name}")
152
+ return ""
153
+
154
  # Common arguments for tokenizer loading
155
  tokenizer_load_args = {"padding_side": "left", "token": True}
156
 
 
172
 
173
  # Check interrupt before loading the model
174
  if generation_interrupt.is_set():
175
+ print(f"Inference interrupted before loading model {model_name}")
176
  return ""
177
 
178
  # Create interrupt criteria for this generation
 
180
 
181
  print("REACHED HERE BEFORE pipe")
182
  print(f"Loading model {model_name}...")
183
+
184
+ # Check interrupt before model loading
185
+ if generation_interrupt.is_set():
186
+ print(f"Inference interrupted during model loading for {model_name}")
187
+ return ""
188
+
189
  if "bitnet" in model_name.lower():
190
  bitnet_model = BitNetForCausalLM.from_pretrained(
191
  model_name,
 
192
  torch_dtype=torch.bfloat16,
 
193
  )
194
  pipe = pipeline(
195
  "text-generation",
196
  model=bitnet_model,
197
  tokenizer=tokenizer,
 
 
198
  torch_dtype=torch.bfloat16,
199
  model_kwargs={
200
  "attn_implementation": "eager",
 
226
  torch_dtype=torch.bfloat16,
227
  )
228
 
229
+ # Final interrupt check before generation
230
+ if generation_interrupt.is_set():
231
+ print(f"Inference interrupted before generation for {model_name}")
232
+ return ""
233
+
234
  text_input = format_rag_prompt(question, context, accepts_sys)
235
+
236
+ print(f"Starting generation for {model_name}")
237
  if "Gemma-3".lower() in model_name.lower():
238
  print("REACHED HERE BEFORE GEN")
239
  result = pipe(
240
  text_input,
241
  max_new_tokens=512,
242
+ stopping_criteria=[interrupt_criteria],
243
  generation_kwargs={"skip_special_tokens": True}
244
  )[0]["generated_text"]
245
 
 
265
  with torch.inference_mode():
266
  # Check interrupt before generation
267
  if generation_interrupt.is_set():
268
+ print(f"Inference interrupted before torch generation for {model_name}")
269
  return ""
270
 
271
  output_sequences = model.generate(
 
274
  max_new_tokens=512,
275
  eos_token_id=tokenizer.eos_token_id,
276
  pad_token_id=tokenizer.pad_token_id,
277
+ stopping_criteria=[interrupt_criteria]
278
  )
279
 
280
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
 
306
  )
307
 
308
  input_length = len(formatted)
 
309
 
310
  outputs = pipe(
311
  formatted,
312
  max_new_tokens=512,
313
+ stopping_criteria=[interrupt_criteria],
314
  generation_kwargs={"skip_special_tokens": True}
315
  )
 
316
  result = outputs[0]["generated_text"][input_length:]
317
 
318
+ print(f"Generation completed for {model_name}")
319
+
320
  except Exception as e:
321
  print(f"Error in inference for {model_name}: {e}")
322
  print(traceback.format_exc())