ruggsea commited on
Commit
173708b
·
1 Parent(s): f4ce675
Files changed (2) hide show
  1. app.py +23 -20
  2. data_preparation.ipynb +1 -1
app.py CHANGED
@@ -56,7 +56,8 @@ Settings.llm = Groq(
56
  model="llama3-8b-8192",
57
  api_key=os.getenv("GROQ_API_KEY"),
58
  max_tokens=6000,
59
- context_window=6000
 
60
  )
61
 
62
  @st.cache_resource
@@ -92,13 +93,11 @@ def load_indices():
92
  index, vector_retriever, bm25_retriever, hybrid_retriever = load_indices()
93
 
94
  # Function to process chat with RAG
95
- def chat_with_rag(message, history, retriever):
96
- # Get context from the index if RAG is enabled
97
  if st.session_state.get('use_rag', True):
98
  nodes = retriever.retrieve(message)
99
- # sort nodes by score
100
  nodes = sorted(nodes, key=lambda x: x.score, reverse=True)
101
- # nodes up to slider value
102
  nodes = nodes[:st.session_state.get('num_chunks', 1)]
103
  context = "\n\n".join([node.text for node in nodes])
104
  system_prompt = f"""{st.session_state.system_prompt}
@@ -108,26 +107,29 @@ def chat_with_rag(message, history, retriever):
108
  {context}
109
  """
110
 
111
- # Store sources in session state for this message
112
- # Calculate the correct message index (total number of messages)
113
  message_index = len(st.session_state.messages)
114
  st.session_state.sources[message_index] = nodes
115
  else:
116
  system_prompt = st.session_state.system_prompt
117
  nodes = []
118
 
119
- # Prepare messages for the API call
120
  messages = [ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)]
121
  for h in history:
122
  role = MessageRole.ASSISTANT if h["role"] == "assistant" else MessageRole.USER
123
  messages.append(ChatMessage(role=role, content=h["content"]))
124
  messages.append(ChatMessage(role=MessageRole.USER, content=message))
125
 
126
- # Call Groq via LiteLLM (replace with LlamaIndex's Groq)
127
- response = Settings.llm.chat(messages)
128
- assistant_response = response.message.content
 
 
 
 
129
 
130
- return assistant_response
 
 
131
 
132
  # Move the title to the top, before tabs
133
  st.title("Freud Explorer")
@@ -272,14 +274,15 @@ with tab2:
272
 
273
  with chat_container:
274
  with st.chat_message("assistant"):
275
- with st.spinner("Thinking..."):
276
- response = chat_with_rag(
277
- prompt,
278
- st.session_state.messages[:-1],
279
- hybrid_retriever if st.session_state.use_rag else None
280
- )
281
- st.markdown(response)
282
- st.session_state.messages.append({"role": "assistant", "content": response})
 
283
 
284
  st.rerun()
285
 
 
56
  model="llama3-8b-8192",
57
  api_key=os.getenv("GROQ_API_KEY"),
58
  max_tokens=6000,
59
+ context_window=6000,
60
+ stream=True # Enable streaming
61
  )
62
 
63
  @st.cache_resource
 
93
  index, vector_retriever, bm25_retriever, hybrid_retriever = load_indices()
94
 
95
  # Function to process chat with RAG
96
+ def chat_with_rag(message, history, retriever, response_placeholder):
97
+ """Modified to handle streaming"""
98
  if st.session_state.get('use_rag', True):
99
  nodes = retriever.retrieve(message)
 
100
  nodes = sorted(nodes, key=lambda x: x.score, reverse=True)
 
101
  nodes = nodes[:st.session_state.get('num_chunks', 1)]
102
  context = "\n\n".join([node.text for node in nodes])
103
  system_prompt = f"""{st.session_state.system_prompt}
 
107
  {context}
108
  """
109
 
 
 
110
  message_index = len(st.session_state.messages)
111
  st.session_state.sources[message_index] = nodes
112
  else:
113
  system_prompt = st.session_state.system_prompt
114
  nodes = []
115
 
 
116
  messages = [ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)]
117
  for h in history:
118
  role = MessageRole.ASSISTANT if h["role"] == "assistant" else MessageRole.USER
119
  messages.append(ChatMessage(role=role, content=h["content"]))
120
  messages.append(ChatMessage(role=MessageRole.USER, content=message))
121
 
122
+ # Stream the response
123
+ response_text = ""
124
+ for response in Settings.llm.stream_chat(messages):
125
+ if response.delta is not None:
126
+ response_text += response.delta
127
+ # Update the placeholder with the accumulated text
128
+ response_placeholder.markdown(response_text + "▌")
129
 
130
+ # Remove the cursor and return the complete response
131
+ response_placeholder.markdown(response_text)
132
+ return response_text
133
 
134
  # Move the title to the top, before tabs
135
  st.title("Freud Explorer")
 
274
 
275
  with chat_container:
276
  with st.chat_message("assistant"):
277
+ # Create a placeholder for the streaming response
278
+ response_placeholder = st.empty()
279
+ response = chat_with_rag(
280
+ prompt,
281
+ st.session_state.messages[:-1],
282
+ hybrid_retriever if st.session_state.use_rag else None,
283
+ response_placeholder
284
+ )
285
+ st.session_state.messages.append({"role": "assistant", "content": response})
286
 
287
  st.rerun()
288
 
data_preparation.ipynb CHANGED
@@ -66,7 +66,7 @@
66
  "path=\"txt\\Freud_Complete_en.txt\"\n",
67
  "\n",
68
  "if os.path.exists(path):\n",
69
- " print(load_txt(path)[:1000])"
70
  ]
71
  },
72
  {
 
66
  "path=\"txt\\Freud_Complete_en.txt\"\n",
67
  "\n",
68
  "if os.path.exists(path):\n",
69
+ " print(load_txt(path)[:1000]) "
70
  ]
71
  },
72
  {