from typing import Optional, Tuple, Union import torch import transformers from torch.nn import CrossEntropyLoss from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import Seq2SeqLMOutput from transformers.modeling_utils import PreTrainedModel from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import \ VisionEncoderDecoderConfig from transformers.utils import logging from .modelling_uniformer import MultiUniFormerWithProjectionHead logger = logging.get_logger(__name__) class CXRRGModel(VisionEncoderDecoderModel): config_class = VisionEncoderDecoderConfig base_model_prefix = "vision_encoder_decoder" main_input_name = "pixel_values" supports_gradient_checkpointing = True def __init__( self, config: Optional[PretrainedConfig] = None, encoder: Optional[PreTrainedModel] = None, decoder: Optional[PreTrainedModel] = None, DefaultEncoderClass = MultiUniFormerWithProjectionHead, DefaultDecoderClass = transformers.LlamaForCausalLM, ): if decoder: assert not decoder.config.add_cross_attention, '"add_cross_attention" must be False for the given decoder' assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder' if config is None and (encoder is None or decoder is None): raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") if config is None: config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) else: if not isinstance(config, self.config_class): raise ValueError(f"Config: {config} has to be of type {self.config_class}") config.tie_word_embeddings = False # Initialize with config: PreTrainedModel.__init__(self, config) # Encoder: if encoder is None: encoder = DefaultEncoderClass(config=config.encoder) # Decoder: if decoder is None: assert not config.decoder.add_cross_attention decoder = DefaultDecoderClass(config=config.decoder) self.encoder = encoder self.decoder = decoder if self.encoder.config.to_dict() != self.config.encoder.to_dict(): logger.warning( f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" f" {self.config.encoder}" ) if self.decoder.config.to_dict() != self.config.decoder.to_dict(): logger.warning( f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" f" {self.config.decoder}" ) self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder assert config.decoder.is_decoder assert 'img_token_id' in self.decoder.config.__dict__ assert 'pad_token_id' in self.decoder.config.__dict__ assert 'token_type_embeddings' in self.decoder.config.__dict__ if self.decoder.config.token_type_embeddings == 'add': self.token_type_embeddings = torch.nn.Embedding(self.decoder.config.num_token_types, self.decoder.config.hidden_size) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.FloatTensor] = None, decoder_token_type_ids: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, decoder_position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } if decoder_inputs_embeds is None: decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids) if encoder_outputs is None: # Ths is for when generate() is not called; for generation, see prepare_inputs_for_generation(): if pixel_values is None: raise ValueError("You have to specify pixel_values") encoder_outputs = self.encoder( pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) # UniFormer does not support output_attentions. assert decoder_inputs_embeds is not None decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1) # Add image token type identifiers: decoder_token_type_ids = torch.cat( [ torch.full( encoder_outputs[0].shape[:-1], self.decoder.config.img_token_id, dtype=decoder_token_type_ids.dtype, device=decoder_token_type_ids.device, ), decoder_token_type_ids ], dim=1, ) # Position identifiers accounting for padding: report_position_ids = decoder_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None] report_position_ids.masked_fill_(decoder_attention_mask == 0, 1) decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1) # 4D attention mask: decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], decoder_attention_mask) assert decoder_position_ids is not None assert decoder_attention_mask is not None assert decoder_token_type_ids is not None if self.decoder.config.token_type_embeddings == 'add': decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids) elif self.decoder.config.token_type_embeddings == 'inbuilt': kwargs_decoder['token_type_ids'] = decoder_token_type_ids # Forward: decoder_outputs = self.decoder( inputs_embeds=decoder_inputs_embeds, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, **kwargs_decoder, ) # Loss: loss = None if labels is not None: logits = decoder_outputs.logits if return_dict else decoder_outputs[0] loss_fct = CrossEntropyLoss() loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1)) if not return_dict: if loss is not None: return (loss,) + decoder_outputs + encoder_outputs else: return decoder_outputs + encoder_outputs encoder_hidden_states = encoder_outputs[0] return Seq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, use_cache=None, encoder_outputs=None, **kwargs, ): """ Modification of: https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660 """ report_attention_mask = (input_ids != self.decoder.config.pad_token_id).long() if past_key_values is None: # 4D attention mask: decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], report_attention_mask) # Position identifiers accounting for padding: report_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None] report_position_ids.masked_fill_(report_attention_mask == 0, 1) decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1) # `inputs_embeds` are only to be used in the 1st generation step: inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1) decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids) decoder_token_type_ids = torch.cat( [ torch.full( encoder_outputs[0].shape[:-1], self.decoder.config.img_token_id, dtype=decoder_token_type_ids.dtype, device=decoder_token_type_ids.device, ), decoder_token_type_ids, ], dim=1, ) # Add image token type identifiers. input_dict = { 'decoder_input_ids': input_ids, 'decoder_inputs_embeds': inputs_embeds, 'decoder_token_type_ids': decoder_token_type_ids, } else: # 4D attention mask: decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values(encoder_outputs[1], report_attention_mask) # Position identifiers accounting for padding: decoder_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None] decoder_position_ids.masked_fill_(report_attention_mask == 0, 1) # Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]: decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids) decoder_position_ids = decoder_position_ids[:, -1:] past_length = past_key_values[0][0].shape[2] # Some generation methods only pass the last input ID: if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Keep only the final ID: remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] input_dict = {'decoder_input_ids': input_ids, 'decoder_token_type_ids': decoder_token_type_ids} input_dict.update( { 'decoder_attention_mask': decoder_attention_mask, 'decoder_position_ids': decoder_position_ids, 'encoder_outputs': encoder_outputs, 'past_key_values': past_key_values, 'use_cache': use_cache, } ) return input_dict def token_ids_to_token_type_ids(self, token_ids): """ Extract token type identifiers from the token identifiers. Argument/s: token_ids - token identifiers. token_type_id_section - token type identifier for each section. Returns: token_type_ids - token type identifiers. """ mbatch_size, seq_len = token_ids.shape token_type_ids = torch.full_like(token_ids, self.config.decoder.section_ids[0], dtype=torch.long, device=token_ids.device) for i, j in enumerate(self.config.decoder.separator_token_ids): # Find first occurrence of special tokens that indicate the boundary between sections: cols = (token_ids == j).int().argmax(dim=1) rows = torch.arange(mbatch_size, device=token_ids.device) # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example cols += 1 # Ensure that the column index is not out of bounds. If 0, then token_id not present. # This is safe as index 0 is always a special token (now equal to 1 due to +1): rows = rows[torch.logical_and(cols != 1, cols < seq_len)] cols = cols[torch.logical_and(cols != 1, cols < seq_len)] # Indices to that correspond to the second sequence: if rows.nelement() != 0: ids = torch.stack([ torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange( y, seq_len, device=token_ids.device, ) ]) token_type_ids[ids[:, 0], ids[:, 1]] = self.config.decoder.section_ids[i + 1] return token_type_ids def token_ids_to_token_type_ids_past(self, token_ids): """ Extract token type identifiers from the token identifiers if past != None. Make sure to input all the token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation). Argument/s: token_ids - token identifiers. Returns: token_type_ids - token type identifiers. """ token_type_ids = torch.full([token_ids.shape[0], 1], self.config.decoder.section_ids[0], dtype=torch.long, device=token_ids.device) # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example token_ids = token_ids[:, :-1] for i, j in enumerate(self.config.decoder.separator_token_ids): # Find first occurrence of special token, which indicates the boundary between sections: exists = torch.any(token_ids == j, dim=1, keepdim=True) token_type_ids[exists] = self.config.decoder.section_ids[i + 1] return token_type_ids def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int): """ Tokenize the reports and creates the inputs and targets for teacher forcing. Argument/s: findings - findings sections. impression - impression sections. return_token_type_ids - return the token type identifiers. tokenizer - Hugging Face tokenizer. max_len - maximum number of tokens. Returns: decoder_input_ids - the token identifiers for the input of the decoder. decoder_attention_mask - the attention mask for the decoder_input_ids. label_ids - the label token identifiers for the decoder. """ # Prepare the sections for the tokenizer by placing special tokens between each section: reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in zip(findings, impression)] # Tokenize the report: tokenized = tokenizer( reports, padding='longest', truncation=True, max_length=max_len + 1, # +1 to account for the bias between input and target. return_tensors='pt', return_token_type_ids=False, add_special_tokens=False, ).to(self.device) # Modify for language modelling: batch_dict = { # Labels for the decoder (shifted right by one for autoregression): 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(), # Remove last token identifier to match the sequence length of the labels: 'decoder_input_ids': tokenized['input_ids'][:, :-1], # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered): 'decoder_attention_mask': tokenized['attention_mask'][:, 1:], } return batch_dict def tokenize_report_teacher_forcing_rev_a(self, tokenizer: PreTrainedTokenizerFast, max_len: int, findings: Optional[str] = None, impression: Optional[str] = None, reports: Optional[str] = None): """ Tokenize the reports and creates the inputs and targets for teacher forcing. Argument/s: tokenizer - Hugging Face tokenizer. max_len - maximum number of tokens. findings - findings sections. impression - impression sections. reports - prepared reports, with special tokens and report sections. Returns: decoder_input_ids - the token identifiers for the input of the decoder. decoder_attention_mask - the attention mask for the decoder_input_ids. label_ids - the label token identifiers for the decoder. """ # Prepare the sections for the tokenizer by placing special tokens between each section: if reports is None: assert findings and impression, "If 'reports' is not defined, 'findings' and 'impression' need to be defined." reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in zip(findings, impression)] # Tokenize the report: tokenized = tokenizer( reports, padding='longest', truncation=True, max_length=max_len + 1, # +1 to account for the bias between input and target. return_tensors='pt', return_token_type_ids=False, add_special_tokens=False, ).to(self.device) # Modify for language modelling: batch_dict = { # Labels for the decoder (shifted right by one for autoregression): 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(), # Remove last token identifier to match the sequence length of the labels: 'decoder_input_ids': tokenized['input_ids'][:, :-1], # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered): 'decoder_attention_mask': tokenized['attention_mask'][:, 1:], } return batch_dict def split_and_decode_sections(self, token_ids, tokenizer: PreTrainedTokenizerFast): """ Split the token identifiers into sections, then convert the token identifiers into strings. Argument/s: token_ids - token identifiers. tokenizer - Hugging Face tokenizer. Returns: token_type_ids - token type identifiers. """ _, seq_len = token_ids.shape # The number of sections is the same as the number of separator_token_ids: num_sections = len(self.config.decoder.end_of_section_token_ids) sections = {k: [] for k in range(num_sections)} for i in token_ids: prev_col = 0 for j, k in enumerate(self.config.decoder.end_of_section_token_ids): # The maximum sequence length was exceeded, thus no more tokens: if prev_col >= seq_len: sections[j].append('') continue # Find first occurrence of special tokens that indicate the boundary between sections: col = (i == k).int().argmax().item() # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded # the maximum sequence length): if col == 0: col = seq_len # Extract section token identifiers: section_token_ids = i[prev_col:col] prev_col = col section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True) sections[j].append(section_string) return tuple(sections.values()) @staticmethod def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask): prompt_seq_len = non_causal_2d_attention_mask.shape[-1] report_seq_len = causal_2d_attention_mask.shape[-1] non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :] causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :] # Upper left of attention matrix: upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1) upper_left = upper_left * non_causal_2d_attention_mask upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2) causal_mask = torch.tril( torch.ones( ( report_seq_len, report_seq_len, ), dtype=torch.long, device=causal_2d_attention_mask.device, ), ) # Lower right of attention matrix: lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1) lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2) lower_right = lower_right * causal_mask # Upper right of attention matrix: upper_right = torch.zeros( causal_2d_attention_mask.shape[0], 1, prompt_seq_len, report_seq_len, dtype=torch.long, device=causal_2d_attention_mask.device, ) # Lower left of attention matrix: lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1) lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2) left = torch.cat((upper_left, lower_left), dim=2) right = torch.cat((upper_right, lower_right), dim=2) mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1) return mixed_causality_4d_attention_mask @staticmethod def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask): non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :] causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :] mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1) return mixed_causality_4d_attention_mask