import torch import torch.nn as nn from torchvision import transforms from PIL import Image import requests import gradio as gr import os # Define the model architecture class BacterialMorphologyClassifier(nn.Module): def __init__(self): super(BacterialMorphologyClassifier, self).__init__() self.feature_extractor = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.fc = nn.Sequential( nn.Flatten(), nn.Linear(64 * 56 * 56, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 3), nn.Softmax(dim=1), ) def forward(self, x): x = self.feature_extractor(x) x = self.fc(x) return x # Load the model MODEL_PATH = "model.pth" model = BacterialMorphologyClassifier() try: # Download the model if it doesn't exist if not os.path.exists(MODEL_PATH): print("Downloading the model...") url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth" response = requests.get(url) with open(MODEL_PATH, "wb") as f: f.write(response.content) # Load the model weights model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading the model: {e}") # Define image preprocessing to match training preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to match model input size transforms.ToTensor(), # Convert to a tensor transforms.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]), # Scale pixel values to [0, 1] ]) # Prediction function def predict(image): try: # Convert the image to a tensor image_tensor = transform(image).unsqueeze(0) # Perform prediction with torch.no_grad(): # Ensure no gradients are calculated output = model(image_tensor) # Class mapping class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'} # Return the predicted class and confidence predicted_class = class_labels[output.argmax().item()] confidence = output.max().item() # Softmax value as confidence return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}" except Exception as e: return f"Error: {str(e)}" # Define example images examples = [ ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"], ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"], ["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"], ] # Set up Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Text(label="Prediction"), title="Bacterial Morphology Classification", description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.", examples=examples, ) # Launch the app if __name__ == "__main__": interface.launch()