0ndr3 commited on
Commit
e1cc806
Β·
verified Β·
1 Parent(s): a95a961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # IMPORTS
 
2
  import warnings
3
  warnings.filterwarnings("ignore", message="Failed to load HostKeys")
4
  warnings.filterwarnings("ignore", message="The 'tuples' format for chatbot messages is deprecated")
@@ -17,7 +18,7 @@ from langchain_core.runnables import RunnableMap, RunnableLambda
17
  from langchain.memory import ConversationBufferMemory
18
  from langchain_groq import ChatGroq
19
 
20
- # ─── Secrets & Paths ────────────────────────────────────────────────────────────
21
 
22
  SFTP_HOST = os.getenv("SFTP_HOST")
23
  SFTP_USER = os.getenv("SFTP_USER")
@@ -26,11 +27,11 @@ SFTP_ALERTS_DIR = "/home/birkbeck/alerts"
26
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
27
  HISTORICAL_JSON = "data/big_prize_data.json"
28
 
29
- # ─── Chat Memory ────────────────────────────────────────────────────────────────
30
 
31
  memory = ConversationBufferMemory(memory_key="chat_history", input_key="question")
32
 
33
- # ─── 1) Build Historical Chroma DB ───────────────────────────────────────────────
34
 
35
  def build_chroma_db():
36
  with open(HISTORICAL_JSON) as f:
@@ -60,7 +61,7 @@ def build_chroma_db():
60
  emb = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
61
  return Chroma.from_documents(docs, emb)
62
 
63
- # ─── 2) Load Live Alerts via SFTP ───────────────────────────────────────────────
64
 
65
  def load_live_alerts():
66
  cnopts = pysftp.CnOpts()
@@ -111,22 +112,23 @@ def load_live_alerts():
111
  alerts.append(Document(page_content=content, metadata=md))
112
  return alerts
113
 
114
- # ─── 3) Retriever ───────────────────────────────────────────────────────────────
115
 
116
- db = build_chroma_db()
117
  live_docs = load_live_alerts()
118
 
119
  def combined_docs(q: str):
120
  hist = db.similarity_search(q, k=8)
121
  return hist + live_docs
122
 
123
- # ─── 4) Prompt + Filter Chain ───────────────────────────────────────────────────
124
 
