Phoenix21's picture
Create app.py
7c923c2 verified
import streamlit as st
import os
from langgraph.graph import MessagesState, StateGraph, START, END
from typing_extensions import TypedDict
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage
from typing import Annotated
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from langchain_groq import ChatGroq
# Define the state
class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# Create graph function
def create_chat_graph(system_prompt, model_name):
# Initialize LLM
llm = ChatGroq(model=model_name)
# Create system message
system_message = SystemMessage(content=system_prompt)
# Define the assistant function
def assistant(state: MessagesState):
# Get all messages including the system message
messages = [system_message] + state["messages"]
# Generate response
response = llm.invoke(messages)
# Return the response
return {"messages": [response]}
# Initialize the graph builder
builder = StateGraph(MessagesState)
# Add the assistant node
builder.add_node("assistant", assistant)
# Define edges
builder.add_edge(START, "assistant")
builder.add_edge("assistant", END)
# Create memory saver for persistence
memory = MemorySaver()
# Compile the graph with memory
graph = builder.compile(checkpointer=memory)
return graph
# Set up Streamlit page
st.set_page_config(page_title="Conversational AI Assistant", page_icon="πŸ’¬")
st.title("AI Chatbot with Memory")
# Sidebar configuration
st.sidebar.header("Configuration")
# API Key input (using st.secrets in production)
if "GROQ_API_KEY" not in os.environ:
api_key = st.sidebar.text_input("Enter your Groq API Key:", type="password")
if api_key:
os.environ["GROQ_API_KEY"] = api_key
else:
st.sidebar.warning("Please enter your Groq API key to continue.")
# Model selection
model_options = ["llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"]
selected_model = st.sidebar.selectbox("Select Model:", model_options)
# System prompt
default_prompt = "You are a helpful and friendly assistant. Maintain a conversational tone and remember previous interactions with the user."
system_prompt = st.sidebar.text_area("System Prompt:", value=default_prompt, height=150)
# Session ID for this conversation
if "session_id" not in st.session_state:
import uuid
st.session_state.session_id = str(uuid.uuid4())
# Initialize or get chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Initialize the graph on first run or when config changes
if "chat_graph" not in st.session_state or st.sidebar.button("Reset Conversation"):
if "GROQ_API_KEY" in os.environ:
with st.spinner("Initializing chatbot..."):
st.session_state.chat_graph = create_chat_graph(system_prompt, selected_model)
st.session_state.messages = [] # Clear messages on reset
st.success("Chatbot initialized!")
else:
st.sidebar.error("API key required to initialize chatbot.")
# Display chat history
for message in st.session_state.messages:
if isinstance(message, dict): # Handle dict format
role = message.get("role", "")
content = message.get("content", "")
else: # Handle direct string format
role = "user" if message.startswith("User: ") else "assistant"
content = message.replace("User: ", "").replace("Assistant: ", "")
with st.chat_message(role):
st.write(content)
# Input for new message
if "chat_graph" in st.session_state and "GROQ_API_KEY" in os.environ:
user_input = st.chat_input("Type your message here...")
if user_input:
# Display user message
with st.chat_message("user"):
st.write(user_input)
# Add to history
st.session_state.messages.append({"role": "user", "content": user_input})
# Get response from the chatbot
with st.spinner("Thinking..."):
# Call the graph with the user's message
config = {"configurable": {"thread_id": st.session_state.session_id}}
user_message = [HumanMessage(content=user_input)]
result = st.session_state.chat_graph.invoke({"messages": user_message}, config)
# Extract response
response = result["messages"][-1].content
# Display assistant response
with st.chat_message("assistant"):
st.write(response)
# Add to history
st.session_state.messages.append({"role": "assistant", "content": response})
# Add some additional info in the sidebar
st.sidebar.markdown("---")
st.sidebar.subheader("About")
st.sidebar.info(
"""
This chatbot uses LangGraph for maintaining conversation context and
ChatGroq's language models for generating responses. Each conversation
has a unique session ID to maintain history.
"""
)
# Download chat history
if st.sidebar.button("Download Chat History"):
import json
from datetime import datetime
# Convert chat history to downloadable format
chat_export = "\n".join([f"{m['role']}: {m['content']}" for m in st.session_state.messages])
# Create download button
st.sidebar.download_button(
label="Download as Text",
data=chat_export,
file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
mime="text/plain"
)