anicolson commited on
Commit
1084f69
1 Parent(s): 2ac23b6

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modelling_medicap.py +21 -264
modelling_medicap.py CHANGED
@@ -136,98 +136,6 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
136
  self.encoder.config = self.config.encoder
137
  self.decoder.config = self.config.decoder
138
 
139
- @classmethod
140
- def from_encoder_decoder_pretrained(
141
- cls,
142
- encoder_pretrained_model_name_or_path: str = None,
143
- decoder_pretrained_model_name_or_path: str = None,
144
- *model_args,
145
- **kwargs,
146
- ) -> PreTrainedModel:
147
- kwargs_encoder = {
148
- argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
149
- }
150
-
151
- kwargs_decoder = {
152
- argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
153
- }
154
-
155
- # remove encoder, decoder kwargs from kwargs
156
- for key in kwargs_encoder.keys():
157
- del kwargs["encoder_" + key]
158
- for key in kwargs_decoder.keys():
159
- del kwargs["decoder_" + key]
160
-
161
- # Load and initialize the encoder and decoder
162
- # The distinction between encoder and decoder at the model level is made
163
- # by the value of the flag `is_decoder` that we need to set correctly.
164
- encoder = kwargs_encoder.pop("model", None)
165
- if encoder is None:
166
- if encoder_pretrained_model_name_or_path is None:
167
- raise ValueError(
168
- "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
169
- "to be defined."
170
- )
171
-
172
- if "config" not in kwargs_encoder:
173
- encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
174
- encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
175
- )
176
-
177
- if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
178
- logger.info(
179
- f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
180
- "from a decoder model. Cross-attention and casual mask are disabled."
181
- )
182
- encoder_config.is_decoder = False
183
- encoder_config.add_cross_attention = False
184
-
185
- kwargs_encoder["config"] = encoder_config
186
-
187
- encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
188
-
189
- decoder = kwargs_decoder.pop("model", None)
190
- if decoder is None:
191
- if decoder_pretrained_model_name_or_path is None:
192
- raise ValueError(
193
- "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
194
- "to be defined."
195
- )
196
-
197
- if "config" not in kwargs_decoder:
198
- decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
199
- decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
200
- )
201
-
202
- if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
203
- logger.info(
204
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
205
- f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
206
- f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
207
- )
208
- decoder_config.is_decoder = True
209
- decoder_config.add_cross_attention = True
210
-
211
- kwargs_decoder["config"] = decoder_config
212
-
213
- if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
214
- logger.warning(
215
- f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
216
- f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
217
- "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
218
- "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
219
- "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
220
- )
221
-
222
- decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
223
-
224
- # instantiate config with corresponding kwargs
225
- config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
226
-
227
- # make sure input & output embeddings is not tied
228
- config.tie_word_embeddings = False
229
- return cls(encoder=encoder, decoder=decoder, config=config)
230
-
231
  def forward(
232
  self,
233
  pixel_values: Optional[torch.FloatTensor] = None,
@@ -265,11 +173,6 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
265
  elif isinstance(encoder_outputs, tuple):
266
  encoder_outputs = BaseModelOutput(*encoder_outputs)
267
 
268
- # encoder_hidden_states = encoder_outputs[0]
269
- # encoder_attention_mask = None
270
-
271
- # image_features = self.encoder(images).projected_last_hidden_state
272
-
273
  embeddings = self.decoder.transformer.wte(decoder_input_ids)
274
  embeddings = torch.cat([encoder_outputs[0], embeddings], dim=1)
275
 
@@ -314,143 +217,43 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
314
  decoder_attentions=decoder_outputs.attentions,
315
  cross_attentions=decoder_outputs.cross_attentions,
316
  encoder_last_hidden_state=encoder_outputs.last_hidden_state,
317
- # encoder_hidden_states=encoder_outputs.hidden_states,
318
- # encoder_attentions=encoder_outputs.attentions,
319
  )
320
 
