nagesh5 commited on
Commit
108fee4
·
verified ·
1 Parent(s): 2651933

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +752 -0
app.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import os # Interacting with the operating system (reading/writing files)
4
+ import chromadb # High-performance vector database for storing/querying dense vectors
5
+ from dotenv import load_dotenv # Loading environment variables from a .env file
6
+ import json # Parsing and handling JSON data
7
+
8
+ # LangChain imports
9
+ from langchain_core.documents import Document # Document data structures
10
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
11
+ from langchain_core.output_parsers import StrOutputParser # String output parser
12
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
13
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
14
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
15
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
16
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
17
+
18
+ # LangChain community & experimental imports
19
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
20
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
21
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
22
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
23
+ from langchain.text_splitter import (
24
+ CharacterTextSplitter, # Splitting text by characters
25
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
26
+ )
27
+ from langchain_core.tools import tool
28
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
29
+ from langchain_core.prompts import ChatPromptTemplate
30
+
31
+ # LangChain OpenAI imports
32
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
33
+ from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
34
+
35
+ # LlamaParse & LlamaIndex imports
36
+ from llama_parse import LlamaParse # Document parsing library
37
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
38
+
39
+ # LangGraph import
40
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
41
+
42
+ # Pydantic import
43
+ from pydantic import BaseModel # Pydantic for data validation
44
+
45
+ # Typing imports
46
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
47
+
48
+ # Other utilities
49
+ import numpy as np # Numpy for numerical operations
50
+ from groq import Groq
51
+ from mem0 import MemoryClient
52
+ import streamlit as st
53
+ from datetime import datetime
54
+
55
+ #====================================SETUP=====================================#
56
+ # Fetch secrets from Hugging Face Spaces
57
+ api_key = config.get("API_KEY")
58
+ endpoint = config.get("OPENAI_API_BASE")
59
+ llama_api_key = os.environ['LLAMA_API_KEY']
60
+ MEM0_api_key = os.environ['mem0']
61
+
62
+ # Initialize the OpenAI embedding function for Chroma
63
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
64
+ api_base=endpoint, # Complete the code to define the API base endpoint
65
+ api_key=api_key, # Complete the code to define the API key
66
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
67
+ )
68
+
69
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
70
+
71
+ # Initialize the OpenAI Embeddings
72
+ embedding_model = OpenAIEmbeddings(
73
+ openai_api_base=endpoint,
74
+ openai_api_key=api_key,
75
+ model='text-embedding-ada-002'
76
+ )
77
+
78
+
79
+ # Initialize the Chat OpenAI model
80
+ llm = ChatOpenAI(
81
+ openai_api_base=endpoint,
82
+ openai_api_key=api_key,
83
+ model="gpt-4o-mini",
84
+ streaming=False
85
+ )
86
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
87
+
88
+ # set the LLM and embedding model in the LlamaIndex settings.
89
+ Settings.llm = llm # Complete the code to define the LLM model
90
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
91
+
92
+ #================================Creating Langgraph agent======================#
93
+
94
+ class AgentState(TypedDict):
95
+ query: str # The current user query
96
+ expanded_query: str # The expanded version of the user query
97
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
98
+ response: str # The generated response to the user query
99
+ precision_score: float # The precision score of the response
100
+ groundedness_score: float # The groundedness score of the response
101
+ groundedness_loop_count: int # Counter for groundedness refinement loops
102
+ precision_loop_count: int # Counter for precision refinement loops
103
+ feedback: str
104
+ query_feedback: str
105
+ groundedness_check: bool
106
+ loop_max_iter: int
107
+
108
+ def expand_query(state):
109
+ """
110
+ Expands the user query to improve retrieval of nutrition disorder-related information.
111
+
112
+ Args:
113
+ state (Dict): The current state of the workflow, containing the user query.
114
+
115
+ Returns:
116
+ Dict: The updated state with the expanded query.
117
+ """
118
+ print("---------Expanding Query---------")
119
+ system_message = """
120
+ You are a domain expert assisting in answering questions related to Nutritional Disorders.
121
+ Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
122
+ or common synonyms for key words in the question, make sure to return multiple versions \
123
+ of the query with the different phrasings.
124
+
125
+ If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries.
126
+
127
+ If there are acronyms or words you are not familiar with, do not try to rephrase them.
128
+
129
+ Return only 3 versions of the question as a list.
130
+ Generate only a list of questions. Do not mention anything before or after the list.
131
+
132
+ Question:
133
+ {question}
134
+ """
135
+
136
+ expand_prompt = ChatPromptTemplate.from_messages([
137
+ ("system", system_message),
138
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
139
+
140
+ ])
141
+
142
+ chain = expand_prompt | llm | StrOutputParser()
143
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
144
+ print("expanded_query", expanded_query)
145
+ state["expanded_query"] = expanded_query
146
+ return state
147
+
148
+
149
+ # Initialize the Chroma vector store for retrieving documents
150
+ vector_store = Chroma(
151
+ collection_name="nutritional_hypotheticals",
152
+ persist_directory="./nutritional_db",
153
+ embedding_function=embedding_model
154
+
155
+ )
156
+
157
+ # Create a retriever from the vector store
158
+ retriever = vector_store.as_retriever(
159
+ search_type='similarity',
160
+ search_kwargs={'k': 3}
161
+ )
162
+
163
+ def retrieve_context(state):
164
+ """
165
+ Retrieves context from the vector store using the expanded or original query.
166
+
167
+ Args:
168
+ state (Dict): The current state of the workflow, containing the query and expanded query.
169
+
170
+ Returns:
171
+ Dict: The updated state with the retrieved context.
172
+ """
173
+ print("---------retrieve_context---------")
174
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
175
+ #print("Query used for retrieval:", query) # Debugging: Print the query
176
+
177
+ # Retrieve documents from the vector store
178
+ docs = retriever.invoke(query)
179
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
180
+
181
+ # Extract both page_content and metadata from each document
182
+ context= [
183
+ {
184
+ "content": doc.page_content, # The actual content of the document
185
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
186
+ }
187
+ for doc in docs
188
+ ]
189
+ state['context'] = context # Complete the code to define the key for storing the context
190
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
191
+ #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
192
+ return state
193
+
194
+
195
+
196
+ def craft_response(state: Dict) -> Dict:
197
+ """
198
+ Generates a response using the retrieved context, focusing on nutrition disorders.
199
+
200
+ Args:
201
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
202
+
203
+ Returns:
204
+ Dict: The updated state with the generated response.
205
+ """
206
+ print("---------craft_response---------")
207
+ system_message = """You are a helpful domain expert in nutritional disorders, that is well versed with crafting responses using the information \
208
+ you have gathered from various sources. You have to summarize the information you have gathered and present it in a coherent manner.
209
+ Add citations to the source using numbers in the response and end the response with the list of sources links used in the response."""
210
+
211
+ response_prompt = ChatPromptTemplate.from_messages([
212
+ ("system", system_message),
213
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
214
+ ])
215
+
216
+ chain = response_prompt | llm
217
+ response = chain.invoke({
218
+ "query": state['query'],
219
+ "context": "\n".join([doc["content"] for doc in state['context']]),
220
+ "feedback": "How to improve the query for better results" # add feedback to the prompt
221
+ })
222
+ state['response'] = response
223
+ print("intermediate response: ", response)
224
+
225
+ return state
226
+
227
+
228
+
229
+ def score_groundedness(state: Dict) -> Dict:
230
+ """
231
+ Checks whether the response is grounded in the retrieved context.
232
+
233
+ Args:
234
+ state (Dict): The current state of the workflow, containing the response and context.
235
+
236
+ Returns:
237
+ Dict: The updated state with the groundedness score.
238
+ """
239
+ print("---------check_groundedness---------")
240
+ system_message = '''You are an expert in evaluating if a response is grounded in the provided context.
241
+ You will be given a context and a response.
242
+ Your task is to evaluate the groundedness of the response based on the context.
243
+ Return a score between 0 and 1, where 1 means the response is fully grounded in the context and 0 means the response is not grounded in the context.
244
+ Return only the score as a float.'''
245
+
246
+ groundedness_prompt = ChatPromptTemplate.from_messages([
247
+ ("system", system_message),
248
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
249
+ ])
250
+
251
+ chain = groundedness_prompt | llm | StrOutputParser()
252
+ groundedness_score = float(chain.invoke({
253
+ "context": "\n".join([doc["content"] for doc in state['context']]),
254
+ "response": state['response'] # Complete the code to define the response
255
+ }))
256
+ print("groundedness_score: ", groundedness_score)
257
+ state['groundedness_loop_count'] += 1
258
+ print("#########Groundedness Incremented###########")
259
+ state['groundedness_score'] = groundedness_score
260
+
261
+ return state
262
+
263
+
264
+
265
+ def check_precision(state: Dict) -> Dict:
266
+ """
267
+ Checks whether the response precisely addresses the user’s query.
268
+
269
+ Args:
270
+ state (Dict): The current state of the workflow, containing the query and response.
271
+
272
+ Returns:
273
+ Dict: The updated state with the precision score.
274
+ """
275
+ print("---------check_precision---------")
276
+ system_message = '''You are an expert in evaluating if a response is precise and directly addresses the user's query.
277
+ You will be given a user query and a response.
278
+ Your task is to evaluate how well the response answers the user's query.
279
+ Return a score between 0 and 1, where 1 means the response is perfectly precise and directly answers the query, and 0 means the response is not at all precise or relevant to the query.
280
+ Return only the score as a float.'''
281
+
282
+ precision_prompt = ChatPromptTemplate.from_messages([
283
+ ("system", system_message),
284
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
285
+ ])
286
+
287
+ chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
288
+ precision_score = float(chain.invoke({
289
+ "query": state['query'],
290
+ "response":state['response'] # Complete the code to access the response from the state
291
+ }))
292
+ state['precision_score'] = precision_score
293
+ print("precision_score:", precision_score)
294
+ state['precision_loop_count'] +=1
295
+ print("#########Precision Incremented###########")
296
+ return state
297
+
298
+
299
+
300
+ def refine_response(state: Dict) -> Dict:
301
+ """
302
+ Suggests improvements for the generated response.
303
+
304
+ Args:
305
+ state (Dict): The current state of the workflow, containing the query and response.
306
+
307
+ Returns:
308
+ Dict: The updated state with response refinement suggestions.
309
+ """
310
+ print("---------refine_response---------")
311
+
312
+ system_message = '''You are an expert in providing constructive feedback on a generated response.
313
+ You will be given a user query and a response.
314
+ Your task is to identify potential gaps, ambiguities, or missing details in the response and suggest improvements to enhance accuracy and completeness.
315
+ Do not rewrite the response, only provide suggestions for improvement.'''
316
+
317
+ refine_response_prompt = ChatPromptTemplate.from_messages([
318
+ ("system", system_message),
319
+ ("user", "Query: {query}\nResponse: {response}\n\n"
320
+ "What improvements can be made to enhance accuracy and completeness?")
321
+ ])
322
+
323
+ chain = refine_response_prompt | llm| StrOutputParser()
324
+
325
+ # Store response suggestions in a structured format
326
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
327
+ print("feedback: ", feedback)
328
+ print(f"State: {state}")
329
+ state['feedback'] = feedback
330
+ return state
331
+
332
+
333
+
334
+ def refine_query(state: Dict) -> Dict:
335
+ """
336
+ Suggests improvements for the expanded query.
337
+
338
+ Args:
339
+ state (Dict): The current state of the workflow, containing the query and expanded query.
340
+
341
+ Returns:
342
+ Dict: The updated state with query refinement suggestions.
343
+ """
344
+ print("---------refine_query---------")
345
+ system_message = '''You are an expert in providing constructive feedback on a generated expanded query.
346
+ You will be given the original user query and the expanded query.
347
+ Your task is to identify missing details, specific keywords, or scope refinements that can enhance search precision.
348
+ Do not rewrite the expanded query, only provide structured suggestions for improvement.'''
349
+
350
+ refine_query_prompt = ChatPromptTemplate.from_messages([
351
+ ("system", system_message),
352
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
353
+ "What improvements can be made for a better search?")
354
+ ])
355
+
356
+ chain = refine_query_prompt | llm | StrOutputParser()
357
+
358
+ # Store refinement suggestions without modifying the original expanded query
359
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
360
+ print("query_feedback: ", query_feedback)
361
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
362
+ state['query_feedback'] = query_feedback
363
+ return state
364
+
365
+
366
+
367
+ def should_continue_groundedness(state):
368
+ """Decides if groundedness is sufficient or needs improvement."""
369
+ print("---------should_continue_groundedness---------")
370
+ print("groundedness loop count: ", state['groundedness_loop_count'])
371
+ if state['groundedness_score'] >= 0.7: # Complete the code to define the threshold for groundedness
372
+ print("Moving to precision")
373
+ return "check_precision"
374
+ else:
375
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
376
+ return "max_iterations_reached"
377
+ else:
378
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
379
+ return "refine_response"
380
+
381
+
382
+ def should_continue_precision(state: Dict) -> str:
383
+ """Decides if precision is sufficient or needs improvement."""
384
+ print("---------should_continue_precision---------")
385
+ print("precision loop count: ", state['precision_loop_count'])
386
+ if state['precision_score'] >= 0.7: # Threshold for precision
387
+ return "pass" # Complete the workflow
388
+ else:
389
+ if state["precision_loop_count"] > state['loop_max_iter']: # Maximum allowed loops
390
+ return "max_iterations_reached"
391
+ else:
392
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
393
+ return "refine_query" # Refine the query
394
+
395
+
396
+
397
+
398
+ def max_iterations_reached(state: Dict) -> Dict:
399
+ """Handles the case when the maximum number of iterations is reached."""
400
+ print("---------max_iterations_reached---------")
401
+ """Handles the case when the maximum number of iterations is reached."""
402
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
403
+ state['response'] = response
404
+ return state
405
+
406
+
407
+
408
+ from langgraph.graph import END, StateGraph, START
409
+
410
+ def create_workflow() -> StateGraph:
411
+ """Creates the updated workflow for the AI nutrition agent."""
412
+ workflow = StateGraph(AgentState) # Complete the code to define the initial state of the agent
413
+
414
+ # Add processing nodes
415
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. Complete with the function to expand the query
416
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
417
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
418
+ workflow.add_node("score_groundedness", score_groundness) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
419
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
420
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. Complete with the function to check precision
421
+ workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
422
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
423
+
424
+ # Main flow edges
425
+ workflow.add_edge(START, "expand_query")
426
+ workflow.add_edge("expand_query", "retrieve_context")
427
+ workflow.add_edge("retrieve_context", "craft_response")
428
+ workflow.add_edge("craft_response", "score_groundedness")
429
+
430
+ # Conditional edges based on groundedness check
431
+ workflow.add_conditional_edges(
432
+ "score_groundedness",
433
+ should_continue_groundedness, # Use the conditional function
434
+ {
435
+ "check_precision": check_precision, # If well-grounded, proceed to precision check.
436
+ "refine_response": refine_response, # If not, refine the response.
437
+ "max_iterations_reached": max_iterations_reached # If max loops reached, exit.
438
+ }
439
+ )
440
+
441
+ workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
442
+
443
+ # Conditional edges based on precision check
444
+ workflow.add_conditional_edges(
445
+ "check_precision",
446
+ should_continue_precision, # Use the conditional function
447
+ {
448
+ "pass": END, # If precise, complete the workflow.
449
+ "refine_query": "refine_query", # If imprecise, refine the query.
450
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
451
+ }
452
+ )
453
+
454
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
455
+
456
+ workflow.add_edge("max_iterations_reached", END)
457
+
458
+ return workflow
459
+
460
+
461
+
462
+
463
+ #=========================== Defining the agentic rag tool ====================#
464
+ WORKFLOW_APP = create_workflow().compile()
465
+ @tool
466
+ def agentic_rag(query: str):
467
+ """
468
+ Runs the RAG-based agent with conversation history for context-aware responses.
469
+
470
+ Args:
471
+ query (str): The current user query.
472
+
473
+ Returns:
474
+ Dict[str, Any]: The updated state with the generated response and conversation history.
475
+ """
476
+ # Initialize state with necessary parameters
477
+ inputs = {
478
+ "query": query, # Current user query
479
+ "expanded_query": "", # Complete the code to define the expanded version of the query
480
+ "context": [], # Retrieved documents (initially empty)
481
+ "response": "", # Complete the code to define the AI-generated response
482
+ "precision_score": 0.0, # Complete the code to define the precision score of the response
483
+ "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
484
+ "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
485
+ "precision_loop_count": 0, # Complete the code to define the counter for precision loops
486
+ "feedback": "", # Complete the code to define the feedback
487
+ "query_feedback": "", # Complete the code to define the query feedback
488
+ "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
489
+ }
490
+
491
+ output = WORKFLOW_APP.invoke(inputs)
492
+
493
+ return output
494
+
495
+
496
+ #================================ Guardrails ===========================#
497
+ llama_guard_client = Groq(api_key=llama_api_key)
498
+ # Function to filter user input with Llama Guard
499
+ def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
500
+ """
501
+ Filters user input using Llama Guard to ensure it is safe.
502
+
503
+ Parameters:
504
+ - user_input: The input provided by the user.
505
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
506
+
507
+ Returns:
508
+ - The filtered and safe input.
509
+ """
510
+ try:
511
+ # Create a request to Llama Guard to filter the user input
512
+ response = llama_guard_client.chat.completions.create(
513
+ messages=[{"role": "user", "content": user_input}],
514
+ model=model,
515
+ )
516
+ # Return the filtered input
517
+ return response.choices[0].message.content.strip()
518
+ except Exception as e:
519
+ print(f"Error with Llama Guard: {e}")
520
+ return None
521
+
522
+
523
+ #============================= Adding Memory to the agent using mem0 ===============================#
524
+
525
+ class NutritionBot:
526
+ def __init__(self):
527
+ """
528
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
529
+ """
530
+
531
+ # Initialize a memory client to store and retrieve customer interactions
532
+ self.memory = MemoryClient(api_key=userdata.get("mem0")) # Complete the code to define the memory client API key
533
+
534
+ # Initialize the OpenAI client using the provided credentials
535
+ self.client = ChatOpenAI(
536
+ model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
537
+ api_key=config.get("API_KEY"), # API key for authentication
538
+ endpoint = config.get("OPENAI_API_BASE"),
539
+ #openai_api_base = config.get("OPENAI_API_BASE"),
540
+ temperature=0 # Controls randomness in responses; 0 ensures deterministic results
541
+ )
542
+
543
+ # Define tools available to the chatbot, such as web search
544
+ tools = [agentic_rag]
545
+
546
+ # Define the system prompt to set the behavior of the chatbot
547
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
548
+ Guidelines for Interaction:
549
+ Maintain a polite, professional, and reassuring tone.
550
+ Show genuine empathy for customer concerns and health challenges.
551
+ Reference past interactions to provide personalized and consistent advice.
552
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
553
+ Ensure consistent and accurate information across conversations.
554
+ If any detail is unclear or missing, proactively ask for clarification.
555
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
556
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
557
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
558
+
559
+ """
560
+
561
+ # Build the prompt template for the agent
562
+ prompt = ChatPromptTemplate.from_messages([
563
+ ("system", system_prompt), # System instructions
564
+ ("human", "{input}"), # Placeholder for human input
565
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
566
+ ])
567
+
568
+ # Create an agent capable of interacting with tools and executing tasks
569
+ agent = create_tool_calling_agent(self.client, tools, prompt)
570
+
571
+ # Wrap the agent in an executor to manage tool interactions and execution flow
572
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
573
+
574
+
575
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
576
+ """
577
+ Store customer interaction in memory for future reference.
578
+
579
+ Args:
580
+ user_id (str): Unique identifier for the customer.
581
+ message (str): Customer's query or message.
582
+ response (str): Chatbot's response.
583
+ metadata (Dict, optional): Additional metadata for the interaction.
584
+ """
585
+ if metadata is None:
586
+ metadata = {}
587
+
588
+ # Add a timestamp to the metadata for tracking purposes
589
+ metadata["timestamp"] = datetime.now().isoformat()
590
+
591
+ # Format the conversation for storage
592
+ conversation = [
593
+ {"role": "user", "content": message},
594
+ {"role": "assistant", "content": response}
595
+ ]
596
+
597
+ # Store the interaction in the memory client
598
+ self.memory.add(
599
+ conversation,
600
+ user_id=user_id,
601
+ output_format="v1.1",
602
+ metadata=metadata
603
+ )
604
+
605
+
606
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
607
+ """
608
+ Retrieve past interactions relevant to the current query.
609
+
610
+ Args:
611
+ user_id (str): Unique identifier for the customer.
612
+ query (str): The customer's current query.
613
+
614
+ Returns:
615
+ List[Dict]: A list of relevant past interactions.
616
+ """
617
+ return self.memory.search(
618
+ query=query, # Search for interactions related to the query
619
+ user_id=user_id, # Restrict search to the specific user
620
+ limit=5 # Complete the code to define the limit for retrieved interactions
621
+ )
622
+
623
+
624
+ def handle_customer_query(self, user_id: str, query: str) -> str:
625
+ """
626
+ Process a customer's query and provide a response, taking into account past interactions.
627
+
628
+ Args:
629
+ user_id (str): Unique identifier for the customer.
630
+ query (str): Customer's query.
631
+
632
+ Returns:
633
+ str: Chatbot's response.
634
+ """
635
+
636
+ # Retrieve relevant past interactions for context
637
+ relevant_history = self.get_relevant_history(user_id, query)
638
+
639
+ # Build a context string from the relevant history
640
+ context = "Previous relevant interactions:\n"
641
+ for memory in relevant_history:
642
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
643
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
644
+ context += "---\n"
645
+
646
+ # Print context for debugging purposes
647
+ print("Context: ", context)
648
+
649
+ # Prepare a prompt combining past context and the current query
650
+ prompt = f"""
651
+ Context:
652
+ {context}
653
+
654
+ Current customer query: {query}
655
+
656
+ Provide a helpful response that takes into account any relevant past interactions.
657
+ """
658
+
659
+ # Generate a response using the agent
660
+ response = self.agent_executor.invoke({"input": prompt})
661
+
662
+ # Store the current interaction for future reference
663
+ self.store_customer_interaction(
664
+ user_id=user_id,
665
+ message=query,
666
+ response=response["output"],
667
+ metadata={"type": "support_query"}
668
+ )
669
+
670
+ # Return the chatbot's response
671
+ return response['output']
672
+
673
+
674
+ #=====================User Interface using streamlit ===========================#
675
+ def nutrition_disorder_streamlit():
676
+ """
677
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
678
+ """
679
+ st.title("Nutrition Disorder Specialist")
680
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
681
+ st.write("Type 'exit' to end the conversation.")
682
+
683
+ # Initialize session state for chat history and user_id if they don't exist
684
+ if 'chat_history' not in st.session_state:
685
+ st.session_state.chat_history = []
686
+ if 'user_id' not in st.session_state:
687
+ st.session_state.user_id = None
688
+
689
+ # Login form: Only if user is not logged in
690
+ if st.session_state.user_id is None:
691
+ with st.form("login_form", clear_on_submit=True):
692
+ user_id = st.text_input("Please enter your name to begin:")
693
+ submit_button = st.form_submit_button("Login")
694
+ if submit_button and user_id:
695
+ st.session_state.user_id = user_id
696
+ st.session_state.chat_history.append({
697
+ "role": "assistant",
698
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
699
+ })
700
+ st.session_state.login_submitted = True # Set flag to trigger rerun
701
+ if st.session_state.get("login_submitted", False):
702
+ st.session_state.pop("login_submitted")
703
+ st.rerun()
704
+ else:
705
+ # Display chat history
706
+ for message in st.session_state.chat_history:
707
+ with st.chat_message(message["role"]):
708
+ st.write(message["content"])
709
+
710
+ # Chat input with custom placeholder text
711
+ user_query = st.chat_input("Type your question here or 'exit' to end)...") # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
712
+ if user_query:
713
+ if user_query.lower() == "exit":
714
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
715
+ with st.chat_message("user"):
716
+ st.write("exit")
717
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
718
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
719
+ with st.chat_message("assistant"):
720
+ st.write(goodbye_msg)
721
+ st.session_state.user_id = None
722
+ st.rerun()
723
+ return
724
+
725
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
726
+ with st.chat_message("user"):
727
+ st.write(user_query)
728
+
729
+ # Filter input using Llama Guard
730
+ filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
731
+ filtered_result = filtered_result.replace("\n", " ") # Normalize the result
732
+
733
+ # Check if input is safe based on allowed statuses
734
+ if filtered_result in ["safe", "unsafe S6", "unsafe S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
735
+ try:
736
+ if 'chatbot' not in st.session_state:
737
+ st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
738
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
739
+ # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
740
+ st.write(response)
741
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
742
+ except Exception as e:
743
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
744
+ st.write(error_msg)
745
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
746
+ else:
747
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
748
+ st.write(inappropriate_msg)
749
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
750
+
751
+ if __name__ == "__main__":
752
+ nutrition_disorder_streamlit()