Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
anicolson commited on
Commit
df24aad
·
verified ·
1 Parent(s): fc8cc82

Update modelling_cxrrg.py

Browse files
Files changed (1) hide show
  1. modelling_cxrrg.py +1 -1
modelling_cxrrg.py CHANGED
@@ -541,4 +541,4 @@ class CXRRGModel(VisionEncoderDecoderModel):
541
  causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
542
 
543
  mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
544
- return mixed_causality_4d_attention_mask
 
541
  causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
542
 
543
  mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
544
+ return mixed_causality_4d_attention_mask.to(dtype=torch.float)