X-iZhang commited on
Commit
943804b
·
verified ·
1 Parent(s): a46b590

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +3 -3
libra/model/builder.py CHANGED
@@ -77,7 +77,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
77
  model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
78
 
79
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
80
- mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
81
  model.load_state_dict(mm_projector_weights, strict=False)
82
  else:
83
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
@@ -94,11 +94,11 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
94
  print(f"Merging weights")
95
  model = model.merge_and_unload()
96
  print('Convert to FP16...')
97
- model.to(torch.float16)
98
  else:
99
  use_fast = False
100
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
101
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True,torch_dtype=torch.bfloat16, **kwargs)
102
 
103
  image_processor = None
104
 
 
77
  model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
78
 
79
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
80
+ mm_projector_weights = {k: v.to(torch.bfloat16) for k, v in mm_projector_weights.items()}
81
  model.load_state_dict(mm_projector_weights, strict=False)
82
  else:
83
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
 
94
  print(f"Merging weights")
95
  model = model.merge_and_unload()
96
  print('Convert to FP16...')
97
+ model.to(torch.bfloat16)
98
  else:
99
  use_fast = False
100
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
101
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
102
 
103
  image_processor = None
104