Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
anicolson commited on
Commit
a72eb06
1 Parent(s): c6de590

Upload model

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. modelling_cxrrg.py +554 -0
config.json CHANGED
@@ -1,9 +1,9 @@
1
  {
2
  "architectures": [
3
- "MIMICCXRMultimodalModel"
4
  ],
5
  "auto_map": {
6
- "AutoModel": "modelling_mimic_cxr_rev_d.MIMICCXRMultimodalModel"
7
  },
8
  "decoder": {
9
  "_name_or_path": "",
 
1
  {
2
  "architectures": [
3
+ "CXRRGModel"
4
  ],
5
  "auto_map": {
6
+ "AutoModel": "modelling_cxrrg.CXRRGModel"
7
  },
8
  "decoder": {
9
  "_name_or_path": "",
modelling_cxrrg.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import transformers
7
+ from modelling_uniformer import MultiUniFormerWithProjectionHead
8
+ from torch.nn import CrossEntropyLoss, Linear
9
+ from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
12
+ from transformers.modeling_utils import PreTrainedModel
13
+ from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import (
14
+ VisionEncoderDecoderConfig,
15
+ )
16
+ from transformers.utils import logging
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class CXRRGModel(VisionEncoderDecoderModel):
22
+
23
+ config_class = VisionEncoderDecoderConfig
24
+ base_model_prefix = "vision_encoder_decoder"
25
+ main_input_name = "pixel_values"
26
+ supports_gradient_checkpointing = True
27
+
28
+ def __init__(
29
+ self,
30
+ config: Optional[PretrainedConfig] = None,
31
+ encoder: Optional[PreTrainedModel] = None,
32
+ decoder: Optional[PreTrainedModel] = None,
33
+ DefaultEncoderClass = MultiUniFormerWithProjectionHead,
34
+ DefaultDecoderClass = transformers.LlamaForCausalLM,
35
+ ):
36
+
37
+ if decoder:
38
+ assert not decoder.config.add_cross_attention, '"add_cross_attention" must be False for the given decoder'
39
+ assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
40
+
41
+ if config is None and (encoder is None or decoder is None):
42
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
43
+ if config is None:
44
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
45
+ else:
46
+ if not isinstance(config, self.config_class):
47
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
48
+
49
+ config.tie_word_embeddings = False
50
+
51
+ # Initialize with config:
52
+ PreTrainedModel.__init__(self, config)
53
+
54
+ # Encoder:
55
+ if encoder is None:
56
+ encoder = DefaultEncoderClass(config=config.encoder)
57
+
58
+ # Decoder:
59
+ if decoder is None:
60
+ assert not config.decoder.add_cross_attention
61
+ decoder = DefaultDecoderClass(config=config.decoder)
62
+
63
+ self.encoder = encoder
64
+ self.decoder = decoder
65
+
66
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
67
+ logger.warning(
68
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
69
+ f" {self.config.encoder}"
70
+ )
71
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
72
+ logger.warning(
73
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
74
+ f" {self.config.decoder}"
75
+ )
76
+
77
+ self.encoder.config = self.config.encoder
78
+ self.decoder.config = self.config.decoder
79
+
80
+ assert config.decoder.is_decoder
81
+ assert 'img_token_id' in self.decoder.config.__dict__
82
+ assert 'pad_token_id' in self.decoder.config.__dict__
83
+ assert 'token_type_embeddings' in self.decoder.config.__dict__
84
+
85
+ if self.decoder.config.token_type_embeddings == 'add':
86
+ self.token_type_embeddings = torch.nn.Embedding(self.decoder.config.num_token_types, self.decoder.config.hidden_size)
87
+
88
+ def forward(
89
+ self,
90
+ pixel_values: Optional[torch.FloatTensor] = None,
91
+ decoder_input_ids: Optional[torch.LongTensor] = None,
92
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
93
+ decoder_token_type_ids: Optional[torch.LongTensor] = None,
94
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
95
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
96
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
97
+ decoder_position_ids: Optional[torch.LongTensor] = None,
98
+ labels: Optional[torch.LongTensor] = None,
99
+ use_cache: Optional[bool] = None,
100
+ output_attentions: Optional[bool] = None,
101
+ output_hidden_states: Optional[bool] = None,
102
+ return_dict: Optional[bool] = None,
103
+ **kwargs,
104
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
105
+
106
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
107
+
108
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
109
+
110
+ kwargs_decoder = {
111
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
112
+ }
113
+
114
+ if decoder_inputs_embeds is None:
115
+ decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)
116
+
117
+ if encoder_outputs is None: # Ths is for when generate() is not called; for generation, see prepare_inputs_for_generation():
118
+ if pixel_values is None:
119
+ raise ValueError("You have to specify pixel_values")
120
+
121
+ encoder_outputs = self.encoder(
122
+ pixel_values,
123
+ output_hidden_states=output_hidden_states,
124
+ return_dict=return_dict,
125
+ **kwargs_encoder,
126
+ ) # UniFormer does not support output_attentions.
127
+
128
+ assert decoder_inputs_embeds is not None
129
+ decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
130
+
131
+ # Add image token type identifiers:
132
+ decoder_token_type_ids = torch.cat(
133
+ [
134
+ torch.full(
135
+ encoder_outputs[0].shape[:-1],
136
+ self.decoder.config.img_token_id,
137
+ dtype=decoder_token_type_ids.dtype,
138
+ device=decoder_token_type_ids.device,
139
+ ),
140
+ decoder_token_type_ids
141
+ ],
142
+ dim=1,
143
+ )
144
+
145
+ # Position identifiers accounting for padding:
146
+ report_position_ids = decoder_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
147
+ report_position_ids.masked_fill_(decoder_attention_mask == 0, 1)
148
+ decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1)
149
+
150
+ # 4D attention mask:
151
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], decoder_attention_mask)
152
+
153
+ assert decoder_position_ids is not None
154
+ assert decoder_attention_mask is not None
155
+ assert decoder_token_type_ids is not None
156
+
157
+ if self.decoder.config.token_type_embeddings == 'add':
158
+ decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids)
159
+ elif self.decoder.config.token_type_embeddings == 'inbuilt':
160
+ kwargs_decoder['token_type_ids'] = decoder_token_type_ids
161
+
162
+ # Forward:
163
+ decoder_outputs = self.decoder(
164
+ inputs_embeds=decoder_inputs_embeds,
165
+ attention_mask=decoder_attention_mask,
166
+ position_ids=decoder_position_ids,
167
+ output_attentions=output_attentions,
168
+ output_hidden_states=output_hidden_states,
169
+ use_cache=use_cache,
170
+ past_key_values=past_key_values,
171
+ return_dict=return_dict,
172
+ **kwargs_decoder,
173
+ )
174
+
175
+ # Loss:
176
+ loss = None
177
+ if labels is not None:
178
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
179
+ loss_fct = CrossEntropyLoss()
180
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
181
+
182
+ if not return_dict:
183
+ if loss is not None:
184
+ return (loss,) + decoder_outputs + encoder_outputs
185
+ else:
186
+ return decoder_outputs + encoder_outputs
187
+
188
+ encoder_hidden_states = encoder_outputs[0]
189
+
190
+ return Seq2SeqLMOutput(
191
+ loss=loss,
192
+ logits=decoder_outputs.logits,
193
+ past_key_values=decoder_outputs.past_key_values,
194
+ decoder_hidden_states=decoder_outputs.hidden_states,
195
+ decoder_attentions=decoder_outputs.attentions,
196
+ encoder_last_hidden_state=encoder_hidden_states,
197
+ )
198
+
199
+ def prepare_inputs_for_generation(
200
+ self,
201
+ input_ids,
202
+ special_token_ids,
203
+ token_type_id_sections=None,
204
+ past_key_values=None,
205
+ use_cache=None,
206
+ encoder_outputs=None,
207
+ **kwargs,
208
+ ):
209
+ """
210
+ Modification of:
211
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
212
+ """
213
+
214
+ report_attention_mask = (input_ids != self.decoder.config.pad_token_id).long()
215
+
216
+ if past_key_values is None:
217
+
218
+ # 4D attention mask:
219
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], report_attention_mask)
220
+
221
+ # Position identifiers accounting for padding:
222
+ report_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
223
+ report_position_ids.masked_fill_(report_attention_mask == 0, 1)
224
+ decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1)
225
+
226
+ # `inputs_embeds` are only to be used in the 1st generation step:
227
+ inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1)
228
+
229
+ decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids, token_type_id_sections)
230
+ decoder_token_type_ids = torch.cat(
231
+ [
232
+ torch.full(
233
+ encoder_outputs[0].shape[:-1],
234
+ self.decoder.config.img_token_id,
235
+ dtype=decoder_token_type_ids.dtype,
236
+ device=decoder_token_type_ids.device,
237
+ ),
238
+ decoder_token_type_ids,
239
+ ],
240
+ dim=1,
241
+ ) # Add image token type identifiers.
242
+
243
+ input_dict = {
244
+ 'decoder_input_ids': input_ids,
245
+ 'decoder_inputs_embeds': inputs_embeds,
246
+ 'decoder_token_type_ids': decoder_token_type_ids,
247
+ }
248
+ else:
249
+
250
+ # 4D attention mask:
251
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values(encoder_outputs[1], report_attention_mask)
252
+
253
+ # Position identifiers accounting for padding:
254
+ decoder_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
255
+ decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
256
+
257
+ # Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]:
258
+ decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids, token_type_id_sections)
259
+ decoder_position_ids = decoder_position_ids[:, -1:]
260
+
261
+ past_length = past_key_values[0][0].shape[2]
262
+
263
+ # Some generation methods only pass the last input ID:
264
+ if input_ids.shape[1] > past_length:
265
+ remove_prefix_length = past_length
266
+ else:
267
+ # Keep only the final ID:
268
+ remove_prefix_length = input_ids.shape[1] - 1
269
+
270
+ input_ids = input_ids[:, remove_prefix_length:]
271
+
272
+ input_dict = {'decoder_input_ids': input_ids, 'decoder_token_type_ids': decoder_token_type_ids}
273
+
274
+ input_dict.update(
275
+ {
276
+ 'decoder_attention_mask': decoder_attention_mask,
277
+ 'decoder_position_ids': decoder_position_ids,
278
+ 'encoder_outputs': encoder_outputs,
279
+ 'past_key_values': past_key_values,
280
+ 'use_cache': use_cache,
281
+ }
282
+ )
283
+ return input_dict
284
+
285
+ def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
286
+ """
287
+ Extract token type identifiers from the token identifiers.
288
+
289
+ Argument/s:
290
+ token_ids - token identifiers.
291
+ special_token_ids - special token identifiers that indicate the separation between sections.
292
+ token_type_id_section - token type identifier for each section.
293
+
294
+ Returns:
295
+ token_type_ids - token type identifiers.
296
+ """
297
+
298
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
299
+
300
+ mbatch_size, seq_len = token_ids.shape
301
+ token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
302
+
303
+ for i, j in enumerate(special_token_ids):
304
+ # Find first occurrence of special tokens that indicate the boundary between sections:
305
+ cols = (token_ids == j).int().argmax(dim=1)
306
+ rows = torch.arange(mbatch_size, device=token_ids.device)
307
+
308
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
309
+ cols += 1
310
+
311
+ # Ensure that the column index is not out of bounds. If 0, then token_id not present.
312
+ # This is safe as index 0 is always a special token (now equal to 1 due to +1):
313
+ rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
314
+ cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
315
+
316
+ # Indices to that correspond to the second sequence:
317
+ if rows.nelement() != 0:
318
+ ids = torch.stack([
319
+ torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
320
+ y, seq_len, device=token_ids.device,
321
+ )
322
+ ])
323
+
324
+ token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
325
+
326
+ return token_type_ids
327
+
328
+ def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
329
+ """
330
+ Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
331
+ token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
332
+
333
+ Argument/s:
334
+ token_ids - token identifiers.
335
+ special_token_ids - special token identifiers that indicate the separation between sections.
336
+
337
+ Returns:
338
+ token_type_ids - token type identifiers.
339
+ """
340
+
341
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
342
+ token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
343
+
344
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
345
+ token_ids = token_ids[:, :-1]
346
+
347
+ for i, j in enumerate(special_token_ids):
348
+
349
+ # Find first occurrence of special token, which indicates the boundary between sections:
350
+ exists = torch.any(token_ids == j, dim=1, keepdim=True)
351
+ token_type_ids[exists] = token_type_id_sections[i + 1]
352
+
353
+ return token_type_ids
354
+
355
+ def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
356
+ """
357
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
358
+
359
+ Argument/s:
360
+ findings - findings sections.
361
+ impression - impression sections.
362
+ return_token_type_ids - return the token type identifiers.
363
+ tokenizer - Hugging Face tokenizer.
364
+ max_len - maximum number of tokens.
365
+
366
+ Returns:
367
+ decoder_input_ids - the token identifiers for the input of the decoder.
368
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
369
+ label_ids - the label token identifiers for the decoder.
370
+ """
371
+
372
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
373
+ reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
374
+ zip(findings, impression)]
375
+
376
+ # Tokenize the report:
377
+ tokenized = tokenizer(
378
+ reports,
379
+ padding='longest',
380
+ truncation=True,
381
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
382
+ return_tensors='pt',
383
+ return_token_type_ids=False,
384
+ add_special_tokens=False,
385
+ ).to(self.device)
386
+
387
+ # Modify for language modelling:
388
+ batch_dict = {
389
+
390
+ # Labels for the decoder (shifted right by one for autoregression):
391
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
392
+
393
+ # Remove last token identifier to match the sequence length of the labels:
394
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
395
+
396
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
397
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
398
+ }
399
+
400
+ return batch_dict
401
+
402
+ def tokenize_report_teacher_forcing_rev_a(self, tokenizer: PreTrainedTokenizerFast, max_len: int, findings: Optional[str] = None, impression: Optional[str] = None, reports: Optional[str] = None):
403
+ """
404
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
405
+
406
+ Argument/s:
407
+ tokenizer - Hugging Face tokenizer.
408
+ max_len - maximum number of tokens.
409
+ findings - findings sections.
410
+ impression - impression sections.
411
+ reports - prepared reports, with special tokens and report sections.
412
+
413
+ Returns:
414
+ decoder_input_ids - the token identifiers for the input of the decoder.
415
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
416
+ label_ids - the label token identifiers for the decoder.
417
+ """
418
+
419
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
420
+ if reports is None:
421
+ assert findings and impression, "If 'reports' is not defined, 'findings' and 'impression' need to be defined."
422
+ reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
423
+ zip(findings, impression)]
424
+
425
+ # Tokenize the report:
426
+ tokenized = tokenizer(
427
+ reports,
428
+ padding='longest',
429
+ truncation=True,
430
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
431
+ return_tensors='pt',
432
+ return_token_type_ids=False,
433
+ add_special_tokens=False,
434
+ ).to(self.device)
435
+
436
+ # Modify for language modelling:
437
+ batch_dict = {
438
+
439
+ # Labels for the decoder (shifted right by one for autoregression):
440
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
441
+
442
+ # Remove last token identifier to match the sequence length of the labels:
443
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
444
+
445
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
446
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
447
+ }
448
+
449
+ return batch_dict
450
+
451
+ def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
452
+ """
453
+ Split the token identifiers into sections, then convert the token identifiers into strings.
454
+
455
+ Argument/s:
456
+ token_ids - token identifiers.
457
+ special_token_ids - special token identifiers that indicate the end of each section.
458
+ tokenizer - Hugging Face tokenizer.
459
+
460
+ Returns:
461
+ token_type_ids - token type identifiers.
462
+ """
463
+
464
+ _, seq_len = token_ids.shape
465
+
466
+ # The number of sections is the same as the number of special_token_ids:
467
+ num_sections = len(special_token_ids)
468
+
469
+ sections = {k: [] for k in range(num_sections)}
470
+
471
+ for i in token_ids:
472
+ prev_col = 0
473
+ for j, k in enumerate(special_token_ids):
474
+
475
+ # The maximum sequence length was exceeded, thus no more tokens:
476
+ if prev_col >= seq_len:
477
+ sections[j].append('')
478
+ continue
479
+
480
+ # Find first occurrence of special tokens that indicate the boundary between sections:
481
+ col = (i == k).int().argmax().item()
482
+
483
+ # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
484
+ # the maximum sequence length):
485
+ if col == 0:
486
+ col = seq_len
487
+
488
+ # Extract section token identifiers:
489
+ section_token_ids = i[prev_col:col]
490
+ prev_col = col
491
+ section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
492
+
493
+ sections[j].append(section_string)
494
+
495
+ return tuple(sections.values())
496
+
497
+ @staticmethod
498
+ def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
499
+
500
+ prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
501
+ report_seq_len = causal_2d_attention_mask.shape[-1]
502
+
503
+ non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
504
+ causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
505
+
506
+ # Upper left of attention matrix:
507
+ upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
508
+ upper_left = upper_left * non_causal_2d_attention_mask
509
+ upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
510
+
511
+ causal_mask = torch.tril(
512
+ torch.ones(
513
+ (
514
+ report_seq_len,
515
+ report_seq_len,
516
+ ),
517
+ dtype=torch.long,
518
+ device=causal_2d_attention_mask.device,
519
+ ),
520
+ )
521
+
522
+ # Lower right of attention matrix:
523
+ lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
524
+ lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
525
+ lower_right = lower_right * causal_mask
526
+
527
+ # Upper right of attention matrix:
528
+ upper_right = torch.zeros(
529
+ causal_2d_attention_mask.shape[0],
530
+ 1,
531
+ prompt_seq_len,
532
+ report_seq_len,
533
+ dtype=torch.long,
534
+ device=causal_2d_attention_mask.device,
535
+ )
536
+
537
+ # Lower left of attention matrix:
538
+ lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
539
+ lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
540
+
541
+ left = torch.cat((upper_left, lower_left), dim=2)
542
+ right = torch.cat((upper_right, lower_right), dim=2)
543
+
544
+ mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
545
+ return mixed_causality_4d_attention_mask
546
+
547
+ @staticmethod
548
+ def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask):
549
+
550
+ non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
551
+ causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
552
+
553
+ mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
554
+ return mixed_causality_4d_attention_mask