Spaces:
Sleeping
Sleeping
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb. | |
# %% auto 0 | |
__all__ = ['HF_TOKEN', 'ENDPOINT_URL', 'title', 'description', 'get_model_endpoint_params', 'query_chat_api', 'format_history', | |
'inference_chat'] | |
# %% app.ipynb 0 | |
import gradio as gr | |
import requests | |
import json | |
import requests | |
import os | |
from pathlib import Path | |
from dotenv import load_dotenv | |
# %% app.ipynb 1 | |
if Path(".env").is_file(): | |
load_dotenv(".env") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
ENDPOINT_URL = os.getenv("ENDPOINT_URL") | |
# %% app.ipynb 2 | |
def get_model_endpoint_params(model_id): | |
if "joi" in model_id: | |
headers = None | |
max_new_tokens_supported = True | |
return ENDPOINT_URL, headers, max_new_tokens_supported | |
else: | |
max_new_tokens_supported = False | |
headers = {"Authorization": f"Bearer {HF_TOKEN}", "x-wait-for-model": "1"} | |
return f"https://api-inference.huggingface.co/models/{model_id}", headers, max_new_tokens_supported | |
# %% app.ipynb 3 | |
def query_chat_api( | |
model_id, | |
inputs, | |
temperature, | |
top_p | |
): | |
endpoint, headers, max_new_tokens_supported = get_model_endpoint_params(model_id) | |
payload = { | |
"inputs": inputs, | |
"parameters": { | |
"temperature": temperature, | |
"top_p": top_p, | |
"do_sample": True, | |
}, | |
} | |
if max_new_tokens_supported is True: | |
payload["parameters"]["max_new_tokens"] = 100 | |
payload["parameters"]["repetition_penalty"]: 1.03 | |
payload["parameters"]["stop"] = ["User:"] | |
else: | |
payload["parameters"]["max_length"] = 512 | |
response = requests.post(endpoint, json=payload, headers=headers) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
return "Error: " + response.text | |
# %% app.ipynb 5 | |
def format_history(history, human="Human", bot="Assistant"): | |
history_input = "" | |
for idx, text in enumerate(history): | |
if idx % 2 == 0: | |
history_input += f"{human}: {text}\n\n" | |
else: | |
history_input += f"{bot}: {text}\n\n" | |
history_input = history_input.rstrip("\n") | |
return history_input | |
# %% app.ipynb 7 | |
def inference_chat( | |
model_id, | |
text_input, | |
temperature, | |
top_p, | |
history=[], | |
): | |
if "joi" in model_id: | |
prompt_filename = "openassistant_joi.json" | |
history_input = format_history(history, human="User", bot="Joi") | |
else: | |
prompt_filename = "anthropic_hhh_single.json" | |
history_input = format_history(history, human="Human", bot="Assistant") | |
with open(f"prompt_templates/{prompt_filename}", "r") as f: | |
prompt_template = json.load(f) | |
inputs = prompt_template["prompt"].format(human_input=text_input, history=history_input) | |
history.append(text_input) | |
print(f"History: {history}") | |
print(f"Inputs: {inputs}") | |
output = query_chat_api(model_id, inputs, temperature, top_p) | |
print(output) | |
if isinstance(output, list): | |
output = output[0] | |
if "joi" in model_id: | |
output = output["generated_text"].rstrip("\n\nUser:") | |
else: | |
output = output["generated_text"].rstrip(" Human:") | |
history.append(" " + output) | |
chat = [ | |
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) | |
] # convert to tuples of list | |
return {chatbot: chat, state: history} | |
# %% app.ipynb 25 | |
title = """<h1 align="center">Chatty Language Models</h1>""" | |
description = """Pretrained language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form: | |
``` | |
Human: <utterance> | |
Assistant: <utterance> | |
Human: <utterance> | |
Assistant: <utterance> | |
... | |
``` | |
In this app, you can explore the outputs of several language models conditioned on different conversational prompts. The models are trained on different datasets and have different objectives, so they will have different personalities and strengths. | |
""" | |
# %% app.ipynb 27 | |
with gr.Blocks( | |
css=""" | |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} | |
#component-21 > div.wrap.svelte-w6rprc {height: 600px;} | |
""" | |
) as iface: | |
state = gr.State([]) | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_id = gr.Dropdown( | |
choices=["google/flan-t5-xl" ,"Rallio67/joi_20B_instruct_alpha"], | |
value="google/flan-t5-xl", | |
label="Model", | |
interactive=True, | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.5, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
top_p = gr.Slider( | |
minimum=0., | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
interactive=True, | |
label="Top-p (nucleus sampling)", | |
) | |
with gr.Column(scale=1.8): | |
with gr.Row(): | |
chatbot = gr.Chatbot( | |
label="Chat Output", | |
) | |
with gr.Row(): | |
chat_input = gr.Textbox(lines=1, label="Chat Input") | |
chat_input.submit( | |
inference_chat, | |
[ | |
model_id, | |
chat_input, | |
temperature, | |
top_p, | |
state, | |
], | |
[chatbot, state], | |
) | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear", interactive=True) | |
clear_button.click( | |
lambda: ("", [], []), | |
[], | |
[chat_input, chatbot, state], | |
queue=False, | |
) | |
submit_button = gr.Button( | |
value="Submit", interactive=True, variant="primary" | |
) | |
submit_button.click( | |
inference_chat, | |
[ | |
model_id, | |
chat_input, | |
temperature, | |
top_p, | |
state, | |
], | |
[chatbot, state], | |
) | |
iface.launch() | |