Edit model card

Hiera Model (Tiny, fine-tuned on IN1K)

Hiera is a hierarchical vision transformer that is fast, powerful, and, above all, simple. It was introduced in the paper Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles and outperforms the state-of-the-art across a wide array of image and video tasks while being much faster.

How does it work?

A diagram of Hiera's architecture.

Vision transformers like ViT use the same spatial resolution and number of features throughout the whole network. But this is inefficient: the early layers don't need that many features, and the later layers don't need that much spatial resolution. Prior hierarchical models like ResNet accounted for this by using fewer features at the start and less spatial resolution at the end.

Several domain specific vision transformers have been introduced that employ this hierarchical design, such as Swin or MViT. But in the pursuit of state-of-the-art results using fully supervised training on ImageNet-1K, these models have become more and more complicated as they add specialized modules to make up for spatial biases that ViTs lack. While these changes produce effective models with attractive FLOP counts, under the hood the added complexity makes these models slower overall.

We show that a lot of this bulk is actually unnecessary. Instead of manually adding spatial bases through architectural changes, we opt to teach the model these biases instead. By training with MAE, we can simplify or remove all of these bulky modules in existing transformers and increase accuracy in the process. The result is Hiera, an extremely efficient and simple architecture that outperforms the state-of-the-art in several image and video recognition tasks.

Intended uses & limitations

Hiera can be used for image classification, feature extraction or masked image modeling. This checkpoint in specific is intended for Feature Extraction.

How to use

from transformers import AutoImageProcessor, HieraModel
import torch
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("facebook/hiera-base-224-hf")
model = HieraModel.from_pretrained("facebook/hiera-base-224-hf")

inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)

You can also extract feature maps from different stages of the model using HieraBackbone and setting out_features when loading the model. This is how you would extract feature maps from every stage:

from transformers import AutoImageProcessor, HieraBackbone
import torch
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("facebook/hiera-base-224-hf")
# `out_features` should be a subset of ['stem', 'stage1', 'stage2', 'stage3', 'stage4']
# This introduce new LayerNorm layers and should probably train on a down-stream task
model = HieraBackbone.from_pretrained("facebook/hiera-base-224-hf", out_features=['stage1', 'stage2', 'stage3', 'stage4'])

inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
feature_maps = outputs.feature_maps

BibTeX entry and citation info

If you use Hiera or this code in your work, please cite:

@article{ryali2023hiera,
  title={Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles},
  author={Ryali, Chaitanya and Hu, Yuan-Ting and Bolya, Daniel and Wei, Chen and Fan, Haoqi and Huang, Po-Yao and Aggarwal, Vaibhav and Chowdhury, Arkabandhu and Poursaeed, Omid and Hoffman, Judy and Malik, Jitendra and Li, Yanghao and Feichtenhofer, Christoph},
  journal={ICML},
  year={2023}
}
Downloads last month
47
Safetensors
Model size
50.8M params
Tensor type
F32
·
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.

Model tree for facebook/hiera-base-224-hf

Quantizations
1 model

Dataset used to train facebook/hiera-base-224-hf