Spaces:
Running
Running
import streamlit as st | |
from pathlib import Path | |
from langchain.llms.openai import OpenAI | |
from langchain.agents import create_sql_agent | |
from langchain.sql_database import SQLDatabase | |
from langchain.agents.agent_types import AgentType | |
from langchain_community.callbacks import StreamlitCallbackHandler | |
from langchain.agents.agent_toolkits import SQLDatabaseToolkit | |
from sqlalchemy import create_engine | |
import sqlite3 | |
import os | |
from langchain_openai import ChatOpenAI | |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] | |
st.set_page_config(page_title="Protected Areas Database Chat", page_icon="π¦", layout="wide") | |
st.title("π¦ Protected Areas Database Chat") | |
db_uri = "duckdb:///:memory:" | |
db_uri = "duckdb:///pad.duckdb" | |
engine = create_engine(db_uri) | |
from sqlalchemy import text | |
con = engine.connect() | |
#con.execute(text("create or replace view agency_name as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-name.parquet'")) | |
#con.execute(text("create or replace view agency_name as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-name.parquet'")) | |
#con.execute(text("create or replace view agency_type as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-agency-type.parquet'")) | |
#con.execute(text("create or replace view category as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-category.parquet'")) | |
#con.execute(text("create or replace view designation_type as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-desgination-type.parquet'")) | |
#con.execute(text("create or replace view easement as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-easement.parquet'")) | |
#con.execute(text("create or replace view fee as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-fee.parquet'")) | |
#con.execute(text("create or replace view marine as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-marine.parquet'")) | |
#con.execute(text("create or replace view iucn as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-iucn.parquet'")) | |
#con.execute(text("create or replace view public_access as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-public-access.parquet'")) | |
#con.execute(text("create or replace view state_name as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-state-name.parquet'")) | |
#con.execute(text("create or replace view combined as select * from 'https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/parquet/pad-combined.parquet'")) | |
db = SQLDatabase(engine, view_support=True) | |
db.get_usable_table_names() | |
# User inputs | |
radio_opt = ["US Protected Areas v3"] | |
selected_opt = st.sidebar.radio(label="Choose suitable option", options=radio_opt) | |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) | |
agent = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True) | |
def handle_user_input(user_query): | |
with history: | |
st.session_state.messages.append({"role": "user", "content": user_query}) | |
#st.chat_message("user").write(user_query) | |
with st.chat_message("assistant"): | |
st_cb = StreamlitCallbackHandler(st.container()) | |
response = agent.run(user_query, callbacks=[st_cb]) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# st.write(response) # thinking is only shown transiently this way | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
main = st.container() | |
with main: | |
history = st.container(height=400) | |
# stores all questions and responses, but not the 'thinking' | |
with history: | |
for msg in st.session_state.messages: | |
st.chat_message(msg["role"]).write(msg["content"]) | |
if user_query := st.chat_input(placeholder="Ask me about US Protected areas!"): | |
handle_user_input(user_query) | |
st.markdown("\n") #add some space for iphone users | |
EXAMPLE_PROMPTS = ["What is the total area in each GAP_Sts category in the fee table?", | |
"List the name of each table in the database", | |
"How much BLM land (BLM is a Mang_Name in the fee table) is in each GAP_Sts category?", | |
"Federal agencies are identified as 'FED' in the Mang_Type column in the 'combined' data table. The Mang_Name column indicates the different agencies. The full name of each agency is given in the agency_name table. Which federal agencies, by full name, manage the greatest area of GAP_Sts 1 or 2 land?"] | |
with st.sidebar: | |
with st.container(): | |
st.title("Examples") | |
for prompt in EXAMPLE_PROMPTS: | |
st.button(prompt, args=(prompt,), on_click=handle_user_input) | |