Michaeldavidstein commited on
Commit
041789a
·
verified ·
1 Parent(s): 653d796

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +700 -0
app.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import os # Interacting with the operating system (reading/writing files)
4
+ import chromadb # High-performance vector database for storing/querying dense vectors
5
+ from dotenv import load_dotenv # Loading environment variables from a .env file
6
+ import json # Parsing and handling JSON data
7
+
8
+ # LangChain imports
9
+ from langchain_core.documents import Document # Document data structures
10
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
11
+ from langchain_core.output_parsers import StrOutputParser # String output parser
12
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
13
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
14
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
15
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
16
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
17
+
18
+ # LangChain community & experimental imports
19
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
20
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
21
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
22
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
23
+ from langchain.text_splitter import (
24
+ CharacterTextSplitter, # Splitting text by characters
25
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
26
+ )
27
+ from langchain_core.tools import tool
28
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
29
+ from langchain_core.prompts import ChatPromptTemplate
30
+
31
+ # LangChain OpenAI imports
32
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
33
+ from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
34
+
35
+ # LlamaParse & LlamaIndex imports
36
+ from llama_parse import LlamaParse # Document parsing library
37
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
38
+
39
+ # LangGraph import
40
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
41
+
42
+ # Pydantic import
43
+ from pydantic import BaseModel # Pydantic for data validation
44
+
45
+ # Typing imports
46
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
47
+
48
+ # Other utilities
49
+ import numpy as np # Numpy for numerical operations
50
+ from groq import Groq
51
+ from mem0 import MemoryClient
52
+ import streamlit as st
53
+ from datetime import datetime
54
+
55
+ #====================================SETUP=====================================#
56
+ # Fetch secrets from Hugging Face Spaces
57
+ api_key = os.environ['api_key']
58
+ endpoint = os.environ['OPENAI_API_BASE']
59
+ # api_version = os.environ['AZURE_OPENAI_APIVERSION']
60
+ model_name = os.environ['CHATGPT_MODEL']
61
+ emb_key = os.environ['EMB_MODEL_KEY']
62
+ emb_endpoint = os.environ['EMB_DEPLOYMENT']
63
+ llama_api_key = os.environ['LLAMA_GUARD_API_KEY']
64
+ mem0_api_key = os.environ['mem0_api_key']
65
+
66
+ # Initialize the OpenAI embedding function for Chroma
67
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
68
+ api_base=endpoint,
69
+ api_key=api_key,
70
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
71
+ )
72
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided Azure endpoint and API key.
73
+
74
+ # Initialize the Azure OpenAI Embeddings
75
+ embedding_model = OpenAIEmbeddings(
76
+ openai_api_base=endpoint,
77
+ openai_api_key=api_key,
78
+ model='text-embedding-ada-002'
79
+ )
80
+ # This initializes the Azure OpenAI embeddings model using the specified endpoint, API key, and model name.
81
+
82
+
83
+ # Initialize the Azure Chat OpenAI model
84
+ llm = ChatOpenAI(
85
+ openai_api_base=endpoint,
86
+ openai_api_key=api_key,
87
+ model="gpt-4o-mini",
88
+ streaming=False
89
+ )
90
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
91
+
92
+ # set the LLM and embedding model in the LlamaIndex settings.
93
+ Settings.llm = llm # Complete the code to define the LLM model
94
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
95
+ #================================Creating Langgraph agent======================#
96
+
97
+ class AgentState(TypedDict):
98
+ query: str # The current user query
99
+ expanded_query: str # The expanded version of the user query
100
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
101
+ response: str # The generated response to the user query
102
+ precision_score: float # The precision score of the response
103
+ groundedness_score: float # The groundedness score of the response
104
+ groundedness_loop_count: int # Counter for groundedness refinement loops
105
+ precision_loop_count: int # Counter for precision refinement loops
106
+ feedback: str
107
+ query_feedback: str
108
+ groundedness_check: bool
109
+ loop_max_iter: int
110
+
111
+ def expand_query(state):
112
+ print("State at the start of expand_query:", state)
113
+ """
114
+ Expands the user query to improve retrieval of nutrition disorder-related information.
115
+
116
+ Args:
117
+ state (Dict): The current state of the workflow, containing the user query.
118
+
119
+ Returns:
120
+ Dict: The updated state with the expanded query.
121
+ """
122
+ print("---------Expanding Query---------")
123
+ system_message = '''You are a helpful research assistant that is well versed in Nutritional Disorders.
124
+ Return an expanded user query based on the user's input query. The expanded query should be designed to improve retrieval of the most relevant information.
125
+ Use the feedback if provided to craft the expanded query.
126
+ '''
127
+
128
+ expand_prompt = ChatPromptTemplate.from_messages([
129
+ ("system", system_message),
130
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
131
+
132
+ ])
133
+
134
+ chain = expand_prompt | llm | StrOutputParser()
135
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
136
+ print("expanded_query", expanded_query)
137
+ state["expanded_query"] = expanded_query
138
+ return state
139
+
140
+
141
+ # Initialize the Chroma vector store for retrieving documents
142
+ vector_store = Chroma(
143
+ collection_name="nutritional_hypotheticals",
144
+ persist_directory="./nutritional_db",
145
+ embedding_function=embedding_model
146
+
147
+ )
148
+
149
+ # Create a retriever from the vector store
150
+ retriever = vector_store.as_retriever(
151
+ search_type='similarity',
152
+ search_kwargs={'k': 3}
153
+ )
154
+
155
+ def retrieve_context(state):
156
+ print("State at the start of retrieve_context:", state)
157
+
158
+ """
159
+ Retrieves context from the vector store using the expanded or original query.
160
+
161
+ Args:
162
+ state (Dict): The current state of the workflow, containing the query and expanded query.
163
+
164
+ Returns:
165
+ Dict: The updated state with the retrieved context.
166
+ """
167
+ query = state['expanded_query']
168
+ print("Query used for retrieval:", query) # Debugging: Print the query
169
+
170
+ # Retrieve documents from the vector store
171
+ docs = retriever.invoke(query)
172
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
173
+
174
+ # Extract both page_content and metadata from each document
175
+ state['context'] = [
176
+ {
177
+ "content": doc.page_content, # The actual content of the document
178
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
179
+ }
180
+ for doc in docs
181
+ ]
182
+
183
+ print("Extracted context with metadata:", state['context']) # Debugging: Print the extracted context
184
+ return state
185
+
186
+
187
+
188
+ def craft_response(state: Dict) -> Dict:
189
+ print("State at the start of craft_response:", state)
190
+ """
191
+ Generates a response using the retrieved context, focusing on nutrition disorders.
192
+
193
+ Args:
194
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
195
+
196
+ Returns:
197
+ Dict: The updated state with the generated response.
198
+ """
199
+ print("---------craft_response---------")
200
+ system_message = '''You are an expert at condensing information. Your task is to extract relevant information for a given query and provide a grounded and highly precise response.'''
201
+
202
+ response_prompt = ChatPromptTemplate.from_messages([
203
+ ("system", system_message),
204
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
205
+ ])
206
+
207
+ chain = response_prompt | llm
208
+ response = chain.invoke({
209
+ "query": state['query'],
210
+ "context": "\n".join([doc["content"] for doc in state['context']]),
211
+ "feedback": state["feedback"] if state["feedback"] else "No feedback provided." # Add feedback to the prompt # add feedback to the prompt
212
+ })
213
+ state['response'] = response
214
+ print("intermediate response: ", response)
215
+
216
+ return state
217
+
218
+
219
+ def score_groundedness(state: Dict) -> Dict:
220
+ print("State at the start of score_groundedness:", state)
221
+ """
222
+ Checks whether the response is grounded in the retrieved context.
223
+
224
+ Args:
225
+ state (Dict): The current state of the workflow, containing the response and context.
226
+
227
+ Returns:
228
+ Dict: The updated state with the groundedness score.
229
+ """
230
+ print("---------check_groundedness---------")
231
+
232
+ # System message to guide the evaluation
233
+ system_message = '''You are a groundedness evaluator. Your task is to assess how well the given response aligns with the provided context.
234
+ - A grounded response is one that is accurate, directly supported by the context, and avoids speculation.
235
+ - A response should not include information that cannot be verified or inferred from the context.
236
+
237
+ Instructions:
238
+ - Assign a score between 0.0 and 1.0, where:
239
+ - 1.0: Fully grounded (entirely supported by the context).
240
+ - 0.5: Partially grounded (some elements are supported, but others are speculative).
241
+ - 0.0: Not grounded (contains speculative or unsupported information).
242
+ - Provide only the numerical groundedness score as the output.'''
243
+
244
+ # Define the prompt template for evaluating groundedness
245
+ groundedness_prompt = ChatPromptTemplate.from_messages([
246
+ ("system", system_message),
247
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
248
+ ])
249
+
250
+ # Chain to compute groundedness score
251
+ chain = groundedness_prompt | llm | StrOutputParser()
252
+ groundedness_score = float(chain.invoke({
253
+ "context": "\n".join([doc["content"] for doc in state['context']]), # Combine document content
254
+ "response": state['response'] # Use the response from the state
255
+ }))
256
+
257
+ print("groundedness_score: ", groundedness_score)
258
+ state['groundedness_loop_count'] += 1
259
+ print("######### Groundedness Loop Count Incremented ###########")
260
+ state['groundedness_score'] = groundedness_score
261
+ print("groundedness_score: ", state['groundedness_score'])
262
+
263
+ return state
264
+
265
+
266
+ def check_precision(state: Dict) -> Dict:
267
+
268
+ print("State at the start of check_precision:", state)
269
+ """
270
+ Checks whether the response precisely addresses the user’s query.
271
+
272
+ Args:
273
+ state (Dict): The current state of the workflow, containing the query and response.
274
+
275
+ Returns:
276
+ Dict: The updated state with the precision score.
277
+ """
278
+ print("---------check_precision---------")
279
+
280
+ # System message for evaluating precision
281
+ system_message = '''You are a precision evaluator. Your task is to assess how well the given response directly and fully addresses the user's query.
282
+
283
+ Instructions:
284
+ - A precise response is one that:
285
+ - Directly answers the user’s query without unnecessary or unrelated information.
286
+ - Fully addresses all aspects of the query.
287
+ - Avoids vague or overly general statements.
288
+ - Assign a precision score between 0.0 and 1.0:
289
+ - 1.0: Fully precise (direct, complete, and relevant to the query).
290
+ - 0.5: Partially precise (addresses the query but is incomplete or includes some irrelevant information).
291
+ - 0.0: Not precise (fails to address the query or contains mostly irrelevant information).
292
+ - Provide only the numerical precision score as the output.'''
293
+
294
+ # Define the prompt template for evaluating precision
295
+ precision_prompt = ChatPromptTemplate.from_messages([
296
+ ("system", system_message),
297
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
298
+ ])
299
+
300
+ # Chain to compute precision score
301
+ chain = precision_prompt | llm | StrOutputParser()
302
+ precision_score = float(chain.invoke({
303
+ "query": state['query'],
304
+ "response": state['response']
305
+ }))
306
+
307
+ # Update the state with precision score
308
+ state['precision_score'] = precision_score
309
+ print("precision_score:", precision_score)
310
+ state['precision_loop_count'] += 1
311
+ print("#########Precision Incremented###########")
312
+ return state
313
+
314
+ def refine_response(state: Dict) -> Dict:
315
+ print("State at the start of refine_response:", state)
316
+
317
+ """
318
+ Suggests improvements for the generated response.
319
+
320
+ Args:
321
+ state (Dict): The current state of the workflow, containing the query and response.
322
+
323
+ Returns:
324
+ Dict: The updated state with response refinement suggestions.
325
+ """
326
+ print("---------refine_response---------")
327
+
328
+ system_message = '''You are a constructive feedback evaluator. Your task is to analyze the provided response and identify potential gaps, ambiguities, or missing details. Your feedback should help improve the response for accuracy, clarity, and completeness.
329
+
330
+ Instructions:
331
+ - Do not rewrite the response.
332
+ - Focus on identifying the following:
333
+ - Are there any gaps in the information provided?
334
+ - Is the response ambiguous or unclear in any part?
335
+ - Are there any details missing that are relevant to fully addressing the context or query?
336
+ - Provide actionable and constructive suggestions for improvement.
337
+ - Avoid criticism without offering specific recommendations.
338
+
339
+ Your output should be written as a list of feedback points, with each suggestion clearly and concisely stated.'''
340
+
341
+ refine_response_prompt = ChatPromptTemplate.from_messages([
342
+ ("system", system_message),
343
+ ("user", "Query: {query}\nResponse: {response}\n\n"
344
+ "What improvements can be made to enhance accuracy and completeness?")
345
+ ])
346
+
347
+ chain = refine_response_prompt | llm | StrOutputParser()
348
+
349
+ # Store response suggestions in a structured format
350
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
351
+ print("feedback: ", feedback)
352
+ print(f"State: {state}")
353
+ state['feedback'] = feedback
354
+ return state
355
+
356
+ def refine_query(state: Dict) -> Dict:
357
+ print("State at the start of refine_query:", state)
358
+ """
359
+ Suggests improvements for the expanded query.
360
+
361
+ Args:
362
+ state (Dict): The current state of the workflow, containing the query and expanded query.
363
+
364
+ Returns:
365
+ Dict: The updated state with query refinement suggestions.
366
+ """
367
+ print("---------refine_query---------")
368
+
369
+ system_message = '''You are a query refinement assistant. Your task is to analyze the original query and its expanded version to suggest specific improvements that can enhance search precision and relevance.
370
+
371
+ Instructions:
372
+ - Do not replace or rewrite the expanded query. Instead, provide structured suggestions for improvement.
373
+ - Focus on identifying:
374
+ - Missing details or specific keywords that could make the query more precise.
375
+ - Scope refinements to narrow or broaden the query if needed.
376
+ - Ambiguities or redundancies that can be clarified or removed.
377
+ - Ensure your suggestions are actionable and presented in a clear, concise, and structured format.
378
+ - Avoid general or vague feedback; provide specific recommendations.
379
+
380
+ Your output should be a list of suggestions that can improve the expanded query without modifying it directly.'''
381
+
382
+ refine_query_prompt = ChatPromptTemplate.from_messages([
383
+ ("system", system_message),
384
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
385
+ "What improvements can be made for a better search?")
386
+ ])
387
+
388
+ chain = refine_query_prompt | llm | StrOutputParser()
389
+
390
+ # Store refinement suggestions without modifying the original expanded query
391
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
392
+ print("query_feedback: ", query_feedback)
393
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
394
+ state['query_feedback'] = query_feedback
395
+ return state
396
+
397
+
398
+ def should_continue_groundedness(state):
399
+ print("State at the start of should_continue_groundedness:", state)
400
+ """Decides if groundedness is sufficient or needs improvement."""
401
+ print("---------should_continue_groundedness---------")
402
+ print("groundedness loop count: ", state['groundedness_loop_count'])
403
+
404
+ # Check if groundedness score meets the required threshold
405
+ if state['groundedness_score'] >= 0.8: # Threshold for groundedness
406
+ print("Moving to precision")
407
+ return "check_precision" # Proceed to precision checking
408
+ else:
409
+ # Check if the maximum number of iterations has been reached
410
+ if state['groundedness_loop_count'] >= state['loop_max_iter']:
411
+ return "max_iterations_reached" # Stop refinement if max iterations reached
412
+ else:
413
+ print(f"---------Groundedness Score Threshold Not Met. Refining Response-----------")
414
+ return "refine_response" # Continue refining the response
415
+
416
+ def should_continue_precision(state: Dict) -> str:
417
+ print("State at the start of should_continue_precision:", state)
418
+
419
+ """Decides if precision is sufficient or needs improvement."""
420
+ print("---------should_continue_precision---------")
421
+ print("precision loop count: ", state['precision_loop_count'])
422
+
423
+ # Check if the precision score meets the required threshold
424
+ if state['precision_score'] >= 0.8: # Threshold for precision
425
+ return "pass" # Complete the workflow
426
+ else:
427
+ # Check if the maximum number of iterations has been reached
428
+ if state['precision_loop_count'] >= state['loop_max_iter']: # Maximum allowed loops
429
+ return "max_iterations_reached"
430
+ else:
431
+ print(f"---------Precision Score Threshold Not Met. Refining Query-----------") # Debugging
432
+ return "refine_query" # Refine the query
433
+
434
+
435
+ def max_iterations_reached(state: Dict) -> Dict:
436
+ """Handles the case when the maximum number of iterations is reached."""
437
+ print("---------max_iterations_reached---------")
438
+ """Handles the case when the maximum number of iterations is reached."""
439
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
440
+ state['response'] = response
441
+ return state
442
+
443
+
444
+
445
+ from langgraph.graph import END, StateGraph, START
446
+
447
+ from langgraph.graph import StateGraph, START, END
448
+ from typing import Callable
449
+
450
+
451
+ def create_workflow() -> StateGraph:
452
+ """Creates the updated workflow for the AI nutrition agent."""
453
+ # Initialize workflow with the `AgentState` schema
454
+ workflow = StateGraph(state_schema=AgentState)
455
+
456
+ # Add processing nodes
457
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand the user query
458
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents
459
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data
460
+ workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding
461
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded
462
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision
463
+ workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision
464
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations gracefully
465
+
466
+ # Main flow edges
467
+ workflow.add_edge(START, "expand_query") # Start with expanding the query
468
+ workflow.add_edge("expand_query", "retrieve_context") # After expansion, retrieve context/documents
469
+ workflow.add_edge("retrieve_context", "craft_response") # Generate a response based on retrieved context
470
+ workflow.add_edge("craft_response", "score_groundedness") # Evaluate the response for groundedness
471
+
472
+ # Conditional edges based on groundedness check
473
+ workflow.add_conditional_edges(
474
+ "score_groundedness",
475
+ should_continue_groundedness, # Use the conditional function
476
+ {
477
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check
478
+ "refine_response": "refine_response", # If not, refine the response
479
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit
480
+ }
481
+ )
482
+
483
+ workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed by crafting a new response
484
+
485
+ # Conditional edges based on precision check
486
+ workflow.add_conditional_edges(
487
+ "check_precision",
488
+ should_continue_precision, # Use the conditional function
489
+ {
490
+ "pass": END, # If precise, complete the workflow
491
+ "refine_query": "refine_query", # If imprecise, refine the query
492
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit
493
+ }
494
+ )
495
+
496
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again
497
+ workflow.add_edge("max_iterations_reached", END) # Max iterations lead to an exit point
498
+
499
+ return workflow
500
+
501
+
502
+
503
+ #=========================== Defining the agentic rag tool ====================#
504
+ WORKFLOW_APP = create_workflow().compile()
505
+
506
+ # Define the tool
507
+ @tool
508
+ def agentic_rag(query: str):
509
+ """
510
+ Runs the RAG-based agent with conversation history for context-aware responses.
511
+
512
+ Args:
513
+ query (str): The current user query.
514
+
515
+ Returns:
516
+ Dict[str, Any]: The updated state with the generated response and conversation history.
517
+ """
518
+ # Initialize state with necessary parameters
519
+ inputs = {
520
+ "query": query, # Current user query
521
+ "expanded_query": "", # Complete the code to define the expanded version of the query
522
+ "context": [], # Retrieved documents (initially empty)
523
+ "response": "", # Complete the code to define the AI-generated response
524
+ "precision_score": 0.0, # Complete the code to define the precision score of the response
525
+ "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
526
+ "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
527
+ "precision_loop_count": 0, # Complete the code to define the counter for precision loops
528
+ "feedback": "", # Complete the code to define the feedback
529
+ "query_feedback": "", # Complete the code to define the query feedback
530
+ "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
531
+ }
532
+
533
+ output = WORKFLOW_APP.invoke(inputs)
534
+
535
+ return output
536
+
537
+ #================================ Guardrails ===========================#
538
+ llama_guard_client = Groq(api_key=llama_api_key)
539
+ # Function to filter user input with Llama Guard
540
+ def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
541
+ """
542
+ Filters user input using Llama Guard to ensure it is safe.
543
+
544
+ Parameters:
545
+ - user_input: The input provided by the user.
546
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
547
+
548
+ Returns:
549
+ - The filtered and safe input.
550
+ """
551
+ try:
552
+ # Create a request to Llama Guard to filter the user input
553
+ response = llama_guard_client.chat.completions.create(
554
+ messages=[{"role": "user", "content": user_input}],
555
+ model=model,
556
+ )
557
+ # Return the filtered input
558
+ return response.choices[0].message.content.strip()
559
+ except Exception as e:
560
+ print(f"Error with Llama Guard: {e}")
561
+ return None
562
+
563
+
564
+ #============================= Adding Memory to the agent using mem0 ===============================#
565
+
566
+ # NutritionBot class
567
+ class NutritionBot:
568
+ def __init__(self):
569
+ """
570
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
571
+ """
572
+
573
+ # Initialize a memory client to store and retrieve customer interactions
574
+ self.memory = MemoryClient(api_key="mock_memory_api_key") # Replace with actual API key
575
+
576
+ # Initialize the OpenAI client using the provided credentials
577
+ self.client = ChatOpenAI(
578
+ model_name="gpt-4", # Specify the model to use (e.g., GPT-4)
579
+ api_key="mock_openai_api_key" # Replace with actual API key
580
+ )
581
+
582
+ # Define tools available to the chatbot, including agentic_rag
583
+ tools = [agentic_rag]
584
+
585
+ # Define the system prompt to set the behavior of the chatbot
586
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience."""
587
+
588
+ # Build the prompt template for the agent
589
+ prompt = ChatPromptTemplate.from_messages([
590
+ ("system", system_prompt), # System instructions
591
+ ("human", "{input}"), # Placeholder for human input
592
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
593
+ ])
594
+
595
+ # Create an agent capable of interacting with tools and executing tasks
596
+ agent = create_tool_calling_agent(self.client, tools, prompt)
597
+
598
+ # Wrap the agent in an executor to manage tool interactions and execution flow
599
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
600
+
601
+ def handle_customer_query(self, user_id: str, query: str) -> str:
602
+ """
603
+ Process a customer's query and provide a response, taking into account past interactions.
604
+
605
+ Args:
606
+ user_id (str): Unique identifier for the customer.
607
+ query (str): Customer's query.
608
+
609
+ Returns:
610
+ str: Chatbot's response.
611
+ """
612
+ # Use the agentic_rag tool to process the query
613
+ try:
614
+ # Call the agentic_rag tool directly
615
+ result = agentic_rag(query)
616
+ response = result.get("response", "I'm sorry, I couldn't generate a response.")
617
+ return response
618
+ except Exception as e:
619
+ return f"An error occurred while processing your query: {str(e)}"
620
+
621
+
622
+ #=====================User Interface using streamlit ===========================#
623
+ def nutrition_disorder_streamlit():
624
+ """
625
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
626
+ """
627
+ st.title("Nutrition Disorder Specialist")
628
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
629
+ st.write("Type 'exit' to end the conversation.")
630
+
631
+ # Initialize session state for chat history and user_id if they don't exist
632
+ if 'chat_history' not in st.session_state:
633
+ st.session_state.chat_history = []
634
+ if 'user_id' not in st.session_state:
635
+ st.session_state.user_id = None
636
+
637
+ # Login form: Only if user is not logged in
638
+ if st.session_state.user_id is None:
639
+ with st.form("login_form", clear_on_submit=True):
640
+ user_id = st.text_input("Please enter your name to begin:")
641
+ submit_button = st.form_submit_button("Login")
642
+ if submit_button and user_id:
643
+ st.session_state.user_id = user_id
644
+ st.session_state.chat_history.append({
645
+ "role": "assistant",
646
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
647
+ })
648
+ st.session_state.login_submitted = True # Set flag to trigger rerun
649
+ if st.session_state.get("login_submitted", False):
650
+ st.session_state.pop("login_submitted")
651
+ st.rerun()
652
+ else:
653
+ # Display chat history
654
+ for message in st.session_state.chat_history:
655
+ with st.chat_message(message["role"]):
656
+ st.write(message["content"])
657
+
658
+ # Chat input with custom placeholder text
659
+ user_query = st.chat_input("You: ").strip() # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
660
+ if user_query:
661
+ if user_query.lower() == "exit":
662
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
663
+ with st.chat_message("user"):
664
+ st.write("exit")
665
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
666
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
667
+ with st.chat_message("assistant"):
668
+ st.write(goodbye_msg)
669
+ st.session_state.user_id = None
670
+ st.rerun()
671
+ return
672
+
673
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
674
+ with st.chat_message("user"):
675
+ st.write(user_query)
676
+
677
+ # Filter input using Llama Guard
678
+ filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
679
+ filtered_result = filtered_result.replace("\n", " ") # Normalize the result
680
+
681
+ # Check if input is safe based on allowed statuses
682
+ if filtered_result in ["safe", "unsafe S7", "unsafe S6"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
683
+ try:
684
+ if 'chatbot' not in st.session_state:
685
+ st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
686
+ response = st.session_state.chatbot.andle_customer_query(st.session_state.user_id, user_query)
687
+ # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
688
+ st.write(response)
689
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
690
+ except Exception as e:
691
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
692
+ st.write(error_msg)
693
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
694
+ else:
695
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
696
+ st.write(inappropriate_msg)
697
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
698
+
699
+ if __name__ == "__main__":
700
+ nutrition_disorder_streamlit()