Neurolingua
commited on
Commit
•
d876bf1
1
Parent(s):
d9b7a74
Update app.py
Browse files
app.py
CHANGED
@@ -206,27 +206,49 @@ def initialize_chroma():
|
|
206 |
initialize_chroma()
|
207 |
|
208 |
def query_rag(query_text: str):
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
-
|
|
|
|
|
|
|
230 |
|
231 |
def download_file(url, extension):
|
232 |
try:
|
|
|
206 |
initialize_chroma()
|
207 |
|
208 |
def query_rag(query_text: str):
|
209 |
+
try:
|
210 |
+
# Ensure query_text is a string
|
211 |
+
if not isinstance(query_text, str):
|
212 |
+
raise ValueError("Query text must be a string.")
|
213 |
+
|
214 |
+
# Initialize the embedding function and Chroma DB
|
215 |
+
embedding_function = get_embedding_function()
|
216 |
+
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
|
217 |
+
|
218 |
+
# Perform similarity search
|
219 |
+
results = db.similarity_search_with_score(query_text, k=5)
|
220 |
+
|
221 |
+
# Extract and clean context text
|
222 |
+
context_texts = [doc.page_content for doc, _score in results]
|
223 |
+
if not all(isinstance(text, str) for text in context_texts):
|
224 |
+
raise ValueError("All context texts must be strings.")
|
225 |
+
|
226 |
+
context_text = "\n\n---\n\n".join(context_texts)
|
227 |
+
|
228 |
+
# Create prompt
|
229 |
+
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
|
230 |
+
prompt = prompt_template.format(context=context_text, question=query_text)
|
231 |
+
|
232 |
+
# Generate response using AI71
|
233 |
+
response = ''
|
234 |
+
for chunk in AI71(AI71_API_KEY).chat.completions.create(
|
235 |
+
model="tiiuae/falcon-180b-chat",
|
236 |
+
messages=[
|
237 |
+
{"role": "system", "content": "You are the best agricultural assistant. Remember to give a response in not more than 2 sentences."},
|
238 |
+
{"role": "user", "content": f'Answer the following query based on the given context: {prompt}'},
|
239 |
+
],
|
240 |
+
stream=True,
|
241 |
+
):
|
242 |
+
if chunk.choices[0].delta.content:
|
243 |
+
response += chunk.choices[0].delta.content
|
244 |
+
|
245 |
+
# Return cleaned response
|
246 |
+
return response.replace("###", '').replace('\nUser:', '')
|
247 |
|
248 |
+
except Exception as e:
|
249 |
+
# Log the error and return a user-friendly message
|
250 |
+
print(f"Error in query_rag: {e}")
|
251 |
+
return "Sorry, there was an error processing your query."
|
252 |
|
253 |
def download_file(url, extension):
|
254 |
try:
|