Spaces:
Runtime error
Runtime error
File size: 2,811 Bytes
c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 e57c445 c73e079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
from TumorModel import TumorClassification, GliomaStageModel
# β
Load tumor classification model
tumor_model = TumorClassification()
tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu")))
tumor_model.eval()
# β
Load glioma stage classification model
glioma_model = GliomaStageModel()
glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu")))
glioma_model.eval()
# β
Labels
tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
# β
Transform (resize to 208x208 to match training)
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((208, 208)), # <-- important for matching FC input
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# β
Tumor Prediction Function
def predict_tumor(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
out = tumor_model(image)
pred = torch.argmax(out, dim=1).item()
return tumor_labels[pred]
# β
Glioma Stage Prediction Function
def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
gender_val = 0 if gender == "Male" else 1
features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
x = torch.tensor(features).float().unsqueeze(0)
with torch.no_grad():
out = glioma_model(x)
pred = torch.argmax(out, dim=1).item()
return stage_labels[pred]
# β
Tumor Detection Tab
tumor_tab = gr.Interface(
fn=predict_tumor,
inputs=gr.Image(type="pil"),
outputs=gr.Label(),
title="π§ Brain Tumor Detector",
description="Upload an MRI image to classify tumor type: glioma, meningioma, notumor, or pituitary."
)
# β
Glioma Stage Prediction Tab
stage_tab = gr.Interface(
fn=predict_stage,
inputs=[
gr.Radio(["Male", "Female"], label="Gender"),
gr.Slider(0, 100, label="Age"),
gr.Slider(0, 1, step=1, label="IDH1 Mutation"),
gr.Slider(0, 1, step=1, label="TP53 Mutation"),
gr.Slider(0, 1, step=1, label="ATRX Mutation"),
gr.Slider(0, 1, step=1, label="PTEN Mutation"),
gr.Slider(0, 1, step=1, label="EGFR Mutation"),
gr.Slider(0, 1, step=1, label="CIC Mutation"),
gr.Slider(0, 1, step=1, label="PIK3CA Mutation")
],
outputs=gr.Label(),
title="𧬠Glioma Stage Classifier",
description="Enter mutation and demographic data to classify glioma stage."
)
# β
Combine into a tabbed interface
demo = gr.TabbedInterface(
[tumor_tab, stage_tab],
tab_names=["Tumor Detector", "Glioma Stage Predictor"]
)
# β
Launch the app
demo.launch()
|