Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -31,21 +31,20 @@ class BacterialMorphologyClassifier(nn.Module):
|
|
31 |
x = self.fc(x)
|
32 |
return x
|
33 |
|
34 |
-
# Load the model and weights
|
35 |
model = BacterialMorphologyClassifier()
|
36 |
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
|
37 |
state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
|
38 |
-
model.load_state_dict(state_dict)
|
39 |
model.eval()
|
40 |
|
|
|
|
|
|
|
|
|
41 |
# Set up Flask app
|
42 |
app = Flask(__name__)
|
43 |
|
44 |
-
# Add a basic route
|
45 |
-
@app.route('/')
|
46 |
-
def home():
|
47 |
-
return "Flask app is running!"
|
48 |
-
|
49 |
# Define image preprocessing transformations
|
50 |
transform = transforms.Compose([
|
51 |
transforms.Resize((224, 224)),
|
@@ -61,7 +60,7 @@ def predict():
|
|
61 |
image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
|
62 |
|
63 |
# Preprocess the image
|
64 |
-
image_tensor = transform(image).unsqueeze(0)
|
65 |
|
66 |
# Make prediction
|
67 |
output = model(image_tensor)
|
|
|
31 |
x = self.fc(x)
|
32 |
return x
|
33 |
|
34 |
+
# Load the model and weights at app startup
|
35 |
model = BacterialMorphologyClassifier()
|
36 |
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
|
37 |
state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
|
38 |
+
model.load_state_dict(state_dict, strict=False)
|
39 |
model.eval()
|
40 |
|
41 |
+
# Move model to GPU if available
|
42 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
43 |
+
model.to(device)
|
44 |
+
|
45 |
# Set up Flask app
|
46 |
app = Flask(__name__)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
48 |
# Define image preprocessing transformations
|
49 |
transform = transforms.Compose([
|
50 |
transforms.Resize((224, 224)),
|
|
|
60 |
image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
|
61 |
|
62 |
# Preprocess the image
|
63 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
64 |
|
65 |
# Make prediction
|
66 |
output = model(image_tensor)
|