multimodalart HF Staff commited on
Commit
6e7c357
·
verified ·
1 Parent(s): ec96039

Update wan/modules/causal_model.py

Browse files
Files changed (1) hide show
  1. wan/modules/causal_model.py +1 -0
wan/modules/causal_model.py CHANGED
@@ -236,6 +236,7 @@ class CausalWanSelfAttention(nn.Module):
236
 
237
  # output
238
  x = x.flatten(2)
 
239
  x = self.o(x)
240
  return x
241
 
 
236
 
237
  # output
238
  x = x.flatten(2)
239
+ x = x.to(self.o.weight.dtype)
240
  x = self.o(x)
241
  return x
242