Codewithsalty commited on
Commit
a75fe7e
·
verified ·
1 Parent(s): 3a92ab0

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +49 -56
newapi.py CHANGED
@@ -1,72 +1,65 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- import torch
5
- from torchvision import transforms
6
- from PIL import Image
7
- import io
8
  import os
 
 
 
 
 
 
9
 
10
- # ✅ Set Hugging Face model cache directory to a writable path
11
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
12
-
13
- from huggingface_hub import hf_hub_download
14
- from models.TumorModel import TumorClassification, GliomaStageModel
15
- from utils import get_precautions_from_gemini
16
-
17
- # Define your app
18
  app = FastAPI()
19
 
20
- # Enable CORS
21
- app.add_middleware(
22
- CORSMiddleware,
23
- allow_origins=["*"],
24
- allow_credentials=True,
25
- allow_methods=["*"],
26
- allow_headers=["*"],
27
- )
28
 
29
- # Load your models from the Hugging Face Hub
30
- btd_model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-detection", filename="brain_tumor_model.pt")
31
- glioma_model_path = hf_hub_download(repo_id="Codewithsalty/brain-tumor-detection", filename="glioma_stage_model.pt")
 
32
 
33
- btd_model = TumorClassification(model_path=btd_model_path)
34
- glioma_model = GliomaStageModel(model_path=glioma_model_path)
 
 
35
 
36
- # Image preprocessing
37
  transform = transforms.Compose([
38
  transforms.Resize((224, 224)),
39
- transforms.ToTensor()
40
  ])
41
 
42
- class DiagnosisResponse(BaseModel):
43
- tumor: str
44
- stage: str
45
- precautions: list
46
-
47
- @app.post("/predict", response_model=DiagnosisResponse)
48
  async def predict(file: UploadFile = File(...)):
49
- contents = await file.read()
50
- image = Image.open(io.BytesIO(contents)).convert("RGB")
51
- image_tensor = transform(image).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- tumor_result = btd_model.predict(image_tensor)
54
- if tumor_result == "No Tumor":
55
- return DiagnosisResponse(
56
- tumor="No Tumor Detected",
57
- stage="N/A",
58
- precautions=[]
59
- )
60
 
61
- stage_result = glioma_model.predict(image_tensor)
62
- precautions = get_precautions_from_gemini(tumor_result, stage_result)
63
 
64
- return DiagnosisResponse(
65
- tumor=tumor_result,
66
- stage=stage_result,
67
- precautions=precautions
68
- )
69
 
70
- @app.get("/")
71
- def root():
72
- return {"message": "Brain Tumor API is running."}
 
 
 
 
 
 
 
 
1
  import os
2
+ from fastapi import FastAPI, UploadFile, File
3
+ from fastapi.responses import JSONResponse
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ from utils import BrainTumorModel, GliomaStageModel
8
 
 
 
 
 
 
 
 
 
9
  app = FastAPI()
10
 
11
+ # Load models (updated to local .pth files)
12
+ btd_model_path = "brain_tumor_model.pth"
13
+ glioma_model_path = "glioma_stage_model.pth"
 
 
 
 
 
14
 
15
+ # Initialize and load Brain Tumor Detection Model
16
+ btd_model = BrainTumorModel()
17
+ btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
18
+ btd_model.eval()
19
 
20
+ # Initialize and load Glioma Stage Detection Model
21
+ glioma_model = GliomaStageModel()
22
+ glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu')))
23
+ glioma_model.eval()
24
 
25
+ # Define preprocessing
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
  ])
30
 
31
+ @app.post("/predict/")
 
 
 
 
 
32
  async def predict(file: UploadFile = File(...)):
33
+ try:
34
+ image = Image.open(file.file).convert("RGB")
35
+ image = transform(image).unsqueeze(0)
36
+
37
+ with torch.no_grad():
38
+ output = btd_model(image)
39
+ predicted = torch.argmax(output, dim=1).item()
40
+
41
+ classes = ['No Tumor', 'Pituitary', 'Meningioma', 'Glioma']
42
+ result = classes[predicted]
43
+
44
+ return JSONResponse(content={"prediction": result})
45
+
46
+ except Exception as e:
47
+ return JSONResponse(content={"error": str(e)})
48
+
49
+ @app.post("/glioma-stage/")
50
+ async def glioma_stage(file: UploadFile = File(...)):
51
+ try:
52
+ image = Image.open(file.file).convert("RGB")
53
+ image = transform(image).unsqueeze(0)
54
 
55
+ with torch.no_grad():
56
+ output = glioma_model(image)
57
+ predicted = torch.argmax(output, dim=1).item()
 
 
 
 
58
 
59
+ stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
60
+ result = stages[predicted]
61
 
62
+ return JSONResponse(content={"glioma_stage": result})
 
 
 
 
63
 
64
+ except Exception as e:
65
+ return JSONResponse(content={"error": str(e)})