X-iZhang commited on
Commit
37ae470
·
verified ·
1 Parent(s): d57ad7c

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +7 -14
libra/model/builder.py CHANGED
@@ -24,20 +24,13 @@ from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, D
24
 
25
 
26
  def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
27
- quantization_config = BitsAndBytesConfig(
28
- load_in_4bit=True,
29
- bnb_4bit_compute_dtype=torch.float16,
30
- bnb_4bit_use_double_quant=True,
31
- bnb_4bit_quant_type='nf4'
32
- )
33
  device_map = {"": device}
34
  kwargs = {
35
  "device_map": device_map,
36
- "torch_dtype": torch.float16
37
  }
38
-
39
-
40
-
41
  if 'libra' in model_name.lower():
42
  # Load Libra model
43
  if 'lora' in model_name.lower() and model_base is None:
@@ -83,24 +76,24 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
83
  model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
84
 
85
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
86
- mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
87
  model.load_state_dict(mm_projector_weights, strict=False)
88
  else:
89
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
90
- model = LibraLlamaForCausalLM.from_pretrained(model_path, quantization_config=quantization_config, low_cpu_mem_usage=True, **kwargs)
91
  else:
92
  # Load language model
93
  if model_base is not None:
94
  # PEFT model
95
  from peft import PeftModel
96
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
97
- model = AutoModelForCausalLM.from_pretrained(model_base, quantization_config=quantization_config, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
98
  print(f"Loading LoRA weights from {model_path}")
99
  model = PeftModel.from_pretrained(model, model_path)
100
  print(f"Merging weights")
101
  model = model.merge_and_unload()
102
  print('Convert to FP16...')
103
- model.to(torch.float16)
104
  else:
105
  use_fast = False
106
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
 
24
 
25
 
26
  def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
27
+
 
 
 
 
 
28
  device_map = {"": device}
29
  kwargs = {
30
  "device_map": device_map,
31
+ "torch_dtype": torch.bfloat16
32
  }
33
+
 
 
34
  if 'libra' in model_name.lower():
35
  # Load Libra model
36
  if 'lora' in model_name.lower() and model_base is None:
 
76
  model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
77
 
78
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
79
+ mm_projector_weights = {k: v.to(torch.bfloat16) for k, v in mm_projector_weights.items()}
80
  model.load_state_dict(mm_projector_weights, strict=False)
81
  else:
82
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
83
+ model = LibraLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
84
  else:
85
  # Load language model
86
  if model_base is not None:
87
  # PEFT model
88
  from peft import PeftModel
89
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
90
+ model = AutoModelForCausalLM.from_pretrained(model_base, quantization_config=quantization_config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
91
  print(f"Loading LoRA weights from {model_path}")
92
  model = PeftModel.from_pretrained(model, model_path)
93
  print(f"Merging weights")
94
  model = model.merge_and_unload()
95
  print('Convert to FP16...')
96
+ model.to(torch.bfloat16)
97
  else:
98
  use_fast = False
99
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)