Jiaqi-hkust commited on
Commit
d38bd7b
·
verified ·
1 Parent(s): b2ed81d

Update hawk/models/video_llama.py

Browse files
Files changed (1) hide show
  1. hawk/models/video_llama.py +4 -4
hawk/models/video_llama.py CHANGED
@@ -403,11 +403,11 @@ class VideoLLAMA(Blip2Base):
403
 
404
  def encode_videoQformer_visual(self, image, motion=False):
405
  if motion is False:
406
- device = image.device
407
 
408
  # input shape b,c,t,h,w
409
  batch_size,_,time_length,_,_ = image.size()
410
- image = einops.rearrange(image, 'b c t h w -> (b t) c h w')
411
  with self.maybe_autocast():
412
 
413
  # embed image features with blip2, out: (b t) q h
@@ -454,12 +454,12 @@ class VideoLLAMA(Blip2Base):
454
 
455
  else:
456
  # Motion Encoder
457
- device = image.device
458
 
459
  # input shape b,c,t,h,w
460
  batch_size,_,time_length,_,_ = image.size()
461
 
462
- image = einops.rearrange(image, 'b c t h w -> (b t) c h w')
463
 
464
  with self.maybe_autocast():
465
 
 
403
 
404
  def encode_videoQformer_visual(self, image, motion=False):
405
  if motion is False:
406
+ device = "cuda:0"
407
 
408
  # input shape b,c,t,h,w
409
  batch_size,_,time_length,_,_ = image.size()
410
+ image = einops.rearrange(image, 'b c t h w -> (b t) c h w').to(device)
411
  with self.maybe_autocast():
412
 
413
  # embed image features with blip2, out: (b t) q h
 
454
 
455
  else:
456
  # Motion Encoder
457
+ device = "cuda:0"
458
 
459
  # input shape b,c,t,h,w
460
  batch_size,_,time_length,_,_ = image.size()
461
 
462
+ image = einops.rearrange(image, 'b c t h w -> (b t) c h w').to(device)
463
 
464
  with self.maybe_autocast():
465