lorocksUMD commited on
Commit
77007c9
·
verified ·
1 Parent(s): 1a6c015

Update DenseAV/denseav/aggregators.py

Browse files
Files changed (1) hide show
  1. DenseAV/denseav/aggregators.py +1 -0
DenseAV/denseav/aggregators.py CHANGED
@@ -161,6 +161,7 @@ class BaseAggregator(torch.nn.Module):
161
  audio_mask.to(torch.float32))
162
 
163
  if self.use_cls:
 
164
  audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
165
  image_cls = preds[IMAGE_CLS].reshape(v, self.num_heads, new_c)
166
  cls_sims = torch.einsum(
 
161
  audio_mask.to(torch.float32))
162
 
163
  if self.use_cls:
164
+ print(preds[AUDIO_CLS].shape)
165
  audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
166
  image_cls = preds[IMAGE_CLS].reshape(v, self.num_heads, new_c)
167
  cls_sims = torch.einsum(