--- license: apache-2.0 inference: false datasets: - mnist pipeline_tag: image-classification --- # Perceiver IO image classifier (MNIST) This model is a small Perceiver IO image classifier (907K parameters) trained from scratch on the [MNIST](https://huggingface.co/datasets/mnist) dataset. It is a [training example](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification) of the [perceiver-io](https://github.com/krasserm/perceiver-io) library. ## Model description Like [krasserm/perceiver-io-img-clf](https://huggingface.co/krasserm/perceiver-io-img-clf) this model also uses 2D Fourier features for position encoding and cross-attends to individual pixels of an input image but uses repeated cross-attention, a configuration that was described in the original [Perceiver paper](https://arxiv.org/abs/2103.03206) which has been dropped in the follow-up [Perceiver IO paper](https://arxiv.org/abs/2107.14795) (see [building blocks](https://github.com/krasserm/perceiver-io/blob/main/docs/building-blocks.md) for more details). ## Model training The model was [trained](https://github.com/krasserm/perceiver-io/blob/main/docs/training-examples.md#image-classification) with randomly initialized weights on the MNIST handwritten digits dataset. Images were normalized, data augmentations were turned off. Training was done with [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) and the resulting checkpoint was converted to this 🤗 model with a library-specific [conversion utility](#checkpoint-conversion). ## Intended use and limitations The model can be used for MNIST handwritten digit classification. ## Usage examples To use this model you first need to [install](https://github.com/krasserm/perceiver-io/blob/main/README.md#installation) the `perceiver-io` library with extension `vision`. ```shell pip install perceiver-io[vision] ``` Then the model can be used with PyTorch. Either use the model and image processor directly ```python from datasets import load_dataset from transformers import AutoModelForImageClassification, AutoImageProcessor from perceiver.model.vision import image_classifier # auto-class registration repo_id = "krasserm/perceiver-io-img-clf-mnist" mnist_dataset = load_dataset("mnist", split="test")[:9] images = mnist_dataset["image"] labels = mnist_dataset["label"] model = AutoModelForImageClassification.from_pretrained(repo_id) processor = AutoImageProcessor.from_pretrained(repo_id) inputs = processor(images, return_tensors="pt") logits = model(**inputs).logits print(f"Labels: {labels}") print(f"Predictions: {logits.argmax(dim=-1).numpy().tolist()}") ``` ``` Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5] Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5] ``` or use an `image-classification` pipeline: ```python from datasets import load_dataset from transformers import pipeline from perceiver.model.vision import image_classifier # auto-class registration repo_id = "krasserm/perceiver-io-img-clf-mnist" mnist_dataset = load_dataset("mnist", split="test")[:9] images = mnist_dataset["image"] labels = mnist_dataset["label"] classifier = pipeline("image-classification", model=repo_id) predictions = [pred[0]["label"] for pred in classifier(images)] print(f"Labels: {labels}") print(f"Predictions: {predictions}") ``` ``` Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5] Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5] ``` ## Checkpoint conversion The `krasserm/perceiver-io-img-clf-mnist` model has been created from a training checkpoint with: ```python from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint convert_mnist_classifier_checkpoint( save_dir="krasserm/perceiver-io-img-clf-mnist", ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt", push_to_hub=True, ) ```