Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
593a8e7
1
Parent(s):
0226e6c
separate bitnet from generic pipeline
Browse files- utils/models.py +26 -2
utils/models.py
CHANGED
@@ -11,7 +11,7 @@ from transformers import (
|
|
11 |
AutoTokenizer,
|
12 |
AutoModelForCausalLM,
|
13 |
StoppingCriteria,
|
14 |
-
|
15 |
)
|
16 |
from .prompts import format_rag_prompt
|
17 |
from .shared import generation_interrupt
|
@@ -156,7 +156,14 @@ def run_inference(model_name, context, question):
|
|
156 |
|
157 |
print("REACHED HERE BEFORE pipe")
|
158 |
print(f"Loading model {model_name}...")
|
159 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
pipe = pipeline(
|
161 |
"text-generation",
|
162 |
model=model_name,
|
@@ -226,7 +233,24 @@ def run_inference(model_name, context, question):
|
|
226 |
|
227 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
228 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
|
|
230 |
else: # For other models
|
231 |
formatted = pipe.tokenizer.apply_chat_template(
|
232 |
text_input,
|
|
|
11 |
AutoTokenizer,
|
12 |
AutoModelForCausalLM,
|
13 |
StoppingCriteria,
|
14 |
+
BitNetForCausalLM
|
15 |
)
|
16 |
from .prompts import format_rag_prompt
|
17 |
from .shared import generation_interrupt
|
|
|
156 |
|
157 |
print("REACHED HERE BEFORE pipe")
|
158 |
print(f"Loading model {model_name}...")
|
159 |
+
if "bitnet" in model_name.lower():
|
160 |
+
bitnet_model = BitNetForCausalLM.from_pretrained(
|
161 |
+
model_name,
|
162 |
+
device_map="cuda",
|
163 |
+
torch_dtype=torch.bfloat16,
|
164 |
+
trust_remote_code=True,
|
165 |
+
)
|
166 |
+
elif "icecream" not in model_name.lower():
|
167 |
pipe = pipeline(
|
168 |
"text-generation",
|
169 |
model=model_name,
|
|
|
233 |
|
234 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
235 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
236 |
+
elif "bitnet" in model_name.lower():
|
237 |
+
formatted = tokenizer.apply_chat_template(
|
238 |
+
text_input,
|
239 |
+
tokenize=True,
|
240 |
+
return_tensors="pt",
|
241 |
+
return_dict=True,
|
242 |
+
**tokenizer_kwargs,
|
243 |
+
).to(device)
|
244 |
+
with torch.inference_mode():
|
245 |
+
# Check interrupt before generation
|
246 |
+
if generation_interrupt.is_set():
|
247 |
+
return ""
|
248 |
+
output_sequences = bitnet_model.generate(
|
249 |
+
**formatted,
|
250 |
+
max_new_tokens=512,
|
251 |
+
)
|
252 |
|
253 |
+
result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
254 |
else: # For other models
|
255 |
formatted = pipe.tokenizer.apply_chat_template(
|
256 |
text_input,
|