MEDIGPT / app.py
FawadHaider2's picture
Update app.py
ad369cf verified
raw
history blame
6.25 kB
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1sjyLFLqBccpUzaUi4eyyP3NYE3gDtHfs
"""
import streamlit as st
from streamlit_option_menu import option_menu
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import time
from efficientnet_pytorch import EfficientNet
from fastai.vision.all import load_learner
# Set up environment variables for GPU handling
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# Cache to load models efficiently
@st.cache_resource
def load_skin_model():
model = MelanomaModel(out_size=9)
model_path = "multi_weight.pth"
checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model
# Transform for skin lesion images
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Diagnosis map for skin disease model
DIAGNOSIS_MAP = {
0: 'Melanoma', 1: 'Melanocytic nevus', 2: 'Basal cell carcinoma', 3: 'Actinic keratosis',
4: 'Benign keratosis', 5: 'Dermatofibroma', 6: 'Vascular lesion', 7: 'Squamous cell carcinoma', 8: 'Unknown'
}
# Model for skin lesion classification
class MelanomaModel(nn.Module):
def __init__(self, out_size, dropout_prob=0.5):
super(MelanomaModel, self).__init__()
self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')
self.efficient_net._fc = nn.Identity()
self.fc1 = nn.Linear(1280, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, out_size)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
x = self.efficient_net(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
# Alzheimer's Prediction
@st.cache_resource
def load_alzheimer_model():
return keras.models.load_model('alzheimer_99.5.h5')
# Brain Tumor Prediction
@st.cache_resource
def load_brain_tumor_model(classes):
if classes == '44 Classes':
return keras.models.load_model('44class_96.5.h5')
elif classes == '17 Classes':
return keras.models.load_model('17class_98.1.h5')
elif classes == '15 Classes':
return keras.models.load_model('15class_99.8.h5')
else: # Default 2 classes
return keras.models.load_model('2calss_lagre_dataset_99.1.h5')
# Prediction for Skin Disease
def predict_skin_lesion(img: Image.Image, model: nn.Module):
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
probs = F.softmax(outputs, dim=1)
top_probs, top_idxs = torch.topk(probs, 3, dim=1)
predictions = [(DIAGNOSIS_MAP.get(idx.item(), "Unknown"), prob.item() * 100) for prob, idx in zip(top_probs[0], top_idxs[0])]
return predictions
# Prediction for Brain Tumor and Alzheimer
def predict(img_path, model, result_classes):
img = tf.keras.utils.load_img(img_path, target_size=(224, 224))
img_array = np.array(img).reshape(-1, 224, 224, 3)
pred = model.predict(img_array)
return result_classes[np.argmax(pred, axis=1)[0]]
# Sidebar for Disease Categories
def spr_sidebar():
menu = option_menu(
menu_title="Navigation",
options=["Brain Tumor", "Alzheimer", "Skin Disease", "Eye Disease", "About"],
icons=["house", "brain", "microscope", "eye", "info-square"],
menu_icon="cast",
default_index=0,
orientation="horizontal"
)
return menu
# Home Page Content
def home_page(selected_category): # Accept selected_category as a parameter
st.title("Disease Detection Web App")
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
if st.button("Classify"):
if selected_category == "Brain Tumor":
classes = st.selectbox("Select Number of Classes", ['44 Classes', '17 Classes', '15 Classes', '2 Classes'])
model = load_brain_tumor_model(classes)
result_classes = ['Astrocitoma', 'Carcinoma', 'Ependimoma', '_NORMAL', 'etc...'] # Define all the classes
result = predict(uploaded_file, model, result_classes)
st.success(f"Prediction: {result}")
elif selected_category == "Alzheimer":
model = load_alzheimer_model()
result_classes = ['Mild_Demented', 'Moderate_Demented', 'Non_Demented']
result = predict(uploaded_file, model, result_classes)
st.success(f"Prediction: {result}")
elif selected_category == "Skin Disease":
model = load_skin_model()
img = Image.open(uploaded_file)
predictions = predict_skin_lesion(img, model)
for idx, (label, confidence) in enumerate(predictions, 1):
st.write(f"{idx}. {label}: {confidence:.2f}%")
elif selected_category == "Eye Disease":
# Implement Eye Disease prediction (similar to others)
pass
# Main Function to Run the App
def main():
selected_category = spr_sidebar()
if selected_category == "Brain Tumor":
home_page(selected_category) # Pass selected_category here
elif selected_category == "Alzheimer":
home_page(selected_category) # Pass selected_category here
elif selected_category == "Skin Disease":
home_page(selected_category) # Pass selected_category here
elif selected_category == "Eye Disease":
home_page(selected_category) # Pass selected_category here
elif selected_category == "About":
about_page()
if __name__ == '__main__':
main()