sql-chatbot / pages /2_πŸ“Š_Chart_Demo.py
cboettig's picture
update
58f1141
# This example does not use a langchain agent,
# The langchain sql chain has knowledge of the database, but doesn't interact with it becond intialization.
# The output of the sql chain is parsed seperately and passed to `duckdb.sql()` by streamlit
import os
os.environ["WEBSOCKET_TIMEOUT_MS"] = "300000" # no effect
import streamlit as st
import geopandas as gpd
import pandas as pd
from shapely import wkb
st.set_page_config(page_title="Explore US Protected Areas", page_icon="🦜", layout="wide")
st.title("Explore US Protected Areas")
## Database connection, reading directly from remote parquet file
from sqlalchemy import create_engine
from langchain.sql_database import SQLDatabase
db_uri = "duckdb:///my.duckdb"
stats = "https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/pad-stats.parquet"
groups = "https://huggingface.co/datasets/boettiger-lab/pad-us-3/resolve/main/pad-groupings.parquet"
engine = create_engine(db_uri) #connect_args={'read_only': True})
con = engine.connect()
con.execute("install spatial; load spatial;")
# con.execute(f"create or replace view stats as select * from read_parquet('{stats}');").fetchall()
con.execute(f"create or replace table groups as select * from read_parquet('{groups}');").fetchall()
db = SQLDatabase(engine, view_support=True)
@st.cache_data
def query_database(response):
z = con.execute(response).fetchall()
return pd.DataFrame(z).head(25)
import ibis
from ibis import _
import ibis.selectors as s
import altair as alt
ibis_con = ibis.duckdb.connect("my.duckdb")
stats = ibis_con.read_parquet(stats)
us_lower_48_area_m2 = 7.8e+12
def summary_table(stats, query, column):
#z = con.execute(query).fetchall()
groups = ibis_con.table("groups").sql(query.replace(";", ""))
df = (stats
.inner_join(groups, "row_n")
.select(~s.contains("_right"))
.rename(area = "area_square_meters")
.group_by(_[column])
.aggregate(percent_protected = 100 * _.area.sum() / us_lower_48_area_m2,
hectares = _.area.sum() / 10000,
n = _.area.count(),
richness = (_.richness * _.area).sum() / _.area.sum(),
rsr = (_.rsr * _.area).sum() / _.area.sum(),
carbon_lost = (_.deforest_carbon * _.area).sum() / _.area.sum(),
crop_expansion = (_.crop_expansion * _.area).sum() / _.area.sum(),
human_impact = (_.human_impact * _.area).sum() / _.area.sum(),
)
.mutate(percent_protected = _.percent_protected.round(1))
)
return df.to_pandas()
def area_plot(df, column):
base = alt.Chart(df).encode(
alt.Theta("percent_protected:Q").stack(True),
alt.Color(column+":N").legend(None)
)
pie = base.mark_arc(innerRadius= 40, outerRadius=80)
text = base.mark_text(radius=120, size=20).encode(
text="percent_protected:Q"
)
plot = pie + text
return st.altair_chart(plot)
def bar_chart(df, x, y):
chart = alt.Chart(df).mark_bar().encode(
x=x,
y=y,
color=alt.Color(x).legend(None)
).properties(width="container", height=300)
return chart
## ChatGPT Connection
from langchain_openai import ChatOpenAI
from langchain_community.llms import Ollama
# from langchain_community.llms import ChatOllama
models = {"chatgpt3.5": ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=st.secrets["OPENAI_API_KEY"])}
other_models = {
"chatgpt4": ChatOpenAI(model="gpt-4", temperature=0, api_key=st.secrets["OPENAI_API_KEY"]),
"duckdb-nsql": Ollama(model="duckdb-nsql", temperature=0),
"command-r-plus": Ollama(model="command-r-plus", temperature=0),
"mixtral:8x22b": Ollama(model="mixtral:8x22b", temperature=0),
"wizardlm2:8x22b": Ollama(model="wizardlm2:8x22b", temperature=0),
"sqlcoder": Ollama(model="sqlcoder", temperature=0),
"zephyr": Ollama(model="zephyr", temperature=0),
"gemma:7b": Ollama(model="gemma:7b", temperature=0),
"codegemma": Ollama(model="codegemma", temperature=0),
"llama2": Ollama(model="llama2", temperature=0),
}
with st.sidebar:
choice = st.radio("Select an LLM:", models)
llm = models[choice]
column = st.text_input("grouping column", "labels")
## A SQL Chain
from langchain.chains import create_sql_query_chain
chain = create_sql_query_chain(llm, db)
main = st.container()
## Does not preserve history
with main:
'''
The US [recently announced](https://www.conservation.gov/pages/america-the-beautiful-initiative) the first-ever national goal to conserve at least 30 percent of our lands and waters by the year 2030.
But which 30%?
Protected areas span a range of "GAP" areas [indicating the degree of protection](https://www.protectedlands.net/uses-of-pad-us/#conservation-of-biodiversity-2). Protected areas include not only owned or "fee"-based parcels such as National Parks and Monuments,
but also "easements" (see [feature classes](https://www.protectedlands.net/pad-us-technical-how-tos/#feature-classes-in-pad-us-2))
- GAP 1: Managed for biodiversity with natural disturbance events allowed (for example, Wilderness, Research Natural Areas, some National Parks, some State or NGO Nature Preserves)
- GAP 2: Managed for biodiversity with management that may interfere with natural processes (for example, suppress wildfire or flood)
- GAP 3: Permanent protection, but the land is subject to multiple uses (forestry, farming, intensive recreation, etc.
- GAP 4: No known institutional mandates to prevent conversion of natural habitat types
Use the chat tool below to specify your own groupings of the data and see how they compare.
##### Try these example queries:
- gap 1, 2, 3 are labelled 'conserved lands' and gap 4 is labeled 'other'
- exclude gap 4, include only Federal manager types, labelled by manager_name
- label gap 1, 2 as "permanently protected", label gap 3 as "additional conserved area", and gap 4 as other
- label gap 1, 2 areas in category ="Easements" as "protected easements", gap 1,2 category="Fee" as "protected areas", gap 3 easements as "mixed use easements", gap 3 Fee as "mixed use lands". exclude gap 4.
'''
prefix = "construct a select query that creates a column called 'labels' that only contains rows that meet the following criteria:"
suffix = ". Do not use LIMIT. Always return all columns. Do not try to select specific columns."
st.markdown("Specify how data should be labelled, as in the examples above:")
chatbox = st.container()
with chatbox:
if prompt := st.chat_input(key="chain"):
st.chat_message("user").write(prompt)
with st.chat_message("assistant"):
response = chain.invoke({"question": prefix + prompt + suffix})
st.write(response)
df = summary_table(stats, response, column)
with st.container():
col1, col2, col3 = st.columns(3)
with col1:
total_percent = df.percent_protected.sum().round(1)
f"{total_percent}% Continental US Covered"
area_plot(df, column)
with col2:
"Species Richness"
st.altair_chart(bar_chart(df, column, "richness"), use_container_width=True)
with col3:
"Range-Size Rarity"
st.altair_chart(bar_chart(df, column, "rsr"), use_container_width=True)
with st.container():
col1b, col2b, col3b = st.columns(3)
with col1b:
"Carbon Lost ('02-'22)"
st.altair_chart(bar_chart(df, column, "carbon_lost"), use_container_width=True)
with col2b:
"Crop expansion"
st.altair_chart(bar_chart(df, column, "crop_expansion"), use_container_width=True)
with col3b:
"Human Impact"
st.altair_chart(bar_chart(df, column, "human_impact"), use_container_width=True)
st.divider()
st.dataframe(df)
#st.divider()
#with st.container():
# st.text("Database schema (top 3 rows)")
# tbl = tbl = query_database("select * from groups limit 3")
# st.dataframe(tbl)
st.divider()
'''
Experimental prototype.
- Author: [Carl Boettiger](https://carlboettiger.info)
- For data sources and processing, see: https://beta.source.coop/repositories/cboettig/pad-us-3/description/
'''