Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
File size: 3,606 Bytes
1ee4361
 
a871d5a
 
 
ab20601
1ee4361
 
e41f3b0
ca35de8
 
 
5ec58e6
 
 
 
 
 
 
 
733b448
5ec58e6
1ee4361
e8acc55
7c82fd0
 
1d6c929
7c82fd0
 
 
 
e41f3b0
 
 
 
 
 
 
 
 
 
 
 
858d464
e8acc55
 
 
 
 
 
 
 
 
 
858d464
e41f3b0
e8acc55
e41f3b0
 
e8acc55
e41f3b0
 
 
 
78c1cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ec58e6
 
 
e41f3b0
 
 
1ee4361
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
---
library_name: transformers
license: apache-2.0
datasets:
- StanfordAIMI/interpret-cxr-test-public
- StanfordAIMI/interpret-cxr-test-hidden
---

# CXRMate-RRG4: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation

This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the Radiology Report Generation task of BioNLP @ ACL 2024.

For this, we proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST):
 - EAST modifies Self-Critical Sequence Training (SCST) by adding entropy regularisation. 
 - Helps maintain a higher entropy in the token distribution.
 - Preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training.
 - This was essential to handle the diversity of the radiology reports in the RRG24 datasets. 

EAST was applied to a multimodal language model with RadGraph as the reward. Other features include:
 - Token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings. 
 - Special tokens (`[NF]` and `[NI]`) to handle missing *findings* and *impression* sections. 
 - Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.

## Example:

```python
import torch  
from torchvision.transforms import v2
import transformers


tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24')
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True)
transforms = v2.Compose(
    [
        v2.PILToTensor(),
        v2.Grayscale(num_output_channels=3),
        v2.Resize(size=model.config.encoder.image_size, antialias=True),
        v2.CenterCrop(size=[model.config.encoder.image_size]*2),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std),
    ]
)

dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test']

def transform_batch(batch):
    batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']]
    batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)  
    return batch

dataset = dataset.with_transform(transform_batch)
dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True)
batch = next(iter(dataloader))

output_ids = model.generate(
    pixel_values=batch['images'],
    max_length=512,
    num_beams=4,
    bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
)
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
```

## Generate findings only:

```python
output_ids = model.generate(
    pixel_values=batch['images'],
    max_length=512,
    num_beams=4,
    bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')]],
    eos_token_id=tokenizer.sep_token_id,
)
findings, _ = model.split_and_decode_sections(output_ids, tokenizer)
```

## Generate impression only:

```python
output_ids = model.generate(
    pixel_values=batch['images'],
    max_length=512,
    num_beams=4,
    bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NI]')]],
    input_ids=torch.tensor([[tokenizer.bos_token_id, tokenizer.convert_tokens_to_ids('[NF]'), tokenizer.sep_token_id]]*mbatch_size, device=device, dtype=torch.long),
)
_, impression = model.split_and_decode_sections(output_ids, tokenizer)
```

## Notebook example:
https://huggingface.co/aehrc/cxrmate-rrg24/blob/main/demo.ipynb

## Paper:

## Citation:

[More Information Needed]