import torch import gradio as gr from PIL import Image from torchvision import transforms from TumorModel import TumorClassification, GliomaStageModel # ─── Load Tumor Classification Model ─────────────────────────────────────────── tumor_model = TumorClassification() sd = torch.load("BTD_model.pth", map_location="cpu") renamed_sd = {} for k, v in sd.items(): new_key = (k .replace("con1d.", "model.0.") .replace("con2d.", "model.3.") .replace("con3d.", "model.6.") .replace("fc1.", "model.8.") .replace("fc2.", "model.10.") .replace("output.", "model.12.")) renamed_sd[new_key] = v tumor_model.load_state_dict(renamed_sd) tumor_model.eval() # ─── Load Glioma Stage Model ─────────────────────────────────────────────────── glioma_model = GliomaStageModel() glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location="cpu")) glioma_model.eval() # ─── Labels and Image Transform ─────────────────────────────────────────────── tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary'] stage_labels = ['Stage 1', 'Stage 2'] # Or adjust to match your second model transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((208, 208)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) # ─── Gradio Prediction Functions ─────────────────────────────────────────────── def predict_tumor(image): tensor = transform(image).unsqueeze(0) with torch.no_grad(): out = tumor_model(tensor) pred = torch.argmax(out, dim=1).item() return tumor_labels[pred] def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca): gender_val = 0 if gender == "Male" else 1 features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca] x = torch.tensor(features).float().unsqueeze(0) with torch.no_grad(): out = glioma_model(x) pred = torch.argmax(out, dim=1).item() return stage_labels[pred] # ─── Gradio UI ──────────────────────────────────────────────────────────────── tumor_tab = gr.Interface( fn=predict_tumor, inputs=gr.Image(type="pil"), outputs=gr.Label(), title="Brain Tumor Detector" ) stage_tab = gr.Interface( fn=predict_stage, inputs=[ gr.Radio(["Male", "Female"], label="Gender"), gr.Slider(0, 100, label="Age"), gr.Slider(0, 1, step=1, label="IDH1"), gr.Slider(0, 1, step=1, label="TP53"), gr.Slider(0, 1, step=1, label="ATRX"), gr.Slider(0, 1, step=1, label="PTEN"), gr.Slider(0, 1, step=1, label="EGFR"), gr.Slider(0, 1, step=1, label="CIC"), gr.Slider(0, 1, step=1, label="PIK3CA") ], outputs=gr.Label(), title="Glioma Stage Predictor" ) demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"]) demo.launch()