tejacherukuri's picture
Update app.py
2747d6a verified
import streamlit as st
from gcg.pipelines import predict
import os
# Define the directory to save uploaded files
TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True) # Create the temp directory if it doesn't exist
st.title("Diabetic Retinopathy Classifier (Guided Context Gating Attention)")
st.subheader("Upload retinal images and get predictions with heatmaps")
st.write("This app is the demo of our ICIP'24 paper [[arXiv]](https://arxiv.org/pdf/2406.13126) | [[Github]](https://github.com/tejacherukuri/guided-context-gating)")
# File uploader to accept multiple images
uploaded_files = st.file_uploader(
"Upload Retinal Images",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True
)
if st.button("Run Inference"):
if uploaded_files:
img_paths = []
for uploaded_file in uploaded_files:
# Save each uploaded file to the temp directory
file_path = os.path.join(TEMP_DIR, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
img_paths.append(file_path) # Collect the file path for inference
# Pass the file paths to the predict function
st.info("Running predictions...")
predictions = predict(img_paths)
# Display predictions and heatmaps
st.success("Inference completed! Here are the results:")
for img_path, predicted_class in zip(img_paths, predictions):
st.write(f"**Image**: {os.path.basename(img_path)}")
st.write(f"**Predicted Class**: {predicted_class}")
heatmap_path = os.path.join("heatmaps", f"heatmap_{os.path.basename(img_path)}")
if os.path.exists(heatmap_path):
st.image(heatmap_path, caption="Attention Map", use_container_width=True)
else:
st.error("Please upload at least one image.")