yolac commited on
Commit
fdcefa1
·
verified ·
1 Parent(s): cf86671

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -28
app.py CHANGED
@@ -4,8 +4,12 @@ import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
6
  import io
 
7
 
8
- # Define the model architecture that matches the saved .pth file
 
 
 
9
  class BacterialMorphologyClassifier(nn.Module):
10
  def __init__(self):
11
  super(BacterialMorphologyClassifier, self).__init__()
@@ -31,56 +35,53 @@ class BacterialMorphologyClassifier(nn.Module):
31
  x = self.fc(x)
32
  return x
33
 
34
- # Load the model and weights
35
- model = BacterialMorphologyClassifier()
36
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
37
- state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
38
- model.load_state_dict(state_dict, strict=False)
39
- model.eval()
 
 
 
 
 
 
40
 
41
- # Define image preprocessing transformations
42
  transform = transforms.Compose([
43
  transforms.Resize((224, 224)),
44
  transforms.ToTensor(),
45
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
46
  ])
47
 
48
- # Define Gradio interface
49
  def predict(image):
50
  try:
51
- # Preprocess the image
52
  image_tensor = transform(image).unsqueeze(0)
53
 
54
  # Make prediction
55
  output = model(image_tensor)
56
  prediction = output.argmax().item()
57
  confidence = output.max().item()
58
-
59
- # Print debugging information
60
- print(f"Predicted: {prediction}, Confidence: {confidence}")
61
- print(f"Model Output: {output}")
62
 
63
  # Class mapping
64
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
65
-
66
- # Return prediction result
67
  return class_labels[prediction], confidence
 
68
  except Exception as e:
 
69
  return {'error': str(e)}
70
 
71
-
72
- # Create the Gradio interface
73
- iface = gr.Interface(
74
  fn=predict,
75
- inputs=gr.Image(type="pil", label="Upload an image"),
76
- outputs=[gr.Label(num_top_classes=3, label="Predicted Class"), gr.Number(label="Confidence")],
77
  examples=[
78
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg",
79
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg",
80
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"
81
  ]
82
- )
83
-
84
- # Launch the app
85
- if __name__ == "__main__":
86
- iface.launch()
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  import io
7
+ import logging
8
 
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.DEBUG)
11
+
12
+ # Define the model architecture
13
  class BacterialMorphologyClassifier(nn.Module):
14
  def __init__(self):
15
  super(BacterialMorphologyClassifier, self).__init__()
 
35
  x = self.fc(x)
36
  return x
37
 
38
+ # Load the model
 
39
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
40
+ model = BacterialMorphologyClassifier()
41
+ try:
42
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
43
+ model.load_state_dict(state_dict, strict=False)
44
+ model.eval()
45
+ logging.info("Model loaded successfully.")
46
+ except Exception as e:
47
+ logging.error(f"Error loading the model: {e}")
48
+ raise
49
 
50
+ # Image preprocessing transformations
51
  transform = transforms.Compose([
52
  transforms.Resize((224, 224)),
53
  transforms.ToTensor(),
54
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
55
  ])
56
 
 
57
  def predict(image):
58
  try:
59
+ logging.info("Received image for prediction.")
60
  image_tensor = transform(image).unsqueeze(0)
61
 
62
  # Make prediction
63
  output = model(image_tensor)
64
  prediction = output.argmax().item()
65
  confidence = output.max().item()
66
+
67
+ logging.debug(f"Model output: {output}, Prediction: {prediction}, Confidence: {confidence}")
 
 
68
 
69
  # Class mapping
70
  class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
 
 
71
  return class_labels[prediction], confidence
72
+
73
  except Exception as e:
74
+ logging.error(f"Error during prediction: {e}")
75
  return {'error': str(e)}
76
 
77
+ # Create Gradio app
78
+ gr.Interface(
 
79
  fn=predict,
80
+ inputs=gr.inputs.Image(type="pil", label="Upload an image"),
81
+ outputs=["text", "number"],
82
  examples=[
83
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
84
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
85
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"]
86
  ]
87
+ ).launch(debug=True)