|
import torch |
|
import torch.nn as nn |
|
import gradio as gr |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
|
|
class BacterialMorphologyClassifier(nn.Module): |
|
def __init__(self): |
|
super(BacterialMorphologyClassifier, self).__init__() |
|
self.feature_extractor = nn.Sequential( |
|
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
) |
|
self.fc = nn.Sequential( |
|
nn.Flatten(), |
|
nn.Linear(64 * 56 * 56, 128), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(128, 3), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.feature_extractor(x) |
|
x = self.fc(x) |
|
return x |
|
|
|
|
|
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth" |
|
model = BacterialMorphologyClassifier() |
|
try: |
|
|
|
state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu')) |
|
model.load_state_dict(state_dict, strict=False) |
|
print("Model loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise e |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'} |
|
|
|
|
|
def predict(image): |
|
try: |
|
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
prediction = output.argmax().item() |
|
confidence = torch.nn.functional.softmax(output, dim=1).max().item() |
|
|
|
return {class_labels[prediction]: confidence} |
|
except Exception as e: |
|
return {'error': str(e)} |
|
|
|
|
|
example_images = [ |
|
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer?row=0&image-viewer=52B421CB70A43313B278D5DD2C58CECE56343012", |
|
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=558EA847F2267CECF4E2CFF6352F9D8888E9A72F", |
|
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=8FBAF2C52C256A392660811C5659788734821C3A" |
|
] |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.inputs.Image(type="pil", label="Upload an image"), |
|
outputs=gr.outputs.Label(num_top_classes=3, label="Predicted class and confidence"), |
|
title="Bacterial Morphology Classifier", |
|
description="Upload an image of a bacterial sample to classify it as 'cocci', 'bacilli', or 'spirilla'.", |
|
examples=example_images |
|
) |
|
|
|
|
|
iface.launch(server_name="0.0.0.0", server_port=5000, share=True) |
|
|