Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import json | |
import requests | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} | |
zephyr_7b_beta = os.getenv('zephyr_7b_beta') | |
zephyr_7b_alpha = os.getenv('zephyr_7b_alpha') | |
def build_input_prompt(message, chatbot): | |
""" | |
Constructs the input prompt string from the chatbot interactions and the current message. | |
""" | |
input_prompt = "<|system|>\n</s>\n<|user|>\n" | |
for interaction in chatbot: | |
input_prompt = input_prompt + str(interaction[0]) + "</s>\n<|assistant|>\n" + str(interaction[1]) + "\n</s>\n<|user|>\n" | |
input_prompt = input_prompt + str(message) + "</s>\n<|assistant|>" | |
return input_prompt | |
def post_request_beta(payload): | |
""" | |
Sends a POST request to the predefined Zephyr-7b-Beta URL and returns the JSON response. | |
""" | |
response = requests.post(zephyr_7b_beta, headers=HEADERS, json=payload) | |
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code | |
return response.json() | |
def post_request_alpha(payload): | |
""" | |
Sends a POST request to the predefined Zephyr-7b-Alpha URL and returns the JSON response. | |
""" | |
response = requests.post(zephyr_7b_alpha, headers=HEADERS, json=payload) | |
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code | |
return response.json() | |
def predict_beta(message, chatbot=[], temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0): | |
temperature = float(temperature) | |
top_p = float(top_p) | |
input_prompt = build_input_prompt(message, chatbot) | |
data = { | |
"inputs": input_prompt, | |
"parameters": { | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"do_sample": True, | |
}, | |
} | |
try: | |
response_data = post_request_beta(data) | |
json_obj = response_data[0] | |
if 'generated_text' in json_obj and len(json_obj['generated_text']) > 0: | |
bot_message = json_obj['generated_text'] | |
chatbot.append((message, bot_message)) | |
return "", chatbot | |
elif 'error' in json_obj: | |
raise gr.Error(json_obj['error'] + ' Please refresh and try again with smaller input prompt') | |
else: | |
warning_msg = f"Unexpected response: {json_obj}" | |
raise gr.Error(warning_msg) | |
except requests.HTTPError as e: | |
error_msg = f"Request failed with status code {e.response.status_code}" | |
raise gr.Error(error_msg) | |
except json.JSONDecodeError as e: | |
error_msg = f"Failed to decode response as JSON: {str(e)}" | |
raise gr.Error(error_msg) | |
def predict_alpha(message, chatbot=[], temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0): | |
temperature = float(temperature) | |
top_p = float(top_p) | |
input_prompt = build_input_prompt(message, chatbot) | |
data = { | |
"inputs": input_prompt, | |
"parameters": { | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"do_sample": True, | |
}, | |
} | |
try: | |
response_data = post_request_beta(data) | |
json_obj = response_data[0] | |
if 'generated_text' in json_obj and len(json_obj['generated_text']) > 0: | |
bot_message = json_obj['generated_text'] | |
chatbot.append((message, bot_message)) | |
return "", chatbot | |
elif 'error' in json_obj: | |
raise gr.Error(json_obj['error'] + ' Please refresh and try again with smaller input prompt') | |
else: | |
warning_msg = f"Unexpected response: {json_obj}" | |
raise gr.Error(warning_msg) | |
except requests.HTTPError as e: | |
error_msg = f"Request failed with status code {e.response.status_code}" | |
raise gr.Error(error_msg) | |
except json.JSONDecodeError as e: | |
error_msg = f"Failed to decode response as JSON: {str(e)}" | |
raise gr.Error(error_msg) | |
def retry_fun_beta(chat_history_beta ): | |
""" | |
Retries the prediction for the last message in the chat history. | |
Removes the last interaction and gets a new prediction for the same message from Zephyr-7b-Beta | |
""" | |
if not chat_history or len(chat_history) < 1: | |
raise gr.Error("Chat history is empty or invalid.") | |
message = chat_history_beta[-1][0] | |
chat_history_beta.pop() | |
_, updated_chat_history_beta = predict_beta(message, chat_history_beta) | |
return updated_chat_history_beta | |
def retry_fun_alpha(chat_history_alpha ): | |
""" | |
Retries the prediction for the last message in the chat history. | |
Removes the last interaction and gets a new prediction for the same message from Zephyr-7b-Alpha | |
""" | |
if not chat_history or len(chat_history) < 1: | |
raise gr.Error("Chat history is empty or invalid.") | |
message = chat_history_alpha[-1][0] | |
chat_history_alpha.pop() | |
_, updated_chat_history_alpha = predict_alpha(message, chat_history_alpha) | |
return updated_chat_history_alpha | |
title = "🌀Zephyr Playground🎮" | |
description = """ | |
Welcome to the Zephyr Playground! This interactive space lets you experience the prowess of two distinct Zephyr models – [Zephyr-7b-Alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) and [Zephyr-7b-Beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) – side by side. These models are products of fine-tuning the Mistral models. | |
- 🔎 Dive deep into the nuances and performance of these models by comparing their responses in real-time. | |
- 📖 For a comprehensive understanding of the Zephyr models, delve into their [technical report](https://arxiv.org/abs/2310.16944) and experiment with the [official Zephyr demo](https://huggingfaceh4-zephyr-chat.hf.space/). | |
- 🛠 If you wish to explore more chat models or set up your own interactive demo, visit the [Hugging Face's chat playground](https://huggingface.co/spaces/HuggingFaceH4/chat-playground). | |
""" | |
footnote = """Note: All rights, including licensing and acceptable use policies, related to the Zephyr models, can be found on their respective model pages on Hugging Face. | |
""" | |
css = """ | |
.gradio-container { | |
width: 100vw !important; | |
min-height: 100vh !important; | |
padding:0 !important; | |
margin:0 !important; | |
max-width: none !important; | |
} | |
""" | |
# Create chatbot components | |
chat_beta = gr.Chatbot(label="zephyr-7b-beta") | |
chat_alpha = gr.Chatbot(label="zephyr-7b-alpha") | |
# Create input and button components | |
textbox = gr.Textbox(container=False, | |
placeholder='Enter text and click the Submit button or press Enter') | |
submit = gr.Button('Submit', variant='primary',) | |
retry = gr.Button('🔄Retry', variant='secondary') | |
undo = gr.Button('↩️Undo', variant='secondary') | |
# Layout the components using Gradio Blocks API | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(f'<h1><center> {title} </center></h1>') | |
gr.Markdown(description) | |
with gr.Row(): | |
chat_beta.render() | |
chat_alpha.render() | |
with gr.Group(): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=5): | |
textbox.render() | |
with gr.Column(scale=1): | |
submit.render() | |
with gr.Row(): | |
retry.render() | |
undo.render() | |
clear = gr.ClearButton(value='🗑️Clear', | |
components=[textbox, | |
chat_beta, | |
chat_alpha]) | |
gr.Markdown(footnote) | |
# Assign events to components | |
textbox.submit(predict_beta, [textbox, chat_beta], [textbox, chat_beta]) | |
textbox.submit(predict_alpha, [textbox, chat_alpha], [textbox, chat_alpha]) | |
submit.click(predict_beta, [textbox, chat_beta], [textbox, chat_beta]) | |
submit.click(predict_alpha, [textbox, chat_alpha], [textbox, chat_alpha]) | |
undo.click(lambda x:x[:-1], [chat_beta], [chat_beta]) | |
undo.click(lambda x:x[:-1], [chat_alpha], [chat_alpha]) | |
retry.click(retry_fun_beta, [chat_beta], [chat_beta]) | |
retry.click(retry_fun_alpha, [chat_alpha], [chat_alpha]) | |
gr.Examples([ | |
['Hi! Who are you?'], | |
['What is a meme?'], | |
['Explain the plot of Cinderella in a sentence.'], | |
['Assuming I am a huge alien species with the ability to consume helicopters, how long would it take me to eat one?'], | |
], | |
textbox) | |
# Launch the demo | |
demo.launch(debug=True) |