Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,14 +6,19 @@ from pydantic import BaseModel
|
|
6 |
from io import BytesIO
|
7 |
from PIL import Image
|
8 |
|
|
|
9 |
app = FastAPI()
|
10 |
|
11 |
-
#
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
model.eval()
|
15 |
|
16 |
-
# Define
|
17 |
preprocess = transforms.Compose([
|
18 |
transforms.Resize((224, 224)),
|
19 |
transforms.ToTensor(),
|
@@ -24,7 +29,7 @@ preprocess = transforms.Compose([
|
|
24 |
def preprocess_image(image: Image.Image):
|
25 |
image = image.convert("RGB")
|
26 |
image = preprocess(image)
|
27 |
-
image = image.unsqueeze(0)
|
28 |
return image
|
29 |
|
30 |
@app.post("/predict/")
|
@@ -40,7 +45,7 @@ async def predict(file: UploadFile = File(...)):
|
|
40 |
outputs = model(image_tensor)
|
41 |
predicted_class_idx = torch.argmax(outputs, dim=1).item()
|
42 |
|
43 |
-
# Map
|
44 |
class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
|
45 |
predicted_class = class_labels[predicted_class_idx]
|
46 |
|
@@ -49,4 +54,5 @@ async def predict(file: UploadFile = File(...)):
|
|
49 |
except Exception as e:
|
50 |
return {"error": str(e)}
|
51 |
|
52 |
-
#
|
|
|
|
6 |
from io import BytesIO
|
7 |
from PIL import Image
|
8 |
|
9 |
+
# Initialize the FastAPI app
|
10 |
app = FastAPI()
|
11 |
|
12 |
+
# Path to the trained model checkpoint
|
13 |
+
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/BacterialMorphologyClassification_model.ipynb"
|
14 |
+
|
15 |
+
# Load the pretrained MobileNetV2 model and modify the classifier
|
16 |
+
model = models.mobilenet_v2(weights=None) # No initial weights
|
17 |
+
model.classifier[1] = nn.Linear(model.last_channel, 3) # Replace with 3 classes for your dataset
|
18 |
+
model.load_state_dict(torch.load(MODEL_PATH)) # Load your trained model weights
|
19 |
model.eval()
|
20 |
|
21 |
+
# Define image preprocessing transformations
|
22 |
preprocess = transforms.Compose([
|
23 |
transforms.Resize((224, 224)),
|
24 |
transforms.ToTensor(),
|
|
|
29 |
def preprocess_image(image: Image.Image):
|
30 |
image = image.convert("RGB")
|
31 |
image = preprocess(image)
|
32 |
+
image = image.unsqueeze(0) # Add a batch dimension
|
33 |
return image
|
34 |
|
35 |
@app.post("/predict/")
|
|
|
45 |
outputs = model(image_tensor)
|
46 |
predicted_class_idx = torch.argmax(outputs, dim=1).item()
|
47 |
|
48 |
+
# Map class index to class label
|
49 |
class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
|
50 |
predicted_class = class_labels[predicted_class_idx]
|
51 |
|
|
|
54 |
except Exception as e:
|
55 |
return {"error": str(e)}
|
56 |
|
57 |
+
# To run the app, use the following command in the terminal:
|
58 |
+
# uvicorn app:app --host 0.0.0.0 --port 8000 --reload
|