File size: 5,614 Bytes
7c923c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import streamlit as st
import os
from langgraph.graph import MessagesState, StateGraph, START, END
from typing_extensions import TypedDict
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage
from typing import Annotated
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from langchain_groq import ChatGroq
# Define the state
class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# Create graph function
def create_chat_graph(system_prompt, model_name):
# Initialize LLM
llm = ChatGroq(model=model_name)
# Create system message
system_message = SystemMessage(content=system_prompt)
# Define the assistant function
def assistant(state: MessagesState):
# Get all messages including the system message
messages = [system_message] + state["messages"]
# Generate response
response = llm.invoke(messages)
# Return the response
return {"messages": [response]}
# Initialize the graph builder
builder = StateGraph(MessagesState)
# Add the assistant node
builder.add_node("assistant", assistant)
# Define edges
builder.add_edge(START, "assistant")
builder.add_edge("assistant", END)
# Create memory saver for persistence
memory = MemorySaver()
# Compile the graph with memory
graph = builder.compile(checkpointer=memory)
return graph
# Set up Streamlit page
st.set_page_config(page_title="Conversational AI Assistant", page_icon="💬")
st.title("AI Chatbot with Memory")
# Sidebar configuration
st.sidebar.header("Configuration")
# API Key input (using st.secrets in production)
if "GROQ_API_KEY" not in os.environ:
api_key = st.sidebar.text_input("Enter your Groq API Key:", type="password")
if api_key:
os.environ["GROQ_API_KEY"] = api_key
else:
st.sidebar.warning("Please enter your Groq API key to continue.")
# Model selection
model_options = ["llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"]
selected_model = st.sidebar.selectbox("Select Model:", model_options)
# System prompt
default_prompt = "You are a helpful and friendly assistant. Maintain a conversational tone and remember previous interactions with the user."
system_prompt = st.sidebar.text_area("System Prompt:", value=default_prompt, height=150)
# Session ID for this conversation
if "session_id" not in st.session_state:
import uuid
st.session_state.session_id = str(uuid.uuid4())
# Initialize or get chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Initialize the graph on first run or when config changes
if "chat_graph" not in st.session_state or st.sidebar.button("Reset Conversation"):
if "GROQ_API_KEY" in os.environ:
with st.spinner("Initializing chatbot..."):
st.session_state.chat_graph = create_chat_graph(system_prompt, selected_model)
st.session_state.messages = [] # Clear messages on reset
st.success("Chatbot initialized!")
else:
st.sidebar.error("API key required to initialize chatbot.")
# Display chat history
for message in st.session_state.messages:
if isinstance(message, dict): # Handle dict format
role = message.get("role", "")
content = message.get("content", "")
else: # Handle direct string format
role = "user" if message.startswith("User: ") else "assistant"
content = message.replace("User: ", "").replace("Assistant: ", "")
with st.chat_message(role):
st.write(content)
# Input for new message
if "chat_graph" in st.session_state and "GROQ_API_KEY" in os.environ:
user_input = st.chat_input("Type your message here...")
if user_input:
# Display user message
with st.chat_message("user"):
st.write(user_input)
# Add to history
st.session_state.messages.append({"role": "user", "content": user_input})
# Get response from the chatbot
with st.spinner("Thinking..."):
# Call the graph with the user's message
config = {"configurable": {"thread_id": st.session_state.session_id}}
user_message = [HumanMessage(content=user_input)]
result = st.session_state.chat_graph.invoke({"messages": user_message}, config)
# Extract response
response = result["messages"][-1].content
# Display assistant response
with st.chat_message("assistant"):
st.write(response)
# Add to history
st.session_state.messages.append({"role": "assistant", "content": response})
# Add some additional info in the sidebar
st.sidebar.markdown("---")
st.sidebar.subheader("About")
st.sidebar.info(
"""
This chatbot uses LangGraph for maintaining conversation context and
ChatGroq's language models for generating responses. Each conversation
has a unique session ID to maintain history.
"""
)
# Download chat history
if st.sidebar.button("Download Chat History"):
import json
from datetime import datetime
# Convert chat history to downloadable format
chat_export = "\n".join([f"{m['role']}: {m['content']}" for m in st.session_state.messages])
# Create download button
st.sidebar.download_button(
label="Download as Text",
data=chat_export,
file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
mime="text/plain"
) |