|
import streamlit as st |
|
from gcg.pipelines import predict |
|
import os |
|
|
|
|
|
TEMP_DIR = "temp" |
|
os.makedirs(TEMP_DIR, exist_ok=True) |
|
|
|
st.title("Diabetic Retinopathy Classifier (Guided Context Gating Attention)") |
|
st.subheader("Upload retinal images and get predictions with heatmaps") |
|
st.write("This app is the demo of our ICIP'24 paper [[arXiv]](https://arxiv.org/pdf/2406.13126) | [[Github]](https://github.com/tejacherukuri/guided-context-gating)") |
|
|
|
|
|
uploaded_files = st.file_uploader( |
|
"Upload Retinal Images", |
|
type=["jpg", "jpeg", "png"], |
|
accept_multiple_files=True |
|
) |
|
|
|
if st.button("Run Inference"): |
|
if uploaded_files: |
|
img_paths = [] |
|
for uploaded_file in uploaded_files: |
|
|
|
file_path = os.path.join(TEMP_DIR, uploaded_file.name) |
|
with open(file_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
img_paths.append(file_path) |
|
|
|
|
|
st.info("Running predictions...") |
|
predictions = predict(img_paths) |
|
|
|
|
|
st.success("Inference completed! Here are the results:") |
|
for img_path, predicted_class in zip(img_paths, predictions): |
|
st.write(f"**Image**: {os.path.basename(img_path)}") |
|
st.write(f"**Predicted Class**: {predicted_class}") |
|
heatmap_path = os.path.join("heatmaps", f"heatmap_{os.path.basename(img_path)}") |
|
if os.path.exists(heatmap_path): |
|
st.image(heatmap_path, caption="Attention Map", use_container_width=True) |
|
else: |
|
st.error("Please upload at least one image.") |