File size: 8,606 Bytes
f6f97d8
 
7faa846
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7faa846
f6f97d8
 
7faa846
f6f97d8
 
 
7faa846
f6f97d8
 
 
 
 
 
 
7faa846
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import json
import os
import uuid
import pandas as pd
import streamlit as st
import argparse
import traceback
from typing import Dict
import requests
from utils.utils import load_data_split
from nsql.database import NeuralDB
from nsql.nsql_exec import NSQLExecutor
from nsql.nsql_exec_python import NPythonExecutor
from generation.generator import Generator
import time

ROOT_DIR = os.path.join(os.path.dirname(__file__), "./")
EXAMPLE_TABLES = {
    "Estonia men's national volleyball team": (558, "what are the total number of players from france?"),
    "Highest mountain peaks of California": (5, "which is the lowest mountain?"),
    "2010–11 UAB Blazers men's basketball team": (1, "how many players come from alabama?"),
    "1999 European Tour": (209, "how many consecutive times was south africa the host country?"),
    "Nissan SR20DET": (438, "which car is the only one with more than 230 hp?"),
}


@st.cache
def load_data():
    return load_data_split("missing_squall", "validation")


@st.cache
def get_key():
    # print the public IP of the demo machine
    ip = requests.get('https://checkip.amazonaws.com').text.strip()
    print(ip)

    URL = "http://54.242.37.195:20217/api/predict"
    # The springboard machine we built to protect the key, 20217 is the birthday of Tianbao's girlfriend
    # we will only let the demo machine have the access to the keys

    one_key = requests.post(url=URL, json={"data": "Hi, binder server. Give me a key!"}).json()['data'][0]
    return one_key


def read_markdown(path):
    with open(path, "r") as f:
        output = f.read()
    st.markdown(output, unsafe_allow_html=True)


def generate_binder_program(_args, _generator, _data_item):
    n_shots = _args.n_shots
    few_shot_prompt = _generator.build_few_shot_prompt_from_file(
        file_path=_args.prompt_file,
        n_shots=n_shots
    )
    generate_prompt = _generator.build_generate_prompt(
        data_item=_data_item,
        generate_type=(_args.generate_type,)
    )
    prompt = few_shot_prompt + "\n\n" + generate_prompt

    # Ensure the input length fit Codex max input tokens by shrinking the n_shots
    max_prompt_tokens = _args.max_api_total_tokens - _args.max_generation_tokens
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=os.path.join(ROOT_DIR, "utils", "gpt2"))
    while len(tokenizer.tokenize(prompt)) >= max_prompt_tokens:  # TODO: Add shrink rows
        n_shots -= 1
        assert n_shots >= 0
        few_shot_prompt = _generator.build_few_shot_prompt_from_file(
            file_path=_args.prompt_file,
            n_shots=n_shots
        )
        prompt = few_shot_prompt + "\n\n" + generate_prompt

    response_dict = _generator.generate_one_pass(
        prompts=[("0", prompt)],  # the "0" is the place taker, take effect only when there are multi threads
        verbose=_args.verbose
    )
    print(response_dict)
    return response_dict["0"][0][0]


# Set up
parser = argparse.ArgumentParser()

parser.add_argument('--prompt_file', type=str, default='templates/prompts/prompt_wikitq_v3.txt')
# Binder program generation options
parser.add_argument('--prompt_style', type=str, default='create_table_select_3_full_table',
                    choices=['create_table_select_3_full_table',
                             'create_table_select_full_table',
                             'create_table_select_3',
                             'create_table',
                             'create_table_select_3_full_table_w_all_passage_image',
                             'create_table_select_3_full_table_w_gold_passage_image',
                             'no_table'])
parser.add_argument('--generate_type', type=str, default='nsql',
                    choices=['nsql', 'sql', 'answer', 'npython', 'python'])
parser.add_argument('--n_shots', type=int, default=14)
parser.add_argument('--seed', type=int, default=42)

