File size: 5,614 Bytes
7c923c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import streamlit as st
import os
from langgraph.graph import MessagesState, StateGraph, START, END
from typing_extensions import TypedDict
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage
from typing import Annotated
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from langchain_groq import ChatGroq

# Define the state
class MessagesState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

# Create graph function
def create_chat_graph(system_prompt, model_name):
    # Initialize LLM
    llm = ChatGroq(model=model_name)
    
    # Create system message
    system_message = SystemMessage(content=system_prompt)
    
    # Define the assistant function
    def assistant(state: MessagesState):
        # Get all messages including the system message
        messages = [system_message] + state["messages"]
        # Generate response
        response = llm.invoke(messages)
        # Return the response
        return {"messages": [response]}
    
    # Initialize the graph builder
    builder = StateGraph(MessagesState)
    
    # Add the assistant node
    builder.add_node("assistant", assistant)
    
    # Define edges
    builder.add_edge(START, "assistant")
    builder.add_edge("assistant", END)
    
    # Create memory saver for persistence
    memory = MemorySaver()
    
    # Compile the graph with memory
    graph = builder.compile(checkpointer=memory)
    
    return graph

# Set up Streamlit page
st.set_page_config(page_title="Conversational AI Assistant", page_icon="💬")
st.title("AI Chatbot with Memory")

# Sidebar configuration
st.sidebar.header("Configuration")

# API Key input (using st.secrets in production)
if "GROQ_API_KEY" not in os.environ:
    api_key = st.sidebar.text_input("Enter your Groq API Key:", type="password")
    if api_key:
        os.environ["GROQ_API_KEY"] = api_key
    else:
        st.sidebar.warning("Please enter your Groq API key to continue.")

# Model selection
model_options = ["llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"]
selected_model = st.sidebar.selectbox("Select Model:", model_options)

# System prompt
default_prompt = "You are a helpful and friendly assistant. Maintain a conversational tone and remember previous interactions with the user."
system_prompt = st.sidebar.text_area("System Prompt:", value=default_prompt, height=150)

# Session ID for this conversation
if "session_id" not in st.session_state:
    import uuid
    st.session_state.session_id = str(uuid.uuid4())

# Initialize or get chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Initialize the graph on first run or when config changes
if "chat_graph" not in st.session_state or st.sidebar.button("Reset Conversation"):
    if "GROQ_API_KEY" in os.environ:
        with st.spinner("Initializing chatbot..."):
            st.session_state.chat_graph = create_chat_graph(system_prompt, selected_model)
            st.session_state.messages = []  # Clear messages on reset
            st.success("Chatbot initialized!")
    else:
        st.sidebar.error("API key required to initialize chatbot.")

# Display chat history
for message in st.session_state.messages:
    if isinstance(message, dict):  # Handle dict format
        role = message.get("role", "")
        content = message.get("content", "")
    else:  # Handle direct string format
        role = "user" if message.startswith("User: ") else "assistant"
        content = message.replace("User: ", "").replace("Assistant: ", "")
    
    with st.chat_message(role):
        st.write(content)

# Input for new message
if "chat_graph" in st.session_state and "GROQ_API_KEY" in os.environ:
    user_input = st.chat_input("Type your message here...")
    
    if user_input:
        # Display user message
        with st.chat_message("user"):
            st.write(user_input)
        
        # Add to history
        st.session_state.messages.append({"role": "user", "content": user_input})
        
        # Get response from the chatbot
        with st.spinner("Thinking..."):
            # Call the graph with the user's message
            config = {"configurable": {"thread_id": st.session_state.session_id}}
            user_message = [HumanMessage(content=user_input)]
            result = st.session_state.chat_graph.invoke({"messages": user_message}, config)
            
            # Extract response
            response = result["messages"][-1].content
        
        # Display assistant response
        with st.chat_message("assistant"):
            st.write(response)
        
        # Add to history
        st.session_state.messages.append({"role": "assistant", "content": response})

# Add some additional info in the sidebar
st.sidebar.markdown("---")
st.sidebar.subheader("About")
st.sidebar.info(
    """
    This chatbot uses LangGraph for maintaining conversation context and 
    ChatGroq's language models for generating responses. Each conversation 
    has a unique session ID to maintain history.
    """
)

# Download chat history
if st.sidebar.button("Download Chat History"):
    import json
    from datetime import datetime
    
    # Convert chat history to downloadable format
    chat_export = "\n".join([f"{m['role']}: {m['content']}" for m in st.session_state.messages])
    
    # Create download button
    st.sidebar.download_button(
        label="Download as Text",
        data=chat_export,
        file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
        mime="text/plain"
    )