Codewithsalty commited on
Commit
91d7910
Β·
verified Β·
1 Parent(s): 81484f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -34
app.py CHANGED
@@ -4,37 +4,47 @@ from PIL import Image
4
  from torchvision import transforms
5
  from TumorModel import TumorClassification, GliomaStageModel
6
 
7
- # βœ… Load tumor classification model
8
  tumor_model = TumorClassification()
9
- tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
 
 
 
 
 
 
 
 
 
 
 
10
  tumor_model.eval()
11
 
12
- # βœ… Load glioma stage classification model
13
  glioma_model = GliomaStageModel()
14
- glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
15
  glioma_model.eval()
16
 
17
- # βœ… Labels
18
  tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
19
- stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
20
 
21
- # βœ… Transform (resize to 208x208 to match training)
22
  transform = transforms.Compose([
23
  transforms.Grayscale(),
24
- transforms.Resize((208, 208)), # <-- important for matching FC input
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.5], std=[0.5])
27
  ])
28
 
29
- # βœ… Tumor Prediction Function
 
30
  def predict_tumor(image):
31
- image = transform(image).unsqueeze(0)
32
  with torch.no_grad():
33
- out = tumor_model(image)
34
  pred = torch.argmax(out, dim=1).item()
35
- return tumor_labels[pred]
36
 
37
- # βœ… Glioma Stage Prediction Function
38
  def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
39
  gender_val = 0 if gender == "Male" else 1
40
  features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
@@ -42,41 +52,33 @@ def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
42
  with torch.no_grad():
43
  out = glioma_model(x)
44
  pred = torch.argmax(out, dim=1).item()
45
- return stage_labels[pred]
 
 
46
 
47
- # βœ… Tumor Detection Tab
48
  tumor_tab = gr.Interface(
49
  fn=predict_tumor,
50
  inputs=gr.Image(type="pil"),
51
  outputs=gr.Label(),
52
- title="🧠 Brain Tumor Detector",
53
- description="Upload an MRI image to classify tumor type: glioma, meningioma, notumor, or pituitary."
54
  )
55
 
56
- # βœ… Glioma Stage Prediction Tab
57
  stage_tab = gr.Interface(
58
  fn=predict_stage,
59
  inputs=[
60
  gr.Radio(["Male", "Female"], label="Gender"),
61
  gr.Slider(0, 100, label="Age"),
62
- gr.Slider(0, 1, step=1, label="IDH1 Mutation"),
63
- gr.Slider(0, 1, step=1, label="TP53 Mutation"),
64
- gr.Slider(0, 1, step=1, label="ATRX Mutation"),
65
- gr.Slider(0, 1, step=1, label="PTEN Mutation"),
66
- gr.Slider(0, 1, step=1, label="EGFR Mutation"),
67
- gr.Slider(0, 1, step=1, label="CIC Mutation"),
68
- gr.Slider(0, 1, step=1, label="PIK3CA Mutation")
69
  ],
70
  outputs=gr.Label(),
71
- title="🧬 Glioma Stage Classifier",
72
- description="Enter mutation and demographic data to classify glioma stage."
73
- )
74
-
75
- # βœ… Combine into a tabbed interface
76
- demo = gr.TabbedInterface(
77
- [tumor_tab, stage_tab],
78
- tab_names=["Tumor Detector", "Glioma Stage Predictor"]
79
  )
80
 
81
- # βœ… Launch the app
82
  demo.launch()
 
4
  from torchvision import transforms
5
  from TumorModel import TumorClassification, GliomaStageModel
6
 
7
+ # ─── Load Tumor Classification Model ───────────────────────────────────────────
8
  tumor_model = TumorClassification()
9
+ sd = torch.load("BTD_model.pth", map_location="cpu")
10
+ renamed_sd = {}
11
+ for k, v in sd.items():
12
+ new_key = (k
13
+ .replace("con1d.", "model.0.")
14
+ .replace("con2d.", "model.3.")
15
+ .replace("con3d.", "model.6.")
16
+ .replace("fc1.", "model.8.")
17
+ .replace("fc2.", "model.10.")
18
+ .replace("output.", "model.12."))
19
+ renamed_sd[new_key] = v
20
+ tumor_model.load_state_dict(renamed_sd)
21
  tumor_model.eval()
22
 
23
+ # ─── Load Glioma Stage Model ───────────────────────────────────────────────────
24
  glioma_model = GliomaStageModel()
25
+ glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location="cpu"))
26
  glioma_model.eval()
27
 
28
+ # ─── Labels and Image Transform ───────────────────────────────────────────────
29
  tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
30
+ stage_labels = ['Stage 1', 'Stage 2'] # Or adjust to match your second model
31
 
 
32
  transform = transforms.Compose([
33
  transforms.Grayscale(),
34
+ transforms.Resize((208, 208)),
35
  transforms.ToTensor(),
36
  transforms.Normalize(mean=[0.5], std=[0.5])
37
  ])
38
 
39
+ # ─── Gradio Prediction Functions ───────────────────────────────────────────────
40
+
41
  def predict_tumor(image):
42
+ tensor = transform(image).unsqueeze(0)
43
  with torch.no_grad():
44
+ out = tumor_model(tensor)
45
  pred = torch.argmax(out, dim=1).item()
46
+ return tumor_labels[pred]
47
 
 
48
  def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
49
  gender_val = 0 if gender == "Male" else 1
50
  features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
 
52
  with torch.no_grad():
53
  out = glioma_model(x)
54
  pred = torch.argmax(out, dim=1).item()
55
+ return stage_labels[pred]
56
+
57
+ # ─── Gradio UI ────────────────────────────────────────────────────────────────
58
 
 
59
  tumor_tab = gr.Interface(
60
  fn=predict_tumor,
61
  inputs=gr.Image(type="pil"),
62
  outputs=gr.Label(),
63
+ title="Brain Tumor Detector"
 
64
  )
65
 
 
66
  stage_tab = gr.Interface(
67
  fn=predict_stage,
68
  inputs=[
69
  gr.Radio(["Male", "Female"], label="Gender"),
70
  gr.Slider(0, 100, label="Age"),
71
+ gr.Slider(0, 1, step=1, label="IDH1"),
72
+ gr.Slider(0, 1, step=1, label="TP53"),
73
+ gr.Slider(0, 1, step=1, label="ATRX"),
74
+ gr.Slider(0, 1, step=1, label="PTEN"),
75
+ gr.Slider(0, 1, step=1, label="EGFR"),
76
+ gr.Slider(0, 1, step=1, label="CIC"),
77
+ gr.Slider(0, 1, step=1, label="PIK3CA")
78
  ],
79
  outputs=gr.Label(),
80
+ title="Glioma Stage Predictor"
 
 
 
 
 
 
 
81
  )
82
 
83
+ demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"])
84
  demo.launch()