resnet-train / Readme.md
Sreekanth Tangirala
updated Readme and logs md
7da72ea
|
raw
history blame
1.58 kB

ResNet50 Model Implementation

This implementation provides a customizable ResNet50 model for image classification tasks.

Model Architecture

The model uses the ResNet50 architecture, which is a deep convolutional neural network with 50 layers. Key features include:

  • Based on the standard ResNet50 architecture
  • Customizable number of output classes
  • Modified final fully connected layer to match the desired number of classes
  • Initialized from scratch (no pre-training)

Functions

get_model(num_classes)

Initializes a new ResNet50 model with a custom number of output classes.

  • Input: Number of classes (integer)
  • Output: Initialized ResNet50 model
  • Note: The model is initialized without pre-trained weights

save_model(model, path)

Saves the model's state dictionary to a specified path.

  • Input:
    • model: Trained PyTorch model
    • path: File path for saving the model

load_model(num_classes, path)

Loads a previously saved model.

  • Input:
    • num_classes: Number of output classes
    • path: Path to the saved model file
  • Output: Loaded ResNet50 model

Usage Example

model = get_model(num_classes=1000)
save_model(model, 'model.pth')
loaded_model = load_model(num_classes=1000, path='model.pth')

Attachments

Trained on AWS Logs 1 Logs 2