Codewithsalty commited on
Commit
60d002b
·
verified ·
1 Parent(s): 119a255

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +44 -79
newapi.py CHANGED
@@ -6,17 +6,18 @@ from torchvision import transforms
6
  from PIL import Image
7
  import io
8
  import os
9
- from huggingface_hub import hf_hub_download
10
 
 
 
 
 
11
  from models.TumorModel import TumorClassification, GliomaStageModel
12
  from utils import get_precautions_from_gemini
13
 
14
- # Use /data as Hugging Face allows writing here only
15
- cache_dir = os.path.join("/data", "cache")
16
- os.makedirs(cache_dir, exist_ok=True)
17
-
18
- app = FastAPI(title="Brain Tumor Detection API")
19
 
 
20
  app.add_middleware(
21
  CORSMiddleware,
22
  allow_origins=["*"],
@@ -25,83 +26,47 @@ app.add_middleware(
25
  allow_headers=["*"],
26
  )
27
 
28
- # ✅ Load Tumor Classification Model
29
- btd_model_path = hf_hub_download(
30
- repo_id="Codewithsalty/brain-tumor-models",
31
- filename="BTD_model.pth",
32
- cache_dir=cache_dir
33
- )
34
- tumor_model = TumorClassification()
35
- tumor_model.load_state_dict(torch.load(btd_model_path, map_location="cpu"))
36
- tumor_model.eval()
37
 
38
- # Load Glioma Stage Model
39
- glioma_model_path = hf_hub_download(
40
- repo_id="Codewithsalty/brain-tumor-models",
41
- filename="glioma_stages.pth",
42
- cache_dir=cache_dir
43
- )
44
- glioma_model = GliomaStageModel()
45
- glioma_model.load_state_dict(torch.load(glioma_model_path, map_location="cpu"))
46
- glioma_model.eval()
47
 
 
48
  transform = transforms.Compose([
49
- transforms.Grayscale(),
50
  transforms.Resize((224, 224)),
51
- transforms.ToTensor(),
52
- transforms.Normalize(mean=[0.5], std=[0.5]),
53
  ])
54
 
55
- @app.get("/")
56
- async def root():
57
- return {"message": "Brain Tumor Detection API is running."}
58
-
59
- labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- @app.post("/predict-image")
62
- async def predict_image(file: UploadFile = File(...)):
63
- img_bytes = await file.read()
64
- img = Image.open(io.BytesIO(img_bytes)).convert("L")
65
- x = transform(img).unsqueeze(0)
66
-
67
- with torch.no_grad():
68
- out = tumor_model(x)
69
- idx = torch.argmax(out, dim=1).item()
70
- tumor_type = labels[idx]
71
-
72
- if tumor_type == "glioma":
73
- return {"tumor_type": tumor_type, "next": "submit_mutation_data"}
74
- else:
75
- precautions = get_precautions_from_gemini(tumor_type)
76
- return {"tumor_type": tumor_type, "precaution": precautions}
77
-
78
- class MutationInput(BaseModel):
79
- gender: str
80
- age: float
81
- idh1: int
82
- tp53: int
83
- atrx: int
84
- pten: int
85
- egfr: int
86
- cic: int
87
- pik3ca: int
88
-
89
- @app.post("/predict-glioma-stage")
90
- async def predict_glioma_stage(data: MutationInput):
91
- gender_val = 0 if data.gender.lower() == 'm' else 1
92
- features = [
93
- gender_val, data.age, data.idh1, data.tp53, data.atrx,
94
- data.pten, data.egfr, data.cic, data.pik3ca
95
- ]
96
- x = torch.tensor(features).float().unsqueeze(0)
97
-
98
- with torch.no_grad():
99
- out = glioma_model(x)
100
- idx = torch.argmax(out, dim=1).item()
101
- stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
102
- return {"glioma_stage": stages[idx]}
103
-
104
- # Only used when running locally
105
- if __name__ == "__main__":
106
- import uvicorn
107
- uvicorn.run("newapi:app", host="0.0.0.0", port=10000)
 
6
  from PIL import Image
7
  import io
8
  import os
 
9
 
10
+ # ✅ Set Hugging Face model cache directory to a writable path
11
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
12
+
13
+ from huggingface_hub import hf_hub_download
14
  from models.TumorModel import TumorClassification, GliomaStageModel
15
  from utils import get_precautions_from_gemini
16
 
17
+ # Define your app
18
+ app = FastAPI()
 
 
 
19
 
20
+ # ✅ Enable CORS
21
  app.add_middleware(
22
  CORSMiddleware,
23
  allow_origins=["*"],
 
26
  allow_headers=["*"],
27
  )
28
 
29
+ # ✅ Load your models from the Hugging Face Hub
30
+ btd_model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-detection", filename="brain_tumor_model.pt")
31
+ glioma_model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-detection", filename="glioma_stage_model.pt")
 
 
 
 
 
 
32
 
33
+ btd_model = TumorClassification(model_path=btd_model_path)
34
+ glioma_model = GliomaStageModel(model_path=glioma_model_path)
 
 
 
 
 
 
 
35
 
36
+ # ✅ Image preprocessing
37
  transform = transforms.Compose([
 
38
  transforms.Resize((224, 224)),
39
+ transforms.ToTensor()
 
40
  ])
41
 
42
+ class DiagnosisResponse(BaseModel):
43
+ tumor: str
44
+ stage: str
45
+ precautions: list
46
+
47
+ @app.post("/predict", response_model=DiagnosisResponse)
48
+ async def predict(file: UploadFile = File(...)):
49
+ contents = await file.read()
50
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
51
+ image_tensor = transform(image).unsqueeze(0)
52
+
53
+ tumor_result = btd_model.predict(image_tensor)
54
+ if tumor_result == "No Tumor":
55
+ return DiagnosisResponse(
56
+ tumor="No Tumor Detected",
57
+ stage="N/A",
58
+ precautions=[]
59
+ )
60
+
61
+ stage_result = glioma_model.predict(image_tensor)
62
+ precautions = get_precautions_from_gemini(tumor_result, stage_result)
63
+
64
+ return DiagnosisResponse(
65
+ tumor=tumor_result,
66
+ stage=stage_result,
67
+ precautions=precautions
68
+ )
69
 
70
+ @app.get("/")
71
+ def root():
72
+ return {"message": "Brain Tumor API is running."}