yolac commited on
Commit
1b9be77
·
verified ·
1 Parent(s): 641df71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -39
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from torchvision import transforms
4
  from PIL import Image
5
- import requests
6
- import gradio as gr
7
- import os
8
 
9
  # Define the model architecture
10
  class BacterialMorphologyClassifier(nn.Module):
@@ -24,7 +22,6 @@ class BacterialMorphologyClassifier(nn.Module):
24
  nn.ReLU(),
25
  nn.Dropout(0.5),
26
  nn.Linear(128, 3),
27
- nn.Softmax(dim=1),
28
  )
29
 
30
  def forward(self, x):
@@ -33,60 +30,52 @@ class BacterialMorphologyClassifier(nn.Module):
33
  return x
34
 
35
  # Load the model
36
- MODEL_PATH = "model.pth"
37
  model = BacterialMorphologyClassifier()
38
-
39
  try:
40
- # Download the model if it doesn't exist
41
- if not os.path.exists(MODEL_PATH):
42
- print("Downloading the model...")
43
- url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
44
- response = requests.get(url)
45
- with open(MODEL_PATH, "wb") as f:
46
- f.write(response.content)
47
- # Load the model weights
48
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')), strict=False)
49
- model.eval()
50
  print("Model loaded successfully.")
51
  except Exception as e:
52
- print(f"Error loading the model: {e}")
 
 
53
 
54
- # Define image preprocessing
55
  transform = transforms.Compose([
56
  transforms.Resize((224, 224)),
57
  transforms.ToTensor(),
58
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
59
  ])
60
 
 
 
 
61
  # Prediction function
62
  def predict(image):
63
  try:
64
- # Convert the image to a tensor
65
  image_tensor = transform(image).unsqueeze(0)
66
-
67
- # Perform prediction
68
- output = model(image_tensor)
69
- prediction = output.argmax().item()
70
-
71
- # Class mapping
72
- class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
73
-
74
- # Return the predicted class and confidence
75
- predicted_class = class_labels[prediction]
76
- confidence = output.max().item()
77
- return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
78
  except Exception as e:
79
- return f"Error: {str(e)}"
80
 
81
  # Set up Gradio interface
82
- interface = gr.Interface(
83
  fn=predict,
84
- inputs=gr.Image(type="pil"),
85
- outputs=gr.Text(label="Prediction"),
86
- title="Bacterial Morphology Classification",
87
- description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.",
88
  )
89
 
90
  # Launch the app
91
- if __name__ == "__main__":
92
- interface.launch()
 
1
  import torch
2
  import torch.nn as nn
3
+ import gradio as gr
4
  from torchvision import transforms
5
  from PIL import Image
 
 
 
6
 
7
  # Define the model architecture
8
  class BacterialMorphologyClassifier(nn.Module):
 
22
  nn.ReLU(),
23
  nn.Dropout(0.5),
24
  nn.Linear(128, 3),
 
25
  )
26
 
27
  def forward(self, x):
 
30
  return x
31
 
32
  # Load the model
33
+ MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
34
  model = BacterialMorphologyClassifier()
 
35
  try:
36
+ # Download and load model state_dict
37
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
38
+ model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
39
  print("Model loaded successfully.")
40
  except Exception as e:
41
+ print(f"Error loading model: {e}")
42
+ raise e
43
+ model.eval()
44
 
45
+ # Define image preprocessing transformations
46
  transform = transforms.Compose([
47
  transforms.Resize((224, 224)),
48
  transforms.ToTensor(),
49
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
50
  ])
51
 
52
+ # Class labels
53
+ class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
54
+
55
  # Prediction function
56
  def predict(image):
57
  try:
58
+ # Preprocess the image
59
  image_tensor = transform(image).unsqueeze(0)
60
+
61
+ # Perform inference
62
+ with torch.no_grad():
63
+ output = model(image_tensor)
64
+ prediction = output.argmax().item()
65
+ confidence = torch.nn.functional.softmax(output, dim=1).max().item()
66
+
67
+ return {class_labels[prediction]: confidence}
 
 
 
 
68
  except Exception as e:
69
+ return {'error': str(e)}
70
 
71
  # Set up Gradio interface
72
+ iface = gr.Interface(
73
  fn=predict,
74
+ inputs=gr.inputs.Image(type="pil", label="Upload an image"),
75
+ outputs=gr.outputs.Label(num_top_classes=3, label="Predicted class and confidence"),
76
+ title="Bacterial Morphology Classifier",
77
+ description="Upload an image of a bacterial sample to classify it as 'cocci', 'bacilli', or 'spirilla'."
78
  )
79
 
80
  # Launch the app
81
+ iface.launch(server_name="0.0.0.0", server_port=5000, share=True)