Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
116a714
1
Parent(s):
83d0454
try using pipe but explicitly set model to be BitNetForCausalLM
Browse files- utils/models.py +29 -18
utils/models.py
CHANGED
@@ -163,6 +163,17 @@ def run_inference(model_name, context, question):
|
|
163 |
torch_dtype=torch.bfloat16,
|
164 |
trust_remote_code=True,
|
165 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
elif "icecream" not in model_name.lower():
|
167 |
pipe = pipeline(
|
168 |
"text-generation",
|
@@ -233,24 +244,24 @@ def run_inference(model_name, context, question):
|
|
233 |
|
234 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
235 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
236 |
-
elif "bitnet" in model_name.lower():
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
else: # For other models
|
255 |
formatted = pipe.tokenizer.apply_chat_template(
|
256 |
text_input,
|
|
|
163 |
torch_dtype=torch.bfloat16,
|
164 |
trust_remote_code=True,
|
165 |
)
|
166 |
+
pipe = pipeline(
|
167 |
+
"text-generation",
|
168 |
+
model=bitnet_model,
|
169 |
+
tokenizer=tokenizer,
|
170 |
+
device_map="cuda",
|
171 |
+
trust_remote_code=True,
|
172 |
+
torch_dtype=torch.bfloat16,
|
173 |
+
model_kwargs={
|
174 |
+
"attn_implementation": "eager",
|
175 |
+
},
|
176 |
+
)
|
177 |
elif "icecream" not in model_name.lower():
|
178 |
pipe = pipeline(
|
179 |
"text-generation",
|
|
|
244 |
|
245 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
246 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
247 |
+
# elif "bitnet" in model_name.lower():
|
248 |
+
# formatted = tokenizer.apply_chat_template(
|
249 |
+
# text_input,
|
250 |
+
# tokenize=True,
|
251 |
+
# return_tensors="pt",
|
252 |
+
# return_dict=True,
|
253 |
+
# **tokenizer_kwargs,
|
254 |
+
# ).to(bitnet_model.device)
|
255 |
+
# with torch.inference_mode():
|
256 |
+
# # Check interrupt before generation
|
257 |
+
# if generation_interrupt.is_set():
|
258 |
+
# return ""
|
259 |
+
# output_sequences = bitnet_model.generate(
|
260 |
+
# **formatted,
|
261 |
+
# max_new_tokens=512,
|
262 |
+
# )
|
263 |
+
|
264 |
+
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
265 |
else: # For other models
|
266 |
formatted = pipe.tokenizer.apply_chat_template(
|
267 |
text_input,
|