anicolson commited on
Commit
65cbcfd
1 Parent(s): 1084f69

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modelling_medicap.py +8 -7
modelling_medicap.py CHANGED
@@ -176,13 +176,14 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
176
  embeddings = self.decoder.transformer.wte(decoder_input_ids)
177
  embeddings = torch.cat([encoder_outputs[0], embeddings], dim=1)
178
 
179
- decoder_attention_mask = torch.cat(
180
- [
181
- torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
182
- decoder_attention_mask
183
- ],
184
- dim=1,
185
- )
 
186
 
187
  decoder_outputs = self.decoder(
188
  input_ids=decoder_input_ids,
 
176
  embeddings = self.decoder.transformer.wte(decoder_input_ids)
177
  embeddings = torch.cat([encoder_outputs[0], embeddings], dim=1)
178
 
179
+ if decoder_attention_mask:
180
+ decoder_attention_mask = torch.cat(
181
+ [
182
+ torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
183
+ decoder_attention_mask
184
+ ],
185
+ dim=1,
186
+ )
187
 
188
  decoder_outputs = self.decoder(
189
  input_ids=decoder_input_ids,