Spaces:
Runtime error
Runtime error
Update hawk/models/video_llama.py
Browse files
hawk/models/video_llama.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import logging
|
2 |
import random
|
|
|
|
|
3 |
|
4 |
import torch
|
5 |
from torch.cuda.amp import autocast as autocast
|
@@ -164,7 +166,7 @@ class VideoLLAMA(Blip2Base):
|
|
164 |
logging.info('Loading Q-Former Motion Done')
|
165 |
|
166 |
logging.info('Loading LLAMA Tokenizer')
|
167 |
-
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
168 |
if self.llama_tokenizer.pad_token is None:
|
169 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.unk_token
|
170 |
DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
|
@@ -179,6 +181,7 @@ class VideoLLAMA(Blip2Base):
|
|
179 |
if self.low_resource:
|
180 |
self.llama_model = LlamaForCausalLM.from_pretrained(
|
181 |
llama_model,
|
|
|
182 |
torch_dtype=torch.bfloat16,
|
183 |
load_in_8bit=True,
|
184 |
device_map={'': device_8bit}
|
@@ -186,6 +189,7 @@ class VideoLLAMA(Blip2Base):
|
|
186 |
else:
|
187 |
self.llama_model = LlamaForCausalLM.from_pretrained(
|
188 |
llama_model,
|
|
|
189 |
torch_dtype=torch.bfloat16,
|
190 |
)
|
191 |
|
@@ -844,7 +848,8 @@ class VideoLLAMA(Blip2Base):
|
|
844 |
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
845 |
if ckpt_path:
|
846 |
print("Load first Checkpoint: {}".format(ckpt_path))
|
847 |
-
|
|
|
848 |
msg = model.load_state_dict(ckpt['model'], strict=False)
|
849 |
ckpt_path_2 = cfg.get("ckpt_2", "")
|
850 |
if ckpt_path_2:
|
|
|
1 |
import logging
|
2 |
import random
|
3 |
+
import requests
|
4 |
+
import io
|
5 |
|
6 |
import torch
|
7 |
from torch.cuda.amp import autocast as autocast
|
|
|
166 |
logging.info('Loading Q-Former Motion Done')
|
167 |
|
168 |
logging.info('Loading LLAMA Tokenizer')
|
169 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False, subfolder="llama-2-7b-chat-hf")
|
170 |
if self.llama_tokenizer.pad_token is None:
|
171 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.unk_token
|
172 |
DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
|
|
|
181 |
if self.low_resource:
|
182 |
self.llama_model = LlamaForCausalLM.from_pretrained(
|
183 |
llama_model,
|
184 |
+
subfolder="llama-2-7b-chat-hf",
|
185 |
torch_dtype=torch.bfloat16,
|
186 |
load_in_8bit=True,
|
187 |
device_map={'': device_8bit}
|
|
|
189 |
else:
|
190 |
self.llama_model = LlamaForCausalLM.from_pretrained(
|
191 |
llama_model,
|
192 |
+
subfolder="llama-2-7b-chat-hf",
|
193 |
torch_dtype=torch.bfloat16,
|
194 |
)
|
195 |
|
|
|
848 |
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
849 |
if ckpt_path:
|
850 |
print("Load first Checkpoint: {}".format(ckpt_path))
|
851 |
+
response = requests.get(ckpt_path)
|
852 |
+
ckpt = torch.load(io.BytesIO(response.content), map_location="cpu")
|
853 |
msg = model.load_state_dict(ckpt['model'], strict=False)
|
854 |
ckpt_path_2 = cfg.get("ckpt_2", "")
|
855 |
if ckpt_path_2:
|