yolac commited on
Commit
9efa9b5
·
verified ·
1 Parent(s): 5409e53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -44
app.py CHANGED
@@ -1,56 +1,100 @@
1
- import gradio as gr
2
  import torch
 
3
  from torchvision import transforms
4
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Load the model
7
- def load_model():
8
- model_path = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
9
- model = torch.load(model_path, map_location=torch.device('cpu'))
10
- model.eval() # Set model to evaluation mode
11
- return model
12
 
13
- model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Define class labels
16
- class_labels = {0: "Cocci", 1: "Bacilli", 2: "Spirilla"}
 
 
 
 
17
 
18
  # Prediction function
19
- def predict_image(image):
20
- transform = transforms.Compose([
21
- transforms.Resize((224, 224)),
22
- transforms.ToTensor(),
23
- transforms.Lambda(lambda x: x / 255.0) # Rescale pixel values to [0, 1]
24
- ])
25
- image_tensor = transform(image).unsqueeze(0) # Add batch dimension
26
-
27
- with torch.no_grad():
28
- outputs = model(image_tensor)
29
- probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
30
- predicted_class = probabilities.argmax().item()
31
- return class_labels[predicted_class], {class_labels[i]: float(prob) for i, prob in enumerate(probabilities)}
32
-
33
- # Example images (these should be in the same directory as your app.py)
34
- example_images = ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg", "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg", "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"]
35
-
36
- # Create a Gradio interface
37
- def create_gradio_interface():
38
- iface = gr.Interface(
39
- fn=predict_image,
40
- inputs=gr.inputs.Image(type="pil", label="Upload an image"),
41
- outputs=[
42
- gr.outputs.Label(num_top_classes=3, label="Predicted Class"),
43
- gr.outputs.JSON(label="Class Probabilities")
44
- ],
45
- examples=example_images,
46
- title="Bacterial Morphology Classification",
47
- description="This app classifies bacterial morphology into **Cocci**, **Bacilli**, or **Spirilla** using a fine-tuned PyTorch model.",
48
- )
49
- return iface
 
 
 
 
50
 
51
  # Launch the app
52
  if __name__ == "__main__":
53
- app = create_gradio_interface()
54
- app.launch()
55
-
56
-
 
 
1
  import torch
2
+ import torch.nn as nn
3
  from torchvision import transforms
4
  from PIL import Image
5
+ import requests
6
+ import gradio as gr
7
+ import os
8
+
9
+ # Define the model architecture
10
+ class BacterialMorphologyClassifier(nn.Module):
11
+ def __init__(self):
12
+ super(BacterialMorphologyClassifier, self).__init__()
13
+ self.feature_extractor = nn.Sequential(
14
+ nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
15
+ nn.ReLU(),
16
+ nn.MaxPool2d(kernel_size=2, stride=2),
17
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
18
+ nn.ReLU(),
19
+ nn.MaxPool2d(kernel_size=2, stride=2),
20
+ )
21
+ self.fc = nn.Sequential(
22
+ nn.Flatten(),
23
+ nn.Linear(64 * 56 * 56, 128),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.5),
26
+ nn.Linear(128, 3),
27
+ nn.Softmax(dim=1),
28
+ )
29
+
30
+ def forward(self, x):
31
+ x = self.feature_extractor(x)
32
+ x = self.fc(x)
33
+ return x
34
 
35
  # Load the model
36
+ MODEL_PATH = "model.pth"
37
+ model = BacterialMorphologyClassifier()
 
 
 
38
 
39
+ try:
40
+ # Download the model if it doesn't exist
41
+ if not os.path.exists(MODEL_PATH):
42
+ print("Downloading the model...")
43
+ url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
44
+ response = requests.get(url)
45
+ with open(MODEL_PATH, "wb") as f:
46
+ f.write(response.content)
47
+ # Load the model weights
48
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
49
+ model.eval()
50
+ print("Model loaded successfully.")
51
+ except Exception as e:
52
+ print(f"Error loading the model: {e}")
53
 
54
+ # Define image preprocessing to match training preprocessing
55
+ transform = transforms.Compose([
56
+ transforms.Resize((224, 224)), # Resize to match model input size
57
+ transforms.ToTensor(), # Convert to a tensor
58
+ transforms.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]), # Scale pixel values to [0, 1]
59
+ ])
60
 
61
  # Prediction function
62
+ def predict(image):
63
+ try:
64
+ # Convert the image to a tensor
65
+ image_tensor = transform(image).unsqueeze(0)
66
+
67
+ # Perform prediction
68
+ with torch.no_grad(): # Ensure no gradients are calculated
69
+ output = model(image_tensor)
70
+
71
+ # Class mapping
72
+ class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
73
+
74
+ # Return the predicted class and confidence
75
+ predicted_class = class_labels[output.argmax().item()]
76
+ confidence = output.max().item() # Softmax value as confidence
77
+ return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
78
+ except Exception as e:
79
+ return f"Error: {str(e)}"
80
+
81
+ # Define example images
82
+ examples = [
83
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
84
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
85
+ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"],
86
+ ]
87
+
88
+ # Set up Gradio interface
89
+ interface = gr.Interface(
90
+ fn=predict,
91
+ inputs=gr.Image(type="pil"),
92
+ outputs=gr.Text(label="Prediction"),
93
+ title="Bacterial Morphology Classification",
94
+ description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.",
95
+ examples=examples,
96
+ )
97
 
98
  # Launch the app
99
  if __name__ == "__main__":
100
+ interface.launch()