import torch import numpy as np import gradio as gr import plotly.express as px import plotly.graph_objects as go from sklearn.decomposition import PCA from torchvision import transforms as T from sklearn.preprocessing import MinMaxScaler device = "cuda" if torch.cuda.is_available() else "cpu" dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') dino.eval() dino.to(device) pca = PCA(n_components=3) scaler = MinMaxScaler(clip=True) def plot_img(img_array: np.array) -> go.Figure: fig = px.imshow(img_array) fig.update_layout( xaxis=dict(showticklabels=False), yaxis=dict(showticklabels=False) ) return fig def app_fn( img: np.ndarray, threshold: float, object_larger_than_bg: bool ) -> go.Figure: IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) patch_h = 40 patch_w = 40 transform = T.Compose([ T.Resize((14 * patch_h, 14 * patch_w)), T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ]) img = torch.from_numpy(img).type(torch.float).permute(2, 0, 1) / 255 img_tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): out = dino.forward_features(img_tensor) features = out["x_prenorm"][:, 1:, :] features = features.squeeze(0) features = features.cpu().numpy() pca_features = pca.fit_transform(features) pca_features = scaler.fit_transform(pca_features) if object_larger_than_bg: pca_features_bg = pca_features[:, 0] > threshold else: pca_features_bg = pca_features[:, 0] < threshold pca_features_fg = ~pca_features_bg pca_features_fg_seg = pca.fit_transform(features[pca_features_fg]) pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg) pca_features_rgb = np.zeros((patch_h * patch_w, 3)) pca_features_rgb[pca_features_bg] = 0 pca_features_rgb[pca_features_fg] = pca_features_fg_seg pca_features_rgb = pca_features_rgb.reshape(patch_h, patch_w, 3) fig_pca = plot_img(pca_features_rgb) return fig_pca if __name__=="__main__": title = "DINOv2" with gr.Blocks(title=title) as demo: gr.Markdown(f"# {title}") gr.Markdown( """ """ ) with gr.Row(): threshold = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05, label="Threshold") object_larger_than_bg = gr.Checkbox(label="Object Larger than Background", value=False) btn = gr.Button(label="Visualize") with gr.Row(): img = gr.Image() fig_pca = gr.Plot(label="PCA Features") btn.click(fn=app_fn, inputs=[img, threshold, object_larger_than_bg], outputs=[fig_pca]) examples = gr.Examples( examples=[ ["assets/photo-1.jpg", 0.6, False], ["assets/photo-2.jpg", 0.7, True], ["assets/photo-3.jpg", 0.8, False] ], inputs=[img, threshold, object_larger_than_bg], outputs=[fig_pca], fn=app_fn, cache_examples=True ) demo.queue(max_size=5).launch()