Commit
·
9578f22
1
Parent(s):
5e9bb10
Fix forward method for long audios.
Browse files- ced_model/modeling_ced.py +6 -12
ced_model/modeling_ced.py
CHANGED
@@ -453,19 +453,13 @@ class CedModel(CedPreTrainedModel):
|
|
453 |
splits = torch.stack(splits[:-1], dim=0)
|
454 |
n_splits = len(splits)
|
455 |
x = torch.flatten(splits, 0, 1) # spl b c f t-> (spl b) c f t
|
456 |
-
x = self.forward_head(self.ced(x))
|
457 |
-
x = torch.reshape(
|
458 |
-
x, (n_splits, -1, self.outputdim)
|
459 |
-
) # (spl b) d -> spl b d, spl=n_splits
|
460 |
-
|
461 |
-
if self.config.eval_avg == "mean":
|
462 |
-
x = x.mean(0)
|
463 |
-
elif self.config.eval_avg == "max":
|
464 |
-
x = x.max(0)[0]
|
465 |
-
else:
|
466 |
-
raise ValueError(f"Unknown Eval average function ({self.eval_avg})")
|
467 |
else:
|
468 |
-
|
|
|
|
|
|
|
|
|
|
|
469 |
|
470 |
return SequenceClassifierOutput(logits=x)
|
471 |
|
|
|
453 |
splits = torch.stack(splits[:-1], dim=0)
|
454 |
n_splits = len(splits)
|
455 |
x = torch.flatten(splits, 0, 1) # spl b c f t-> (spl b) c f t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
else:
|
457 |
+
n_splits = 1
|
458 |
+
|
459 |
+
x = self.forward_features(x)
|
460 |
+
if n_splits > 1:
|
461 |
+
x = torch.flatten(x, 0, 1)
|
462 |
+
x = torch.unsqueeze(x, 0)
|
463 |
|
464 |
return SequenceClassifierOutput(logits=x)
|
465 |
|