oliver-aizip commited on
Commit
c0fdd5a
·
1 Parent(s): e2b5d99

remove bitnet handling completely

Browse files
Files changed (1) hide show
  1. utils/models.py +2 -46
utils/models.py CHANGED
@@ -11,7 +11,6 @@ from transformers import (
11
  AutoTokenizer,
12
  AutoModelForCausalLM,
13
  StoppingCriteria,
14
- BitNetForCausalLM
15
  )
16
  from .prompts import format_rag_prompt
17
  from .shared import generation_interrupt
@@ -156,25 +155,7 @@ def run_inference(model_name, context, question):
156
 
157
  print("REACHED HERE BEFORE pipe")
158
  print(f"Loading model {model_name}...")
159
- if "bitnet" in model_name.lower():
160
- bitnet_model = BitNetForCausalLM.from_pretrained(
161
- model_name,
162
- #device_map="auto",
163
- torch_dtype=torch.bfloat16,
164
- #trust_remote_code=True,
165
- )
166
- pipe = pipeline(
167
- "text-generation",
168
- model=bitnet_model,
169
- tokenizer=tokenizer,
170
- #device_map="auto",
171
- #trust_remote_code=True,
172
- torch_dtype=torch.bfloat16,
173
- model_kwargs={
174
- "attn_implementation": "eager",
175
- },
176
- )
177
- elif "icecream" not in model_name.lower():
178
  pipe = pipeline(
179
  "text-generation",
180
  model=model_name,
@@ -221,12 +202,8 @@ def run_inference(model_name, context, question):
221
  **tokenizer_kwargs,
222
  )
223
 
224
-
225
  model_inputs = model_inputs.to(model.device)
226
-
227
  input_ids = model_inputs.input_ids
228
- attention_mask = model_inputs.attention_mask
229
-
230
  prompt_tokens_length = input_ids.shape[1]
231
 
232
  with torch.inference_mode():
@@ -235,33 +212,12 @@ def run_inference(model_name, context, question):
235
  return ""
236
 
237
  output_sequences = model.generate(
238
- input_ids=input_ids,
239
- attention_mask=attention_mask,
240
  max_new_tokens=512,
241
- eos_token_id=tokenizer.eos_token_id,
242
- pad_token_id=tokenizer.pad_token_id # Addresses the warning
243
  )
244
 
245
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
246
  result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
247
- # elif "bitnet" in model_name.lower():
248
- # formatted = tokenizer.apply_chat_template(
249
- # text_input,
250
- # tokenize=True,
251
- # return_tensors="pt",
252
- # return_dict=True,
253
- # **tokenizer_kwargs,
254
- # ).to(bitnet_model.device)
255
- # with torch.inference_mode():
256
- # # Check interrupt before generation
257
- # if generation_interrupt.is_set():
258
- # return ""
259
- # output_sequences = bitnet_model.generate(
260
- # **formatted,
261
- # max_new_tokens=512,
262
- # )
263
-
264
- # result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
265
  else: # For other models
266
  formatted = pipe.tokenizer.apply_chat_template(
267
  text_input,
 
11
  AutoTokenizer,
12
  AutoModelForCausalLM,
13
  StoppingCriteria,
 
14
  )
15
  from .prompts import format_rag_prompt
16
  from .shared import generation_interrupt
 
155
 
156
  print("REACHED HERE BEFORE pipe")
157
  print(f"Loading model {model_name}...")
158
+ if "icecream" not in model_name.lower():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  pipe = pipeline(
160
  "text-generation",
161
  model=model_name,
 
202
  **tokenizer_kwargs,
203
  )
204
 
 
205
  model_inputs = model_inputs.to(model.device)
 
206
  input_ids = model_inputs.input_ids
 
 
207
  prompt_tokens_length = input_ids.shape[1]
208
 
209
  with torch.inference_mode():
 
212
  return ""
213
 
214
  output_sequences = model.generate(
215
+ **model_inputs,
 
216
  max_new_tokens=512,
 
 
217
  )
218
 
219
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
220
  result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  else: # For other models
222
  formatted = pipe.tokenizer.apply_chat_template(
223
  text_input,