File size: 5,921 Bytes
b667e96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
import os
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain import PromptTemplate, HuggingFaceHub, LLMChain, ConversationChain

from langchain.llms import OpenAI
from langchain.chains.conversation.memory import ConversationBufferMemory
from threading import Lock
import openai

from openai.error import AuthenticationError, InvalidRequestError, RateLimitError
from typing import Optional, Tuple

TOOLS_DEFAULT_LIST = ['serpapi', 'news-api', 'pal-math']
MAX_TOKENS = 512
PROMPT_TEMPLATE = PromptTemplate(
    input_variables=["original_words"],
    template="Restate the following: \n{original_words}\n",
)

BUG_FOUND_MSG = "Congratulations, you've found a bug in this application!"
AUTH_ERR_MSG = "Please paste your OpenAI key."

news_api_key = os.environ["NEWS_API_KEY"]

def run_chain(chain, inp, capture_hidden_text):
    output = ""
    hidden_text = None
    try:
        output = chain.run(input=inp)
    except AuthenticationError as ae:
        output = AUTH_ERR_MSG
    except RateLimitError as rle:
        output = "\n\nRateLimitError: " + str(rle)
    except ValueError as ve:
        output = "\n\nValueError: " + str(ve)
    except InvalidRequestError as ire:
        output = "\n\nInvalidRequestError: " + str(ire)
    except Exception as e:
        output = "\n\n" + BUG_FOUND_MSG + ":\n\n" + str(e)

    return output, hidden_text

def transform_text(desc, express_chain):

    formatted_prompt = PROMPT_TEMPLATE.format(
        original_words=desc
    )
    generated_text = desc

    # replace all newlines with <br> in generated_text
    generated_text = generated_text.replace("\n", "\n\n")

    return generated_text

class ChatWrapper:

    def __init__(self):
        self.lock = Lock()

    def __call__(
            self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain], express_chain: Optional[LLMChain]):
        """Execute the chat functionality."""
        self.lock.acquire()
        try:
            history = history or []
            # If chain is None, that is because no API key was provided.
            output = "Please paste your OpenAI key to use this application."
            hidden_text = output

            if chain and chain != "":
                # Set OpenAI key
                openai.api_key = api_key
                output, hidden_text = run_chain(chain, inp, capture_hidden_text=False)
                print('output1', output)

            output = transform_text(output, express_chain)
            print('output2', output)
            text_to_display = output
            history.append((inp, text_to_display))

        except Exception as e:
            raise e
        finally:
            self.lock.release()
        # return history, history, html_video, temp_file, ""
        return history, history


chat = ChatWrapper()

def load_chain(tools_list, llm):
    chain = None
    express_chain = None
    print("\ntools_list", tools_list)
    tool_names = tools_list
    tools = load_tools(tool_names, llm=llm, news_api_key=news_api_key)

    memory = ConversationBufferMemory(memory_key="chat_history")

    chain = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True, memory=memory)
    express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True)
    return chain, express_chain


def set_openai_api_key(api_key):
    """Set the api key and return chain.
    If no api_key, then None is returned.
    """

    os.environ["OPENAI_API_KEY"] = api_key
    llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS)
    chain, express_chain = load_chain(TOOLS_DEFAULT_LIST, llm)
    os.environ["OPENAI_API_KEY"] = ""
    return chain, express_chain, llm

with gr.Blocks() as app:
    llm_state = gr.State()
    history_state = gr.State()
    chain_state = gr.State()
    express_chain_state = gr.State()

    with gr.Row():
            with gr.Column():
                gr.HTML(
                    """<b><center>GPT + Google</center></b>""")

            openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)",
                                                show_label=False, lines=1, type='password')
    with gr.Row():

            with gr.Column(scale=3):
                chatbot = gr.Chatbot()
    with gr.Row():
            message = gr.Textbox(label="What's on your mind??",
                                 placeholder="What's the answer to life, the universe, and everything?",
                                 lines=1)
            submit = gr.Button(value="Send", variant="secondary").style(full_width=False)

    gr.Examples(
            examples=["How many people live in Canada?",
                      "What is 2 to the 30th power?",
                      "If x+y=10 and x-y=4, what are x and y?",
                      "How much did it rain in SF today?",
                      "Get me information about the movie 'Avatar'",
                      "What are the top tech headlines in the US?",
                      "On the desk, you see two blue booklets, two purple booklets, and two yellow pairs of sunglasses - "
                      "if I remove all the pairs of sunglasses from the desk, how many purple items remain on it?"],
            inputs=message
        )
    message.submit(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state,
                                 express_chain_state], outputs=[chatbot, history_state])

    submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state,
                               express_chain_state], outputs=[chatbot, history_state])
    
    openai_api_key_textbox.change(set_openai_api_key,
                                  inputs=[openai_api_key_textbox],
                                  outputs=[chain_state, express_chain_state, llm_state])

app.launch(debug=True)