File size: 6,143 Bytes
ed43681 810861f e045a96 810861f e045a96 0fd9f42 e045a96 ed43681 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
---
title: 'GCViT: Global Context Vision Transformer'
colorFrom: indigo
sdk: gradio
sdk_version: 3.0.15
emoji: π
pinned: false
license: apache-2.0
app_file: app.py
---
<h1 align="center">
<p><a href='https://arxiv.org/pdf/2206.09959v1.pdf'>GCViT: Global Context Vision Transformer</a></p>
</h1>
<div align=center><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_arch.PNG" width=800></div>
<p align="center">
<a href="https://github.com/awsaf49/gcvit-tf/blob/main/LICENSE.md">
<img src="https://img.shields.io/badge/License-MIT-yellow.svg">
</a>
<img alt="python" src="https://img.shields.io/badge/python-%3E%3D3.6-blue?logo=python">
<img alt="tensorflow" src="https://img.shields.io/badge/tensorflow-%3E%3D2.4.1-orange?logo=tensorflow">
<div align=center><p>
<a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/π€%20Hugging%20Face-Spaces-yellow.svg"></a>
<a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>
</p></div>
<h2 align="center">
<p>Tensorflow 2.0 Implementation of GCViT</p>
</h2>
</p>
<p align="center">
This library implements <b>GCViT</b> using Tensorflow 2.0 specifically in <code>tf.keras.Model</code> manner to get PyTorch flavor.
</p>
## Update
* **15 Jan 2023** : `GCViTLarge` model added with ckpt.
* **3 Sept 2022** : Annotated [kaggle-notebook](https://www.kaggle.com/code/awsaf49/gcvit-global-context-vision-transformer) based on this project won [Kaggle ML Research Spotlight: August 2022](https://www.kaggle.com/discussions/general/349817).
* **19 Aug 2022** : This project got acknowledged by [Official](https://github.com/NVlabs/GCVit) repo [here](https://github.com/NVlabs/GCVit#third-party-implementations-and-resources)
## Model
* Architecture:
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/arch.PNG">
* Local Vs Global Attention:
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_msa.PNG">
## Result
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/result.PNG" width=900>
Official codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on **ImageNetV2-Test** data,
| Model | Acc@1 | Acc@5 | #Params |
|--------------|-------|-------|---------|
| GCViT-XXTiny | 0.663 | 0.873 | 12M |
| GCViT-XTiny | 0.685 | 0.885 | 20M |
| GCViT-Tiny | 0.708 | 0.899 | 28M |
| GCViT-Small | 0.720 | 0.901 | 51M |
| GCViT-Base | 0.731 | 0.907 | 90M |
| GCViT-Large | 0.734 | 0.913 | 202M |
## Installation
```bash
pip install -U gcvit
# or
# pip install -U git+https://github.com/awsaf49/gcvit-tf
```
## Usage
Load model using following codes,
```py
from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)
```
Simple code to check model's prediction,
```py
from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])
```
Prediction:
```py
[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623),
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297),
('n02883205', 'bow_tie', 0.00042479983)]
```
For feature extraction:
```py
model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
model.reset_classifier(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)
```
Feature:
```py
(None, 512)
```
For feature map:
```py
model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
feature = model.forward_features(img)
print(feature.shape)
```
Feature map:
```py
(None, 7, 7, 512)
```
## Live-Demo
* For live demo on Image Classification & Grad-CAM, with **ImageNet** weights, click <a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/Try%20on-Gradio-orange"></a> powered by π€ Space and Gradio. here's an example,
<a href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="image/gradio_demo.JPG" height=500></a>
## Example
For working training example checkout these notebooks on **Google Colab** <a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> & **Kaggle** <a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>.
Here is grad-cam result after training on Flower Classification Dataset,
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/flower_gradcam.PNG" height=500>
## To Do
- [ ] Segmentation Pipeline
- [x] New updated weights have been added.
- [x] Working training example in Colab & Kaggle.
- [x] GradCAM showcase.
- [x] Gradio Demo.
- [x] Build model with `tf.keras.Model`.
- [x] Port weights from official repo.
- [x] Support for `TPU`.
## Acknowledgement
* [GCVit](https://github.com/NVlabs/GCVit) (Official)
* [Swin-Transformer-TF](https://github.com/rishigami/Swin-Transformer-TF)
* [tfgcvit](https://github.com/shkarupa-alex/tfgcvit/tree/develop/tfgcvit)
* [keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_model)
## Citation
```bibtex
@article{hatamizadeh2022global,
title={Global Context Vision Transformers},
author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
journal={arXiv preprint arXiv:2206.09959},
year={2022}
}
``` |