Commit
·
548d485
1
Parent(s):
9578f22
Update ced_model/modeling_ced.py
Browse files
ced_model/modeling_ced.py
CHANGED
@@ -457,9 +457,7 @@ class CedModel(CedPreTrainedModel):
|
|
457 |
n_splits = 1
|
458 |
|
459 |
x = self.forward_features(x)
|
460 |
-
|
461 |
-
x = torch.flatten(x, 0, 1)
|
462 |
-
x = torch.unsqueeze(x, 0)
|
463 |
|
464 |
return SequenceClassifierOutput(logits=x)
|
465 |
|
|
|
457 |
n_splits = 1
|
458 |
|
459 |
x = self.forward_features(x)
|
460 |
+
x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1]))
|
|
|
|
|
461 |
|
462 |
return SequenceClassifierOutput(logits=x)
|
463 |
|