Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -230,7 +230,7 @@ def craft_response(state: Dict) -> Dict:
|
|
230 |
|
231 |
|
232 |
|
233 |
-
def
|
234 |
"""
|
235 |
Checks whether the response is grounded in the retrieved context.
|
236 |
|
@@ -255,7 +255,7 @@ def score_groundedness(state: Dict) -> Dict:
|
|
255 |
chain = groundedness_prompt | llm | StrOutputParser()
|
256 |
groundedness_score = float(chain.invoke({
|
257 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
258 |
-
"response": state['response'] #
|
259 |
}))
|
260 |
print("groundedness_score: ", groundedness_score)
|
261 |
state['groundedness_loop_count'] += 1
|
@@ -265,7 +265,6 @@ def score_groundedness(state: Dict) -> Dict:
|
|
265 |
return state
|
266 |
|
267 |
|
268 |
-
|
269 |
def check_precision(state: Dict) -> Dict:
|
270 |
"""
|
271 |
Checks whether the response precisely addresses the user’s query.
|
@@ -288,14 +287,14 @@ def check_precision(state: Dict) -> Dict:
|
|
288 |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
289 |
])
|
290 |
|
291 |
-
chain = precision_prompt | llm | StrOutputParser()
|
292 |
precision_score = float(chain.invoke({
|
293 |
"query": state['query'],
|
294 |
-
"response":state['response']
|
295 |
}))
|
296 |
state['precision_score'] = precision_score
|
297 |
print("precision_score:", precision_score)
|
298 |
-
state['precision_loop_count'] +=1
|
299 |
print("#########Precision Incremented###########")
|
300 |
return state
|
301 |
|
@@ -324,7 +323,7 @@ def refine_response(state: Dict) -> Dict:
|
|
324 |
"What improvements can be made to enhance accuracy and completeness?")
|
325 |
])
|
326 |
|
327 |
-
chain = refine_response_prompt | llm| StrOutputParser()
|
328 |
|
329 |
# Store response suggestions in a structured format
|
330 |
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
|
@@ -372,7 +371,7 @@ def should_continue_groundedness(state):
|
|
372 |
"""Decides if groundedness is sufficient or needs improvement."""
|
373 |
print("---------should_continue_groundedness---------")
|
374 |
print("groundedness loop count: ", state['groundedness_loop_count'])
|
375 |
-
if state['groundedness_score'] >= 0.7: #
|
376 |
print("Moving to precision")
|
377 |
return "check_precision"
|
378 |
else:
|
@@ -398,13 +397,9 @@ def should_continue_precision(state: Dict) -> str:
|
|
398 |
|
399 |
|
400 |
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
print("---------max_iterations_reached---------")
|
405 |
-
"""Handles the case when the maximum number of iterations is reached."""
|
406 |
-
response = "I'm unable to refine the response further. Please provide more context or clarify your question."
|
407 |
-
state['response'] = response
|
408 |
return state
|
409 |
|
410 |
|
@@ -413,32 +408,33 @@ from langgraph.graph import END, StateGraph, START
|
|
413 |
|
414 |
def create_workflow() -> StateGraph:
|
415 |
"""Creates the updated workflow for the AI nutrition agent."""
|
416 |
-
workflow = StateGraph(AgentState)
|
417 |
|
418 |
# Add processing nodes
|
419 |
-
workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
|
420 |
-
workflow.add_node("retrieve_context", retrieve_context)
|
421 |
-
workflow.add_node("craft_response", craft_response)
|
422 |
-
workflow.add_node("
|
423 |
-
workflow.add_node("refine_response", refine_response)
|
424 |
-
workflow.add_node("check_precision", check_precision)
|
425 |
-
workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision.
|
426 |
-
workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
|
427 |
|
428 |
# Main flow edges
|
429 |
workflow.add_edge(START, "expand_query")
|
430 |
workflow.add_edge("expand_query", "retrieve_context")
|
431 |
workflow.add_edge("retrieve_context", "craft_response")
|
432 |
-
workflow.add_edge("craft_response", "
|
|
|
433 |
|
434 |
# Conditional edges based on groundedness check
|
435 |
workflow.add_conditional_edges(
|
436 |
-
"
|
437 |
should_continue_groundedness, # Use the conditional function
|
438 |
{
|
439 |
-
"check_precision": check_precision, # If well-grounded, proceed to precision check.
|
440 |
-
"refine_response": refine_response, # If not, refine the response.
|
441 |
-
"max_iterations_reached": max_iterations_reached # If max loops reached, exit.
|
442 |
}
|
443 |
)
|
444 |
|
@@ -456,7 +452,6 @@ def create_workflow() -> StateGraph:
|
|
456 |
)
|
457 |
|
458 |
workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
|
459 |
-
|
460 |
workflow.add_edge("max_iterations_reached", END)
|
461 |
|
462 |
return workflow
|
@@ -479,19 +474,33 @@ def agentic_rag(query: str):
|
|
479 |
"""
|
480 |
# Initialize state with necessary parameters
|
481 |
inputs = {
|
482 |
-
"query": query,
|
483 |
-
"expanded_query": "",
|
484 |
-
"context": [],
|
485 |
-
"response": "",
|
486 |
-
"precision_score": 0.0,
|
487 |
-
"groundedness_score": 0.0,
|
488 |
-
"groundedness_loop_count": 0,
|
489 |
-
"precision_loop_count": 0,
|
490 |
-
"feedback": "",
|
491 |
-
"query_feedback": "",
|
492 |
-
"loop_max_iter": 3 #
|
493 |
}
|
494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
output = WORKFLOW_APP.invoke(inputs)
|
496 |
|
497 |
return output
|
@@ -506,7 +515,7 @@ def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12
|
|
506 |
|
507 |
Parameters:
|
508 |
- user_input: The input provided by the user.
|
509 |
-
- model: The Llama Guard model to be used for filtering (default is "llama-guard-
|
510 |
|
511 |
Returns:
|
512 |
- The filtered and safe input.
|
@@ -538,12 +547,14 @@ class NutritionBot:
|
|
538 |
# Initialize the OpenAI client using the provided credentials
|
539 |
self.client = ChatOpenAI(
|
540 |
model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
temperature=0 # Controls randomness in responses; 0 ensures deterministic results
|
545 |
)
|
546 |
|
|
|
|
|
547 |
# Define tools available to the chatbot, such as web search
|
548 |
tools = [agentic_rag]
|
549 |
|
|
|
230 |
|
231 |
|
232 |
|
233 |
+
def score_groundness(state: Dict) -> Dict:
|
234 |
"""
|
235 |
Checks whether the response is grounded in the retrieved context.
|
236 |
|
|
|
255 |
chain = groundedness_prompt | llm | StrOutputParser()
|
256 |
groundedness_score = float(chain.invoke({
|
257 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
258 |
+
"response": state['response'] #
|
259 |
}))
|
260 |
print("groundedness_score: ", groundedness_score)
|
261 |
state['groundedness_loop_count'] += 1
|
|
|
265 |
return state
|
266 |
|
267 |
|
|
|
268 |
def check_precision(state: Dict) -> Dict:
|
269 |
"""
|
270 |
Checks whether the response precisely addresses the user’s query.
|
|
|
287 |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
288 |
])
|
289 |
|
290 |
+
chain = precision_prompt | llm | StrOutputParser()
|
291 |
precision_score = float(chain.invoke({
|
292 |
"query": state['query'],
|
293 |
+
"response": state['response']
|
294 |
}))
|
295 |
state['precision_score'] = precision_score
|
296 |
print("precision_score:", precision_score)
|
297 |
+
state['precision_loop_count'] += 1
|
298 |
print("#########Precision Incremented###########")
|
299 |
return state
|
300 |
|
|
|
323 |
"What improvements can be made to enhance accuracy and completeness?")
|
324 |
])
|
325 |
|
326 |
+
chain = refine_response_prompt | llm | StrOutputParser()
|
327 |
|
328 |
# Store response suggestions in a structured format
|
329 |
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
|
|
|
371 |
"""Decides if groundedness is sufficient or needs improvement."""
|
372 |
print("---------should_continue_groundedness---------")
|
373 |
print("groundedness loop count: ", state['groundedness_loop_count'])
|
374 |
+
if state['groundedness_score'] >= 0.7: # Threshold for groundedness
|
375 |
print("Moving to precision")
|
376 |
return "check_precision"
|
377 |
else:
|
|
|
397 |
|
398 |
|
399 |
|
400 |
+
def max_iterations_reached(state: AgentState) -> AgentState:
|
401 |
+
"""Handles the case where max iterations are reached."""
|
402 |
+
state['response'] = "We need more context to provide an accurate answer."
|
|
|
|
|
|
|
|
|
403 |
return state
|
404 |
|
405 |
|
|
|
408 |
|
409 |
def create_workflow() -> StateGraph:
|
410 |
"""Creates the updated workflow for the AI nutrition agent."""
|
411 |
+
workflow = StateGraph(AgentState)
|
412 |
|
413 |
# Add processing nodes
|
414 |
+
workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
|
415 |
+
workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents.
|
416 |
+
workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data.
|
417 |
+
workflow.add_node("score_groundness", score_groundness) # Step 4: Evaluate response grounding.
|
418 |
+
workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded.
|
419 |
+
workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision.
|
420 |
+
workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision.
|
421 |
+
workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
|
422 |
|
423 |
# Main flow edges
|
424 |
workflow.add_edge(START, "expand_query")
|
425 |
workflow.add_edge("expand_query", "retrieve_context")
|
426 |
workflow.add_edge("retrieve_context", "craft_response")
|
427 |
+
workflow.add_edge("craft_response", "score_groundness")
|
428 |
+
|
429 |
|
430 |
# Conditional edges based on groundedness check
|
431 |
workflow.add_conditional_edges(
|
432 |
+
"score_groundness",
|
433 |
should_continue_groundedness, # Use the conditional function
|
434 |
{
|
435 |
+
"check_precision": "check_precision", # If well-grounded, proceed to precision check.
|
436 |
+
"refine_response": "refine_response", # If not, refine the response.
|
437 |
+
"max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
|
438 |
}
|
439 |
)
|
440 |
|
|
|
452 |
)
|
453 |
|
454 |
workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
|
|
|
455 |
workflow.add_edge("max_iterations_reached", END)
|
456 |
|
457 |
return workflow
|
|
|
474 |
"""
|
475 |
# Initialize state with necessary parameters
|
476 |
inputs = {
|
477 |
+
"query": query,
|
478 |
+
"expanded_query": "",
|
479 |
+
"context": [],
|
480 |
+
"response": "",
|
481 |
+
"precision_score": 0.0,
|
482 |
+
"groundedness_score": 0.0,
|
483 |
+
"groundedness_loop_count": 0,
|
484 |
+
"precision_loop_count": 0,
|
485 |
+
"feedback": "",
|
486 |
+
"query_feedback": "",
|
487 |
+
"loop_max_iter": 3 # Set a reasonable maximum number of iterations
|
488 |
}
|
489 |
|
490 |
+
# Initialize the Chroma vector store for retrieving documents with the correct collection
|
491 |
+
vector_store = Chroma(
|
492 |
+
collection_name="semantic_chunks", # Use the collection with semantic chunks
|
493 |
+
persist_directory="./research_db",
|
494 |
+
embedding_function=embedding_model
|
495 |
+
)
|
496 |
+
|
497 |
+
# Create a retriever from the vector store
|
498 |
+
retriever = vector_store.as_retriever(
|
499 |
+
search_type='similarity',
|
500 |
+
search_kwargs={'k': 3}
|
501 |
+
)
|
502 |
+
|
503 |
+
|
504 |
output = WORKFLOW_APP.invoke(inputs)
|
505 |
|
506 |
return output
|
|
|
515 |
|
516 |
Parameters:
|
517 |
- user_input: The input provided by the user.
|
518 |
+
- model: The Llama Guard model to be used for filtering (default is "meta-llama/llama-guard-4-12b").
|
519 |
|
520 |
Returns:
|
521 |
- The filtered and safe input.
|
|
|
547 |
# Initialize the OpenAI client using the provided credentials
|
548 |
self.client = ChatOpenAI(
|
549 |
model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
|
550 |
+
openai_api_key=config.get("API_KEY"), # API key for authentication
|
551 |
+
#openai_endpoint = config.get("OPENAI_API_BASE"),
|
552 |
+
openai_api_base = config.get("OPENAI_API_BASE"),
|
553 |
temperature=0 # Controls randomness in responses; 0 ensures deterministic results
|
554 |
)
|
555 |
|
556 |
+
|
557 |
+
|
558 |
# Define tools available to the chatbot, such as web search
|
559 |
tools = [agentic_rag]
|
560 |
|