|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
import requests |
|
import gradio as gr |
|
import os |
|
|
|
|
|
class BacterialMorphologyClassifier(nn.Module): |
|
def __init__(self): |
|
super(BacterialMorphologyClassifier, self).__init__() |
|
self.feature_extractor = nn.Sequential( |
|
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
) |
|
self.fc = nn.Sequential( |
|
nn.Flatten(), |
|
nn.Linear(64 * 56 * 56, 128), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(128, 3), |
|
nn.Softmax(dim=1), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.feature_extractor(x) |
|
x = self.fc(x) |
|
return x |
|
|
|
|
|
MODEL_PATH = "model.pth" |
|
model = BacterialMorphologyClassifier() |
|
|
|
try: |
|
|
|
if not os.path.exists(MODEL_PATH): |
|
print("Downloading the model...") |
|
url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth" |
|
response = requests.get(url) |
|
with open(MODEL_PATH, "wb") as f: |
|
f.write(response.content) |
|
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) |
|
model.eval() |
|
print("Model loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading the model: {e}") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]), |
|
]) |
|
|
|
|
|
def predict(image): |
|
try: |
|
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
|
|
|
|
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'} |
|
|
|
|
|
predicted_class = class_labels[output.argmax().item()] |
|
confidence = output.max().item() |
|
return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}" |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
examples = [ |
|
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"], |
|
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"], |
|
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"], |
|
] |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Text(label="Prediction"), |
|
title="Bacterial Morphology Classification", |
|
description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.", |
|
examples=examples, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|