yolac commited on
Commit
a2830d3
·
verified ·
1 Parent(s): b884821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -55
app.py CHANGED
@@ -1,78 +1,92 @@
1
  import torch
2
  import torch.nn as nn
3
- import gradio as gr
4
  from torchvision import transforms
5
  from PIL import Image
6
  import requests
7
- from io import BytesIO
 
8
 
9
- # Define the PyTorch model architecture
10
- class MyPyTorchModel(nn.Module):
11
  def __init__(self):
12
- super(MyPyTorchModel, self).__init__()
13
- self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
14
- self.relu = nn.ReLU()
15
- self.flatten = nn.Flatten()
16
- self.fc1 = nn.Linear(32 * 224 * 224, 3) # Adjust output size for 3 classes
17
-
 
 
 
 
 
 
 
 
 
 
 
 
18
  def forward(self, x):
19
- x = self.conv1(x)
20
- x = self.relu(x)
21
- x = self.flatten(x)
22
- x = self.fc1(x)
23
  return x
24
 
25
  # Load the model
26
- model_path = 'https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth'
27
- model = MyPyTorchModel()
28
- model.load_state_dict(torch.load(model_path))
29
- model.eval()
30
 
31
- # Define image transformations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  transform = transforms.Compose([
33
  transforms.Resize((224, 224)),
34
  transforms.ToTensor(),
 
35
  ])
36
 
37
- # Define a function to predict the class of the image
38
  def predict(image):
39
- image = Image.fromarray(image).convert('RGB')
40
- image = transform(image).unsqueeze(0)
41
-
42
- with torch.no_grad():
43
- outputs = model(image)
44
- _, predicted = torch.max(outputs, 1)
45
-
46
- class_labels = ['cocci', 'bacilli', 'spirilla']
47
- predicted_label = class_labels[predicted.item()]
48
-
49
- return predicted_label
50
-
51
- # URLs for 3 example images
52
- example_image_urls = [
53
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg", # Replace with the actual URL
54
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg", # Replace with the actual URL
55
- "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg", # Replace with the actual URL
56
- ]
57
-
58
- # Function to download and open example images
59
- def load_example_image(url):
60
- response = requests.get(url)
61
- image = Image.open(BytesIO(response.content))
62
- return image
63
-
64
- # Load the example images
65
- example_images = [load_example_image(url) for url in example_image_urls]
66
 
67
- # Create a Gradio interface
68
- iface = gr.Interface(
69
  fn=predict,
70
- inputs=gr.inputs.Image(shape=(224, 224), label="Upload an image or use examples"),
71
- outputs="text",
72
- title="Bacterial Morphology Classifier",
73
- description="Classify images of bacteria into cocci, bacilli, or spirilla.",
74
- examples=example_images # Add example images to the Gradio app
75
  )
76
 
77
  # Launch the app
78
- iface.launch()
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  import requests
6
+ import gradio as gr
7
+ import os
8
 
9
+ # Define the model architecture
10
+ class BacterialMorphologyClassifier(nn.Module):
11
  def __init__(self):
12
+ super(BacterialMorphologyClassifier, self).__init__()
13
+ self.feature_extractor = nn.Sequential(
14
+ nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
15
+ nn.ReLU(),
16
+ nn.MaxPool2d(kernel_size=2, stride=2),
17
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
18
+ nn.ReLU(),
19
+ nn.MaxPool2d(kernel_size=2, stride=2),
20
+ )
21
+ self.fc = nn.Sequential(
22
+ nn.Flatten(),
23
+ nn.Linear(64 * 56 * 56, 128),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.5),
26
+ nn.Linear(128, 3),
27
+ nn.Softmax(dim=1),
28
+ )
29
+
30
  def forward(self, x):
31
+ x = self.feature_extractor(x)
32
+ x = self.fc(x)
 
 
33
  return x
34
 
35
  # Load the model
36
+ MODEL_PATH = "model.pth"
37
+ model = BacterialMorphologyClassifier()
 
 
38
 
39
+ try:
40
+ # Download the model if it doesn't exist
41
+ if not os.path.exists(MODEL_PATH):
42
+ print("Downloading the model...")
43
+ url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
44
+ response = requests.get(url)
45
+ with open(MODEL_PATH, "wb") as f:
46
+ f.write(response.content)
47
+ # Load the model weights
48
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')), strict=False)
49
+ model.eval()
50
+ print("Model loaded successfully.")
51
+ except Exception as e:
52
+ print(f"Error loading the model: {e}")
53
+
54
+ # Define image preprocessing
55
  transform = transforms.Compose([
56
  transforms.Resize((224, 224)),
57
  transforms.ToTensor(),
58
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
59
  ])
60
 
61
+ # Prediction function
62
  def predict(image):
63
+ try:
64
+ # Convert the image to a tensor
65
+ image_tensor = transform(image).unsqueeze(0)
66
+
67
+ # Perform prediction
68
+ output = model(image_tensor)
69
+ prediction = output.argmax().item()
70
+
71
+ # Class mapping
72
+ class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
73
+
74
+ # Return the predicted class and confidence
75
+ predicted_class = class_labels[prediction]
76
+ confidence = output.max().item()
77
+ return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
78
+ except Exception as e:
79
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # Set up Gradio interface
82
+ interface = gr.Interface(
83
  fn=predict,
84
+ inputs=gr.Image(type="pil"),
85
+ outputs=gr.Text(label="Prediction"),
86
+ title="Bacterial Morphology Classification",
87
+ description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.",
 
88
  )
89
 
90
  # Launch the app
91
+ if __name__ == "__main__":
92
+ interface.launch()