Veda0718 commited on
Commit
6cae404
·
verified ·
1 Parent(s): f2f1a96

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system('pip install -q -e .')
4
+ os.system('pip uninstall bitsandbytes')
5
+ os.system('pip install bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl')
6
+
7
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
8
+
9
+ import torch
10
+ print(torch.cuda.is_available())
11
+
12
+ print(os.system('python -m bitsandbytes'))
13
+
14
+ import gradio as gr
15
+ import io
16
+ from contextlib import redirect_stdout
17
+ import openai
18
+ import torch
19
+ from transformers import AutoTokenizer, BitsAndBytesConfig
20
+ from llava.model import LlavaMistralForCausalLM
21
+ from llava.eval.run_llava import eval_model
22
+
23
+ # LLaVa-Med model setup
24
+ model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
25
+ kwargs = {"device_map": "auto"}
26
+ kwargs['load_in_4bit'] = True
27
+ kwargs['quantization_config'] = BitsAndBytesConfig(
28
+ load_in_4bit=True,
29
+ bnb_4bit_compute_dtype=torch.float16,
30
+ bnb_4bit_use_double_quant=True,
31
+ bnb_4bit_quant_type='nf4'
32
+ )
33
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
34
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
35
+
36
+ def query_gpt(api_key, llava_med_result, user_question, model="gpt-3.5-turbo"):
37
+ """
38
+ Queries GPT to generate a detailed and medically accurate response.
39
+ """
40
+ openai.api_key = api_key # Set API key dynamically
41
+ prompt = f"""
42
+ You are an AI Medical Assistant specializing in radiology, trained to analyze radiology scan findings (e.g., MRI, CT, X-ray) and provide clear, medically accurate explanations.
43
+ Based on the scan analysis {llava_med_result} and the question {user_question}, provide a concise summary of the radiology findings, explain their clinical significance in relation to
44
+ the question, and offer relevant recommendations such as follow-up imaging, specialist consultations, or further tests. Use clear, professional language, and if uncertain,
45
+ recommend consulting a licensed radiologist or healthcare provider.
46
+ """
47
+ response = openai.chat.completions.create(
48
+ model=model,
49
+ messages=[{"role": "user", "content": prompt}]
50
+ )
51
+ return response.choices[0].message.content
52
+
53
+ with gr.Blocks(theme=gr.themes.Monochrome()) as app:
54
+ with gr.Column(scale=1):
55
+ gr.Markdown("<center><h1>LLaVa-Med</h1></center>")
56
+
57
+ with gr.Row():
58
+ api_key_input = gr.Textbox(
59
+ placeholder="Enter OpenAI API Key",
60
+ label="API Key",
61
+ type="password",
62
+ scale=3
63
+ )
64
+
65
+ with gr.Row():
66
+ image = gr.Image(type="filepath", scale=2)
67
+ question = gr.Textbox(placeholder="Enter a question", label="Question", scale=3)
68
+
69
+ with gr.Row():
70
+ answer = gr.Textbox(placeholder="Answer pops up here", label="Answer", scale=1)
71
+
72
+ def run_inference(api_key, image, question):
73
+ # Arguments for the model
74
+ args = type('Args', (), {
75
+ "model_path": model_path,
76
+ "model_base": None,
77
+ "image_file": image,
78
+ "query": question,
79
+ "conv_mode": None,
80
+ "sep": ",",
81
+ "temperature": 0,
82
+ "top_p": None,
83
+ "num_beams": 1,
84
+ "max_new_tokens": 512
85
+ })()
86
+
87
+ # Capture the printed output of eval_model
88
+ f = io.StringIO()
89
+ with redirect_stdout(f):
90
+ eval_model(args)
91
+ llava_med_result = f.getvalue()
92
+ print(llava_med_result)
93
+
94
+ # Generate more descriptive answer with GPT
95
+ descriptive_answer = query_gpt(api_key, llava_med_result, question)
96
+
97
+ return descriptive_answer
98
+
99
+ with gr.Row():
100
+ btn = gr.Button("Run Inference", scale=1)
101
+
102
+ btn.click(fn=run_inference, inputs=[api_key_input, image, question], outputs=answer)
103
+
104
+ app.launch(share=True, debug=True, height=800, width="100%")