Spaces:
Sleeping
Sleeping
here we go
Browse files
app.py
CHANGED
@@ -10,59 +10,63 @@ from sqlalchemy import create_engine
|
|
10 |
import sqlite3
|
11 |
import os
|
12 |
from langchain_openai import ChatOpenAI
|
13 |
-
|
14 |
-
|
15 |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
16 |
-
openai_api_key = st.secrets["OPENAI_API_KEY"]
|
17 |
-
|
18 |
st.set_page_config(page_title="Chat with Protected Areas Database", page_icon="🦜")
|
19 |
st.title("🦜 LangChain: Chat with Protected Areas Database")
|
20 |
-
|
21 |
-
INJECTION_WARNING = """
|
22 |
-
Experimental!
|
23 |
-
"""
|
24 |
-
LOCALDB = "duckdb:///pad.duckdb"
|
25 |
|
26 |
# User inputs
|
27 |
radio_opt = ["US Protected Areas v3"]
|
28 |
selected_opt = st.sidebar.radio(label="Choose suitable option", options=radio_opt)
|
29 |
-
if radio_opt.index(selected_opt) == 1:
|
30 |
-
st.sidebar.warning(INJECTION_WARNING, icon="⚠️")
|
31 |
-
db_uri = st.sidebar.text_input(
|
32 |
-
label="Database URI", placeholder="duckdb:///pad.duckdb"
|
33 |
-
)
|
34 |
-
else:
|
35 |
-
db_uri = LOCALDB
|
36 |
-
|
37 |
-
#st.sidebar.text_input(label="OpenAI API Key", type="password")
|
38 |
-
|
39 |
|
40 |
# Setup agent
|
41 |
-
llm = OpenAI(openai_api_key=openai_api_key, model="gpt-3.5-turbo", temperature=0, streaming=True)
|
42 |
-
|
43 |
@st.cache_resource(ttl="2h")
|
44 |
def configure_db(db_uri):
|
45 |
return SQLDatabase.from_uri(database_uri=db_uri, view_support=True)
|
46 |
-
|
47 |
db = configure_db(db_uri)
|
48 |
|
49 |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
50 |
agent = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
for msg in st.session_state.messages:
|
56 |
-
st.chat_message(msg["role"]).write(msg["content"])
|
57 |
|
58 |
-
user_query = st.chat_input(placeholder="Ask me anything!")
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
10 |
import sqlite3
|
11 |
import os
|
12 |
from langchain_openai import ChatOpenAI
|
|
|
|
|
13 |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
|
|
|
|
14 |
st.set_page_config(page_title="Chat with Protected Areas Database", page_icon="🦜")
|
15 |
st.title("🦜 LangChain: Chat with Protected Areas Database")
|
16 |
+
db_uri = "duckdb:///pad.duckdb"
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# User inputs
|
19 |
radio_opt = ["US Protected Areas v3"]
|
20 |
selected_opt = st.sidebar.radio(label="Choose suitable option", options=radio_opt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Setup agent
|
|
|
|
|
23 |
@st.cache_resource(ttl="2h")
|
24 |
def configure_db(db_uri):
|
25 |
return SQLDatabase.from_uri(database_uri=db_uri, view_support=True)
|
|
|
26 |
db = configure_db(db_uri)
|
27 |
|
28 |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
29 |
agent = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
|
30 |
|
31 |
+
def handle_user_input(user_query):
|
32 |
+
with history:
|
33 |
+
st.session_state.messages.append({"role": "user", "content": user_query})
|
34 |
+
st.chat_message("user").write(user_query)
|
35 |
+
|
36 |
+
with st.chat_message("assistant"):
|
37 |
+
st_cb = StreamlitCallbackHandler(st.container())
|
38 |
+
response = agent.run(user_query, callbacks=[st_cb])
|
39 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
40 |
+
st.write(response)
|
41 |
+
|
42 |
+
|
43 |
+
if "messages" not in st.session_state:
|
44 |
+
st.session_state["messages"] = []
|
45 |
+
|
46 |
+
main = st.container()
|
47 |
+
with main:
|
48 |
+
history = st.container(height=800)
|
49 |
+
#with history:
|
50 |
+
# for msg in st.session_state.messages:
|
51 |
+
# st.chat_message(msg["role"]).write(msg["content"])
|
52 |
+
if user_query := st.chat_input(placeholder="Ask me about US Protected areas!"):
|
53 |
+
handle_user_input(user_query)
|
54 |
+
|
55 |
+
st.markdown("\n") #add some space for iphone users
|
56 |
+
|
57 |
+
|
58 |
+
#if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
|
59 |
+
# st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
60 |
|
|
|
|
|
61 |
|
|
|
62 |
|
63 |
+
EXAMPLE_PROMPTS = ["What is the total area in each GAP_Sts category in the fee table?",
|
64 |
+
"List the name of each table in the database",
|
65 |
+
"How much BLM land (BLM is a Mang_Name in the fee table) is in each GAP_Sts category?",
|
66 |
+
"Federal agencies are identified as 'FED' in the Mang_Type column in the 'combined' data table. The Mang_Name column indicates the different agencies. The full name of each agency is given in the agency_name table. Which federal agencies, by full name, manage the greatest area of GAP_Sts 1 or 2 land?"]
|
67 |
|
68 |
+
with st.sidebar:
|
69 |
+
with st.container():
|
70 |
+
st.title("Examples")
|
71 |
+
for prompt in EXAMPLE_PROMPTS:
|
72 |
+
st.button(prompt, args=(prompt,), on_click=handle_user_input)
|