File size: 3,606 Bytes
1ee4361 a871d5a ab20601 1ee4361 e41f3b0 ca35de8 5ec58e6 733b448 5ec58e6 1ee4361 e8acc55 7c82fd0 1d6c929 7c82fd0 e41f3b0 858d464 e8acc55 858d464 e41f3b0 e8acc55 e41f3b0 e8acc55 e41f3b0 78c1cf9 5ec58e6 e41f3b0 1ee4361 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
---
library_name: transformers
license: apache-2.0
datasets:
- StanfordAIMI/interpret-cxr-test-public
- StanfordAIMI/interpret-cxr-test-hidden
---
# CXRMate-RRG4: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation
This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the Radiology Report Generation task of BioNLP @ ACL 2024.
For this, we proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST):
- EAST modifies Self-Critical Sequence Training (SCST) by adding entropy regularisation.
- Helps maintain a higher entropy in the token distribution.
- Preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training.
- This was essential to handle the diversity of the radiology reports in the RRG24 datasets.
EAST was applied to a multimodal language model with RadGraph as the reward. Other features include:
- Token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings.
- Special tokens (`[NF]` and `[NI]`) to handle missing *findings* and *impression* sections.
- Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
## Example:
```python
import torch
from torchvision.transforms import v2
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24')
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True)
transforms = v2.Compose(
[
v2.PILToTensor(),
v2.Grayscale(num_output_channels=3),
v2.Resize(size=model.config.encoder.image_size, antialias=True),
v2.CenterCrop(size=[model.config.encoder.image_size]*2),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std),
]
)
dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test']
def transform_batch(batch):
batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']]
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
return batch
dataset = dataset.with_transform(transform_batch)
dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True)
batch = next(iter(dataloader))
output_ids = model.generate(
pixel_values=batch['images'],
max_length=512,
num_beams=4,
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
)
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
```
## Generate findings only:
```python
output_ids = model.generate(
pixel_values=batch['images'],
max_length=512,
num_beams=4,
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')]],
eos_token_id=tokenizer.sep_token_id,
)
findings, _ = model.split_and_decode_sections(output_ids, tokenizer)
```
## Generate impression only:
```python
output_ids = model.generate(
pixel_values=batch['images'],
max_length=512,
num_beams=4,
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NI]')]],
input_ids=torch.tensor([[tokenizer.bos_token_id, tokenizer.convert_tokens_to_ids('[NF]'), tokenizer.sep_token_id]]*mbatch_size, device=device, dtype=torch.long),
)
_, impression = model.split_and_decode_sections(output_ids, tokenizer)
```
## Notebook example:
https://huggingface.co/aehrc/cxrmate-rrg24/blob/main/demo.ipynb
## Paper:
## Citation:
[More Information Needed]
|