nagesh5 commited on
Commit
9247687
·
verified ·
1 Parent(s): 2e15307

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +56 -45
app.py CHANGED
@@ -230,7 +230,7 @@ def craft_response(state: Dict) -> Dict:
230
 
231
 
232
 
233
- def score_groundedness(state: Dict) -> Dict:
234
  """
235
  Checks whether the response is grounded in the retrieved context.
236
 
@@ -255,7 +255,7 @@ def score_groundedness(state: Dict) -> Dict:
255
  chain = groundedness_prompt | llm | StrOutputParser()
256
  groundedness_score = float(chain.invoke({
257
  "context": "\n".join([doc["content"] for doc in state['context']]),
258
- "response": state['response'] # Complete the code to define the response
259
  }))
260
  print("groundedness_score: ", groundedness_score)
261
  state['groundedness_loop_count'] += 1
@@ -265,7 +265,6 @@ def score_groundedness(state: Dict) -> Dict:
265
  return state
266
 
267
 
268
-
269
  def check_precision(state: Dict) -> Dict:
270
  """
271
  Checks whether the response precisely addresses the user’s query.
@@ -288,14 +287,14 @@ def check_precision(state: Dict) -> Dict:
288
  ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
289
  ])
290
 
291
- chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
292
  precision_score = float(chain.invoke({
293
  "query": state['query'],
294
- "response":state['response'] # Complete the code to access the response from the state
295
  }))
296
  state['precision_score'] = precision_score
297
  print("precision_score:", precision_score)
298
- state['precision_loop_count'] +=1
299
  print("#########Precision Incremented###########")
300
  return state
301
 
@@ -324,7 +323,7 @@ def refine_response(state: Dict) -> Dict:
324
  "What improvements can be made to enhance accuracy and completeness?")
325
  ])
326
 
327
- chain = refine_response_prompt | llm| StrOutputParser()
328
 
329
  # Store response suggestions in a structured format
330
  feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
@@ -372,7 +371,7 @@ def should_continue_groundedness(state):
372
  """Decides if groundedness is sufficient or needs improvement."""
373
  print("---------should_continue_groundedness---------")
374
  print("groundedness loop count: ", state['groundedness_loop_count'])
375
- if state['groundedness_score'] >= 0.7: # Complete the code to define the threshold for groundedness
376
  print("Moving to precision")
377
  return "check_precision"
378
  else:
@@ -398,13 +397,9 @@ def should_continue_precision(state: Dict) -> str:
398
 
399
 
400
 
401
-
402
- def max_iterations_reached(state: Dict) -> Dict:
403
- """Handles the case when the maximum number of iterations is reached."""
404
- print("---------max_iterations_reached---------")
405
- """Handles the case when the maximum number of iterations is reached."""
406
- response = "I'm unable to refine the response further. Please provide more context or clarify your question."
407
- state['response'] = response
408
  return state
409
 
410
 
@@ -413,32 +408,33 @@ from langgraph.graph import END, StateGraph, START
413
 
414
  def create_workflow() -> StateGraph:
415
  """Creates the updated workflow for the AI nutrition agent."""
416
- workflow = StateGraph(AgentState) # Complete the code to define the initial state of the agent
417
 
418
  # Add processing nodes
419
- workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. Complete with the function to expand the query
420
- workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
421
- workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
422
- workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
423
- workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
424
- workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. Complete with the function to check precision
425
- workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
426
- workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
427
 
428
  # Main flow edges
429
  workflow.add_edge(START, "expand_query")
430
  workflow.add_edge("expand_query", "retrieve_context")
431
  workflow.add_edge("retrieve_context", "craft_response")
432
- workflow.add_edge("craft_response", "score_groundedness")
 
433
 
434
  # Conditional edges based on groundedness check
435
  workflow.add_conditional_edges(
436
- "score_groundedness",
437
  should_continue_groundedness, # Use the conditional function
438
  {
439
- "check_precision": check_precision, # If well-grounded, proceed to precision check.
440
- "refine_response": refine_response, # If not, refine the response.
441
- "max_iterations_reached": max_iterations_reached # If max loops reached, exit.
442
  }
443
  )
444
 
@@ -456,7 +452,6 @@ def create_workflow() -> StateGraph:
456
  )
457
 
458
  workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
459
-
460
  workflow.add_edge("max_iterations_reached", END)
461
 
462
  return workflow
@@ -479,19 +474,33 @@ def agentic_rag(query: str):
479
  """
480
  # Initialize state with necessary parameters
481
  inputs = {
482
- "query": query, # Current user query
483
- "expanded_query": "", # Complete the code to define the expanded version of the query
484
- "context": [], # Retrieved documents (initially empty)
485
- "response": "", # Complete the code to define the AI-generated response
486
- "precision_score": 0.0, # Complete the code to define the precision score of the response
487
- "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
488
- "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
489
- "precision_loop_count": 0, # Complete the code to define the counter for precision loops
490
- "feedback": "", # Complete the code to define the feedback
491
- "query_feedback": "", # Complete the code to define the query feedback
492
- "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
493
  }
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  output = WORKFLOW_APP.invoke(inputs)
496
 
497
  return output
@@ -506,7 +515,7 @@ def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12
506
 
507
  Parameters:
508
  - user_input: The input provided by the user.
509
- - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
510
 
511
  Returns:
512
  - The filtered and safe input.
@@ -538,12 +547,14 @@ class NutritionBot:
538
  # Initialize the OpenAI client using the provided credentials
539
  self.client = ChatOpenAI(
540
  model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
541
- api_key=config.get("API_KEY"), # API key for authentication
542
- endpoint = config.get("OPENAI_API_BASE"),
543
- #openai_api_base = config.get("OPENAI_API_BASE"),
544
  temperature=0 # Controls randomness in responses; 0 ensures deterministic results
545
  )
546
 
 
 
547
  # Define tools available to the chatbot, such as web search
548
  tools = [agentic_rag]
549
 
 
230
 
231
 
232
 
233
+ def score_groundness(state: Dict) -> Dict:
234
  """
