yolac's picture
Update app.py
9efa9b5 verified
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import requests
import gradio as gr
import os
# Define the model architecture
class BacterialMorphologyClassifier(nn.Module):
def __init__(self):
super(BacterialMorphologyClassifier, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 56 * 56, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 3),
nn.Softmax(dim=1),
)
def forward(self, x):
x = self.feature_extractor(x)
x = self.fc(x)
return x
# Load the model
MODEL_PATH = "model.pth"
model = BacterialMorphologyClassifier()
try:
# Download the model if it doesn't exist
if not os.path.exists(MODEL_PATH):
print("Downloading the model...")
url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
response = requests.get(url)
with open(MODEL_PATH, "wb") as f:
f.write(response.content)
# Load the model weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading the model: {e}")
# Define image preprocessing to match training preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to match model input size
transforms.ToTensor(), # Convert to a tensor
transforms.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]), # Scale pixel values to [0, 1]
])
# Prediction function
def predict(image):
try:
# Convert the image to a tensor
image_tensor = transform(image).unsqueeze(0)
# Perform prediction
with torch.no_grad(): # Ensure no gradients are calculated
output = model(image_tensor)
# Class mapping
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
# Return the predicted class and confidence
predicted_class = class_labels[output.argmax().item()]
confidence = output.max().item() # Softmax value as confidence
return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}"
except Exception as e:
return f"Error: {str(e)}"
# Define example images
examples = [
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"],
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"],
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"],
]
# Set up Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Text(label="Prediction"),
title="Bacterial Morphology Classification",
description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.",
examples=examples,
)
# Launch the app
if __name__ == "__main__":
interface.launch()