Jiaqi-hkust commited on
Commit
a44c0f5
·
verified ·
1 Parent(s): f7aa0ce

Update hawk/conversation/conversation_video.py

Browse files
hawk/conversation/conversation_video.py CHANGED
@@ -21,6 +21,8 @@ from hawk.processors.video_processor import ToTHWC,ToUint8,load_video,load_video
21
  from hawk.processors import Blip2ImageEvalProcessor
22
 
23
  from hawk.models.ImageBind.data import load_and_transform_audio_data
 
 
24
 
25
  class SeparatorStyle(Enum):
26
  """Different separator style."""
@@ -311,15 +313,16 @@ class Chat:
311
  else:
312
  raise NotImplementedError
313
 
314
-
315
- # conv.system = "You can understand the video that the user provides. Follow the instructions carefully and explain your answers in detail."
316
- image_emb, _, _ = self.model.encode_videoQformer_visual(video) # 1,32,4096
317
- image_motion_emb, _, _ = self.model.encode_videoQformer_visual(video_motion, motion=True) # 1,32,4096
318
- image_emb = image_emb.clone().detach()
319
- image_motion_emb = image_motion_emb.clone().detach()
320
- img_list.append(torch.cat((image_emb, image_motion_emb), dim=1))
321
- # img_list.append(image_motion_emb)
322
- conv.append_message(conv.roles[0], "<Video><ImageHere></Video> ")
 
323
  return "Received."
324
 
325
  def upload_img(self, image, conv, img_list):
 
21
  from hawk.processors import Blip2ImageEvalProcessor
22
 
23
  from hawk.models.ImageBind.data import load_and_transform_audio_data
24
+ from torch.cuda.amp import autocast
25
+
26
 
27
  class SeparatorStyle(Enum):
28
  """Different separator style."""
 
313
  else:
314
  raise NotImplementedError
315
 
316
+ with autocast():
317
+ # conv.system = "You can understand the video that the user provides. Follow the instructions carefully and explain your answers in detail."
318
+ image_emb, _, _ = self.model.encode_videoQformer_visual(video) # 1,32,4096
319
+ image_motion_emb, _, _ = self.model.encode_videoQformer_visual(video_motion, motion=True) # 1,32,4096
320
+ image_emb = image_emb.clone().detach()
321
+ image_motion_emb = image_motion_emb.clone().detach()
322
+ img_list.append(torch.cat((image_emb, image_motion_emb), dim=1))
323
+ # img_list.append(image_motion_emb)
324
+ conv.append_message(conv.roles[0], "<Video><ImageHere></Video> ")
325
+
326
  return "Received."
327
 
328
  def upload_img(self, image, conv, img_list):