Jiaqi-hkust commited on
Commit
eb8007e
·
verified ·
1 Parent(s): 6fb09d0

Update hawk/models/video_llama.py

Browse files
Files changed (1) hide show
  1. hawk/models/video_llama.py +4 -4
hawk/models/video_llama.py CHANGED
@@ -94,11 +94,11 @@ class VideoLLAMA(Blip2Base):
94
  if freeze_vit:
95
  for name, param in self.visual_encoder.named_parameters():
96
  param.requires_grad = False
97
- self.visual_encoder = self.visual_encoder.float().eval()
98
  self.visual_encoder.train = disabled_train
99
  for name, param in self.ln_vision.named_parameters():
100
  param.requires_grad = False
101
- self.ln_vision = self.ln_vision.float().eval()
102
  self.ln_vision.train = disabled_train
103
  logging.info("freeze vision encoder")
104
  logging.info('Loading VIT Done')
@@ -114,11 +114,11 @@ class VideoLLAMA(Blip2Base):
114
  if freeze_vit:
115
  for name, param in self.visual_encoder_motion.named_parameters():
116
  param.requires_grad = False
117
- self.visual_encoder_motion = self.visual_encoder_motion.float().eval()
118
  self.visual_encoder_motion.train = disabled_train
119
  for name, param in self.ln_vision_motion.named_parameters():
120
  param.requires_grad = False
121
- self.ln_vision_motion = self.ln_vision_motion.float().eval()
122
  self.ln_vision_motion.train = disabled_train
123
  logging.info("freeze vision encoder")
124
  logging.info('Loading VIT motion Done')
 
94
  if freeze_vit:
95
  for name, param in self.visual_encoder.named_parameters():
96
  param.requires_grad = False
97
+ self.visual_encoder = self.visual_encoder.eval()
98
  self.visual_encoder.train = disabled_train
99
  for name, param in self.ln_vision.named_parameters():
100
  param.requires_grad = False
101
+ self.ln_vision = self.ln_vision.eval()
102
  self.ln_vision.train = disabled_train
103
  logging.info("freeze vision encoder")
104
  logging.info('Loading VIT Done')
 
114
  if freeze_vit:
115
  for name, param in self.visual_encoder_motion.named_parameters():
116
  param.requires_grad = False
117
+ self.visual_encoder_motion = self.visual_encoder_motion.eval()
118
  self.visual_encoder_motion.train = disabled_train
119
  for name, param in self.ln_vision_motion.named_parameters():
120
  param.requires_grad = False
121
+ self.ln_vision_motion = self.ln_vision_motion.eval()
122
  self.ln_vision_motion.train = disabled_train
123
  logging.info("freeze vision encoder")
124
  logging.info('Loading VIT motion Done')