Update model.py
Browse files
model.py
CHANGED
@@ -339,8 +339,10 @@ class StripedHyena(nn.Module):
|
|
339 |
self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
|
340 |
|
341 |
if config.get("use_flashfft", "False"):
|
342 |
-
|
343 |
-
|
|
|
|
|
344 |
self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
|
345 |
else:
|
346 |
self.flash_fft = None
|
|
|
339 |
self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
|
340 |
|
341 |
if config.get("use_flashfft", "False"):
|
342 |
+
try:
|
343 |
+
from flashfftconv import FlashFFTConv
|
344 |
+
except:
|
345 |
+
raise ImportError
|
346 |
self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
|
347 |
else:
|
348 |
self.flash_fft = None
|