File size: 397 Bytes
4adef30
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

def show_heatmap(tensor_path="tensor.pt"):
    # Load embeddings
    weights = torch.load(tensor_path).numpy()
    # Compute similarity
    sim = cosine_similarity(weights)
    # Plot
    plt.imshow(sim, cmap="viridis")
    plt.colorbar()
    plt.title("Token Similarity Heatmap")
    plt.show()