S-MultiMAE / s_multimae /visualize_2d_posemb.py
thinh-researcher's picture
Init
6e9c433
import numpy as np
from torch import Tensor
import matplotlib.pyplot as plt
from s_multimae.model.multimae import build_2d_sincos_posemb
def visualize_2d_posemb():
NH, NW = 14, 14
dim_tokens = 768
colors = [
"Greys",
"Purples",
"Blues",
"Greens",
"Oranges",
"Reds",
"YlOrBr",
"YlOrRd",
"OrRd",
"PuRd",
"RdPu",
"BuPu",
"GnBu",
"PuBu",
"YlGnBu",
"PuBuGn",
"BuGn",
"YlGn",
]
pos_emb: Tensor = build_2d_sincos_posemb(NH, NW, dim_tokens)
pos_emb_numpy: np.ndarray = (
pos_emb.squeeze(0).permute(1, 2, 0).numpy()
) # 14 x 14 x 768
x = np.linspace(0, NH - 1, NH)
y = np.linspace(0, NW - 1, NW)
X, Y = np.meshgrid(x, y)
for color, i in zip(colors, range(0, pos_emb_numpy.shape[2], 100)):
ax = plt.axes(projection="3d")
Z = pos_emb_numpy[:, :, i]
# plt.imshow(Z, cmap='viridis')
# plt.savefig(f'posemb_visualization/test_{i}.png')
ax.plot_surface(
X,
Y,
Z,
# rstride=1, cstride=1,
cmap="viridis",
edgecolor="none",
)
plt.show()
plt.savefig(f"posemb_visualization/test_{i}.png")