235
  Checks whether the response is grounded in the retrieved context.
236
 
 
255
  chain = groundedness_prompt | llm | StrOutputParser()
256
  groundedness_score = float(chain.invoke({
257
  "context": "\n".join([doc["content"] for doc in state['context']]),
258
+ "response": state['response'] #
259
  }))
260
  print("groundedness_score: ", groundedness_score)
261
  state['groundedness_loop_count'] += 1
 
265
  return state
266
 
267
 
 
268
  def check_precision(state: Dict) -> Dict:
269
  """
270
  Checks whether the response precisely addresses the user’s query.
 
287
  ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
288
  ])
289
 
290
+ chain = precision_prompt | llm | StrOutputParser()
291
  precision_score = float(chain.invoke({
292
  "query": state['query'],
293
+ "response": state['response']
294
  }))
295
  state['precision_score'] = precision_score
296
  print("precision_score:", precision_score)
297
+ state['precision_loop_count'] += 1
298
  print("#########Precision Incremented###########")
299
  return state
300
 
 
323
  "What improvements can be made to enhance accuracy and completeness?")
324
  ])
325
 
326
+ chain = refine_response_prompt | llm | StrOutputParser()
327
 
328
  # Store response suggestions in a structured format
329
  feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
 
371
  """Decides if groundedness is sufficient or needs improvement."""
372
  print("---------should_continue_groundedness---------")
373
  print("groundedness loop count: ", state['groundedness_loop_count'])
374
+ if state['groundedness_score'] >= 0.7: # Threshold for groundedness
375
  print("Moving to precision")
376
  return "check_precision"
377
  else:
 
397
 
398
 
399
 
400
+ def max_iterations_reached(state: AgentState) -> AgentState:
401
+ """Handles the case where max iterations are reached."""
402
+ state['response'] = "We need more context to provide an accurate answer."
 
 
 
 
403
  return state
404
 
405
 
 
408
 
409
  def create_workflow() -> StateGraph:
410
  """Creates the updated workflow for the AI nutrition agent."""
411
+ workflow = StateGraph(AgentState)
412
 
413
  # Add processing nodes
414
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
415
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents.
416
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data.
417
+ workflow.add_node("score_groundness", score_groundness) # Step 4: Evaluate response grounding.
418
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded.
419
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision.
420
+ workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision.
421
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
422
 
423
  # Main flow edges
424
  workflow.add_edge(START, "expand_query")
425
  workflow.add_edge("expand_query", "retrieve_context")
426
  workflow.add_edge("retrieve_context", "craft_response")
427
+ workflow.add_edge("craft_response", "score_groundness")
428
+
429
 
430
  # Conditional edges based on groundedness check
431
  workflow.add_conditional_edges(
432
+ "score_groundness",
433
  should_continue_groundedness, # Use the conditional function
434
  {
435
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
436
+ "refine_response": "refine_response", # If not, refine the response.
437
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
438
  }
439
  )
440
 
 
452
  )
453
 
454
  workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
 
455
  workflow.add_edge("max_iterations_reached", END)
456
 
457
  return workflow
 
474
  """
475
  # Initialize state with necessary parameters
476
  inputs = {
477
+ "query": query,
478
+ "expanded_query": "",
479
+ "context": [],
480
+ "response": "",
481
+ "precision_score": 0.0,
482
+ "groundedness_score": 0.0,
483
+ "groundedness_loop_count": 0,
484
+ "precision_loop_count": 0,
485
+ "feedback": "",
486
+ "query_feedback": "",
487
+ "loop_max_iter": 3 # Set a reasonable maximum number of iterations
488
  }
489
 
490
+ # Initialize the Chroma vector store for retrieving documents with the correct collection
491
+ vector_store = Chroma(
492
+ collection_name="semantic_chunks", # Use the collection with semantic chunks
493
+ persist_directory="./research_db",
494
+ embedding_function=embedding_model
495
+ )
496
+
497
+ # Create a retriever from the vector store
498
+ retriever = vector_store.as_retriever(
499
+ search_type='similarity',
500
+ search_kwargs={'k': 3}
501
+ )
502
+
503
+
504
  output = WORKFLOW_APP.invoke(inputs)
505
 
506
  return output
 
515
 
516
  Parameters:
517
  - user_input: The input provided by the user.
518
+ - model: The Llama Guard model to be used for filtering (default is "meta-llama/llama-guard-4-12b").
519
 
520
  Returns:
521
  - The filtered and safe input.
 
547
  # Initialize the OpenAI client using the provided credentials
548
  self.client = ChatOpenAI(
549
  model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
550
+ openai_api_key=config.get("API_KEY"), # API key for authentication
551
+ #openai_endpoint = config.get("OPENAI_API_BASE"),
552
+ openai_api_base = config.get("OPENAI_API_BASE"),
553
  temperature=0 # Controls randomness in responses; 0 ensures deterministic results
554
  )
555
 
556
+
557
+
558
  # Define tools available to the chatbot, such as web search
559
  tools = [agentic_rag]
560