Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
File size: 3,900 Bytes
1ee4361
a871d5a
 
ab20601
3079483
 
1ee4361
 
44897b8
ca35de8
171b591
ca35de8
b109bce
 
5ec58e6
171b591
5ec58e6
 
 
 
 
 
733b448
5ec58e6
1ee4361
e8acc55
7c82fd0
 
1d6c929
7c82fd0
 
 
 
e41f3b0
 
 
 
 
 
 
 
 
 
 
 
858d464
e8acc55
 
 
 
 
 
 
 
 
 
858d464
e41f3b0
e8acc55
e41f3b0
 
e8acc55
e41f3b0
 
 
 
78c1cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ec58e6
 
 
e41f3b0
1ee4361
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
---
datasets:
- StanfordAIMI/interpret-cxr-test-public
- StanfordAIMI/interpret-cxr-test-hidden
library_name: transformers
license: apache-2.0
---

# CXRMate-RRG24: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation

This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the [Radiology Report Generation](https://stanford-aimi.github.io/RRG24/) task of [BioNLP @ ACL 2024](https://aclweb.org/aclwiki/BioNLP_Workshop).

The leaderboard for the task can be found [here](https://vilmedic.app/misc/bionlp24/leaderboard).

For this, we proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST):
 - EAST modifies [Self-Critical Sequence Training (SCST)](https://openaccess.thecvf.com/content_cvpr_2017/papers/Rennie_Self-Critical_Sequence_Training_CVPR_2017_paper.pdf) by adding entropy regularisation. 
 - Helps maintain a higher entropy in the token distribution.
 - Preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training.
 - This was essential to handle the diversity of the radiology reports in the RRG24 datasets. 

EAST was applied to a multimodal language model with RadGraph as the reward. Other features include:
 - Token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings. 
 - Special tokens (`[NF]` and `[NI]`) to handle missing *findings* and *impression* sections. 
 - Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.

## Example:

```python
import torch  
from torchvision.transforms import v2
import transformers


tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24')
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True)
transforms = v2.Compose(
    [
        v2.PILToTensor(),
        v2.Grayscale(num_output_channels=3),
        v2.Resize(size=model.config.encoder.image_size, antialias=True),
        v2.CenterCrop(size=[model.config.encoder.image_size]*2),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std),
    ]
)

dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test']

def transform_batch(batch):
    batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']]
    batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)  
    return batch

dataset = dataset.with_transform(transform_batch)
dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True)
batch = next(iter(dataloader))

output_ids = model.generate(
    pixel_values=batch['images'],
    max_length=512,
    num_beams=4,
    bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
)
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
```

## Generate findings only:

```python
output_ids = model.generate(
    pixel_values=batch['images'],
    max_length=512,
    num_beams=4,
    bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')]],
    eos_token_id=tokenizer.sep_token_id,
)
findings, _ = model.split_and_decode_sections(output_ids, tokenizer)
```

## Generate impression only:

```python
output_ids = model.generate(
    pixel_values=batch['images'],
    max_length=512,
    num_beams=4,
    bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NI]')]],
    input_ids=torch.tensor([[tokenizer.bos_token_id, tokenizer.convert_tokens_to_ids('[NF]'), tokenizer.sep_token_id]]*mbatch_size, device=device, dtype=torch.long),
)
_, impression = model.split_and_decode_sections(output_ids, tokenizer)
```

## Notebook example:
https://huggingface.co/aehrc/cxrmate-rrg24/blob/main/demo.ipynb

## Citation:

[More Information Needed]