|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
datasets: |
|
- StanfordAIMI/interpret-cxr-test-public |
|
- StanfordAIMI/interpret-cxr-test-hidden |
|
--- |
|
|
|
# CXRMate-RRG4: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation |
|
|
|
This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the Radiology Report Generation task of BioNLP @ ACL 2024. |
|
|
|
For this, proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST). |
|
EAST modifies Self-Critical Sequence Training (SCST) by adding entropy regularisation. |
|
This helps maintain a higher entropy in the token distribution, |
|
preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training, |
|
which is essential for handling the diversity of the radiology reports in the RRG24 datasets. |
|
We apply this to a multimodal language model with RadGraph as the reward. |
|
|
|
Additionally, our model incorporates several other aspects. |
|
We use token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings. |
|
To handle missing sections, we employ special tokens. |
|
We also utilise an attention mask with non-causal masking for the image embeddings and a causal mask for the report token embeddings. |
|
|
|
## How to use: |
|
|
|
```python |
|
import torch |
|
from torchvision.transforms import v2 |
|
import transformers |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24') |
|
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True) |
|
transforms = v2.Compose( |
|
[ |
|
v2.PILToTensor(), |
|
v2.Grayscale(num_output_channels=3), |
|
v2.Resize(size=model.config.encoder.image_size, antialias=True), |
|
v2.CenterCrop(size=[model.config.encoder.image_size]*2), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std), |
|
] |
|
) |
|
|
|
image = transforms(image) # Fix. |
|
|
|
output_ids = model.generate( |
|
pixel_values=images, # Fix. |
|
max_length=512, |
|
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]], |
|
num_beams=4, |
|
use_cache=True, |
|
) |
|
findings, impression = model.split_and_decode_sections(output_ids, tokenizer) |
|
``` |
|
|
|
## Paper: |
|
|
|
## Citation: |
|
|
|
[More Information Needed] |
|
|
|
|