|
import streamlit as st |
|
import os |
|
from langgraph.graph import MessagesState, StateGraph, START, END |
|
from typing_extensions import TypedDict |
|
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage |
|
from typing import Annotated |
|
from langgraph.graph.message import add_messages |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langchain_groq import ChatGroq |
|
|
|
|
|
class MessagesState(TypedDict): |
|
messages: Annotated[list[AnyMessage], add_messages] |
|
|
|
|
|
def create_chat_graph(system_prompt, model_name): |
|
|
|
llm = ChatGroq(model=model_name) |
|
|
|
|
|
system_message = SystemMessage(content=system_prompt) |
|
|
|
|
|
def assistant(state: MessagesState): |
|
|
|
messages = [system_message] + state["messages"] |
|
|
|
response = llm.invoke(messages) |
|
|
|
return {"messages": [response]} |
|
|
|
|
|
builder = StateGraph(MessagesState) |
|
|
|
|
|
builder.add_node("assistant", assistant) |
|
|
|
|
|
builder.add_edge(START, "assistant") |
|
builder.add_edge("assistant", END) |
|
|
|
|
|
memory = MemorySaver() |
|
|
|
|
|
graph = builder.compile(checkpointer=memory) |
|
|
|
return graph |
|
|
|
|
|
st.set_page_config(page_title="Conversational AI Assistant", page_icon="π¬") |
|
st.title("AI Chatbot with Memory") |
|
|
|
|
|
st.sidebar.header("Configuration") |
|
|
|
|
|
if "GROQ_API_KEY" not in os.environ: |
|
api_key = st.sidebar.text_input("Enter your Groq API Key:", type="password") |
|
if api_key: |
|
os.environ["GROQ_API_KEY"] = api_key |
|
else: |
|
st.sidebar.warning("Please enter your Groq API key to continue.") |
|
|
|
|
|
model_options = ["llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"] |
|
selected_model = st.sidebar.selectbox("Select Model:", model_options) |
|
|
|
|
|
default_prompt = "You are a helpful and friendly assistant. Maintain a conversational tone and remember previous interactions with the user." |
|
system_prompt = st.sidebar.text_area("System Prompt:", value=default_prompt, height=150) |
|
|
|
|
|
if "session_id" not in st.session_state: |
|
import uuid |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
if "chat_graph" not in st.session_state or st.sidebar.button("Reset Conversation"): |
|
if "GROQ_API_KEY" in os.environ: |
|
with st.spinner("Initializing chatbot..."): |
|
st.session_state.chat_graph = create_chat_graph(system_prompt, selected_model) |
|
st.session_state.messages = [] |
|
st.success("Chatbot initialized!") |
|
else: |
|
st.sidebar.error("API key required to initialize chatbot.") |
|
|
|
|
|
for message in st.session_state.messages: |
|
if isinstance(message, dict): |
|
role = message.get("role", "") |
|
content = message.get("content", "") |
|
else: |
|
role = "user" if message.startswith("User: ") else "assistant" |
|
content = message.replace("User: ", "").replace("Assistant: ", "") |
|
|
|
with st.chat_message(role): |
|
st.write(content) |
|
|
|
|
|
if "chat_graph" in st.session_state and "GROQ_API_KEY" in os.environ: |
|
user_input = st.chat_input("Type your message here...") |
|
|
|
if user_input: |
|
|
|
with st.chat_message("user"): |
|
st.write(user_input) |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
with st.spinner("Thinking..."): |
|
|
|
config = {"configurable": {"thread_id": st.session_state.session_id}} |
|
user_message = [HumanMessage(content=user_input)] |
|
result = st.session_state.chat_graph.invoke({"messages": user_message}, config) |
|
|
|
|
|
response = result["messages"][-1].content |
|
|
|
|
|
with st.chat_message("assistant"): |
|
st.write(response) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
st.sidebar.markdown("---") |
|
st.sidebar.subheader("About") |
|
st.sidebar.info( |
|
""" |
|
This chatbot uses LangGraph for maintaining conversation context and |
|
ChatGroq's language models for generating responses. Each conversation |
|
has a unique session ID to maintain history. |
|
""" |
|
) |
|
|
|
|
|
if st.sidebar.button("Download Chat History"): |
|
import json |
|
from datetime import datetime |
|
|
|
|
|
chat_export = "\n".join([f"{m['role']}: {m['content']}" for m in st.session_state.messages]) |
|
|
|
|
|
st.sidebar.download_button( |
|
label="Download as Text", |
|
data=chat_export, |
|
file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", |
|
mime="text/plain" |
|
) |