File size: 6,774 Bytes
b00d5c3 51af647 30d9f56 b00d5c3 51af647 8fdae9d d9d891d 8fdae9d 051e985 8fdae9d 4a02079 d9d891d 285e0fb 8fdae9d ca57734 8fdae9d 285e0fb 8fdae9d ca57734 8fdae9d 285e0fb ca57734 8fdae9d ca57734 8fdae9d 285e0fb 8fdae9d 285e0fb 8fdae9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
---
license: apache-2.0
tags:
- medical
- radiology report generation
- medical chatbot
datasets:
- mimic-cxr
library_name: transformers
---
<!-- markdownlint-disable first-line-h1 -->
<!-- markdownlint-disable html -->
<div align="center">
<h1>
RaDialog
</h1>
</div>
<p align="center">
📝 <a href="https://arxiv.org/abs/2311.18681" target="_blank">Paper</a> • 🖥️ <a href="https://github.com/ChantalMP/RaDialog" target="_blank">Github</a> • 🗂️<a href="https://physionet.org/content/radialog-instruct-dataset/1.0.0/" target="_blank">Dataset</a>
• 🌐️<a href="https://chantalmp.github.io/RaDialog/" target="_blank">Project Page</a> </p>
<div align="center">
</div>
## Get Started
Clone repository:
```python
git clone https://huggingface.co/Chantal/RaDialog-interactive-radiology-report-generation
cd RaDialog-interactive-radiology-report-generation
```
Install requirements:
```python
conda create -n llava_hf python=3.10
conda activate llava_hf
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install -r requirements.txt
```
Run RaDialog inference:
```python
from pathlib import Path
import io
import requests
import torch
from PIL import Image
import numpy as np
from huggingface_hub import snapshot_download
from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
from LLAVA_Biovil.llava.model.builder import load_pretrained_model
from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1
from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
from utils import create_chest_xray_transform_for_inference, init_chexpert_predictor
def load_model_from_huggingface(repo_id):
# Download model files
model_path = snapshot_download(repo_id=repo_id, revision="main", force_download=True)
model_path = Path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, $
return tokenizer, model, image_processor, context_len
if __name__ == '__main__':
sample_img_path = "https://openi.nlm.nih.gov/imgs/512/294/3502/CXR3502_IM-1707-1001.png?keywords=Surgical%20Instruments,Cardiomegaly,Pulmonary%20Congestion,Diaphragm"
response = requests.get(sample_img_path)
image = Image.open(io.BytesIO(response.content))
image = remap_to_uint8(np.array(image))
image = Image.fromarray(image).convert("L")
tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
model.config.tokenizer_padding_side = "left"
cp_image = cp_transforms(image)
logits = cp_model(cp_image[None].half().cuda())
preds_probs = torch.sigmoid(logits)
preds = preds_probs > 0.5
pred = preds[0].cpu().numpy()
findings = cp_class_names[pred].tolist()
findings = ', '.join(findings).lower().strip()
conv = conv_vicuna_v1.copy()
REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predi$
print("USER: ", REPORT_GEN_PROMPT)
conv.append_message("USER", REPORT_GEN_PROMPT)
conv.append_message("ASSISTANT", None)
text_input = conv.get_prompt()
# get the image
vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
image_tensor = vis_transforms_biovil(image).unsqueeze(0)
image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
# generate a report
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=False,
use_cache=True,
max_new_tokens=300,
stopping_criteria=[stopping_criteria],
pad_token_id=tokenizer.pad_token_id
)
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
print("ASSISTANT: ", pred)
# add prediction to conversation
conv.messages.pop()
conv.append_message("ASSISTANT", pred)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
# generate a report
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=False,
use_cache=True,
max_new_tokens=300,
stopping_criteria=[stopping_criteria],
pad_token_id=tokenizer.pad_token_id
)
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
print("ASSISTANT: ", pred)
# add prediction to conversation
conv.messages.pop()
conv.append_message("ASSISTANT", pred)
conv.append_message("USER", "Translate this report to easy language for a patient to understand.")
conv.append_message("ASSISTANT", None)
text_input = conv.get_prompt()
print("USER: ", "Translate this report to easy language for a patient to understand.")
# generate easy language report
input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=False,
use_cache=True,
max_new_tokens=300,
stopping_criteria=[stopping_criteria],
pad_token_id=tokenizer.pad_token_id
)
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
print("ASSISTANT: ", pred)
```
## ✏️ Citation
```
@article{pellegrini2023radialog,
title={RaDialog: A Large Vision-Language Model for Radiology Report Generation and Conversational Assistance},
author={Pellegrini, Chantal and {\"O}zsoy, Ege and Busam, Benjamin and Navab, Nassir and Keicher, Matthias},
journal={arXiv preprint arXiv:2311.18681},
year={2023}
}
```
|