File size: 5,432 Bytes
dc94d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a697138
dc94d87
a697138
 
dc94d87
a697138
 
dc94d87
 
71b35d2
dc94d87
a697138
dc94d87
 
a697138
 
dc94d87
ca57734
a697138
 
 
 
 
 
 
 
 
dc94d87
 
a697138
 
 
 
 
 
 
dc94d87
 
71b35d2
dc94d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca57734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc94d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from pathlib import Path

import io

import requests
import torch
from PIL import Image
import numpy as np
from huggingface_hub import snapshot_download

from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
from LLAVA_Biovil.llava.model.builder import load_pretrained_model
from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1

from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
from utils import create_chest_xray_transform_for_inference, init_chexpert_predictor


def load_model_from_huggingface(repo_id):
    # Download model files
    model_path = snapshot_download(repo_id=repo_id, revision="main", force_download=True)
    model_path = Path(model_path)

    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
                                                                           model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)


    return tokenizer, model, image_processor, context_len



if __name__ == '__main__':
    sample_img_path = "https://openi.nlm.nih.gov/imgs/512/294/3502/CXR3502_IM-1707-1001.png?keywords=Surgical%20Instruments,Cardiomegaly,Pulmonary%20Congestion,Diaphragm"

    response = requests.get(sample_img_path)
    image = Image.open(io.BytesIO(response.content))
    image = remap_to_uint8(np.array(image))
    image = Image.fromarray(image).convert("L")

    tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
    cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()

    model.config.tokenizer_padding_side = "left"

    cp_image = cp_transforms(image)
    logits = cp_model(cp_image[None].half().cuda())
    preds_probs = torch.sigmoid(logits)
    preds = preds_probs > 0.5
    pred = preds[0].cpu().numpy()
    findings = cp_class_names[pred].tolist()
    findings = ', '.join(findings).lower().strip()

    conv = conv_vicuna_v1.copy()
    REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
    print("USER: ", REPORT_GEN_PROMPT)
    conv.append_message("USER", REPORT_GEN_PROMPT)
    conv.append_message("ASSISTANT", None)
    text_input = conv.get_prompt()

    # get the image
    vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
    image_tensor = vis_transforms_biovil(image).unsqueeze(0)

    image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
    input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)

    # generate a report
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)

    # add prediction to conversation
    conv.messages.pop()
    conv.append_message("ASSISTANT", pred)
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)

    # generate a report
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)

    # add prediction to conversation
    conv.messages.pop()
    conv.append_message("ASSISTANT", pred)
    conv.append_message("USER", "Translate this report to easy language for a patient to understand.")
    conv.append_message("ASSISTANT", None)
    text_input = conv.get_prompt()
    print("USER: ", "Translate this report to easy language for a patient to understand.")

    # generate easy language report
    input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=False,
            use_cache=True,
            max_new_tokens=300,
            stopping_criteria=[stopping_criteria],
            pad_token_id=tokenizer.pad_token_id
        )

    pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
    print("ASSISTANT: ", pred)