IamYash's picture
Create README.md
e5445d2

Overview

This experiment is created to analyze the training dynamics of the vision transfomrers unde Prisma project. The small Vision Transformers were trained and evaluated for the task of shape classification on the dSprites dataset. This dataset consists of 2D shapes generated procedurally, focusing on six independent latent factors. This specific task involved classifying three distinct shapes within the dSprites dataset using ViTs. All of the training checkpoints are available on the Hugging Face Hub. The checkpoints are summarised in the following table with links to the models on the Hub:

Size No. Layers AttentionOnly Attention-and-MLP
tiny 1 link link
base 2 link link
small 3 link link
medium 4 link link

Here each repo has the multiple intermediate checkpoints. Each checkpoint is stored as "checkpoint_{i}.pth", where i the the number of traineng sample the model has been trained on.

The other details regarding training and results is described Here.

How to Use

!git clone https://github.com/soniajoseph/ViT-Prisma
!cd ViT-Prisma
!pip install -e .
from huggingface_hub import hf_hub_download
import torch

REPO_ID = "IamYash/dSprites-tiny-AttentionOnly"
FILENAME = "model_0.pth"

checkpoint = torch.load(
hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
)
from vit_prisma.models.base_vit import BaseViT
from vit_prisma.configs.DSpritesConfig import GlobalConfig
from vit_prisma.utils.wandb_utils import update_dataclass_from_dict

config = GlobalConfig()
print(config)
update_dict = {
    'transformer':{
        'attention_only': True,
        'hidden_dim': 512,
        'num_heads': 8,
        'num_layers': 1
    }
}
update_dataclass_from_dict(config, update_dict)

model = BaseViT(config)

model.load_state_dict(checkpoint['model_state_dict'])

license: mit