Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
anicolson commited on
Commit
c6de590
1 Parent(s): 0a19e9f

Delete modelling_mimic_cxr_rev_d.py

Browse files
Files changed (1) hide show
  1. modelling_mimic_cxr_rev_d.py +0 -557
modelling_mimic_cxr_rev_d.py DELETED
@@ -1,557 +0,0 @@
1
- import functools
2
- import os
3
- from typing import Optional, Tuple, Union
4
-
5
- import torch
6
- import transformers
7
- from torch.nn import CrossEntropyLoss, Linear
8
- from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
9
- from transformers.configuration_utils import PretrainedConfig
10
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
11
- from transformers.modeling_utils import PreTrainedModel
12
- from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import (
13
- VisionEncoderDecoderConfig,
14
- )
15
- from transformers.utils import logging
16
-
17
- from modules.transformers.uniformer.modelling_uniformer import (
18
- MultiUniFormerWithProjectionHead,
19
- )
20
-
21
- logger = logging.get_logger(__name__)
22
-
23
-
24
- class MIMICCXRMultimodalModel(VisionEncoderDecoderModel):
25
-
26
- config_class = VisionEncoderDecoderConfig
27
- base_model_prefix = "vision_encoder_decoder"
28
- main_input_name = "pixel_values"
29
- supports_gradient_checkpointing = True
30
-
31
- def __init__(
32
- self,
33
- config: Optional[PretrainedConfig] = None,
34
- encoder: Optional[PreTrainedModel] = None,
35
- decoder: Optional[PreTrainedModel] = None,
36
- DefaultEncoderClass = MultiUniFormerWithProjectionHead,
37
- DefaultDecoderClass = transformers.LlamaForCausalLM,
38
- ):
39
-
40
- if decoder:
41
- assert not decoder.config.add_cross_attention, '"add_cross_attention" must be False for the given decoder'
42
- assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
43
-
44
- if config is None and (encoder is None or decoder is None):
45
- raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
46
- if config is None:
47
- config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
48
- else:
49
- if not isinstance(config, self.config_class):
50
- raise ValueError(f"Config: {config} has to be of type {self.config_class}")
51
-
52
- config.tie_word_embeddings = False
53
-
54
- # Initialize with config:
55
- PreTrainedModel.__init__(self, config)
56
-
57
- # Encoder:
58
- if encoder is None:
59
- encoder = DefaultEncoderClass(config=config.encoder)
60
-
61
- # Decoder:
62
- if decoder is None:
63
- assert not config.decoder.add_cross_attention
64
- decoder = DefaultDecoderClass(config=config.decoder)
65
-
66
- self.encoder = encoder
67
- self.decoder = decoder
68
-
69
- if self.encoder.config.to_dict() != self.config.encoder.to_dict():
70
- logger.warning(
71
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
72
- f" {self.config.encoder}"
73
- )
74
- if self.decoder.config.to_dict() != self.config.decoder.to_dict():
75
- logger.warning(
76
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
77
- f" {self.config.decoder}"
78
- )
79
-
80
- self.encoder.config = self.config.encoder
81
- self.decoder.config = self.config.decoder
82
-
83
- assert config.decoder.is_decoder
84
- assert 'img_token_id' in self.decoder.config.__dict__
85
- assert 'pad_token_id' in self.decoder.config.__dict__
86
- assert 'token_type_embeddings' in self.decoder.config.__dict__
87
-
88
- if self.decoder.config.token_type_embeddings == 'add':
89
- self.token_type_embeddings = torch.nn.Embedding(self.decoder.config.num_token_types, self.decoder.config.hidden_size)
90
-
91
- def forward(
92
- self,
93
- pixel_values: Optional[torch.FloatTensor] = None,
94
- decoder_input_ids: Optional[torch.LongTensor] = None,
95
- decoder_attention_mask: Optional[torch.FloatTensor] = None,
96
- decoder_token_type_ids: Optional[torch.LongTensor] = None,
97
- encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
98
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
99
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
100
- decoder_position_ids: Optional[torch.LongTensor] = None,
101
- labels: Optional[torch.LongTensor] = None,
102
- use_cache: Optional[bool] = None,
103
- output_attentions: Optional[bool] = None,
104
- output_hidden_states: Optional[bool] = None,
105
- return_dict: Optional[bool] = None,
106
- **kwargs,
107
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
108
-
109
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
110
-
111
- kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
112
-
113
- kwargs_decoder = {
114
- argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
115
- }
116
-
117
- if decoder_inputs_embeds is None:
118
- decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)
119
-
120
- if encoder_outputs is None: # Ths is for when generate() is not called; for generation, see prepare_inputs_for_generation():
121
- if pixel_values is None:
122
- raise ValueError("You have to specify pixel_values")
123
-
124
- encoder_outputs = self.encoder(
125
- pixel_values,
126
- output_hidden_states=output_hidden_states,
127
- return_dict=return_dict,
128
- **kwargs_encoder,
129
- ) # UniFormer does not support output_attentions.
130
-
131
- assert decoder_inputs_embeds is not None
132
- decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
133
-
134
- # Add image token type identifiers:
135
- decoder_token_type_ids = torch.cat(
136
- [
137
- torch.full(
138
- encoder_outputs[0].shape[:-1],
139
- self.decoder.config.img_token_id,
140
- dtype=decoder_token_type_ids.dtype,
141
- device=decoder_token_type_ids.device,
142
- ),
143
- decoder_token_type_ids
144
- ],
145
- dim=1,
146
- )
147
-
148
- # Position identifiers accounting for padding:
149
- report_position_ids = decoder_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
150
- report_position_ids.masked_fill_(decoder_attention_mask == 0, 1)
151
- decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1)
152
-
153
- # 4D attention mask:
154
- decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], decoder_attention_mask)
155
-
156
- assert decoder_position_ids is not None
157
- assert decoder_attention_mask is not None
158
- assert decoder_token_type_ids is not None
159
-
160
- if self.decoder.config.token_type_embeddings == 'add':
161
- decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids)
162
- elif self.decoder.config.token_type_embeddings == 'inbuilt':
163
- kwargs_decoder['token_type_ids'] = decoder_token_type_ids
164
-
165
- # Forward:
166
- decoder_outputs = self.decoder(
167
- inputs_embeds=decoder_inputs_embeds,
168
- attention_mask=decoder_attention_mask,
169
- position_ids=decoder_position_ids,
170
- output_attentions=output_attentions,
171
- output_hidden_states=output_hidden_states,
172
- use_cache=use_cache,
173
- past_key_values=past_key_values,
174
- return_dict=return_dict,
175
- **kwargs_decoder,
176
- )
177
-
178
- # Loss:
179
- loss = None
180
- if labels is not None:
181
- logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
182
- loss_fct = CrossEntropyLoss()
183
- loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
184
-
185
- if not return_dict:
186
- if loss is not None:
187
- return (loss,) + decoder_outputs + encoder_outputs
188
- else:
189
- return decoder_outputs + encoder_outputs
190
-
191
- encoder_hidden_states = encoder_outputs[0]
192
-
193
- return Seq2SeqLMOutput(
194
- loss=loss,
195
- logits=decoder_outputs.logits,
196
- past_key_values=decoder_outputs.past_key_values,
197
- decoder_hidden_states=decoder_outputs.hidden_states,
198
- decoder_attentions=decoder_outputs.attentions,
199
- encoder_last_hidden_state=encoder_hidden_states,
200
- )
201
-
202
- def prepare_inputs_for_generation(
203
- self,
204
- input_ids,
205
- special_token_ids,
206
- token_type_id_sections=None,
207
- past_key_values=None,
208
- use_cache=None,
209
- encoder_outputs=None,
210
- **kwargs,
211
- ):
212
- """
213
- Modification of:
214
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
215
- """
216
-
217
- report_attention_mask = (input_ids != self.decoder.config.pad_token_id).long()
218
-
219
- if past_key_values is None:
220
-
221
- # 4D attention mask:
222
- decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], report_attention_mask)
223
-
224
- # Position identifiers accounting for padding:
225
- report_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
226
- report_position_ids.masked_fill_(report_attention_mask == 0, 1)
227
- decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1)
228
-
229
- # `inputs_embeds` are only to be used in the 1st generation step:
230
- inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1)
231
-
232
- decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids, token_type_id_sections)
233
- decoder_token_type_ids = torch.cat(
234
- [
235
- torch.full(
236
- encoder_outputs[0].shape[:-1],
237
- self.decoder.config.img_token_id,
238
- dtype=decoder_token_type_ids.dtype,
239
- device=decoder_token_type_ids.device,
240
- ),
241
- decoder_token_type_ids,
242
- ],
243
- dim=1,
244
- ) # Add image token type identifiers.
245
-
246
- input_dict = {
247
- 'decoder_input_ids': input_ids,
248
- 'decoder_inputs_embeds': inputs_embeds,
249
- 'decoder_token_type_ids': decoder_token_type_ids,
250
- }
251
- else:
252
-
253
- # 4D attention mask:
254
- decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values(encoder_outputs[1], report_attention_mask)
255
-
256
- # Position identifiers accounting for padding:
257
- decoder_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
258
- decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
259
-
260
- # Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]:
261
- decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids, token_type_id_sections)
262
- decoder_position_ids = decoder_position_ids[:, -1:]
263
-
264
- past_length = past_key_values[0][0].shape[2]
265
-
266
- # Some generation methods only pass the last input ID:
267
- if input_ids.shape[1] > past_length:
268
- remove_prefix_length = past_length
269
- else:
270
- # Keep only the final ID:
271
- remove_prefix_length = input_ids.shape[1] - 1
272
-
273
- input_ids = input_ids[:, remove_prefix_length:]
274
-
275
- input_dict = {'decoder_input_ids': input_ids, 'decoder_token_type_ids': decoder_token_type_ids}
276
-
277
- input_dict.update(
278
- {
279
- 'decoder_attention_mask': decoder_attention_mask,
280
- 'decoder_position_ids': decoder_position_ids,
281
- 'encoder_outputs': encoder_outputs,
282
- 'past_key_values': past_key_values,
283
- 'use_cache': use_cache,
284
- }
285
- )
286
- return input_dict
287
-
288
- def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
289
- """
290
- Extract token type identifiers from the token identifiers.
291
-
292
- Argument/s:
293
- token_ids - token identifiers.
294
- special_token_ids - special token identifiers that indicate the separation between sections.
295
- token_type_id_section - token type identifier for each section.
296
-
297
- Returns:
298
- token_type_ids - token type identifiers.
299
- """
300
-
301
- token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
302
-
303
- mbatch_size, seq_len = token_ids.shape
304
- token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
305
-
306
- for i, j in enumerate(special_token_ids):
307
- # Find first occurrence of special tokens that indicate the boundary between sections:
308
- cols = (token_ids == j).int().argmax(dim=1)
309
- rows = torch.arange(mbatch_size, device=token_ids.device)
310
-
311
- # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
312
- cols += 1
313
-
314
- # Ensure that the column index is not out of bounds. If 0, then token_id not present.
315
- # This is safe as index 0 is always a special token (now equal to 1 due to +1):
316
- rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
317
- cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
318
-
319
- # Indices to that correspond to the second sequence:
320
- if rows.nelement() != 0:
321
- ids = torch.stack([
322
- torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
323
- y, seq_len, device=token_ids.device,
324
- )
325
- ])
326
-
327
- token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
328
-
329
- return token_type_ids
330
-
331
- def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
332
- """
333
- Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
334
- token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
335
-
336
- Argument/s:
337
- token_ids - token identifiers.
338
- special_token_ids - special token identifiers that indicate the separation between sections.
339
-
340
- Returns:
341
- token_type_ids - token type identifiers.
342
- """
343
-
344
- token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
345
- token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
346
-
347
- # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
348
- token_ids = token_ids[:, :-1]
349
-
350
- for i, j in enumerate(special_token_ids):
351
-
352
- # Find first occurrence of special token, which indicates the boundary between sections:
353
- exists = torch.any(token_ids == j, dim=1, keepdim=True)
354
- token_type_ids[exists] = token_type_id_sections[i + 1]
355
-
356
- return token_type_ids
357
-
358
- def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
359
- """
360
- Tokenize the reports and creates the inputs and targets for teacher forcing.
361
-
362
- Argument/s:
363
- findings - findings sections.
364
- impression - impression sections.
365
- return_token_type_ids - return the token type identifiers.
366
- tokenizer - Hugging Face tokenizer.
367
- max_len - maximum number of tokens.
368
-
369
- Returns:
370
- decoder_input_ids - the token identifiers for the input of the decoder.
371
- decoder_attention_mask - the attention mask for the decoder_input_ids.
372
- label_ids - the label token identifiers for the decoder.
373
- """
374
-
375
- # Prepare the sections for the tokenizer by placing special tokens between each section:
376
- reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
377
- zip(findings, impression)]
378
-
379
- # Tokenize the report:
380
- tokenized = tokenizer(
381
- reports,
382
- padding='longest',
383
- truncation=True,
384
- max_length=max_len + 1, # +1 to account for the bias between input and target.
385
- return_tensors='pt',
386
- return_token_type_ids=False,
387
- add_special_tokens=False,
388
- ).to(self.device)
389
-
390
- # Modify for language modelling:
391
- batch_dict = {
392
-
393
- # Labels for the decoder (shifted right by one for autoregression):
394
- 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
395
-
396
- # Remove last token identifier to match the sequence length of the labels:
397
- 'decoder_input_ids': tokenized['input_ids'][:, :-1],
398
-
399
- # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
400
- 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
401
- }
402
-
403
- return batch_dict
404
-
405
- def tokenize_report_teacher_forcing_rev_a(self, tokenizer: PreTrainedTokenizerFast, max_len: int, findings: Optional[str] = None, impression: Optional[str] = None, reports: Optional[str] = None):
406
- """
407
- Tokenize the reports and creates the inputs and targets for teacher forcing.
408
-
409
- Argument/s:
410
- tokenizer - Hugging Face tokenizer.
411
- max_len - maximum number of tokens.
412
- findings - findings sections.
413
- impression - impression sections.
414
- reports - prepared reports, with special tokens and report sections.
415
-
416
- Returns:
417
- decoder_input_ids - the token identifiers for the input of the decoder.
418
- decoder_attention_mask - the attention mask for the decoder_input_ids.
419
- label_ids - the label token identifiers for the decoder.
420
- """
421
-
422
- # Prepare the sections for the tokenizer by placing special tokens between each section:
423
- if reports is None:
424
- assert findings and impression, "If 'reports' is not defined, 'findings' and 'impression' need to be defined."
425
- reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
426
- zip(findings, impression)]
427
-
428
- # Tokenize the report:
429
- tokenized = tokenizer(
430
- reports,
431
- padding='longest',
432
- truncation=True,
433
- max_length=max_len + 1, # +1 to account for the bias between input and target.
434
- return_tensors='pt',
435
- return_token_type_ids=False,
436
- add_special_tokens=False,
437
- ).to(self.device)
438
-
439
- # Modify for language modelling:
440
- batch_dict = {
441
-
442
- # Labels for the decoder (shifted right by one for autoregression):
443
- 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
444
-
445
- # Remove last token identifier to match the sequence length of the labels:
446
- 'decoder_input_ids': tokenized['input_ids'][:, :-1],
447
-
448
- # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
449
- 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
450
- }
451
-
452
- return batch_dict
453
-
454
- def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
455
- """
456
- Split the token identifiers into sections, then convert the token identifiers into strings.
457
-
458
- Argument/s:
459
- token_ids - token identifiers.
460
- special_token_ids - special token identifiers that indicate the end of each section.
461
- tokenizer - Hugging Face tokenizer.
462
-
463
- Returns:
464
- token_type_ids - token type identifiers.
465
- """
466
-
467
- _, seq_len = token_ids.shape
468
-
469
- # The number of sections is the same as the number of special_token_ids:
470
- num_sections = len(special_token_ids)
471
-
472
- sections = {k: [] for k in range(num_sections)}
473
-
474
- for i in token_ids:
475
- prev_col = 0
476
- for j, k in enumerate(special_token_ids):
477
-
478
- # The maximum sequence length was exceeded, thus no more tokens:
479
- if prev_col >= seq_len:
480
- sections[j].append('')
481
- continue
482
-
483
- # Find first occurrence of special tokens that indicate the boundary between sections:
484
- col = (i == k).int().argmax().item()
485
-
486
- # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
487
- # the maximum sequence length):
488
- if col == 0:
489
- col = seq_len
490
-
491
- # Extract section token identifiers:
492
- section_token_ids = i[prev_col:col]
493
- prev_col = col
494
- section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
495
-
496
- sections[j].append(section_string)
497
-
498
- return tuple(sections.values())
499
-
500
- @staticmethod
501
- def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
502
-
503
- prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
504
- report_seq_len = causal_2d_attention_mask.shape[-1]
505
-
506
- non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
507
- causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
508
-
509
- # Upper left of attention matrix:
510
- upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
511
- upper_left = upper_left * non_causal_2d_attention_mask
512
- upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
513
-
514
- causal_mask = torch.tril(
515
- torch.ones(
516
- (
517
- report_seq_len,
518
- report_seq_len,
519
- ),
520
- dtype=torch.long,
521
- device=causal_2d_attention_mask.device,
522
- ),
523
- )
524
-
525
- # Lower right of attention matrix:
526
- lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
527
- lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
528
- lower_right = lower_right * causal_mask
529
-
530
- # Upper right of attention matrix:
531
- upper_right = torch.zeros(
532
- causal_2d_attention_mask.shape[0],
533
- 1,
534
- prompt_seq_len,
535
- report_seq_len,
536
- dtype=torch.long,
537
- device=causal_2d_attention_mask.device,
538
- )
539
-
540
- # Lower left of attention matrix:
541
- lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
542
- lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
543
-
544
- left = torch.cat((upper_left, lower_left), dim=2)
545
- right = torch.cat((upper_right, lower_right), dim=2)
546
-
547
- mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
548
- return mixed_causality_4d_attention_mask
549
-
550
- @staticmethod
551
- def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask):
552
-
553
- non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
554
- causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
555
-
556
- mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
557
- return mixed_causality_4d_attention_mask