Edit model card

Image Classifier

This repository contains a pre-trained PyTorch model, designed for classifying images into 10 categories: airplane, bird, car, cat, deer, dog, horse, monkey, ship, and truck. The model uses a Convolutional Neural Network (CNN) architecture and can classify images based on the categories below.

Model Overview

The model is a simple CNN classifier with two convolutional blocks followed by a fully connected layer. It was trained on an image dataset and can classify images into the following categories:

  • 0: Airplane
  • 1: Bird
  • 2: Car
  • 3: Cat
  • 4: Deer
  • 5: Dog
  • 6: Horse
  • 7: Monkey
  • 8: Ship
  • 9: Truck

Model Architecture

The model consists of the following layers:

  1. Conv Block 1: Two convolutional layers with ReLU activations followed by max pooling.
  2. Conv Block 2: Two more convolutional layers with ReLU activations and max pooling.
  3. Fully Connected Classifier: A linear layer that maps the features to 10 output categories.

Here’s the architecture of the model:

class CNNV0(nn.Module):
    def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=hidden_units*576, out_features=output_shape)
        )

    def forward(self, x):
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        x = self.classifier(x)
        return x

Requirements

  • Python 3.7 or higher
  • PyTorch 1.8 or higher
  • torchvision (for loading and preprocessing images)

Usage

  1. Clone this repository and install dependencies:

    git clone <repository-url>
    cd <repository-folder>
    pip install torch torchvision
    
  2. Load and use the model in your Python script:

    import torch
    from torchvision import transforms
    from PIL import Image
    
    # Load the model
    model = torch.load('model_0.pth')
    model.eval()  # Set to evaluation mode
    
    # Load and preprocess the image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    img = Image.open('path_to_image.jpg')
    img = transform(img).view(1, 3, 224, 224)  # Reshape to (1, 3, 224, 224) for batch processing
    
    # Predict
    with torch.no_grad():
        output = model(img)
        _, predicted = torch.max(output, 1)
        print("Predicted Aircraft Type:", predicted.item())
    
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train Zahaab/object-classification