cboettig commited on
Commit
d86085f
·
1 Parent(s): 9d1df78

here we go

Browse files
Files changed (1) hide show
  1. app.py +39 -35
app.py CHANGED
@@ -10,59 +10,63 @@ from sqlalchemy import create_engine
10
  import sqlite3
11
  import os
12
  from langchain_openai import ChatOpenAI
13
-
14
-
15
  os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
16
- openai_api_key = st.secrets["OPENAI_API_KEY"]
17
-
18
  st.set_page_config(page_title="Chat with Protected Areas Database", page_icon="🦜")
19
  st.title("🦜 LangChain: Chat with Protected Areas Database")
20
-
21
- INJECTION_WARNING = """
22
- Experimental!
23
- """
24
- LOCALDB = "duckdb:///pad.duckdb"
25
 
26
  # User inputs
27
  radio_opt = ["US Protected Areas v3"]
28
  selected_opt = st.sidebar.radio(label="Choose suitable option", options=radio_opt)
29
- if radio_opt.index(selected_opt) == 1:
30
- st.sidebar.warning(INJECTION_WARNING, icon="⚠️")
31
- db_uri = st.sidebar.text_input(
32
- label="Database URI", placeholder="duckdb:///pad.duckdb"
33
- )
34
- else:
35
- db_uri = LOCALDB
36
-
37
- #st.sidebar.text_input(label="OpenAI API Key", type="password")
38
-
39
 
40
  # Setup agent
41
- llm = OpenAI(openai_api_key=openai_api_key, model="gpt-3.5-turbo", temperature=0, streaming=True)
42
-
43
  @st.cache_resource(ttl="2h")
44
  def configure_db(db_uri):
45
  return SQLDatabase.from_uri(database_uri=db_uri, view_support=True)
46
-
47
  db = configure_db(db_uri)
48
 
49
  llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
50
  agent = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
51
 
52
- if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
53
- st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- for msg in st.session_state.messages:
56
- st.chat_message(msg["role"]).write(msg["content"])
57
 
58
- user_query = st.chat_input(placeholder="Ask me anything!")
59
 
60
- if user_query:
61
- st.session_state.messages.append({"role": "user", "content": user_query})
62
- st.chat_message("user").write(user_query)
 
63
 
64
- with st.chat_message("assistant"):
65
- st_cb = StreamlitCallbackHandler(st.container())
66
- response = agent.run(user_query, callbacks=[st_cb])
67
- st.session_state.messages.append({"role": "assistant", "content": response})
68
- st.write(response)
 
10
  import sqlite3
11
  import os
12
  from langchain_openai import ChatOpenAI
 
 
13
  os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
 
 
14
  st.set_page_config(page_title="Chat with Protected Areas Database", page_icon="🦜")
15
  st.title("🦜 LangChain: Chat with Protected Areas Database")
16
+ db_uri = "duckdb:///pad.duckdb"
 
 
 
 
17
 
18
  # User inputs
19
  radio_opt = ["US Protected Areas v3"]
20
  selected_opt = st.sidebar.radio(label="Choose suitable option", options=radio_opt)
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Setup agent
 
 
23
  @st.cache_resource(ttl="2h")
24
  def configure_db(db_uri):
25
  return SQLDatabase.from_uri(database_uri=db_uri, view_support=True)
 
26
  db = configure_db(db_uri)
27
 
28
  llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
29
  agent = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
30
 
31
+ def handle_user_input(user_query):
32
+ with history:
33
+ st.session_state.messages.append({"role": "user", "content": user_query})
34
+ st.chat_message("user").write(user_query)
35
+
36
+ with st.chat_message("assistant"):
37
+ st_cb = StreamlitCallbackHandler(st.container())
38
+ response = agent.run(user_query, callbacks=[st_cb])
39
+ st.session_state.messages.append({"role": "assistant", "content": response})
40
+ st.write(response)
41
+
42
+
43
+ if "messages" not in st.session_state:
44
+ st.session_state["messages"] = []
45
+
46
+ main = st.container()
47
+ with main:
48
+ history = st.container(height=800)
49
+ #with history:
50
+ # for msg in st.session_state.messages:
51
+ # st.chat_message(msg["role"]).write(msg["content"])
52
+ if user_query := st.chat_input(placeholder="Ask me about US Protected areas!"):
53
+ handle_user_input(user_query)
54
+
55
+ st.markdown("\n") #add some space for iphone users
56
+
57
+
58
+ #if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
59
+ # st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
60
 
 
 
61
 
 
62
 
63
+ EXAMPLE_PROMPTS = ["What is the total area in each GAP_Sts category in the fee table?",
64
+ "List the name of each table in the database",
65
+ "How much BLM land (BLM is a Mang_Name in the fee table) is in each GAP_Sts category?",
66
+ "Federal agencies are identified as 'FED' in the Mang_Type column in the 'combined' data table. The Mang_Name column indicates the different agencies. The full name of each agency is given in the agency_name table. Which federal agencies, by full name, manage the greatest area of GAP_Sts 1 or 2 land?"]
67
 
68
+ with st.sidebar:
69
+ with st.container():
70
+ st.title("Examples")
71
+ for prompt in EXAMPLE_PROMPTS:
72
+ st.button(prompt, args=(prompt,), on_click=handle_user_input)