oliver-aizip commited on
Commit
593a8e7
·
1 Parent(s): 0226e6c

separate bitnet from generic pipeline

Browse files
Files changed (1) hide show
  1. utils/models.py +26 -2
utils/models.py CHANGED
@@ -11,7 +11,7 @@ from transformers import (
11
  AutoTokenizer,
12
  AutoModelForCausalLM,
13
  StoppingCriteria,
14
- StoppingCriteriaList,
15
  )
16
  from .prompts import format_rag_prompt
17
  from .shared import generation_interrupt
@@ -156,7 +156,14 @@ def run_inference(model_name, context, question):
156
 
157
  print("REACHED HERE BEFORE pipe")
158
  print(f"Loading model {model_name}...")
159
- if "icecream" not in model_name.lower():
 
 
 
 
 
 
 
160
  pipe = pipeline(
161
  "text-generation",
162
  model=model_name,
@@ -226,7 +233,24 @@ def run_inference(model_name, context, question):
226
 
227
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
228
  result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
 
230
  else: # For other models
231
  formatted = pipe.tokenizer.apply_chat_template(
232
  text_input,
 
11
  AutoTokenizer,
12
  AutoModelForCausalLM,
13
  StoppingCriteria,
14
+ BitNetForCausalLM
15
  )
16
  from .prompts import format_rag_prompt
17
  from .shared import generation_interrupt
 
156
 
157
  print("REACHED HERE BEFORE pipe")
158
  print(f"Loading model {model_name}...")
159
+ if "bitnet" in model_name.lower():
160
+ bitnet_model = BitNetForCausalLM.from_pretrained(
161
+ model_name,
162
+ device_map="cuda",
163
+ torch_dtype=torch.bfloat16,
164
+ trust_remote_code=True,
165
+ )
166
+ elif "icecream" not in model_name.lower():
167
  pipe = pipeline(
168
  "text-generation",
169
  model=model_name,
 
233
 
234
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
235
  result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
236
+ elif "bitnet" in model_name.lower():
237
+ formatted = tokenizer.apply_chat_template(
238
+ text_input,
239
+ tokenize=True,
240
+ return_tensors="pt",
241
+ return_dict=True,
242
+ **tokenizer_kwargs,
243
+ ).to(device)
244
+ with torch.inference_mode():
245
+ # Check interrupt before generation
246
+ if generation_interrupt.is_set():
247
+ return ""
248
+ output_sequences = bitnet_model.generate(
249
+ **formatted,
250
+ max_new_tokens=512,
251
+ )
252
 
253
+ result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
254
  else: # For other models
255
  formatted = pipe.tokenizer.apply_chat_template(
256
  text_input,