Added safety measure
Browse files- models/vallex.py +5 -0
models/vallex.py
CHANGED
@@ -588,6 +588,11 @@ class VALLE(VALLF):
|
|
588 |
print(f"Current memory used: {memory_used:.2f} MB")
|
589 |
break
|
590 |
|
|
|
|
|
|
|
|
|
|
|
591 |
y = torch.concat([y, samples], dim=1)
|
592 |
|
593 |
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|
|
|
588 |
print(f"Current memory used: {memory_used:.2f} MB")
|
589 |
break
|
590 |
|
591 |
+
# safety measure, break if token sequence too long
|
592 |
+
if y.shape[1] > 2250:
|
593 |
+
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
|
594 |
+
break
|
595 |
+
|
596 |
y = torch.concat([y, samples], dim=1)
|
597 |
|
598 |
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|