Rohit Rajpoot commited on
Commit
d1a29d0
·
1 Parent(s): 6b3fcc5

Deploy transformer demo to Space

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. app.py +9 -8
  3. assist/main.py +6 -0
  4. assist/transformer_demo.py +50 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -1,18 +1,19 @@
1
  import streamlit as st
2
  from assist.chat import chat as embed_chat
3
  from assist.bayes_chat import bayes_chat
 
4
 
5
  st.title("RepoSage Chatbot Demo")
6
 
7
  question = st.text_input("Enter your question below:")
8
 
9
- col1, col2 = st.columns(2)
10
  with col1:
11
- if st.button("Ask Embedding RepoSage"):
12
- answer = embed_chat(question)
13
- st.write(answer)
14
-
15
  with col2:
16
- if st.button("Ask Bayesian RepoSage"):
17
- answer = bayes_chat(question)
18
- st.write(answer)
 
 
 
1
  import streamlit as st
2
  from assist.chat import chat as embed_chat
3
  from assist.bayes_chat import bayes_chat
4
+ from assist.transformer_demo import transformer_next
5
 
6
  st.title("RepoSage Chatbot Demo")
7
 
8
  question = st.text_input("Enter your question below:")
9
 
10
+ col1, col2, col3 = st.columns(3)
11
  with col1:
12
+ if st.button("Embedding Q&A"):
13
+ st.write(embed_chat(question))
 
 
14
  with col2:
15
+ if st.button("Bayesian Q&A"):
16
+ st.write(bayes_chat(question))
17
+ with col3:
18
+ if st.button("Transformer Demo"):
19
+ st.write(transformer_next(question))
assist/main.py CHANGED
@@ -22,5 +22,11 @@ def chat(question: str = typer.Argument(..., help="Question to ask RepoSage")):
22
  response = chat_plugin(question)
23
  print(response)
24
 
 
 
 
 
 
 
25
  if __name__ == "__main__":
26
  app()
 
22
  response = chat_plugin(question)
23
  print(response)
24
 
25
+ @app.command()
26
+ def transform(prompt: str = typer.Argument(..., help="Prompt for transformer demo")):
27
+ """Invoke the single‐block transformer next-token demo."""
28
+ from .transformer_demo import transformer_next
29
+ print(transformer_next(prompt))
30
+
31
  if __name__ == "__main__":
32
  app()
assist/transformer_demo.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from .chat import TOKEN2IDX, IDX2TOKEN # reuse your vocab maps
5
+ from .chat import WEIGHTS # reuse your embedding weights
6
+
7
+ class SingleTransformerBlock(nn.Module):
8
+ def __init__(self, embed_dim, num_heads=2):
9
+ super().__init__()
10
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
11
+ self.norm1 = nn.LayerNorm(embed_dim)
12
+ self.ff = nn.Sequential(
13
+ nn.Linear(embed_dim, embed_dim * 4),
14
+ nn.ReLU(),
15
+ nn.Linear(embed_dim * 4, embed_dim)
16
+ )
17
+ self.norm2 = nn.LayerNorm(embed_dim)
18
+
19
+ def forward(self, x):
20
+ # Self-attention
21
+ attn_out, _ = self.attn(x, x, x)
22
+ x = self.norm1(x + attn_out)
23
+ # Feed-forward
24
+ ff_out = self.ff(x)
25
+ x = self.norm2(x + ff_out)
26
+ return x
27
+
28
+ # Instantiate once
29
+ _EMB = torch.tensor(WEIGHTS, dtype=torch.float32) # V×D
30
+ _block = SingleTransformerBlock(embed_dim=_EMB.size(1), num_heads=2)
31
+
32
+ def transformer_next(prompt: str) -> str:
33
+ """
34
+ Given a prompt, tokenize it, embed each token, run through one
35
+ transformer block, then use the last position’s output vector
36
+ to pick the nearest vocab token as the “next token.”
37
+ """
38
+ tokens = prompt.lower().split()
39
+ idxs = [TOKEN2IDX[t] for t in tokens if t in TOKEN2IDX]
40
+ if not idxs:
41
+ return "🤔 No known tokens to predict from."
42
+ # Build batch: 1×seq_len×D
43
+ x = _EMB[idxs].unsqueeze(0)
44
+ # Forward pass
45
+ out = _block(x) # 1×seq_len×D
46
+ last = out[0, -1].unsqueeze(0) # 1×D
47
+ # Cosine similarity against all embeddings
48
+ sims = nn.functional.cosine_similarity(last, _EMB)
49
+ best = int(torch.argmax(sims))
50
+ return f"🔮 Next‐token prediction: **{IDX2TOKEN[best]}**"