Tumor-Detection / app.py
Hammad712's picture
Update app.py
f01141e verified
raw
history blame
6.47 kB
import streamlit as st
from PIL import Image
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import logging
import base64
from io import BytesIO
from groq import Groq # Import the Groq client for Deepseek R1 API
# ------------------ Setup Logging ------------------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# ------------------ Load the ViT Model ------------------
repository_id = "EnDevSols/brainmri-vit-model"
model = ViTForImageClassification.from_pretrained(repository_id)
feature_extractor = ViTImageProcessor.from_pretrained(repository_id)
# ------------------ ViT Inference Function ------------------
def predict(image):
"""
Given an image, perform inference using the ViT model to detect brain tumor.
Returns a human-readable diagnosis string.
"""
# Convert to RGB and preprocess the image
image = image.convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")
# Set the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Perform inference without gradient computation
with torch.no_grad():
outputs = model(**inputs)
# Get the predicted label and map to a diagnosis
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
label_map = {0: "No", 1: "Yes"}
diagnosis = label_map[predicted_label]
if diagnosis == "Yes":
return "The diagnosis indicates that you have a brain tumor."
else:
return "The diagnosis indicates that you do not have a brain tumor."
# ------------------ Deepseek R1 Assistance Function ------------------
def get_assistance_from_deepseek(diagnosis_text):
"""
Given the diagnosis from the ViT model, call the Deepseek R1 model via the Groq API
to get additional recommendations and next steps.
"""
# Instantiate the Groq client with the provided API key
client = Groq(api_key="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")
# Construct a prompt that includes the diagnosis and asks for detailed guidance
prompt = (
f"Based on the following diagnosis: '{diagnosis_text}', please provide next steps and "
"recommendations for the patient. Include whether to consult a specialist, if further tests "
"are needed, and any other immediate actions or lifestyle recommendations."
)
messages = [
{
"role": "system",
"content": "You are a helpful medical assistant providing guidance after a brain tumor diagnosis."
},
{"role": "user", "content": prompt}
]
# Create the completion using the Deepseek R1 model (non-streaming for simplicity)
completion = client.chat.completions.create(
model="deepseek-r1-distill-llama-70b",
messages=messages,
temperature=0.6,
max_completion_tokens=4096,
top_p=0.95,
stream=False,
stop=None,
)
# Extract the response text. (Depending on the API response format, adjust as needed.)
try:
assistance_text = completion.choices[0].message.content
except AttributeError:
# Fallback in case the structure is different
assistance_text = completion.choices[0].text
return assistance_text
# ------------------ Custom CSS for Styling ------------------
combined_css = """
.main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
.block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
.stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
.stSpinner { color: #4CAF50; }
.title {
font-size: 3rem;
font-weight: bold;
display: flex;
align-items: center;
justify-content: center;
}
.colorful-text {
background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.black-white-text {
color: black;
}
.custom-text {
font-size: 1.2rem;
color: #feb47b;
text-align: center;
margin-top: -20px;
margin-bottom: 20px;
}
"""
# ------------------ Streamlit App Configuration ------------------
st.set_page_config(layout="wide")
st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
# App Title and Description
st.markdown(
'<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>',
unsafe_allow_html=True
)
st.markdown(
'<div class="custom-text">Upload an MRI image to detect a brain tumor and receive next steps and recommendations.</div>',
unsafe_allow_html=True
)
# ------------------ Image Upload Section ------------------
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
# Resize image for display purposes
resized_image = image.resize((150, 150))
# Convert image to base64 for HTML display
buffered = BytesIO()
resized_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Display the uploaded image in the center
st.markdown(
f"<div style='text-align: center;'><img src='data:image/jpeg;base64,{img_str}' alt='Uploaded Image' width='300'></div>",
unsafe_allow_html=True
)
st.write("")
st.write("Processing the image...")
# ------------------ Step 1: Get Diagnosis from the ViT Model ------------------
diagnosis = predict(image)
st.markdown("### Diagnosis:")
st.write(diagnosis)
# ------------------ Step 2: Get Further Assistance from Deepseek R1 ------------------
with st.spinner("Fetching additional guidance based on your diagnosis..."):
assistance = get_assistance_from_deepseek(diagnosis)
st.markdown("### Next Steps and Recommendations:")
st.write(assistance)