import torch import torch.nn as nn import gradio as gr from torchvision import transforms from PIL import Image # Define the model architecture class BacterialMorphologyClassifier(nn.Module): def __init__(self): super(BacterialMorphologyClassifier, self).__init__() self.feature_extractor = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.fc = nn.Sequential( nn.Flatten(), nn.Linear(64 * 56 * 56, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 3), ) def forward(self, x): x = self.feature_extractor(x) x = self.fc(x) return x # Load the model MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth" model = BacterialMorphologyClassifier() try: # Download and load model state_dict state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu')) model.load_state_dict(state_dict, strict=False) print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") raise e model.eval() # Define image preprocessing transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Class labels class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'} # Prediction function def predict(image): try: # Preprocess the image image_tensor = transform(image).unsqueeze(0) # Perform inference with torch.no_grad(): output = model(image_tensor) prediction = output.argmax().item() confidence = torch.nn.functional.softmax(output, dim=1).max().item() return {class_labels[prediction]: confidence} except Exception as e: return {'error': str(e)} # Example input images (provide paths or URLs) example_images = [ "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer?row=0&image-viewer=52B421CB70A43313B278D5DD2C58CECE56343012", # Replace with the actual paths to your example images "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=558EA847F2267CECF4E2CFF6352F9D8888E9A72F", "https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/viewer/default/train?p=2&row=201&image-viewer=8FBAF2C52C256A392660811C5659788734821C3A" ] # Set up Gradio interface with examples iface = gr.Interface( fn=predict, inputs=gr.inputs.Image(type="pil", label="Upload an image"), outputs=gr.outputs.Label(num_top_classes=3, label="Predicted class and confidence"), title="Bacterial Morphology Classifier", description="Upload an image of a bacterial sample to classify it as 'cocci', 'bacilli', or 'spirilla'.", examples=example_images # Provide the example image paths ) # Launch the app iface.launch(server_name="0.0.0.0", server_port=5000, share=True)