File size: 6,829 Bytes
ee622f8 89f1b49 ee622f8 a2df2af ee622f8 b57b11d ee622f8 89d2039 6eaf354 89d2039 6eaf354 89d2039 6eaf354 89d2039 ee622f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
---
library_name: aim
pipeline_tag: image-classification
license: other
license_name: apple-sample-code-license
license_link: LICENSE
datasets:
- imagenet-1k
metrics:
- accuracy
tags:
- large-scale-vision-models
- pytorch
- mlx
- jax
- vision
- ssl
- pre-training
- DFN
---
# AIM: Autoregressive Image Models
*Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar,
Joshua M Susskind, and Armand Joulin*
This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.
We introduce **AIM** a collection of vision models pre-trained with an autoregressive generative objective.
We show that autoregressive pre-training of image features exhibits similar scaling properties to their
textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:
1. the model capacity can be trivially scaled to billions of parameters, and
2. AIM effectively leverages large collections of uncurated image data.
## Installation
Please install PyTorch using the official [installation instructions](https://pytorch.org/get-started/locally/).
Afterward, install the package as:
```commandline
pip install git+https://[email protected]/apple/ml-aim.git
```
We also offer [MLX](https://github.com/ml-explore/mlx) backend support for research and experimentation on Apple silicon.
To enable MLX support, simply run:
```commandline
pip install mlx
```
## Usage
Below we provide an example of usage in [PyTorch](https://pytorch.org/):
```python
from PIL import Image
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
logits, _ = model(inp)
```
<details>
<summary>and in both <a href="https://ml-explore.github.io/mlx/">MLX</a></summary>
```python
from PIL import Image
import mlx.core as mx
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, _ = model(inp)
```
</details>
<details>
<summary>and <a href="https://jax.readthedocs.io/">JAX</a></summary>
```python
from PIL import Image
import jax.numpy as jnp
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, _), _ = model.apply(params, inp, mutable=['batch_stats'])
```
</details>
## Usage
The pre-trained models can be used via [Hugging Face hub](https://huggingface.co/collections/apple/aim-65aa3ce948c718a574f09eb7) as follows:
```python
from PIL import Image
from aim.torch.models import AIMForImageClassification
from aim.torch.data import val_transforms
img = Image.open(...)
model = AIMForImageClassification.from_pretrained("apple/aim-7B")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
logits, features = model(inp)
```
### Pre-trained backbones
The following table contains pre-trained backbones used in our paper.
<table style="margin: auto">
<thead>
<tr>
<th>model</th>
<th>#params</th>
<th>attn (best layer)</th>
<th>backbone, SHA256</th>
</tr>
</thead>
<tbody>
<tr>
<td>AIM-0.6B</td>
<td>0.6B</td>
<td>79.4%</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_backbone.pth">link</a>, 0d6f6b8f</td>
</tr>
<tr>
<td>AIM-1B</td>
<td>1B</td>
<td>82.3%</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_backbone.pth">link</a>, d254ecd3</td>
</tr>
<tr>
<td>AIM-3B</td>
<td>3B</td>
<td>83.3%</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_backbone.pth">link</a>, 8475ce4e</td>
</tr>
<tr>
<td>AIM-7B</td>
<td>7B</td>
<td>84.0%</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_backbone.pth">link</a>, 184ed94c</td>
</tr>
</tbody>
</table>
### Pre-trained attention heads
The table below contains the classification results on ImageNet-1k validation set.
<table style="margin: auto">
<thead>
<tr>
<th rowspan="2">model</th>
<th colspan="2">top-1 IN-1k</th>
<th colspan="2">attention head, SHA256</th>
</tr>
<tr>
<th>last layer</th>
<th>best layer</th>
<th>last layer</th>
<th>best layer</th>
</tr>
</thead>
<tbody>
<tr>
<td>AIM-0.6B</td>
<td>78.5%</td>
<td>79.4%</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_head_last_layers.pth">link</a>, 5ce5a341</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_head_best_layers.pth">link</a>, ebd45c05</td>
</tr>
<tr>
<td>AIM-1B</td>
<td>80.6%</td>
<td>82.3%</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_head_last_layers.pth">link</a>, db3be2ad</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_head_best_layers.pth">link</a>, f1ed7852</td>
</tr>
<tr>
<td>AIM-3B</td>
<td>82.2%</td>
<td>83.3%</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_head_last_layers.pth">link</a>, 5c057b30</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_head_best_layers.pth">link</a>, ad380e16</td>
</tr>
<tr>
<td>AIM-7B</td>
<td>82.4%</td>
<td>84.0%</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_head_last_layers.pth">link</a>, 1e5c99ba</td>
<td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_head_best_layers.pth">link</a>, 73ecd732</td>
</tr>
</tbody>
</table>
## Reproducing the IN-1k classification results
The commands below reproduce the [attention probe results](#pre-trained-attention-heads) on ImageNet-1k
validation set. We run the evaluation using 1 node with 8 GPUs:
```commandline
torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
--model=aim-7B \
--batch-size=64 \
--data-path=/path/to/imagenet \
--probe-layers=last \
--backbone-ckpt-path=/path/to/backbone_ckpt.pth \
--head-ckpt-path=/path/to/head_ckpt.pth
```
By default, we probe the last 6 layers. To change this, simply pass `--probe-layers=best`. |