125
  prompt = PromptTemplate(
126
  input_variables=["chat_history","context","question"],
127
  template="""
128
  You are **Rafael The Raffler**, a calm friendly expert in instant-win raffle analysis.
129
- If asked β€œwhat do you do?”, give a bullet list of your strengths (raffle timing, value insights, patterns).
 
130
  Reasoning Rules:
131
  1. **Interpreting β€œWhen”:** Whenever the user asks β€œWhen…?”, interpret that as β€œAt what tickets-sold count and percent did the prize win occur?” Do *not* give calendar dates or times.
132
  --- Conversation So Far ---
@@ -140,23 +142,21 @@ Reasoning Rules:
140
 
141
  def filter_docs(inputs):
142
  docs, q = inputs["documents"], inputs["question"].lower()
143
- # live‐prize query?
144
  if ("live" in q or "latest" in q or "recent" in q) and any(w in q for w in ("prize","raffle","won")):
145
  live = [d for d in docs if d.metadata["source"]=="recent and live"]
146
  if live:
147
  recent = max(live, key=lambda d: parser.isoparse(d.metadata["timestamp"]))
148
  return {"documents":[recent], "question":q}
149
- # threshold filter
150
  m = re.search(r"(?:above|over|greater than)\s*Β£?([\d,]+)", q)
151
  if m:
152
  thr = float(m.group(1).replace(",",""))
153
  docs = [d for d in docs if d.metadata["value"] > thr]
154
  return {"documents":docs, "question":q}
155
 
156
- # ─── Follow-up Question Rewriting ───────────────────────────────────────────────
157
 
158
- # This template will turn "How many big prizes...?" into
159
- # "In raffle 86, how many big prizes in total were won?"
160
  question_rewrite_template = PromptTemplate(
161
  input_variables=["chat_history","question"],
162
  template="""
@@ -172,30 +172,27 @@ Rewritten standalone question:"""
172
  )
173
 
174
  rewrite_chain = (
175
- # bundle history + raw question
176
  RunnableLambda(lambda q: {
177
  "chat_history": memory.load_memory_variables({})["chat_history"],
178
  "question": q
179
  })
180
- # build the rewrite prompt
181
  | RunnableLambda(lambda inp: question_rewrite_template.format(**inp))
182
- # call the LLM to rewrite
183
  | ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
184
  | StrOutputParser()
185
  )
186
 
187
- # ─── 5) RAG + ChatGroq Chain (with rewrite) ────────────────────────────────────
188
 
189
  retrieval_chain = (
190
- # 1) Rewrite the question first
191
  rewrite_chain
192
- # 2) Retrieve docs against the rewritten question
193
  | RunnableMap({
194
  "documents": lambda rewritten_q: combined_docs(rewritten_q),
195
  "question": lambda rewritten_q: rewritten_q
196
  })
197
  | RunnableLambda(filter_docs)
198
- # 3) Build final inputs and truncate history
199
  | RunnableLambda(lambda d: {
200
  "chat_history": "\n".join(
201
  memory.load_memory_variables({})["chat_history"].splitlines()[-4:]
@@ -203,13 +200,13 @@ retrieval_chain = (
203
  "context": "\n".join(doc.page_content for doc in d["documents"]),
204
  "question": d["question"]
205
  })
206
- # 4) Format final prompt and call LLM
207
  | RunnableLambda(lambda inp: prompt.format(**inp))
208
  | ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
209
  | StrOutputParser()
210
  )
211
 
212
- # ─── 6) Gradio Interface ───────────────────────────────────────────────────────
213
 
214
  WELCOME = """
215
  πŸ‘‹ **Welcome to Rafael The Raffler**
@@ -217,10 +214,22 @@ Your raffle-analysis assistant with RAG.
217
  Ask about raffle wins, ticket timing, prize values or the latest live raffle.
218
  """
219
 
 
 
 
 
 
 
220
  def gradio_chat(question: str) -> str:
221
- # run RAG chain
 
 
 
 
 
 
 
222
  answer = retrieval_chain.invoke(question)
223
- # store in memory for multi‐turn
224
  memory.save_context({"question": question}, {"answer": answer})
225
  return answer
226
 
 
1
  # IMPORTS
2
+
3
  import warnings
4
  warnings.filterwarnings("ignore", message="Failed to load HostKeys")
5
  warnings.filterwarnings("ignore", message="The 'tuples' format for chatbot messages is deprecated")
 
18
  from langchain.memory import ConversationBufferMemory
19
  from langchain_groq import ChatGroq
20
 
21
+ # SECRETS & PATHS
22
 
23
  SFTP_HOST = os.getenv("SFTP_HOST")
24
  SFTP_USER = os.getenv("SFTP_USER")
 
27
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
28
  HISTORICAL_JSON = "data/big_prize_data.json"
29
 
30
+ # CHAT MEMORY
31
 
32
  memory = ConversationBufferMemory(memory_key="chat_history", input_key="question")
33
 
34
+ # BUILD HISTORICAL CHROMA DB
35
 
36
  def build_chroma_db():
37
  with open(HISTORICAL_JSON) as f:
 
61
  emb = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
62
  return Chroma.from_documents(docs, emb)
63
 
64
+ # LOAD RECENT/LIVE ALERTS VIA SFTP
65
 
66
  def load_live_alerts():
67
  cnopts = pysftp.CnOpts()
 
112
  alerts.append(Document(page_content=content, metadata=md))
113
  return alerts
114
 
115
+ # RETRIEVER
116
 
117
+ db = build_chroma_db()
118
  live_docs = load_live_alerts()
119
 
120
  def combined_docs(q: str):
121
  hist = db.similarity_search(q, k=8)
122
  return hist + live_docs
123
 
124
+ # PROMPT + FILTER CHAIN
125
 
126
  prompt = PromptTemplate(
127
  input_variables=["chat_history","context","question"],
128
  template="""
129
  You are **Rafael The Raffler**, a calm friendly expert in instant-win raffle analysis.
130
+ **Only** describe your strengths (raffle timing, value insights, patterns) when the user explicitly asks β€œwhat do you do?” or "what you good at?".
131
+ If they merely greet you or ask anything else, do **not** list your strengthsβ€”just answer the question.
132
  Reasoning Rules:
133
  1. **Interpreting β€œWhen”:** Whenever the user asks β€œWhen…?”, interpret that as β€œAt what tickets-sold count and percent did the prize win occur?” Do *not* give calendar dates or times.
134
  --- Conversation So Far ---
 
142
 
143
  def filter_docs(inputs):
144
  docs, q = inputs["documents"], inputs["question"].lower()
145
+ # RECENT/LIVE
146
  if ("live" in q or "latest" in q or "recent" in q) and any(w in q for w in ("prize","raffle","won")):
147
  live = [d for d in docs if d.metadata["source"]=="recent and live"]
148
  if live:
149
  recent = max(live, key=lambda d: parser.isoparse(d.metadata["timestamp"]))
150
  return {"documents":[recent], "question":q}
151
+ # THRESHOLD
152
  m = re.search(r"(?:above|over|greater than)\s*Β£?([\d,]+)", q)
153
  if m:
154
  thr = float(m.group(1).replace(",",""))
155
  docs = [d for d in docs if d.metadata["value"] > thr]
156
  return {"documents":docs, "question":q}
157
 
158
+ # FOLLOW-UP QUESTION REWRITING
159
 
 
 
160
  question_rewrite_template = PromptTemplate(
161
  input_variables=["chat_history","question"],
162
  template="""
 
172
  )
173
 
174
  rewrite_chain = (
 
175
  RunnableLambda(lambda q: {
176
  "chat_history": memory.load_memory_variables({})["chat_history"],
177
  "question": q
178
  })
 
179
  | RunnableLambda(lambda inp: question_rewrite_template.format(**inp))
 
180
  | ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
181
  | StrOutputParser()
182
  )
183
 
184
+ # RAG + CHATGROQ CHAIN (WITH REWRITE) ────────────────────────────────────
185
 
186
  retrieval_chain = (
187
+ # 1. REWRITE QUESTION FIRST
188
  rewrite_chain
189
+ # 2. RETRIEVE DOCS AGAINST REWRITTEN QUESTION
190
  | RunnableMap({
191
  "documents": lambda rewritten_q: combined_docs(rewritten_q),
192
  "question": lambda rewritten_q: rewritten_q
193
  })
194
  | RunnableLambda(filter_docs)
195
+ # 3. BUILD FINAL INPUTS AND TRUNCATE HISTORY
196
  | RunnableLambda(lambda d: {
197
  "chat_history": "\n".join(
198
  memory.load_memory_variables({})["chat_history"].splitlines()[-4:]
 
200
  "context": "\n".join(doc.page_content for doc in d["documents"]),
201
  "question": d["question"]
202
  })
203
+ # 4. FORMAT FINAL PROMPT AND CALL LLM
204
  | RunnableLambda(lambda inp: prompt.format(**inp))
205
  | ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
206
  | StrOutputParser()
207
  )
208
 
209
+ # GRADIO
210
 
211
  WELCOME = """
212
  πŸ‘‹ **Welcome to Rafael The Raffler**
 
214
  Ask about raffle wins, ticket timing, prize values or the latest live raffle.
215
  """
216
 
217
+ # GREETING HANDLING
218
+
219
+ def handle_greeting(question: str):
220
+ if re.match(r'^(hi|hello|hey)[.!?]*$', question.strip(), re.I):
221
+ return "Hello! How can I help you with your raffle analysis today?"
222
+
223
  def gradio_chat(question: str) -> str:
224
+ # 1. GREETING ONLY?
225
+ greet = handle_greeting(question)
226
+ if greet:
227
+ # SAVE GREETING
228
+ memory.save_context({"question": question}, {"answer": greet})
229
+ return greet
230
+
231
+ # 2. OTHERWISE > RAG CHAIN
232
  answer = retrieval_chain.invoke(question)
 
233
  memory.save_context({"question": question}, {"answer": answer})
234
  return answer
235