Spaces:
Runtime error
Runtime error
File size: 3,413 Bytes
c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 c73e079 91d7910 e57c445 c73e079 91d7910 c73e079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
from TumorModel import TumorClassification, GliomaStageModel
# βββ Load Tumor Classification Model βββββββββββββββββββββββββββββββββββββββββββ
tumor_model = TumorClassification()
sd = torch.load("BTD_model.pth", map_location="cpu")
renamed_sd = {}
for k, v in sd.items():
new_key = (k
.replace("con1d.", "model.0.")
.replace("con2d.", "model.3.")
.replace("con3d.", "model.6.")
.replace("fc1.", "model.8.")
.replace("fc2.", "model.10.")
.replace("output.", "model.12."))
renamed_sd[new_key] = v
tumor_model.load_state_dict(renamed_sd)
tumor_model.eval()
# βββ Load Glioma Stage Model βββββββββββββββββββββββββββββββββββββββββββββββββββ
glioma_model = GliomaStageModel()
glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location="cpu"))
glioma_model.eval()
# βββ Labels and Image Transform βββββββββββββββββββββββββββββββββββββββββββββββ
tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
stage_labels = ['Stage 1', 'Stage 2'] # Or adjust to match your second model
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((208, 208)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# βββ Gradio Prediction Functions βββββββββββββββββββββββββββββββββββββββββββββββ
def predict_tumor(image):
tensor = transform(image).unsqueeze(0)
with torch.no_grad():
out = tumor_model(tensor)
pred = torch.argmax(out, dim=1).item()
return tumor_labels[pred]
def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca):
gender_val = 0 if gender == "Male" else 1
features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca]
x = torch.tensor(features).float().unsqueeze(0)
with torch.no_grad():
out = glioma_model(x)
pred = torch.argmax(out, dim=1).item()
return stage_labels[pred]
# βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
tumor_tab = gr.Interface(
fn=predict_tumor,
inputs=gr.Image(type="pil"),
outputs=gr.Label(),
title="Brain Tumor Detector"
)
stage_tab = gr.Interface(
fn=predict_stage,
inputs=[
gr.Radio(["Male", "Female"], label="Gender"),
gr.Slider(0, 100, label="Age"),
gr.Slider(0, 1, step=1, label="IDH1"),
gr.Slider(0, 1, step=1, label="TP53"),
gr.Slider(0, 1, step=1, label="ATRX"),
gr.Slider(0, 1, step=1, label="PTEN"),
gr.Slider(0, 1, step=1, label="EGFR"),
gr.Slider(0, 1, step=1, label="CIC"),
gr.Slider(0, 1, step=1, label="PIK3CA")
],
outputs=gr.Label(),
title="Glioma Stage Predictor"
)
demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"])
demo.launch()
|