Spaces:
Sleeping
Sleeping
import numpy as np | |
from torch import Tensor | |
import matplotlib.pyplot as plt | |
from s_multimae.model.multimae import build_2d_sincos_posemb | |
def visualize_2d_posemb(): | |
NH, NW = 14, 14 | |
dim_tokens = 768 | |
colors = [ | |
"Greys", | |
"Purples", | |
"Blues", | |
"Greens", | |
"Oranges", | |
"Reds", | |
"YlOrBr", | |
"YlOrRd", | |
"OrRd", | |
"PuRd", | |
"RdPu", | |
"BuPu", | |
"GnBu", | |
"PuBu", | |
"YlGnBu", | |
"PuBuGn", | |
"BuGn", | |
"YlGn", | |
] | |
pos_emb: Tensor = build_2d_sincos_posemb(NH, NW, dim_tokens) | |
pos_emb_numpy: np.ndarray = ( | |
pos_emb.squeeze(0).permute(1, 2, 0).numpy() | |
) # 14 x 14 x 768 | |
x = np.linspace(0, NH - 1, NH) | |
y = np.linspace(0, NW - 1, NW) | |
X, Y = np.meshgrid(x, y) | |
for color, i in zip(colors, range(0, pos_emb_numpy.shape[2], 100)): | |
ax = plt.axes(projection="3d") | |
Z = pos_emb_numpy[:, :, i] | |
# plt.imshow(Z, cmap='viridis') | |
# plt.savefig(f'posemb_visualization/test_{i}.png') | |
ax.plot_surface( | |
X, | |
Y, | |
Z, | |
# rstride=1, cstride=1, | |
cmap="viridis", | |
edgecolor="none", | |
) | |
plt.show() | |
plt.savefig(f"posemb_visualization/test_{i}.png") | |