Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
from TumorModel import TumorClassification, GliomaStageModel | |
# β Load tumor classification model | |
tumor_model = TumorClassification() | |
tumor_model.load_state_dict(torch.load("BTD_model.pth", map_location=torch.device("cpu"))) | |
tumor_model.eval() | |
# β Load glioma stage classification model | |
glioma_model = GliomaStageModel() | |
glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location=torch.device("cpu"))) | |
glioma_model.eval() | |
# β Labels | |
tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary'] | |
stage_labels = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4'] | |
# β Transform (resize to 208x208 to match training) | |
transform = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.Resize((208, 208)), # <-- important for matching FC input | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
# β Tumor Prediction Function | |
def predict_tumor(image): | |
image = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
out = tumor_model(image) | |
pred = torch.argmax(out, dim=1).item() | |
return tumor_labels[pred] | |
# β Glioma Stage Prediction Function | |
def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca): | |
gender_val = 0 if gender == "Male" else 1 | |
features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca] | |
x = torch.tensor(features).float().unsqueeze(0) | |
with torch.no_grad(): | |
out = glioma_model(x) | |
pred = torch.argmax(out, dim=1).item() | |
return stage_labels[pred] | |
# β Tumor Detection Tab | |
tumor_tab = gr.Interface( | |
fn=predict_tumor, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(), | |
title="π§ Brain Tumor Detector", | |
description="Upload an MRI image to classify tumor type: glioma, meningioma, notumor, or pituitary." | |
) | |
# β Glioma Stage Prediction Tab | |
stage_tab = gr.Interface( | |
fn=predict_stage, | |
inputs=[ | |
gr.Radio(["Male", "Female"], label="Gender"), | |
gr.Slider(0, 100, label="Age"), | |
gr.Slider(0, 1, step=1, label="IDH1 Mutation"), | |
gr.Slider(0, 1, step=1, label="TP53 Mutation"), | |
gr.Slider(0, 1, step=1, label="ATRX Mutation"), | |
gr.Slider(0, 1, step=1, label="PTEN Mutation"), | |
gr.Slider(0, 1, step=1, label="EGFR Mutation"), | |
gr.Slider(0, 1, step=1, label="CIC Mutation"), | |
gr.Slider(0, 1, step=1, label="PIK3CA Mutation") | |
], | |
outputs=gr.Label(), | |
title="𧬠Glioma Stage Classifier", | |
description="Enter mutation and demographic data to classify glioma stage." | |
) | |
# β Combine into a tabbed interface | |
demo = gr.TabbedInterface( | |
[tumor_tab, stage_tab], | |
tab_names=["Tumor Detector", "Glioma Stage Predictor"] | |
) | |
# β Launch the app | |
demo.launch() | |