File size: 6,467 Bytes
d7767a0
 
 
 
 
 
 
f01141e
d7767a0
f01141e
d7767a0
 
f01141e
d7767a0
 
 
 
f01141e
d7767a0
f01141e
 
 
 
 
d7767a0
 
f01141e
 
d7767a0
 
 
f01141e
 
d7767a0
 
f01141e
 
d7767a0
 
 
 
f01141e
d7767a0
 
 
 
 
f01141e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7767a0
f01141e
d7767a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f01141e
d7767a0
 
 
f01141e
 
 
 
 
 
 
 
 
d7767a0
f01141e
 
d7767a0
 
 
 
f01141e
d7767a0
 
f01141e
d7767a0
 
 
 
f01141e
 
 
 
 
d7767a0
 
f01141e
d7767a0
f01141e
d7767a0
f01141e
d7767a0
f01141e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import streamlit as st
from PIL import Image
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import logging
import base64
from io import BytesIO
from groq import Groq  # Import the Groq client for Deepseek R1 API

# ------------------ Setup Logging ------------------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# ------------------ Load the ViT Model ------------------
repository_id = "EnDevSols/brainmri-vit-model"
model = ViTForImageClassification.from_pretrained(repository_id)
feature_extractor = ViTImageProcessor.from_pretrained(repository_id)

# ------------------ ViT Inference Function ------------------
def predict(image):
    """
    Given an image, perform inference using the ViT model to detect brain tumor.
    Returns a human-readable diagnosis string.
    """
    # Convert to RGB and preprocess the image
    image = image.convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Set the device (GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Perform inference without gradient computation
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get the predicted label and map to a diagnosis
    logits = outputs.logits
    predicted_label = logits.argmax(-1).item()
    label_map = {0: "No", 1: "Yes"}
    diagnosis = label_map[predicted_label]
    
    if diagnosis == "Yes":
        return "The diagnosis indicates that you have a brain tumor."
    else:
        return "The diagnosis indicates that you do not have a brain tumor."

# ------------------ Deepseek R1 Assistance Function ------------------
def get_assistance_from_deepseek(diagnosis_text):
    """
    Given the diagnosis from the ViT model, call the Deepseek R1 model via the Groq API
    to get additional recommendations and next steps.
    """
    # Instantiate the Groq client with the provided API key
    client = Groq(api_key="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")
    
    # Construct a prompt that includes the diagnosis and asks for detailed guidance
    prompt = (
        f"Based on the following diagnosis: '{diagnosis_text}', please provide next steps and "
        "recommendations for the patient. Include whether to consult a specialist, if further tests "
        "are needed, and any other immediate actions or lifestyle recommendations."
    )
    
    messages = [
        {
            "role": "system",
            "content": "You are a helpful medical assistant providing guidance after a brain tumor diagnosis."
        },
        {"role": "user", "content": prompt}
    ]
    
    # Create the completion using the Deepseek R1 model (non-streaming for simplicity)
    completion = client.chat.completions.create(
        model="deepseek-r1-distill-llama-70b",
        messages=messages,
        temperature=0.6,
        max_completion_tokens=4096,
        top_p=0.95,
        stream=False,
        stop=None,
    )
    
    # Extract the response text. (Depending on the API response format, adjust as needed.)
    try:
        assistance_text = completion.choices[0].message.content
    except AttributeError:
        # Fallback in case the structure is different
        assistance_text = completion.choices[0].text
    
    return assistance_text

# ------------------ Custom CSS for Styling ------------------
combined_css = """
    .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
    .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
    .stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
    .stSpinner { color: #4CAF50; }
    .title {
        font-size: 3rem;
        font-weight: bold;
        display: flex; 
        align-items: center; 
        justify-content: center;
    }
    .colorful-text {
        background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
    }
    .black-white-text {
        color: black;
    }
    .custom-text {
        font-size: 1.2rem;
        color: #feb47b;
        text-align: center;
        margin-top: -20px;
        margin-bottom: 20px;
    }
"""

# ------------------ Streamlit App Configuration ------------------
st.set_page_config(layout="wide")
st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)

# App Title and Description
st.markdown(
    '<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>',
    unsafe_allow_html=True
)
st.markdown(
    '<div class="custom-text">Upload an MRI image to detect a brain tumor and receive next steps and recommendations.</div>',
    unsafe_allow_html=True
)

# ------------------ Image Upload Section ------------------
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    
    # Resize image for display purposes
    resized_image = image.resize((150, 150))
    
    # Convert image to base64 for HTML display
    buffered = BytesIO()
    resized_image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    # Display the uploaded image in the center
    st.markdown(
        f"<div style='text-align: center;'><img src='data:image/jpeg;base64,{img_str}' alt='Uploaded Image' width='300'></div>",
        unsafe_allow_html=True
    )
    
    st.write("")
    st.write("Processing the image...")

    # ------------------ Step 1: Get Diagnosis from the ViT Model ------------------
    diagnosis = predict(image)
    st.markdown("### Diagnosis:")
    st.write(diagnosis)
    
    # ------------------ Step 2: Get Further Assistance from Deepseek R1 ------------------
    with st.spinner("Fetching additional guidance based on your diagnosis..."):
        assistance = get_assistance_from_deepseek(diagnosis)
    
    st.markdown("### Next Steps and Recommendations:")
    st.write(assistance)