Codewithsalty commited on
Commit
c73e079
·
verified ·
1 Parent(s): a0abc97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -79
app.py CHANGED
@@ -1,79 +1,79 @@
1
- import torch
2
- import gradio as gr
3
- from PIL import Image
4
- from torchvision import transforms
5
- from TumorModel import TumorClassification, GliomaStageModel
6
-
7
- # Load tumor classification model
8
- tumor_model = TumorClassification()
9
- tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
10
- tumor_model.eval()
11
-
12
- # Load glioma stage model
13
- glioma_model = GliomaStageModel()
14
- glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
15
- glioma_model.eval()
16
-
17
- # Labels
18
- tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
19
- stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
20
-
21
- # Image transform
22
- transform = transforms.Compose([
23
- transforms.Grayscale(),
24
- transforms.Resize((224, 224)),
25
- transforms.ToTensor(),
26
- transforms.Normalize(mean=[0.5], std=[0.5])
27
- ])
28
-
29
- # Tumor type prediction
30
- def predict_tumor(image):
31
- image = transform(image).unsqueeze(0)
32
- with torch.no_grad():
33
- out = tumor_model(image)
34
- pred = torch.argmax(out, dim=1).item()
35
- return tumor_labels[pred]
36
-
37
- # Glioma stage prediction
38
- def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
39
- gender_val = 0 if gender == "Male" else 1
40
- features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
41
- x = torch.tensor(features).float().unsqueeze(0)
42
- with torch.no_grad():
43
- out = glioma_model(x)
44
- pred = torch.argmax(out, dim=1).item()
45
- return stage_labels[pred]
46
-
47
- # Interface 1: Tumor Detector
48
- tumor_tab = gr.Interface(
49
- fn=predict_tumor,
50
- inputs=gr.Image(type="pil"),
51
- outputs=gr.Label(),
52
- title="🧠 Brain Tumor Detection",
53
- description="Upload an MRI image to classify tumor type: glioma, meningioma, notumor, or pituitary."
54
- )
55
-
56
- # Interface 2: Glioma Stage Prediction
57
- stage_tab = gr.Interface(
58
- fn=predict_stage,
59
- inputs=[
60
- gr.Radio(["Male", "Female"], label="Gender"),
61
- gr.Slider(0, 100, label="Age"),
62
- gr.Slider(0, 1, step=1, label="IDH1 Mutation"),
63
- gr.Slider(0, 1, step=1, label="TP53 Mutation"),
64
- gr.Slider(0, 1, step=1, label="ATRX Mutation"),
65
- gr.Slider(0, 1, step=1, label="PTEN Mutation"),
66
- gr.Slider(0, 1, step=1, label="EGFR Mutation"),
67
- gr.Slider(0, 1, step=1, label="CIC Mutation"),
68
- gr.Slider(0, 1, step=1, label="PIK3CA Mutation")
69
- ],
70
- outputs=gr.Label(),
71
- title="🧬 Glioma Stage Predictor",
72
- description="Enter patient’s mutation and demographic data to predict glioma stage."
73
- )
74
-
75
- # Combine in tabs
76
- demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"])
77
-
78
- # Run app
79
- demo.launch()
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from TumorModel import TumorClassification, GliomaStageModel
6
+
7
+ # Load tumor classification model
8
+ tumor_model = TumorClassification()
9
+ tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
10
+ tumor_model.eval()
11
+
12
+ # Load glioma stage model
13
+ glioma_model = GliomaStageModel()
14
+ glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
15
+ glioma_model.eval()
16
+
17
+ # Labels
18
+ tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
19
+ stage_labels = ['Stage 1', 'Stage 2'] # Only 2 classes in your model
20
+
21
+ # Transform for image input
22
+ transform = transforms.Compose([
23
+ transforms.Grayscale(),
24
+ transforms.Resize((224, 224)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.5], std=[0.5])
27
+ ])
28
+
29
+ # Predict tumor type
30
+ def predict_tumor(image):
31
+ image = transform(image).unsqueeze(0)
32
+ with torch.no_grad():
33
+ out = tumor_model(image)
34
+ pred = torch.argmax(out, dim=1).item()
35
+ return tumor_labels[pred]
36
+
37
+ # Predict glioma stage
38
+ def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
39
+ gender_val = 0 if gender == "Male" else 1
40
+ features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
41
+ x = torch.tensor(features).float().unsqueeze(0)
42
+ with torch.no_grad():
43
+ out = glioma_model(x)
44
+ pred = torch.argmax(out, dim=1).item()
45
+ return stage_labels[pred]
46
+
47
+ # Interface 1: Tumor Classification
48
+ tumor_tab = gr.Interface(
49
+ fn=predict_tumor,
50
+ inputs=gr.Image(type="pil"),
51
+ outputs=gr.Label(),
52
+ title="🧠 Brain Tumor Detection",
53
+ description="Upload an MRI image to classify tumor type: glioma, meningioma, notumor, or pituitary."
54
+ )
55
+
56
+ # Interface 2: Glioma Stage Prediction
57
+ stage_tab = gr.Interface(
58
+ fn=predict_stage,
59
+ inputs=[
60
+ gr.Radio(["Male", "Female"], label="Gender"),
61
+ gr.Slider(0, 100, label="Age"),
62
+ gr.Slider(0, 1, step=1, label="IDH1 Mutation"),
63
+ gr.Slider(0, 1, step=1, label="TP53 Mutation"),
64
+ gr.Slider(0, 1, step=1, label="ATRX Mutation"),
65
+ gr.Slider(0, 1, step=1, label="PTEN Mutation"),
66
+ gr.Slider(0, 1, step=1, label="EGFR Mutation"),
67
+ gr.Slider(0, 1, step=1, label="CIC Mutation"),
68
+ gr.Slider(0, 1, step=1, label="PIK3CA Mutation")
69
+ ],
70
+ outputs=gr.Label(),
71
+ title="🧬 Glioma Stage Classifier",
72
+ description="Enter patient mutation and demographic data to classify glioma stage."
73
+ )
74
+
75
+ # Combine both into a tabbed interface
76
+ demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"])
77
+
78
+ # Launch
79
+ demo.launch()