ChayanDeb commited on
Commit
fe2122a
·
verified ·
1 Parent(s): a0055c5

Update Chest_Xray_Report_Generator-Web-V2.py

Browse files
Files changed (1) hide show
  1. Chest_Xray_Report_Generator-Web-V2.py +577 -537
Chest_Xray_Report_Generator-Web-V2.py CHANGED
@@ -1,537 +1,577 @@
1
- import os
2
- import transformers
3
- from transformers import pipeline
4
-
5
- ### Gradio
6
- import gradio as gr
7
- from gradio.themes.base import Base
8
- from gradio.themes.utils import colors, fonts, sizes
9
- from typing import Union, Iterable
10
- import time
11
- #####
12
-
13
-
14
- import cv2
15
- import numpy as np
16
- import pydicom
17
- import re
18
-
19
- ##### Libraries For Grad-Cam-View
20
- import os
21
- import cv2
22
- import numpy as np
23
- import torch
24
- from functools import partial
25
- from torchvision import transforms
26
- from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad
27
- from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
28
- from pytorch_grad_cam.ablation_layer import AblationLayerVit
29
- from transformers import VisionEncoderDecoderModel
30
-
31
-
32
- from transformers import AutoTokenizer
33
- import transformers
34
- import torch
35
-
36
- from openai import OpenAI
37
- client = OpenAI()
38
-
39
- import spaces # Import the spaces module for ZeroGPU
40
-
41
-
42
- @spaces.GPU
43
- def generate_gradcam(image_path, model_path, output_path, method='gradcam', use_cuda=True, aug_smooth=False, eigen_smooth=False):
44
- methods = {
45
- "gradcam": GradCAM,
46
- "scorecam": ScoreCAM,
47
- "gradcam++": GradCAMPlusPlus,
48
- "ablationcam": AblationCAM,
49
- "xgradcam": XGradCAM,
50
- "eigencam": EigenCAM,
51
- "eigengradcam": EigenGradCAM,
52
- "layercam": LayerCAM,
53
- "fullgrad": FullGrad
54
- }
55
-
56
- if method not in methods:
57
- raise ValueError(f"Method should be one of {list(methods.keys())}")
58
-
59
- model = VisionEncoderDecoderModel.from_pretrained(model_path)
60
- model.encoder.eval()
61
-
62
- if use_cuda and torch.cuda.is_available():
63
- model.encoder = model.encoder.cuda()
64
- else:
65
- use_cuda = False
66
-
67
- #target_layers = [model.blocks[-1].norm1] ## For ViT model
68
- #target_layers = model.blocks[-1].norm1 ## For EfficientNet-B7 model
69
- #target_layers = [model.encoder.encoder.layer[-1].layernorm_before] ## For ViT-based VisionEncoderDecoder model
70
- target_layers = [model.encoder.encoder.layers[-1].blocks[-0].layernorm_after, model.encoder.encoder.layers[-1].blocks[-1].layernorm_after] ## [model.encoder.encoder.layers[-1].blocks[-1].layernorm_before, model.encoder.encoder.layers[-1].blocks[0].layernorm_before] For Swin-based VisionEncoderDecoder model
71
-
72
-
73
- if method == "ablationcam":
74
- cam = methods[method](model=model.encoder,
75
- target_layers=target_layers,
76
- use_cuda=use_cuda,
77
- reshape_transform=reshape_transform,
78
- ablation_layer=AblationLayerVit())
79
- else:
80
- cam = methods[method](model=model.encoder,
81
- target_layers=target_layers,
82
- use_cuda=use_cuda,
83
- reshape_transform=reshape_transform)
84
-
85
- rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
86
- rgb_img = cv2.resize(rgb_img, (384, 384)) ## (224, 224)
87
- rgb_img = np.float32(rgb_img) / 255
88
- input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
89
-
90
- targets = None
91
- cam.batch_size = 16
92
-
93
- grayscale_cam = cam(input_tensor=input_tensor, targets=targets, eigen_smooth=eigen_smooth, aug_smooth=aug_smooth)
94
- grayscale_cam = grayscale_cam[0, :]
95
-
96
- cam_image = show_cam_on_image(rgb_img, grayscale_cam)
97
- output_file = os.path.join(output_path, 'gradcam_result.png')
98
- cv2.imwrite(output_file, cam_image)
99
-
100
-
101
-
102
- def reshape_transform(tensor, height=12, width=12): ### height=14, width=14 for ViT-based Model
103
- batch_size, token_number, embed_dim = tensor.size()
104
- if token_number < height * width:
105
- pad = torch.zeros(batch_size, height * width - token_number, embed_dim, device=tensor.device)
106
- tensor = torch.cat([tensor, pad], dim=1)
107
- elif token_number > height * width:
108
- tensor = tensor[:, :height * width, :]
109
-
110
- result = tensor.reshape(batch_size, height, width, embed_dim)
111
- result = result.transpose(2, 3).transpose(1, 2)
112
- return result
113
-
114
-
115
- # Example usage:
116
- #image_path = "/home/chayan/CGI_Net/images/images/CXR1353_IM-0230-1001.png"
117
- model_path = "./Model/"
118
- output_path = "./CAM-Result/"
119
-
120
-
121
-
122
- def sentence_case(paragraph):
123
- sentences = paragraph.split('. ')
124
- formatted_sentences = [sentence.capitalize() for sentence in sentences if sentence]
125
- formatted_paragraph = '. '.join(formatted_sentences)
126
- return formatted_paragraph
127
-
128
- def num2sym_bullets(text, bullet='-'):
129
- """
130
- Replaces '<num>.' bullet points with a specified symbol and formats the text as a bullet list.
131
-
132
- Args:
133
- text (str): Input text containing '<num>.' bullet points.
134
- bullet (str): The symbol to replace '<num>.' with.
135
-
136
- Returns:
137
- str: Modified text with '<num>.' replaced and formatted as a bullet list.
138
- """
139
- sentences = re.split(r'<num>\.\s', text)
140
- formatted_text = '\n'.join(f'{bullet} {sentence.strip()}' for sentence in sentences if sentence.strip())
141
- return formatted_text
142
-
143
- def is_cxr(image_path):
144
- """
145
- Checks if the uploaded image is a Chest X-ray using basic image processing.
146
-
147
- Args:
148
- image_path (str): Path to the uploaded image.
149
-
150
- Returns:
151
- bool: True if the image is likely a Chest X-ray, False otherwise.
152
- """
153
- try:
154
-
155
- image = cv2.imread(image_path)
156
-
157
- if image is None:
158
- raise ValueError("Invalid image path.")
159
-
160
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
161
- color_std = np.std(image, axis=2).mean()
162
-
163
- if color_std > 0:
164
- return False
165
-
166
- return True
167
-
168
- except Exception as e:
169
- print(f"Error processing image: {e}")
170
- return False
171
-
172
- def dicom_to_png(dicom_file, png_file):
173
- # Load DICOM file
174
- dicom_data = pydicom.dcmread(dicom_file)
175
- dicom_data.PhotometricInterpretation = 'MONOCHROME1'
176
-
177
- # Normalize pixel values to 0-255
178
- img = dicom_data.pixel_array
179
- img = img.astype(np.float32)
180
-
181
- img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
182
- img = img.astype(np.uint8)
183
-
184
- # Save as PNG
185
- cv2.imwrite(png_file, img)
186
- return img
187
-
188
-
189
- Image_Captioner = pipeline("image-to-text", model = "./Model/", device = 0)
190
-
191
- data_dir = "./CAM-Result"
192
-
193
- @spaces.GPU(duration=300)
194
- def xray_report_generator(Image_file, Query):
195
- if Image_file[-4:] =='.dcm':
196
- png_file = 'DCM2PNG.png'
197
- dicom_to_png(Image_file, png_file)
198
- Image_file = os.path.join(data_dir, png_file)
199
- output = Image_Captioner(Image_file, max_new_tokens=512)
200
-
201
- else:
202
- output = Image_Captioner(Image_file, max_new_tokens=512)
203
-
204
- result = output[0]['generated_text']
205
- output_paragraph = sentence_case(result)
206
-
207
- final_response = num2sym_bullets(output_paragraph, bullet='-')
208
-
209
- query_prompt = f""" You are analyzing the doctor's query based on the patient's history and the generated chest X-ray report. Extract only the information relevant to the query.
210
- If the report mentions the queried condition, write only the exact wording without any introduction. If the condition is not mentioned, respond with: 'No relevant findings related to [query condition].'.
211
- """
212
-
213
- #If the condition is negated, respond with: 'There is no [query condition].'.
214
-
215
- completion = client.chat.completions.create(
216
- model="gpt-4-turbo", ### gpt-4-turbo ### gpt-3.5-turbo-0125
217
- messages=[
218
- {"role": "system", "content": query_prompt},
219
- {"role": "user", "content": f"Generated Report: {final_response}\nHistory/Doctor's Query: {Query}"}
220
- ],
221
- temperature=0.2)
222
- query_response = completion.choices[0].message.content
223
-
224
- generate_gradcam(Image_file, model_path, output_path, method='gradcam', use_cuda=True)
225
-
226
- grad_cam_image = output_path + 'gradcam_result.png'
227
-
228
- return grad_cam_image, final_response, query_response
229
-
230
-
231
- # def save_feedback(feedback):
232
- # feedback_dir = "Chayan/Feedback/" # Update this to your desired directory
233
- # if not os.path.exists(feedback_dir):
234
- # os.makedirs(feedback_dir)
235
- # feedback_file = os.path.join(feedback_dir, "feedback.txt")
236
- # with open(feedback_file, "a") as f:
237
- # f.write(feedback + "\n")
238
- # return "Feedback submitted successfully!"
239
-
240
-
241
- def save_feedback(feedback):
242
- feedback_dir = "Chayan/Feedback/" # Update this to your desired directory
243
- if not os.path.exists(feedback_dir):
244
- os.makedirs(feedback_dir)
245
- feedback_file = os.path.join(feedback_dir, "feedback.txt")
246
-
247
- try:
248
- with open(feedback_file, "a") as f:
249
- f.write(feedback + "\n")
250
- print(f"Feedback saved at: {feedback_file}")
251
- return "Feedback submitted successfully!"
252
- except Exception as e:
253
- print(f"Error saving feedback: {e}")
254
- return "Failed to submit feedback!"
255
-
256
-
257
- # Custom Theme Definition
258
- class Seafoam(Base):
259
- def __init__(
260
- self,
261
- *,
262
- primary_hue: Union[colors.Color, str] = colors.emerald,
263
- secondary_hue: Union[colors.Color, str] = colors.blue,
264
- neutral_hue: Union[colors.Color, str] = colors.gray,
265
- spacing_size: Union[sizes.Size, str] = sizes.spacing_md,
266
- radius_size: Union[sizes.Size, str] = sizes.radius_md,
267
- text_size: Union[sizes.Size, str] = sizes.text_lg,
268
- font: Union[fonts.Font, str, Iterable[Union[fonts.Font, str]]] = (
269
- fonts.GoogleFont("Quicksand"),
270
- "ui-sans-serif",
271
- "sans-serif",
272
- ),
273
- font_mono: Union[fonts.Font, str, Iterable[Union[fonts.Font, str]]] = (
274
- fonts.GoogleFont("IBM Plex Mono"),
275
- "ui-monospace",
276
- "monospace",
277
- ),
278
- ):
279
- super().__init__(
280
- primary_hue=primary_hue,
281
- secondary_hue=secondary_hue,
282
- neutral_hue=neutral_hue,
283
- spacing_size=spacing_size,
284
- radius_size=radius_size,
285
- text_size=text_size,
286
- font=font,
287
- font_mono=font_mono,
288
- )
289
-
290
- self.set(
291
- body_background_fill="linear-gradient(114.2deg, rgba(184,215,21,1) -15.3%, rgba(21,215,98,1) 14.5%, rgba(21,215,182,1) 38.7%, rgba(129,189,240,1) 58.8%, rgba(219,108,205,1) 77.3%, rgba(240,129,129,1) 88.5%)"
292
- )
293
- # Initialize the theme
294
- seafoam = Seafoam()
295
-
296
-
297
-
298
- # Custom CSS styles
299
- custom_css = """
300
- <style>
301
-
302
- /* Set background color for the entire Gradio app */
303
- body, .gradio-container {
304
- background-color: #f2f7f5 !important;
305
- }
306
-
307
- /* Optional: Add padding or margin for aesthetics */
308
- .gradio-container {
309
- padding: 20px;
310
- }
311
-
312
- #title {
313
- color: green;
314
- font-size: 36px;
315
- font-weight: bold;
316
- }
317
- #description {
318
- color: green;
319
- font-size: 22px;
320
- }
321
-
322
- #title-row {
323
- display: flex;
324
- align-items: center;
325
- gap: 10px;
326
- margin-bottom: 0px;
327
- }
328
- #title-header h1 {
329
- margin: 0;
330
- }
331
-
332
-
333
- #submit-btn {
334
- background-color: #f5dec6; /* Banana leaf */
335
- color: green;
336
- padding: 15px 32px;
337
- text-align: center;
338
- text-decoration: none;
339
- display: inline-block;
340
- font-size: 30px;
341
- margin: 4px 2px;
342
- cursor: pointer;
343
- }
344
- #submit-btn:hover {
345
- background-color: #00FFFF;
346
- }
347
-
348
-
349
- .intext textarea {
350
- color: green;
351
- font-size: 20px;
352
- font-weight: bold;
353
- }
354
-
355
-
356
- .small-button {
357
- color: green;
358
- padding: 5px 10px;
359
- font-size: 20px;
360
- }
361
-
362
- </style>
363
- """
364
-
365
- # Sample image paths
366
- sample_images = [
367
- "./Test-Images/0d930f0a-46f813a9-db3b137b-05142eef-eca3c5a7.jpg",
368
- "./Test-Images/93681764-ec39480e-0518b12c-199850c2-f15118ab.jpg",
369
- "./Test-Images/6ff741e9-6ea01eef-1bf10153-d1b6beba-590b6620.jpg"
370
- #"sample4.png",
371
- #"sample5.png"
372
- ]
373
-
374
- def set_input_image(image_path):
375
- return gr.update(value=image_path)
376
-
377
- def show_contact_info():
378
- yield gr.update(visible=True, value="""
379
- **Contact Us:**
380
- - Chayan Mondal
381
- - Email: [email protected]
382
- - Associate Prof. Sonny Pham
383
- - Email: [email protected]
384
- - Dr. Ashu Gupta
385
- - Email: [email protected]
386
- """)
387
- # Wait for 20 seconds (you can adjust the time as needed)
388
- time.sleep(20)
389
- # Hide the content after 5 seconds
390
- yield gr.update(visible=False)
391
-
392
- def show_acknowledgment():
393
- yield gr.update(visible=True, value="""
394
- **Acknowledgment:**
395
- This Research has been supported by the Western Australian Future Health Research and Innovation Fund.
396
- """)
397
- # Wait for 20 seconds
398
- time.sleep(20)
399
- # Hide the acknowledgment
400
- yield gr.update(visible=False)
401
-
402
-
403
- with gr.Blocks(theme=seafoam, css=custom_css) as demo:
404
-
405
- #gr.HTML(custom_css) # Inject custom CSS
406
-
407
-
408
- with gr.Row(elem_id="title-row"):
409
- with gr.Column(scale=0):
410
- gr.Image(
411
- value="./AURA-CXR-Logo.png",
412
- show_label=False,
413
- width=60,
414
- container=False
415
- )
416
- with gr.Column():
417
- gr.Markdown(
418
- """
419
- <h1 style="color:blue; font-size: 32px; font-weight: bold; margin: 0;">
420
- AURA-CXR: Explainable Diagnosis of Chest Diseases from X-rays
421
- </h1>
422
- """,
423
- elem_id="title-header"
424
- )
425
-
426
- gr.Markdown(
427
- "<p id='description'>Upload an X-ray image and get its report with heat-map visualization.</p>"
428
- )
429
-
430
-
431
-
432
- # gr.Markdown(
433
- # """
434
- # <h1 style="color:blue; font-size: 36px; font-weight: bold; margin: 0;">AURA-CXR: Explainable Diagnosis of Chest Diseases from X-rays</h1>
435
- # <p id="description">Upload an X-ray image and get its report with heat-map visualization.</p>
436
- # """
437
- # )
438
-
439
- #<h1 style="color:blue; font-size: 36px; font-weight: bold">AURA-CXR: Explainable Diagnosis of Chest Diseases from X-rays</h1>
440
-
441
- with gr.Row():
442
- inputs = gr.File(label="Upload Chest X-ray Image File", type="filepath")
443
-
444
- with gr.Row():
445
- with gr.Column(scale=1, min_width=300):
446
- outputs1 = gr.Image(label="Image Viewer")
447
- history_query = gr.Textbox(label="History/Doctor's Query", elem_classes="intext")
448
- with gr.Column(scale=1, min_width=300):
449
- outputs2 = gr.Image(label="Grad_CAM-Visualization")
450
- with gr.Column(scale=1, min_width=300):
451
- outputs3 = gr.Textbox(label="Generated Report", elem_classes = "intext")
452
- outputs4 = gr.Textbox(label = "Query's Response", elem_classes = "intext")
453
-
454
-
455
- submit_btn = gr.Button("Generate Report", elem_id="submit-btn", variant="primary")
456
-
457
- def show_image(file_path):
458
- if is_cxr(file_path): # Check if it's a valid Chest X-ray
459
- return file_path, "Valid Image" # Show the image in Image Viewer
460
- else:
461
- return None, "Invalid image. Please upload a proper Chest X-ray."
462
-
463
-
464
- # Show the uploaded image immediately in the Image Viewer
465
- inputs.change(
466
- fn=show_image, # Calls the function to return the same file path
467
- inputs=inputs,
468
- outputs=[outputs1, outputs3]
469
- )
470
-
471
-
472
-
473
-
474
- submit_btn.click(
475
- fn=xray_report_generator,
476
- inputs=[inputs,history_query],
477
- outputs=[outputs2, outputs3, outputs4])
478
-
479
-
480
- gr.Markdown(
481
- """
482
- <h2 style="color:green; font-size: 24px;">Or choose a sample image:</h2>
483
- """
484
- )
485
-
486
- with gr.Row():
487
- for idx, sample_image in enumerate(sample_images):
488
- with gr.Column(scale=1):
489
- #sample_image_component = gr.Image(value=sample_image, interactive=False)
490
- select_button = gr.Button(f"Select Sample Image {idx+1}")
491
- select_button.click(
492
- fn=set_input_image,
493
- inputs=gr.State(value=sample_image),
494
- outputs=inputs
495
- )
496
-
497
-
498
- # Feedback section
499
- gr.Markdown(
500
- """
501
- <h2 style="color:green; font-size: 24px;">Provide Your Valuable Feedback:</h2>
502
- """
503
- )
504
-
505
- with gr.Row():
506
- feedback_input = gr.Textbox(label="Your Feedback", lines=4, placeholder="Enter your feedback here...")
507
- feedback_submit_btn = gr.Button("Submit Feedback", elem_classes="small-button", variant="secondary")
508
- feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
509
-
510
-
511
-
512
- feedback_submit_btn.click(
513
- fn=save_feedback,
514
- inputs=feedback_input,
515
- outputs=feedback_output
516
- )
517
-
518
-
519
- # Buttons and Markdown for Contact Us and Acknowledgment
520
- with gr.Row():
521
- contact_btn = gr.Button("Contact Us", elem_classes="small-button", variant="secondary")
522
- ack_btn = gr.Button("Acknowledgment", elem_classes="small-button", variant="secondary")
523
-
524
- contact_info = gr.Markdown(visible=False) # Initially hidden
525
- acknowledgment_info = gr.Markdown(visible=False) # Initially hidden
526
-
527
- # Update the content and make it visible when the buttons are clicked
528
- contact_btn.click(fn=show_contact_info, outputs=contact_info, show_progress=False)
529
- ack_btn.click(fn=show_acknowledgment, outputs=acknowledgment_info, show_progress=False)
530
-
531
- # Update the content and make it visible when the buttons are clicked
532
- # contact_btn.click(fn=show_contact_info, outputs=contact_info, show_progress=False)
533
- # ack_btn.click(fn=show_acknowledgment, outputs=acknowledgment_info, show_progress=False)
534
-
535
-
536
- demo.launch(share=True)
537
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import transformers
3
+ from transformers import pipeline
4
+
5
+ ### Gradio
6
+ import gradio as gr
7
+ from gradio.themes.base import Base
8
+ from gradio.themes.utils import colors, fonts, sizes
9
+ from typing import Union, Iterable
10
+ import time
11
+ #####
12
+
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import pydicom
17
+ import re
18
+
19
+ ##### Libraries For Grad-Cam-View
20
+ import os
21
+ import cv2
22
+ import numpy as np
23
+ import torch
24
+ from functools import partial
25
+ from torchvision import transforms
26
+ from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad
27
+ from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
28
+ from pytorch_grad_cam.ablation_layer import AblationLayerVit
29
+ from transformers import VisionEncoderDecoderModel
30
+
31
+
32
+ from transformers import AutoTokenizer
33
+ import transformers
34
+ import torch
35
+
36
+ from openai import OpenAI
37
+ client = OpenAI()
38
+
39
+ # === SET THESE ===
40
+ # REPO_ID = "ChayanDeb/AURA-CXR_Feedback"
41
+ # LOCAL_DIR = "feedback_repo"
42
+ # FEEDBACK_FILE = os.path.join(LOCAL_DIR, "feedback.txt")
43
+
44
+ # # Only once at app start
45
+ # def setup_repo():
46
+ # token = os.getenv("HF_TOKEN")
47
+ # if not token:
48
+ # raise ValueError("HF_TOKEN not set in environment variables")
49
+
50
+ # login(token=token)
51
+
52
+ # if not os.path.exists(LOCAL_DIR):
53
+ # print("Cloning feedback repo...")
54
+ # Repository(local_dir=LOCAL_DIR, clone_from=REPO_ID, use_auth_token=token)
55
+ # else:
56
+ # print("Repo already exists locally.")
57
+
58
+ # setup_repo()
59
+
60
+ # # Call this on user feedback submission
61
+ # def save_feedback(feedback):
62
+ # try:
63
+ # now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
64
+ # entry = f"{now} | {feedback.strip()}\n"
65
+
66
+ # # Append to local file
67
+ # with open(FEEDBACK_FILE, "a") as f:
68
+ # f.write(entry)
69
+
70
+ # # Push to hub
71
+ # repo = Repository(local_dir=LOCAL_DIR, clone_from=REPO_ID, use_auth_token=os.getenv("HF_TOKEN"))
72
+ # repo.push_to_hub(commit_message="Add new feedback")
73
+
74
+ # return "Feedback submitted!"
75
+ # except Exception as e:
76
+ # print("Error:", e)
77
+ # return "Failed to submit feedback."
78
+
79
+ import spaces # Import the spaces module for ZeroGPU
80
+
81
+
82
+ @spaces.GPU
83
+ def generate_gradcam(image_path, model_path, output_path, method='gradcam', use_cuda=True, aug_smooth=False, eigen_smooth=False):
84
+ methods = {
85
+ "gradcam": GradCAM,
86
+ "scorecam": ScoreCAM,
87
+ "gradcam++": GradCAMPlusPlus,
88
+ "ablationcam": AblationCAM,
89
+ "xgradcam": XGradCAM,
90
+ "eigencam": EigenCAM,
91
+ "eigengradcam": EigenGradCAM,
92
+ "layercam": LayerCAM,
93
+ "fullgrad": FullGrad
94
+ }
95
+
96
+ if method not in methods:
97
+ raise ValueError(f"Method should be one of {list(methods.keys())}")
98
+
99
+ model = VisionEncoderDecoderModel.from_pretrained(model_path)
100
+ model.encoder.eval()
101
+
102
+ if use_cuda and torch.cuda.is_available():
103
+ model.encoder = model.encoder.cuda()
104
+ else:
105
+ use_cuda = False
106
+
107
+ #target_layers = [model.blocks[-1].norm1] ## For ViT model
108
+ #target_layers = model.blocks[-1].norm1 ## For EfficientNet-B7 model
109
+ #target_layers = [model.encoder.encoder.layer[-1].layernorm_before] ## For ViT-based VisionEncoderDecoder model
110
+ target_layers = [model.encoder.encoder.layers[-1].blocks[-0].layernorm_after, model.encoder.encoder.layers[-1].blocks[-1].layernorm_after] ## [model.encoder.encoder.layers[-1].blocks[-1].layernorm_before, model.encoder.encoder.layers[-1].blocks[0].layernorm_before] For Swin-based VisionEncoderDecoder model
111
+
112
+
113
+ if method == "ablationcam":
114
+ cam = methods[method](model=model.encoder,
115
+ target_layers=target_layers,
116
+ use_cuda=use_cuda,
117
+ reshape_transform=reshape_transform,
118
+ ablation_layer=AblationLayerVit())
119
+ else:
120
+ cam = methods[method](model=model.encoder,
121
+ target_layers=target_layers,
122
+ use_cuda=use_cuda,
123
+ reshape_transform=reshape_transform)
124
+
125
+ rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
126
+ rgb_img = cv2.resize(rgb_img, (384, 384)) ## (224, 224)
127
+ rgb_img = np.float32(rgb_img) / 255
128
+ input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
129
+
130
+ targets = None
131
+ cam.batch_size = 16
132
+
133
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets, eigen_smooth=eigen_smooth, aug_smooth=aug_smooth)
134
+ grayscale_cam = grayscale_cam[0, :]
135
+
136
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam)
137
+ output_file = os.path.join(output_path, 'gradcam_result.png')
138
+ cv2.imwrite(output_file, cam_image)
139
+
140
+
141
+
142
+ def reshape_transform(tensor, height=12, width=12): ### height=14, width=14 for ViT-based Model
143
+ batch_size, token_number, embed_dim = tensor.size()
144
+ if token_number < height * width:
145
+ pad = torch.zeros(batch_size, height * width - token_number, embed_dim, device=tensor.device)
146
+ tensor = torch.cat([tensor, pad], dim=1)
147
+ elif token_number > height * width:
148
+ tensor = tensor[:, :height * width, :]
149
+
150
+ result = tensor.reshape(batch_size, height, width, embed_dim)
151
+ result = result.transpose(2, 3).transpose(1, 2)
152
+ return result
153
+
154
+
155
+ # Example usage:
156
+ #image_path = "/home/chayan/CGI_Net/images/images/CXR1353_IM-0230-1001.png"
157
+ model_path = "./Model/"
158
+ output_path = "./CAM-Result/"
159
+
160
+
161
+
162
+ def sentence_case(paragraph):
163
+ sentences = paragraph.split('. ')
164
+ formatted_sentences = [sentence.capitalize() for sentence in sentences if sentence]
165
+ formatted_paragraph = '. '.join(formatted_sentences)
166
+ return formatted_paragraph
167
+
168
+ def num2sym_bullets(text, bullet='-'):
169
+ """
170
+ Replaces '<num>.' bullet points with a specified symbol and formats the text as a bullet list.
171
+
172
+ Args:
173
+ text (str): Input text containing '<num>.' bullet points.
174
+ bullet (str): The symbol to replace '<num>.' with.
175
+
176
+ Returns:
177
+ str: Modified text with '<num>.' replaced and formatted as a bullet list.
178
+ """
179
+ sentences = re.split(r'<num>\.\s', text)
180
+ formatted_text = '\n'.join(f'{bullet} {sentence.strip()}' for sentence in sentences if sentence.strip())
181
+ return formatted_text
182
+
183
+ def is_cxr(image_path):
184
+ """
185
+ Checks if the uploaded image is a Chest X-ray using basic image processing.
186
+
187
+ Args:
188
+ image_path (str): Path to the uploaded image.
189
+
190
+ Returns:
191
+ bool: True if the image is likely a Chest X-ray, False otherwise.
192
+ """
193
+ try:
194
+
195
+ image = cv2.imread(image_path)
196
+
197
+ if image is None:
198
+ raise ValueError("Invalid image path.")
199
+
200
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
201
+ color_std = np.std(image, axis=2).mean()
202
+
203
+ if color_std > 0:
204
+ return False
205
+
206
+ return True
207
+
208
+ except Exception as e:
209
+ print(f"Error processing image: {e}")
210
+ return False
211
+
212
+ def dicom_to_png(dicom_file, png_file):
213
+ # Load DICOM file
214
+ dicom_data = pydicom.dcmread(dicom_file)
215
+ dicom_data.PhotometricInterpretation = 'MONOCHROME1'
216
+
217
+ # Normalize pixel values to 0-255
218
+ img = dicom_data.pixel_array
219
+ img = img.astype(np.float32)
220
+
221
+ img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
222
+ img = img.astype(np.uint8)
223
+
224
+ # Save as PNG
225
+ cv2.imwrite(png_file, img)
226
+ return img
227
+
228
+
229
+ Image_Captioner = pipeline("image-to-text", model = "./Model/", device = 0)
230
+
231
+ data_dir = "./CAM-Result"
232
+
233
+ @spaces.GPU(duration=300)
234
+ def xray_report_generator(Image_file, Query):
235
+ if Image_file[-4:] =='.dcm':
236
+ png_file = 'DCM2PNG.png'
237
+ dicom_to_png(Image_file, png_file)
238
+ Image_file = os.path.join(data_dir, png_file)
239
+ output = Image_Captioner(Image_file, max_new_tokens=512)
240
+
241
+ else:
242
+ output = Image_Captioner(Image_file, max_new_tokens=512)
243
+
244
+ result = output[0]['generated_text']
245
+ output_paragraph = sentence_case(result)
246
+
247
+ final_response = num2sym_bullets(output_paragraph, bullet='-')
248
+
249
+ query_prompt = f""" You are analyzing the doctor's query based on the patient's history and the generated chest X-ray report. Extract only the information relevant to the query.
250
+ If the report mentions the queried condition, write only the exact wording without any introduction. If the condition is not mentioned, respond with: 'No relevant findings related to [query condition].'.
251
+ """
252
+
253
+ #If the condition is negated, respond with: 'There is no [query condition].'.
254
+
255
+ completion = client.chat.completions.create(
256
+ model="gpt-4-turbo", ### gpt-4-turbo ### gpt-3.5-turbo-0125
257
+ messages=[
258
+ {"role": "system", "content": query_prompt},
259
+ {"role": "user", "content": f"Generated Report: {final_response}\nHistory/Doctor's Query: {Query}"}
260
+ ],
261
+ temperature=0.2)
262
+ query_response = completion.choices[0].message.content
263
+
264
+ generate_gradcam(Image_file, model_path, output_path, method='gradcam', use_cuda=True)
265
+
266
+ grad_cam_image = output_path + 'gradcam_result.png'
267
+
268
+ return grad_cam_image, final_response, query_response
269
+
270
+
271
+ # def save_feedback(feedback):
272
+ # feedback_dir = "Chayan/Feedback/" # Update this to your desired directory
273
+ # if not os.path.exists(feedback_dir):
274
+ # os.makedirs(feedback_dir)
275
+ # feedback_file = os.path.join(feedback_dir, "feedback.txt")
276
+ # with open(feedback_file, "a") as f:
277
+ # f.write(feedback + "\n")
278
+ # return "Feedback submitted successfully!"
279
+
280
+
281
+ def save_feedback(feedback):
282
+ feedback_dir = "Chayan/Feedback/" # Update this to your desired directory
283
+ if not os.path.exists(feedback_dir):
284
+ os.makedirs(feedback_dir)
285
+ feedback_file = os.path.join(feedback_dir, "feedback.txt")
286
+
287
+ try:
288
+ with open(feedback_file, "a") as f:
289
+ f.write(feedback + "\n")
290
+ print(f"Feedback saved at: {feedback_file}")
291
+ return "Feedback submitted successfully!"
292
+ except Exception as e:
293
+ print(f"Error saving feedback: {e}")
294
+ return "Failed to submit feedback!"
295
+
296
+
297
+ # Custom Theme Definition
298
+ class Seafoam(Base):
299
+ def __init__(
300
+ self,
301
+ *,
302
+ primary_hue: Union[colors.Color, str] = colors.emerald,
303
+ secondary_hue: Union[colors.Color, str] = colors.blue,
304
+ neutral_hue: Union[colors.Color, str] = colors.gray,
305
+ spacing_size: Union[sizes.Size, str] = sizes.spacing_md,
306
+ radius_size: Union[sizes.Size, str] = sizes.radius_md,
307
+ text_size: Union[sizes.Size, str] = sizes.text_lg,
308
+ font: Union[fonts.Font, str, Iterable[Union[fonts.Font, str]]] = (
309
+ fonts.GoogleFont("Quicksand"),
310
+ "ui-sans-serif",
311
+ "sans-serif",
312
+ ),
313
+ font_mono: Union[fonts.Font, str, Iterable[Union[fonts.Font, str]]] = (
314
+ fonts.GoogleFont("IBM Plex Mono"),
315
+ "ui-monospace",
316
+ "monospace",
317
+ ),
318
+ ):
319
+ super().__init__(
320
+ primary_hue=primary_hue,
321
+ secondary_hue=secondary_hue,
322
+ neutral_hue=neutral_hue,
323
+ spacing_size=spacing_size,
324
+ radius_size=radius_size,
325
+ text_size=text_size,
326
+ font=font,
327
+ font_mono=font_mono,
328
+ )
329
+
330
+ self.set(
331
+ body_background_fill="linear-gradient(114.2deg, rgba(184,215,21,1) -15.3%, rgba(21,215,98,1) 14.5%, rgba(21,215,182,1) 38.7%, rgba(129,189,240,1) 58.8%, rgba(219,108,205,1) 77.3%, rgba(240,129,129,1) 88.5%)"
332
+ )
333
+ # Initialize the theme
334
+ seafoam = Seafoam()
335
+
336
+
337
+
338
+ # Custom CSS styles
339
+ custom_css = """
340
+ <style>
341
+
342
+ /* Set background color for the entire Gradio app */
343
+ body, .gradio-container {
344
+ background-color: #f2f7f5 !important;
345
+ }
346
+
347
+ /* Optional: Add padding or margin for aesthetics */
348
+ .gradio-container {
349
+ padding: 20px;
350
+ }
351
+
352
+ #title {
353
+ color: green;
354
+ font-size: 36px;
355
+ font-weight: bold;
356
+ }
357
+ #description {
358
+ color: green;
359
+ font-size: 22px;
360
+ }
361
+
362
+ #title-row {
363
+ display: flex;
364
+ align-items: center;
365
+ gap: 10px;
366
+ margin-bottom: 0px;
367
+ }
368
+ #title-header h1 {
369
+ margin: 0;
370
+ }
371
+
372
+
373
+ #submit-btn {
374
+ background-color: #f5dec6; /* Banana leaf */
375
+ color: green;
376
+ padding: 15px 32px;
377
+ text-align: center;
378
+ text-decoration: none;
379
+ display: inline-block;
380
+ font-size: 30px;
381
+ margin: 4px 2px;
382
+ cursor: pointer;
383
+ }
384
+ #submit-btn:hover {
385
+ background-color: #00FFFF;
386
+ }
387
+
388
+
389
+ .intext textarea {
390
+ color: green;
391
+ font-size: 20px;
392
+ font-weight: bold;
393
+ }
394
+
395
+
396
+ .small-button {
397
+ color: green;
398
+ padding: 5px 10px;
399
+ font-size: 20px;
400
+ }
401
+
402
+ </style>
403
+ """
404
+
405
+ # Sample image paths
406
+ sample_images = [
407
+ "./Test-Images/0d930f0a-46f813a9-db3b137b-05142eef-eca3c5a7.jpg",
408
+ "./Test-Images/93681764-ec39480e-0518b12c-199850c2-f15118ab.jpg",
409
+ "./Test-Images/6ff741e9-6ea01eef-1bf10153-d1b6beba-590b6620.jpg"
410
+ #"sample4.png",
411
+ #"sample5.png"
412
+ ]
413
+
414
+ def set_input_image(image_path):
415
+ return gr.update(value=image_path)
416
+
417
+ def show_contact_info():
418
+ yield gr.update(visible=True, value="""
419
+ **Contact Us:**
420
+ - Chayan Mondal
421
+ - Email: [email protected]
422
+ - Associate Prof. Sonny Pham
423
+ - Email: [email protected]
424
+ - Dr. Ashu Gupta
425
+ - Email: [email protected]
426
+ """)
427
+ # Wait for 20 seconds (you can adjust the time as needed)
428
+ time.sleep(20)
429
+ # Hide the content after 5 seconds
430
+ yield gr.update(visible=False)
431
+
432
+ def show_acknowledgment():
433
+ yield gr.update(visible=True, value="""
434
+ **Acknowledgment:**
435
+ This Research has been supported by the Western Australian Future Health Research and Innovation Fund.
436
+ """)
437
+ # Wait for 20 seconds
438
+ time.sleep(20)
439
+ # Hide the acknowledgment
440
+ yield gr.update(visible=False)
441
+
442
+
443
+ with gr.Blocks(theme=seafoam, css=custom_css) as demo:
444
+
445
+ #gr.HTML(custom_css) # Inject custom CSS
446
+
447
+
448
+ with gr.Row(elem_id="title-row"):
449
+ with gr.Column(scale=0):
450
+ gr.Image(
451
+ value="./AURA-CXR-Logo.png",
452
+ show_label=False,
453
+ width=60,
454
+ container=False
455
+ )
456
+ with gr.Column():
457
+ gr.Markdown(
458
+ """
459
+ <h1 style="color:blue; font-size: 32px; font-weight: bold; margin: 0;">
460
+ AURA-CXR: Explainable Diagnosis of Chest Diseases from X-rays
461
+ </h1>
462
+ """,
463
+ elem_id="title-header"
464
+ )
465
+
466
+ gr.Markdown(
467
+ "<p id='description'>Upload an X-ray image and get its report with heat-map visualization.</p>"
468
+ )
469
+
470
+
471
+
472
+ # gr.Markdown(
473
+ # """
474
+ # <h1 style="color:blue; font-size: 36px; font-weight: bold; margin: 0;">AURA-CXR: Explainable Diagnosis of Chest Diseases from X-rays</h1>
475
+ # <p id="description">Upload an X-ray image and get its report with heat-map visualization.</p>
476
+ # """
477
+ # )
478
+
479
+ #<h1 style="color:blue; font-size: 36px; font-weight: bold">AURA-CXR: Explainable Diagnosis of Chest Diseases from X-rays</h1>
480
+
481
+ with gr.Row():
482
+ inputs = gr.File(label="Upload Chest X-ray Image File", type="filepath")
483
+
484
+ with gr.Row():
485
+ with gr.Column(scale=1, min_width=300):
486
+ outputs1 = gr.Image(label="Image Viewer")
487
+ history_query = gr.Textbox(label="History/Doctor's Query", elem_classes="intext")
488
+ with gr.Column(scale=1, min_width=300):
489
+ outputs2 = gr.Image(label="Grad_CAM-Visualization")
490
+ with gr.Column(scale=1, min_width=300):
491
+ outputs3 = gr.Textbox(label="Generated Report", elem_classes = "intext")
492
+ outputs4 = gr.Textbox(label = "Query's Response", elem_classes = "intext")
493
+
494
+
495
+ submit_btn = gr.Button("Generate Report", elem_id="submit-btn", variant="primary")
496
+
497
+ def show_image(file_path):
498
+ if is_cxr(file_path): # Check if it's a valid Chest X-ray
499
+ return file_path, "Valid Image" # Show the image in Image Viewer
500
+ else:
501
+ return None, "Invalid image. Please upload a proper Chest X-ray."
502
+
503
+
504
+ # Show the uploaded image immediately in the Image Viewer
505
+ inputs.change(
506
+ fn=show_image, # Calls the function to return the same file path
507
+ inputs=inputs,
508
+ outputs=[outputs1, outputs3]
509
+ )
510
+
511
+
512
+
513
+
514
+ submit_btn.click(
515
+ fn=xray_report_generator,
516
+ inputs=[inputs,history_query],
517
+ outputs=[outputs2, outputs3, outputs4])
518
+
519
+
520
+ gr.Markdown(
521
+ """
522
+ <h2 style="color:green; font-size: 24px;">Or choose a sample image:</h2>
523
+ """
524
+ )
525
+
526
+ with gr.Row():
527
+ for idx, sample_image in enumerate(sample_images):
528
+ with gr.Column(scale=1):
529
+ #sample_image_component = gr.Image(value=sample_image, interactive=False)
530
+ select_button = gr.Button(f"Select Sample Image {idx+1}")
531
+ select_button.click(
532
+ fn=set_input_image,
533
+ inputs=gr.State(value=sample_image),
534
+ outputs=inputs
535
+ )
536
+
537
+
538
+ # Feedback section
539
+ gr.Markdown(
540
+ """
541
+ <h2 style="color:green; font-size: 24px;">Provide Your Valuable Feedback:</h2>
542
+ """
543
+ )
544
+
545
+ with gr.Row():
546
+ feedback_input = gr.Textbox(label="Your Feedback", lines=4, placeholder="Enter your feedback here...")
547
+ feedback_submit_btn = gr.Button("Submit Feedback", elem_classes="small-button", variant="secondary")
548
+ feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
549
+
550
+
551
+
552
+ feedback_submit_btn.click(
553
+ fn=save_feedback,
554
+ inputs=feedback_input,
555
+ outputs=feedback_output
556
+ )
557
+
558
+
559
+ # Buttons and Markdown for Contact Us and Acknowledgment
560
+ with gr.Row():
561
+ contact_btn = gr.Button("Contact Us", elem_classes="small-button", variant="secondary")
562
+ ack_btn = gr.Button("Acknowledgment", elem_classes="small-button", variant="secondary")
563
+
564
+ contact_info = gr.Markdown(visible=False) # Initially hidden
565
+ acknowledgment_info = gr.Markdown(visible=False) # Initially hidden
566
+
567
+ # Update the content and make it visible when the buttons are clicked
568
+ contact_btn.click(fn=show_contact_info, outputs=contact_info, show_progress=False)
569
+ ack_btn.click(fn=show_acknowledgment, outputs=acknowledgment_info, show_progress=False)
570
+
571
+ # Update the content and make it visible when the buttons are clicked
572
+ # contact_btn.click(fn=show_contact_info, outputs=contact_info, show_progress=False)
573
+ # ack_btn.click(fn=show_acknowledgment, outputs=acknowledgment_info, show_progress=False)
574
+
575
+
576
+ demo.launch(share=True)
577
+