import json
import gradio as gr
import os
import requests
from huggingface_hub import AsyncInferenceClient
HF_TOKEN = os.getenv('HF_TOKEN')
api_url = os.getenv('API_URL')
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
client = AsyncInferenceClient(api_url)
system_message = "\nYou are a helpful, respectful and honest Excel formula assistant. Always answer as helpfully as possible, while being safe."
title = "Excel Bot"
description = """
This is an Excel Assistant AI.
Note: Derivate work of [Llama-2-70b-chat](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) by Meta.
"""
css = """.toast-wrap { display: none !important } """
examples=[
['Write an Excel formula to sum number in a row.'],
["Write an Excel formula to generate a random number."],
]
# Note: We have removed default system prompt as requested by the paper authors [Dated: 13/Oct/2023]
# Prompting style for Llama2 without using system prompt
# [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST]
# Stream text - stream tokens with InferenceClient from TGI
async def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
if system_prompt != "":
input_prompt = f"[INST] <>\n{system_prompt}\n<>\n\n "
else:
input_prompt = f"[INST] "
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " [INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
partial_message = ""
async for token in await client.text_generation(prompt=input_prompt,
max_new_tokens=max_new_tokens,
stream=True,
best_of=1,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty):
partial_message = partial_message + token
yield partial_message
# No Stream - batch produce tokens using TGI inference endpoint
def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
if system_prompt != "":
input_prompt = f"[INST] <>\n{system_prompt}\n<>\n\n "
else:
input_prompt = f"[INST] "
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
for interaction in chatbot:
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " [INST] "
input_prompt = input_prompt + str(message) + " [/INST] "
print(f"input_prompt - {input_prompt}")
data = {
"inputs": input_prompt,
"parameters": {
"max_new_tokens":max_new_tokens,
"temperature":temperature,
"top_p":top_p,
"repetition_penalty":repetition_penalty,
"do_sample":True,
},
}
response = requests.post(api_url, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data),
if response.status_code == 200: # check if the request was successful
try:
json_obj = response.json()
if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0:
return json_obj[0]['generated_text']
elif 'error' in json_obj[0]:
return json_obj[0]['error'] + ' Please refresh and try again with smaller input prompt'
else:
print(f"Unexpected response: {json_obj[0]}")
except json.JSONDecodeError:
print(f"Failed to decode response as JSON: {response.text}")
else:
print(f"Request failed with status code {response.status_code}")
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
additional_inputs=[
gr.Textbox("", label="Optional system prompt"),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.6,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False)
chat_interface_stream = gr.ChatInterface(predict,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot_stream,
css=css,
examples=examples,
#cache_examples=True,
additional_inputs=additional_inputs,)
chat_interface_batch=gr.ChatInterface(predict_batch,
title=title,
description=description,
textbox=gr.Textbox(),
chatbot=chatbot_batch,
css=css,
examples=examples,
#cache_examples=True,
additional_inputs=additional_inputs,)
# Gradio Demo
with gr.Blocks() as demo:
with gr.Tab("Streaming"):
# streaming chatbot
chatbot_stream.like(vote, None, None)
chat_interface_stream.render()
# with gr.Tab("Batch"):
# # non-streaming chatbot
# chatbot_batch.like(vote, None, None)
# chat_interface_batch.render()
demo.queue(max_size=100).launch()