Spaces:
Build error
Build error
Davidzhangyuanhan
commited on
Commit
·
34d86b5
1
Parent(s):
cc8b572
Add application file
Browse files- Bamboo_v0-1_ViT-B16.pth.tar.convert +3 -0
- README copy.md +13 -0
- app.py +105 -0
- examples/Ferrari-F355.jpg +0 -0
- examples/basketball.jpg +0 -0
- examples/dribbler.jpg +0 -0
- examples/fratercula_arctica.jpg +0 -0
- examples/husky.jpg +0 -0
- examples/northern_oriole.jpg +0 -0
- examples/playing_mahjong.jpg +0 -0
- examples/taraxacum_erythrospermum.jpg +0 -0
- requirements.txt +6 -0
- timmvit.py +83 -0
- trainid2name.json +0 -0
Bamboo_v0-1_ViT-B16.pth.tar.convert
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6d30c823ba2fc764291e65a06747390a81b15a1e655dd02b45d58528e08c937
|
3 |
+
size 697651655
|
README copy.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Bamboo ViT-B16 Demo
|
3 |
+
emoji: 🎋
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.0.17
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: cc-by-4.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import requests
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision
|
10 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
11 |
+
from timm.data import create_transform
|
12 |
+
|
13 |
+
from timmvit import timmvit
|
14 |
+
import json
|
15 |
+
from timm.models.hub import download_cached_file
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
def pil_loader(filepath):
|
19 |
+
with Image.open(filepath) as img:
|
20 |
+
img = img.convert('RGB')
|
21 |
+
return img
|
22 |
+
|
23 |
+
def build_transforms(input_size, center_crop=True):
|
24 |
+
transform = torchvision.transforms.Compose([
|
25 |
+
torchvision.transforms.ToPILImage(),
|
26 |
+
torchvision.transforms.Resize(input_size * 8 // 7),
|
27 |
+
torchvision.transforms.CenterCrop(input_size),
|
28 |
+
torchvision.transforms.ToTensor(),
|
29 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
30 |
+
])
|
31 |
+
return transform
|
32 |
+
|
33 |
+
# Download human-readable labels for Bamboo.
|
34 |
+
with open('./trainid2name.json') as f:
|
35 |
+
id2name = json.load(f)
|
36 |
+
|
37 |
+
|
38 |
+
'''
|
39 |
+
build model
|
40 |
+
'''
|
41 |
+
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
'''
|
45 |
+
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
|
46 |
+
'''
|
47 |
+
def show_cam_on_image(img: np.ndarray,
|
48 |
+
mask: np.ndarray,
|
49 |
+
use_rgb: bool = False,
|
50 |
+
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
|
51 |
+
""" This function overlays the cam mask on the image as an heatmap.
|
52 |
+
By default the heatmap is in BGR format.
|
53 |
+
:param img: The base image in RGB or BGR format.
|
54 |
+
:param mask: The cam mask.
|
55 |
+
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
|
56 |
+
:param colormap: The OpenCV colormap to be used.
|
57 |
+
:returns: The default image with the cam overlay.
|
58 |
+
"""
|
59 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
60 |
+
if use_rgb:
|
61 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
62 |
+
heatmap = np.float32(heatmap) / 255
|
63 |
+
|
64 |
+
if np.max(img) > 1:
|
65 |
+
raise Exception(
|
66 |
+
"The input image should np.float32 in the range [0, 1]")
|
67 |
+
|
68 |
+
cam = 0.7*heatmap + 0.3*img
|
69 |
+
# cam = cam / np.max(cam)
|
70 |
+
return np.uint8(255 * cam)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def recognize_image(image):
|
76 |
+
img_t = eval_transforms(image)
|
77 |
+
# compute output
|
78 |
+
output = model(img_t.unsqueeze(0))
|
79 |
+
prediction = output.softmax(-1).flatten()
|
80 |
+
_,top5_idx = torch.topk(prediction, 5)
|
81 |
+
return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
|
82 |
+
|
83 |
+
eval_transforms = build_transforms(224)
|
84 |
+
|
85 |
+
|
86 |
+
image = gr.inputs.Image()
|
87 |
+
label = gr.outputs.Label(num_top_classes=5)
|
88 |
+
|
89 |
+
gr.Interface(
|
90 |
+
description="Bamboo for Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo). Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).",
|
91 |
+
fn=recognize_image,
|
92 |
+
inputs=["image"],
|
93 |
+
outputs=[
|
94 |
+
label,
|
95 |
+
],
|
96 |
+
examples=[
|
97 |
+
["./examples/playing_mahjong.jpg"],
|
98 |
+
["./examples/dribbler.jpg"],
|
99 |
+
["./examples/Ferrari-F355.jpg"],
|
100 |
+
["./examples/northern_oriole.jpg"],
|
101 |
+
["./examples/fratercula_arctica.jpg"],
|
102 |
+
["./examples/husky.jpg"],
|
103 |
+
["./examples/taraxacum_erythrospermum.jpg"],
|
104 |
+
],
|
105 |
+
).launch()
|
examples/Ferrari-F355.jpg
ADDED
![]() |
examples/basketball.jpg
ADDED
![]() |
examples/dribbler.jpg
ADDED
![]() |
examples/fratercula_arctica.jpg
ADDED
![]() |
examples/husky.jpg
ADDED
![]() |
examples/northern_oriole.jpg
ADDED
![]() |
examples/playing_mahjong.jpg
ADDED
![]() |
examples/taraxacum_erythrospermum.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchvision==0.11.2
|
2 |
+
torch==1.10.1
|
3 |
+
opencv-python-headless==4.5.3.56
|
4 |
+
timm==0.4.12
|
5 |
+
numpy
|
6 |
+
|
timmvit.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# SenseTime VTAB
|
3 |
+
# Copyright (c) 2021 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
8 |
+
# ------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import timm
|
11 |
+
import torch
|
12 |
+
import copy
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchvision
|
15 |
+
import json
|
16 |
+
from timm.models.hub import download_cached_file
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class MyViT(nn.Module):
|
22 |
+
def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False):
|
23 |
+
super().__init__()
|
24 |
+
print('initializing ViT model as backbone using ckpt:', pretrain_path)
|
25 |
+
self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True)
|
26 |
+
# def forward_features(self, x):
|
27 |
+
# x = self.model.patch_embed(x)
|
28 |
+
# cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
29 |
+
# if self.model.dist_token is None:
|
30 |
+
# x = torch.cat((cls_token, x), dim=1)
|
31 |
+
# else:
|
32 |
+
# x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
33 |
+
|
34 |
+
# x = self.model.pos_drop(x + self.model.pos_embed)
|
35 |
+
# x = self.model.blocks(x)
|
36 |
+
# x = self.model.norm(x)
|
37 |
+
|
38 |
+
# return self.model.pre_logits(x[:, 0])
|
39 |
+
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.model.forward(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
def timmvit(**kwargs):
|
47 |
+
default_kwargs={}
|
48 |
+
default_kwargs.update(**kwargs)
|
49 |
+
return MyViT(**default_kwargs)
|
50 |
+
|
51 |
+
|
52 |
+
def build_transforms(input_size, center_crop=True):
|
53 |
+
transform = torchvision.transforms.Compose([
|
54 |
+
torchvision.transforms.Resize(input_size * 8 // 7),
|
55 |
+
torchvision.transforms.CenterCrop(input_size),
|
56 |
+
torchvision.transforms.ToTensor(),
|
57 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
58 |
+
])
|
59 |
+
return transform
|
60 |
+
|
61 |
+
def pil_loader(filepath):
|
62 |
+
with Image.open(filepath) as img:
|
63 |
+
img = img.convert('RGB')
|
64 |
+
return img
|
65 |
+
|
66 |
+
def test_build():
|
67 |
+
with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f:
|
68 |
+
id2name = json.load(f)
|
69 |
+
img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg')
|
70 |
+
eval_transforms = build_transforms(224)
|
71 |
+
img_t = eval_transforms(img)
|
72 |
+
img_t = img_t[None, :]
|
73 |
+
model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
74 |
+
# image = torch.rand(1, 3, 224, 224)
|
75 |
+
output = model(img_t)
|
76 |
+
# import pdb;pdb.set_trace()
|
77 |
+
prediction = output.softmax(-1).flatten()
|
78 |
+
_,top5_idx = torch.topk(prediction, 5)
|
79 |
+
# import pdb;pdb.set_trace()
|
80 |
+
print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()})
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
test_build()
|
trainid2name.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|