File size: 6,774 Bytes
b00d5c3
 
51af647
 
 
 
 
 
30d9f56
b00d5c3
51af647
 
 
 
 
 
 
 
 
8fdae9d
 
d9d891d
 
8fdae9d
 
 
 
 
 
 
 
 
051e985
8fdae9d
 
 
 
4a02079
 
d9d891d
285e0fb
8fdae9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca57734
 
8fdae9d
 
 
 
 
 
 
285e0fb
8fdae9d
ca57734
8fdae9d
 
 
 
 
285e0fb
ca57734
 
 
 
 
 
 
 
 
8fdae9d
 
ca57734
 
 
 
 
 
 
8fdae9d
 
285e0fb
8fdae9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285e0fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fdae9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
---
license: apache-2.0
tags:
- medical
- radiology report generation
- medical chatbot
datasets:
- mimic-cxr
library_name: transformers
---

<!-- markdownlint-disable first-line-h1 -->
<!-- markdownlint-disable html -->

<div align="center">
<h1>
  RaDialog
</h1>
</div>

<p align="center">
📝 <a href="https://arxiv.org/abs/2311.18681" target="_blank">Paper</a> • 🖥️ <a href="https://github.com/ChantalMP/RaDialog" target="_blank">Github</a> • 🗂️<a href="https://physionet.org/content/radialog-instruct-dataset/1.0.0/" target="_blank">Dataset</a>
 • 🌐️<a href="https://chantalmp.github.io/RaDialog/" target="_blank">Project Page</a> </p>

<div align="center">
</div>

## Get Started

Clone repository:
```python
git clone https://huggingface.co/Chantal/RaDialog-interactive-radiology-report-generation
cd RaDialog-interactive-radiology-report-generation
```

Install requirements:
```python
conda create -n llava_hf python=3.10
conda activate llava_hf
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install -r requirements.txt
```

Run RaDialog inference:
```python
from pathlib import Path

import io

import requests
import torch
from PIL import Image
import numpy as np
from huggingface_hub import snapshot_download

from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
from LLAVA_Biovil.llava.model.builder import load_pretrained_model
from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1

from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
from utils import create_chest_xray_transform_for_inference, init_chexpert_predictor


def load_model_from_huggingface(repo_id):
    # Download model files
    model_path = snapshot_download(repo_id=repo_id, revision="main", force_download=True)
    model_path = Path(model_path)

    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
                                                                           model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, $


    return tokenizer, model, image_processor, context_len



if __name__ == '__main__':
    sample_img_path = "https://openi.nlm.nih.gov/imgs/512/294/3502/CXR3502_IM-1707-1001.png?keywords=Surgical%20Instruments,Cardiomegaly,Pulmonary%20Congestion,Diaphragm"

    response = requests.get(sample_img_path)
    image = Image.open(io.BytesIO(response.content))
    image = remap_to_uint8(np.array(image))
    image = Image.fromarray(image).convert("L")

    tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
    cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()

    model.config.tokenizer_padding_side = "left"

    cp_image = cp_transforms(image)
    logits = cp_model(cp_image[None].half().cuda())
    preds_probs = torch.sigmoid(logits)
    preds = preds_probs > 0.5
    pred = preds[0].cpu().numpy()
    findings = cp_class_names[pred].tolist()
    findings = ', '.join(findings).lower().strip()

    conv = conv_vicuna_v1.copy()
    REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predi$
    print("USER: ", REPORT_GEN_PROMPT)
    conv.append_message("USER", REPORT_GEN_PROMPT)
    conv.append_message("ASSISTANT", None)
    text_input = conv.get_prompt()

    # get the image
    vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
    image_tensor = vis_transforms_biovil(image).unsqueeze(0)

    image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
    input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)

    # generate a report
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)

    # add prediction to conversation
    conv.messages.pop()
    conv.append_message("ASSISTANT", pred)
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)

    # generate a report
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)

    # add prediction to conversation
    conv.messages.pop()
    conv.append_message("ASSISTANT", pred)
    conv.append_message("USER", "Translate this report to easy language for a patient to understand.")
    conv.append_message("ASSISTANT", None)
    text_input = conv.get_prompt()
    print("USER: ", "Translate this report to easy language for a patient to understand.")

    # generate easy language report
    input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)
```

## ✏️ Citation

```
@article{pellegrini2023radialog,
  title={RaDialog: A Large Vision-Language Model for Radiology Report Generation and Conversational Assistance},
  author={Pellegrini, Chantal and {\"O}zsoy, Ege and Busam, Benjamin and Navab, Nassir and Keicher, Matthias},
  journal={arXiv preprint arXiv:2311.18681},
  year={2023}
}
```