Update modelling_cxrrg.py
Browse files- modelling_cxrrg.py +1 -1
modelling_cxrrg.py
CHANGED
@@ -541,4 +541,4 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
541 |
causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
|
542 |
|
543 |
mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
|
544 |
-
return mixed_causality_4d_attention_mask
|
|
|
541 |
causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
|
542 |
|
543 |
mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
|
544 |
+
return mixed_causality_4d_attention_mask.to(dtype=torch.float)
|