Upload folder using huggingface_hub
Browse files- modelling_medicap.py +55 -15
modelling_medicap.py
CHANGED
@@ -160,6 +160,9 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
160 |
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
161 |
}
|
162 |
|
|
|
|
|
|
|
163 |
if encoder_outputs is None:
|
164 |
if pixel_values is None:
|
165 |
raise ValueError("You have to specify pixel_values")
|
@@ -170,23 +173,18 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
170 |
return_dict=return_dict,
|
171 |
**kwargs_encoder,
|
172 |
) # CvT does not support output_attentions.
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
decoder_attention_mask
|
184 |
-
],
|
185 |
-
dim=1,
|
186 |
-
)
|
187 |
|
188 |
decoder_outputs = self.decoder(
|
189 |
-
input_ids=decoder_input_ids,
|
190 |
attention_mask=decoder_attention_mask,
|
191 |
inputs_embeds=decoder_inputs_embeds,
|
192 |
output_attentions=output_attentions,
|
@@ -220,6 +218,48 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
220 |
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
221 |
)
|
222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
def tokenize_captions_teacher_forcing(
|
224 |
self,
|
225 |
captions: str,
|
|
|
160 |
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
161 |
}
|
162 |
|
163 |
+
if decoder_inputs_embeds is None:
|
164 |
+
decoder_inputs_embeds = self.decoder.transformer.wte(decoder_input_ids)
|
165 |
+
|
166 |
if encoder_outputs is None:
|
167 |
if pixel_values is None:
|
168 |
raise ValueError("You have to specify pixel_values")
|
|
|
173 |
return_dict=return_dict,
|
174 |
**kwargs_encoder,
|
175 |
) # CvT does not support output_attentions.
|
176 |
+
assert decoder_inputs_embeds.shape[1] == 1
|
177 |
+
decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
|
178 |
+
if decoder_attention_mask is not None:
|
179 |
+
decoder_attention_mask = torch.cat(
|
180 |
+
[
|
181 |
+
torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
|
182 |
+
decoder_attention_mask
|
183 |
+
],
|
184 |
+
dim=1,
|
185 |
+
)
|
|
|
|
|
|
|
|
|
186 |
|
187 |
decoder_outputs = self.decoder(
|
|
|
188 |
attention_mask=decoder_attention_mask,
|
189 |
inputs_embeds=decoder_inputs_embeds,
|
190 |
output_attentions=output_attentions,
|
|
|
218 |
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
219 |
)
|
220 |
|
221 |
+
def prepare_inputs_for_generation(
|
222 |
+
self,
|
223 |
+
input_ids,
|
224 |
+
past_key_values=None,
|
225 |
+
attention_mask=None,
|
226 |
+
use_cache=None,
|
227 |
+
encoder_outputs=None,
|
228 |
+
**kwargs,
|
229 |
+
):
|
230 |
+
"""
|
231 |
+
Modification of:
|
232 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
|
233 |
+
|
234 |
+
This can help with managing input_embeds and input_ids:
|
235 |
+
https://github.com/huggingface/transformers/issues/6535
|
236 |
+
"""
|
237 |
+
input_dict = {'use_cache': use_cache, 'encoder_outputs': encoder_outputs, 'attention_mask': attention_mask}
|
238 |
+
|
239 |
+
if past_key_values is None:
|
240 |
+
decoder_inputs = self.decoder.prepare_inputs_for_generation(
|
241 |
+
input_ids, inputs_embeds=encoder_outputs[0], past_key_values=past_key_values,
|
242 |
+
)
|
243 |
+
input_dict['decoder_inputs_embeds'] = decoder_inputs['inputs_embeds']
|
244 |
+
else:
|
245 |
+
decoder_inputs = self.decoder.prepare_inputs_for_generation(
|
246 |
+
input_ids, past_key_values=past_key_values,
|
247 |
+
)
|
248 |
+
input_dict['decoder_input_ids'] = decoder_inputs['input_ids']
|
249 |
+
input_dict['past_key_values'] = decoder_inputs['past_key_values']
|
250 |
+
input_dict['decoder_attention_mask'] = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
|
251 |
+
|
252 |
+
# if torch.is_tensor(decoder_attention_mask):
|
253 |
+
# decoder_attention_mask = torch.cat(
|
254 |
+
# [
|
255 |
+
# torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
|
256 |
+
# decoder_attention_mask
|
257 |
+
# ],
|
258 |
+
# dim=1,
|
259 |
+
# )
|
260 |
+
|
261 |
+
return input_dict
|
262 |
+
|
263 |
def tokenize_captions_teacher_forcing(
|
264 |
self,
|
265 |
captions: str,
|