# Codex options
# todo: Allow adjusting Codex parameters
parser.add_argument('--engine', type=str, default="code-davinci-002")
parser.add_argument('--max_generation_tokens', type=int, default=512)
parser.add_argument('--max_api_total_tokens', type=int, default=8001)
parser.add_argument('--temperature', type=float, default=0.)
parser.add_argument('--sampling_n', type=int, default=1)
parser.add_argument('--top_p', type=float, default=1.0)
parser.add_argument('--stop_tokens', type=str, default='\n\n',
                    help='Split stop tokens by ||')
parser.add_argument('--qa_retrieve_pool_file', type=str, default='templates/qa_retrieve_pool.json')

# debug options
parser.add_argument('-v', '--verbose', action='store_false')
args = parser.parse_args()
keys = [get_key()]

# The title
st.markdown("# Binder Playground")

# Summary about Binder
read_markdown('resources/summary.md')

# Introduction of Binder
# todo: Write Binder introduction here
# read_markdown('resources/introduction.md')
st.image('resources/intro.png')

# Upload tables/Switch tables

st.markdown('### Try Binder!')
col1, _ = st.columns(2)
with col1:
    selected_table_title = st.selectbox(
        "Select an example table",
        (
            "Estonia men's national volleyball team",
            "Highest mountain peaks of California",
            "2010–11 UAB Blazers men's basketball team",
            "1999 European Tour",
            "Nissan SR20DET",
        )
    )

# Here we just use ourselves'
data_items = load_data()
data_item = data_items[EXAMPLE_TABLES[selected_table_title][0]]
table = data_item['table']
header, rows, title = table['header'], table['rows'], table['page_title']
db = NeuralDB(
    [{"title": title, "table": table}])  # todo: try to cache this db instead of re-creating it again and again.
df = db.get_table_df()
st.markdown("Title: {}".format(title))
st.dataframe(df)

# Let user input the question
question = st.text_input(
    "Ask a question about the table:",
    value=EXAMPLE_TABLES[selected_table_title][1]
)
with col1:
    # todo: Why selecting language will flush the page?
    selected_language = st.selectbox(
        "Select a programming language",
        ("SQL", "Python"),
    )
if selected_language == 'SQL':
    args.prompt_file = 'templates/prompts/prompt_wikitq_v3.txt'
    args.generate_type = 'nsql'
elif selected_language == 'Python':
    args.prompt_file = 'templates/prompts/prompt_wikitq_python_simplified_v4.txt'
    args.generate_type = 'npython'
else:
    raise ValueError(f'{selected_language} language is not supported.')
button = st.button("Generate program")
if not button:
    st.stop()

# Generate Binder Program
generator = Generator(args, keys=keys)
with st.spinner("Generating program ..."):
    binder_program = generate_binder_program(args, generator,
                                             {"question": question, "table": db.get_table_df(), "title": title})


# Do execution
st.markdown("#### Binder program")
if selected_language == 'SQL':
    with st.container():
        st.write(binder_program)
    executor = NSQLExecutor(args, keys=keys)
elif selected_language == 'Python':
    st.code(binder_program, language='python')
    executor = NPythonExecutor(args, keys=keys)
    db = db.get_table_df()
else:
    raise ValueError(f'{selected_language} language is not supported.')
try:
    stamp = '{}'.format(uuid.uuid4())
    os.makedirs('tmp_for_vis/', exist_ok=True)
    with st.spinner("Executing program ..."):
        exec_answer = executor.nsql_exec(stamp, binder_program, db)
    # todo: Make it more pretty!
    # todo: Do we need vis for Python?
    if selected_language == 'SQL':
        with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "r") as f:
            steps = json.load(f)
        st.markdown("#### Steps & Intermediate results")
        for i, step in enumerate(steps):
            st.markdown(step)
            st.text("↓")
            with st.spinner('...'):
                time.sleep(1)
            with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, i), "r") as f:
                result_in_this_step = json.load(f)
            if isinstance(result_in_this_step, Dict):
                st.dataframe(pd.DataFrame(pd.DataFrame(result_in_this_step["rows"], columns=result_in_this_step["header"])))
            else:
                st.markdown(result_in_this_step)
            st.text("↓")
    elif selected_language == 'Python':
        pass
    if isinstance(exec_answer, list) and len(exec_answer) == 1:
        exec_answer = exec_answer[0]
    st.markdown(f'Execution answer: {exec_answer}')
except Exception as e:
    traceback.print_exc()