yolac commited on
Commit
02c0ae0
·
verified ·
1 Parent(s): 21e99b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -38
app.py CHANGED
@@ -1,10 +1,11 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import gradio as gr
4
  from torchvision import transforms
5
  from PIL import Image
 
6
 
7
- # Define the model architecture
8
  class BacterialMorphologyClassifier(nn.Module):
9
  def __init__(self):
10
  super(BacterialMorphologyClassifier, self).__init__()
@@ -22,6 +23,7 @@ class BacterialMorphologyClassifier(nn.Module):
22
  nn.ReLU(),
23
  nn.Dropout(0.5),
24
  nn.Linear(128, 3),
 
25
  )
26
 
27
  def forward(self, x):
@@ -29,17 +31,11 @@ class BacterialMorphologyClassifier(nn.Module):
29
  x = self.fc(x)
30
  return x
31
 
32
- # Load the model
33
- MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
34
  model = BacterialMorphologyClassifier()
35
- try:
36
- # Download and load model state_dict
37
- state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
38
- model.load_state_dict(state_dict, strict=False)
39
- print("Model loaded successfully.")
40
- except Exception as e:
41
- print(f"Error loading model: {e}")
42
- raise e
43
  model.eval()
44
 
45
  # Define image preprocessing transformations
@@ -49,41 +45,36 @@ transform = transforms.Compose([
49
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
50
  ])
51
 
52
- # Class labels
53
- class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
54
-
55
- # Prediction function
56
  def predict(image):
57
  try:
58
  # Preprocess the image
59
  image_tensor = transform(image).unsqueeze(0)
60
-
61
- # Perform inference
62
- with torch.no_grad():
63
- output = model(image_tensor)
64
- prediction = output.argmax().item()
65
- confidence = torch.nn.functional.softmax(output, dim=1).max().item()
66
-
67
- return {class_labels[prediction]: confidence}
 
 
68
  except Exception as e:
69
  return {'error': str(e)}
70
 
71
- # Example input images (provide paths or URLs)
72
- example_images = [
73
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer?row=0&image-viewer=52B421CB70A43313B278D5DD2C58CECE56343012", # Replace with the actual paths to your example images
74
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=558EA847F2267CECF4E2CFF6352F9D8888E9A72F",
75
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=8FBAF2C52C256A392660811C5659788734821C3A"
76
- ]
77
-
78
- # Set up Gradio interface with examples
79
  iface = gr.Interface(
80
  fn=predict,
81
- inputs=gr.inputs.Image(type="pil", label="Upload an image"),
82
- outputs=gr.outputs.Label(num_top_classes=3, label="Predicted class and confidence"),
83
- title="Bacterial Morphology Classifier",
84
- description="Upload an image of a bacterial sample to classify it as 'cocci', 'bacilli', or 'spirilla'.",
85
- examples=example_images # Provide the example image paths
 
 
86
  )
87
 
88
  # Launch the app
89
- iface.launch(server_name="0.0.0.0", server_port=5000, share=True)
 
 
1
+ import gradio as gr
2
  import torch
3
  import torch.nn as nn
 
4
  from torchvision import transforms
5
  from PIL import Image
6
+ import io
7
 
8
+ # Define the model architecture that matches the saved .pth file
9
  class BacterialMorphologyClassifier(nn.Module):
10
  def __init__(self):
11
  super(BacterialMorphologyClassifier, self).__init__()
 
23
  nn.ReLU(),
24
  nn.Dropout(0.5),
25
  nn.Linear(128, 3),
26
+ nn.Softmax(dim=1),
27
  )
28
 
29
  def forward(self, x):
 
31
  x = self.fc(x)
32
  return x
33
 
34
+ # Load the model and weights
 
35
  model = BacterialMorphologyClassifier()
36
+ MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
37
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
38
+ model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
39
  model.eval()
40
 
41
  # Define image preprocessing transformations
 
45
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
46
  ])
47
 
48
+ # Define Gradio interface
 
 
 
49
  def predict(image):
50
  try:
51
  # Preprocess the image
52
  image_tensor = transform(image).unsqueeze(0)
53
+
54
+ # Make prediction
55
+ output = model(image_tensor)
56
+ prediction = output.argmax().item()
57
+
58
+ # Class mapping
59
+ class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
60
+
61
+ # Return prediction result
62
+ return class_labels[prediction], output.max().item()
63
  except Exception as e:
64
  return {'error': str(e)}
65
 
66
+ # Create the Gradio interface
 
 
 
 
 
 
 
67
  iface = gr.Interface(
68
  fn=predict,
69
+ inputs=gr.Image(type="pil", label="Upload an image"),
70
+ outputs=[gr.Label(num_top_classes=3, label="Predicted Class"), gr.Number(label="Confidence")],
71
+ examples=[
72
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=8FBAF2C52C256A392660811C5659788734821C3A",
73
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&image-viewer=AEF1AA2978EEB77362DA9CCC8792473666F7CDC6",
74
+ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?image-viewer=C98E6CFAB26ECC3808C63185F6CCE90DE4E7C442"
75
+ ]
76
  )
77
 
78
  # Launch the app
79
+ if __name__ == "__main__":
80
+ iface.launch()