Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
# IMPORTS
|
|
|
2 |
import warnings
|
3 |
warnings.filterwarnings("ignore", message="Failed to load HostKeys")
|
4 |
warnings.filterwarnings("ignore", message="The 'tuples' format for chatbot messages is deprecated")
|
@@ -17,7 +18,7 @@ from langchain_core.runnables import RunnableMap, RunnableLambda
|
|
17 |
from langchain.memory import ConversationBufferMemory
|
18 |
from langchain_groq import ChatGroq
|
19 |
|
20 |
-
#
|
21 |
|
22 |
SFTP_HOST = os.getenv("SFTP_HOST")
|
23 |
SFTP_USER = os.getenv("SFTP_USER")
|
@@ -26,11 +27,11 @@ SFTP_ALERTS_DIR = "/home/birkbeck/alerts"
|
|
26 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
27 |
HISTORICAL_JSON = "data/big_prize_data.json"
|
28 |
|
29 |
-
#
|
30 |
|
31 |
memory = ConversationBufferMemory(memory_key="chat_history", input_key="question")
|
32 |
|
33 |
-
#
|
34 |
|
35 |
def build_chroma_db():
|
36 |
with open(HISTORICAL_JSON) as f:
|
@@ -60,7 +61,7 @@ def build_chroma_db():
|
|
60 |
emb = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
61 |
return Chroma.from_documents(docs, emb)
|
62 |
|
63 |
-
#
|
64 |
|
65 |
def load_live_alerts():
|
66 |
cnopts = pysftp.CnOpts()
|
@@ -111,22 +112,23 @@ def load_live_alerts():
|
|
111 |
alerts.append(Document(page_content=content, metadata=md))
|
112 |
return alerts
|
113 |
|
114 |
-
#
|
115 |
|
116 |
-
db
|
117 |
live_docs = load_live_alerts()
|
118 |
|
119 |
def combined_docs(q: str):
|
120 |
hist = db.similarity_search(q, k=8)
|
121 |
return hist + live_docs
|
122 |
|
123 |
-
#
|
124 |
|
125 |
prompt = PromptTemplate(
|
126 |
input_variables=["chat_history","context","question"],
|
127 |
template="""
|
128 |
You are **Rafael The Raffler**, a calm friendly expert in instant-win raffle analysis.
|
129 |
-
|
|
|
130 |
Reasoning Rules:
|
131 |
1. **Interpreting βWhenβ:** Whenever the user asks βWhenβ¦?β, interpret that as βAt what tickets-sold count and percent did the prize win occur?β Do *not* give calendar dates or times.
|
132 |
--- Conversation So Far ---
|
@@ -140,23 +142,21 @@ Reasoning Rules:
|
|
140 |
|
141 |
def filter_docs(inputs):
|
142 |
docs, q = inputs["documents"], inputs["question"].lower()
|
143 |
-
#
|
144 |
if ("live" in q or "latest" in q or "recent" in q) and any(w in q for w in ("prize","raffle","won")):
|
145 |
live = [d for d in docs if d.metadata["source"]=="recent and live"]
|
146 |
if live:
|
147 |
recent = max(live, key=lambda d: parser.isoparse(d.metadata["timestamp"]))
|
148 |
return {"documents":[recent], "question":q}
|
149 |
-
#
|
150 |
m = re.search(r"(?:above|over|greater than)\s*Β£?([\d,]+)", q)
|
151 |
if m:
|
152 |
thr = float(m.group(1).replace(",",""))
|
153 |
docs = [d for d in docs if d.metadata["value"] > thr]
|
154 |
return {"documents":docs, "question":q}
|
155 |
|
156 |
-
#
|
157 |
|
158 |
-
# This template will turn "How many big prizes...?" into
|
159 |
-
# "In raffle 86, how many big prizes in total were won?"
|
160 |
question_rewrite_template = PromptTemplate(
|
161 |
input_variables=["chat_history","question"],
|
162 |
template="""
|
@@ -172,30 +172,27 @@ Rewritten standalone question:"""
|
|
172 |
)
|
173 |
|
174 |
rewrite_chain = (
|
175 |
-
# bundle history + raw question
|
176 |
RunnableLambda(lambda q: {
|
177 |
"chat_history": memory.load_memory_variables({})["chat_history"],
|
178 |
"question": q
|
179 |
})
|
180 |
-
# build the rewrite prompt
|
181 |
| RunnableLambda(lambda inp: question_rewrite_template.format(**inp))
|
182 |
-
# call the LLM to rewrite
|
183 |
| ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
|
184 |
| StrOutputParser()
|
185 |
)
|
186 |
|
187 |
-
#
|
188 |
|
189 |
retrieval_chain = (
|
190 |
-
# 1
|
191 |
rewrite_chain
|
192 |
-
# 2
|
193 |
| RunnableMap({
|
194 |
"documents": lambda rewritten_q: combined_docs(rewritten_q),
|
195 |
"question": lambda rewritten_q: rewritten_q
|
196 |
})
|
197 |
| RunnableLambda(filter_docs)
|
198 |
-
# 3
|
199 |
| RunnableLambda(lambda d: {
|
200 |
"chat_history": "\n".join(
|
201 |
memory.load_memory_variables({})["chat_history"].splitlines()[-4:]
|
@@ -203,13 +200,13 @@ retrieval_chain = (
|
|
203 |
"context": "\n".join(doc.page_content for doc in d["documents"]),
|
204 |
"question": d["question"]
|
205 |
})
|
206 |
-
# 4
|
207 |
| RunnableLambda(lambda inp: prompt.format(**inp))
|
208 |
| ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
|
209 |
| StrOutputParser()
|
210 |
)
|
211 |
|
212 |
-
#
|
213 |
|
214 |
WELCOME = """
|
215 |
π **Welcome to Rafael The Raffler**
|
@@ -217,10 +214,22 @@ Your raffle-analysis assistant with RAG.
|
|
217 |
Ask about raffle wins, ticket timing, prize values or the latest live raffle.
|
218 |
"""
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
def gradio_chat(question: str) -> str:
|
221 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
answer = retrieval_chain.invoke(question)
|
223 |
-
# store in memory for multiβturn
|
224 |
memory.save_context({"question": question}, {"answer": answer})
|
225 |
return answer
|
226 |
|
|
|
1 |
# IMPORTS
|
2 |
+
|
3 |
import warnings
|
4 |
warnings.filterwarnings("ignore", message="Failed to load HostKeys")
|
5 |
warnings.filterwarnings("ignore", message="The 'tuples' format for chatbot messages is deprecated")
|
|
|
18 |
from langchain.memory import ConversationBufferMemory
|
19 |
from langchain_groq import ChatGroq
|
20 |
|
21 |
+
# SECRETS & PATHS
|
22 |
|
23 |
SFTP_HOST = os.getenv("SFTP_HOST")
|
24 |
SFTP_USER = os.getenv("SFTP_USER")
|
|
|
27 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
28 |
HISTORICAL_JSON = "data/big_prize_data.json"
|
29 |
|
30 |
+
# CHAT MEMORY
|
31 |
|
32 |
memory = ConversationBufferMemory(memory_key="chat_history", input_key="question")
|
33 |
|
34 |
+
# BUILD HISTORICAL CHROMA DB
|
35 |
|
36 |
def build_chroma_db():
|
37 |
with open(HISTORICAL_JSON) as f:
|
|
|
61 |
emb = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
62 |
return Chroma.from_documents(docs, emb)
|
63 |
|
64 |
+
# LOAD RECENT/LIVE ALERTS VIA SFTP
|
65 |
|
66 |
def load_live_alerts():
|
67 |
cnopts = pysftp.CnOpts()
|
|
|
112 |
alerts.append(Document(page_content=content, metadata=md))
|
113 |
return alerts
|
114 |
|
115 |
+
# RETRIEVER
|
116 |
|
117 |
+
db = build_chroma_db()
|
118 |
live_docs = load_live_alerts()
|
119 |
|
120 |
def combined_docs(q: str):
|
121 |
hist = db.similarity_search(q, k=8)
|
122 |
return hist + live_docs
|
123 |
|
124 |
+
# PROMPT + FILTER CHAIN
|
125 |
|
126 |
prompt = PromptTemplate(
|
127 |
input_variables=["chat_history","context","question"],
|
128 |
template="""
|
129 |
You are **Rafael The Raffler**, a calm friendly expert in instant-win raffle analysis.
|
130 |
+
**Only** describe your strengths (raffle timing, value insights, patterns) when the user explicitly asks βwhat do you do?β or "what you good at?".
|
131 |
+
If they merely greet you or ask anything else, do **not** list your strengthsβjust answer the question.
|
132 |
Reasoning Rules:
|
133 |
1. **Interpreting βWhenβ:** Whenever the user asks βWhenβ¦?β, interpret that as βAt what tickets-sold count and percent did the prize win occur?β Do *not* give calendar dates or times.
|
134 |
--- Conversation So Far ---
|
|
|
142 |
|
143 |
def filter_docs(inputs):
|
144 |
docs, q = inputs["documents"], inputs["question"].lower()
|
145 |
+
# RECENT/LIVE
|
146 |
if ("live" in q or "latest" in q or "recent" in q) and any(w in q for w in ("prize","raffle","won")):
|
147 |
live = [d for d in docs if d.metadata["source"]=="recent and live"]
|
148 |
if live:
|
149 |
recent = max(live, key=lambda d: parser.isoparse(d.metadata["timestamp"]))
|
150 |
return {"documents":[recent], "question":q}
|
151 |
+
# THRESHOLD
|
152 |
m = re.search(r"(?:above|over|greater than)\s*Β£?([\d,]+)", q)
|
153 |
if m:
|
154 |
thr = float(m.group(1).replace(",",""))
|
155 |
docs = [d for d in docs if d.metadata["value"] > thr]
|
156 |
return {"documents":docs, "question":q}
|
157 |
|
158 |
+
# FOLLOW-UP QUESTION REWRITING
|
159 |
|
|
|
|
|
160 |
question_rewrite_template = PromptTemplate(
|
161 |
input_variables=["chat_history","question"],
|
162 |
template="""
|
|
|
172 |
)
|
173 |
|
174 |
rewrite_chain = (
|
|
|
175 |
RunnableLambda(lambda q: {
|
176 |
"chat_history": memory.load_memory_variables({})["chat_history"],
|
177 |
"question": q
|
178 |
})
|
|
|
179 |
| RunnableLambda(lambda inp: question_rewrite_template.format(**inp))
|
|
|
180 |
| ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
|
181 |
| StrOutputParser()
|
182 |
)
|
183 |
|
184 |
+
# RAG + CHATGROQ CHAIN (WITH REWRITE) ββββββββββββββββββββββββββββββββββββ
|
185 |
|
186 |
retrieval_chain = (
|
187 |
+
# 1. REWRITE QUESTION FIRST
|
188 |
rewrite_chain
|
189 |
+
# 2. RETRIEVE DOCS AGAINST REWRITTEN QUESTION
|
190 |
| RunnableMap({
|
191 |
"documents": lambda rewritten_q: combined_docs(rewritten_q),
|
192 |
"question": lambda rewritten_q: rewritten_q
|
193 |
})
|
194 |
| RunnableLambda(filter_docs)
|
195 |
+
# 3. BUILD FINAL INPUTS AND TRUNCATE HISTORY
|
196 |
| RunnableLambda(lambda d: {
|
197 |
"chat_history": "\n".join(
|
198 |
memory.load_memory_variables({})["chat_history"].splitlines()[-4:]
|
|
|
200 |
"context": "\n".join(doc.page_content for doc in d["documents"]),
|
201 |
"question": d["question"]
|
202 |
})
|
203 |
+
# 4. FORMAT FINAL PROMPT AND CALL LLM
|
204 |
| RunnableLambda(lambda inp: prompt.format(**inp))
|
205 |
| ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192")
|
206 |
| StrOutputParser()
|
207 |
)
|
208 |
|
209 |
+
# GRADIO
|
210 |
|
211 |
WELCOME = """
|
212 |
π **Welcome to Rafael The Raffler**
|
|
|
214 |
Ask about raffle wins, ticket timing, prize values or the latest live raffle.
|
215 |
"""
|
216 |
|
217 |
+
# GREETING HANDLING
|
218 |
+
|
219 |
+
def handle_greeting(question: str):
|
220 |
+
if re.match(r'^(hi|hello|hey)[.!?]*$', question.strip(), re.I):
|
221 |
+
return "Hello! How can I help you with your raffle analysis today?"
|
222 |
+
|
223 |
def gradio_chat(question: str) -> str:
|
224 |
+
# 1. GREETING ONLY?
|
225 |
+
greet = handle_greeting(question)
|
226 |
+
if greet:
|
227 |
+
# SAVE GREETING
|
228 |
+
memory.save_context({"question": question}, {"answer": greet})
|
229 |
+
return greet
|
230 |
+
|
231 |
+
# 2. OTHERWISE > RAG CHAIN
|
232 |
answer = retrieval_chain.invoke(question)
|
|
|
233 |
memory.save_context({"question": question}, {"answer": answer})
|
234 |
return answer
|
235 |
|