Jiaqi-hkust commited on
Commit
95b4753
·
verified ·
1 Parent(s): a56aadc

Update hawk/models/video_llama.py

Browse files
Files changed (1) hide show
  1. hawk/models/video_llama.py +7 -2
hawk/models/video_llama.py CHANGED
@@ -1,5 +1,7 @@
1
  import logging
2
  import random
 
 
3
 
4
  import torch
5
  from torch.cuda.amp import autocast as autocast
@@ -164,7 +166,7 @@ class VideoLLAMA(Blip2Base):
164
  logging.info('Loading Q-Former Motion Done')
165
 
166
  logging.info('Loading LLAMA Tokenizer')
167
- self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
168
  if self.llama_tokenizer.pad_token is None:
169
  self.llama_tokenizer.pad_token = self.llama_tokenizer.unk_token
170
  DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
@@ -179,6 +181,7 @@ class VideoLLAMA(Blip2Base):
179
  if self.low_resource:
180
  self.llama_model = LlamaForCausalLM.from_pretrained(
181
  llama_model,
 
182
  torch_dtype=torch.bfloat16,
183
  load_in_8bit=True,
184
  device_map={'': device_8bit}
@@ -186,6 +189,7 @@ class VideoLLAMA(Blip2Base):
186
  else:
187
  self.llama_model = LlamaForCausalLM.from_pretrained(
188
  llama_model,
 
189
  torch_dtype=torch.bfloat16,
190
  )
191
 
@@ -844,7 +848,8 @@ class VideoLLAMA(Blip2Base):
844
  ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
845
  if ckpt_path:
846
  print("Load first Checkpoint: {}".format(ckpt_path))
847
- ckpt = torch.load(ckpt_path, map_location="cpu")
 
848
  msg = model.load_state_dict(ckpt['model'], strict=False)
849
  ckpt_path_2 = cfg.get("ckpt_2", "")
850
  if ckpt_path_2:
 
1
  import logging
2
  import random
3
+ import requests
4
+ import io
5
 
6
  import torch
7
  from torch.cuda.amp import autocast as autocast
 
166
  logging.info('Loading Q-Former Motion Done')
167
 
168
  logging.info('Loading LLAMA Tokenizer')
169
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False, subfolder="llama-2-7b-chat-hf")
170
  if self.llama_tokenizer.pad_token is None:
171
  self.llama_tokenizer.pad_token = self.llama_tokenizer.unk_token
172
  DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
 
181
  if self.low_resource:
182
  self.llama_model = LlamaForCausalLM.from_pretrained(
183
  llama_model,
184
+ subfolder="llama-2-7b-chat-hf",
185
  torch_dtype=torch.bfloat16,
186
  load_in_8bit=True,
187
  device_map={'': device_8bit}
 
189
  else:
190
  self.llama_model = LlamaForCausalLM.from_pretrained(
191
  llama_model,
192
+ subfolder="llama-2-7b-chat-hf",
193
  torch_dtype=torch.bfloat16,
194
  )
195
 
 
848
  ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
849
  if ckpt_path:
850
  print("Load first Checkpoint: {}".format(ckpt_path))
851
+ response = requests.get(ckpt_path)
852
+ ckpt = torch.load(io.BytesIO(response.content), map_location="cpu")
853
  msg = model.load_state_dict(ckpt['model'], strict=False)
854
  ckpt_path_2 = cfg.get("ckpt_2", "")
855
  if ckpt_path_2: