Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
cxrmate-rrg24 / README.md
anicolson's picture
Update README.md
7c249d6 verified
|
raw
history blame
2.34 kB
metadata
library_name: transformers
license: apache-2.0
datasets:
  - StanfordAIMI/interpret-cxr-test-public
  - StanfordAIMI/interpret-cxr-test-hidden

CXRMate-RRG4: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation

This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the Radiology Report Generation task of BioNLP @ ACL 2024.

For this, proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST). EAST modifies Self-Critical Sequence Training (SCST) by adding entropy regularisation. This helps maintain a higher entropy in the token distribution, preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training, which is essential for handling the diversity of the radiology reports in the RRG24 datasets. We apply this to a multimodal language model with RadGraph as the reward.

Additionally, our model incorporates several other aspects. We use token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings. To handle missing sections, we employ special tokens. We also utilise an attention mask with non-causal masking for the image embeddings and a causal mask for the report token embeddings.

How to use:

import torch
from torchvision.transforms import v2
import transformers


tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24')
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True)
transforms = v2.Compose(
    [
        v2.PILToTensor(),
        v2.Grayscale(num_output_channels=3),
        v2.Resize(size=model.config.encoder.image_size, antialias=True),
        v2.CenterCrop(size=[model.config.encoder.image_size]*2),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std),
    ]
)

image = transforms(image)  # Fix.

output_ids = model.generate(
    pixel_values=images,  # Fix.
    max_length=512,
    bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
    num_beams=4,
    use_cache=True,
)
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)

Paper:

Citation:

[More Information Needed]