Spaces:
Running
Running
Update libra/model/builder.py
Browse files- libra/model/builder.py +3 -3
libra/model/builder.py
CHANGED
@@ -77,7 +77,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
|
|
77 |
model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
78 |
|
79 |
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
80 |
-
mm_projector_weights = {k: v.to(torch.
|
81 |
model.load_state_dict(mm_projector_weights, strict=False)
|
82 |
else:
|
83 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
@@ -94,11 +94,11 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
|
|
94 |
print(f"Merging weights")
|
95 |
model = model.merge_and_unload()
|
96 |
print('Convert to FP16...')
|
97 |
-
model.to(torch.
|
98 |
else:
|
99 |
use_fast = False
|
100 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
101 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True,
|
102 |
|
103 |
image_processor = None
|
104 |
|
|
|
77 |
model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
78 |
|
79 |
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
80 |
+
mm_projector_weights = {k: v.to(torch.bfloat16) for k, v in mm_projector_weights.items()}
|
81 |
model.load_state_dict(mm_projector_weights, strict=False)
|
82 |
else:
|
83 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
|
94 |
print(f"Merging weights")
|
95 |
model = model.merge_and_unload()
|
96 |
print('Convert to FP16...')
|
97 |
+
model.to(torch.bfloat16)
|
98 |
else:
|
99 |
use_fast = False
|
100 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
101 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
102 |
|
103 |
image_processor = None
|
104 |
|