File size: 8,664 Bytes
673ee85
da6bfd7
 
 
 
 
 
d045ffb
 
da6bfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d045ffb
da6bfd7
 
 
 
40a578e
da6bfd7
 
 
775c97a
da6bfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebd5d39
 
 
da6bfd7
ebd5d39
 
da6bfd7
 
 
40a578e
da6bfd7
 
 
775c97a
da6bfd7
 
 
 
 
 
 
 
 
 
 
 
 
373f61d
da6bfd7
 
 
 
 
 
 
 
 
 
ebd5d39
 
 
da6bfd7
ebd5d39
 
da6bfd7
 
 
 
 
 
 
 
373f61d
da6bfd7
 
 
 
 
 
 
 
 
 
 
 
 
373f61d
da6bfd7
 
 
 
 
 
 
373f61d
673ee85
 
 
 
4b963fb
 
 
 
 
673ee85
ebd5d39
 
 
 
 
 
 
 
 
 
 
da6bfd7
29a138d
 
da6bfd7
 
 
 
 
 
 
 
 
ebd5d39
673ee85
 
da6bfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b963fb
 
da6bfd7
 
 
 
 
 
 
 
 
 
 
 
4b963fb
 
 
 
 
 
 
 
 
da6bfd7
4b963fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import gradio as gr 
import os 
import json
import requests


HF_TOKEN = os.getenv('HF_TOKEN')
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}

zephyr_7b_beta = os.getenv('zephyr_7b_beta')
zephyr_7b_alpha = os.getenv('zephyr_7b_alpha')


def build_input_prompt(message, chatbot):
    """
    Constructs the input prompt string from the chatbot interactions and the current message.
    """
    input_prompt = "<|system|>\n</s>\n<|user|>\n"
    for interaction in chatbot:
        input_prompt = input_prompt + str(interaction[0]) + "</s>\n<|assistant|>\n" + str(interaction[1]) + "\n</s>\n<|user|>\n"

    input_prompt = input_prompt + str(message) + "</s>\n<|assistant|>"
    return input_prompt


def post_request_beta(payload):
    """
    Sends a POST request to the predefined Zephyr-7b-Beta URL and returns the JSON response.
    """
    response = requests.post(zephyr_7b_beta, headers=HEADERS, json=payload)
    response.raise_for_status()  # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
    return response.json()


def post_request_alpha(payload):
    """
    Sends a POST request to the predefined Zephyr-7b-Alpha URL and returns the JSON response.
    """
    response = requests.post(zephyr_7b_alpha, headers=HEADERS, json=payload)
    response.raise_for_status()  # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
    return response.json()


def predict_beta(message, chatbot=[], temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0):
    temperature = float(temperature)
    top_p = float(top_p)

    input_prompt = build_input_prompt(message, chatbot)

    data = {
        "inputs": input_prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "do_sample": True,
        },
    }

    try:
        response_data = post_request_beta(data)
        json_obj = response_data[0]
        
        if 'generated_text' in json_obj and len(json_obj['generated_text']) > 0:
            bot_message = json_obj['generated_text']
            chatbot.append((message, bot_message))
            return "", chatbot
        elif 'error' in json_obj:
            raise gr.Error(json_obj['error'] + ' Please refresh and try again with smaller input prompt')
        else:
            warning_msg = f"Unexpected response: {json_obj}"
            raise gr.Error(warning_msg)
    except requests.HTTPError as e:
        error_msg = f"Request failed with status code {e.response.status_code}"
        raise gr.Error(error_msg)
    except json.JSONDecodeError as e:
        error_msg = f"Failed to decode response as JSON: {str(e)}"
        raise gr.Error(error_msg)


def predict_alpha(message, chatbot=[], temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0):
    temperature = float(temperature)
    top_p = float(top_p)

    input_prompt = build_input_prompt(message, chatbot)

    data = {
        "inputs": input_prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "do_sample": True,
        },
    }

    try:
        response_data = post_request_alpha(data)
        json_obj = response_data[0]
        
        if 'generated_text' in json_obj and len(json_obj['generated_text']) > 0:
            bot_message = json_obj['generated_text']
            chatbot.append((message, bot_message))
            return "", chatbot
        elif 'error' in json_obj:
            raise gr.Error(json_obj['error'] + ' Please refresh and try again with smaller input prompt')
        else:
            warning_msg = f"Unexpected response: {json_obj}"
            raise gr.Error(warning_msg)
    except requests.HTTPError as e:
        error_msg = f"Request failed with status code {e.response.status_code}"
        raise gr.Error(error_msg)
    except json.JSONDecodeError as e:
        error_msg = f"Failed to decode response as JSON: {str(e)}"
        raise gr.Error(error_msg)


