Codewithsalty commited on
Commit
28addcf
·
verified ·
1 Parent(s): 7903056

Update newapi.py

Browse files
Files changed (1) hide show
  1. newapi.py +32 -40
newapi.py CHANGED
@@ -1,59 +1,51 @@
1
- from fastapi import FastAPI, File, UploadFile
2
  from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
  import torch
5
  import torchvision.transforms as transforms
6
- from PIL import Image
7
- import io
8
-
9
- from utils import YourModelClass # Make sure this matches your actual model class
10
 
11
  app = FastAPI()
12
 
13
- # CORS Middleware (optional but good for frontend API usage)
14
- app.add_middleware(
15
- CORSMiddleware,
16
- allow_origins=["*"],
17
- allow_credentials=True,
18
- allow_methods=["*"],
19
- allow_headers=["*"],
20
- )
21
-
22
- # Load model
23
- btd_model_path = "models/BTD_model.pth" # Correct filename and folder
24
- btd_model = YourModelClass()
25
- btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
26
- btd_model.eval()
27
-
28
- # Image transformation (adjust according to how your model was trained)
29
  transform = transforms.Compose([
30
- transforms.Resize((224, 224)), # Adjust to your model's expected input size
31
- transforms.ToTensor(),
32
- transforms.Normalize(mean=[0.5], std=[0.5]) # Adjust for grayscale or RGB
33
  ])
34
 
 
 
 
35
  @app.get("/")
36
- def root():
37
- return {"message": "Brain Tumor Detection API is up and running!"}
38
 
39
- @app.post("/predict/")
40
  async def predict(file: UploadFile = File(...)):
41
  try:
42
- # Read image
43
- contents = await file.read()
44
- image = Image.open(io.BytesIO(contents)).convert('RGB')
45
- image = transform(image).unsqueeze(0)
46
 
47
- # Run model
48
  with torch.no_grad():
49
  outputs = btd_model(image)
50
- _, predicted = torch.max(outputs, 1)
 
 
51
 
52
- # Class mapping (adjust according to your model's labels)
53
- classes = ['No Tumor', 'Glioma', 'Meningioma', 'Pituitary']
54
- prediction = classes[predicted.item()]
 
55
 
56
- return JSONResponse({"prediction": prediction})
57
-
58
  except Exception as e:
59
- return JSONResponse(status_code=500, content={"error": str(e)})
 
1
+ from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
+ from PIL import Image
4
  import torch
5
  import torchvision.transforms as transforms
6
+ from utils import BrainTumorModel, get_precautions_from_gemini
 
 
 
7
 
8
  app = FastAPI()
9
 
10
+ # Load the model
11
+ btd_model = BrainTumorModel()
12
+ btd_model_path = "brain_tumor_model.pth"
13
+
14
+ try:
15
+ btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
16
+ btd_model.eval()
17
+ except Exception as e:
18
+ print(f"❌ Error loading model: {e}")
19
+
20
+ # Define image transform
 
 
 
 
 
21
  transform = transforms.Compose([
22
+ transforms.Resize((224, 224)),
23
+ transforms.ToTensor()
 
24
  ])
25
 
26
+ # Class labels (adjust if your model uses different labels)
27
+ classes = ['glioma', 'meningioma', 'notumor', 'pituitary']
28
+
29
  @app.get("/")
30
+ def read_root():
31
+ return {"message": "Brain Tumor Detection API is running 🚀"}
32
 
33
+ @app.post("/predict")
34
  async def predict(file: UploadFile = File(...)):
35
  try:
36
+ image = Image.open(file.file).convert("RGB")
37
+ image = transform(image).unsqueeze(0) # Shape: [1, 3, 224, 224]
 
 
38
 
 
39
  with torch.no_grad():
40
  outputs = btd_model(image)
41
+ _, predicted = torch.max(outputs.data, 1)
42
+ predicted_class = classes[predicted.item()]
43
+ precautions = get_precautions_from_gemini(predicted_class)
44
 
45
+ return JSONResponse(content={
46
+ "prediction": predicted_class,
47
+ "precautions": precautions
48
+ })
49
 
 
 
50
  except Exception as e:
51
+ return JSONResponse(content={"error": str(e)}, status_code=500)