Codewithsalty commited on
Commit
f99073d
·
verified ·
1 Parent(s): ad92d1d

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +61 -66
newapi.py CHANGED
@@ -1,21 +1,22 @@
1
- import os
2
- from fastapi import FastAPI, File, UploadFile, HTTPException
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
6
- from PIL import Image
7
  import torch
8
  import torchvision.transforms as transforms
 
 
 
9
 
10
- from utils import BrainTumorModel, GliomaStageModel, get_precautions_from_gemini
 
 
11
 
12
- # ——— Constants ———
13
- MODEL_DIR = "models"
14
- BTD_FILENAME = "BTD_model.pth"
15
- GLIO_FILENAME = "glioma_stages.pth"
16
 
17
- # ——— App setup ———
18
- app = FastAPI(title="Brain Tumor Detection API")
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
@@ -23,68 +24,62 @@ app.add_middleware(
23
  allow_headers=["*"],
24
  )
25
 
26
- DEVICE = torch.device("cpu")
27
  transform = transforms.Compose([
28
  transforms.Resize((224, 224)),
29
  transforms.ToTensor(),
30
- transforms.Normalize([0.5]*3, [0.5]*3),
31
  ])
32
 
33
- def load_model(cls, fname):
34
- path = os.path.join(MODEL_DIR, fname)
35
- if not os.path.isfile(path):
36
- raise FileNotFoundError(f"Model file not found: {path}")
37
- m = cls().to(DEVICE)
38
- m.load_state_dict(torch.load(path, map_location=DEVICE))
39
- return m.eval()
40
 
41
- try:
42
- tumor_model = load_model(BrainTumorModel, BTD_FILENAME)
43
- glioma_model = load_model(GliomaStageModel, GLIO_FILENAME)
44
- except Exception as e:
45
- print("❌ Error loading model:", e)
 
 
 
 
 
 
 
 
46
 
47
- @app.get("/")
48
- async def health():
49
- return {"status": "ok", "message": "API is up"}
50
 
51
- @app.post("/predict-image/")
52
- async def predict_image(file: UploadFile = File(...)):
53
- if not file.content_type.startswith("image/"):
54
- raise HTTPException(400, "Upload an image")
55
- img = Image.open(file.file).convert("RGB")
56
- t = transform(img).unsqueeze(0).to(DEVICE)
57
- with torch.no_grad():
58
- out = tumor_model(t)
59
- idx = int(out.argmax(1))
60
- labels = ["glioma","meningioma","notumor","pituitary"]
61
- tumor_type = labels[idx]
62
- if tumor_type == "glioma":
63
- return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
64
- return {
65
- "tumor_type": tumor_type,
66
- "precaution": get_precautions_from_gemini(tumor_type)
67
- }
68
 
69
- class MutationInput(BaseModel):
70
- gender: str
71
- age: float
72
- idh1: int
73
- tp53: int
74
- atrx: int
75
- pten: int
76
- egfr: int
77
- cic: int
78
- pik3ca: int
79
 
80
- @app.post("/predict-glioma-stage/")
81
- async def predict_glioma_stage(data: MutationInput):
82
- gen = 0 if data.gender.lower().startswith("m") else 1
83
- feats = [gen, data.age, data.idh1, data.tp53,
84
- data.atrx, data.pten, data.egfr, data.cic, data.pik3ca]
85
- t = torch.tensor(feats, dtype=torch.float32).unsqueeze(0).to(DEVICE)
86
- with torch.no_grad():
87
- out = glioma_model(t)
88
- idx = int(out.argmax(1))
89
- stages = ["Stage 1","Stage 2","Stage 3","Stage 4"]
90
- return {"glioma_stage": stages[idx]}
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # newapi.py
2
+
3
+ from fastapi import FastAPI, File, UploadFile
4
  from fastapi.middleware.cors import CORSMiddleware
 
5
  from pydantic import BaseModel
 
6
  import torch
7
  import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ import io
10
+ import os
11
 
12
+ # Use a writable directory in Hugging Face Spaces
13
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
14
+ os.environ["HF_HOME"] = "/tmp/.cache"
15
 
16
+ # Define FastAPI app
17
+ app = FastAPI(title="🧠 Brain Tumor Detection API")
 
 
18
 
19
+ # Enable CORS
 
20
  app.add_middleware(
21
  CORSMiddleware,
22
  allow_origins=["*"],
 
24
  allow_headers=["*"],
25
  )
26
 
27
+ # Image transform
28
  transform = transforms.Compose([
29
  transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
32
  ])
33
 
34
+ # Define your model directly inside this file (to avoid import errors)
35
+ import torch.nn as nn
 
 
 
 
 
36
 
37
+ class BrainTumorModel(nn.Module):
38
+ def __init__(self):
39
+ super(BrainTumorModel, self).__init__()
40
+ self.model = nn.Sequential(
41
+ nn.Conv2d(3, 16, kernel_size=3),
42
+ nn.ReLU(),
43
+ nn.MaxPool2d(2),
44
+ nn.Conv2d(16, 32, kernel_size=3),
45
+ nn.ReLU(),
46
+ nn.MaxPool2d(2),
47
+ nn.Flatten(),
48
+ nn.Linear(32 * 54 * 54, 2),
49
+ )
50
 
51
+ def forward(self, x):
52
+ return self.model(x)
 
53
 
54
+ # Load model
55
+ model_path = "BTD_model.pth"
56
+ if not os.path.exists(model_path):
57
+ from huggingface_hub import hf_hub_download
58
+ model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-models", filename="BTD_model.pth", cache_dir="/tmp/.cache")
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ btd_model = BrainTumorModel()
61
+ btd_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
62
+ btd_model.eval()
 
 
 
 
 
 
 
63
 
64
+ # Define prediction endpoint
65
+ @app.post("/predict/")
66
+ async def predict(file: UploadFile = File(...)):
67
+ try:
68
+ contents = await file.read()
69
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
70
+ image_tensor = transform(image).unsqueeze(0)
71
+
72
+ with torch.no_grad():
73
+ output = btd_model(image_tensor)
74
+ prediction = torch.argmax(output, dim=1).item()
75
+
76
+ result = {0: "No tumor", 1: "Tumor detected"}[prediction]
77
+ return {"prediction": result}
78
+
79
+ except Exception as e:
80
+ return {"error": str(e)}
81
+
82
+ # Health check endpoint
83
+ @app.get("/")
84
+ def root():
85
+ return {"message": "🧠 Brain Tumor Detection API is running!"}