321
- def prepare_inputs_for_generation(
322
- self,
323
- input_ids,
324
- special_token_ids,
325
- past_key_values=None,
326
- attention_mask=None,
327
- use_cache=None,
328
- encoder_outputs=None,
329
- **kwargs,
330
  ):
331
  """
332
- Modification of:
333
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
334
- """
335
-
336
- decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
337
- decoder_attention_mask = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
338
-
339
- if not past_key_values:
340
- token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids)
341
- else:
342
- token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids)
343
-
344
- input_dict = {
345
- 'attention_mask': attention_mask,
346
- 'decoder_attention_mask': decoder_attention_mask,
347
- 'decoder_input_ids': decoder_inputs['input_ids'],
348
- 'decoder_token_type_ids': token_type_ids,
349
- 'encoder_outputs': encoder_outputs,
350
- 'past_key_values': decoder_inputs['past_key_values'],
351
- 'use_cache': use_cache,
352
- }
353
- return input_dict
354
-
355
- def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
356
- """
357
- Extract token type identifiers from the token identifiers.
358
-
359
- Argument/s:
360
- token_ids - token identifiers.
361
- special_token_ids - special token identifiers that indicate the separation between sections.
362
- token_type_id_section - token type identifier for each section.
363
-
364
- Returns:
365
- token_type_ids - token type identifiers.
366
- """
367
-
368
- token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
369
-
370
- mbatch_size, seq_len = token_ids.shape
371
- token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
372
-
373
- for i, j in enumerate(special_token_ids):
374
- # Find first occurrence of special tokens that indicate the boundary between sections:
375
- cols = (token_ids == j).int().argmax(dim=1)
376
- rows = torch.arange(mbatch_size, device=token_ids.device)
377
-
378
- # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
379
- cols += 1
380
-
381
- # Ensure that the column index is not out of bounds. If 0, then token_id not present.
382
- # This is safe as index 0 is always a special token (now equal to 1 due to +1):
383
- rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
384
- cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
385
-
386
- # Indices to that correspond to the second sequence:
387
- if rows.nelement() != 0:
388
- ids = torch.stack([
389
- torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
390
- y, seq_len, device=token_ids.device,
391
- )
392
- ])
393
-
394
- token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
395
-
396
- return token_type_ids
397
-
398
- def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
399
- """
400
- Extract token type identifiers from the token identifiers if past != None.
401
 
402
  Argument/s:
403
- token_ids - token identifiers.
404
- special_token_ids - special token identifiers that indicate the separation between sections.
405
-
406
- Returns:
407
- token_type_ids - token type identifiers.
408
- """
409
-
410
- token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
411
- token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
412
-
413
- # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
414
- token_ids = token_ids[:, :-1]
415
-
416
- for i, j in enumerate(special_token_ids):
417
-
418
- # Find first occurrence of special token, which indicates the boundary between sections:
419
- exists = torch.any(token_ids == j, dim=1, keepdim=True)
420
- token_type_ids[exists] = token_type_id_sections[i + 1]
421
-
422
- return token_type_ids
423
-
424
- def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
425
- """
426
- Tokenize the reports and creates the inputs and targets for teacher forcing.
427
-
428
- Argument/s:
429
- findings - findings section.
430
- impression - impression section.
431
- return_token_type_ids - return the token type identifiers.
432
  tokenizer - Hugging Face tokenizer.
433
  max_len - maximum number of tokens.
434
 
435
  Returns:
436
- decoder_input_ids - the token identifiers for the input of the decoder.
437
- decoder_attention_mask - the attention mask for the decoder_input_ids.
438
- label_ids - the label token identifiers for the decoder.
 
 
 
439
  """
440
 
441
- # Prepare the sections for the tokenizer by placing special tokens between each section:
442
- report = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
443
- zip(findings, impression)]
444
 
445
- # Tokenize the report:
446
- tokenized = tokenizer(
447
- report,
448
  padding='longest',
449
  truncation=True,
450
- max_length=max_len + 1, # +1 to account for the bias between input and target.
451
  return_tensors='pt',
452
  return_token_type_ids=False,
453
- add_special_tokens=False,
454
  ).to(self.device)
