|
|
|
|
|
import os |
|
import chromadb |
|
from dotenv import load_dotenv |
|
import json |
|
|
|
|
|
from langchain_core.documents import Document |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.chains.query_constructor.base import AttributeInfo |
|
from langchain.retrievers.self_query.base import SelfQueryRetriever |
|
from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
|
|
|
|
from langchain_community.vectorstores import Chroma |
|
from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader |
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder |
|
from langchain_experimental.text_splitter import SemanticChunker |
|
from langchain.text_splitter import ( |
|
CharacterTextSplitter, |
|
RecursiveCharacterTextSplitter |
|
) |
|
from langchain_core.tools import tool |
|
from langchain.agents import create_tool_calling_agent, AgentExecutor |
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
|
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
from llama_parse import LlamaParse |
|
from llama_index.core import Settings, SimpleDirectoryReader |
|
|
|
|
|
from langgraph.graph import StateGraph, END, START |
|
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
from typing import Dict, List, Tuple, Any, TypedDict |
|
|
|
|
|
import numpy as np |
|
from groq import Groq |
|
from mem0 import MemoryClient |
|
import streamlit as st |
|
from datetime import datetime |
|
|
|
import traceback |
|
import time |
|
import random |
|
from datetime import datetime |
|
from typing import Dict, List |
|
|
|
|
|
|
|
api_key = os.environ['api_key'] |
|
endpoint = os.environ['OPENAI_API_BASE'] |
|
|
|
model_name = os.environ['CHATGPT_MODEL'] |
|
emb_key = os.environ['EMB_MODEL_KEY'] |
|
emb_endpoint = os.environ['EMB_DEPLOYMENT'] |
|
llama_api_key = os.environ['LLAMA_GUARD_API_KEY'] |
|
mem0_api_key = os.environ['mem0_api_key'] |
|
|
|
|
|
embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction( |
|
api_base=endpoint, |
|
api_key=api_key, |
|
model_name='text-embedding-ada-002' |
|
) |
|
|
|
|
|
|
|
embedding_model = OpenAIEmbeddings( |
|
openai_api_base=endpoint, |
|
openai_api_key=api_key, |
|
model='text-embedding-ada-002' |
|
) |
|
|
|
|
|
|
|
llm = ChatOpenAI( |
|
openai_api_base=endpoint, |
|
openai_api_key=api_key, |
|
model="gpt-4o-mini", |
|
streaming=False |
|
) |
|
|
|
|
|
|
|
|
|
Settings.llm = llm |
|
Settings.embedding = embedding_model |
|
|
|
|
|
class AgentState(TypedDict): |
|
query: str |
|
expanded_query: str |
|
context: List[Dict[str, Any]] |
|
response: str |
|
precision_score: float |
|
groundedness_score: float |
|
groundedness_loop_count: int |
|
precision_loop_count: int |
|
feedback: str |
|
query_feedback: str |
|
groundedness_check: bool |
|
loop_max_iter: int |
|
|
|
def expand_query(state): |
|
print("State at the start of expand_query:", state) |
|
""" |
|
Expands the user query to improve retrieval of nutrition disorder-related information. |
|
|
|
Args: |
|
state (Dict): The current state of the workflow, containing the user query. |
|
|
|
Returns: |
|
Dict: The updated state with the expanded query. |
|
""" |
|
print("---------Expanding Query---------") |
|
system_message = '''You are a helpful research assistant that is well versed in Nutritional Disorders. |
|
Return an expanded user query based on the user's input query. The expanded query should be designed to improve retrieval of the most relevant information. |
|
Use the feedback if provided to craft the expanded query. |
|
''' |
|
|
|
expand_prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_message), |
|
("user", "Expand this query: {query} using the feedback: {query_feedback}") |
|
|
|
]) |
|
|
|
chain = expand_prompt | llm | StrOutputParser() |
|
expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]}) |
|
print("expanded_query", expanded_query) |
|
state["expanded_query"] = expanded_query |
|
return state |
|
|
|
|
|
|
|
vector_store = Chroma( |
|
collection_name="nutritional_hypotheticals", |
|
persist_directory="./nutritional_db", |
|
embedding_function=embedding_model |
|
|
|
) |
|
|
|
|
|
retriever = vector_store.as_retriever( |
|
search_type='similarity', |
|
search_kwargs={'k': 3} |
|
) |
|
|
|
def retrieve_context(state): |
|
print("State at the start of retrieve_context:", state) |
|
|
|
""" |
|
Retrieves context from the vector store using the expanded or original query. |
|
|
|
Args: |
|
state (Dict): The current state of the workflow, containing the query and expanded query. |
|
|
|
Returns: |
|
Dict: The updated state with the retrieved context. |
|
""" |
|
query = state['expanded_query'] |
|
print("Query used for retrieval:", query) |
|
|
|
|
|
docs = retriever.invoke(query) |
|
print("Retrieved documents:", docs) |
|
|
|
|
|
state['context'] = [ |
|
{ |
|
"content": doc.page_content, |
|
"metadata": doc.metadata |
|
} |
|
for doc in docs |
|
] |
|
|
|
print("Extracted context with metadata:", state['context']) |
|
return state |
|
|
|
|
|
|
|
def craft_response(state: Dict) -> Dict: |
|
print("State at the start of craft_response:", state) |
|
""" |
|
Generates a response using the retrieved context, focusing on nutrition disorders. |
|
|
|
Args: |
|
state (Dict): The current state of the workflow, containing the query and retrieved context. |
|
|
|
Returns: |
|
Dict: The updated state with the generated response. |
|
""" |
|
print("---------craft_response---------") |
|
system_message = '''You are an expert at condensing information. Your task is to extract relevant information for a given query and provide a grounded and highly precise response.''' |
|
|
|
response_prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_message), |
|
("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}") |
|
]) |
|
|
|
chain = response_prompt | llm |
|
response = chain.invoke({ |
|
"query": state['query'], |
|
"context": "\n".join([doc["content"] for doc in state['context']]), |
|
"feedback": state["feedback"] if state["feedback"] else "No feedback provided." |
|
}) |
|
state['response'] = response |
|
print("intermediate response: ", response) |
|
|
|
return state |
|
|
|
|
|
def score_groundedness(state: Dict) -> Dict: |
|
print("State at the start of score_groundedness:", state) |
|
""" |
|
Checks whether the response is grounded in the retrieved context. |
|
|
|
Args: |
|
state (Dict): The current state of the workflow, containing the response and context. |
|
|
|
Returns: |
|
Dict: The updated state with the groundedness score. |
|
""" |
|
print("---------check_groundedness---------") |
|
|
|
|
|
system_message = '''You are a groundedness evaluator. Your task is to assess how well the given response aligns with the provided context. |
|
- A grounded response is one that is accurate, directly supported by the context, and avoids speculation. |
|
- A response should not include information that cannot be verified or inferred from the context. |
|
|
|
Instructions: |
|
- Assign a score between 0.0 and 1.0, where: |
|
- 1.0: Fully grounded (entirely supported by the context). |
|
- 0.5: Partially grounded (some elements are supported, but others are speculative). |
|
- 0.0: Not grounded (contains speculative or unsupported information). |
|
- Provide only the numerical groundedness score as the output.''' |
|
|
|
|
|
groundedness_prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_message), |
|
("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:") |
|
]) |
|
|
|
|
|
chain = groundedness_prompt | llm | StrOutputParser() |
|
groundedness_score = float(chain.invoke({ |
|
"context": "\n".join([doc["content"] for doc in state['context']]), |
|
"response": state['response'] |
|
})) |
|
|
|
print("groundedness_score: ", groundedness_score) |
|
state['groundedness_loop_count'] += 1 |
|
print("######### Groundedness Loop Count Incremented ###########") |
|
state['groundedness_score'] = groundedness_score |
|
print("groundedness_score: ", state['groundedness_score']) |
|
|
|
return state |
|
|
|
|
|
def check_precision(state: Dict) -> Dict: |
|
|
|
print("State at the start of check_precision:", state) |
|
""" |
|
Checks whether the response precisely addresses the user’s query. |
|
|
|
Args: |
|
state (Dict): The current state of the workflow, containing the query and response. |
|
|
|
Returns: |
|
Dict: The updated state with the precision score. |
|
""" |
|
print("---------check_precision---------") |
|
|
|
|
|
system_message = '''You are a precision evaluator. Your task is to assess how well the given response directly and fully addresses the user's query. |
|
|
|
Instructions: |
|
- A precise response is one that: |
|
- Directly answers the user’s query without unnecessary or unrelated information. |
|
- Fully addresses all aspects of the query. |
|
- Avoids vague or overly general statements. |
|
- Assign a precision score between 0.0 and 1.0: |
|
- 1.0: Fully precise (direct, complete, and relevant to the query). |
|
- 0.5: Partially precise (addresses the query but is incomplete or includes some irrelevant information). |
|
- 0.0: Not precise (fails to address the query or contains mostly irrelevant information). |
|
- Provide only the numerical precision score as the output.''' |
|
|
|
|
|
precision_prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_message), |
|
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:") |
|
]) |
|
|
|
|
|
chain = precision_prompt | llm | StrOutputParser() |
|
precision_score = float(chain.invoke({ |
|
"query": state['query'], |
|
"response": state['response'] |
|
})) |
|
|
|
|
|
state['precision_score'] = precision_score |
|
print("precision_score:", precision_score) |
|
state['precision_loop_count'] += 1 |
|
print("#########Precision Incremented###########") |
|
return state |
|
|
|
def refine_response(state: Dict) -> Dict: |
|
print("State at the start of refine_response:", state) |
|
|
|
""" |
|
Suggests improvements for the generated response. |
|
|
|
Args: |
|
state (Dict): The current state of the workflow, containing the query and response. |
|
|
|
Returns: |
|
Dict: The updated state with response refinement suggestions. |
|
""" |
|
print("---------refine_response---------") |
|
|
|
system_message = '''You are a constructive feedback evaluator. Your task is to analyze the provided response and identify potential gaps, ambiguities, or missing details. Your feedback should help improve the response for accuracy, clarity, and completeness. |
|
|
|
Instructions: |
|
- Do not rewrite the response. |
|
- Focus on identifying the following: |
|
- Are there any gaps in the information provided? |
|
- Is the response ambiguous or unclear in any part? |
|
- Are there any details missing that are relevant to fully addressing the context or query? |
|
- Provide actionable and constructive suggestions for improvement. |
|
- Avoid criticism without offering specific recommendations. |
|
|
|
Your output should be written as a list of feedback points, with each suggestion clearly and concisely stated.''' |
|
|
|
refine_response_prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_message), |
|
("user", "Query: {query}\nResponse: {response}\n\n" |
|
"What improvements can be made to enhance accuracy and completeness?") |
|
]) |
|
|
|
chain = refine_response_prompt | llm | StrOutputParser() |
|
|
|
|
|
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}" |
|
print("feedback: ", feedback) |
|
print(f"State: {state}") |
|
state['feedback'] = feedback |
|
return state |
|
|
|
def refine_query(state: Dict) -> Dict: |
|
print("State at the start of refine_query:", state) |
|
""" |
|
Suggests improvements for the expanded query. |
|
|
|
Args: |
|
state (Dict): The current state of the workflow, containing the query and expanded query. |
|
|
|
Returns: |
|
Dict: The updated state with query refinement suggestions. |
|
""" |
|
print("---------refine_query---------") |
|
|
|
system_message = '''You are a query refinement assistant. Your task is to analyze the original query and its expanded version to suggest specific improvements that can enhance search precision and relevance. |
|
|
|
Instructions: |
|
- Do not replace or rewrite the expanded query. Instead, provide structured suggestions for improvement. |
|
- Focus on identifying: |
|
- Missing details or specific keywords that could make the query more precise. |
|
- Scope refinements to narrow or broaden the query if needed. |
|
- Ambiguities or redundancies that can be clarified or removed. |
|
- Ensure your suggestions are actionable and presented in a clear, concise, and structured format. |
|
- Avoid general or vague feedback; provide specific recommendations. |
|
|
|
Your output should be a list of suggestions that can improve the expanded query without modifying it directly.''' |
|
|
|
refine_query_prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_message), |
|
("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n" |
|
"What improvements can be made for a better search?") |
|
]) |
|
|
|
chain = refine_query_prompt | llm | StrOutputParser() |
|
|
|
|
|
query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}" |
|
print("query_feedback: ", query_feedback) |
|
print(f"Groundedness loop count: {state['groundedness_loop_count']}") |
|
state['query_feedback'] = query_feedback |
|
return state |
|
|
|
|
|
def should_continue_groundedness(state): |
|
print("State at the start of should_continue_groundedness:", state) |
|
"""Decides if groundedness is sufficient or needs improvement.""" |
|
print("---------should_continue_groundedness---------") |
|
print("groundedness loop count: ", state['groundedness_loop_count']) |
|
|
|
|
|
if state['groundedness_score'] >= 0.8: |
|
print("Moving to precision") |
|
return "check_precision" |
|
else: |
|
|
|
if state['groundedness_loop_count'] >= state['loop_max_iter']: |
|
return "max_iterations_reached" |
|
else: |
|
print(f"---------Groundedness Score Threshold Not Met. Refining Response-----------") |
|
return "refine_response" |
|
|
|
def should_continue_precision(state: Dict) -> str: |
|
print("State at the start of should_continue_precision:", state) |
|
|
|
"""Decides if precision is sufficient or needs improvement.""" |
|
print("---------should_continue_precision---------") |
|
print("precision loop count: ", state['precision_loop_count']) |
|
|
|
|
|
if state['precision_score'] >= 0.8: |
|
return "pass" |
|
else: |
|
|
|
if state['precision_loop_count'] >= state['loop_max_iter']: |
|
return "max_iterations_reached" |
|
else: |
|
print(f"---------Precision Score Threshold Not Met. Refining Query-----------") |
|
return "refine_query" |
|
|
|
|
|
def max_iterations_reached(state: Dict) -> Dict: |
|
"""Handles the case when the maximum number of iterations is reached.""" |
|
print("---------max_iterations_reached---------") |
|
"""Handles the case when the maximum number of iterations is reached.""" |
|
response = "I'm unable to refine the response further. Please provide more context or clarify your question." |
|
state['response'] = response |
|
return state |
|
|
|
|
|
|
|
from langgraph.graph import END, StateGraph, START |
|
|
|
from langgraph.graph import StateGraph, START, END |
|
from typing import Callable |
|
|
|
|
|
def create_workflow() -> StateGraph: |
|
"""Creates the updated workflow for the AI nutrition agent.""" |
|
|
|
workflow = StateGraph(state_schema=AgentState) |
|
|
|
|
|
workflow.add_node("expand_query", expand_query) |
|
workflow.add_node("retrieve_context", retrieve_context) |
|
workflow.add_node("craft_response", craft_response) |
|
workflow.add_node("score_groundedness", score_groundedness) |
|
workflow.add_node("refine_response", refine_response) |
|
workflow.add_node("check_precision", check_precision) |
|
workflow.add_node("refine_query", refine_query) |
|
workflow.add_node("max_iterations_reached", max_iterations_reached) |
|
|
|
|
|
workflow.add_edge(START, "expand_query") |
|
workflow.add_edge("expand_query", "retrieve_context") |
|
workflow.add_edge("retrieve_context", "craft_response") |
|
workflow.add_edge("craft_response", "score_groundedness") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"score_groundedness", |
|
should_continue_groundedness, |
|
{ |
|
"check_precision": "check_precision", |
|
"refine_response": "refine_response", |
|
"max_iterations_reached": "max_iterations_reached" |
|
} |
|
) |
|
|
|
workflow.add_edge("refine_response", "craft_response") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"check_precision", |
|
should_continue_precision, |
|
{ |
|
"pass": END, |
|
"refine_query": "refine_query", |
|
"max_iterations_reached": "max_iterations_reached" |
|
} |
|
) |
|
|
|
workflow.add_edge("refine_query", "expand_query") |
|
workflow.add_edge("max_iterations_reached", END) |
|
|
|
return workflow |
|
|
|
|
|
|
|
|
|
WORKFLOW_APP = create_workflow().compile() |
|
|
|
|
|
@tool |
|
def agentic_rag(query: str): |
|
""" |
|
Runs the RAG-based agent with conversation history for context-aware responses. |
|
|
|
Args: |
|
query (str): The current user query. |
|
|
|
Returns: |
|
Dict[str, Any]: The updated state with the generated response and conversation history. |
|
""" |
|
|
|
inputs = { |
|
"query": query, |
|
"expanded_query": "", |
|
"context": [], |
|
"response": "", |
|
"precision_score": 0.0, |
|
"groundedness_score": 0.0, |
|
"groundedness_loop_count": 0, |
|
"precision_loop_count": 0, |
|
"feedback": "", |
|
"query_feedback": "", |
|
"loop_max_iter": 3 |
|
} |
|
|
|
output = WORKFLOW_APP.invoke(inputs) |
|
|
|
return output |
|
|
|
|
|
llama_guard_client = Groq(api_key=llama_api_key) |
|
|
|
def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"): |
|
""" |
|
Filters user input using Llama Guard to ensure it is safe. |
|
|
|
Parameters: |
|
- user_input: The input provided by the user. |
|
- model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b"). |
|
|
|
Returns: |
|
- The filtered and safe input. |
|
""" |
|
try: |
|
|
|
response = llama_guard_client.chat.completions.create( |
|
messages=[{"role": "user", "content": user_input}], |
|
model=model, |
|
) |
|
|
|
return response.choices[0].message.content.strip() |
|
except Exception as e: |
|
print(f"Error with Llama Guard: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
class NutritionBot: |
|
def __init__(self, api_key: str, api_base: str): |
|
""" |
|
Initialize the NutritionBot class, setting up memory and the LLM client. |
|
|
|
Args: |
|
api_key (str): The OpenAI API key for authenticating requests. |
|
api_base (str): The custom OpenAI API base endpoint. |
|
""" |
|
print(f"Initializing NutritionBot with OpenAI API key: {api_key}") |
|
print(f"Using custom OpenAI API base: {api_base}") |
|
|
|
|
|
self.memory = MemoryClient(api_key=userdata.get("mem0_api_key")) |
|
print("Memory client initialized.") |
|
|
|
|
|
self.client = ChatOpenAI( |
|
model_name="gpt-4o-mini", |
|
api_key=api_key, |
|
openai_api_base=api_base, |
|
temperature=0.7, |
|
verbose=True |
|
) |
|
print("OpenAI client initialized with custom API base and model gpt-4o-mini.") |
|
|
|
def get_relevant_history(self, user_id: str, query: str) -> List[Dict]: |
|
""" |
|
Retrieve past interactions relevant to the current query. |
|
|
|
Args: |
|
user_id (str): Unique identifier for the customer. |
|
query (str): The customer's current query. |
|
|
|
Returns: |
|
List[Dict]: A list of relevant past interactions. |
|
""" |
|
print("Entering get_relevant_history function...") |
|
try: |
|
history = self.memory.search( |
|
query=query, |
|
user_id=user_id, |
|
limit=3 |
|
) |
|
print("Relevant history retrieved:", history) |
|
return history |
|
except Exception as e: |
|
print(f"Error retrieving history: {e}") |
|
traceback.print_exc() |
|
return [] |
|
|
|
def query_model(self, prompt: str) -> str: |
|
""" |
|
Query the OpenAI model directly using the prompt. |
|
|
|
Args: |
|
prompt (str): The input prompt for the model. |
|
|
|
Returns: |
|
str: The assistant's response. |
|
""" |
|
print("Querying the OpenAI model...") |
|
try: |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
response = self.client.invoke(messages) |
|
print("Raw response from OpenAI API:", response) |
|
|
|
|
|
content = response.content |
|
print("Extracted response content:", content) |
|
return content |
|
except Exception as e: |
|
print(f"Error querying the model: {e}") |
|
traceback.print_exc() |
|
return "I'm sorry, I couldn't process your request. Please try again later." |
|
|
|
def handle_customer_query(self, user_id: str, query: str) -> str: |
|
""" |
|
Process a customer's query and provide a response, incorporating past interactions for context. |
|
|
|
Args: |
|
user_id (str): Unique identifier for the customer. |
|
query (str): Customer's query. |
|
|
|
Returns: |
|
str: Chatbot's response. |
|
""" |
|
print("Entering handle_customer_query function...") |
|
|
|
|
|
relevant_history = self.get_relevant_history(user_id, query) |
|
|
|
|
|
context = "Previous relevant interactions:\n" |
|
for memory in relevant_history: |
|
context += f"Customer: {memory['query']}\n" |
|
context += f"Support: {memory['response']}\n---\n" |
|
|
|
|
|
prompt = f""" |
|
Context: |
|
{context} |
|
|
|
Current customer query: {query} |
|
|
|
Provide a helpful response that takes into account any relevant past interactions. |
|
""" |
|
print("Final prompt being sent to the model:") |
|
print(prompt) |
|
|
|
|
|
max_retries = 3 |
|
for attempt in range(max_retries): |
|
try: |
|
print(f"Querying model (attempt {attempt + 1})...") |
|
response_content = self.query_model(prompt) |
|
if not response_content: |
|
raise ValueError("Model returned an empty response.") |
|
|
|
|
|
self.store_customer_interaction( |
|
user_id=user_id, |
|
message=query, |
|
response=response_content, |
|
metadata={"type": "support_query"} |
|
) |
|
return response_content |
|
|
|
except Exception as e: |
|
print(f"Error querying the model (attempt {attempt + 1}): {e}") |
|
traceback.print_exc() |
|
if attempt < max_retries - 1: |
|
wait_time = (2 ** attempt) + random.uniform(0, 1) |
|
print(f"Retrying in {wait_time:.2f} seconds...") |
|
time.sleep(wait_time) |
|
else: |
|
return "I'm sorry, I couldn't process your request. Please try again later." |
|
|
|
def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None): |
|
""" |
|
Store customer interaction in memory for future reference. |
|
|
|
Args: |
|
user_id (str): Unique identifier for the customer. |
|
message (str): Customer's query or message. |
|
response (str): Chatbot's response. |
|
metadata (Dict, optional): Additional metadata for the interaction. |
|
""" |
|
print("Entering store_customer_interaction function...") |
|
|
|
if metadata is None: |
|
metadata = {} |
|
|
|
|
|
metadata["timestamp"] = datetime.now().isoformat() |
|
|
|
|
|
conversation = [ |
|
{"role": "user", "content": message}, |
|
{"role": "assistant", "content": response} |
|
] |
|
|
|
try: |
|
self.memory.add( |
|
conversation, |
|
user_id=user_id, |
|
output_format="v1.1", |
|
metadata=metadata |
|
) |
|
print("Interaction stored successfully.") |
|
except Exception as e: |
|
print(f"Error storing interaction: {e}") |
|
traceback.print_exc() |
|
|
|
|
|
def nutrition_disorder_streamlit(): |
|
""" |
|
A Streamlit-based UI for the Nutrition Disorder Specialist Agent. |
|
""" |
|
st.title("Nutrition Disorder Specialist") |
|
st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.") |
|
st.write("Type 'exit' to end the conversation.") |
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state.chat_history = [] |
|
if 'user_id' not in st.session_state: |
|
st.session_state.user_id = None |
|
|
|
|
|
if st.session_state.user_id is None: |
|
with st.form("login_form", clear_on_submit=True): |
|
user_id = st.text_input("Please enter your name to begin:") |
|
submit_button = st.form_submit_button("Login") |
|
if submit_button and user_id: |
|
st.session_state.user_id = user_id |
|
st.session_state.chat_history.append({ |
|
"role": "assistant", |
|
"content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?" |
|
}) |
|
st.session_state.login_submitted = True |
|
if st.session_state.get("login_submitted", False): |
|
st.session_state.pop("login_submitted") |
|
st.rerun() |
|
else: |
|
|
|
for message in st.session_state.chat_history: |
|
with st.chat_message(message["role"]): |
|
st.write(message["content"]) |
|
|
|
|
|
user_query = st.chat_input("You: ").strip() |
|
if user_query: |
|
if user_query.lower() == "exit": |
|
st.session_state.chat_history.append({"role": "user", "content": "exit"}) |
|
with st.chat_message("user"): |
|
st.write("exit") |
|
goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders." |
|
st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg}) |
|
with st.chat_message("assistant"): |
|
st.write(goodbye_msg) |
|
st.session_state.user_id = None |
|
st.rerun() |
|
return |
|
|
|
st.session_state.chat_history.append({"role": "user", "content": user_query}) |
|
with st.chat_message("user"): |
|
st.write(user_query) |
|
|
|
|
|
filtered_result = filter_input_with_llama_guard(user_query) |
|
filtered_result = filtered_result.replace("\n", " ") |
|
|
|
|
|
if filtered_result in ["safe", "unsafe S7", "unsafe S6"]: |
|
try: |
|
if 'chatbot' not in st.session_state: |
|
st.session_state.chatbot = NutritionBot() |
|
response = st.session_state.chatbot.andle_customer_query(st.session_state.user_id, user_query) |
|
|
|
st.write(response) |
|
st.session_state.chat_history.append({"role": "assistant", "content": response}) |
|
except Exception as e: |
|
error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}" |
|
st.write(error_msg) |
|
st.session_state.chat_history.append({"role": "assistant", "content": error_msg}) |
|
else: |
|
inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again." |
|
st.write(inappropriate_msg) |
|
st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg}) |
|
|
|
if __name__ == "__main__": |
|
nutrition_disorder_streamlit() |
|
|