Update README.md
Browse files
README.md
CHANGED
@@ -21,7 +21,7 @@ EAST was applied to a multimodal language model with RadGraph as the reward. Oth
|
|
21 |
- Special tokens (`NF` and `NI`) to handle missing *findings* and *impression* sections.
|
22 |
- Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
|
23 |
|
24 |
-
##
|
25 |
|
26 |
```python
|
27 |
import torch
|
@@ -42,14 +42,23 @@ transforms = v2.Compose(
|
|
42 |
]
|
43 |
)
|
44 |
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
output_ids = model.generate(
|
48 |
-
pixel_values=images,
|
49 |
max_length=512,
|
50 |
-
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
|
51 |
num_beams=4,
|
52 |
use_cache=True,
|
|
|
53 |
)
|
54 |
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
|
55 |
```
|
|
|
21 |
- Special tokens (`NF` and `NI`) to handle missing *findings* and *impression* sections.
|
22 |
- Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
|
23 |
|
24 |
+
## Example:
|
25 |
|
26 |
```python
|
27 |
import torch
|
|
|
42 |
]
|
43 |
)
|
44 |
|
45 |
+
dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test']
|
46 |
+
|
47 |
+
def transform_batch(batch):
|
48 |
+
batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']]
|
49 |
+
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
|
50 |
+
return batch
|
51 |
+
|
52 |
+
dataset = dataset.with_transform(transform_batch)
|
53 |
+
dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True)
|
54 |
+
batch = next(iter(dataloader))
|
55 |
|
56 |
output_ids = model.generate(
|
57 |
+
pixel_values=batch['images'],
|
58 |
max_length=512,
|
|
|
59 |
num_beams=4,
|
60 |
use_cache=True,
|
61 |
+
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
|
62 |
)
|
63 |
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
|
64 |
```
|