File size: 467 Bytes
4adef30
 
 
 
612ef10
4adef30
b802c7c
4adef30
 
612ef10
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

def get_heatmap_figure(tensor_path="tensor.pt"):
    # Load embeddings
    weights = torch.load(tensor_path).detach().numpy()
    # Compute similarity
    sim = cosine_similarity(weights)
    # Build figure
    fig, ax = plt.subplots()
    cax = ax.imshow(sim, cmap="viridis")
    fig.colorbar(cax, ax=ax)
    ax.set_title("Token Similarity Heatmap")
    return fig