yolac commited on
Commit
faa394e
·
verified ·
1 Parent(s): 94f9989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -33,7 +33,7 @@ class BacterialMorphologyClassifier(nn.Module):
33
  return x
34
 
35
  # Load the model
36
- MODEL_PATH = "model.pth"
37
  model = BacterialMorphologyClassifier()
38
 
39
  try:
@@ -45,7 +45,7 @@ try:
45
  with open(MODEL_PATH, "wb") as f:
46
  f.write(response.content)
47
  # Load the model weights
48
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')), strict=False)
49
  model.eval()
50
  print("Model loaded successfully.")
51
  except Exception as e:
@@ -65,15 +65,15 @@ def predict(image):
65
  image_tensor = transform(image).unsqueeze(0)
66
 
67
  # Perform prediction
68
- output = model(image_tensor)
69
- prediction = output.argmax().item()
70
 
71
  # Class mapping
72
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
73
 
74
  # Return the predicted class and confidence
75
- predicted_class = class_labels[prediction]
76
- confidence = output.max().item()
77
  return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
78
  except Exception as e:
79
  return f"Error: {str(e)}"
 
33
  return x
34
 
35
  # Load the model
36
+ MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
37
  model = BacterialMorphologyClassifier()
38
 
39
  try:
 
45
  with open(MODEL_PATH, "wb") as f:
46
  f.write(response.content)
47
  # Load the model weights
48
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
49
  model.eval()
50
  print("Model loaded successfully.")
51
  except Exception as e:
 
65
  image_tensor = transform(image).unsqueeze(0)
66
 
67
  # Perform prediction
68
+ with torch.no_grad(): # Ensure no gradients are calculated
69
+ output = model(image_tensor)
70
 
71
  # Class mapping
72
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
73
 
74
  # Return the predicted class and confidence
75
+ predicted_class = class_labels[output.argmax().item()]
76
+ confidence = output.max().item() # Softmax value as confidence
77
  return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
78
  except Exception as e:
79
  return f"Error: {str(e)}"