yolac's picture
Update app.py
21e99b3 verified
raw
history blame
3.32 kB
import torch
import torch.nn as nn
import gradio as gr
from torchvision import transforms
from PIL import Image
# Define the model architecture
class BacterialMorphologyClassifier(nn.Module):
def __init__(self):
super(BacterialMorphologyClassifier, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 56 * 56, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 3),
)
def forward(self, x):
x = self.feature_extractor(x)
x = self.fc(x)
return x
# Load the model
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
model = BacterialMorphologyClassifier()
try:
# Download and load model state_dict
state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
raise e
model.eval()
# Define image preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Class labels
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
# Prediction function
def predict(image):
try:
# Preprocess the image
image_tensor = transform(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
output = model(image_tensor)
prediction = output.argmax().item()
confidence = torch.nn.functional.softmax(output, dim=1).max().item()
return {class_labels[prediction]: confidence}
except Exception as e:
return {'error': str(e)}
# Example input images (provide paths or URLs)
example_images = [
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer?row=0&image-viewer=52B421CB70A43313B278D5DD2C58CECE56343012", # Replace with the actual paths to your example images
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=558EA847F2267CECF4E2CFF6352F9D8888E9A72F",
"https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=8FBAF2C52C256A392660811C5659788734821C3A"
]
# Set up Gradio interface with examples
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="pil", label="Upload an image"),
outputs=gr.outputs.Label(num_top_classes=3, label="Predicted class and confidence"),
title="Bacterial Morphology Classifier",
description="Upload an image of a bacterial sample to classify it as 'cocci', 'bacilli', or 'spirilla'.",
examples=example_images # Provide the example image paths
)
# Launch the app
iface.launch(server_name="0.0.0.0", server_port=5000, share=True)