X-iZhang commited on
Commit
31fc3ad
·
verified ·
1 Parent(s): 1253c31

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +3 -3
libra/model/builder.py CHANGED
@@ -81,14 +81,14 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
81
  model.load_state_dict(mm_projector_weights, strict=False)
82
  else:
83
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
84
- model = LibraLlamaForCausalLM.from_pretrained(model_path, **kwargs)
85
  else:
86
  # Load language model
87
  if model_base is not None:
88
  # PEFT model
89
  from peft import PeftModel
90
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
- model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
92
  print(f"Loading LoRA weights from {model_path}")
93
  model = PeftModel.from_pretrained(model, model_path)
94
  print(f"Merging weights")
@@ -98,7 +98,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
98
  else:
99
  use_fast = False
100
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
101
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
102
 
103
  image_processor = None
104
 
 
81
  model.load_state_dict(mm_projector_weights, strict=False)
82
  else:
83
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
84
+ model = LibraLlamaForCausalLM.from_pretrained(model_path, device_map={"": "cpu"}, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **kwargs)
85
  else:
86
  # Load language model
87
  if model_base is not None:
88
  # PEFT model
89
  from peft import PeftModel
90
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
92
  print(f"Loading LoRA weights from {model_path}")
93
  model = PeftModel.from_pretrained(model, model_path)
94
  print(f"Merging weights")
 
98
  else:
99
  use_fast = False
100
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
101
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True,torch_dtype=torch.bfloat16, **kwargs)
102
 
103
  image_processor = None
104