File size: 6,477 Bytes
35a5820 e218107 35a5820 1d08ff2 1cab0e8 35a5820 e218107 35a5820 1cab0e8 35a5820 e218107 35a5820 e218107 35a5820 e218107 35a5820 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import re
import gradio as gr
import openai
openai.api_base = os.environ.get("OPENAI_API_BASE")
openai.api_key = os.environ.get("OPENAI_API_KEY")
BASE_SYSTEM_MESSAGE = """I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning.
I am an assistant who thinks through their answers step-by-step to be sure I always get the right answer.
I think more clearly if I write out my thought process in a scratchpad manner first; therefore, I always explain background context, assumptions, and step-by-step thinking BEFORE trying to answer or solve anything."""
def make_prediction(prompt, max_tokens=None, temperature=None, top_p=None, top_k=None, repetition_penalty=None):
completion = openai.Completion.create(model="openaccess-ai-collective/jackalope-7b", prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, stream=True, stop=["</s>", "<|im_end|>"])
for chunk in completion:
yield chunk["choices"][0]["text"]
def clear_chat(chat_history_state, chat_message):
chat_history_state = []
chat_message = ''
return chat_history_state, chat_message
def user(message, history):
history = history or []
# Append the user's message to the conversation history
history.append([message, ""])
return "", history
def pop_last(history):
turn = history.pop()
# append the user's last message to the conversation history
history.append([turn[0], ""])
return history
def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty):
history = history or []
sys_prompt = system_message.strip() or BASE_SYSTEM_MESSAGE
messages = "<|im_start|> "+"system\n" + sys_prompt + "<|im_end|>\n" + \
"\n".join(["\n".join(["<|im_start|> "+"user\n"+item[0]+"<|im_end|>", "<|im_start|> assistant\n"+item[1]+"<|im_end|>"])
for item in history])
# strip the last `<|im_end|>` from the messages
messages = messages.rstrip("<|im_end|>")
# remove last space from assistant, some models output a ZWSP if you leave a space
messages = messages.rstrip()
# If temperature is set to 0, force Top P to 1 and Top K to -1
if temperature == 0:
top_p = 1
top_k = -1
prediction = make_prediction(
messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
for tokens in prediction:
tokens = re.findall(r'(.*?)(\s|$)', tokens)
for subtoken in tokens:
subtoken = "".join(subtoken)
answer = subtoken
history[-1][1] += answer
# stream the response
yield history, history, ""
start_message = BASE_SYSTEM_MESSAGE
CSS ="""
.contain { display: flex; flex-direction: column; }
.gradio-container { height: 100vh !important; }
#component-0 { height: 100%; }
#chatbot { flex-grow: 1; overflow: auto; resize: vertical; }
"""
#with gr.Blocks() as demo:
with gr.Blocks(css=CSS) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(f"""
## This PREVIEW demo is an un-quantized GPU chatbot of [Jackalope 7B](https://huggingface.co/openaccess-ai-collective/jackalope-7b)
- Completed model drops on Wednesday October 11th.
- Brought to you by your friends at Open Access AI Collective, Alignment Lab AI, and OpenChat!
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
""")
with gr.Row():
gr.Markdown("# π°π¦ Jackalope 7B Playground Space! π°π¦")
with gr.Row():
system_msg = gr.Textbox(
start_message, label="System Message", interactive=True, visible=True, placeholder="System prompt. Provide instructions which you want the model to remember.", lines=5)
with gr.Row():
chatbot = gr.Chatbot(elem_id="chatbot").style(height=400)
with gr.Row():
message = gr.Textbox(
label="What do you want to chat about?",
placeholder="Ask me anything.",
lines=3,
)
with gr.Row():
submit = gr.Button(value="Send message", variant="primary").style(full_width=True)
clear = gr.Button(value="New topic", variant="secondary").style(full_width=False)
stop = gr.Button(value="Stop", variant="secondary").style(full_width=False)
regenerate = gr.Button(value="Regenerate", variant="secondary").style(full_width=False)
with gr.Accordion("Show Model Parameters", open=False):
with gr.Row():
with gr.Column():
max_tokens = gr.Slider(20, 2500, label="Max Tokens", step=20, value=500)
temperature = gr.Slider(0.0, 2.0, label="Temperature", step=0.1, value=0.4)
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95)
top_k = gr.Slider(1, 100, label="Top K", step=1, value=40)
repetition_penalty = gr.Slider(1.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1)
chat_history_state = gr.State()
clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False)
clear.click(lambda: None, None, chatbot, queue=False)
submit_click_event = submit.click(
fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True
).then(
fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True
)
regenerate_click_event = regenerate.click(
fn=pop_last, inputs=[chat_history_state], outputs=[chat_history_state], queue=True
).then(
fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True
)
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event, regenerate_click_event], queue=False)
demo.queue(max_size=128, concurrency_count=48).launch(debug=True, server_name="0.0.0.0", server_port=7860) |