Codewithsalty commited on
Commit
342c341
·
verified ·
1 Parent(s): 374a9d4

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +36 -45
newapi.py CHANGED
@@ -1,68 +1,59 @@
1
- import os
2
- from fastapi import FastAPI, UploadFile, File
3
  from fastapi.responses import JSONResponse
4
- from PIL import Image
5
  import torch
6
  import torchvision.transforms as transforms
7
- from utils import BrainTumorModel, GliomaStageModel
 
8
 
9
- app = FastAPI()
10
 
11
- # === Use exact filenames from the Space directory ===
12
- btd_model_path = "brain_tumor_model.pth"
13
- glioma_model_path = "glioma_stage_model.pth"
14
 
15
- # === Load Brain Tumor Model ===
16
- btd_model = BrainTumorModel()
 
 
 
 
 
 
 
 
 
 
17
  btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
18
  btd_model.eval()
19
 
20
- # === Load Glioma Stage Model ===
21
- glioma_model = GliomaStageModel()
22
- glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu')))
23
- glioma_model.eval()
24
-
25
- # === Image Transform ===
26
  transform = transforms.Compose([
27
- transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
 
29
  ])
30
 
31
- # === Routes ===
 
 
32
 
33
  @app.post("/predict/")
34
  async def predict(file: UploadFile = File(...)):
35
  try:
36
- image = Image.open(file.file).convert("RGB")
37
- image = transform(image).unsqueeze(0)
38
-
39
- with torch.no_grad():
40
- output = btd_model(image)
41
- predicted = torch.argmax(output, dim=1).item()
42
-
43
- classes = ['No Tumor', 'Pituitary', 'Meningioma', 'Glioma']
44
- result = classes[predicted]
45
-
46
- return JSONResponse(content={"prediction": result})
47
-
48
- except Exception as e:
49
- return JSONResponse(content={"error": str(e)})
50
-
51
-
52
- @app.post("/glioma-stage/")
53
- async def glioma_stage(file: UploadFile = File(...)):
54
- try:
55
- image = Image.open(file.file).convert("RGB")
56
  image = transform(image).unsqueeze(0)
57
 
 
58
  with torch.no_grad():
59
- output = glioma_model(image)
60
- predicted = torch.argmax(output, dim=1).item()
61
-
62
- stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
63
- result = stages[predicted]
64
 
65
- return JSONResponse(content={"glioma_stage": result})
 
 
66
 
 
 
67
  except Exception as e:
68
- return JSONResponse(content={"error": str(e)})
 
1
+ from fastapi import FastAPI, File, UploadFile
 
2
  from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  import torch
5
  import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import io
8
 
9
+ from utils import YourModelClass # Make sure this matches your actual model class
10
 
11
+ app = FastAPI()
 
 
12
 
13
+ # CORS Middleware (optional but good for frontend API usage)
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ # Load model
23
+ btd_model_path = "models/BTD_model.pth" # ✅ Correct filename and folder
24
+ btd_model = YourModelClass()
25
  btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
26
  btd_model.eval()
27
 
28
+ # Image transformation (adjust according to how your model was trained)
 
 
 
 
 
29
  transform = transforms.Compose([
30
+ transforms.Resize((224, 224)), # Adjust to your model's expected input size
31
  transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.5], std=[0.5]) # Adjust for grayscale or RGB
33
  ])
34
 
35
+ @app.get("/")
36
+ def root():
37
+ return {"message": "Brain Tumor Detection API is up and running!"}
38
 
39
  @app.post("/predict/")
40
  async def predict(file: UploadFile = File(...)):
41
  try:
42
+ # Read image
43
+ contents = await file.read()
44
+ image = Image.open(io.BytesIO(contents)).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  image = transform(image).unsqueeze(0)
46
 
47
+ # Run model
48
  with torch.no_grad():
49
+ outputs = btd_model(image)
50
+ _, predicted = torch.max(outputs, 1)
 
 
 
51
 
52
+ # Class mapping (adjust according to your model's labels)
53
+ classes = ['No Tumor', 'Glioma', 'Meningioma', 'Pituitary']
54
+ prediction = classes[predicted.item()]
55
 
56
+ return JSONResponse({"prediction": prediction})
57
+
58
  except Exception as e:
59
+ return JSONResponse(status_code=500, content={"error": str(e)})