Rohit Rajpoot
Add heatmap button to Streamlit UI
612ef10
raw
history blame contribute delete
467 Bytes
import torch
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
def get_heatmap_figure(tensor_path="tensor.pt"):
# Load embeddings
weights = torch.load(tensor_path).detach().numpy()
# Compute similarity
sim = cosine_similarity(weights)
# Build figure
fig, ax = plt.subplots()
cax = ax.imshow(sim, cmap="viridis")
fig.colorbar(cax, ax=ax)
ax.set_title("Token Similarity Heatmap")
return fig