yolac commited on
Commit
31017b6
·
verified ·
1 Parent(s): 7acf092

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -27
app.py CHANGED
@@ -3,12 +3,13 @@ import torch
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
 
6
  import logging
7
 
8
- # Set up logging
9
  logging.basicConfig(level=logging.DEBUG)
10
 
11
- # Define the model architecture
12
  class BacterialMorphologyClassifier(nn.Module):
13
  def __init__(self):
14
  super(BacterialMorphologyClassifier, self).__init__()
@@ -34,51 +35,48 @@ class BacterialMorphologyClassifier(nn.Module):
34
  x = self.fc(x)
35
  return x
36
 
37
- # Load the model
38
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
 
39
  model = BacterialMorphologyClassifier()
40
- try:
41
- state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
42
- model.load_state_dict(state_dict, strict=False)
43
- model.eval()
44
- logging.info("Model loaded successfully.")
45
- except Exception as e:
46
- logging.error(f"Error loading the model: {e}")
47
- raise
48
 
49
- # Image preprocessing transformations
50
  transform = transforms.Compose([
51
  transforms.Resize((224, 224)),
52
  transforms.ToTensor(),
53
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
54
  ])
55
 
 
56
  def predict(image):
57
  try:
58
- logging.info("Received image for prediction.")
59
  image_tensor = transform(image).unsqueeze(0)
60
 
61
  # Make prediction
62
  output = model(image_tensor)
63
  prediction = output.argmax().item()
64
- confidence = output.max().item()
65
-
66
- logging.debug(f"Model output: {output}, Prediction: {prediction}, Confidence: {confidence}")
67
-
68
  # Class mapping
69
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
70
- return class_labels[prediction], confidence
71
-
 
 
 
 
72
  except Exception as e:
73
- logging.error(f"Error during prediction: {e}")
74
- return {'error': str(e)}
75
 
76
- # Create Gradio app
77
- gr.Interface(
78
- fn=predict,
79
- inputs=gr.Image(type="pil", label="Upload an image"),
80
- outputs=["text", "number"],
81
- examples=[
82
  ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
83
  ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
84
  ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"]
 
3
  import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
6
+ import io
7
  import logging
8
 
9
+ # Set up logging for debugging
10
  logging.basicConfig(level=logging.DEBUG)
11
 
12
+ # Define the model architecture that matches the saved .pth file
13
  class BacterialMorphologyClassifier(nn.Module):
14
  def __init__(self):
15
  super(BacterialMorphologyClassifier, self).__init__()
 
35
  x = self.fc(x)
36
  return x
37
 
38
+ # Load the model and weights
39
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
40
+ logging.debug("Starting model loading...")
41
  model = BacterialMorphologyClassifier()
42
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
43
+ model.load_state_dict(state_dict, strict=False)
44
+ model.eval()
45
+ logging.debug("Model loaded successfully.")
 
 
 
 
46
 
47
+ # Define image preprocessing transformations
48
  transform = transforms.Compose([
49
  transforms.Resize((224, 224)),
50
  transforms.ToTensor(),
51
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
52
  ])
53
 
54
+ # Define the prediction function
55
  def predict(image):
56
  try:
57
+ # Preprocess the image
58
  image_tensor = transform(image).unsqueeze(0)
59
 
60
  # Make prediction
61
  output = model(image_tensor)
62
  prediction = output.argmax().item()
63
+
 
 
 
64
  # Class mapping
65
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
66
+
67
+ # Log prediction details
68
+ logging.debug(f"Predicted class: {class_labels[prediction]}, Confidence: {output.max().item()}")
69
+
70
+ # Return prediction result
71
+ return class_labels[prediction], float(output.max().item())
72
  except Exception as e:
73
+ logging.error(f"Error during prediction: {str(e)}")
74
+ return "Error", 0.0
75
 
76
+ # Create a Gradio interface
77
+ inputs = gr.Image(type="pil", label="Upload an image")
78
+ outputs = gr.Label(num_top_classes=3, label="Predicted Class")
79
+ examples=[
 
 
80
  ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
81
  ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
82
  ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"]