Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
from torchvision import transforms
|
5 |
from PIL import Image
|
|
|
6 |
import logging
|
7 |
|
8 |
# Set up logging for debugging
|
@@ -37,14 +38,11 @@ class BacterialMorphologyClassifier(nn.Module):
|
|
37 |
# Load the model and weights
|
38 |
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
|
39 |
logging.debug("Starting model loading...")
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
logging.debug("Model loaded successfully.")
|
46 |
-
except Exception as e:
|
47 |
-
logging.error(f"Error loading the model: {str(e)}")
|
48 |
|
49 |
# Define image preprocessing transformations
|
50 |
transform = transforms.Compose([
|
@@ -56,10 +54,8 @@ transform = transforms.Compose([
|
|
56 |
# Define the prediction function
|
57 |
def predict(image):
|
58 |
try:
|
59 |
-
logging.debug("Starting prediction...")
|
60 |
# Preprocess the image
|
61 |
image_tensor = transform(image).unsqueeze(0)
|
62 |
-
logging.debug("Image preprocessing completed.")
|
63 |
|
64 |
# Make prediction
|
65 |
output = model(image_tensor)
|
@@ -78,14 +74,8 @@ def predict(image):
|
|
78 |
return "Error", 0.0
|
79 |
|
80 |
# Create a Gradio interface
|
81 |
-
gr.
|
82 |
-
|
83 |
-
inputs=gr.Image(type="pil", label="Upload an image"),
|
84 |
-
outputs=["text", "number"],
|
85 |
-
examples=[
|
86 |
-
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
|
87 |
-
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
|
88 |
-
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"]
|
89 |
-
]
|
90 |
-
).launch(debug=True)
|
91 |
|
|
|
|
|
|
3 |
import torch.nn as nn
|
4 |
from torchvision import transforms
|
5 |
from PIL import Image
|
6 |
+
import io
|
7 |
import logging
|
8 |
|
9 |
# Set up logging for debugging
|
|
|
38 |
# Load the model and weights
|
39 |
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
|
40 |
logging.debug("Starting model loading...")
|
41 |
+
model = BacterialMorphologyClassifier()
|
42 |
+
state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
|
43 |
+
model.load_state_dict(state_dict, strict=False)
|
44 |
+
model.eval()
|
45 |
+
logging.debug("Model loaded successfully.")
|
|
|
|
|
|
|
46 |
|
47 |
# Define image preprocessing transformations
|
48 |
transform = transforms.Compose([
|
|
|
54 |
# Define the prediction function
|
55 |
def predict(image):
|
56 |
try:
|
|
|
57 |
# Preprocess the image
|
58 |
image_tensor = transform(image).unsqueeze(0)
|
|
|
59 |
|
60 |
# Make prediction
|
61 |
output = model(image_tensor)
|
|
|
74 |
return "Error", 0.0
|
75 |
|
76 |
# Create a Gradio interface
|
77 |
+
inputs = gr.Image(type="pil", label="Upload an image")
|
78 |
+
outputs = gr.Label(num_top_classes=3, label="Predicted Class")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
# Launch the Gradio app
|
81 |
+
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, live=True).launch(server_name="0.0.0.0", server_port=7861, debug=True)
|