File size: 2,010 Bytes
4bcd70b
a75fe7e
 
 
 
 
 
e4acaca
60d002b
e4acaca
a75fe7e
 
 
e4acaca
a75fe7e
 
 
 
e4acaca
a75fe7e
 
 
 
e4acaca
a75fe7e
e4acaca
 
a75fe7e
e4acaca
 
a75fe7e
60d002b
a75fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60d002b
a75fe7e
 
 
60d002b
a75fe7e
 
60d002b
a75fe7e
e4acaca
a75fe7e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from utils import BrainTumorModel, GliomaStageModel

app = FastAPI()

# Load models (updated to local .pth files)
btd_model_path = "brain_tumor_model.pth"
glioma_model_path = "glioma_stage_model.pth"

# Initialize and load Brain Tumor Detection Model
btd_model = BrainTumorModel()
btd_model.load_state_dict(torch.load(btd_model_path, map_location=torch.device('cpu')))
btd_model.eval()

# Initialize and load Glioma Stage Detection Model
glioma_model = GliomaStageModel()
glioma_model.load_state_dict(torch.load(glioma_model_path, map_location=torch.device('cpu')))
glioma_model.eval()

# Define preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        image = Image.open(file.file).convert("RGB")
        image = transform(image).unsqueeze(0)

        with torch.no_grad():
            output = btd_model(image)
            predicted = torch.argmax(output, dim=1).item()

        classes = ['No Tumor', 'Pituitary', 'Meningioma', 'Glioma']
        result = classes[predicted]

        return JSONResponse(content={"prediction": result})

    except Exception as e:
        return JSONResponse(content={"error": str(e)})

@app.post("/glioma-stage/")
async def glioma_stage(file: UploadFile = File(...)):
    try:
        image = Image.open(file.file).convert("RGB")
        image = transform(image).unsqueeze(0)

        with torch.no_grad():
            output = glioma_model(image)
            predicted = torch.argmax(output, dim=1).item()

        stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
        result = stages[predicted]

        return JSONResponse(content={"glioma_stage": result})

    except Exception as e:
        return JSONResponse(content={"error": str(e)})