mouaddb commited on
Commit
3ec9877
1 Parent(s): a57bd0c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ from sklearn.decomposition import PCA
7
+ from torchvision import transforms as T
8
+ from sklearn.preprocessing import MinMaxScaler
9
+
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
14
+ dino.eval()
15
+ dino.to(device)
16
+
17
+ pca = PCA(n_components=3)
18
+ scaler = MinMaxScaler(clip=True)
19
+
20
+ def plot_img(img_array: np.array) -> go.Figure:
21
+ fig = px.imshow(img_array)
22
+ fig.update_layout(
23
+ xaxis=dict(showticklabels=False),
24
+ yaxis=dict(showticklabels=False)
25
+ )
26
+
27
+ return fig
28
+
29
+
30
+ def app_fn(
31
+ img: np.ndarray,
32
+ threshold: float,
33
+ object_larger_than_bg: bool
34
+ ) -> go.Figure:
35
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
36
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
37
+
38
+ patch_h = 40
39
+ patch_w = 40
40
+
41
+ transform = T.Compose([
42
+ T.Resize((14 * patch_h, 14 * patch_w)),
43
+ T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
44
+ ])
45
+
46
+ img = torch.from_numpy(img).type(torch.float).permute(2, 0, 1) / 255
47
+ img_tensor = transform(img).unsqueeze(0).to(device)
48
+
49
+ with torch.no_grad():
50
+ out = dino.forward_features(img_tensor)
51
+
52
+ features = out["x_prenorm"][:, 1:, :]
53
+ features = features.squeeze(0)
54
+ features = features.cpu().numpy()
55
+
56
+ pca_features = pca.fit_transform(features)
57
+ pca_features = scaler.fit_transform(pca_features)
58
+
59
+ if object_larger_than_bg:
60
+ pca_features_bg = pca_features[:, 0] > threshold
61
+ else:
62
+ pca_features_bg = pca_features[:, 0] < threshold
63
+
64
+ pca_features_fg = ~pca_features_bg
65
+
66
+ pca_features_fg_seg = pca.fit_transform(features[pca_features_fg])
67
+
68
+ pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg)
69
+
70
+ pca_features_rgb = np.zeros((patch_h * patch_w, 3))
71
+ pca_features_rgb[pca_features_bg] = 0
72
+ pca_features_rgb[pca_features_fg] = pca_features_fg_seg
73
+ pca_features_rgb = pca_features_rgb.reshape(patch_h, patch_w, 3)
74
+
75
+
76
+ fig_pca = plot_img(pca_features_rgb)
77
+
78
+ return fig_pca
79
+
80
+ if __name__=="__main__":
81
+ title = "DINOv2"
82
+ with gr.Blocks(title=title) as demo:
83
+ gr.Markdown(f"# {title}")
84
+ gr.Markdown(
85
+ """
86
+ """
87
+ )
88
+ with gr.Row():
89
+ threshold = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05, label="Threshold")
90
+ object_larger_than_bg = gr.Checkbox(label="Object Larger than Background", value=False)
91
+ btn = gr.Button(label="Visualize")
92
+ with gr.Row():
93
+ img = gr.Image()
94
+ fig_pca = gr.Plot(label="PCA Features")
95
+
96
+ btn.click(fn=app_fn, inputs=[img, threshold, object_larger_than_bg], outputs=[fig_pca])
97
+ examples = gr.Examples(
98
+ examples=[
99
+ ["assets/neca-the-cat.jpeg", 0.6, True],
100
+ ["assets/dog.png", 0.7, False]
101
+ ],
102
+ inputs=[img, threshold, object_larger_than_bg],
103
+ outputs=[fig_pca],
104
+ fn=app_fn,
105
+ cache_examples=True
106
+ )
107
+
108
+ demo.queue(max_size=5).launch()