File size: 5,286 Bytes
2ac93d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from utils import SyncStreamingLLMCallbackHandler
from anyio.from_thread import start_blocking_portal
from queue import Queue
import prompt, tiktoken

def num_token(string: str) -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding('cl100k_base')
    num_tokens = len(encoding.encode(string))
    return num_tokens

def send_message(history, temp, top_p, fp, pp):
    q = Queue()
    job_done = object()
    def task():
        llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            streaming=True,
            temperature=temp,
            model_kwargs = {
                'top_p':top_p,
                'frequency_penalty':fp,
                'presence_penalty':pp
            }
        )
        query_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            temperature = 0
        )
        query_template = prompt.search_query_gradio(history)
        search_query = query_llm([HumanMessage(content=query_template)]).content
        print('query: ', search_query)
        answer_template = prompt.question_answer_gradio(history, search_query)
        print('final temp: \n\n', answer_template, f'\nprompt total: {num_token(answer_template)}')
        llm([HumanMessage(content=answer_template)], callbacks=[SyncStreamingLLMCallbackHandler(q)])
        q.put(job_done)
        return

    with start_blocking_portal() as portal:
        portal.start_task_soon(task)

        content = ""
        while True:
            next_token = q.get(True, timeout=10)
            if next_token is job_done:
                return history

            content += next_token
            latest = history[-1]
            latest[1] = content
            yield history


def append_user_message(message, history):
    "Add user message to chain/history and clear input field"
    print(history)
    history.append([message, None])
    return "", history

css = """
.contain { display: flex; flex-direction: column; height: 100%;}
#chatbot { flex-grow: 1;}
#chatbot .user {text-align: right;}

.contain { display: flex; flex-direction: column; height: 100%;}
#query_ref { flex-grow: 1; }
"""


with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="orange")) as demo:

    with gr.Row(elem_id="whole_page"):
        with gr.Column(elem_id="sidebar", scale=0.2, min_width=0):
            # TODO temperature?
            gr.Markdown('## City of Lake Elsinore Proposal Demo', show_label=False)
            with gr.Row(elem_id="logo_section1"):
                with gr.Column(elem_id="logo_col1", scale=0.5, min_width=0):
                    gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><img src="./file=icon_1-modified.png" alt="ThinkCol" width="75" height="87" /></a></div>""")
                with gr.Column(elem_id="logo_col2", scale=0.5, min_width=0):
                    gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><img src="./file=icon_2-modified.png" alt="ThinkCol" width="75" height="87" /></a></div>""")
            gr.Markdown('### Chatbot Settings', show_label=False)
            temp = gr.Slider(maximum=1, label='Temperature')
            top_p = gr.Slider(maximum=1, label='Top_p', value=1, interactive=True, visible='primary')
            fp = gr.Slider(maximum=2, label='Frequency Penalty')
            pp = gr.Slider(maximum=2, label='Presence Penalty')

        with gr.Column(elem_id="main_chat_interface", scale=0.5):
            chatbot = gr.Chatbot([], elem_id="chatbot")
            with gr.Row(elem_id="box_split"):
                with gr.Column(elem_id="enter_box",scale=0.85):
                    txt = gr.Textbox(
                        elem_id='input',
                        show_label=False,
                        placeholder="Enter text and press enter, or upload an image"
                    ).style(container=False)
                    txt.submit(append_user_message, [txt, chatbot], [txt, chatbot]) \
                        .then(send_message, [chatbot, temp, top_p, fp, pp], [chatbot])

                with gr.Column(elem_id="send_box",scale=0.15, min_width=0):
                    btn = gr.Button('Send', elem_id='send', variant='primary')
                    btn.click(append_user_message, [txt, chatbot], [txt, chatbot]) \
                        .then(send_message, [chatbot, temp, top_p, fp, pp], [chatbot])


        # with gr.Column(elem_id="main_chat_interface", scale=0.3):
        #     with gr.Tab("Search Query"):
        #         query_ref = gr.TextArea(
        #                 value='',
        #                 interactive=False,
        #                 elem_id='query_ref',
        #                 show_label=False,
        #             ).style(container=False)
        #     with gr.Tab("Prompt"):
        #         prompt_ref = gr.TextArea(
        #                 value='',
        #                 interactive=False,
        #                 elem_id='prompt_ref',
        #                 show_label=False,
        #             ).style(container=False)
        #     print(query_ref, prompt_ref)

demo.queue()
if __name__ == "__main__":
    demo.launch()