Spaces:
Running
Running
import gradio as gr | |
import os | |
from langchain.agents import load_tools | |
from langchain.agents import initialize_agent | |
from langchain import PromptTemplate, HuggingFaceHub, LLMChain, ConversationChain | |
from langchain.llms import OpenAI | |
from langchain.chains.conversation.memory import ConversationBufferMemory | |
from threading import Lock | |
import openai | |
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError | |
from typing import Optional, Tuple | |
TOOLS_DEFAULT_LIST = ['serpapi', 'news-api', 'pal-math'] | |
MAX_TOKENS = 512 | |
PROMPT_TEMPLATE = PromptTemplate( | |
input_variables=["original_words"], | |
template="Restate the following: \n{original_words}\n", | |
) | |
BUG_FOUND_MSG = "Congratulations, you've found a bug in this application!" | |
AUTH_ERR_MSG = "Please paste your OpenAI key." | |
news_api_key = os.environ["NEWS_API_KEY"] | |
def run_chain(chain, inp, capture_hidden_text): | |
output = "" | |
hidden_text = None | |
try: | |
output = chain.run(input=inp) | |
except AuthenticationError as ae: | |
output = AUTH_ERR_MSG | |
except RateLimitError as rle: | |
output = "\n\nRateLimitError: " + str(rle) | |
except ValueError as ve: | |
output = "\n\nValueError: " + str(ve) | |
except InvalidRequestError as ire: | |
output = "\n\nInvalidRequestError: " + str(ire) | |
except Exception as e: | |
output = "\n\n" + BUG_FOUND_MSG + ":\n\n" + str(e) | |
return output, hidden_text | |
def transform_text(desc, express_chain): | |
formatted_prompt = PROMPT_TEMPLATE.format( | |
original_words=desc | |
) | |
generated_text = desc | |
# replace all newlines with <br> in generated_text | |
generated_text = generated_text.replace("\n", "\n\n") | |
return generated_text | |
class ChatWrapper: | |
def __init__(self): | |
self.lock = Lock() | |
def __call__( | |
self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain], express_chain: Optional[LLMChain]): | |
"""Execute the chat functionality.""" | |
self.lock.acquire() | |
try: | |
history = history or [] | |
# If chain is None, that is because no API key was provided. | |
output = "Please paste your OpenAI key to use this application." | |
hidden_text = output | |
if chain and chain != "": | |
# Set OpenAI key | |
openai.api_key = api_key | |
output, hidden_text = run_chain(chain, inp, capture_hidden_text=False) | |
print('output1', output) | |
output = transform_text(output, express_chain) | |
print('output2', output) | |
text_to_display = output | |
history.append((inp, text_to_display)) | |
except Exception as e: | |
raise e | |
finally: | |
self.lock.release() | |
# return history, history, html_video, temp_file, "" | |
return history, history | |
chat = ChatWrapper() | |
def load_chain(tools_list, llm): | |
chain = None | |
express_chain = None | |
print("\ntools_list", tools_list) | |
tool_names = tools_list | |
tools = load_tools(tool_names, llm=llm, news_api_key=news_api_key) | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
chain = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True, memory=memory) | |
express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True) | |
return chain, express_chain | |
def set_openai_api_key(api_key): | |
"""Set the api key and return chain. | |
If no api_key, then None is returned. | |
""" | |
os.environ["OPENAI_API_KEY"] = api_key | |
llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS) | |
chain, express_chain = load_chain(TOOLS_DEFAULT_LIST, llm) | |
os.environ["OPENAI_API_KEY"] = "" | |
return chain, express_chain, llm | |
with gr.Blocks() as app: | |
llm_state = gr.State() | |
history_state = gr.State() | |
chain_state = gr.State() | |
express_chain_state = gr.State() | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML( | |
"""<b><center>GPT + Google</center></b>""") | |
openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)", | |
show_label=False, lines=1, type='password') | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
message = gr.Textbox(label="What's on your mind??", | |
placeholder="What's the answer to life, the universe, and everything?", | |
lines=1) | |
submit = gr.Button(value="Send", variant="secondary").style(full_width=False) | |
gr.Examples( | |
examples=["How many people live in Canada?", | |
"What is 2 to the 30th power?", | |
"If x+y=10 and x-y=4, what are x and y?", | |
"How much did it rain in SF today?", | |
"Get me information about the movie 'Avatar'", | |
"What are the top tech headlines in the US?", | |
"On the desk, you see two blue booklets, two purple booklets, and two yellow pairs of sunglasses - " | |
"if I remove all the pairs of sunglasses from the desk, how many purple items remain on it?"], | |
inputs=message | |
) | |
message.submit(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state, | |
express_chain_state], outputs=[chatbot, history_state]) | |
submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state, | |
express_chain_state], outputs=[chatbot, history_state]) | |
openai_api_key_textbox.change(set_openai_api_key, | |
inputs=[openai_api_key_textbox], | |
outputs=[chain_state, express_chain_state, llm_state]) | |
app.launch(debug=True) |