Neurolingua commited on
Commit
d876bf1
1 Parent(s): d9b7a74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -20
app.py CHANGED
@@ -206,27 +206,49 @@ def initialize_chroma():
206
  initialize_chroma()
207
 
208
  def query_rag(query_text: str):
209
- embedding_function = get_embedding_function()
210
- db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
211
- results = db.similarity_search_with_score(query_text, k=5)
212
- context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
213
-
214
- prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
215
- prompt = prompt_template.format(context=context_text, question=query_text)
216
-
217
- response = ''
218
- for chunk in AI71(AI71_API_KEY).chat.completions.create(
219
- model="tiiuae/falcon-180b-chat",
220
- messages=[
221
- {"role": "system", "content": "You are the best agricultural assistant. Remember to give a response in not more than 2 sentences."},
222
- {"role": "user", "content": f'''Answer the following query based on the given context: {prompt}'''},
223
- ],
224
- stream=True,
225
- ):
226
- if chunk.choices[0].delta.content:
227
- response += chunk.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
- return response.replace("###", '').replace('\nUser:', '')
 
 
 
230
 
231
  def download_file(url, extension):
232
  try:
 
206
  initialize_chroma()
207
 
208
  def query_rag(query_text: str):
209
+ try:
210
+ # Ensure query_text is a string
211
+ if not isinstance(query_text, str):
212
+ raise ValueError("Query text must be a string.")
213
+
214
+ # Initialize the embedding function and Chroma DB
215
+ embedding_function = get_embedding_function()
216
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
217
+
218
+ # Perform similarity search
219
+ results = db.similarity_search_with_score(query_text, k=5)
220
+
221
+ # Extract and clean context text
222
+ context_texts = [doc.page_content for doc, _score in results]
223
+ if not all(isinstance(text, str) for text in context_texts):
224
+ raise ValueError("All context texts must be strings.")
225
+
226
+ context_text = "\n\n---\n\n".join(context_texts)
227
+
228
+ # Create prompt
229
+ prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
230
+ prompt = prompt_template.format(context=context_text, question=query_text)
231
+
232
+ # Generate response using AI71
233
+ response = ''
234
+ for chunk in AI71(AI71_API_KEY).chat.completions.create(
235
+ model="tiiuae/falcon-180b-chat",
236
+ messages=[
237
+ {"role": "system", "content": "You are the best agricultural assistant. Remember to give a response in not more than 2 sentences."},
238
+ {"role": "user", "content": f'Answer the following query based on the given context: {prompt}'},
239
+ ],
240
+ stream=True,
241
+ ):
242
+ if chunk.choices[0].delta.content:
243
+ response += chunk.choices[0].delta.content
244
+
245
+ # Return cleaned response
246
+ return response.replace("###", '').replace('\nUser:', '')
247
 
248
+ except Exception as e:
249
+ # Log the error and return a user-friendly message
250
+ print(f"Error in query_rag: {e}")
251
+ return "Sorry, there was an error processing your query."
252
 
253
  def download_file(url, extension):
254
  try: