File size: 3,736 Bytes
6cae404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449b506
6cae404
 
 
 
 
 
1352457
6cae404
 
 
 
1352457
5aa6565
6cae404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987ff34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os

os.system('pip install -q -e .')
os.system('pip uninstall bitsandbytes')
os.system('pip install bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl')

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
print(torch.cuda.is_available()) 

print(os.system('python -m bitsandbytes'))

import gradio as gr
import io
from contextlib import redirect_stdout
import openai
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig
from llava.model import LlavaMistralForCausalLM
from llava.eval.run_llava import eval_model

# LLaVa-Med model setup
model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4'
)
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

def query_gpt(api_key, llava_med_result, user_question, model="gpt-4o"):
    """
    Queries GPT to generate a detailed and medically accurate response.
    """
    openai.api_key = api_key  # Set API key dynamically
    prompt = f"""
    You are an AI Medical Assistant specializing in radiology, trained to analyze radiology scan findings (e.g., MRI, CT, X-ray) and provide clear, medically accurate explanations. 
    Based on the scan analysis {llava_med_result} and the question {user_question}, provide a concise summary of the radiology findings. Use clear, professional language, and if uncertain, 
    recommend consulting a licensed radiologist or healthcare provider.
    """
    response = openai.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=256
    )
    return response.choices[0].message.content

with gr.Blocks(theme=gr.themes.Monochrome()) as app:
    with gr.Column(scale=1):
        gr.Markdown("<center><h1>LLaVa-Med</h1></center>")

        with gr.Row():
            api_key_input = gr.Textbox(
                placeholder="Enter OpenAI API Key", 
                label="API Key", 
                type="password",
                scale=3
            )

        with gr.Row():
            image = gr.Image(type="filepath", scale=2)
            question = gr.Textbox(placeholder="Enter a question", label="Question", scale=3)

        with gr.Row():
            answer = gr.Textbox(placeholder="Answer pops up here", label="Answer", scale=1)

        def run_inference(api_key, image, question):
            # Arguments for the model
            args = type('Args', (), {
                "model_path": model_path,
                "model_base": None,
                "image_file": image,
                "query": question,
                "conv_mode": None,
                "sep": ",",
                "temperature": 0,
                "top_p": None,
                "num_beams": 1,
                "max_new_tokens": 512
            })()

            # Capture the printed output of eval_model
            f = io.StringIO()
            with redirect_stdout(f):
                eval_model(args)
            llava_med_result = f.getvalue()
            print(llava_med_result)

            # Generate more descriptive answer with GPT
            descriptive_answer = query_gpt(api_key, llava_med_result, question)

            return descriptive_answer

        with gr.Row():
            btn = gr.Button("Run Inference", scale=1)

        btn.click(fn=run_inference, inputs=[api_key_input, image, question], outputs=answer)

app.launch(debug=True, height=800, width="100%")