aizip-dev commited on
Commit
b867be1
·
verified ·
1 Parent(s): 693f0cb

Roll back interruption changes

Browse files
Files changed (1) hide show
  1. utils/models.py +13 -69
utils/models.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- # Add Dynamo error suppression
3
  import torch._dynamo
4
  torch._dynamo.config.suppress_errors = True
5
 
@@ -17,7 +17,8 @@ from transformers import (
17
  BitNetForCausalLM
18
  )
19
  from .prompts import format_rag_prompt
20
- from .shared import generation_interrupt
 
21
 
22
  models = {
23
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
@@ -47,13 +48,13 @@ tokenizer_cache = {}
47
  model_names = list(models.keys())
48
 
49
 
50
- # Custom stopping criteria that checks the interrupt flag
51
- class InterruptCriteria(StoppingCriteria):
52
- def __init__(self, interrupt_event):
53
- self.interrupt_event = interrupt_event
54
-
55
- def __call__(self, input_ids, scores, **kwargs):
56
- return self.interrupt_event.is_set()
57
 
58
 
59
  @spaces.GPU
@@ -61,20 +62,12 @@ def generate_summaries(example, model_a_name, model_b_name):
61
  """
62
  Generates summaries for the given example using the assigned models sequentially.
63
  """
64
- if generation_interrupt.is_set():
65
- print("Generation interrupted before starting")
66
- return "", ""
67
-
68
  context_text = ""
69
  context_parts = []
70
 
71
  if "full_contexts" in example and example["full_contexts"]:
72
  for i, ctx in enumerate(example["full_contexts"]):
73
- # Check interrupt during context processing
74
- if generation_interrupt.is_set():
75
- print("Generation interrupted during context processing")
76
- return "", ""
77
-
78
  content = ""
79
 
80
  # Extract content from either dict or string
@@ -97,18 +90,10 @@ def generate_summaries(example, model_a_name, model_b_name):
97
 
98
  question = example.get("question", "")
99
 
100
- if generation_interrupt.is_set():
101
- print("Generation interrupted before model A")
102
- return "", ""
103
-
104
  print(f"Starting inference for Model A: {model_a_name}")
105
  # Run model A
106
  summary_a = run_inference(models[model_a_name], context_text, question)
107
 
108
- if generation_interrupt.is_set():
109
- print("Generation interrupted after model A, before model B")
110
- return summary_a, ""
111
-
112
  print(f"Starting inference for Model B: {model_b_name}")
113
  # Run model B
114
  summary_b = run_inference(models[model_b_name], context_text, question)
@@ -121,13 +106,8 @@ def generate_summaries(example, model_a_name, model_b_name):
121
  def run_inference(model_name, context, question):
122
  """
123
  Run inference using the specified model.
124
- Returns the generated text or empty string if interrupted.
125
  """
126
- # Check interrupt at the beginning
127
- if generation_interrupt.is_set():
128
- print(f"Inference interrupted before starting for {model_name}")
129
- return ""
130
-
131
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
  result = ""
133
  tokenizer_kwargs = {
@@ -146,11 +126,6 @@ def run_inference(model_name, context, question):
146
  if model_name in tokenizer_cache:
147
  tokenizer = tokenizer_cache[model_name]
148
  else:
149
- # Check interrupt before loading tokenizer
150
- if generation_interrupt.is_set():
151
- print(f"Inference interrupted before loading tokenizer for {model_name}")
152
- return ""
153
-
154
  # Common arguments for tokenizer loading
155
  tokenizer_load_args = {"padding_side": "left", "token": True}
156
 
@@ -170,21 +145,8 @@ def run_inference(model_name, context, question):
170
  if tokenizer.pad_token is None:
171
  tokenizer.pad_token = tokenizer.eos_token
172
 
173
- # Check interrupt before loading the model
174
- if generation_interrupt.is_set():
175
- print(f"Inference interrupted before loading model {model_name}")
176
- return ""
177
-
178
- # Create interrupt criteria for this generation
179
- interrupt_criteria = InterruptCriteria(generation_interrupt)
180
-
181
  print("REACHED HERE BEFORE pipe")
182
  print(f"Loading model {model_name}...")
183
-
184
- # Check interrupt before model loading
185
- if generation_interrupt.is_set():
186
- print(f"Inference interrupted during model loading for {model_name}")
187
- return ""
188
 
189
  if "bitnet" in model_name.lower():
190
  bitnet_model = BitNetForCausalLM.from_pretrained(
@@ -226,11 +188,6 @@ def run_inference(model_name, context, question):
226
  torch_dtype=torch.bfloat16,
227
  )
228
 
229
- # Final interrupt check before generation
230
- if generation_interrupt.is_set():
231
- print(f"Inference interrupted before generation for {model_name}")
232
- return ""
233
-
234
  text_input = format_rag_prompt(question, context, accepts_sys)
235
 
236
  print(f"Starting generation for {model_name}")
@@ -239,7 +196,6 @@ def run_inference(model_name, context, question):
239
  result = pipe(
240
  text_input,
241
  max_new_tokens=512,
242
- stopping_criteria=[interrupt_criteria],
243
  generation_kwargs={"skip_special_tokens": True}
244
  )[0]["generated_text"]
245
 
@@ -263,18 +219,12 @@ def run_inference(model_name, context, question):
263
  prompt_tokens_length = input_ids.shape[1]
264
 
265
  with torch.inference_mode():
266
- # Check interrupt before generation
267
- if generation_interrupt.is_set():
268
- print(f"Inference interrupted before torch generation for {model_name}")
269
- return ""
270
-
271
  output_sequences = model.generate(
272
  input_ids=input_ids,
273
  attention_mask=attention_mask,
274
  max_new_tokens=512,
275
  eos_token_id=tokenizer.eos_token_id,
276
- pad_token_id=tokenizer.pad_token_id,
277
- stopping_criteria=[interrupt_criteria]
278
  )
279
 
280
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
@@ -288,15 +238,10 @@ def run_inference(model_name, context, question):
288
  # **tokenizer_kwargs,
289
  # ).to(bitnet_model.device)
290
  # with torch.inference_mode():
291
- # # Check interrupt before generation
292
- # if generation_interrupt.is_set():
293
- # return ""
294
  # output_sequences = bitnet_model.generate(
295
  # **formatted,
296
  # max_new_tokens=512,
297
- # stopping_criteria=[interrupt_criteria]
298
  # )
299
-
300
  # result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
301
  else: # For other models
302
  formatted = pipe.tokenizer.apply_chat_template(
@@ -310,7 +255,6 @@ def run_inference(model_name, context, question):
310
  outputs = pipe(
311
  formatted,
312
  max_new_tokens=512,
313
- stopping_criteria=[interrupt_criteria],
314
  generation_kwargs={"skip_special_tokens": True}
315
  )
316
  result = outputs[0]["generated_text"][input_length:]
 
1
  import os
2
+ # Keep Dynamo error suppression
3
  import torch._dynamo
4
  torch._dynamo.config.suppress_errors = True
5
 
 
17
  BitNetForCausalLM
18
  )
19
  from .prompts import format_rag_prompt
20
+ # Remove interrupt import
21
+ # from .shared import generation_interrupt
22
 
23
  models = {
24
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
 
48
  model_names = list(models.keys())
49
 
50
 
51
+ # Remove interrupt criteria class since we're not using it
52
+ # class InterruptCriteria(StoppingCriteria):
53
+ # def __init__(self, interrupt_event):
54
+ # self.interrupt_event = interrupt_event
55
+ #
56
+ # def __call__(self, input_ids, scores, **kwargs):
57
+ # return self.interrupt_event.is_set()
58
 
59
 
60
  @spaces.GPU
 
62
  """
63
  Generates summaries for the given example using the assigned models sequentially.
64
  """
65
+ # Remove interrupt checks
 
 
 
66
  context_text = ""
67
  context_parts = []
68
 
69
  if "full_contexts" in example and example["full_contexts"]:
70
  for i, ctx in enumerate(example["full_contexts"]):
 
 
 
 
 
71
  content = ""
72
 
73
  # Extract content from either dict or string
 
90
 
91
  question = example.get("question", "")
92
 
 
 
 
 
93
  print(f"Starting inference for Model A: {model_a_name}")
94
  # Run model A
95
  summary_a = run_inference(models[model_a_name], context_text, question)
96
 
 
 
 
 
97
  print(f"Starting inference for Model B: {model_b_name}")
98
  # Run model B
99
  summary_b = run_inference(models[model_b_name], context_text, question)
 
106
  def run_inference(model_name, context, question):
107
  """
108
  Run inference using the specified model.
109
+ Returns the generated text.
110
  """
 
 
 
 
 
111
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
  result = ""
113
  tokenizer_kwargs = {
 
126
  if model_name in tokenizer_cache:
127
  tokenizer = tokenizer_cache[model_name]
128
  else:
 
 
 
 
 
129
  # Common arguments for tokenizer loading
130
  tokenizer_load_args = {"padding_side": "left", "token": True}
131
 
 
145
  if tokenizer.pad_token is None:
146
  tokenizer.pad_token = tokenizer.eos_token
147
 
 
 
 
 
 
 
 
 
148
  print("REACHED HERE BEFORE pipe")
149
  print(f"Loading model {model_name}...")
 
 
 
 
 
150
 
151
  if "bitnet" in model_name.lower():
152
  bitnet_model = BitNetForCausalLM.from_pretrained(
 
188
  torch_dtype=torch.bfloat16,
189
  )
190
 
 
 
 
 
 
191
  text_input = format_rag_prompt(question, context, accepts_sys)
192
 
193
  print(f"Starting generation for {model_name}")
 
196
  result = pipe(
197
  text_input,
198
  max_new_tokens=512,
 
199
  generation_kwargs={"skip_special_tokens": True}
200
  )[0]["generated_text"]
201
 
 
219
  prompt_tokens_length = input_ids.shape[1]
220
 
221
  with torch.inference_mode():
 
 
 
 
 
222
  output_sequences = model.generate(
223
  input_ids=input_ids,
224
  attention_mask=attention_mask,
225
  max_new_tokens=512,
226
  eos_token_id=tokenizer.eos_token_id,
227
+ pad_token_id=tokenizer.pad_token_id
 
228
  )
229
 
230
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
 
238
  # **tokenizer_kwargs,
239
  # ).to(bitnet_model.device)
240
  # with torch.inference_mode():
 
 
 
241
  # output_sequences = bitnet_model.generate(
242
  # **formatted,
243
  # max_new_tokens=512,
 
244
  # )
 
245
  # result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
246
  else: # For other models
247
  formatted = pipe.tokenizer.apply_chat_template(
 
255
  outputs = pipe(
256
  formatted,
257
  max_new_tokens=512,
 
258
  generation_kwargs={"skip_special_tokens": True}
259
  )
260
  result = outputs[0]["generated_text"][input_length:]