File size: 3,320 Bytes
d9f1830 9047e05 1b9be77 2551488 d3b5926 1b76d00 d9d91ec 2551488 1b76d00 c8491fd 1b9be77 2551488 c8491fd 1b9be77 c8491fd 1b9be77 2551488 1b9be77 d3b5926 9047e05 1b76d00 2551488 1b76d00 1b9be77 c8491fd 2551488 1b9be77 d9d91ec 1b9be77 2551488 1b9be77 c8491fd 21e99b3 1b9be77 c8491fd 1b9be77 21e99b3 c8491fd 2551488 c8491fd 1b9be77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
import torch.nn as nn
import gradio as gr
from torchvision import transforms
from PIL import Image
# Define the model architecture
class BacterialMorphologyClassifier(nn.Module):
def __init__(self):
super(BacterialMorphologyClassifier, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 56 * 56, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 3),
)
def forward(self, x):
x = self.feature_extractor(x)
x = self.fc(x)
return x
# Load the model
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
model = BacterialMorphologyClassifier()
try:
# Download and load model state_dict
state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
raise e
model.eval()
# Define image preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Class labels
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
# Prediction function
def predict(image):
try:
# Preprocess the image
image_tensor = transform(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
output = model(image_tensor)
prediction = output.argmax().item()
confidence = torch.nn.functional.softmax(output, dim=1).max().item()
return {class_labels[prediction]: confidence}
except Exception as e:
return {'error': str(e)}
# Example input images (provide paths or URLs)
example_images = [
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer?row=0&image-viewer=52B421CB70A43313B278D5DD2C58CECE56343012", # Replace with the actual paths to your example images
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=558EA847F2267CECF4E2CFF6352F9D8888E9A72F",
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=8FBAF2C52C256A392660811C5659788734821C3A"
]
# Set up Gradio interface with examples
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="pil", label="Upload an image"),
outputs=gr.outputs.Label(num_top_classes=3, label="Predicted class and confidence"),
title="Bacterial Morphology Classifier",
description="Upload an image of a bacterial sample to classify it as 'cocci', 'bacilli', or 'spirilla'.",
examples=example_images # Provide the example image paths
)
# Launch the app
iface.launch(server_name="0.0.0.0", server_port=5000, share=True)
|