Fawazzx's picture
Create README.md
f0544ef verified

Fine-Tuning ResNet50 for Alzheimer's MRI Classification

This repository contains a Jupyter Notebook for fine-tuning a ResNet50 model to classify Alzheimer's disease stages from MRI images. The notebook uses PyTorch and the dataset is loaded from the Hugging Face Datasets library.

Table of Contents

Introduction

This notebook fine-tunes a pre-trained ResNet50 model to classify MRI images into one of four stages of Alzheimer's disease:

  • Mild Demented
  • Moderate Demented
  • Non-Demented
  • Very Mild Demented

Dataset

The dataset used is Falah/Alzheimer_MRI from the Hugging Face Datasets library. It consists of MRI images categorized into the four stages of Alzheimer's disease.

Model Architecture

The model architecture is based on ResNet50. The final fully connected layer is modified to output predictions for 4 classes.

Setup

To run the notebook locally, follow these steps:

  1. Clone the repository:
    git clone https://github.com/your_username/alzheimer_mri_classification.git
    cd alzheimer_mri_classification
    
  2. Install the required dependencies:
    pip install -r requirements.txt
    
  3. Open the notebook:
    jupyter notebook fine-tuning.ipynb
    

Training

The notebook includes sections for:

  • Loading and preprocessing the dataset
  • Defining the model architecture
  • Setting up the training loop with a learning rate scheduler and optimizer
  • Training the model for a specified number of epochs
  • Saving the trained model weights

Evaluation

The notebook includes a section for evaluating the trained model on the validation set. It calculates and prints the validation loss and accuracy.

Usage

Once trained, the model can be saved and used for inference on new MRI images. The trained model weights are saved as alzheimer_model_resnet50.pth.

Load the model architecture and weights

 ```python
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, 4)
model.load_state_dict(torch.load("alzheimer_model_resnet50.pth", map_location=torch.device('cpu')))
model.eval()
```

Results

The model achieved an accuracy of 95.9375% on the validation set.

Contributing

Contributions are welcome! If you have any suggestions, bug reports, or feature requests, please open an issue or submit a pull request.