Rohit Rajpoot commited on
Commit
612ef10
·
1 Parent(s): b802c7c

Add heatmap button to Streamlit UI

Browse files
Files changed (2) hide show
  1. app.py +8 -4
  2. assist/heatmap.py +7 -6
app.py CHANGED
@@ -1,13 +1,17 @@
1
  import streamlit as st
2
  from assist.chat import chat as chat_plugin
 
3
 
4
  st.title("RepoSage Chatbot Demo")
5
 
6
- # 1) Change the label to make it obvious we're asking a question
7
- question = st.text_input("Ask RepoSage a question:", "")
 
 
 
8
 
9
- # 2) Only run when clicked
 
10
  if st.button("Ask RepoSage"):
11
- # 3) Pass that question into your stub
12
  response = chat_plugin(question)
13
  st.write(response)
 
1
  import streamlit as st
2
  from assist.chat import chat as chat_plugin
3
+ from assist.heatmap import get_heatmap_figure
4
 
5
  st.title("RepoSage Chatbot Demo")
6
 
7
+ # Tab or expander for heatmap
8
+ with st.expander("🔢 View Token Similarity Heatmap"):
9
+ if st.button("Show Heatmap"):
10
+ fig = get_heatmap_figure("tensor.pt")
11
+ st.pyplot(fig)
12
 
13
+ # Chat UI
14
+ question = st.text_input("Ask RepoSage a question:", "")
15
  if st.button("Ask RepoSage"):
 
16
  response = chat_plugin(question)
17
  st.write(response)
assist/heatmap.py CHANGED
@@ -2,13 +2,14 @@ import torch
2
  import matplotlib.pyplot as plt
3
  from sklearn.metrics.pairwise import cosine_similarity
4
 
5
- def show_heatmap(tensor_path="tensor.pt"):
6
  # Load embeddings
7
  weights = torch.load(tensor_path).detach().numpy()
8
  # Compute similarity
9
  sim = cosine_similarity(weights)
10
- # Plot
11
- plt.imshow(sim, cmap="viridis")
12
- plt.colorbar()
13
- plt.title("Token Similarity Heatmap")
14
- plt.show()
 
 
2
  import matplotlib.pyplot as plt
3
  from sklearn.metrics.pairwise import cosine_similarity
4
 
5
+ def get_heatmap_figure(tensor_path="tensor.pt"):
6
  # Load embeddings
7
  weights = torch.load(tensor_path).detach().numpy()
8
  # Compute similarity
9
  sim = cosine_similarity(weights)
10
+ # Build figure
11
+ fig, ax = plt.subplots()
12
+ cax = ax.imshow(sim, cmap="viridis")
13
+ fig.colorbar(cax, ax=ax)
14
+ ax.set_title("Token Similarity Heatmap")
15
+ return fig