lym0302 commited on
Commit
2771ada
·
1 Parent(s): 3f6c072
third_party/VideoLLaMA2/videollama2/model/__init__.py CHANGED
@@ -76,8 +76,8 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
76
  bnb_4bit_quant_type='nf4'
77
  )
78
  else:
79
- kwargs['torch_dtype'] = torch.float16
80
- # kwargs['torch_dtype'] = torch.bfloat16
81
 
82
  if use_flash_attn:
83
  kwargs['attn_implementation'] = 'flash_attention_2'
 
76
  bnb_4bit_quant_type='nf4'
77
  )
78
  else:
79
+ # kwargs['torch_dtype'] = torch.float16
80
+ kwargs['torch_dtype'] = torch.bfloat16
81
 
82
  if use_flash_attn:
83
  kwargs['attn_implementation'] = 'flash_attention_2'