File size: 3,439 Bytes
1422569
9efa9b5
1422569
d1439d9
9efa9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdcefa1
5409e53
9efa9b5
 
d1439d9
9efa9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
5409e53
9efa9b5
 
 
 
 
 
5409e53
 
9efa9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5409e53
 
 
9efa9b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import requests
import gradio as gr
import os

# Define the model architecture
class BacterialMorphologyClassifier(nn.Module):
    def __init__(self):
        super(BacterialMorphologyClassifier, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 56 * 56, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 3),
            nn.Softmax(dim=1),
        )
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.fc(x)
        return x

# Load the model
MODEL_PATH = "model.pth"
model = BacterialMorphologyClassifier()

try:
    # Download the model if it doesn't exist
    if not os.path.exists(MODEL_PATH):
        print("Downloading the model...")
        url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
        response = requests.get(url)
        with open(MODEL_PATH, "wb") as f:
            f.write(response.content)
    # Load the model weights
    model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
    model.eval()
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading the model: {e}")

# Define image preprocessing to match training preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match model input size
    transforms.ToTensor(),  # Convert to a tensor
    transforms.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]),  # Scale pixel values to [0, 1]
])

# Prediction function
def predict(image):
    try:
        # Convert the image to a tensor
        image_tensor = transform(image).unsqueeze(0)
        
        # Perform prediction
        with torch.no_grad():  # Ensure no gradients are calculated
            output = model(image_tensor)
        
        # Class mapping
        class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
        
        # Return the predicted class and confidence
        predicted_class = class_labels[output.argmax().item()]
        confidence = output.max().item()  # Softmax value as confidence
        return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
    except Exception as e:
        return f"Error: {str(e)}"

# Define example images
examples = [
    ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
    ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
    ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"],
]

# Set up Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Text(label="Prediction"),
    title="Bacterial Morphology Classification",
    description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.",
    examples=examples,
)

# Launch the app
if __name__ == "__main__":
    interface.launch()