File size: 3,320 Bytes
d9f1830
9047e05
1b9be77
2551488
d3b5926
1b76d00
d9d91ec
2551488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b76d00
c8491fd
1b9be77
2551488
c8491fd
1b9be77
 
 
c8491fd
 
1b9be77
 
 
2551488
1b9be77
d3b5926
9047e05
1b76d00
2551488
1b76d00
 
1b9be77
 
 
c8491fd
 
2551488
1b9be77
d9d91ec
1b9be77
 
 
 
 
 
 
 
2551488
1b9be77
c8491fd
21e99b3
 
 
 
 
 
 
 
1b9be77
c8491fd
1b9be77
 
 
21e99b3
 
c8491fd
2551488
c8491fd
1b9be77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import gradio as gr
from torchvision import transforms
from PIL import Image

# Define the model architecture
class BacterialMorphologyClassifier(nn.Module):
    def __init__(self):
        super(BacterialMorphologyClassifier, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 56 * 56, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 3),
        )
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.fc(x)
        return x

# Load the model
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
model = BacterialMorphologyClassifier()
try:
    # Download and load model state_dict
    state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict, strict=False)
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    raise e
model.eval()

# Define image preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Class labels
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}

# Prediction function
def predict(image):
    try:
        # Preprocess the image
        image_tensor = transform(image).unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            output = model(image_tensor)
            prediction = output.argmax().item()
            confidence = torch.nn.functional.softmax(output, dim=1).max().item()

        return {class_labels[prediction]: confidence}
    except Exception as e:
        return {'error': str(e)}

# Example input images (provide paths or URLs)
example_images = [
    "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer?row=0&image-viewer=52B421CB70A43313B278D5DD2C58CECE56343012",  # Replace with the actual paths to your example images
    "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=558EA847F2267CECF4E2CFF6352F9D8888E9A72F",
    "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=8FBAF2C52C256A392660811C5659788734821C3A"
]

# Set up Gradio interface with examples
iface = gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(type="pil", label="Upload an image"),
    outputs=gr.outputs.Label(num_top_classes=3, label="Predicted class and confidence"),
    title="Bacterial Morphology Classifier",
    description="Upload an image of a bacterial sample to classify it as 'cocci', 'bacilli', or 'spirilla'.",
    examples=example_images  # Provide the example image paths
)

# Launch the app
iface.launch(server_name="0.0.0.0", server_port=5000, share=True)