comsec / gradio_demo.py
briankchan's picture
Add demo
2ac93d3
import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from utils import SyncStreamingLLMCallbackHandler
from anyio.from_thread import start_blocking_portal
from queue import Queue
import prompt, tiktoken
def num_token(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
def send_message(history, temp, top_p, fp, pp):
q = Queue()
job_done = object()
def task():
llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
streaming=True,
temperature=temp,
model_kwargs = {
'top_p':top_p,
'frequency_penalty':fp,
'presence_penalty':pp
}
)
query_llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature = 0
)
query_template = prompt.search_query_gradio(history)
search_query = query_llm([HumanMessage(content=query_template)]).content
print('query: ', search_query)
answer_template = prompt.question_answer_gradio(history, search_query)
print('final temp: \n\n', answer_template, f'\nprompt total: {num_token(answer_template)}')
llm([HumanMessage(content=answer_template)], callbacks=[SyncStreamingLLMCallbackHandler(q)])
q.put(job_done)
return
with start_blocking_portal() as portal:
portal.start_task_soon(task)
content = ""
while True:
next_token = q.get(True, timeout=10)
if next_token is job_done:
return history
content += next_token
latest = history[-1]
latest[1] = content
yield history
def append_user_message(message, history):
"Add user message to chain/history and clear input field"
print(history)
history.append([message, None])
return "", history
css = """
.contain { display: flex; flex-direction: column; height: 100%;}
#chatbot { flex-grow: 1;}
#chatbot .user {text-align: right;}
.contain { display: flex; flex-direction: column; height: 100%;}
#query_ref { flex-grow: 1; }
"""
with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="orange")) as demo:
with gr.Row(elem_id="whole_page"):
with gr.Column(elem_id="sidebar", scale=0.2, min_width=0):
# TODO temperature?
gr.Markdown('## City of Lake Elsinore Proposal Demo', show_label=False)
with gr.Row(elem_id="logo_section1"):
with gr.Column(elem_id="logo_col1", scale=0.5, min_width=0):
gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><img src="./file=icon_1-modified.png" alt="ThinkCol" width="75" height="87" /></a></div>""")
with gr.Column(elem_id="logo_col2", scale=0.5, min_width=0):
gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><img src="./file=icon_2-modified.png" alt="ThinkCol" width="75" height="87" /></a></div>""")
gr.Markdown('### Chatbot Settings', show_label=False)
temp = gr.Slider(maximum=1, label='Temperature')
top_p = gr.Slider(maximum=1, label='Top_p', value=1, interactive=True, visible='primary')
fp = gr.Slider(maximum=2, label='Frequency Penalty')
pp = gr.Slider(maximum=2, label='Presence Penalty')
with gr.Column(elem_id="main_chat_interface", scale=0.5):
chatbot = gr.Chatbot([], elem_id="chatbot")
with gr.Row(elem_id="box_split"):
with gr.Column(elem_id="enter_box",scale=0.85):
txt = gr.Textbox(
elem_id='input',
show_label=False,
placeholder="Enter text and press enter, or upload an image"
).style(container=False)
txt.submit(append_user_message, [txt, chatbot], [txt, chatbot]) \
.then(send_message, [chatbot, temp, top_p, fp, pp], [chatbot])
with gr.Column(elem_id="send_box",scale=0.15, min_width=0):
btn = gr.Button('Send', elem_id='send', variant='primary')
btn.click(append_user_message, [txt, chatbot], [txt, chatbot]) \
.then(send_message, [chatbot, temp, top_p, fp, pp], [chatbot])
# with gr.Column(elem_id="main_chat_interface", scale=0.3):
# with gr.Tab("Search Query"):
# query_ref = gr.TextArea(
# value='',
# interactive=False,
# elem_id='query_ref',
# show_label=False,
# ).style(container=False)
# with gr.Tab("Prompt"):
# prompt_ref = gr.TextArea(
# value='',
# interactive=False,
# elem_id='prompt_ref',
# show_label=False,
# ).style(container=False)
# print(query_ref, prompt_ref)
demo.queue()
if __name__ == "__main__":
demo.launch()