File size: 3,325 Bytes
463a6b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91


import streamlit as st
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification
import pandas as pd

# Load your model
@st.cache_data
def load_dataset():
    dataset_path = "./Data_Entry_2017_v2020.csv"  # Replace with your dataset path 
    return pd.read_csv(dataset_path)

data = load_dataset()

@st.cache_resource
def load_model():
    # Define the model architecture
    model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=15)
    # Load the saved state dictionary
    state_dict = torch.load("best_model_new_retrain.pth", map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    model.eval()
    return model

model = load_model()

# Define image transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjust based on your model's requirements
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

# Function to make predictions
def predict_image(image):
    image = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        outputs = model(image).logits
        probabilities = torch.sigmoid(outputs)
    return probabilities

# Streamlit App
st.title("Chest Xray Disease Prediction App")
st.write("Upload single or multiple images to get predictions.")

# File uploader for single or bulk images
uploaded_files = st.file_uploader("Upload Image(s)", type=["jpg", "png", "jpeg"], accept_multiple_files=True)

# Process each uploaded file
if uploaded_files:
    for uploaded_file in uploaded_files:
        # Load and display the image
        image = Image.open(uploaded_file).convert("RGB")
        st.image(image, caption=f"Uploaded Image: {uploaded_file.name}", use_column_width=True)
        
        # Search for the filename in the dataset
        uploaded_filename = uploaded_file.name
        matching_row = data[data['Image Index'] == uploaded_filename]
        truth = matching_row.iloc[0]['Finding Labels'] if not matching_row.empty else "No matching label found"
        
        st.write(f"**Truth (Ground Truth Labels):** {truth}")

        # Get predictions
        probabilities = predict_image(image)

        # Create a DataFrame to display probabilities
        label_columns = [
            'No Finding', 'Infiltration', 'Effusion', 'Atelectasis', 'Nodule',
            'Mass', 'Pneumothorax', 'Consolidation', 'Pleural_Thickening',
            'Cardiomegaly', 'Emphysema', 'Edema', 'Fibrosis', 'Pneumonia', 'Hernia'
        ]
        prediction_df = pd.DataFrame({
            "Class": label_columns,
            "Probability": probabilities.squeeze().tolist()
        })

        # Highlight the highest probabilities (you can customize the threshold)
        prediction_df['Highlight'] = prediction_df['Probability'] > 0.5

        # Display predictions
        st.write("**Prediction (Model Probabilities):**")
        st.dataframe(
            prediction_df.style.format({"Probability": "{:.2f}"}).applymap(
                lambda val: 'background-color: yellow;' if val else '', subset=['Highlight']
            )
        )