Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
anicolson commited on
Commit
1ee4361
1 Parent(s): c79748a

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MIMICCXRMultimodalModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoModel": "modelling_mimic_cxr_rev_d.MIMICCXRMultimodalModel"
7
+ },
8
+ "decoder": {
9
+ "_name_or_path": "",
10
+ "add_cross_attention": false,
11
+ "architectures": null,
12
+ "attention_bias": false,
13
+ "attention_dropout": 0.0,
14
+ "bad_words_ids": null,
15
+ "begin_suppress_tokens": null,
16
+ "bos_token_id": 1,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "early_stopping": false,
23
+ "encoder_no_repeat_ngram_size": 0,
24
+ "eos_token_id": 2,
25
+ "exponential_decay_length_penalty": null,
26
+ "finetuning_task": null,
27
+ "forced_bos_token_id": null,
28
+ "forced_eos_token_id": null,
29
+ "hidden_act": "silu",
30
+ "hidden_size": 768,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1"
34
+ },
35
+ "img_token_id": 0,
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 3072,
38
+ "is_decoder": true,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "length_penalty": 1.0,
45
+ "max_length": 20,
46
+ "max_position_embeddings": 2048,
47
+ "min_length": 0,
48
+ "model_type": "llama",
49
+ "no_repeat_ngram_size": 0,
50
+ "num_attention_heads": 12,
51
+ "num_beam_groups": 1,
52
+ "num_beams": 1,
53
+ "num_hidden_layers": 6,
54
+ "num_key_value_heads": 12,
55
+ "num_return_sequences": 1,
56
+ "num_token_types": 3,
57
+ "output_attentions": false,
58
+ "output_hidden_states": false,
59
+ "output_scores": false,
60
+ "pad_token_id": 4,
61
+ "prefix": null,
62
+ "pretraining_tp": 1,
63
+ "problem_type": null,
64
+ "pruned_heads": {},
65
+ "remove_invalid_values": false,
66
+ "repetition_penalty": 1.0,
67
+ "return_dict": true,
68
+ "return_dict_in_generate": false,
69
+ "rms_norm_eps": 1e-06,
70
+ "rope_scaling": null,
71
+ "rope_theta": 10000.0,
72
+ "sep_token_id": null,
73
+ "suppress_tokens": null,
74
+ "task_specific_params": null,
75
+ "temperature": 1.0,
76
+ "tf_legacy_loss": false,
77
+ "tie_encoder_decoder": false,
78
+ "tie_word_embeddings": false,
79
+ "token_type_embeddings": "add",
80
+ "tokenizer_class": null,
81
+ "top_k": 50,
82
+ "top_p": 1.0,
83
+ "torch_dtype": null,
84
+ "torchscript": false,
85
+ "typical_p": 1.0,
86
+ "use_bfloat16": false,
87
+ "use_cache": true,
88
+ "vocab_size": 30000
89
+ },
90
+ "encoder": {
91
+ "_name_or_path": "",
92
+ "add_cross_attention": false,
93
+ "architectures": null,
94
+ "attention_probs_dropout_prob": 0.0,
95
+ "attn_drop_rate": 0.0,
96
+ "bad_words_ids": null,
97
+ "begin_suppress_tokens": null,
98
+ "bos_token_id": null,
99
+ "chunk_size_feed_forward": 0,
100
+ "conv_stem": false,
101
+ "cross_attention_hidden_size": null,
102
+ "decoder_start_token_id": null,
103
+ "depth": [
104
+ 5,
105
+ 8,
106
+ 20,
107
+ 7
108
+ ],
109
+ "diversity_penalty": 0.0,
110
+ "do_sample": false,
111
+ "drop_path_rate": 0.3,
112
+ "drop_rate": 0.0,
113
+ "early_stopping": false,
114
+ "embed_dim": [
115
+ 64,
116
+ 128,
117
+ 320,
118
+ 512
119
+ ],
120
+ "encoder_no_repeat_ngram_size": 0,
121
+ "encoder_stride": 16,
122
+ "eos_token_id": null,
123
+ "exponential_decay_length_penalty": null,
124
+ "finetuning_task": null,
125
+ "forced_bos_token_id": null,
126
+ "forced_eos_token_id": null,
127
+ "head_dim": 64,
128
+ "hidden_act": "gelu",
129
+ "hidden_dropout_prob": 0.0,
130
+ "hidden_size": 768,
131
+ "id2label": {
132
+ "0": "LABEL_0",
133
+ "1": "LABEL_1"
134
+ },
135
+ "image_size": 384,
136
+ "in_chans": 3,
137
+ "initializer_range": 0.02,
138
+ "intermediate_size": 3072,
139
+ "is_decoder": false,
140
+ "is_encoder_decoder": false,
141
+ "label2id": {
142
+ "LABEL_0": 0,
143
+ "LABEL_1": 1
144
+ },
145
+ "layer_norm_eps": 1e-06,
146
+ "length_penalty": 1.0,
147
+ "max_length": 20,
148
+ "min_length": 0,
149
+ "mlp_ratio": 4,
150
+ "model_type": "vit",
151
+ "no_repeat_ngram_size": 0,
152
+ "num_attention_heads": 12,
153
+ "num_beam_groups": 1,
154
+ "num_beams": 1,
155
+ "num_channels": 3,
156
+ "num_classes": 1000,
157
+ "num_hidden_layers": 12,
158
+ "num_return_sequences": 1,
159
+ "output_attentions": false,
160
+ "output_hidden_states": false,
161
+ "output_scores": false,
162
+ "pad_token_id": null,
163
+ "patch_size": [
164
+ 4,
165
+ 2,
166
+ 2,
167
+ 2
168
+ ],
169
+ "prefix": null,
170
+ "problem_type": null,
171
+ "projection_size": 768,
172
+ "pruned_heads": {},
173
+ "qk_scale": null,
174
+ "qkv_bias": true,
175
+ "remove_invalid_values": false,
176
+ "repetition_penalty": 1.0,
177
+ "representation_size": null,
178
+ "return_dict": true,
179
+ "return_dict_in_generate": false,
180
+ "sep_token_id": null,
181
+ "suppress_tokens": null,
182
+ "task_specific_params": null,
183
+ "temperature": 1.0,
184
+ "tf_legacy_loss": false,
185
+ "tie_encoder_decoder": false,
186
+ "tie_word_embeddings": true,
187
+ "tokenizer_class": null,
188
+ "top_k": 50,
189
+ "top_p": 1.0,
190
+ "torch_dtype": null,
191
+ "torchscript": false,
192
+ "typical_p": 1.0,
193
+ "use_bfloat16": false
194
+ },
195
+ "is_encoder_decoder": true,
196
+ "model_type": "vision-encoder-decoder",
197
+ "tie_word_embeddings": false,
198
+ "torch_dtype": "float32",
199
+ "transformers_version": "4.39.0"
200
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 4,
6
+ "transformers_version": "4.39.0"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff2867bd60bbde19061a835cff6d05b67cbfc89249fcfec039329a4a8c5b5e23
3
+ size 609603000
modelling_mimic_cxr_rev_d.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import transformers
7
+ from torch.nn import CrossEntropyLoss, Linear
8
+ from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import (
13
+ VisionEncoderDecoderConfig,
14
+ )
15
+ from transformers.utils import logging
16
+
17
+ from modules.transformers.uniformer.modelling_uniformer import (
18
+ MultiUniFormerWithProjectionHead,
19
+ )
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class MIMICCXRMultimodalModel(VisionEncoderDecoderModel):
25
+
26
+ config_class = VisionEncoderDecoderConfig
27
+ base_model_prefix = "vision_encoder_decoder"
28
+ main_input_name = "pixel_values"
29
+ supports_gradient_checkpointing = True
30
+
31
+ def __init__(
32
+ self,
33
+ config: Optional[PretrainedConfig] = None,
34
+ encoder: Optional[PreTrainedModel] = None,
35
+ decoder: Optional[PreTrainedModel] = None,
36
+ DefaultEncoderClass = MultiUniFormerWithProjectionHead,
37
+ DefaultDecoderClass = transformers.LlamaForCausalLM,
38
+ ):
39
+
40
+ if decoder:
41
+ assert not decoder.config.add_cross_attention, '"add_cross_attention" must be False for the given decoder'
42
+ assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
43
+
44
+ if config is None and (encoder is None or decoder is None):
45
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
46
+ if config is None:
47
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
48
+ else:
49
+ if not isinstance(config, self.config_class):
50
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
51
+
52
+ config.tie_word_embeddings = False
53
+
54
+ # Initialize with config:
55
+ PreTrainedModel.__init__(self, config)
56
+
57
+ # Encoder:
58
+ if encoder is None:
59
+ encoder = DefaultEncoderClass(config=config.encoder)
60
+
61
+ # Decoder:
62
+ if decoder is None:
63
+ assert not config.decoder.add_cross_attention
64
+ decoder = DefaultDecoderClass(config=config.decoder)
65
+
66
+ self.encoder = encoder
67
+ self.decoder = decoder
68
+
69
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
70
+ logger.warning(
71
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
72
+ f" {self.config.encoder}"
73
+ )
74
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
75
+ logger.warning(
76
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
77
+ f" {self.config.decoder}"
78
+ )
79
+
80
+ self.encoder.config = self.config.encoder
81
+ self.decoder.config = self.config.decoder
82
+
83
+ assert config.decoder.is_decoder
84
+ assert 'img_token_id' in self.decoder.config.__dict__
85
+ assert 'pad_token_id' in self.decoder.config.__dict__
86
+ assert 'token_type_embeddings' in self.decoder.config.__dict__
87
+
88
+ if self.decoder.config.token_type_embeddings == 'add':
89
+ self.token_type_embeddings = torch.nn.Embedding(self.decoder.config.num_token_types, self.decoder.config.hidden_size)
90
+
91
+ def forward(
92
+ self,
93
+ pixel_values: Optional[torch.FloatTensor] = None,
94
+ decoder_input_ids: Optional[torch.LongTensor] = None,
95
+ decoder_attention_mask: Optional[torch.FloatTensor] = None,
96
+ decoder_token_type_ids: Optional[torch.LongTensor] = None,
97
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
98
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
99
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
100
+ decoder_position_ids: Optional[torch.LongTensor] = None,
101
+ labels: Optional[torch.LongTensor] = None,
102
+ use_cache: Optional[bool] = None,
103
+ output_attentions: Optional[bool] = None,
104
+ output_hidden_states: Optional[bool] = None,
105
+ return_dict: Optional[bool] = None,
106
+ **kwargs,
107
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
108
+
109
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
110
+
111
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
112
+
113
+ kwargs_decoder = {
114
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
115
+ }
116
+
117
+ if decoder_inputs_embeds is None:
118
+ decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)
119
+
120
+ if encoder_outputs is None: # Ths is for when generate() is not called; for generation, see prepare_inputs_for_generation():
121
+ if pixel_values is None:
122
+ raise ValueError("You have to specify pixel_values")
123
+
124
+ encoder_outputs = self.encoder(
125
+ pixel_values,
126
+ output_hidden_states=output_hidden_states,
127
+ return_dict=return_dict,
128
+ **kwargs_encoder,
129
+ ) # UniFormer does not support output_attentions.
130
+
131
+ assert decoder_inputs_embeds is not None
132
+ decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
133
+
134
+ # Add image token type identifiers:
135
+ decoder_token_type_ids = torch.cat(
136
+ [
137
+ torch.full(
138
+ encoder_outputs[0].shape[:-1],
139
+ self.decoder.config.img_token_id,
140
+ dtype=decoder_token_type_ids.dtype,
141
+ device=decoder_token_type_ids.device,
142
+ ),
143
+ decoder_token_type_ids
144
+ ],
145
+ dim=1,
146
+ )
147
+
148
+ # Position identifiers accounting for padding:
149
+ report_position_ids = decoder_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
150
+ report_position_ids.masked_fill_(decoder_attention_mask == 0, 1)
151
+ decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1)
152
+
153
+ # 4D attention mask:
154
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], decoder_attention_mask)
155
+
156
+ assert decoder_position_ids is not None
157
+ assert decoder_attention_mask is not None
158
+ assert decoder_token_type_ids is not None
159
+
160
+ if self.decoder.config.token_type_embeddings == 'add':
161
+ decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids)
162
+ elif self.decoder.config.token_type_embeddings == 'inbuilt':
163
+ kwargs_decoder['token_type_ids'] = decoder_token_type_ids
164
+
165
+ # Forward:
166
+ decoder_outputs = self.decoder(
167
+ inputs_embeds=decoder_inputs_embeds,
168
+ attention_mask=decoder_attention_mask,
169
+ position_ids=decoder_position_ids,
170
+ output_attentions=output_attentions,
171
+ output_hidden_states=output_hidden_states,
172
+ use_cache=use_cache,
173
+ past_key_values=past_key_values,
174
+ return_dict=return_dict,
175
+ **kwargs_decoder,
176
+ )
177
+
178
+ # Loss:
179
+ loss = None
180
+ if labels is not None:
181
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
182
+ loss_fct = CrossEntropyLoss()
183
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
184
+
185
+ if not return_dict:
186
+ if loss is not None:
187
+ return (loss,) + decoder_outputs + encoder_outputs
188
+ else:
189
+ return decoder_outputs + encoder_outputs
190
+
191
+ encoder_hidden_states = encoder_outputs[0]
192
+
193
+ return Seq2SeqLMOutput(
194
+ loss=loss,
195
+ logits=decoder_outputs.logits,
196
+ past_key_values=decoder_outputs.past_key_values,
197
+ decoder_hidden_states=decoder_outputs.hidden_states,
198
+ decoder_attentions=decoder_outputs.attentions,
199
+ encoder_last_hidden_state=encoder_hidden_states,
200
+ )
201
+
202
+ def prepare_inputs_for_generation(
203
+ self,
204
+ input_ids,
205
+ special_token_ids,
206
+ token_type_id_sections=None,
207
+ past_key_values=None,
208
+ use_cache=None,
209
+ encoder_outputs=None,
210
+ **kwargs,
211
+ ):
212
+ """
213
+ Modification of:
214
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
215
+ """
216
+
217
+ report_attention_mask = (input_ids != self.decoder.config.pad_token_id).long()
218
+
219
+ if past_key_values is None:
220
+
221
+ # 4D attention mask:
222
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], report_attention_mask)
223
+
224
+ # Position identifiers accounting for padding:
225
+ report_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
226
+ report_position_ids.masked_fill_(report_attention_mask == 0, 1)
227
+ decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1)
228
+
229
+ # `inputs_embeds` are only to be used in the 1st generation step:
230
+ inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1)
231
+
232
+ decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids, token_type_id_sections)
233
+ decoder_token_type_ids = torch.cat(
234
+ [
235
+ torch.full(
236
+ encoder_outputs[0].shape[:-1],
237
+ self.decoder.config.img_token_id,
238
+ dtype=decoder_token_type_ids.dtype,
239
+ device=decoder_token_type_ids.device,
240
+ ),
241
+ decoder_token_type_ids,
242
+ ],
243
+ dim=1,
244
+ ) # Add image token type identifiers.
245
+
246
+ input_dict = {
247
+ 'decoder_input_ids': input_ids,
248
+ 'decoder_inputs_embeds': inputs_embeds,
249
+ 'decoder_token_type_ids': decoder_token_type_ids,
250
+ }
251
+ else:
252
+
253
+ # 4D attention mask:
254
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values(encoder_outputs[1], report_attention_mask)
255
+
256
+ # Position identifiers accounting for padding:
257
+ decoder_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
258
+ decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
259
+
260
+ # Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]:
261
+ decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids, token_type_id_sections)
262
+ decoder_position_ids = decoder_position_ids[:, -1:]
263
+
264
+ past_length = past_key_values[0][0].shape[2]
265
+
266
+ # Some generation methods only pass the last input ID:
267
+ if input_ids.shape[1] > past_length:
268
+ remove_prefix_length = past_length
269
+ else:
270
+ # Keep only the final ID:
271
+ remove_prefix_length = input_ids.shape[1] - 1
272
+
273
+ input_ids = input_ids[:, remove_prefix_length:]
274
+
275
+ input_dict = {'decoder_input_ids': input_ids, 'decoder_token_type_ids': decoder_token_type_ids}
276
+
277
+ input_dict.update(
278
+ {
279
+ 'decoder_attention_mask': decoder_attention_mask,
280
+ 'decoder_position_ids': decoder_position_ids,
281
+ 'encoder_outputs': encoder_outputs,
282
+ 'past_key_values': past_key_values,
283
+ 'use_cache': use_cache,
284
+ }
285
+ )
286
+ return input_dict
287
+
288
+ def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
289
+ """
290
+ Extract token type identifiers from the token identifiers.
291
+
292
+ Argument/s:
293
+ token_ids - token identifiers.
294
+ special_token_ids - special token identifiers that indicate the separation between sections.
295
+ token_type_id_section - token type identifier for each section.
296
+
297
+ Returns:
298
+ token_type_ids - token type identifiers.
299
+ """
300
+
301
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
302
+
303
+ mbatch_size, seq_len = token_ids.shape
304
+ token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
305
+
306
+ for i, j in enumerate(special_token_ids):
307
+ # Find first occurrence of special tokens that indicate the boundary between sections:
308
+ cols = (token_ids == j).int().argmax(dim=1)
309
+ rows = torch.arange(mbatch_size, device=token_ids.device)
310
+
311
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
312
+ cols += 1
313
+
314
+ # Ensure that the column index is not out of bounds. If 0, then token_id not present.
315
+ # This is safe as index 0 is always a special token (now equal to 1 due to +1):
316
+ rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
317
+ cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
318
+
319
+ # Indices to that correspond to the second sequence:
320
+ if rows.nelement() != 0:
321
+ ids = torch.stack([
322
+ torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
323
+ y, seq_len, device=token_ids.device,
324
+ )
325
+ ])
326
+
327
+ token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
328
+
329
+ return token_type_ids
330
+
331
+ def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
332
+ """
333
+ Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
334
+ token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
335
+
336
+ Argument/s:
337
+ token_ids - token identifiers.
338
+ special_token_ids - special token identifiers that indicate the separation between sections.
339
+
340
+ Returns:
341
+ token_type_ids - token type identifiers.
342
+ """
343
+
344
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
345
+ token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
346
+
347
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
348
+ token_ids = token_ids[:, :-1]
349
+
350
+ for i, j in enumerate(special_token_ids):
351
+
352
+ # Find first occurrence of special token, which indicates the boundary between sections:
353
+ exists = torch.any(token_ids == j, dim=1, keepdim=True)
354
+ token_type_ids[exists] = token_type_id_sections[i + 1]
355
+
356
+ return token_type_ids
357
+
358
+ def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
359
+ """
360
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
361
+
362
+ Argument/s:
363
+ findings - findings sections.
364
+ impression - impression sections.
365
+ return_token_type_ids - return the token type identifiers.
366
+ tokenizer - Hugging Face tokenizer.
367
+ max_len - maximum number of tokens.
368
+
369
+ Returns:
370
+ decoder_input_ids - the token identifiers for the input of the decoder.
371
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
372
+ label_ids - the label token identifiers for the decoder.
373
+ """
374
+
375
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
376
+ reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
377
+ zip(findings, impression)]
378
+
379
+ # Tokenize the report:
380
+ tokenized = tokenizer(
381
+ reports,
382
+ padding='longest',
383
+ truncation=True,
384
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
385
+ return_tensors='pt',
386
+ return_token_type_ids=False,
387
+ add_special_tokens=False,
388
+ ).to(self.device)
389
+
390
+ # Modify for language modelling:
391
+ batch_dict = {
392
+
393
+ # Labels for the decoder (shifted right by one for autoregression):
394
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
395
+
396
+ # Remove last token identifier to match the sequence length of the labels:
397
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
398
+
399
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
400
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
401
+ }
402
+
403
+ return batch_dict
404
+
405
+ def tokenize_report_teacher_forcing_rev_a(self, tokenizer: PreTrainedTokenizerFast, max_len: int, findings: Optional[str] = None, impression: Optional[str] = None, reports: Optional[str] = None):
406
+ """
407
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
408
+
409
+ Argument/s:
410
+ tokenizer - Hugging Face tokenizer.
411
+ max_len - maximum number of tokens.
412
+ findings - findings sections.
413
+ impression - impression sections.
414
+ reports - prepared reports, with special tokens and report sections.
415
+
416
+ Returns:
417
+ decoder_input_ids - the token identifiers for the input of the decoder.
418
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
419
+ label_ids - the label token identifiers for the decoder.
420
+ """
421
+
422
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
423
+ if reports is None:
424
+ assert findings and impression, "If 'reports' is not defined, 'findings' and 'impression' need to be defined."
425
+ reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
426
+ zip(findings, impression)]
427
+
428
+ # Tokenize the report:
429
+ tokenized = tokenizer(
430
+ reports,
431
+ padding='longest',
432
+ truncation=True,
433
+ max_length=max_len + 1, # +1 to account for the bias between input and target.
434
+ return_tensors='pt',
435
+ return_token_type_ids=False,
436
+ add_special_tokens=False,
437
+ ).to(self.device)
438
+
439
+ # Modify for language modelling:
440
+ batch_dict = {
441
+
442
+ # Labels for the decoder (shifted right by one for autoregression):
443
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
444
+
445
+ # Remove last token identifier to match the sequence length of the labels:
446
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
447
+
448
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
449
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
450
+ }
451
+
452
+ return batch_dict
453
+
454
+ def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
455
+ """
456
+ Split the token identifiers into sections, then convert the token identifiers into strings.
457
+
458
+ Argument/s:
459
+ token_ids - token identifiers.
460
+ special_token_ids - special token identifiers that indicate the end of each section.
461
+ tokenizer - Hugging Face tokenizer.
462
+
463
+ Returns:
464
+ token_type_ids - token type identifiers.
465
+ """
466
+
467
+ _, seq_len = token_ids.shape
468
+
469
+ # The number of sections is the same as the number of special_token_ids:
470
+ num_sections = len(special_token_ids)
471
+
472
+ sections = {k: [] for k in range(num_sections)}
473
+
474
+ for i in token_ids:
475
+ prev_col = 0
476
+ for j, k in enumerate(special_token_ids):
477
+
478
+ # The maximum sequence length was exceeded, thus no more tokens:
479
+ if prev_col >= seq_len:
480
+ sections[j].append('')
481
+ continue
482
+
483
+ # Find first occurrence of special tokens that indicate the boundary between sections:
484
+ col = (i == k).int().argmax().item()
485
+
486
+ # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
487
+ # the maximum sequence length):
488
+ if col == 0:
489
+ col = seq_len
490
+
491
+ # Extract section token identifiers:
492
+ section_token_ids = i[prev_col:col]
493
+ prev_col = col
494
+ section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
495
+
496
+ sections[j].append(section_string)
497
+
498
+ return tuple(sections.values())
499
+
500
+ @staticmethod
501
+ def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
502
+
503
+ prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
504
+ report_seq_len = causal_2d_attention_mask.shape[-1]
505
+
506
+ non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
507
+ causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
508
+
509
+ # Upper left of attention matrix:
510
+ upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
511
+ upper_left = upper_left * non_causal_2d_attention_mask
512
+ upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
513
+
514
+ causal_mask = torch.tril(
515
+ torch.ones(
516
+ (
517
+ report_seq_len,
518
+ report_seq_len,
519
+ ),
520
+ dtype=torch.long,
521
+ device=causal_2d_attention_mask.device,
522
+ ),
523
+ )
524
+
525
+ # Lower right of attention matrix:
526
+ lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
527
+ lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
528
+ lower_right = lower_right * causal_mask
529
+
530
+ # Upper right of attention matrix:
531
+ upper_right = torch.zeros(
532
+ causal_2d_attention_mask.shape[0],
533
+ 1,
534
+ prompt_seq_len,
535
+ report_seq_len,
536
+ dtype=torch.long,
537
+ device=causal_2d_attention_mask.device,
538
+ )
539
+
540
+ # Lower left of attention matrix:
541
+ lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
542
+ lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
543
+
544
+ left = torch.cat((upper_left, lower_left), dim=2)
545
+ right = torch.cat((upper_right, lower_right), dim=2)
546
+
547
+ mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
548
+ return mixed_causality_4d_attention_mask
549
+
550
+ @staticmethod
551
+ def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask):
552
+
553
+ non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
554
+ causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
555
+
556
+ mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
557
+ return mixed_causality_4d_attention_mask