Spaces:
Runtime error
Runtime error
File size: 8,606 Bytes
f6f97d8 7faa846 f6f97d8 7faa846 f6f97d8 7faa846 f6f97d8 7faa846 f6f97d8 7faa846 f6f97d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import json
import os
import uuid
import pandas as pd
import streamlit as st
import argparse
import traceback
from typing import Dict
import requests
from utils.utils import load_data_split
from nsql.database import NeuralDB
from nsql.nsql_exec import NSQLExecutor
from nsql.nsql_exec_python import NPythonExecutor
from generation.generator import Generator
import time
ROOT_DIR = os.path.join(os.path.dirname(__file__), "./")
EXAMPLE_TABLES = {
"Estonia men's national volleyball team": (558, "what are the total number of players from france?"),
"Highest mountain peaks of California": (5, "which is the lowest mountain?"),
"2010β11 UAB Blazers men's basketball team": (1, "how many players come from alabama?"),
"1999 European Tour": (209, "how many consecutive times was south africa the host country?"),
"Nissan SR20DET": (438, "which car is the only one with more than 230 hp?"),
}
@st.cache
def load_data():
return load_data_split("missing_squall", "validation")
@st.cache
def get_key():
# print the public IP of the demo machine
ip = requests.get('https://checkip.amazonaws.com').text.strip()
print(ip)
URL = "http://54.242.37.195:20217/api/predict"
# The springboard machine we built to protect the key, 20217 is the birthday of Tianbao's girlfriend
# we will only let the demo machine have the access to the keys
one_key = requests.post(url=URL, json={"data": "Hi, binder server. Give me a key!"}).json()['data'][0]
return one_key
def read_markdown(path):
with open(path, "r") as f:
output = f.read()
st.markdown(output, unsafe_allow_html=True)
def generate_binder_program(_args, _generator, _data_item):
n_shots = _args.n_shots
few_shot_prompt = _generator.build_few_shot_prompt_from_file(
file_path=_args.prompt_file,
n_shots=n_shots
)
generate_prompt = _generator.build_generate_prompt(
data_item=_data_item,
generate_type=(_args.generate_type,)
)
prompt = few_shot_prompt + "\n\n" + generate_prompt
# Ensure the input length fit Codex max input tokens by shrinking the n_shots
max_prompt_tokens = _args.max_api_total_tokens - _args.max_generation_tokens
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=os.path.join(ROOT_DIR, "utils", "gpt2"))
while len(tokenizer.tokenize(prompt)) >= max_prompt_tokens: # TODO: Add shrink rows
n_shots -= 1
assert n_shots >= 0
few_shot_prompt = _generator.build_few_shot_prompt_from_file(
file_path=_args.prompt_file,
n_shots=n_shots
)
prompt = few_shot_prompt + "\n\n" + generate_prompt
response_dict = _generator.generate_one_pass(
prompts=[("0", prompt)], # the "0" is the place taker, take effect only when there are multi threads
verbose=_args.verbose
)
print(response_dict)
return response_dict["0"][0][0]
# Set up
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_file', type=str, default='templates/prompts/prompt_wikitq_v3.txt')
# Binder program generation options
parser.add_argument('--prompt_style', type=str, default='create_table_select_3_full_table',
choices=['create_table_select_3_full_table',
'create_table_select_full_table',
'create_table_select_3',
'create_table',
'create_table_select_3_full_table_w_all_passage_image',
'create_table_select_3_full_table_w_gold_passage_image',
'no_table'])
parser.add_argument('--generate_type', type=str, default='nsql',
choices=['nsql', 'sql', 'answer', 'npython', 'python'])
parser.add_argument('--n_shots', type=int, default=14)
parser.add_argument('--seed', type=int, default=42)
# Codex options
# todo: Allow adjusting Codex parameters
parser.add_argument('--engine', type=str, default="code-davinci-002")
parser.add_argument('--max_generation_tokens', type=int, default=512)
parser.add_argument('--max_api_total_tokens', type=int, default=8001)
parser.add_argument('--temperature', type=float, default=0.)
parser.add_argument('--sampling_n', type=int, default=1)
parser.add_argument('--top_p', type=float, default=1.0)
parser.add_argument('--stop_tokens', type=str, default='\n\n',
help='Split stop tokens by ||')
parser.add_argument('--qa_retrieve_pool_file', type=str, default='templates/qa_retrieve_pool.json')
# debug options
parser.add_argument('-v', '--verbose', action='store_false')
args = parser.parse_args()
keys = [get_key()]
# The title
st.markdown("# Binder Playground")
# Summary about Binder
read_markdown('resources/summary.md')
# Introduction of Binder
# todo: Write Binder introduction here
# read_markdown('resources/introduction.md')
st.image('resources/intro.png')
# Upload tables/Switch tables
st.markdown('### Try Binder!')
col1, _ = st.columns(2)
with col1:
selected_table_title = st.selectbox(
"Select an example table",
(
"Estonia men's national volleyball team",
"Highest mountain peaks of California",
"2010β11 UAB Blazers men's basketball team",
"1999 European Tour",
"Nissan SR20DET",
)
)
# Here we just use ourselves'
data_items = load_data()
data_item = data_items[EXAMPLE_TABLES[selected_table_title][0]]
table = data_item['table']
header, rows, title = table['header'], table['rows'], table['page_title']
db = NeuralDB(
[{"title": title, "table": table}]) # todo: try to cache this db instead of re-creating it again and again.
df = db.get_table_df()
st.markdown("Title: {}".format(title))
st.dataframe(df)
# Let user input the question
question = st.text_input(
"Ask a question about the table:",
value=EXAMPLE_TABLES[selected_table_title][1]
)
with col1:
# todo: Why selecting language will flush the page?
selected_language = st.selectbox(
"Select a programming language",
("SQL", "Python"),
)
if selected_language == 'SQL':
args.prompt_file = 'templates/prompts/prompt_wikitq_v3.txt'
args.generate_type = 'nsql'
elif selected_language == 'Python':
args.prompt_file = 'templates/prompts/prompt_wikitq_python_simplified_v4.txt'
args.generate_type = 'npython'
else:
raise ValueError(f'{selected_language} language is not supported.')
button = st.button("Generate program")
if not button:
st.stop()
# Generate Binder Program
generator = Generator(args, keys=keys)
with st.spinner("Generating program ..."):
binder_program = generate_binder_program(args, generator,
{"question": question, "table": db.get_table_df(), "title": title})
# Do execution
st.markdown("#### Binder program")
if selected_language == 'SQL':
with st.container():
st.write(binder_program)
executor = NSQLExecutor(args, keys=keys)
elif selected_language == 'Python':
st.code(binder_program, language='python')
executor = NPythonExecutor(args, keys=keys)
db = db.get_table_df()
else:
raise ValueError(f'{selected_language} language is not supported.')
try:
stamp = '{}'.format(uuid.uuid4())
os.makedirs('tmp_for_vis/', exist_ok=True)
with st.spinner("Executing program ..."):
exec_answer = executor.nsql_exec(stamp, binder_program, db)
# todo: Make it more pretty!
# todo: Do we need vis for Python?
if selected_language == 'SQL':
with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "r") as f:
steps = json.load(f)
st.markdown("#### Steps & Intermediate results")
for i, step in enumerate(steps):
st.markdown(step)
st.text("β")
with st.spinner('...'):
time.sleep(1)
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, i), "r") as f:
result_in_this_step = json.load(f)
if isinstance(result_in_this_step, Dict):
st.dataframe(pd.DataFrame(pd.DataFrame(result_in_this_step["rows"], columns=result_in_this_step["header"])))
else:
st.markdown(result_in_this_step)
st.text("β")
elif selected_language == 'Python':
pass
if isinstance(exec_answer, list) and len(exec_answer) == 1:
exec_answer = exec_answer[0]
st.markdown(f'Execution answer: {exec_answer}')
except Exception as e:
traceback.print_exc()
|