anicolson commited on
Commit
4dc39f4
1 Parent(s): a38680c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modelling_medicap.py +55 -15
modelling_medicap.py CHANGED
@@ -160,6 +160,9 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
160
  argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
161
  }
162
 
 
 
 
163
  if encoder_outputs is None:
164
  if pixel_values is None:
165
  raise ValueError("You have to specify pixel_values")
@@ -170,23 +173,18 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
170
  return_dict=return_dict,
171
  **kwargs_encoder,
172
  ) # CvT does not support output_attentions.
173
- elif isinstance(encoder_outputs, tuple):
174
- encoder_outputs = BaseModelOutput(*encoder_outputs)
175
-
176
- embeddings = self.decoder.transformer.wte(decoder_input_ids)
177
- embeddings = torch.cat([encoder_outputs[0], embeddings], dim=1)
178
-
179
- if torch.is_tensor(decoder_attention_mask):
180
- decoder_attention_mask = torch.cat(
181
- [
182
- torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
183
- decoder_attention_mask
184
- ],
185
- dim=1,
186
- )
187
 
188
  decoder_outputs = self.decoder(
189
- input_ids=decoder_input_ids,
190
  attention_mask=decoder_attention_mask,
191
  inputs_embeds=decoder_inputs_embeds,
192
  output_attentions=output_attentions,
@@ -220,6 +218,48 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
220
  encoder_last_hidden_state=encoder_outputs.last_hidden_state,
221
  )
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def tokenize_captions_teacher_forcing(
224
  self,
225
  captions: str,
 
160
  argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
161
  }
162
 
163
+ if decoder_inputs_embeds is None:
164
+ decoder_inputs_embeds = self.decoder.transformer.wte(decoder_input_ids)
165
+
166
  if encoder_outputs is None:
167
  if pixel_values is None:
168
  raise ValueError("You have to specify pixel_values")
 
173
  return_dict=return_dict,
174
  **kwargs_encoder,
175
  ) # CvT does not support output_attentions.
176
+ assert decoder_inputs_embeds.shape[1] == 1
177
+ decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
178
+ if decoder_attention_mask is not None:
179
+ decoder_attention_mask = torch.cat(
180
+ [
181
+ torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
182
+ decoder_attention_mask
183
+ ],
184
+ dim=1,
185
+ )
 
 
 
 
186
 
187
  decoder_outputs = self.decoder(
 
188
  attention_mask=decoder_attention_mask,
189
  inputs_embeds=decoder_inputs_embeds,
190
  output_attentions=output_attentions,
 
218
  encoder_last_hidden_state=encoder_outputs.last_hidden_state,
219
  )
220
 
221
+ def prepare_inputs_for_generation(
222
+ self,
223
+ input_ids,
224
+ past_key_values=None,
225
+ attention_mask=None,
226
+ use_cache=None,
227
+ encoder_outputs=None,
228
+ **kwargs,
229
+ ):
230
+ """
231
+ Modification of:
232
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
233
+
234
+ This can help with managing input_embeds and input_ids:
235
+ https://github.com/huggingface/transformers/issues/6535
236
+ """
237
+ input_dict = {'use_cache': use_cache, 'encoder_outputs': encoder_outputs, 'attention_mask': attention_mask}
238
+
239
+ if past_key_values is None:
240
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(
241
+ input_ids, inputs_embeds=encoder_outputs[0], past_key_values=past_key_values,
242
+ )
243
+ input_dict['decoder_inputs_embeds'] = decoder_inputs['inputs_embeds']
244
+ else:
245
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(
246
+ input_ids, past_key_values=past_key_values,
247
+ )
248
+ input_dict['decoder_input_ids'] = decoder_inputs['input_ids']
249
+ input_dict['past_key_values'] = decoder_inputs['past_key_values']
250
+ input_dict['decoder_attention_mask'] = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
251
+
252
+ # if torch.is_tensor(decoder_attention_mask):
253
+ # decoder_attention_mask = torch.cat(
254
+ # [
255
+ # torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
256
+ # decoder_attention_mask
257
+ # ],
258
+ # dim=1,
259
+ # )
260
+
261
+ return input_dict
262
+
263
  def tokenize_captions_teacher_forcing(
264
  self,
265
  captions: str,