Phoenix21 commited on
Commit
7c923c2
·
verified ·
1 Parent(s): 5552508

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from langgraph.graph import MessagesState, StateGraph, START, END
4
+ from typing_extensions import TypedDict
5
+ from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage
6
+ from typing import Annotated
7
+ from langgraph.graph.message import add_messages
8
+ from langgraph.checkpoint.memory import MemorySaver
9
+ from langchain_groq import ChatGroq
10
+
11
+ # Define the state
12
+ class MessagesState(TypedDict):
13
+ messages: Annotated[list[AnyMessage], add_messages]
14
+
15
+ # Create graph function
16
+ def create_chat_graph(system_prompt, model_name):
17
+ # Initialize LLM
18
+ llm = ChatGroq(model=model_name)
19
+
20
+ # Create system message
21
+ system_message = SystemMessage(content=system_prompt)
22
+
23
+ # Define the assistant function
24
+ def assistant(state: MessagesState):
25
+ # Get all messages including the system message
26
+ messages = [system_message] + state["messages"]
27
+ # Generate response
28
+ response = llm.invoke(messages)
29
+ # Return the response
30
+ return {"messages": [response]}
31
+
32
+ # Initialize the graph builder
33
+ builder = StateGraph(MessagesState)
34
+
35
+ # Add the assistant node
36
+ builder.add_node("assistant", assistant)
37
+
38
+ # Define edges
39
+ builder.add_edge(START, "assistant")
40
+ builder.add_edge("assistant", END)
41
+
42
+ # Create memory saver for persistence
43
+ memory = MemorySaver()
44
+
45
+ # Compile the graph with memory
46
+ graph = builder.compile(checkpointer=memory)
47
+
48
+ return graph
49
+
50
+ # Set up Streamlit page
51
+ st.set_page_config(page_title="Conversational AI Assistant", page_icon="💬")
52
+ st.title("AI Chatbot with Memory")
53
+
54
+ # Sidebar configuration
55
+ st.sidebar.header("Configuration")
56
+
57
+ # API Key input (using st.secrets in production)
58
+ if "GROQ_API_KEY" not in os.environ:
59
+ api_key = st.sidebar.text_input("Enter your Groq API Key:", type="password")
60
+ if api_key:
61
+ os.environ["GROQ_API_KEY"] = api_key
62
+ else:
63
+ st.sidebar.warning("Please enter your Groq API key to continue.")
64
+
65
+ # Model selection
66
+ model_options = ["llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"]
67
+ selected_model = st.sidebar.selectbox("Select Model:", model_options)
68
+
69
+ # System prompt
70
+ default_prompt = "You are a helpful and friendly assistant. Maintain a conversational tone and remember previous interactions with the user."
71
+ system_prompt = st.sidebar.text_area("System Prompt:", value=default_prompt, height=150)
72
+
73
+ # Session ID for this conversation
74
+ if "session_id" not in st.session_state:
75
+ import uuid
76
+ st.session_state.session_id = str(uuid.uuid4())
77
+
78
+ # Initialize or get chat history
79
+ if "messages" not in st.session_state:
80
+ st.session_state.messages = []
81
+
82
+ # Initialize the graph on first run or when config changes
83
+ if "chat_graph" not in st.session_state or st.sidebar.button("Reset Conversation"):
84
+ if "GROQ_API_KEY" in os.environ:
85
+ with st.spinner("Initializing chatbot..."):
86
+ st.session_state.chat_graph = create_chat_graph(system_prompt, selected_model)
87
+ st.session_state.messages = [] # Clear messages on reset
88
+ st.success("Chatbot initialized!")
89
+ else:
90
+ st.sidebar.error("API key required to initialize chatbot.")
91
+
92
+ # Display chat history
93
+ for message in st.session_state.messages:
94
+ if isinstance(message, dict): # Handle dict format
95
+ role = message.get("role", "")
96
+ content = message.get("content", "")
97
+ else: # Handle direct string format
98
+ role = "user" if message.startswith("User: ") else "assistant"
99
+ content = message.replace("User: ", "").replace("Assistant: ", "")
100
+
101
+ with st.chat_message(role):
102
+ st.write(content)
103
+
104
+ # Input for new message
105
+ if "chat_graph" in st.session_state and "GROQ_API_KEY" in os.environ:
106
+ user_input = st.chat_input("Type your message here...")
107
+
108
+ if user_input:
109
+ # Display user message
110
+ with st.chat_message("user"):
111
+ st.write(user_input)
112
+
113
+ # Add to history
114
+ st.session_state.messages.append({"role": "user", "content": user_input})
115
+
116
+ # Get response from the chatbot
117
+ with st.spinner("Thinking..."):
118
+ # Call the graph with the user's message
119
+ config = {"configurable": {"thread_id": st.session_state.session_id}}
120
+ user_message = [HumanMessage(content=user_input)]
121
+ result = st.session_state.chat_graph.invoke({"messages": user_message}, config)
122
+
123
+ # Extract response
124
+ response = result["messages"][-1].content
125
+
126
+ # Display assistant response
127
+ with st.chat_message("assistant"):
128
+ st.write(response)
129
+
130
+ # Add to history
131
+ st.session_state.messages.append({"role": "assistant", "content": response})
132
+
133
+ # Add some additional info in the sidebar
134
+ st.sidebar.markdown("---")
135
+ st.sidebar.subheader("About")
136
+ st.sidebar.info(
137
+ """
138
+ This chatbot uses LangGraph for maintaining conversation context and
139
+ ChatGroq's language models for generating responses. Each conversation
140
+ has a unique session ID to maintain history.
141
+ """
142
+ )
143
+
144
+ # Download chat history
145
+ if st.sidebar.button("Download Chat History"):
146
+ import json
147
+ from datetime import datetime
148
+
149
+ # Convert chat history to downloadable format
150
+ chat_export = "\n".join([f"{m['role']}: {m['content']}" for m in st.session_state.messages])
151
+
152
+ # Create download button
153
+ st.sidebar.download_button(
154
+ label="Download as Text",
155
+ data=chat_export,
156
+ file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
157
+ mime="text/plain"
158
+ )