|
import os
|
|
import transformers
|
|
from transformers import pipeline
|
|
|
|
|
|
import gradio as gr
|
|
from gradio.themes.base import Base
|
|
from gradio.themes.utils import colors, fonts, sizes
|
|
from typing import Union, Iterable
|
|
import time
|
|
|
|
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import pydicom
|
|
import re
|
|
|
|
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from functools import partial
|
|
from torchvision import transforms
|
|
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
|
|
from pytorch_grad_cam.ablation_layer import AblationLayerVit
|
|
from transformers import VisionEncoderDecoderModel
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
import transformers
|
|
import torch
|
|
|
|
from openai import OpenAI
|
|
client = OpenAI()
|
|
|
|
import spaces
|
|
|
|
|
|
@spaces.GPU
|
|
def generate_gradcam(image_path, model_path, output_path, method='gradcam', use_cuda=True, aug_smooth=False, eigen_smooth=False):
|
|
methods = {
|
|
"gradcam": GradCAM,
|
|
"scorecam": ScoreCAM,
|
|
"gradcam++": GradCAMPlusPlus,
|
|
"ablationcam": AblationCAM,
|
|
"xgradcam": XGradCAM,
|
|
"eigencam": EigenCAM,
|
|
"eigengradcam": EigenGradCAM,
|
|
"layercam": LayerCAM,
|
|
"fullgrad": FullGrad
|
|
}
|
|
|
|
if method not in methods:
|
|
raise ValueError(f"Method should be one of {list(methods.keys())}")
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_path)
|
|
model.encoder.eval()
|
|
|
|
if use_cuda and torch.cuda.is_available():
|
|
model.encoder = model.encoder.cuda()
|
|
else:
|
|
use_cuda = False
|
|
|
|
|
|
|
|
|
|
target_layers = [model.encoder.encoder.layers[-1].blocks[-0].layernorm_after, model.encoder.encoder.layers[-1].blocks[-1].layernorm_after]
|
|
|
|
|
|
if method == "ablationcam":
|
|
cam = methods[method](model=model.encoder,
|
|
target_layers=target_layers,
|
|
use_cuda=use_cuda,
|
|
reshape_transform=reshape_transform,
|
|
ablation_layer=AblationLayerVit())
|
|
else:
|
|
cam = methods[method](model=model.encoder,
|
|
target_layers=target_layers,
|
|
use_cuda=use_cuda,
|
|
reshape_transform=reshape_transform)
|
|
|
|
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
|
|
rgb_img = cv2.resize(rgb_img, (384, 384))
|
|
rgb_img = np.float32(rgb_img) / 255
|
|
input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
|
|
targets = None
|
|
cam.batch_size = 16
|
|
|
|
grayscale_cam = cam(input_tensor=input_tensor, targets=targets, eigen_smooth=eigen_smooth, aug_smooth=aug_smooth)
|
|
grayscale_cam = grayscale_cam[0, :]
|
|
|
|
cam_image = show_cam_on_image(rgb_img, grayscale_cam)
|
|
output_file = os.path.join(output_path, 'gradcam_result.png')
|
|
cv2.imwrite(output_file, cam_image)
|
|
|
|
|
|
|
|
def reshape_transform(tensor, height=12, width=12):
|
|
batch_size, token_number, embed_dim = tensor.size()
|
|
if token_number < height * width:
|
|
pad = torch.zeros(batch_size, height * width - token_number, embed_dim, device=tensor.device)
|
|
tensor = torch.cat([tensor, pad], dim=1)
|
|
elif token_number > height * width:
|
|
tensor = tensor[:, :height * width, :]
|
|
|
|
result = tensor.reshape(batch_size, height, width, embed_dim)
|
|
result = result.transpose(2, 3).transpose(1, 2)
|
|
return result
|
|
|
|
|
|
|
|
|
|
model_path = "./Model/"
|
|
output_path = "./CAM-Result/"
|
|
|
|
|
|
|
|
def sentence_case(paragraph):
|
|
sentences = paragraph.split('. ')
|
|
formatted_sentences = [sentence.capitalize() for sentence in sentences if sentence]
|
|
formatted_paragraph = '. '.join(formatted_sentences)
|
|
return formatted_paragraph
|
|
|
|
def num2sym_bullets(text, bullet='-'):
|
|
"""
|
|
Replaces '<num>.' bullet points with a specified symbol and formats the text as a bullet list.
|
|
|
|
Args:
|
|
text (str): Input text containing '<num>.' bullet points.
|
|
bullet (str): The symbol to replace '<num>.' with.
|
|
|
|
Returns:
|
|
str: Modified text with '<num>.' replaced and formatted as a bullet list.
|
|
"""
|
|
sentences = re.split(r'<num>\.\s', text)
|
|
formatted_text = '\n'.join(f'{bullet} {sentence.strip()}' for sentence in sentences if sentence.strip())
|
|
return formatted_text
|
|
|
|
def is_cxr(image_path):
|
|
"""
|
|
Checks if the uploaded image is a Chest X-ray using basic image processing.
|
|
|
|
Args:
|
|
image_path (str): Path to the uploaded image.
|
|
|
|
Returns:
|
|
bool: True if the image is likely a Chest X-ray, False otherwise.
|
|
"""
|
|
try:
|
|
|
|
image = cv2.imread(image_path)
|
|
|
|
if image is None:
|
|
raise ValueError("Invalid image path.")
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
color_std = np.std(image, axis=2).mean()
|
|
|
|
if color_std > 0:
|
|
return False
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Error processing image: {e}")
|
|
return False
|
|
|
|
def dicom_to_png(dicom_file, png_file):
|
|
|
|
dicom_data = pydicom.dcmread(dicom_file)
|
|
dicom_data.PhotometricInterpretation = 'MONOCHROME1'
|
|
|
|
|
|
img = dicom_data.pixel_array
|
|
img = img.astype(np.float32)
|
|
|
|
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
|
|
img = img.astype(np.uint8)
|
|
|
|
|
|
cv2.imwrite(png_file, img)
|
|
return img
|
|
|
|
|
|
Image_Captioner = pipeline("image-to-text", model = "./Model/", device = 0)
|
|
|
|
data_dir = "./CAM-Result"
|
|
|
|
@spaces.GPU(duration=300)
|
|
def xray_report_generator(Image_file, Query):
|
|
if Image_file[-4:] =='.dcm':
|
|
png_file = 'DCM2PNG.png'
|
|
dicom_to_png(Image_file, png_file)
|
|
Image_file = os.path.join(data_dir, png_file)
|
|
output = Image_Captioner(Image_file, max_new_tokens=512)
|
|
|
|
else:
|
|
output = Image_Captioner(Image_file, max_new_tokens=512)
|
|
|
|
result = output[0]['generated_text']
|
|
output_paragraph = sentence_case(result)
|
|
|
|
final_response = num2sym_bullets(output_paragraph, bullet='-')
|
|
|
|
query_prompt = f""" You are analyzing the doctor's query based on the patient's history and the generated chest X-ray report. Extract only the information relevant to the query.
|
|
If the report mentions the queried condition, write only the exact wording without any introduction. If the condition is not mentioned, respond with: 'No relevant findings related to [query condition].'.
|
|
"""
|
|
|
|
|
|
|
|
completion = client.chat.completions.create(
|
|
model="gpt-4-turbo",
|
|
messages=[
|
|
{"role": "system", "content": query_prompt},
|
|
{"role": "user", "content": f"Generated Report: {final_response}\nHistory/Doctor's Query: {Query}"}
|
|
],
|
|
temperature=0.2)
|
|
query_response = completion.choices[0].message.content
|
|
|
|
generate_gradcam(Image_file, model_path, output_path, method='gradcam', use_cuda=True)
|
|
|
|
grad_cam_image = output_path + 'gradcam_result.png'
|
|
|
|
return grad_cam_image, final_response, query_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_feedback(feedback):
|
|
feedback_dir = "Chayan/Feedback/"
|
|
if not os.path.exists(feedback_dir):
|
|
os.makedirs(feedback_dir)
|
|
feedback_file = os.path.join(feedback_dir, "feedback.txt")
|
|
|
|
try:
|
|
with open(feedback_file, "a") as f:
|
|
f.write(feedback + "\n")
|
|
print(f"Feedback saved at: {feedback_file}")
|
|
return "Feedback submitted successfully!"
|
|
except Exception as e:
|
|
print(f"Error saving feedback: {e}")
|
|
return "Failed to submit feedback!"
|
|
|
|
|
|
|
|
class Seafoam(Base):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
primary_hue: Union[colors.Color, str] = colors.emerald,
|
|
secondary_hue: Union[colors.Color, str] = colors.blue,
|
|
neutral_hue: Union[colors.Color, str] = colors.gray,
|
|
spacing_size: Union[sizes.Size, str] = sizes.spacing_md,
|
|
radius_size: Union[sizes.Size, str] = sizes.radius_md,
|
|
text_size: Union[sizes.Size, str] = sizes.text_lg,
|
|
font: Union[fonts.Font, str, Iterable[Union[fonts.Font, str]]] = (
|
|
fonts.GoogleFont("Quicksand"),
|
|
"ui-sans-serif",
|
|
"sans-serif",
|
|
),
|
|
font_mono: Union[fonts.Font, str, Iterable[Union[fonts.Font, str]]] = (
|
|
fonts.GoogleFont("IBM Plex Mono"),
|
|
"ui-monospace",
|
|
"monospace",
|
|
),
|
|
):
|
|
super().__init__(
|
|
primary_hue=primary_hue,
|
|
secondary_hue=secondary_hue,
|
|
neutral_hue=neutral_hue,
|
|
spacing_size=spacing_size,
|
|
radius_size=radius_size,
|
|
text_size=text_size,
|
|
font=font,
|
|
font_mono=font_mono,
|
|
)
|
|
|
|
self.set(
|
|
body_background_fill="linear-gradient(114.2deg, rgba(184,215,21,1) -15.3%, rgba(21,215,98,1) 14.5%, rgba(21,215,182,1) 38.7%, rgba(129,189,240,1) 58.8%, rgba(219,108,205,1) 77.3%, rgba(240,129,129,1) 88.5%)"
|
|
)
|
|
|
|
seafoam = Seafoam()
|
|
|
|
|
|
|
|
|
|
custom_css = """
|
|
<style>
|
|
|
|
/* Set background color for the entire Gradio app */
|
|
body, .gradio-container {
|
|
background-color: #f2f7f5 !important;
|
|
}
|
|
|
|
/* Optional: Add padding or margin for aesthetics */
|
|
.gradio-container {
|
|
padding: 20px;
|
|
}
|
|
|
|
#title {
|
|
color: green;
|
|
font-size: 36px;
|
|
font-weight: bold;
|
|
}
|
|
#description {
|
|
color: green;
|
|
font-size: 22px;
|
|
}
|
|
|
|
#title-row {
|
|
display: flex;
|
|
align-items: center;
|
|
gap: 10px;
|
|
margin-bottom: 0px;
|
|
}
|
|
#title-header h1 {
|
|
margin: 0;
|
|
}
|
|
|
|
|
|
#submit-btn {
|
|
background-color: #f5dec6; /* Banana leaf */
|
|
color: green;
|
|
padding: 15px 32px;
|
|
text-align: center;
|
|
text-decoration: none;
|
|
display: inline-block;
|
|
font-size: 30px;
|
|
margin: 4px 2px;
|
|
cursor: pointer;
|
|
}
|
|
#submit-btn:hover {
|
|
background-color: #00FFFF;
|
|
}
|
|
|
|
|
|
.intext textarea {
|
|
color: green;
|
|
font-size: 20px;
|
|
font-weight: bold;
|
|
}
|
|
|
|
|
|
.small-button {
|
|
color: green;
|
|
padding: 5px 10px;
|
|
font-size: 20px;
|
|
}
|
|
|
|
</style>
|
|
"""
|
|
|
|
|
|
sample_images = [
|
|
"./Test-Images/0d930f0a-46f813a9-db3b137b-05142eef-eca3c5a7.jpg",
|
|
"./Test-Images/93681764-ec39480e-0518b12c-199850c2-f15118ab.jpg",
|
|
"./Test-Images/6ff741e9-6ea01eef-1bf10153-d1b6beba-590b6620.jpg"
|
|
|
|
|
|
]
|
|
|
|
def set_input_image(image_path):
|
|
return gr.update(value=image_path)
|
|
|
|
def show_contact_info():
|
|
yield gr.update(visible=True, value="""
|
|
**Contact Us:**
|
|
- Chayan Mondal
|
|
- Email: [email protected]
|
|
- Associate Prof. Sonny Pham
|
|
- Email: [email protected]
|
|
- Dr. Ashu Gupta
|
|
- Email: [email protected]
|
|
""")
|
|
|
|
time.sleep(20)
|
|
|
|
yield gr.update(visible=False)
|
|
|
|
def show_acknowledgment():
|
|
yield gr.update(visible=True, value="""
|
|
**Acknowledgment:**
|
|
This Research has been supported by the Western Australian Future Health Research and Innovation Fund.
|
|
""")
|
|
|
|
time.sleep(20)
|
|
|
|
yield gr.update(visible=False)
|
|
|
|
|
|
with gr.Blocks(theme=seafoam, css=custom_css) as demo:
|
|
|
|
|
|
|
|
|
|
with gr.Row(elem_id="title-row"):
|
|
with gr.Column(scale=0):
|
|
gr.Image(
|
|
value="./AURA-CXR-Logo.png",
|
|
show_label=False,
|
|
width=60,
|
|
container=False
|
|
)
|
|
with gr.Column():
|
|
gr.Markdown(
|
|
"""
|
|
<h1 style="color:blue; font-size: 32px; font-weight: bold; margin: 0;">
|
|
AURA-CXR: Explainable Diagnosis of Chest Diseases from X-rays
|
|
</h1>
|
|
""",
|
|
elem_id="title-header"
|
|
)
|
|
|
|
gr.Markdown(
|
|
"<p id='description'>Upload an X-ray image and get its report with heat-map visualization.</p>"
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
inputs = gr.File(label="Upload Chest X-ray Image File", type="filepath")
|
|
|
|
with gr.Row():
|
|
with gr.Column(scale=1, min_width=300):
|
|
outputs1 = gr.Image(label="Image Viewer")
|
|
history_query = gr.Textbox(label="History/Doctor's Query", elem_classes="intext")
|
|
with gr.Column(scale=1, min_width=300):
|
|
outputs2 = gr.Image(label="Grad_CAM-Visualization")
|
|
with gr.Column(scale=1, min_width=300):
|
|
outputs3 = gr.Textbox(label="Generated Report", elem_classes = "intext")
|
|
outputs4 = gr.Textbox(label = "Query's Response", elem_classes = "intext")
|
|
|
|
|
|
submit_btn = gr.Button("Generate Report", elem_id="submit-btn", variant="primary")
|
|
|
|
def show_image(file_path):
|
|
if is_cxr(file_path):
|
|
return file_path, "Valid Image"
|
|
else:
|
|
return None, "Invalid image. Please upload a proper Chest X-ray."
|
|
|
|
|
|
|
|
inputs.change(
|
|
fn=show_image,
|
|
inputs=inputs,
|
|
outputs=[outputs1, outputs3]
|
|
)
|
|
|
|
|
|
|
|
|
|
submit_btn.click(
|
|
fn=xray_report_generator,
|
|
inputs=[inputs,history_query],
|
|
outputs=[outputs2, outputs3, outputs4])
|
|
|
|
|
|
gr.Markdown(
|
|
"""
|
|
<h2 style="color:green; font-size: 24px;">Or choose a sample image:</h2>
|
|
"""
|
|
)
|
|
|
|
with gr.Row():
|
|
for idx, sample_image in enumerate(sample_images):
|
|
with gr.Column(scale=1):
|
|
|
|
select_button = gr.Button(f"Select Sample Image {idx+1}")
|
|
select_button.click(
|
|
fn=set_input_image,
|
|
inputs=gr.State(value=sample_image),
|
|
outputs=inputs
|
|
)
|
|
|
|
|
|
|
|
gr.Markdown(
|
|
"""
|
|
<h2 style="color:green; font-size: 24px;">Provide Your Valuable Feedback:</h2>
|
|
"""
|
|
)
|
|
|
|
with gr.Row():
|
|
feedback_input = gr.Textbox(label="Your Feedback", lines=4, placeholder="Enter your feedback here...")
|
|
feedback_submit_btn = gr.Button("Submit Feedback", elem_classes="small-button", variant="secondary")
|
|
feedback_output = gr.Textbox(label="Feedback Status", interactive=False)
|
|
|
|
|
|
|
|
feedback_submit_btn.click(
|
|
fn=save_feedback,
|
|
inputs=feedback_input,
|
|
outputs=feedback_output
|
|
)
|
|
|
|
|
|
|
|
with gr.Row():
|
|
contact_btn = gr.Button("Contact Us", elem_classes="small-button", variant="secondary")
|
|
ack_btn = gr.Button("Acknowledgment", elem_classes="small-button", variant="secondary")
|
|
|
|
contact_info = gr.Markdown(visible=False)
|
|
acknowledgment_info = gr.Markdown(visible=False)
|
|
|
|
|
|
contact_btn.click(fn=show_contact_info, outputs=contact_info, show_progress=False)
|
|
ack_btn.click(fn=show_acknowledgment, outputs=acknowledgment_info, show_progress=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch(share=True)
|
|
|
|
|