Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
from TumorModel import TumorClassification, GliomaStageModel | |
# βββ Load Tumor Classification Model βββββββββββββββββββββββββββββββββββββββββββ | |
tumor_model = TumorClassification() | |
sd = torch.load("BTD_model.pth", map_location="cpu") | |
renamed_sd = {} | |
for k, v in sd.items(): | |
new_key = (k | |
.replace("con1d.", "model.0.") | |
.replace("con2d.", "model.3.") | |
.replace("con3d.", "model.6.") | |
.replace("fc1.", "model.8.") | |
.replace("fc2.", "model.10.") | |
.replace("output.", "model.12.")) | |
renamed_sd[new_key] = v | |
tumor_model.load_state_dict(renamed_sd) | |
tumor_model.eval() | |
# βββ Load Glioma Stage Model βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
glioma_model = GliomaStageModel() | |
glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location="cpu")) | |
glioma_model.eval() | |
# βββ Labels and Image Transform βββββββββββββββββββββββββββββββββββββββββββββββ | |
tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary'] | |
stage_labels = ['Stage 1', 'Stage 2'] # Or adjust to match your second model | |
transform = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.Resize((208, 208)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
# βββ Gradio Prediction Functions βββββββββββββββββββββββββββββββββββββββββββββββ | |
def predict_tumor(image): | |
tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
out = tumor_model(tensor) | |
pred = torch.argmax(out, dim=1).item() | |
return tumor_labels[pred] | |
def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca): | |
gender_val = 0 if gender == "Male" else 1 | |
features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca] | |
x = torch.tensor(features).float().unsqueeze(0) | |
with torch.no_grad(): | |
out = glioma_model(x) | |
pred = torch.argmax(out, dim=1).item() | |
return stage_labels[pred] | |
# βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
tumor_tab = gr.Interface( | |
fn=predict_tumor, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(), | |
title="Brain Tumor Detector" | |
) | |
stage_tab = gr.Interface( | |
fn=predict_stage, | |
inputs=[ | |
gr.Radio(["Male", "Female"], label="Gender"), | |
gr.Slider(0, 100, label="Age"), | |
gr.Slider(0, 1, step=1, label="IDH1"), | |
gr.Slider(0, 1, step=1, label="TP53"), | |
gr.Slider(0, 1, step=1, label="ATRX"), | |
gr.Slider(0, 1, step=1, label="PTEN"), | |
gr.Slider(0, 1, step=1, label="EGFR"), | |
gr.Slider(0, 1, step=1, label="CIC"), | |
gr.Slider(0, 1, step=1, label="PIK3CA") | |
], | |
outputs=gr.Label(), | |
title="Glioma Stage Predictor" | |
) | |
demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"]) | |
demo.launch() | |