Sbnos commited on
Commit
35fbc0b
·
verified ·
1 Parent(s): d1601e3

chatgpt updates, reranking plus copy button plus streaming

Browse files
Files changed (1) hide show
  1. app.py +53 -16
app.py CHANGED
@@ -46,6 +46,15 @@ llmc = Together(
46
  together_api_key=os.environ['pilotikval']
47
  )
48
 
 
 
 
 
 
 
 
 
 
49
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
50
  memory = ConversationBufferMemory(chat_memory=msgs)
51
 
@@ -63,6 +72,26 @@ def store_chat_history(role: str, content: str):
63
  # Append the new message to the chat history
64
  chistory.append({"role": role, "content": content})
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Define the Streamlit app
67
  def app():
68
  with st.sidebar:
@@ -123,9 +152,9 @@ def app():
123
  conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | llm
124
 
125
  st.header("Ask Away!")
126
- for message in st.session_state.messages:
127
  with st.chat_message(message["role"]):
128
- st.write(message["content"])
129
  store_chat_history(message["role"], message["content"])
130
 
131
  prompts2 = st.chat_input("Say something")
@@ -140,29 +169,37 @@ def app():
140
  with st.spinner("Thinking..."):
141
  for _ in range(3): # Retry up to 3 times
142
  try:
143
- response_generator = stream_conversational_qa_chain(
 
144
  {
145
  "question": prompts2,
146
  "chat_history": chistory,
147
- }
 
148
  )
149
- message_content = ""
150
- for response_part in response_generator:
151
- message_content += response_part
152
- st.chat_message("assistant", message_content)
153
- st.session_state.messages.append({"role": "assistant", "content": message_content})
154
  break
155
  except Exception as e:
156
  st.error(f"An error occurred: {e}")
157
  time.sleep(2) # Wait 2 seconds before retrying
158
 
159
- def stream_conversational_qa_chain(inputs):
160
- try:
161
- response = conversational_qa_chain.invoke(inputs)
162
- for part in response:
163
- yield part
164
- except Exception as e:
165
- raise e
 
 
 
 
 
 
 
166
 
167
  if __name__ == '__main__':
168
  app()
 
46
  together_api_key=os.environ['pilotikval']
47
  )
48
 
49
+ # Load the reranking model
50
+ reranker = Together(
51
+ model="mistralai/Mixtral-8x22B-Instruct-v0.1",
52
+ temperature=0.2,
53
+ max_tokens=512,
54
+ top_k=10,
55
+ together_api_key=os.environ['pilotikval']
56
+ )
57
+
58
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
59
  memory = ConversationBufferMemory(chat_memory=msgs)
60
 
 
72
  # Append the new message to the chat history
73
  chistory.append({"role": role, "content": content})
74
 
75
+ def render_message_with_copy_button(role: str, content: str, key: str):
76
+ html_code = f"""
77
+ <div class="message" style="position: relative; padding-right: 40px;">
78
+ <div class="message-content">{content}</div>
79
+ <button onclick="copyToClipboard('{key}')" style="position: absolute; right: 0; top: 0;">Copy</button>
80
+ </div>
81
+ <textarea id="{key}" style="display:none;">{content}</textarea>
82
+ <script>
83
+ function copyToClipboard(key) {{
84
+ var copyText = document.getElementById(key);
85
+ copyText.style.display = "block";
86
+ copyText.select();
87
+ document.execCommand("copy");
88
+ copyText.style.display = "none";
89
+ alert("Copied to clipboard");
90
+ }}
91
+ </script>
92
+ """
93
+ st.write(html_code, unsafe_allow_html=True)
94
+
95
  # Define the Streamlit app
96
  def app():
97
  with st.sidebar:
 
152
  conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | llm
153
 
154
  st.header("Ask Away!")
155
+ for i, message in enumerate(st.session_state.messages):
156
  with st.chat_message(message["role"]):
157
+ render_message_with_copy_button(message["role"], message["content"], key=f"message-{i}")
158
  store_chat_history(message["role"], message["content"])
159
 
160
  prompts2 = st.chat_input("Say something")
 
169
  with st.spinner("Thinking..."):
170
  for _ in range(3): # Retry up to 3 times
171
  try:
172
+ responses = generate_multiple_responses(
173
+ conversational_qa_chain,
174
  {
175
  "question": prompts2,
176
  "chat_history": chistory,
177
+ },
178
+ num_responses=5
179
  )
180
+ best_response = rerank_responses(reranker, responses)
181
+ st.write(best_response)
182
+ message = {"role": "assistant", "content": best_response}
183
+ st.session_state.messages.append(message)
 
184
  break
185
  except Exception as e:
186
  st.error(f"An error occurred: {e}")
187
  time.sleep(2) # Wait 2 seconds before retrying
188
 
189
+ def generate_multiple_responses(chain, inputs, num_responses=5):
190
+ responses = []
191
+ for _ in range(num_responses):
192
+ response = chain.invoke(inputs)
193
+ responses.append(response)
194
+ return responses
195
+
196
+ def rerank_responses(reranker, responses):
197
+ scores = []
198
+ for response in responses:
199
+ score = reranker.invoke({"input": response})
200
+ scores.append(score)
201
+ best_response_idx = scores.index(max(scores))
202
+ return responses[best_response_idx]
203
 
204
  if __name__ == '__main__':
205
  app()