Spaces:
Runtime error
Runtime error
import json | |
import gradio as gr | |
import os | |
import requests | |
from huggingface_hub import AsyncInferenceClient | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
api_url = os.getenv('API_URL') | |
headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
client = AsyncInferenceClient(api_url) | |
system_message = "\nYou are a helpful, respectful and honest Excel formula assistant. Always answer as helpfully as possible, while being safe." | |
title = "Excel Bot" | |
description = """ | |
This is an Excel Assistant AI. | |
Note: Derivate work of [Llama-2-70b-chat](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) by Meta. | |
""" | |
css = """.toast-wrap { display: none !important } """ | |
examples=[ | |
['Write an Excel formula to sum number in a row.'], | |
["Write an Excel formula to generate a random number."], | |
] | |
# Note: We have removed default system prompt as requested by the paper authors [Dated: 13/Oct/2023] | |
# Prompting style for Llama2 without using system prompt | |
# <s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST] | |
# Stream text - stream tokens with InferenceClient from TGI | |
async def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,): | |
if system_prompt != "": | |
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " | |
else: | |
input_prompt = f"<s>[INST] " | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
for interaction in chatbot: | |
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] " | |
input_prompt = input_prompt + str(message) + " [/INST] " | |
partial_message = "" | |
async for token in await client.text_generation(prompt=input_prompt, | |
max_new_tokens=max_new_tokens, | |
stream=True, | |
best_of=1, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
repetition_penalty=repetition_penalty): | |
partial_message = partial_message + token | |
yield partial_message | |
# No Stream - batch produce tokens using TGI inference endpoint | |
def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,): | |
if system_prompt != "": | |
input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " | |
else: | |
input_prompt = f"<s>[INST] " | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
for interaction in chatbot: | |
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] " | |
input_prompt = input_prompt + str(message) + " [/INST] " | |
print(f"input_prompt - {input_prompt}") | |
data = { | |
"inputs": input_prompt, | |
"parameters": { | |
"max_new_tokens":max_new_tokens, | |
"temperature":temperature, | |
"top_p":top_p, | |
"repetition_penalty":repetition_penalty, | |
"do_sample":True, | |
}, | |
} | |
response = requests.post(api_url, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data), | |
if response.status_code == 200: # check if the request was successful | |
try: | |
json_obj = response.json() | |
if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0: | |
return json_obj[0]['generated_text'] | |
elif 'error' in json_obj[0]: | |
return json_obj[0]['error'] + ' Please refresh and try again with smaller input prompt' | |
else: | |
print(f"Unexpected response: {json_obj[0]}") | |
except json.JSONDecodeError: | |
print(f"Failed to decode response as JSON: {response.text}") | |
else: | |
print(f"Request failed with status code {response.status_code}") | |
def vote(data: gr.LikeData): | |
if data.liked: | |
print("You upvoted this response: " + data.value) | |
else: | |
print("You downvoted this response: " + data.value) | |
additional_inputs=[ | |
gr.Textbox("", label="Optional system prompt"), | |
gr.Slider( | |
label="Temperature", | |
value=0.9, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
), | |
gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=4096, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.6, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
] | |
chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False) | |
chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False) | |
chat_interface_stream = gr.ChatInterface(predict, | |
title=title, | |
description=description, | |
textbox=gr.Textbox(), | |
chatbot=chatbot_stream, | |
css=css, | |
examples=examples, | |
#cache_examples=True, | |
additional_inputs=additional_inputs,) | |
chat_interface_batch=gr.ChatInterface(predict_batch, | |
title=title, | |
description=description, | |
textbox=gr.Textbox(), | |
chatbot=chatbot_batch, | |
css=css, | |
examples=examples, | |
#cache_examples=True, | |
additional_inputs=additional_inputs,) | |
# Gradio Demo | |
with gr.Blocks() as demo: | |
with gr.Tab("Streaming"): | |
# streaming chatbot | |
chatbot_stream.like(vote, None, None) | |
chat_interface_stream.render() | |
# with gr.Tab("Batch"): | |
# # non-streaming chatbot | |
# chatbot_batch.like(vote, None, None) | |
# chat_interface_batch.render() | |
demo.queue(max_size=100).launch() |