|
|
|
|
|
import os |
|
from typing import Iterator |
|
|
|
import gradio as gr |
|
import torch |
|
|
|
from model import run |
|
from settings import (ALLOW_CHANGING_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS, |
|
DEFAULT_SYSTEM_PROMPT, MAX_MAX_NEW_TOKENS) |
|
|
|
DESCRIPTION = '# Llama-2 7B chat' |
|
if not torch.cuda.is_available(): |
|
DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>' |
|
|
|
|
|
def clear_and_save_textbox(message: str) -> tuple[str, str]: |
|
return '', message |
|
|
|
|
|
def display_input(message: str, |
|
history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
|
history.append((message, '')) |
|
return history |
|
|
|
|
|
def delete_prev_fn( |
|
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
|
try: |
|
message, _ = history.pop() |
|
except IndexError: |
|
message = '' |
|
return history, message or '' |
|
|
|
|
|
def fn( |
|
message: str, |
|
history_with_input: list[tuple[str, str]], |
|
system_prompt: str, |
|
max_new_tokens: int, |
|
top_p: float, |
|
temperature: float, |
|
top_k: int, |
|
) -> Iterator[list[tuple[str, str]]]: |
|
if max_new_tokens > MAX_MAX_NEW_TOKENS: |
|
raise ValueError |
|
|
|
history = history_with_input[:-1] |
|
generator = run(message, history, system_prompt, max_new_tokens, |
|
temperature, top_p, top_k) |
|
try: |
|
first_response = next(generator) |
|
yield history + [(message, first_response)] |
|
except StopIteration: |
|
yield history + [(message, '')] |
|
for response in generator: |
|
yield history + [(message, response)] |
|
|
|
|
|
with gr.Blocks(css='style.css') as demo: |
|
gr.Markdown(DESCRIPTION) |
|
gr.DuplicateButton(value='Duplicate Space for private use', |
|
elem_id='duplicate-button', |
|
visible=os.getenv('SHOW_DUPLICATE_BUTTON') == '1') |
|
|
|
with gr.Group(): |
|
chatbot = gr.Chatbot(label='Chatbot') |
|
with gr.Row(): |
|
textbox = gr.Textbox( |
|
container=False, |
|
show_label=False, |
|
placeholder='Type a message...', |
|
scale=10, |
|
) |
|
submit_button = gr.Button('Submit', |
|
variant='primary', |
|
scale=1, |
|
min_width=0) |
|
with gr.Row(): |
|
retry_button = gr.Button('🔄 Retry', variant='secondary') |
|
undo_button = gr.Button('↩️ Undo', variant='secondary') |
|
clear_button = gr.Button('🗑️ Clear', variant='secondary') |
|
|
|
saved_input = gr.State() |
|
|
|
with gr.Accordion(label='Advanced options', open=False): |
|
system_prompt = gr.Textbox(label='System prompt', |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
lines=6, |
|
interactive=ALLOW_CHANGING_SYSTEM_PROMPT) |
|
max_new_tokens = gr.Slider( |
|
label='Max new tokens', |
|
minimum=1, |
|
maximum=MAX_MAX_NEW_TOKENS, |
|
step=1, |
|
value=DEFAULT_MAX_NEW_TOKENS, |
|
) |
|
temperature = gr.Slider( |
|
label='Temperature', |
|
minimum=0.1, |
|
maximum=5.0, |
|
step=0.1, |
|
value=0.8, |
|
) |
|
top_p = gr.Slider( |
|
label='Top-p (nucleus sampling)', |
|
minimum=0.05, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.95, |
|
) |
|
top_k = gr.Slider( |
|
label='Top-k', |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=50, |
|
) |
|
|
|
textbox.submit( |
|
fn=clear_and_save_textbox, |
|
inputs=textbox, |
|
outputs=[textbox, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=fn, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
system_prompt, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name=False, |
|
) |
|
|
|
button_event_preprocess = submit_button.click( |
|
fn=clear_and_save_textbox, |
|
inputs=textbox, |
|
outputs=[textbox, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=fn, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
system_prompt, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name=False, |
|
) |
|
|
|
retry_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=fn, |
|
inputs=[ |
|
saved_input, |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
top_k, |
|
], |
|
outputs=chatbot, |
|
api_name=False, |
|
) |
|
|
|
undo_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=lambda x: x, |
|
inputs=[saved_input], |
|
outputs=textbox, |
|
api_name=False, |
|
queue=False, |
|
) |
|
|
|
clear_button.click( |
|
fn=lambda: ([], ''), |
|
outputs=[chatbot, saved_input], |
|
queue=False, |
|
api_name=False, |
|
) |
|
|
|
demo.queue(max_size=20).launch() |
|
|