Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
cxrmate-rrg24 / modelling_cxrrg.py
anicolson's picture
Upload model
3079483 verified
from typing import Optional, Tuple, Union
import torch
import transformers
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import \
VisionEncoderDecoderConfig
from transformers.utils import logging
from .modelling_uniformer import MultiUniFormerWithProjectionHead
logger = logging.get_logger(__name__)
class CXRRGModel(VisionEncoderDecoderModel):
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
DefaultEncoderClass = MultiUniFormerWithProjectionHead,
DefaultDecoderClass = transformers.LlamaForCausalLM,
):
if decoder:
assert not decoder.config.add_cross_attention, '"add_cross_attention" must be False for the given decoder'
assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
if config is None and (encoder is None or decoder is None):
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
if config is None:
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
config.tie_word_embeddings = False
# Initialize with config:
PreTrainedModel.__init__(self, config)
# Encoder:
if encoder is None:
encoder = DefaultEncoderClass(config=config.encoder)
# Decoder:
if decoder is None:
assert not config.decoder.add_cross_attention
decoder = DefaultDecoderClass(config=config.decoder)
self.encoder = encoder
self.decoder = decoder
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
f" {self.config.decoder}"
)
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
assert config.decoder.is_decoder
assert 'img_token_id' in self.decoder.config.__dict__
assert 'pad_token_id' in self.decoder.config.__dict__
assert 'token_type_embeddings' in self.decoder.config.__dict__
if self.decoder.config.token_type_embeddings == 'add':
self.token_type_embeddings = torch.nn.Embedding(self.decoder.config.num_token_types, self.decoder.config.hidden_size)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
decoder_token_type_ids: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if decoder_inputs_embeds is None:
decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)
if encoder_outputs is None: # Ths is for when generate() is not called; for generation, see prepare_inputs_for_generation():
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
encoder_outputs = self.encoder(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs_encoder,
) # UniFormer does not support output_attentions.
assert decoder_inputs_embeds is not None
decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
# Add image token type identifiers:
decoder_token_type_ids = torch.cat(
[
torch.full(
encoder_outputs[0].shape[:-1],
self.decoder.config.img_token_id,
dtype=decoder_token_type_ids.dtype,
device=decoder_token_type_ids.device,
),
decoder_token_type_ids
],
dim=1,
)
# Position identifiers accounting for padding:
report_position_ids = decoder_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
report_position_ids.masked_fill_(decoder_attention_mask == 0, 1)
decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1)
# 4D attention mask:
decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], decoder_attention_mask)
assert decoder_position_ids is not None
assert decoder_attention_mask is not None
assert decoder_token_type_ids is not None
if self.decoder.config.token_type_embeddings == 'add':
decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids)
elif self.decoder.config.token_type_embeddings == 'inbuilt':
kwargs_decoder['token_type_ids'] = decoder_token_type_ids
# Forward:
decoder_outputs = self.decoder(
inputs_embeds=decoder_inputs_embeds,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
# Loss:
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
encoder_hidden_states = encoder_outputs[0]
return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
"""
Modification of:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
"""
report_attention_mask = (input_ids != self.decoder.config.pad_token_id).long()
if past_key_values is None:
# 4D attention mask:
decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(encoder_outputs[1], report_attention_mask)
# Position identifiers accounting for padding:
report_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
report_position_ids.masked_fill_(report_attention_mask == 0, 1)
decoder_position_ids = torch.cat([encoder_outputs[1], report_position_ids], dim=1)
# `inputs_embeds` are only to be used in the 1st generation step:
inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1)
decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids)
decoder_token_type_ids = torch.cat(
[
torch.full(
encoder_outputs[0].shape[:-1],
self.decoder.config.img_token_id,
dtype=decoder_token_type_ids.dtype,
device=decoder_token_type_ids.device,
),
decoder_token_type_ids,
],
dim=1,
) # Add image token type identifiers.
input_dict = {
'decoder_input_ids': input_ids,
'decoder_inputs_embeds': inputs_embeds,
'decoder_token_type_ids': decoder_token_type_ids,
}
else:
# 4D attention mask:
decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values(encoder_outputs[1], report_attention_mask)
# Position identifiers accounting for padding:
decoder_position_ids = report_attention_mask.cumsum(-1) + encoder_outputs[1].max(dim=1).values[:, None]
decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
# Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]:
decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids)
decoder_position_ids = decoder_position_ids[:, -1:]
past_length = past_key_values[0][0].shape[2]
# Some generation methods only pass the last input ID:
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Keep only the final ID:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
input_dict = {'decoder_input_ids': input_ids, 'decoder_token_type_ids': decoder_token_type_ids}
input_dict.update(
{
'decoder_attention_mask': decoder_attention_mask,
'decoder_position_ids': decoder_position_ids,
'encoder_outputs': encoder_outputs,
'past_key_values': past_key_values,
'use_cache': use_cache,
}
)
return input_dict
def token_ids_to_token_type_ids(self, token_ids):
"""
Extract token type identifiers from the token identifiers.
Argument/s:
token_ids - token identifiers.
token_type_id_section - token type identifier for each section.
Returns:
token_type_ids - token type identifiers.
"""
mbatch_size, seq_len = token_ids.shape
token_type_ids = torch.full_like(token_ids, self.config.decoder.section_ids[0], dtype=torch.long, device=token_ids.device)
for i, j in enumerate(self.config.decoder.separator_token_ids):
# Find first occurrence of special tokens that indicate the boundary between sections:
cols = (token_ids == j).int().argmax(dim=1)
rows = torch.arange(mbatch_size, device=token_ids.device)
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
cols += 1
# Ensure that the column index is not out of bounds. If 0, then token_id not present.
# This is safe as index 0 is always a special token (now equal to 1 due to +1):
rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
# Indices to that correspond to the second sequence:
if rows.nelement() != 0:
ids = torch.stack([
torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
y, seq_len, device=token_ids.device,
)
])
token_type_ids[ids[:, 0], ids[:, 1]] = self.config.decoder.section_ids[i + 1]
return token_type_ids
def token_ids_to_token_type_ids_past(self, token_ids):
"""
Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
Argument/s:
token_ids - token identifiers.
Returns:
token_type_ids - token type identifiers.
"""
token_type_ids = torch.full([token_ids.shape[0], 1], self.config.decoder.section_ids[0], dtype=torch.long, device=token_ids.device)
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
token_ids = token_ids[:, :-1]
for i, j in enumerate(self.config.decoder.separator_token_ids):
# Find first occurrence of special token, which indicates the boundary between sections:
exists = torch.any(token_ids == j, dim=1, keepdim=True)
token_type_ids[exists] = self.config.decoder.section_ids[i + 1]
return token_type_ids
def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int):
"""
Tokenize the reports and creates the inputs and targets for teacher forcing.
Argument/s:
findings - findings sections.
impression - impression sections.
return_token_type_ids - return the token type identifiers.
tokenizer - Hugging Face tokenizer.
max_len - maximum number of tokens.
Returns:
decoder_input_ids - the token identifiers for the input of the decoder.
decoder_attention_mask - the attention mask for the decoder_input_ids.
label_ids - the label token identifiers for the decoder.
"""
# Prepare the sections for the tokenizer by placing special tokens between each section:
reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
zip(findings, impression)]
# Tokenize the report:
tokenized = tokenizer(
reports,
padding='longest',
truncation=True,
max_length=max_len + 1, # +1 to account for the bias between input and target.
return_tensors='pt',
return_token_type_ids=False,
add_special_tokens=False,
).to(self.device)
# Modify for language modelling:
batch_dict = {
# Labels for the decoder (shifted right by one for autoregression):
'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
# Remove last token identifier to match the sequence length of the labels:
'decoder_input_ids': tokenized['input_ids'][:, :-1],
# Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
}
return batch_dict
def tokenize_report_teacher_forcing_rev_a(self, tokenizer: PreTrainedTokenizerFast, max_len: int, findings: Optional[str] = None, impression: Optional[str] = None, reports: Optional[str] = None):
"""
Tokenize the reports and creates the inputs and targets for teacher forcing.
Argument/s:
tokenizer - Hugging Face tokenizer.
max_len - maximum number of tokens.
findings - findings sections.
impression - impression sections.
reports - prepared reports, with special tokens and report sections.
Returns:
decoder_input_ids - the token identifiers for the input of the decoder.
decoder_attention_mask - the attention mask for the decoder_input_ids.
label_ids - the label token identifiers for the decoder.
"""
# Prepare the sections for the tokenizer by placing special tokens between each section:
if reports is None:
assert findings and impression, "If 'reports' is not defined, 'findings' and 'impression' need to be defined."
reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
zip(findings, impression)]
# Tokenize the report:
tokenized = tokenizer(
reports,
padding='longest',
truncation=True,
max_length=max_len + 1, # +1 to account for the bias between input and target.
return_tensors='pt',
return_token_type_ids=False,
add_special_tokens=False,
).to(self.device)
# Modify for language modelling:
batch_dict = {
# Labels for the decoder (shifted right by one for autoregression):
'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
# Remove last token identifier to match the sequence length of the labels:
'decoder_input_ids': tokenized['input_ids'][:, :-1],
# Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
}
return batch_dict
def split_and_decode_sections(self, token_ids, tokenizer: PreTrainedTokenizerFast):
"""
Split the token identifiers into sections, then convert the token identifiers into strings.
Argument/s:
token_ids - token identifiers.
tokenizer - Hugging Face tokenizer.
Returns:
token_type_ids - token type identifiers.
"""
_, seq_len = token_ids.shape
# The number of sections is the same as the number of separator_token_ids:
num_sections = len(self.config.decoder.end_of_section_token_ids)
sections = {k: [] for k in range(num_sections)}
for i in token_ids:
prev_col = 0
for j, k in enumerate(self.config.decoder.end_of_section_token_ids):
# The maximum sequence length was exceeded, thus no more tokens:
if prev_col >= seq_len:
sections[j].append('')
continue
# Find first occurrence of special tokens that indicate the boundary between sections:
col = (i == k).int().argmax().item()
# If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
# the maximum sequence length):
if col == 0:
col = seq_len
# Extract section token identifiers:
section_token_ids = i[prev_col:col]
prev_col = col
section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
sections[j].append(section_string)
return tuple(sections.values())
@staticmethod
def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
report_seq_len = causal_2d_attention_mask.shape[-1]
non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
# Upper left of attention matrix:
upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
upper_left = upper_left * non_causal_2d_attention_mask
upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
causal_mask = torch.tril(
torch.ones(
(
report_seq_len,
report_seq_len,
),
dtype=torch.long,
device=causal_2d_attention_mask.device,
),
)
# Lower right of attention matrix:
lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
lower_right = lower_right * causal_mask
# Upper right of attention matrix:
upper_right = torch.zeros(
causal_2d_attention_mask.shape[0],
1,
prompt_seq_len,
report_seq_len,
dtype=torch.long,
device=causal_2d_attention_mask.device,
)
# Lower left of attention matrix:
lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
left = torch.cat((upper_left, lower_left), dim=2)
right = torch.cat((upper_right, lower_right), dim=2)
mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
return mixed_causality_4d_attention_mask
@staticmethod
def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask):
non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
return mixed_causality_4d_attention_mask