Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio_client import Client | |
from huggingface_hub import InferenceClient | |
ss_client = Client("https://omnibus-html-image-current-tab.hf.space/") | |
models = [ | |
"google/gemma-7b", | |
"google/gemma-7b-it", | |
"google/gemma-2b", | |
"google/gemma-2b-it" | |
] | |
clients = [ | |
InferenceClient(models[0]), | |
InferenceClient(models[1]), | |
InferenceClient(models[2]), | |
InferenceClient(models[3]), | |
] | |
VERBOSE = False | |
def load_models(): | |
return gr.update(label=models[0]) | |
def format_prompt(message, history): | |
prompt = "" | |
if history: | |
for user_prompt, bot_response in history: | |
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>" | |
prompt += f"<start_of_turn>model{bot_response}<end_of_turn>" | |
if VERBOSE: | |
print(prompt) | |
prompt += message | |
return prompt | |
def chat_inf(prompt, history, memory, temp, tokens, top_p, rep_p, chat_mem): | |
hist_len = 0 | |
client = clients[0] | |
if not history: | |
history = [] | |
hist_len = 0 | |
if not memory: | |
memory = [] | |
mem_len = 0 | |
if memory: | |
for ea in memory[0 - chat_mem :]: | |
hist_len += len(str(ea)) | |
in_len = len(prompt) + hist_len | |
if (in_len + tokens) > 8000: | |
history.append( | |
( | |
prompt, | |
"Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value", | |
) | |
) | |
yield history, memory | |
else: | |
generate_kwargs = dict( | |
temperature=temp, | |
max_new_tokens=tokens, | |
top_p=top_p, | |
repetition_penalty=rep_p, | |
do_sample=True, | |
) | |
formatted_prompt = format_prompt(prompt, memory[0 - chat_mem :]) | |
stream = client.text_generation( | |
formatted_prompt, | |
**generate_kwargs, | |
stream=True, | |
details=True, | |
return_full_text=True, | |
) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
yield [(prompt, output)], memory | |
history.append((prompt, output)) | |
memory.append((prompt, output)) | |
yield history, memory | |
if VERBOSE: | |
print("\n######### HIST " + str(in_len)) | |
print("\n######### TOKENS " + str(tokens)) | |
def get_screenshot( | |
chat: list, | |
height=5000, | |
width=600, | |
chatblock=[], | |
theme="light", | |
wait=3000, | |
header=True, | |
): | |
tog = 0 | |
if chatblock: | |
tog = 3 | |
result = ss_client.predict( | |
str(chat), | |
height, | |
width, | |
chatblock, | |
header, | |
theme, | |
wait, | |
api_name="/run_script", | |
) | |
out = f'https://omnibus-html-image-current-tab.hf.space/file={result[tog]}' | |
return out | |
def clear_fn(): | |
return None, None, None, None | |
with gr.Blocks() as app: | |
memory = gr.State() | |
chat_b = gr.Chatbot(height=500) | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
inp = gr.Textbox(label="Prompt") | |
btn = gr.Button("Chat") | |
with gr.Column(scale=1): | |
with gr.Group(): | |
temp = gr.Slider( | |
label="Temperature", | |
step=0.01, | |
minimum=0.01, | |
maximum=1.0, | |
value=0.49, | |
) | |
tokens = gr.Slider( | |
label="Max new tokens", | |
value=1600, | |
minimum=0, | |
maximum=8000, | |
step=64, | |
interactive=True, | |
visible=True, | |
info="The maximum number of tokens", | |
) | |
top_p = gr.Slider( | |
label="Top-P", | |
step=0.01, | |
minimum=0.01, | |
maximum=1.0, | |
value=0.49, | |
) | |
rep_p = gr.Slider( | |
label="Repetition Penalty", | |
step=0.01, | |
minimum=0.1, | |
maximum=2.0, | |
value=0.99, | |
) | |
chat_mem = gr.Number( | |
label="Chat Memory", | |
info="Number of previous chats to retain", | |
value=4, | |
) | |
app.load(load_models) | |
chat_sub = inp.submit().then( | |
chat_inf, [inp, chat_b, memory, temp, tokens, top_p, rep_p, chat_mem], [chat_b, memory] | |
) | |
go = btn.click().then( | |
chat_inf, | |