import json import gradio as gr import os import requests from huggingface_hub import AsyncInferenceClient HF_TOKEN = os.getenv('HF_TOKEN') api_url = os.getenv('API_URL') headers = {"Authorization": f"Bearer {HF_TOKEN}"} client = AsyncInferenceClient(api_url) system_message = "\nYou are a helpful, respectful and honest Excel formula assistant. Always answer as helpfully as possible, while being safe." title = "Excel Bot" description = """ This is an Excel Assistant AI. Note: Derivate work of [Llama-2-70b-chat](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) by Meta. """ css = """.toast-wrap { display: none !important } """ examples=[ ['Write an Excel formula to sum number in a row.'], ["Write an Excel formula to generate a random number."], ] # Note: We have removed default system prompt as requested by the paper authors [Dated: 13/Oct/2023] # Prompting style for Llama2 without using system prompt # [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] # Stream text - stream tokens with InferenceClient from TGI async def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,): if system_prompt != "": input_prompt = f"[INST] <>\n{system_prompt}\n<>\n\n " else: input_prompt = f"[INST] " temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) for interaction in chatbot: input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " [INST] " input_prompt = input_prompt + str(message) + " [/INST] " partial_message = "" async for token in await client.text_generation(prompt=input_prompt, max_new_tokens=max_new_tokens, stream=True, best_of=1, temperature=temperature, top_p=top_p, do_sample=True, repetition_penalty=repetition_penalty): partial_message = partial_message + token yield partial_message # No Stream - batch produce tokens using TGI inference endpoint def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,): if system_prompt != "": input_prompt = f"[INST] <>\n{system_prompt}\n<>\n\n " else: input_prompt = f"[INST] " temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) for interaction in chatbot: input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " [INST] " input_prompt = input_prompt + str(message) + " [/INST] " print(f"input_prompt - {input_prompt}") data = { "inputs": input_prompt, "parameters": { "max_new_tokens":max_new_tokens, "temperature":temperature, "top_p":top_p, "repetition_penalty":repetition_penalty, "do_sample":True, }, } response = requests.post(api_url, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data), if response.status_code == 200: # check if the request was successful try: json_obj = response.json() if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0: return json_obj[0]['generated_text'] elif 'error' in json_obj[0]: return json_obj[0]['error'] + ' Please refresh and try again with smaller input prompt' else: print(f"Unexpected response: {json_obj[0]}") except json.JSONDecodeError: print(f"Failed to decode response as JSON: {response.text}") else: print(f"Request failed with status code {response.status_code}") def vote(data: gr.LikeData): if data.liked: print("You upvoted this response: " + data.value) else: print("You downvoted this response: " + data.value) additional_inputs=[ gr.Textbox("", label="Optional system prompt"), gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ), gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=4096, step=64, interactive=True, info="The maximum numbers of new tokens", ), gr.Slider( label="Top-p (nucleus sampling)", value=0.6, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) ] chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False) chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False) chat_interface_stream = gr.ChatInterface(predict, title=title, description=description, textbox=gr.Textbox(), chatbot=chatbot_stream, css=css, examples=examples, #cache_examples=True, additional_inputs=additional_inputs,) chat_interface_batch=gr.ChatInterface(predict_batch, title=title, description=description, textbox=gr.Textbox(), chatbot=chatbot_batch, css=css, examples=examples, #cache_examples=True, additional_inputs=additional_inputs,) # Gradio Demo with gr.Blocks() as demo: with gr.Tab("Streaming"): # streaming chatbot chatbot_stream.like(vote, None, None) chat_interface_stream.render() # with gr.Tab("Batch"): # # non-streaming chatbot # chatbot_batch.like(vote, None, None) # chat_interface_batch.render() demo.queue(max_size=100).launch()