JLW's picture
Print message if there isn't an API key pasted
fa42e0a
raw
history blame
4.68 kB
import os
import openai
import gradio as gr
from langchain import OpenAI
from langchain.chains import PALChain
import datetime
gpt_only_prompt = "Calculate the following, giving only the final answer:\n"
prompt = ""
# os.environ["OPENAI_API_KEY"] = ""
def set_openai_api_key(api_key, openai_api_key, pal_chain):
if api_key:
openai_api_key = api_key
os.environ["OPENAI_API_KEY"] = api_key
llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)
os.environ["OPENAI_API_KEY"] = ""
pal_chain = PALChain.from_math_prompt(llm, verbose=True)
return openai_api_key, pal_chain
def openai_create(prompt, openai_api_key):
print("prompt: " + prompt)
# We use temperature of 0.0 because it gives the most predictable, factual answer (i.e. avoids hallucination).
os.environ["OPENAI_API_KEY"] = openai_api_key
response = openai.Completion.create(
model="text-davinci-003",
prompt=prompt,
temperature=0.0,
max_tokens=300,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
os.environ["OPENAI_API_KEY"] = ""
return response.choices[0].text
def calc_gpt_only(math_problem, openai_api_key):
if not openai_api_key or openai_api_key == "":
return "<pre>Please paste your OpenAI API key</pre>"
answer = openai_create(gpt_only_prompt + math_problem + "\n", openai_api_key)
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
print("calc_gpt_only math problem: " + math_problem)
print("calc_gpt_only answer: " + answer)
html = "<pre>" + answer + "</pre>"
return html
def calc_gpt_pal(math_problem, pal_chain):
if not pal_chain:
return "<pre>Please paste your OpenAI API key</pre>"
answer = pal_chain.run(math_problem)
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
print("calc_gpt_pal math problem: " + math_problem)
print("calc_gpt_pal answer: " + answer)
html = "<pre>" + answer + "</pre>"
return html
block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
with block:
with gr.Row():
title = gr.Markdown("""<h3><center>Comparing GPT math techniques</center></h3>""")
openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)",
show_label=False, lines=1, type='password')
answer_html = gr.Markdown()
request = gr.Textbox(label="Math question:",
placeholder="Ex: What is the sum of the first 10 prime numbers?")
with gr.Row():
gpt_only = gr.Button(value="GPT Only", variant="secondary").style(full_width=False)
gpt_pal = gr.Button(value="GPT w/PAL", variant="secondary").style(full_width=False)
gr.Examples(
examples=["42 times 81",
"Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
"What is the sum of the first 21 Fibonacci numbers?",
"Jane quit her job on Mar 20, 2020. 176 days have passed since then. What is the date tomorrow in "
"MM/DD/YYYY?",
"If y = 8βˆ’5x+4x2, what is the value of y when x = βˆ’2?",
"A line parallel to y = 4x + 6 passes through (5, 10). What is the y-coordinate of the point where "
"this line crosses the y-axis?"],
inputs=request
)
gr.HTML("""
This simple app demonstrates a couple of techniques for using GPT-3 to solve math problems.
The first technique is to simply ask GPT-3 to solve the problem. The second technique is to use
GPT-3 to interpret the problem and then create/run a Python program to solve it. The program is
generated using the <a href='https://langchain.readthedocs.io/en/latest/examples/chains/pal.html'>
PALChain from the LangChain</a> library.
See <a href='https://reasonwithpal.com/'>PAL: Program-aided Language Models</a>""")
gr.HTML("<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain πŸ¦œοΈπŸ”—</a></center>")
openai_api_key_state = gr.State()
pal_chain_state = gr.State()
gpt_only.click(calc_gpt_only, inputs=[request, openai_api_key_state], outputs=[answer_html])
gpt_pal.click(calc_gpt_pal, inputs=[request, pal_chain_state], outputs=[answer_html])
openai_api_key_textbox.change(set_openai_api_key,
inputs=[openai_api_key_textbox, openai_api_key_state, pal_chain_state],
outputs=[openai_api_key_state, pal_chain_state])
block.launch()