|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
datasets: |
|
- StanfordAIMI/interpret-cxr-test-public |
|
- StanfordAIMI/interpret-cxr-test-hidden |
|
--- |
|
|
|
# CXRMate-RRG4: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation |
|
|
|
This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the Radiology Report Generation task of BioNLP @ ACL 2024. |
|
|
|
For this, we proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST): |
|
- EAST modifies Self-Critical Sequence Training (SCST) by adding entropy regularisation. |
|
- Helps maintain a higher entropy in the token distribution. |
|
- Preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training. |
|
- This was essential to handle the diversity of the radiology reports in the RRG24 datasets. |
|
|
|
EAST was applied to a multimodal language model with RadGraph as the reward. Other features include: |
|
- Token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings. |
|
- Special tokens (`[NF]` and `[NI]`) to handle missing *findings* and *impression* sections. |
|
- Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings. |
|
|
|
## Example: |
|
|
|
```python |
|
import torch |
|
from torchvision.transforms import v2 |
|
import transformers |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24') |
|
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True) |
|
transforms = v2.Compose( |
|
[ |
|
v2.PILToTensor(), |
|
v2.Grayscale(num_output_channels=3), |
|
v2.Resize(size=model.config.encoder.image_size, antialias=True), |
|
v2.CenterCrop(size=[model.config.encoder.image_size]*2), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std), |
|
] |
|
) |
|
|
|
dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test'] |
|
|
|
def transform_batch(batch): |
|
batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']] |
|
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0) |
|
return batch |
|
|
|
dataset = dataset.with_transform(transform_batch) |
|
dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True) |
|
batch = next(iter(dataloader)) |
|
|
|
output_ids = model.generate( |
|
pixel_values=batch['images'], |
|
max_length=512, |
|
num_beams=4, |
|
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]], |
|
) |
|
findings, impression = model.split_and_decode_sections(output_ids, tokenizer) |
|
``` |
|
|
|
## Generate findings only: |
|
|
|
```python |
|
output_ids = model.generate( |
|
pixel_values=batch['images'], |
|
max_length=512, |
|
num_beams=4, |
|
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')]], |
|
eos_token_id=tokenizer.sep_token_id, |
|
) |
|
findings, _ = model.split_and_decode_sections(output_ids, tokenizer) |
|
``` |
|
|
|
## Generate impression only: |
|
|
|
```python |
|
output_ids = model.generate( |
|
pixel_values=batch['images'], |
|
max_length=512, |
|
num_beams=4, |
|
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NI]')]], |
|
input_ids=torch.tensor([[tokenizer.bos_token_id, tokenizer.convert_tokens_to_ids('[NF]'), tokenizer.sep_token_id]]*mbatch_size, device=device, dtype=torch.long), |
|
) |
|
_, impression = model.split_and_decode_sections(output_ids, tokenizer) |
|
``` |
|
|
|
## Notebook example: |
|
https://huggingface.co/aehrc/cxrmate-rrg24/blob/main/demo.ipynb |
|
|
|
## Paper: |
|
|
|
## Citation: |
|
|
|
[More Information Needed] |
|
|
|
|