File size: 4,949 Bytes
44be70c
9d1df78
 
 
 
 
2071015
9d1df78
 
 
 
 
 
e0fc76b
0539eb3
 
27b4ed0
2071015
27b4ed0
 
 
2071015
 
 
 
 
 
 
 
 
 
 
 
27b4ed0
 
 
0539eb3
 
 
9d1df78
 
 
 
 
 
 
 
d86085f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0fc76b
d86085f
 
 
 
 
 
 
 
 
9d1df78
 
d86085f
 
 
 
9d1df78
d86085f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import streamlit as st
from pathlib import Path
from langchain.llms.openai import OpenAI
from langchain.agents import create_sql_agent
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_types import AgentType
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from sqlalchemy import create_engine
import sqlite3
import os
from langchain_openai import ChatOpenAI
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
st.set_page_config(page_title="Protected Areas Database Chat", page_icon="🦜", layout="wide")
st.title("🦜 Protected Areas Database Chat")

db_uri = "duckdb:///:memory:"
db_uri = "duckdb:///pad.duckdb"
engine = create_engine(db_uri)
from sqlalchemy import text
con = engine.connect()
#con.execute(text("create or replace view agency_name as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-name.parquet'"))
#con.execute(text("create or replace view agency_name as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-name.parquet'"))
#con.execute(text("create or replace view agency_type as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-type.parquet'"))
#con.execute(text("create or replace view category as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-category.parquet'"))
#con.execute(text("create or replace view designation_type as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-desgination-type.parquet'"))
#con.execute(text("create or replace view easement as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-easement.parquet'"))
#con.execute(text("create or replace view fee as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-fee.parquet'"))
#con.execute(text("create or replace view marine as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-marine.parquet'"))
#con.execute(text("create or replace view iucn as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-iucn.parquet'"))
#con.execute(text("create or replace view public_access as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-public-access.parquet'"))
#con.execute(text("create or replace view state_name as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-state-name.parquet'"))
#con.execute(text("create or replace view combined as  select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-combined.parquet'"))

db = SQLDatabase(engine, view_support=True)
db.get_usable_table_names()




# User inputs
radio_opt = ["US Protected Areas v3"]
selected_opt = st.sidebar.radio(label="Choose suitable option", options=radio_opt)

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

def handle_user_input(user_query):
    with history:
        st.session_state.messages.append({"role": "user", "content": user_query})
        st.chat_message("user").write(user_query)

        with st.chat_message("assistant"):
            st_cb = StreamlitCallbackHandler(st.container())
            response = agent.run(user_query, callbacks=[st_cb])
            st.session_state.messages.append({"role": "assistant", "content": response})
            st.write(response)


if "messages" not in st.session_state:
    st.session_state["messages"] = []

main = st.container()
with main:
    history = st.container(height=400)
    #with history:
    #    for msg in st.session_state.messages:
    #        st.chat_message(msg["role"]).write(msg["content"])  
    if user_query := st.chat_input(placeholder="Ask me about US Protected areas!"):
        handle_user_input(user_query)
    
    st.markdown("\n") #add some space for iphone users




EXAMPLE_PROMPTS = ["What is the total area in each GAP_Sts category in the fee table?",
                   "List the name of each table in the database",
                   "How much BLM land (BLM is a Mang_Name in the fee table) is in each GAP_Sts category?",
                   "Federal agencies are identified as 'FED' in the Mang_Type column in the 'combined' data table. The Mang_Name column indicates the different agencies. The full name of each agency is given in the agency_name table. Which federal agencies, by full name, manage the greatest area of GAP_Sts 1 or 2 land?"]

with st.sidebar:
    with st.container():
        st.title("Examples")
        for prompt in EXAMPLE_PROMPTS:
            st.button(prompt, args=(prompt,), on_click=handle_user_input)