oliver-aizip commited on
Commit
116a714
·
1 Parent(s): 83d0454

try using pipe but explicitly set model to be BitNetForCausalLM

Browse files
Files changed (1) hide show
  1. utils/models.py +29 -18
utils/models.py CHANGED
@@ -163,6 +163,17 @@ def run_inference(model_name, context, question):
163
  torch_dtype=torch.bfloat16,
164
  trust_remote_code=True,
165
  )
 
 
 
 
 
 
 
 
 
 
 
166
  elif "icecream" not in model_name.lower():
167
  pipe = pipeline(
168
  "text-generation",
@@ -233,24 +244,24 @@ def run_inference(model_name, context, question):
233
 
234
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
235
  result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
236
- elif "bitnet" in model_name.lower():
237
- formatted = tokenizer.apply_chat_template(
238
- text_input,
239
- tokenize=True,
240
- return_tensors="pt",
241
- return_dict=True,
242
- **tokenizer_kwargs,
243
- ).to(bitnet_model.device)
244
- with torch.inference_mode():
245
- # Check interrupt before generation
246
- if generation_interrupt.is_set():
247
- return ""
248
- output_sequences = bitnet_model.generate(
249
- **formatted,
250
- max_new_tokens=512,
251
- )
252
-
253
- result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
254
  else: # For other models
255
  formatted = pipe.tokenizer.apply_chat_template(
256
  text_input,
 
163
  torch_dtype=torch.bfloat16,
164
  trust_remote_code=True,
165
  )
166
+ pipe = pipeline(
167
+ "text-generation",
168
+ model=bitnet_model,
169
+ tokenizer=tokenizer,
170
+ device_map="cuda",
171
+ trust_remote_code=True,
172
+ torch_dtype=torch.bfloat16,
173
+ model_kwargs={
174
+ "attn_implementation": "eager",
175
+ },
176
+ )
177
  elif "icecream" not in model_name.lower():
178
  pipe = pipeline(
179
  "text-generation",
 
244
 
245
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
246
  result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
247
+ # elif "bitnet" in model_name.lower():
248
+ # formatted = tokenizer.apply_chat_template(
249
+ # text_input,
250
+ # tokenize=True,
251
+ # return_tensors="pt",
252
+ # return_dict=True,
253
+ # **tokenizer_kwargs,
254
+ # ).to(bitnet_model.device)
255
+ # with torch.inference_mode():
256
+ # # Check interrupt before generation
257
+ # if generation_interrupt.is_set():
258
+ # return ""
259
+ # output_sequences = bitnet_model.generate(
260
+ # **formatted,
261
+ # max_new_tokens=512,
262
+ # )
263
+
264
+ # result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
265
  else: # For other models
266
  formatted = pipe.tokenizer.apply_chat_template(
267
  text_input,