maybe fp16 works?
Browse files- utils/inference.py +1 -1
utils/inference.py
CHANGED
@@ -45,7 +45,7 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
|
|
45 |
)
|
46 |
else:
|
47 |
model = LlamaForCausalLM.from_pretrained(
|
48 |
-
base_model, device_map={"": device}, low_cpu_mem_usage=True
|
49 |
)
|
50 |
if adapter_model is not None:
|
51 |
model = PeftModel.from_pretrained(
|
|
|
45 |
)
|
46 |
else:
|
47 |
model = LlamaForCausalLM.from_pretrained(
|
48 |
+
base_model, device_map={"": device}, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
49 |
)
|
50 |
if adapter_model is not None:
|
51 |
model = PeftModel.from_pretrained(
|