Codewithsalty commited on
Commit
e57c445
Β·
verified Β·
1 Parent(s): 33617dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -4,29 +4,29 @@ from PIL import Image
4
  from torchvision import transforms
5
  from TumorModel import TumorClassification, GliomaStageModel
6
 
7
- # Load tumor classification model
8
  tumor_model = TumorClassification()
9
  tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
10
  tumor_model.eval()
11
 
12
- # Load glioma stage model
13
  glioma_model = GliomaStageModel()
14
  glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
15
  glioma_model.eval()
16
 
17
- # Labels
18
  tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
19
- stage_labels = ['Stage 1', 'Stage 2'] # Only 2 classes in your model
20
 
21
- # Transform for image input
22
  transform = transforms.Compose([
23
  transforms.Grayscale(),
24
- transforms.Resize((224, 224)),
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.5], std=[0.5])
27
  ])
28
 
29
- # Predict tumor type
30
  def predict_tumor(image):
31
  image = transform(image).unsqueeze(0)
32
  with torch.no_grad():
@@ -34,7 +34,7 @@ def predict_tumor(image):
34
  pred = torch.argmax(out, dim=1).item()
35
  return tumor_labels[pred]
36
 
37
- # Predict glioma stage
38
  def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
39
  gender_val = 0 if gender == "Male" else 1
40
  features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
@@ -44,16 +44,16 @@ def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
44
  pred = torch.argmax(out, dim=1).item()
45
  return stage_labels[pred]
46
 
47
- # Interface 1: Tumor Classification
48
  tumor_tab = gr.Interface(
49
  fn=predict_tumor,
50
  inputs=gr.Image(type="pil"),
51
  outputs=gr.Label(),
52
- title="🧠 Brain Tumor Detection",
53
  description="Upload an MRI image to classify tumor type: glioma, meningioma, notumor, or pituitary."
54
  )
55
 
56
- # Interface 2: Glioma Stage Prediction
57
  stage_tab = gr.Interface(
58
  fn=predict_stage,
59
  inputs=[
@@ -69,11 +69,14 @@ stage_tab = gr.Interface(
69
  ],
70
  outputs=gr.Label(),
71
  title="🧬 Glioma Stage Classifier",
72
- description="Enter patient mutation and demographic data to classify glioma stage."
73
  )
74
 
75
- # Combine both into a tabbed interface
76
- demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"])
 
 
 
77
 
78
- # Launch
79
  demo.launch()
 
4
  from torchvision import transforms
5
  from TumorModel import TumorClassification, GliomaStageModel
6
 
7
+ # βœ… Load tumor classification model
8
  tumor_model = TumorClassification()
9
  tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
10
  tumor_model.eval()
11
 
12
+ # βœ… Load glioma stage classification model
13
  glioma_model = GliomaStageModel()
14
  glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
15
  glioma_model.eval()
16
 
17
+ # βœ… Labels
18
  tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
19
+ stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
20
 
21
+ # βœ… Transform (resize to 208x208 to match training)
22
  transform = transforms.Compose([
23
  transforms.Grayscale(),
24
+ transforms.Resize((208, 208)), # <-- important for matching FC input
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.5], std=[0.5])
27
  ])
28
 
29
+ # βœ… Tumor Prediction Function
30
  def predict_tumor(image):
31
  image = transform(image).unsqueeze(0)
32
  with torch.no_grad():
 
34
  pred = torch.argmax(out, dim=1).item()
35
  return tumor_labels[pred]
36
 
37
+ # βœ… Glioma Stage Prediction Function
38
  def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
39
  gender_val = 0 if gender == "Male" else 1
40
  features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
 
44
  pred = torch.argmax(out, dim=1).item()
45
  return stage_labels[pred]
46
 
47
+ # βœ… Tumor Detection Tab
48
  tumor_tab = gr.Interface(
49
  fn=predict_tumor,
50
  inputs=gr.Image(type="pil"),
51
  outputs=gr.Label(),
52
+ title="🧠 Brain Tumor Detector",
53
  description="Upload an MRI image to classify tumor type: glioma, meningioma, notumor, or pituitary."
54
  )
55
 
56
+ # βœ… Glioma Stage Prediction Tab
57
  stage_tab = gr.Interface(
58
  fn=predict_stage,
59
  inputs=[
 
69
  ],
70
  outputs=gr.Label(),
71
  title="🧬 Glioma Stage Classifier",
72
+ description="Enter mutation and demographic data to classify glioma stage."
73
  )
74
 
75
+ # βœ… Combine into a tabbed interface
76
+ demo = gr.TabbedInterface(
77
+ [tumor_tab, stage_tab],
78
+ tab_names=["Tumor Detector", "Glioma Stage Predictor"]
79
+ )
80
 
81
+ # βœ… Launch the app
82
  demo.launch()