import torch import matplotlib.pyplot as plt from sklearn.metrics.pairwise import cosine_similarity def get_heatmap_figure(tensor_path="tensor.pt"): # Load embeddings weights = torch.load(tensor_path).detach().numpy() # Compute similarity sim = cosine_similarity(weights) # Build figure fig, ax = plt.subplots() cax = ax.imshow(sim, cmap="viridis") fig.colorbar(cax, ax=ax) ax.set_title("Token Similarity Heatmap") return fig