|
--- |
|
license: apache-2.0 |
|
inference: false |
|
tags: |
|
- medical |
|
- radiology report generation |
|
- medical chatbot |
|
datasets: |
|
- mimic-cxr |
|
library_name: transformers |
|
--- |
|
|
|
<!-- markdownlint-disable first-line-h1 --> |
|
<!-- markdownlint-disable html --> |
|
|
|
<div align="center"> |
|
<h1> |
|
RaDialog |
|
</h1> |
|
</div> |
|
|
|
<p align="center"> |
|
📝 <a href="https://arxiv.org/abs/2311.18681" target="_blank">Paper</a> • 🖥️ <a href="https://github.com/ChantalMP/RaDialog_LLaVA" target="_blank">Github</a> • 🗂️<a href="https://physionet.org/content/radialog-instruct-dataset/1.0.0/" target="_blank">Dataset</a> |
|
• 🌐️<a href="https://chantalmp.github.io/RaDialog/" target="_blank">Project Page</a> </p> |
|
|
|
<div align="center"> |
|
</div> |
|
|
|
## Get Started |
|
|
|
Clone repository: |
|
```python |
|
git clone https://huggingface.co/ChantalPellegrini/RaDialog-interactive-radiology-report-generation |
|
``` |
|
|
|
Install requirements: |
|
```python |
|
conda create -n llava_hf python=3.10 |
|
conda activate llava_hf |
|
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia |
|
pip install -r requirements.txt |
|
``` |
|
|
|
Run RaDialog inference: |
|
```python |
|
from pathlib import Path |
|
|
|
import io |
|
|
|
import requests |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from huggingface_hub import snapshot_download |
|
|
|
from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8 |
|
from LLAVA_Biovil.llava.model.builder import load_pretrained_model |
|
from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1 |
|
|
|
from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX |
|
from utils import create_chest_xray_transform_for_inference, init_chexpert_predictor |
|
|
|
|
|
def load_model_from_huggingface(repo_id): |
|
# Download model files |
|
model_path = snapshot_download(repo_id=repo_id, revision="main", force_download=True) |
|
model_path = Path(model_path) |
|
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b', |
|
model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False) |
|
|
|
|
|
return tokenizer, model, image_processor, context_len |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
sample_img_path = "https://openi.nlm.nih.gov/imgs/512/294/3502/CXR3502_IM-1707-1001.png?keywords=Surgical%20Instruments,Cardiomegaly,Pulmonary%20Congestion,Diaphragm" |
|
|
|
response = requests.get(sample_img_path) |
|
image = Image.open(io.BytesIO(response.content)) |
|
image = remap_to_uint8(np.array(image)) |
|
image = Image.fromarray(image).convert("L") |
|
|
|
tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation") |
|
cp_model, cp_class_names, cp_transforms = init_chexpert_predictor() |
|
|
|
model.config.tokenizer_padding_side = "left" |
|
|
|
cp_image = cp_transforms(image) |
|
logits = cp_model(cp_image[None].half().cuda()) |
|
preds_probs = torch.sigmoid(logits) |
|
preds = preds_probs > 0.5 |
|
pred = preds[0].cpu().numpy() |
|
findings = cp_class_names[pred].tolist() |
|
findings = ', '.join(findings).lower().strip() |
|
|
|
conv = conv_vicuna_v1.copy() |
|
REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons." |
|
print("USER: ", REPORT_GEN_PROMPT) |
|
conv.append_message("USER", REPORT_GEN_PROMPT) |
|
conv.append_message("ASSISTANT", None) |
|
text_input = conv.get_prompt() |
|
|
|
# get the image |
|
vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448) |
|
image_tensor = vis_transforms_biovil(image).unsqueeze(0) |
|
|
|
image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16) |
|
input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) |
|
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids) |
|
|
|
# generate a report |
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=False, |
|
use_cache=True, |
|
max_new_tokens=300, |
|
stopping_criteria=[stopping_criteria], |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "") |
|
print("ASSISTANT: ", pred) |
|
|
|
# add prediction to conversation |
|
conv.messages.pop() |
|
conv.append_message("ASSISTANT", pred) |
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids) |
|
|
|
# generate a report |
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=False, |
|
use_cache=True, |
|
max_new_tokens=300, |
|
stopping_criteria=[stopping_criteria], |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "") |
|
print("ASSISTANT: ", pred) |
|
|
|
# add prediction to conversation |
|
conv.messages.pop() |
|
conv.append_message("ASSISTANT", pred) |
|
conv.append_message("USER", "Translate this report to easy language for a patient to understand.") |
|
conv.append_message("ASSISTANT", None) |
|
text_input = conv.get_prompt() |
|
print("USER: ", "Translate this report to easy language for a patient to understand.") |
|
|
|
# generate easy language report |
|
input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) |
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=False, |
|
use_cache=True, |
|
max_new_tokens=300, |
|
stopping_criteria=[stopping_criteria], |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "") |
|
print("ASSISTANT: ", pred) |
|
|
|
``` |
|
|
|
## ✏️ Citation |
|
|
|
``` |
|
@article{pellegrini2023radialog, |
|
title={RaDialog: A Large Vision-Language Model for Radiology Report Generation and Conversational Assistance}, |
|
author={Pellegrini, Chantal and {\"O}zsoy, Ege and Busam, Benjamin and Navab, Nassir and Keicher, Matthias}, |
|
journal={arXiv preprint arXiv:2311.18681}, |
|
year={2023} |
|
} |
|
``` |
|
|