import streamlit as st import os from langgraph.graph import MessagesState, StateGraph, START, END from typing_extensions import TypedDict from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage from typing import Annotated from langgraph.graph.message import add_messages from langgraph.checkpoint.memory import MemorySaver from langchain_groq import ChatGroq # Define the state class MessagesState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] # Create graph function def create_chat_graph(system_prompt, model_name): # Initialize LLM llm = ChatGroq(model=model_name) # Create system message system_message = SystemMessage(content=system_prompt) # Define the assistant function def assistant(state: MessagesState): # Get all messages including the system message messages = [system_message] + state["messages"] # Generate response response = llm.invoke(messages) # Return the response return {"messages": [response]} # Initialize the graph builder builder = StateGraph(MessagesState) # Add the assistant node builder.add_node("assistant", assistant) # Define edges builder.add_edge(START, "assistant") builder.add_edge("assistant", END) # Create memory saver for persistence memory = MemorySaver() # Compile the graph with memory graph = builder.compile(checkpointer=memory) return graph # Set up Streamlit page st.set_page_config(page_title="Conversational AI Assistant", page_icon="💬") st.title("AI Chatbot with Memory") # Sidebar configuration st.sidebar.header("Configuration") # API Key input (using st.secrets in production) if "GROQ_API_KEY" not in os.environ: api_key = st.sidebar.text_input("Enter your Groq API Key:", type="password") if api_key: os.environ["GROQ_API_KEY"] = api_key else: st.sidebar.warning("Please enter your Groq API key to continue.") # Model selection model_options = ["llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"] selected_model = st.sidebar.selectbox("Select Model:", model_options) # System prompt default_prompt = "You are a helpful and friendly assistant. Maintain a conversational tone and remember previous interactions with the user." system_prompt = st.sidebar.text_area("System Prompt:", value=default_prompt, height=150) # Session ID for this conversation if "session_id" not in st.session_state: import uuid st.session_state.session_id = str(uuid.uuid4()) # Initialize or get chat history if "messages" not in st.session_state: st.session_state.messages = [] # Initialize the graph on first run or when config changes if "chat_graph" not in st.session_state or st.sidebar.button("Reset Conversation"): if "GROQ_API_KEY" in os.environ: with st.spinner("Initializing chatbot..."): st.session_state.chat_graph = create_chat_graph(system_prompt, selected_model) st.session_state.messages = [] # Clear messages on reset st.success("Chatbot initialized!") else: st.sidebar.error("API key required to initialize chatbot.") # Display chat history for message in st.session_state.messages: if isinstance(message, dict): # Handle dict format role = message.get("role", "") content = message.get("content", "") else: # Handle direct string format role = "user" if message.startswith("User: ") else "assistant" content = message.replace("User: ", "").replace("Assistant: ", "") with st.chat_message(role): st.write(content) # Input for new message if "chat_graph" in st.session_state and "GROQ_API_KEY" in os.environ: user_input = st.chat_input("Type your message here...") if user_input: # Display user message with st.chat_message("user"): st.write(user_input) # Add to history st.session_state.messages.append({"role": "user", "content": user_input}) # Get response from the chatbot with st.spinner("Thinking..."): # Call the graph with the user's message config = {"configurable": {"thread_id": st.session_state.session_id}} user_message = [HumanMessage(content=user_input)] result = st.session_state.chat_graph.invoke({"messages": user_message}, config) # Extract response response = result["messages"][-1].content # Display assistant response with st.chat_message("assistant"): st.write(response) # Add to history st.session_state.messages.append({"role": "assistant", "content": response}) # Add some additional info in the sidebar st.sidebar.markdown("---") st.sidebar.subheader("About") st.sidebar.info( """ This chatbot uses LangGraph for maintaining conversation context and ChatGroq's language models for generating responses. Each conversation has a unique session ID to maintain history. """ ) # Download chat history if st.sidebar.button("Download Chat History"): import json from datetime import datetime # Convert chat history to downloadable format chat_export = "\n".join([f"{m['role']}: {m['content']}" for m in st.session_state.messages]) # Create download button st.sidebar.download_button( label="Download as Text", data=chat_export, file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", mime="text/plain" )