def retry_fun_beta(chat_history_beta ):
    """
    Retries the prediction for the last message in the chat history.
    Removes the last interaction and gets a new prediction for the same message from Zephyr-7b-Beta
    """
    if not chat_history_beta or len(chat_history_beta) < 1:
        raise gr.Error("Chat history is empty or invalid.")
    
    message = chat_history_beta[-1][0]
    chat_history_beta.pop()
    _, updated_chat_history_beta = predict_beta(message, chat_history_beta)
    return updated_chat_history_beta


def retry_fun_alpha(chat_history_alpha ):
    """
    Retries the prediction for the last message in the chat history.
    Removes the last interaction and gets a new prediction for the same message from Zephyr-7b-Alpha
    """
    if not chat_history_alpha or len(chat_history_alpha) < 1:
        raise gr.Error("Chat history is empty or invalid.")
    
    message = chat_history_alpha[-1][0]
    chat_history_alpha.pop()
    _, updated_chat_history_alpha = predict_alpha(message, chat_history_alpha)
    return updated_chat_history_alpha


title = "🌀Zephyr Playground🎮"
description = """
Welcome to the Zephyr Playground! This interactive space lets you experience the prowess of two distinct Zephyr models – [Zephyr-7b-Alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) and [Zephyr-7b-Beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) – side by side. These models are products of fine-tuning the Mistral models.

- 🔎 Dive deep into the nuances and performance of these models by comparing their responses in real-time.
- 📖 For a comprehensive understanding of the Zephyr models, delve into their [technical report](https://arxiv.org/abs/2310.16944) and experiment with the [official Zephyr demo](https://huggingfaceh4-zephyr-chat.hf.space/).
- 🛠 If you wish to explore more chat models or set up your own interactive demo, visit the [Hugging Face's chat playground](https://huggingface.co/spaces/HuggingFaceH4/chat-playground).
"""
footnote = """Note: All rights, including licensing and acceptable use policies, related to the Zephyr models, can be found on their respective model pages on Hugging Face.
"""

css = """
.gradio-container {
    width: 100vw !important;
    min-height: 100vh !important;
    padding:0 !important;
    margin:0 !important;
    max-width: none !important;
}
"""

# Create chatbot components
chat_beta = gr.Chatbot(label="zephyr-7b-beta", layout='panel')
chat_alpha = gr.Chatbot(label="zephyr-7b-alpha", layout='panel')

# Create input and button components
textbox = gr.Textbox(container=False,
                     placeholder='Enter text and click the Submit button or press Enter')
submit = gr.Button('Submit', variant='primary',)
retry = gr.Button('🔄Retry', variant='secondary')
undo = gr.Button('↩️Undo', variant='secondary')

# Layout the components using Gradio Blocks API
with gr.Blocks(css=css) as demo:
  gr.HTML(f'<h1><center> {title} </center></h1>')
  gr.Markdown(description)
  with gr.Row():
    chat_beta.render()
    chat_alpha.render()
  with gr.Group():
    with gr.Row(equal_height=True):
      with gr.Column(scale=5):
        textbox.render()
      with gr.Column(scale=1):
        submit.render()
    with gr.Row():
      retry.render()
      undo.render()
      clear = gr.ClearButton(value='🗑️Clear',
                            components=[textbox,
                                        chat_beta, 
                                        chat_alpha])

  gr.Markdown(footnote)

  # Assign events to components
  textbox.submit(predict_beta, [textbox, chat_beta], [textbox, chat_beta])
  textbox.submit(predict_alpha, [textbox, chat_alpha], [textbox, chat_alpha])
  submit.click(predict_beta, [textbox, chat_beta], [textbox, chat_beta])
  submit.click(predict_alpha, [textbox, chat_alpha], [textbox, chat_alpha])

  undo.click(lambda x:x[:-1], [chat_beta], [chat_beta])
  undo.click(lambda x:x[:-1], [chat_alpha], [chat_alpha])

  retry.click(retry_fun_beta, [chat_beta], [chat_beta])
  retry.click(retry_fun_alpha, [chat_alpha], [chat_alpha])

  gr.Examples([
              ['Hi! Who are you?'],
              ['What is a meme?'],
              ['Explain the plot of Cinderella in a sentence.'],
              ['Assuming I am a huge alien species with the ability to consume helicopters, how long would it take me to eat one?'],
              ],
              textbox)
              
  
# Launch the demo
demo.launch(debug=True)