Codewithsalty commited on
Commit
5760faf
Β·
verified Β·
1 Parent(s): a609c19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -77
app.py CHANGED
@@ -1,77 +1,81 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import transforms
5
- from PIL import Image
6
- from TumorModel import TumorClassification, GliomaStageModel
7
-
8
- # Title
9
- st.set_page_config(page_title="Brain Tumor Predictor", layout="centered")
10
- st.title("🧠 Brain Tumor Classifier & Glioma Stage Predictor")
11
-
12
- # Load models
13
- @st.cache_resource
14
- def load_models():
15
- tumor_model = TumorClassification()
16
- tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
17
- tumor_model.eval()
18
-
19
- glioma_model = GliomaStageModel()
20
- glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
21
- glioma_model.eval()
22
-
23
- return tumor_model, glioma_model
24
-
25
- tumor_model, glioma_model = load_models()
26
-
27
- # Labels
28
- tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
29
- stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
30
-
31
- # Image Preprocessing
32
- transform = transforms.Compose([
33
- transforms.Grayscale(),
34
- transforms.Resize((224, 224)),
35
- transforms.ToTensor(),
36
- transforms.Normalize([0.5], [0.5])
37
- ])
38
-
39
- # Tabs
40
- tab1, tab2 = st.tabs(["🧠 Tumor Type Detection", "🧬 Glioma Stage Prediction"])
41
-
42
- # Tab 1: Tumor Classification
43
- with tab1:
44
- st.header("Upload Brain MRI")
45
- image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
46
- if image_file:
47
- image = Image.open(image_file).convert("L")
48
- st.image(image, caption="Uploaded Image", use_column_width=True)
49
-
50
- image_tensor = transform(image).unsqueeze(0)
51
- with torch.no_grad():
52
- prediction = tumor_model(image_tensor)
53
- pred_label = tumor_labels[torch.argmax(prediction)]
54
- st.success(f"🧠 Tumor Prediction: **{pred_label.upper()}**")
55
-
56
- # Tab 2: Glioma Stage Prediction
57
- with tab2:
58
- st.header("Patient & Genetic Info")
59
- gender = st.radio("Gender", ["Male", "Female"])
60
- age = st.slider("Age", 1, 100, 25)
61
-
62
- st.subheader("Gene Mutations (0 = No, 1 = Yes)")
63
- idh1 = st.radio("IDH1", [0, 1], horizontal=True)
64
- tp53 = st.radio("TP53", [0, 1], horizontal=True)
65
- atrx = st.radio("ATRX", [0, 1], horizontal=True)
66
- pten = st.radio("PTEN", [0, 1], horizontal=True)
67
- egfr = st.radio("EGFR", [0, 1], horizontal=True)
68
- cic = st.radio("CIC", [0, 1], horizontal=True)
69
- pik3ca = st.radio("PIK3CA", [0, 1], horizontal=True)
70
-
71
- if st.button("πŸ” Predict Glioma Stage"):
72
- gender_val = 0 if gender == "Male" else 1
73
- features = torch.tensor([[gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]], dtype=torch.float32)
74
- with torch.no_grad():
75
- prediction = glioma_model(features)
76
- stage = stage_labels[torch.argmax(prediction)]
77
- st.success(f"🧬 Predicted Glioma Stage: **{stage}**")
 
 
 
 
 
1
+ import os
2
+
3
+ # βœ… Fix permission issues on Hugging Face Spaces
4
+ os.environ["STREAMLIT_HOME"] = "/home/user/app"
5
+ os.environ["XDG_CONFIG_HOME"] = "/home/user/app"
6
+
7
+ import streamlit as st
8
+ import torch
9
+ from torchvision import transforms
10
+ from PIL import Image
11
+ from TumorModel import TumorClassification, GliomaStageModel
12
+
13
+ # 🎯 Load classification model
14
+ tumor_model = TumorClassification()
15
+ tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
16
+ tumor_model.eval()
17
+
18
+ # 🎯 Load glioma stage model
19
+ glioma_model = GliomaStageModel()
20
+ glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
21
+ glioma_model.eval()
22
+
23
+ # πŸ“Œ Class labels
24
+ tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
25
+ stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
26
+
27
+ # πŸ” Image Transform
28
+ transform = transforms.Compose([
29
+ transforms.Grayscale(),
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.5], std=[0.5])
33
+ ])
34
+
35
+ # 🧠 Tumor prediction
36
+ def predict_tumor(image: Image.Image) -> str:
37
+ image = transform(image).unsqueeze(0)
38
+ with torch.no_grad():
39
+ output = tumor_model(image)
40
+ pred = torch.argmax(output, dim=1).item()
41
+ return tumor_labels[pred]
42
+
43
+ # 🧬 Glioma stage prediction
44
+ def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
45
+ gender_val = 0 if gender == "Male" else 1
46
+ features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
47
+ x = torch.tensor(features).float().unsqueeze(0)
48
+ with torch.no_grad():
49
+ out = glioma_model(x)
50
+ pred = torch.argmax(out, dim=1).item()
51
+ return stage_labels[pred]
52
+
53
+ # 🎨 Streamlit UI
54
+ st.title("🧠 Brain Tumor Detection & Glioma Stage Predictor")
55
+
56
+ tab1, tab2 = st.tabs(["Tumor Type Detector", "Glioma Stage Predictor"])
57
+
58
+ with tab1:
59
+ st.header("πŸ–ΌοΈ Upload an MRI Image")
60
+ uploaded_file = st.file_uploader("Choose an MRI image", type=["jpg", "jpeg", "png"])
61
+ if uploaded_file:
62
+ image = Image.open(uploaded_file).convert("RGB")
63
+ st.image(image, caption="Uploaded MRI", use_column_width=True)
64
+ prediction = predict_tumor(image)
65
+ st.success(f"🧠 Predicted Tumor Type: **{prediction.upper()}**")
66
+
67
+ with tab2:
68
+ st.header("🧬 Patient Genetic and Demographic Information")
69
+ gender = st.radio("Gender", ["Male", "Female"])
70
+ age = st.slider("Age", 1, 100, 30)
71
+ idh1 = st.slider("IDH1 Mutation", 0, 1, 0)
72
+ tp53 = st.slider("TP53 Mutation", 0, 1, 0)
73
+ atrx = st.slider("ATRX Mutation", 0, 1, 0)
74
+ pten = st.slider("PTEN Mutation", 0, 1, 0)
75
+ egfr = st.slider("EGFR Mutation", 0, 1, 0)
76
+ cic = st.slider("CIC Mutation", 0, 1, 0)
77
+ pik3ca = st.slider("PIK3CA Mutation", 0, 1, 0)
78
+
79
+ if st.button("Predict Glioma Stage"):
80
+ stage = predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca)
81
+ st.success(f"πŸ“Š Predicted Glioma Stage: **{stage}**")