455
 
456
  # Modify for language modelling:
@@ -466,50 +269,4 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
466
  'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
467
  }
468
 
469
- return batch_dict
470
-
471
- def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
472
- """
473
- Split the token identifiers into sections, then convert the token identifiers into strings.
474
-
475
- Argument/s:
476
- token_ids - token identifiers.
477
- special_token_ids - special token identifiers that indicate the end of each section.
478
- tokenizer - Hugging Face tokenizer.
479
-
480
- Returns:
481
- token_type_ids - token type identifiers.
482
- """
483
-
484
- _, seq_len = token_ids.shape
485
-
486
- # The number of sections is the same as the number of special_token_ids:
487
- num_sections = len(special_token_ids)
488
-
489
- sections = {k: [] for k in range(num_sections)}
490
-
491
- for i in token_ids:
492
- prev_col = 0
493
- for j, k in enumerate(special_token_ids):
494
-
495
- # The maximum sequence length was exceeded, thus no more tokens:
496
- if prev_col >= seq_len:
497
- sections[j].append('')
498
- continue
499
-
500
- # Find first occurrence of special tokens that indicate the boundary between sections:
501
- col = (i == k).int().argmax().item()
502
-
503
- # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
504
- # the maximum sequence length):
505
- if col == 0:
506
- col = seq_len
507
-
508
- # Extract section token identifiers:
509
- section_token_ids = i[prev_col:col]
510
- prev_col = col
511
- section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
512
-
513
- sections[j].append(section_string)
514
-
515
- return tuple(sections.values())
 
136
  self.encoder.config = self.config.encoder
137
  self.decoder.config = self.config.decoder
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def forward(
140
  self,
141
  pixel_values: Optional[torch.FloatTensor] = None,
 
173
  elif isinstance(encoder_outputs, tuple):
174
  encoder_outputs = BaseModelOutput(*encoder_outputs)
175
 
 
 
 
 
 
176
  embeddings = self.decoder.transformer.wte(decoder_input_ids)
177
  embeddings = torch.cat([encoder_outputs[0], embeddings], dim=1)
178
 
 
217
  decoder_attentions=decoder_outputs.attentions,
218
  cross_attentions=decoder_outputs.cross_attentions,
219
  encoder_last_hidden_state=encoder_outputs.last_hidden_state,
 
 
220
  )
221
 
222
+ def tokenize_captions_teacher_forcing(
223
+ self,
224
+ captions: str,
225
+ tokenizer: PreTrainedTokenizerFast,
226
+ max_len: int,
 
 
 
 
227
  ):
228
  """
229
+ Tokenizes the captions and creates the inputs and targets for teacher forcing.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  Argument/s:
232
+ captions - the captions.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  tokenizer - Hugging Face tokenizer.
234
  max_len - maximum number of tokens.
235
 
236
  Returns:
237
+ batch_dict = {
238
+ decoder_input_ids - the token identifiers for the input of the decoder.
239
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
240
+ decoder_token_type_ids - the token type identifiers for the decoder_input_ids.
241
+ label_ids - the label token identifiers for the decoder.
242
+ }
243
  """
244
 
245
+ # Prepare the caption for the tokenizer by placing the special tokens:
246
+ caption = [f'{tokenizer.bos_token}{i}{tokenizer.eos_token}' for i in captions]
 
247
 
248
+ # Tokenize the caption:
249
+ tokenized = self.tokenizer(
250
+ caption,
251
  padding='longest',
252
  truncation=True,
253
+ max_length=max_len + 1, # +1 to account for the shift between input and target.
254
  return_tensors='pt',
255
  return_token_type_ids=False,
256
+ add_special_tokens=False, # Done in prepare_sections_for_tokenizer()
257
  ).to(self.device)
258
 
259
  # Modify for language modelling:
 
269
  'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
270
  }
271
 
272
+ return batch_dict