|
|
|
|
|
import streamlit as st
|
|
from PIL import Image
|
|
import torch
|
|
from torchvision import transforms
|
|
from transformers import AutoModelForImageClassification
|
|
import pandas as pd
|
|
|
|
|
|
@st.cache_data
|
|
def load_dataset():
|
|
dataset_path = "./Data_Entry_2017_v2020.csv"
|
|
return pd.read_csv(dataset_path)
|
|
|
|
data = load_dataset()
|
|
|
|
@st.cache_resource
|
|
def load_model():
|
|
|
|
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=15)
|
|
|
|
state_dict = torch.load("best_model_new_retrain.pth", map_location=torch.device('cpu'))
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
return model
|
|
|
|
model = load_model()
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
|
|
def predict_image(image):
|
|
image = transform(image).unsqueeze(0)
|
|
with torch.no_grad():
|
|
outputs = model(image).logits
|
|
probabilities = torch.sigmoid(outputs)
|
|
return probabilities
|
|
|
|
|
|
st.title("Chest Xray Disease Prediction App")
|
|
st.write("Upload single or multiple images to get predictions.")
|
|
|
|
|
|
uploaded_files = st.file_uploader("Upload Image(s)", type=["jpg", "png", "jpeg"], accept_multiple_files=True)
|
|
|
|
|
|
if uploaded_files:
|
|
for uploaded_file in uploaded_files:
|
|
|
|
image = Image.open(uploaded_file).convert("RGB")
|
|
st.image(image, caption=f"Uploaded Image: {uploaded_file.name}", use_column_width=True)
|
|
|
|
|
|
uploaded_filename = uploaded_file.name
|
|
matching_row = data[data['Image Index'] == uploaded_filename]
|
|
truth = matching_row.iloc[0]['Finding Labels'] if not matching_row.empty else "No matching label found"
|
|
|
|
st.write(f"**Truth (Ground Truth Labels):** {truth}")
|
|
|
|
|
|
probabilities = predict_image(image)
|
|
|
|
|
|
label_columns = [
|
|
'No Finding', 'Infiltration', 'Effusion', 'Atelectasis', 'Nodule',
|
|
'Mass', 'Pneumothorax', 'Consolidation', 'Pleural_Thickening',
|
|
'Cardiomegaly', 'Emphysema', 'Edema', 'Fibrosis', 'Pneumonia', 'Hernia'
|
|
]
|
|
prediction_df = pd.DataFrame({
|
|
"Class": label_columns,
|
|
"Probability": probabilities.squeeze().tolist()
|
|
})
|
|
|
|
|
|
prediction_df['Highlight'] = prediction_df['Probability'] > 0.5
|
|
|
|
|
|
st.write("**Prediction (Model Probabilities):**")
|
|
st.dataframe(
|
|
prediction_df.style.format({"Probability": "{:.2f}"}).applymap(
|
|
lambda val: 'background-color: yellow;' if val else '', subset=['Highlight']
|
|
)
|
|
)
|
|
|
|
|