vicuna-chat / app.py
celestialli's picture
Update app.py
05d618f
import argparse
from collections import defaultdict
import datetime
import json
import os
import random
import time
import uuid
import websocket
from websocket import WebSocketConnectionClosedException
import gradio as gr
import requests
import logging
import re
from fastchat.conversation import SeparatorStyle
from fastchat.constants import (
LOGDIR,
WORKER_API_TIMEOUT,
ErrorCode,
MODERATION_MSG,
CONVERSATION_LIMIT_MSG,
SERVER_ERROR_MSG,
INACTIVE_MSG,
INPUT_CHAR_LEN_LIMIT,
CONVERSATION_TURN_LIMIT,
SESSION_EXPIRATION_TIME,
)
from fastchat.model.model_adapter import get_conversation_template
from fastchat.model.model_registry import model_info
from fastchat.serve.api_provider import (
anthropic_api_stream_iter,
openai_api_stream_iter,
palm_api_stream_iter,
init_palm_chat,
)
from fastchat.utils import (
build_logger,
violates_moderation,
get_window_url_params_js,
parse_gradio_auth_creds,
)
logger = build_logger("gradio_web_server", "gradio_web_server.log")
no_change_dropdown = gr.Dropdown.update()
no_change_slider = gr.Slider.update()
no_change_textbox = gr.Textbox.update()
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
def get_internet_ip():
r = requests.get("http://txt.go.sohu.com/ip/soip")
ip = re.findall(r'\d+.\d+.\d+.\d+', r.text)
if ip is not None and len(ip) > 0:
return ip[0]
return None
enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False
concurrency_count = int(os.environ.get('concurrency_count', default='10'))
model_list_mode = os.environ.get('model_list_mode', default='reload')
midware_url = os.environ.get('midware_url', default='')
preset_token = os.environ.get('preset_token', default='')
worker_addr = os.environ.get('worker_addr', default='')
allow_running = int(os.environ.get('allow_running', default='1'))
ft_list_job_url = os.environ.get('ft_list_job_url', default='')
ft_submit_job_url = os.environ.get('ft_submit_job_url', default='')
ft_remove_job_url = os.environ.get('ft_remove_job_url', default='')
ft_console_log_url = os.environ.get('ft_console_log_url', default='')
dataset_sample = {
"english": {
"train": ["abcdef"],
"valid": ["zxcvbn"]
},
}
dataset_to_midware_name = {
"english": "english",
"cat": "cat",
"dog": "dog",
"bird": "bird"
}
hps_keys = ["epochs", "train_batch_size", "eval_batch_size", "gradient_accumulation_steps", "learning_rate", "weight_decay", "model_max_length"]
headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": preset_token}
learn_more_md = """
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/LICENSE) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
"""
ip_expiration_dict = defaultdict(lambda: 0)
def is_legal_char(c):
if c.isalnum():
return True
if '\u4e00' <= c <= '\u9fff':
return True
if c in "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.":
return True
if c in '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~':
return True
return False
def str_filter(s):
for _ in range(2):
if len(s) > 0 and (not is_legal_char(s[-1])):
s = s[:-1]
return s
def str_not_int(s):
try:
int(s)
return False
except ValueError:
return True
def str_not_float(s):
try:
float(s)
return False
except ValueError:
return True
class State:
def __init__(self, model_name):
self.conv = get_conversation_template(model_name)
self.conv_id = uuid.uuid4().hex
self.skip_next = False
self.model_name = model_name
if model_name == "palm-2":
# According to release note, "chat-bison@001" is PaLM 2 for chat.
# https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023
self.palm_chat = init_palm_chat("chat-bison@001")
def to_gradio_chatbot(self):
return self.conv.to_gradio_chatbot()
def dict(self):
base = self.conv.dict()
base.update(
{
"conv_id": self.conv_id,
"model_name": self.model_name,
}
)
return base
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_model_list(midware_url):
setted_model_order = {
"vicuna-7b-v1.5-16k": 10,
"vicuna-13b-v1.5": 90,
}
try:
ret = requests.get(midware_url, headers={"PRIVATE-TOKEN": preset_token}, timeout=5)
if "code" in ret.json() and "invalid" in ret.json()["code"]:
gr.Warning("Invalid preset token.")
models = ["CANNOT GET MODEL"]
else:
models = ret.json()["data"]
except requests.exceptions.RequestException:
models = ["CANNOT GET MODEL"]
models = sorted(models, key=lambda x: setted_model_order.get(x, 100))
logger.info(f"Models: {models}")
return models
def load_demo_single(models, url_params):
selected_model = models[0] if len(models) > 0 else ""
if "model" in url_params:
model = url_params["model"]
if model in models:
selected_model = model
dropdown_update = gr.Dropdown.update(
choices=models, value=selected_model, visible=True
)
state = None
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def load_demo(url_params, request: gr.Request):
global models
ip = request.client.host
logger.info(f"load_demo. ip: {ip}. params: {url_params}")
ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME
if model_list_mode == "reload":
models = get_model_list(midware_url)
return load_demo_single(models, url_params)
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.conv.update_last_message(None)
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 2
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 2
def add_text(state, model_selector, text, request: gr.Request):
ip = request.client.host
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
if state is None:
state = State(model_selector)
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 2
if ip_expiration_dict[ip] < time.time():
logger.info(f"inactive. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), INACTIVE_MSG) + (no_change_btn,) * 2
if enable_moderation:
flagged = violates_moderation(text)
if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), MODERATION_MSG) + (
no_change_btn,
) * 2
conv = state.conv
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
no_change_btn,
) * 2
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 2
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def model_worker_stream_iter(
conv,
model_name,
worker_addr,
prompt,
temperature,
repetition_penalty,
top_p,
max_new_tokens,
):
# Make requests
gen_params = {
"model_name": model_name,
"question": prompt,
"temperature": 1e-6,
"repetition_penalty": repetition_penalty,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"stop": conv.stop_str,
"stop_token_ids": conv.stop_token_ids,
"echo": False,
}
logger.info(f"==== request ====\n{gen_params}")
# Stream output
response = requests.post(
worker_addr,
headers=headers,
json=gen_params,
stream=True,
timeout=WORKER_API_TIMEOUT,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
yield data
def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request):
logger.info(f"bot_response. ip: {request.client.host}")
start_tstamp = time.time()
temperature = float(temperature)
top_p = float(top_p)
max_new_tokens = int(max_new_tokens)
if state.skip_next:
# This generate call is skipped due to invalid inputs
state.skip_next = False
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 2
return
conv, model_name = state.conv, state.model_name
if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
prompt = conv.to_openai_api_messages()
stream_iter = openai_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_name == "claude-2" or model_name == "claude-instant-1":
prompt = conv.get_prompt()
stream_iter = anthropic_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_name == "palm-2":
stream_iter = palm_api_stream_iter(
state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens
)
else:
# Get worker address
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
conv.update_last_message(SERVER_ERROR_MSG)
yield (
state,
state.to_gradio_chatbot(),
enable_btn,
enable_btn,
)
return
# Construct prompt.
# We need to call it here, so it will not be affected by "▌".
prompt = conv.get_prompt()
# Set repetition_penalty
if "t5" in model_name:
repetition_penalty = 1.2
else:
repetition_penalty = 1.0
stream_iter = model_worker_stream_iter(
conv,
model_name,
worker_addr,
prompt,
temperature,
repetition_penalty,
top_p,
max_new_tokens,
)
conv.update_last_message("▌")
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
try:
for data in stream_iter:
if data["error_code"] == 0:
finish_reason = data.get("finish_reason", None)
if finish_reason is not None and finish_reason == "length":
gr.Warning("Answer interrupted because the setting of [Max output tokens], try set a larger value.")
output = data["text"].strip()
if "vicuna" in model_name:
output = post_process_code(output)
output = str_filter(output)
conv.update_last_message(output + "▌")
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
else:
output = data["text"] + f"\n\n(error_code: {data['error_code']})"
conv.update_last_message(output)
yield (state, state.to_gradio_chatbot()) + (
enable_btn,
enable_btn,
)
return
time.sleep(0.015)
except requests.exceptions.RequestException as e:
conv.update_last_message(
f"{SERVER_ERROR_MSG}\n\n"
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
)
yield (state, state.to_gradio_chatbot()) + (
enable_btn,
enable_btn,
)
return
except Exception as e:
conv.update_last_message(
f"{SERVER_ERROR_MSG}\n\n"
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
)
yield (state, state.to_gradio_chatbot()) + (
enable_btn,
enable_btn,
)
return
# Delete "▌"
conv.update_last_message(conv.messages[-1][-1][:-1])
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = """
#dialog_notice_markdown {
font-size: 104%
}
#dialog_notice_markdown th {
display: none;
}
#dialog_notice_markdown td {
padding-top: 6px;
padding-bottom: 6px;
}
#leaderboard_markdown {
font-size: 104%
}
#leaderboard_markdown td {
padding-top: 6px;
padding-bottom: 6px;
}
#leaderboard_dataframe td {
line-height: 0.1em;
}
"""
def get_model_description_md(models):
model_description_md = """
| | | |
| ---- | ---- | ---- |
"""
ct = 0
visited = set()
for i, name in enumerate(models):
if name in model_info:
minfo = model_info[name]
if minfo.simple_name in visited:
continue
visited.add(minfo.simple_name)
one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
else:
visited.add(name)
one_model_md = (
f"[{name}](): Add the description at fastchat/model/model_registry.py"
)
if ct % 3 == 0:
model_description_md += "|"
model_description_md += f" {one_model_md} |"
if ct % 3 == 2:
model_description_md += "\n"
ct += 1
return model_description_md
def build_single_model_ui(models, add_promotion_links=False):
with gr.Column():
with gr.Tab("🧠 模型对话 Dialog"):
state = gr.State()
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
container=False,
)
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Scroll down and start chatting",
visible=False,
height=550,
)
with gr.Row():
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
container=False,
)
with gr.Column(scale=1, min_width=50):
send_btn = gr.Button(value="Send", visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
gr.Examples(
examples=["如何变得富有?", "你能用Python写一段快速排序吗?", "How to be rich?", "Can you write a quicksort code in Python?"],
inputs=textbox,
)
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=16,
maximum=1024,
value=512,
step=64,
interactive=True,
label="Max output tokens",
)
gr.Markdown(learn_more_md)
# Register listeners
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
send_btn.click(
add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
def ft_get_job_data():
running = 0
res_lst = []
try:
r = requests.get(ft_list_job_url, headers={"PRIVATE-TOKEN": preset_token}, timeout=8)
if "code" in r.json() and "invalid" in r.json()["code"]:
gr.Warning("Invalid preset token.")
return res_lst, running
for d in r.json():
if isinstance(d['status'], str) and d['status'].lower() == "running":
running += 1
hps = dict()
for key in hps_keys:
if key in d['parameter']:
hps[key] = d['parameter'][key]
res_lst.append([d['jobName'], d['username'], d['created_at'], d['model'], d['dataset'], d['status'], json.dumps(hps)])
res_lst = sorted(res_lst,key=(lambda x:x[2]), reverse=True)
res_lst = sorted(res_lst,key=(lambda x:x[5]), reverse=True)
except requests.exceptions.RequestException:
logger.info(f"Get job list fail")
return res_lst, running
def ft_refresh_click():
return ft_get_job_data()
def ft_cease_click(ft_console):
output = ft_console + "\n" + "** Streaming output ceased by user **"
return output
def console_generator(addr, sleep_time):
total_str = ""
ws = websocket.WebSocket()
ws.connect(addr, header={"PRIVATE-TOKEN": preset_token})
while True:
try:
new_str = ws.recv()
total_str = total_str + new_str
time.sleep(sleep_time)
yield total_str
except WebSocketConnectionClosedException:
ws.close()
break
ws.close()
def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length):
if ft_user_name == "":
gr.Warning(f"Submit fail, empty username.")
res_lst, running = ft_get_job_data()
return res_lst, running, no_change_textbox
if str_not_int(ft_train_batch_size) or str_not_int(ft_eval_batch_size) or str_not_int(ft_gradient_accumulation_steps) or str_not_float(ft_learning_rate) or str_not_float(ft_weight_decay) or str_not_int(ft_model_max_length):
gr.Warning(f"Submit fail, check the types. [learning rate] and [weight decay] should be float, others HPs should be int.")
res_lst, running = ft_get_job_data()
return res_lst, running, no_change_textbox
if ft_latest_running_cnt < int(allow_running):
midware_header = {"FINETUNE-SECRET": ft_token, "PRIVATE-TOKEN": preset_token}
hps_json = {
"epochs": str(ft_epochs),
"train_batch_size": str(ft_train_batch_size),
"eval_batch_size": str(ft_eval_batch_size),
"gradient_accumulation_steps": str(ft_gradient_accumulation_steps),
"learning_rate": str(ft_learning_rate),
"weight_decay": str(ft_weight_decay),
"model_max_length": str(ft_model_max_length)
}
json_data = {
"dataset": dataset_to_midware_name[ft_dataset_name],
"model": ft_model,
"parameter": hps_json,
"username": ft_user_name
}
try:
r = requests.post(ft_submit_job_url, json=json_data, headers=midware_header, timeout=120)
job_name = r.json()["jobName"]
gr.Info(f"Job {job_name} submit success.")
res_lst, running = ft_get_job_data()
total_str = ""
for s in console_generator(ft_console_log_url + job_name, 1):
total_str = s
yield res_lst, running, s
res_lst, running = ft_get_job_data()
yield res_lst, running, total_str
except requests.exceptions.RequestException:
gr.Warning(f"Connection Failure.")
res_lst, running = ft_get_job_data()
return res_lst, running, ""
else:
gr.Warning(f"Only allow {str(allow_running)} job(s) running simultaneously, please wait.")
res_lst, running = ft_get_job_data()
return res_lst, running, no_change_textbox
def ft_show_click(ft_selected_row_data):
for s in console_generator(ft_console_log_url + ft_selected_row_data[0], 0.2):
yield s
def ft_remove_click(ft_selected_row_data, ft_token):
status = ft_selected_row_data[5]
if isinstance(status, str) and status.lower() == "running":
r = requests.delete(ft_remove_job_url + ft_selected_row_data[0], headers={'FINETUNE-SECRET': ft_token, "PRIVATE-TOKEN": preset_token})
if r.status_code == 200:
gr.Info("Remove success.")
else:
gr.Warning(f"Remove fail. {r.status_code} {r.reason}.")
else:
gr.Warning("Remove fail. Can only remove a running job.")
return ft_get_job_data()
def ft_jobs_info_select(ft_jobs_info, evt: gr.SelectData):
selected_row = ft_jobs_info[evt.index[0]]
if evt.index[1] in (3, 4, 6):
try:
Hps = json.loads(selected_row[6])
except json.decoder.JSONDecodeError:
Hps = dict()
return [selected_row, selected_row[3], selected_row[4], Hps.get('epochs', ''), Hps.get('train_batch_size', ''), Hps.get('eval_batch_size', ''),
Hps.get('gradient_accumulation_steps', ''), Hps.get('learning_rate', ''), Hps.get('weight_decay', ''), Hps.get('model_max_length', '')]
else:
return [selected_row, no_change_dropdown, no_change_dropdown, no_change_slider, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox]
def ft_dataset_preview_click(ft_dataset_name):
value = dataset_sample.get(ft_dataset_name, {})
return gr.JSON.update(value=value, visible=True)
def ft_hide_dataset_click():
return gr.JSON.update(visible=False)
def build_demo(models):
with gr.Blocks(
title="Vicuna Test",
theme=gr.themes.Base(),
css = block_css
) as demo:
url_params = gr.JSON(visible=False)
(
state,
model_selector,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
) = build_single_model_ui(models)
if model_list_mode not in ["once", "reload"]:
raise ValueError(f"Unknown model list mode: {model_list_mode}")
demo.load(
load_demo,
[url_params],
[
state,
model_selector,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
],
_js=get_window_url_params_js,
)
return demo
try:
print("Internet IP:", get_internet_ip())
except Exception as e:
print(f"Get Internet IP error: {e}")
models = get_model_list(midware_url)
# Launch the demo
demo = build_demo(models)
demo.queue(
concurrency_count=concurrency_count, status_update_rate=10, api_open=False
).launch(
max_threads=200,
)