Spaces:
Sleeping
Sleeping
File size: 6,467 Bytes
d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e d7767a0 f01141e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import streamlit as st
from PIL import Image
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import logging
import base64
from io import BytesIO
from groq import Groq # Import the Groq client for Deepseek R1 API
# ------------------ Setup Logging ------------------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# ------------------ Load the ViT Model ------------------
repository_id = "EnDevSols/brainmri-vit-model"
model = ViTForImageClassification.from_pretrained(repository_id)
feature_extractor = ViTImageProcessor.from_pretrained(repository_id)
# ------------------ ViT Inference Function ------------------
def predict(image):
"""
Given an image, perform inference using the ViT model to detect brain tumor.
Returns a human-readable diagnosis string.
"""
# Convert to RGB and preprocess the image
image = image.convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")
# Set the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Perform inference without gradient computation
with torch.no_grad():
outputs = model(**inputs)
# Get the predicted label and map to a diagnosis
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
label_map = {0: "No", 1: "Yes"}
diagnosis = label_map[predicted_label]
if diagnosis == "Yes":
return "The diagnosis indicates that you have a brain tumor."
else:
return "The diagnosis indicates that you do not have a brain tumor."
# ------------------ Deepseek R1 Assistance Function ------------------
def get_assistance_from_deepseek(diagnosis_text):
"""
Given the diagnosis from the ViT model, call the Deepseek R1 model via the Groq API
to get additional recommendations and next steps.
"""
# Instantiate the Groq client with the provided API key
client = Groq(api_key="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")
# Construct a prompt that includes the diagnosis and asks for detailed guidance
prompt = (
f"Based on the following diagnosis: '{diagnosis_text}', please provide next steps and "
"recommendations for the patient. Include whether to consult a specialist, if further tests "
"are needed, and any other immediate actions or lifestyle recommendations."
)
messages = [
{
"role": "system",
"content": "You are a helpful medical assistant providing guidance after a brain tumor diagnosis."
},
{"role": "user", "content": prompt}
]
# Create the completion using the Deepseek R1 model (non-streaming for simplicity)
completion = client.chat.completions.create(
model="deepseek-r1-distill-llama-70b",
messages=messages,
temperature=0.6,
max_completion_tokens=4096,
top_p=0.95,
stream=False,
stop=None,
)
# Extract the response text. (Depending on the API response format, adjust as needed.)
try:
assistance_text = completion.choices[0].message.content
except AttributeError:
# Fallback in case the structure is different
assistance_text = completion.choices[0].text
return assistance_text
# ------------------ Custom CSS for Styling ------------------
combined_css = """
.main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
.block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
.stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
.stSpinner { color: #4CAF50; }
.title {
font-size: 3rem;
font-weight: bold;
display: flex;
align-items: center;
justify-content: center;
}
.colorful-text {
background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.black-white-text {
color: black;
}
.custom-text {
font-size: 1.2rem;
color: #feb47b;
text-align: center;
margin-top: -20px;
margin-bottom: 20px;
}
"""
# ------------------ Streamlit App Configuration ------------------
st.set_page_config(layout="wide")
st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
# App Title and Description
st.markdown(
'<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>',
unsafe_allow_html=True
)
st.markdown(
'<div class="custom-text">Upload an MRI image to detect a brain tumor and receive next steps and recommendations.</div>',
unsafe_allow_html=True
)
# ------------------ Image Upload Section ------------------
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
# Resize image for display purposes
resized_image = image.resize((150, 150))
# Convert image to base64 for HTML display
buffered = BytesIO()
resized_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Display the uploaded image in the center
st.markdown(
f"<div style='text-align: center;'><img src='data:image/jpeg;base64,{img_str}' alt='Uploaded Image' width='300'></div>",
unsafe_allow_html=True
)
st.write("")
st.write("Processing the image...")
# ------------------ Step 1: Get Diagnosis from the ViT Model ------------------
diagnosis = predict(image)
st.markdown("### Diagnosis:")
st.write(diagnosis)
# ------------------ Step 2: Get Further Assistance from Deepseek R1 ------------------
with st.spinner("Fetching additional guidance based on your diagnosis..."):
assistance = get_assistance_from_deepseek(diagnosis)
st.markdown("### Next Steps and Recommendations:")
st.write(assistance)
|