File size: 4,018 Bytes
bd5d744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
from PIL import Image
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import logging
import base64
from io import BytesIO

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load the model and feature extractor from Hugging Face
repository_id = "EnDevSols/brainmri-vit-model"
model = ViTForImageClassification.from_pretrained(repository_id)
feature_extractor = ViTImageProcessor.from_pretrained(repository_id)

# Function to perform inference
def predict(image):
    # Load and preprocess the image
    image = image.convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="pt")

    # Move the inputs to the appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the predicted label
    logits = outputs.logits
    predicted_label = logits.argmax(-1).item()

    # Map the label to "No" or "Yes"
    label_map = {0: "No", 1: "Yes"}
    diagnosis = label_map[predicted_label]

    # Return a complete statement
    if diagnosis == "Yes":
        return "The diagnosis indicates that you have a brain tumor."
    else:
        return "The diagnosis indicates that you do not have a brain tumor."

# Custom CSS
def set_css(style):
    st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)

# Combined dark mode styles
combined_css = """
    .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
    .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
    .stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
    .stSpinner { color: #4CAF50; }
    .title {
        font-size: 3rem;
        font-weight: bold;
        display: flex; 
        align-items: center; 
        justify-content: center;
    }
    .colorful-text {
        background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
    }
    .black-white-text {
        color: black;
    }
    .small-input .stTextInput>div>input {
        height: 2rem;
        font-size: 0.9rem;
    }
    .small-file-uploader .stFileUploader>div>div {
        height: 2rem;
        font-size: 0.9rem;
    }
    .custom-text {
        font-size: 1.2rem;
        color: #feb47b;
        text-align: center;
        margin-top: -20px;
        margin-bottom: 20px;
    }
"""

# Streamlit application
st.set_page_config(layout="wide")

st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)

st.markdown('<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>', unsafe_allow_html=True)
st.markdown('<div class="custom-text">Upload an MRI image to detect brain tumor</div>', unsafe_allow_html=True)

# Uploading image
uploaded_file = st.file_uploader("Choose an image...", type="jpg")

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    
    # Resize the image for display
    resized_image = image.resize((150, 150))
    
    # Convert image to base64
    buffered = BytesIO()
    resized_image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    # Display the image in the center
    st.markdown(f"<div style='text-align: center;'><img src='data:image/jpeg;base64,{img_str}' alt='Uploaded Image' width='300'></div>", unsafe_allow_html=True)
    
    st.write("")
    st.write("Result...")

    diagnosis = predict(image)
    st.write(diagnosis)