streamlit-demo / sql.py
cboettig's picture
examples
03da1c6
raw
history blame
1.89 kB
import streamlit as st
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
# +
# Set up Langchain SQL access
parquet = "s3://us-west-2.opendata.source.coop/cboettig/gbif/2024-10-01/**"
db = SQLDatabase.from_uri("duckdb:///tmp.db", view_support=True)
db.run(f"create or replace view mydata as select * from read_parquet('{parquet}');")
llm = ChatOpenAI(model="llama3", temperature=0, api_key=st.secrets["LITELLM_KEY"], base_url = "https://llm.nrp-nautilus.io")
# -
from langchain_core.prompts import PromptTemplate
template = '''
You are a {dialect} expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Never query for all columns from a table.
You must query only the columns that are needed to answer the question.
Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below.
Be careful to not query for columns that do not exist.
Also, pay attention to which column is in which table.
Pay attention to use today() function to get the current date, if the question involves "today".
Respond with only the SQL query to run. Do not repeat the question or explanation. Just the raw SQL query.
Only use the following tables:
{table_info}
Question: {input}
'''
prompt = PromptTemplate.from_template(template, partial_variables = {"dialect": "duckdb", "top_k": 10})
chain = create_sql_query_chain(llm, db, prompt)
# +
#print(db.dialect)
#print(db.get_usable_table_names())
#chain.get_prompts()[0].pretty_print()
# -
response = chain.invoke({"question": "Count the number of mammal occurrences in each h0 grouping"})
response
# %%time
x = db.run(response)