File size: 3,900 Bytes
1ee4361 a871d5a ab20601 3079483 1ee4361 44897b8 ca35de8 171b591 ca35de8 b109bce 5ec58e6 171b591 5ec58e6 733b448 5ec58e6 1ee4361 e8acc55 7c82fd0 1d6c929 7c82fd0 e41f3b0 858d464 e8acc55 858d464 e41f3b0 e8acc55 e41f3b0 e8acc55 e41f3b0 78c1cf9 5ec58e6 e41f3b0 1ee4361 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
---
datasets:
- StanfordAIMI/interpret-cxr-test-public
- StanfordAIMI/interpret-cxr-test-hidden
library_name: transformers
license: apache-2.0
---
# CXRMate-RRG24: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation
This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the [Radiology Report Generation](https://stanford-aimi.github.io/RRG24/) task of [BioNLP @ ACL 2024](https://aclweb.org/aclwiki/BioNLP_Workshop).
The leaderboard for the task can be found [here](https://vilmedic.app/misc/bionlp24/leaderboard).
For this, we proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST):
- EAST modifies [Self-Critical Sequence Training (SCST)](https://openaccess.thecvf.com/content_cvpr_2017/papers/Rennie_Self-Critical_Sequence_Training_CVPR_2017_paper.pdf) by adding entropy regularisation.
- Helps maintain a higher entropy in the token distribution.
- Preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training.
- This was essential to handle the diversity of the radiology reports in the RRG24 datasets.
EAST was applied to a multimodal language model with RadGraph as the reward. Other features include:
- Token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings.
- Special tokens (`[NF]` and `[NI]`) to handle missing *findings* and *impression* sections.
- Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
## Example:
```python
import torch
from torchvision.transforms import v2
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24')
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True)
transforms = v2.Compose(
[
v2.PILToTensor(),
v2.Grayscale(num_output_channels=3),
v2.Resize(size=model.config.encoder.image_size, antialias=True),
v2.CenterCrop(size=[model.config.encoder.image_size]*2),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std),
]
)
dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test']
def transform_batch(batch):
batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']]
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
return batch
dataset = dataset.with_transform(transform_batch)
dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True)
batch = next(iter(dataloader))
output_ids = model.generate(
pixel_values=batch['images'],
max_length=512,
num_beams=4,
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
)
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
```
## Generate findings only:
```python
output_ids = model.generate(
pixel_values=batch['images'],
max_length=512,
num_beams=4,
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')]],
eos_token_id=tokenizer.sep_token_id,
)
findings, _ = model.split_and_decode_sections(output_ids, tokenizer)
```
## Generate impression only:
```python
output_ids = model.generate(
pixel_values=batch['images'],
max_length=512,
num_beams=4,
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NI]')]],
input_ids=torch.tensor([[tokenizer.bos_token_id, tokenizer.convert_tokens_to_ids('[NF]'), tokenizer.sep_token_id]]*mbatch_size, device=device, dtype=torch.long),
)
_, impression = model.split_and_decode_sections(output_ids, tokenizer)
```
## Notebook example:
https://huggingface.co/aehrc/cxrmate-rrg24/blob/main/demo.ipynb
## Citation:
[More Information Needed]
|