RioJune commited on
Commit
826447b
·
1 Parent(s): 69b60ec
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoProcessor
5
+ import numpy as np
6
+ import supervision as sv
7
+ import albumentations as A
8
+ import cv2
9
+ from transformers import AutoConfig
10
+ import yaml
11
+
12
+ # Set Streamlit page configuration for a wide layout
13
+ st.set_page_config(layout="wide")
14
+
15
+ # Custom CSS for better layout and mobile responsiveness
16
+ st.markdown("""
17
+ <style>
18
+ .main {
19
+ max-width: 1200px; /* Max width for content */
20
+ margin: 0 auto;
21
+ }
22
+ .block-container {
23
+ padding-top: 2rem;
24
+ padding-bottom: 2rem;
25
+ padding-left: 3rem;
26
+ padding-right: 3rem;
27
+ }
28
+ .title {
29
+ font-size: 2.5rem;
30
+ text-align: center;
31
+ color: #FF6347;
32
+ }
33
+ .subheader {
34
+ font-size: 1.5rem;
35
+ margin-bottom: 20px;
36
+ }
37
+ .btn {
38
+ font-size: 1.1rem;
39
+ padding: 10px 20px;
40
+ background-color: #FF6347;
41
+ color: white;
42
+ border-radius: 5px;
43
+ border: none;
44
+ cursor: pointer;
45
+ }
46
+ .btn:hover {
47
+ background-color: #FF4500;
48
+ }
49
+ .column-spacing {
50
+ display: flex;
51
+ justify-content: space-between;
52
+ }
53
+ .col-half {
54
+ width: 48%;
55
+ }
56
+ .col-full {
57
+ width: 100%;
58
+ }
59
+ .instructions {
60
+ padding: 20px;
61
+ background-color: #f9f9f9;
62
+ border-radius: 8px;
63
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
64
+ }
65
+ </style>
66
+ """, unsafe_allow_html=True)
67
+
68
+ # Load Model and Processor
69
+ @st.cache_resource
70
+ def load_model():
71
+ REVISION = 'refs/pr/6'
72
+ MODEL_NAME = "RioJune/AD-KD-MICCAI25"
73
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+
75
+ config_model = AutoConfig.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
76
+ config_model.vision_config.model_type = "davit"
77
+
78
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True, config=config_model).to(DEVICE)
79
+
80
+ BASE_PROCESSOR = "microsoft/Florence-2-base-ft"
81
+ processor = AutoProcessor.from_pretrained(BASE_PROCESSOR, trust_remote_code=True)
82
+ processor.image_processor.size = 512
83
+ processor.image_processor.crop_size = 512
84
+
85
+ return model, processor, DEVICE
86
+
87
+ model, processor, DEVICE = load_model()
88
+
89
+ # Load Definitions
90
+ @st.cache_resource
91
+ def load_definitions():
92
+ vindr_path = 'configs/vindr_definition.yaml'
93
+ padchest_path = 'configs/padchest_definition.yaml'
94
+ prompt_path = 'examples/prompt.yaml'
95
+
96
+ with open(vindr_path, 'r') as file:
97
+ vindr_definitions = yaml.safe_load(file)
98
+ with open(padchest_path, 'r') as file:
99
+ padchest_definitions = yaml.safe_load(file)
100
+ with open(prompt_path, 'r') as file:
101
+ prompt_definitions = yaml.safe_load(file)
102
+
103
+ return vindr_definitions, padchest_definitions, prompt_definitions
104
+
105
+ vindr_definitions, padchest_definitions, prompt_definitions = load_definitions()
106
+
107
+ dataset_options = {"Vindr": vindr_definitions, "PadChest": padchest_definitions}
108
+
109
+ def load_example_images():
110
+ return list(prompt_definitions.keys())
111
+
112
+ example_images = load_example_images()
113
+
114
+ def apply_transform(image, size_mode=512):
115
+ pad_resize_transform = A.Compose([
116
+ A.LongestMaxSize(max_size=size_mode, interpolation=cv2.INTER_AREA),
117
+ A.PadIfNeeded(min_height=size_mode, min_width=size_mode, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)),
118
+ A.Resize(height=512, width=512, interpolation=cv2.INTER_AREA),
119
+ ])
120
+ image_np = np.array(image)
121
+ transformed = pad_resize_transform(image=image_np)
122
+ return transformed["image"]
123
+
124
+ # Streamlit UI with Colorful Title and Emojis
125
+ st.markdown("<h1 class='title'>🩺 Enhancing Abnormality Grounding for Vision Language Models with Knowledge Descriptions 🚀</h1>", unsafe_allow_html=True)
126
+ st.markdown(
127
+ "<p style='text-align: center; font-size: 18px;'>Welcome to a simple demo of our work! 🎉 Choose an example or upload your own image to get started! 👇</p>",
128
+ unsafe_allow_html=True
129
+ )
130
+
131
+ # Display Example Images First
132
+ st.subheader("🌄 Example Images")
133
+ selected_example = st.selectbox("Choose an example", example_images)
134
+ image = Image.open(selected_example).convert("RGB")
135
+ example_diseases = prompt_definitions.get(selected_example, [])
136
+ st.write("**Associated Diseases:**", ", ".join(example_diseases))
137
+
138
+ # Layout for Original Image and Instructions
139
+ col1, col2 = st.columns([1, 2])
140
+
141
+ # Left column for original image
142
+ with col1:
143
+ st.image(image, caption=f"Original Example Image: {selected_example}", width=400)
144
+
145
+ # Right column for Instructions and Run Inference Button
146
+ with col2:
147
+ st.subheader("⚙️ Instructions to Get Started:")
148
+ st.write("""
149
+ - **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results.
150
+ - **Choose an Example**: 🌄 Select an example image from the dataset to view its associated diseases.
151
+ - **Upload Your Own Image**: 📤 Upload an image of your choice to analyze it for diseases.
152
+ - **Select Dataset**: 📚 Choose between available datasets (Vindr or PadChest) for disease information.
153
+ - **Select Disease**: 🦠 Pick the disease to be analyzed from the list of diseases in the selected dataset.
154
+ """)
155
+
156
+ st.subheader("⚠️ Warning:")
157
+ st.write("""
158
+ - **🚫 Please avoid uploading non-frontal chest X-ray images**. Our model has been specifically trained on **frontal chest X-ray images**.
159
+ - This demo is intended for **🔬 research purposes only** and should **❌ not be used for medical diagnoses**.
160
+ - The model’s responses may contain **🤖 hallucinations or incorrect information**. Always consult a **👨‍⚕️ medical professional** for accurate diagnosis and advice.
161
+ """)
162
+
163
+ st.markdown("</div>", unsafe_allow_html=True)
164
+
165
+ # Run Inference Button
166
+ if st.button("Run Inference on Example", key="example"):
167
+ if image is None:
168
+ st.error("❌ Please select an example image first.")
169
+ else:
170
+ # Use the selected example's disease and definition for inference
171
+ disease_choice = example_diseases[0] if example_diseases else ""
172
+ definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, ""))
173
+
174
+ # Generate the prompt for the model
175
+ det_obj = f"{disease_choice} means {definition}."
176
+ st.write(f"**Definition:** {definition}")
177
+ prompt = f"Locate the phrases in the caption: {det_obj}."
178
+ prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}"
179
+
180
+ # Prepare the image and input
181
+ np_image = np.array(image)
182
+ inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)
183
+
184
+ with st.spinner("Processing... ⏳"):
185
+ # Generate the result
186
+ generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3)
187
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
188
+
189
+ predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2])
190
+
191
+ detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])
192
+
193
+ # Annotate the image with bounding boxes and labels
194
+ bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
195
+ label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
196
+ image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection)
197
+ image_with_predictions = label_annotator.annotate(image_with_predictions, detection)
198
+ annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8))
199
+
200
+ # Display the original and result images side by side
201
+ col1, col2 = st.columns([1, 1])
202
+
203
+ with col1:
204
+ st.image(image, caption=f"Original Image: {selected_example}", width=400)
205
+
206
+ with col2:
207
+ st.image(annotated_image, caption="Inference Results 🖼️", width=400)
208
+
209
+ # Display the generated text
210
+ st.write("**Generated Text:**", generated_text)
211
+
212
+ # Upload Image section
213
+ st.subheader("📤 Upload Your Own Image")
214
+
215
+ col1, col2 = st.columns([1, 1])
216
+ with col1:
217
+ dataset_choice = st.selectbox("Select Dataset 📚", options=list(dataset_options.keys()))
218
+ disease_options = list(dataset_options[dataset_choice].keys())
219
+ with col2:
220
+ disease_choice = st.selectbox("Select Disease 🦠", options=disease_options)
221
+
222
+ uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
223
+
224
+ # if uploaded_file:
225
+ # image = Image.open(uploaded_file).convert("RGB")
226
+ # image = apply_transform(image) # Ensure the uploaded image is transformed correctly
227
+ # st.image(image, caption="Uploaded Image", width=400)
228
+
229
+ # # Let user select dataset and disease dynamically
230
+ # disease_choice = disease_choice if disease_choice else example_diseases[0]
231
+
232
+ # # Get Definition Priority: Dataset -> User Input
233
+ # definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, ""))
234
+ # if not definition:
235
+ # definition = st.text_input("Enter Definition Manually 📝", value="")
236
+
237
+ col1, col2 = st.columns([1, 2])
238
+
239
+ with col1:
240
+ # Handle file upload
241
+ if uploaded_file:
242
+ image = Image.open(uploaded_file).convert("RGB")
243
+ image = apply_transform(image) # Ensure the uploaded image is transformed correctly
244
+ st.image(image, caption="Uploaded Image", width=400)
245
+
246
+ # Let user select dataset and disease dynamically
247
+ disease_choice = disease_choice if disease_choice else example_diseases[0]
248
+
249
+ # Get Definition Priority: Dataset -> User Input
250
+ definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, ""))
251
+ if not definition:
252
+ definition = st.text_input("Enter Definition Manually 📝", value="")
253
+
254
+ with col2:
255
+ # Instructions and warnings
256
+ st.subheader("⚙️ Instructions to Get Started:")
257
+ st.write("""
258
+ - **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results.
259
+ - **Choose an Example**: 🌄 Select an example image from the dataset to view its associated diseases.
260
+ - **Upload Your Own Image**: 📤 Upload an image of your choice to analyze it for diseases.
261
+ - **Select Dataset**: 📚 Choose between available datasets (Vindr or PadChest) for disease information.
262
+ - **Select Disease**: 🦠 Pick the disease to be analyzed from the list of diseases in the selected dataset.
263
+ """)
264
+
265
+ st.subheader("⚠️ Warning:")
266
+ st.write("""
267
+ - **🚫 Please avoid uploading non-frontal chest X-ray images**. Our model has been specifically trained on **frontal chest X-ray images**.
268
+ - This demo is intended for **🔬 research purposes only** and should **❌ not be used for medical diagnoses**.
269
+ - The model’s responses may contain **🤖 hallucinations or incorrect information**. Always consult a **👨‍⚕️ medical professional** for accurate diagnosis and advice.
270
+ """)
271
+
272
+ # Run inference after upload
273
+ if st.button("Run Inference 🏃‍♂️"):
274
+ if image is None:
275
+ st.error("❌ Please upload an image or select an example.")
276
+ else:
277
+ det_obj = f"{disease_choice} means {definition}."
278
+ st.write(f"**Definition:** {definition}")
279
+
280
+ # Construct Prompt with Disease Definition
281
+ prompt = f"Locate the phrases in the caption: {det_obj}."
282
+ prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}"
283
+
284
+ np_image = np.array(image)
285
+ inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)
286
+
287
+ with st.spinner("Processing... ⏳"):
288
+ generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3)
289
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
290
+
291
+ predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2])
292
+
293
+ detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])
294
+
295
+ bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
296
+ label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
297
+ image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection)
298
+ image_with_predictions = label_annotator.annotate(image_with_predictions, detection)
299
+ annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8))
300
+
301
+ # Create two columns to display the original and the results side by side
302
+ col1, col2 = st.columns([1, 1])
303
+
304
+ # Left column for original image
305
+ with col1:
306
+ st.image(image, caption="Uploaded Image", width=400)
307
+
308
+ # Right column for result image
309
+ with col2:
310
+ st.image(annotated_image, caption="Inference Results 🖼️", width=400)
311
+
312
+ # Display the generated text
313
+ st.write("**Generated Text:**", generated_text)
configs/experiment.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment 1 Configuration
2
+
3
+ model:
4
+ model_type: "microsoft/Florence-2-base-ft"
5
+ lora_config: "configs/lora_config.yaml"
6
+ init_checkpoint: "checkpoints/mimic_model_init.pt"
7
+ processor:
8
+ image_size: 512
9
+ crop_size: 512
10
+ peft:
11
+ use_peft: False
12
+ lora_checkpoint: None
13
+ finetune: true # true
14
+
15
+ trainer:
16
+ checkpoint_dir: "../outputs"
17
+ project_name: "Knowledge-AG" # change to your own wandb project name
18
+ entity_name: "compai" # change to your own wandb entity name
19
+ max_epochs: 50
20
+ train_batch_size: 16
21
+ valid_batch_size: 16
22
+ num_workers: 28
23
+ log_every_n_steps: 100
24
+ gpu: 0
25
+ ddp: true
26
+ optimizer: "adamw"
27
+ learning_rate: 3e-6 #5e-6
28
+ weight_decay: 0.01
29
+
30
+ dataset:
31
+ vindr:
32
+ img_root: "/vol/ciamspace/datasets/X-ray/vindr-cxr/processed/images_512/"
33
+ annotation_csv: "/u/home/lj0/Code/AG-KD-miccai25/annotations/vindr_dataset.csv"
34
+ data_pct: 1.0
35
+
36
+
configs/padchest_definition.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pleural thickening: "Increased thickness of the pleura seen as a dense layer around the lung."
2
+ atelectasis: "Collapsed lung tissue causing darkened or shrunken areas in the lung."
3
+ pleural effusion: "Excess fluid in the pleural space appearing as a shadow around the lungs."
4
+ cardiomegaly: "Enlargement of the heart seen when the heart appears larger than normal."
5
+ aortic elongation: "Lengthened and tortuous aorta, visible as an elongated curving structure."
6
+ vertebral degenerative changes: "Irregular vertebral margins with bony sclerosis and osteophytes."
7
+ aortic atheromatosis: "Calcified deposits in the aortic wall appearing as bright, irregular opacities."
8
+ nodule: "A growth or lump in the lung which may appear as a well-defined or irregular shape."
9
+ alveolar pattern: "Cloud-like, patchy opacities representing fluid or cellular accumulation in alveoli."
10
+ hiatal hernia: "A soft-tissue mass or air-fluid level above the diaphragm, near the midline."
11
+ scoliosis: "Sideways curvature of the spine causing misalignment of vertebral bodies."
12
+ hemidiaphragm elevation: "One side of the diaphragm appearing higher than the other, with convex shape."
13
+ hyperinflated lung: "Abnormally increased lung volume with expanded air spaces."
14
+ interstitial pattern: "Fine reticular or nodular opacities spread across the lung, indicating interstitial involvement."
15
+ fracture: "A break in the bone appearing as a radiolucent line or displacement."
16
+ vascular hilar enlargement: "Increased prominence of the pulmonary vessels near the lung hila."
17
+ nsg tube: "A thin radiopaque tube extending from the nasal cavity into the stomach."
18
+ endotracheal tube: "A thin or opaque line in the middle of the trachea. "
19
+ hypoexpansion: "Reduced lung inflation with increased density and narrow intercostal spaces."
20
+ central venous catheter: "A visible line inside large vein."
21
+ electrical device: "A dense, well-defined metallic opacity, typically a pacemaker or defibrillator."
22
+ bronchiectasis: "Dilated bronchi with thick walls, appearing as tubular or cystic opacities."
23
+ goiter: "A soft tissue mass in the anterior neck, sometimes displacing the trachea."
24
+ other entities: "An unusual mass or area in the lung with irregular borders or density."
configs/vindr_definition.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lung opacity: "An area of increased density in the lung fields typically appearing as a white or grayish patch."
2
+ infiltration: "Accumulation of substances or cells in the lung tissue visible as increased density or nodules."
3
+ consolidation: "Lung tissue filled with fluid or cells causing dense solid areas on imaging."
4
+ nodule or mass: "A growth or lump in the lung which may appear as a well-defined or irregular shape."
5
+ pleural thickening: "Increased thickness of the pleura seen as a dense layer around the lung."
6
+ aortic enlargement: "Widening of the aorta visible as an enlarged artery on imaging."
7
+ pulmonary fibrosis: "Scarring of the lung tissue creating a dense fibrous appearance."
8
+ ild: "Scarring or inflammation of the lung’s interstitial tissue creating a reticular or nodular pattern."
9
+ cardiomegaly: "Enlargement of the heart seen when the heart appears larger than normal."
10
+ other lesion: "An unusual mass or area in the lung with irregular borders or density."
11
+ pleural effusion: "Excess fluid in the pleural space appearing as a shadow around the lungs."
12
+ calcification: "Calcium deposits in lung tissue visible as bright white spots."
13
+ enlarged pa: "Widening of the pulmonary artery seen as an enlarged artery in the chest."
14
+ lung cavity: "Air-filled spaces within the lung often surrounded by dense tissue."
15
+ atelectasis: "Collapsed lung tissue causing darkened or shrunken areas in the lung."
16
+ mediastinal shift: "Displacement of central chest structures like the heart to one side."
17
+ lung cyst: "Fluid-filled spaces in the lung often round with thin walls."
18
+ pneumothorax: "Air trapped in the pleural space creating a gap or absence of lung tissue."
19
+ emphysema: "Enlarged air spaces in the lungs appearing over-expanded or damaged."
20
+ clavicle fracture: "A break in the collarbone seen as a gap or irregularity in the bone."
21
+ rib fracture: "A break in one or more ribs appearing as a visible crack or displacement."
22
+ edema: "Fluid accumulation in the lungs creating a hazy or clouded area."
examples/26746130963764173994750391023442607773-2_mukhp1.png ADDED
examples/f1eb2216d773ced6330b1f31e18f04f8.png ADDED
examples/fb4dfacc089f4b5550f03f52e706b6f2.png ADDED
examples/prompt.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ./examples/26746130963764173994750391023442607773-2_mukhp1.png:
2
+ - electrical device
3
+
4
+ ./examples/f1eb2216d773ced6330b1f31e18f04f8.png:
5
+ - pulmonary fibrosis
6
+
7
+ ./examples/fb4dfacc089f4b5550f03f52e706b6f2.png:
8
+ - cardiomegaly
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ pillow
5
+ numpy
6
+ supervision
7
+ albumentations
8
+ opencv-python
9
+ pyyaml
10
+ einops
11
+ timm