import argparse from collections import defaultdict import datetime import json import os import random import time import uuid import websocket from websocket import WebSocketConnectionClosedException import gradio as gr import requests import logging import re from fastchat.conversation import SeparatorStyle from fastchat.constants import ( LOGDIR, WORKER_API_TIMEOUT, ErrorCode, MODERATION_MSG, CONVERSATION_LIMIT_MSG, SERVER_ERROR_MSG, INACTIVE_MSG, INPUT_CHAR_LEN_LIMIT, CONVERSATION_TURN_LIMIT, SESSION_EXPIRATION_TIME, ) from fastchat.model.model_adapter import get_conversation_template from fastchat.model.model_registry import model_info from fastchat.serve.api_provider import ( anthropic_api_stream_iter, openai_api_stream_iter, palm_api_stream_iter, init_palm_chat, ) from fastchat.utils import ( build_logger, violates_moderation, get_window_url_params_js, parse_gradio_auth_creds, ) logger = build_logger("gradio_web_server", "gradio_web_server.log") no_change_dropdown = gr.Dropdown.update() no_change_slider = gr.Slider.update() no_change_textbox = gr.Textbox.update() no_change_btn = gr.Button.update() enable_btn = gr.Button.update(interactive=True) disable_btn = gr.Button.update(interactive=False) def get_internet_ip(): r = requests.get("http://txt.go.sohu.com/ip/soip") ip = re.findall(r'\d+.\d+.\d+.\d+', r.text) if ip is not None and len(ip) > 0: return ip[0] return None enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False concurrency_count = int(os.environ.get('concurrency_count', default='10')) model_list_mode = os.environ.get('model_list_mode', default='reload') midware_url = os.environ.get('midware_url', default='') preset_token = os.environ.get('preset_token', default='') worker_addr = os.environ.get('worker_addr', default='') allow_running = int(os.environ.get('allow_running', default='1')) ft_list_job_url = os.environ.get('ft_list_job_url', default='') ft_submit_job_url = os.environ.get('ft_submit_job_url', default='') ft_remove_job_url = os.environ.get('ft_remove_job_url', default='') ft_console_log_url = os.environ.get('ft_console_log_url', default='') dataset_sample = { "english": { "train": ["abcdef"], "valid": ["zxcvbn"] }, } dataset_to_midware_name = { "english": "english", "cat": "cat", "dog": "dog", "bird": "bird" } hps_keys = ["epochs", "train_batch_size", "eval_batch_size", "gradient_accumulation_steps", "learning_rate", "weight_decay", "model_max_length"] headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": preset_token} learn_more_md = """ ### License The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/LICENSE) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. """ ip_expiration_dict = defaultdict(lambda: 0) def is_legal_char(c): if c.isalnum(): return True if '\u4e00' <= c <= '\u9fff': return True if c in "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.": return True if c in '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~': return True return False def str_filter(s): for _ in range(2): if len(s) > 0 and (not is_legal_char(s[-1])): s = s[:-1] return s def str_not_int(s): try: int(s) return False except ValueError: return True def str_not_float(s): try: float(s) return False except ValueError: return True class State: def __init__(self, model_name): self.conv = get_conversation_template(model_name) self.conv_id = uuid.uuid4().hex self.skip_next = False self.model_name = model_name if model_name == "palm-2": # According to release note, "chat-bison@001" is PaLM 2 for chat. # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023 self.palm_chat = init_palm_chat("chat-bison@001") def to_gradio_chatbot(self): return self.conv.to_gradio_chatbot() def dict(self): base = self.conv.dict() base.update( { "conv_id": self.conv_id, "model_name": self.model_name, } ) return base def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") return name def get_model_list(midware_url): setted_model_order = { "vicuna-7b-v1.5-16k": 10, "vicuna-13b-v1.5": 90, } try: ret = requests.get(midware_url, headers={"PRIVATE-TOKEN": preset_token}, timeout=5) if "code" in ret.json() and "invalid" in ret.json()["code"]: gr.Warning("Invalid preset token.") models = ["CANNOT GET MODEL"] else: models = ret.json()["data"] except requests.exceptions.RequestException: models = ["CANNOT GET MODEL"] models = sorted(models, key=lambda x: setted_model_order.get(x, 100)) logger.info(f"Models: {models}") return models def load_demo_single(models, url_params): selected_model = models[0] if len(models) > 0 else "" if "model" in url_params: model = url_params["model"] if model in models: selected_model = model dropdown_update = gr.Dropdown.update( choices=models, value=selected_model, visible=True ) state = None return ( state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True), ) def load_demo(url_params, request: gr.Request): global models ip = request.client.host logger.info(f"load_demo. ip: {ip}. params: {url_params}") ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME if model_list_mode == "reload": models = get_model_list(midware_url) return load_demo_single(models, url_params) def regenerate(state, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") state.conv.update_last_message(None) return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 2 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = None return (state, [], "") + (disable_btn,) * 2 def add_text(state, model_selector, text, request: gr.Request): ip = request.client.host logger.info(f"add_text. ip: {ip}. len: {len(text)}") if state is None: state = State(model_selector) if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 2 if ip_expiration_dict[ip] < time.time(): logger.info(f"inactive. ip: {request.client.host}. text: {text}") state.skip_next = True return (state, state.to_gradio_chatbot(), INACTIVE_MSG) + (no_change_btn,) * 2 if enable_moderation: flagged = violates_moderation(text) if flagged: logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") state.skip_next = True return (state, state.to_gradio_chatbot(), MODERATION_MSG) + ( no_change_btn, ) * 2 conv = state.conv if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {request.client.host}. text: {text}") state.skip_next = True return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + ( no_change_btn, ) * 2 text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off conv.append_message(conv.roles[0], text) conv.append_message(conv.roles[1], None) return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 2 def post_process_code(code): sep = "\n```" if sep in code: blocks = code.split(sep) if len(blocks) % 2 == 1: for i in range(1, len(blocks), 2): blocks[i] = blocks[i].replace("\\_", "_") code = sep.join(blocks) return code def model_worker_stream_iter( conv, model_name, worker_addr, prompt, temperature, repetition_penalty, top_p, max_new_tokens, ): # Make requests gen_params = { "model_name": model_name, "question": prompt, "temperature": 1e-6, "repetition_penalty": repetition_penalty, "top_p": top_p, "max_new_tokens": max_new_tokens, "stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids, "echo": False, } logger.info(f"==== request ====\n{gen_params}") # Stream output response = requests.post( worker_addr, headers=headers, json=gen_params, stream=True, timeout=WORKER_API_TIMEOUT, ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) yield data def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request): logger.info(f"bot_response. ip: {request.client.host}") start_tstamp = time.time() temperature = float(temperature) top_p = float(top_p) max_new_tokens = int(max_new_tokens) if state.skip_next: # This generate call is skipped due to invalid inputs state.skip_next = False yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 2 return conv, model_name = state.conv, state.model_name if model_name == "gpt-3.5-turbo" or model_name == "gpt-4": prompt = conv.to_openai_api_messages() stream_iter = openai_api_stream_iter( model_name, prompt, temperature, top_p, max_new_tokens ) elif model_name == "claude-2" or model_name == "claude-instant-1": prompt = conv.get_prompt() stream_iter = anthropic_api_stream_iter( model_name, prompt, temperature, top_p, max_new_tokens ) elif model_name == "palm-2": stream_iter = palm_api_stream_iter( state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens ) else: # Get worker address logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") # No available worker if worker_addr == "": conv.update_last_message(SERVER_ERROR_MSG) yield ( state, state.to_gradio_chatbot(), enable_btn, enable_btn, ) return # Construct prompt. # We need to call it here, so it will not be affected by "▌". prompt = conv.get_prompt() # Set repetition_penalty if "t5" in model_name: repetition_penalty = 1.2 else: repetition_penalty = 1.0 stream_iter = model_worker_stream_iter( conv, model_name, worker_addr, prompt, temperature, repetition_penalty, top_p, max_new_tokens, ) conv.update_last_message("▌") yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 try: for data in stream_iter: if data["error_code"] == 0: finish_reason = data.get("finish_reason", None) if finish_reason is not None and finish_reason == "length": gr.Warning("Answer interrupted because the setting of [Max output tokens], try set a larger value.") output = data["text"].strip() if "vicuna" in model_name: output = post_process_code(output) output = str_filter(output) conv.update_last_message(output + "▌") yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 else: output = data["text"] + f"\n\n(error_code: {data['error_code']})" conv.update_last_message(output) yield (state, state.to_gradio_chatbot()) + ( enable_btn, enable_btn, ) return time.sleep(0.015) except requests.exceptions.RequestException as e: conv.update_last_message( f"{SERVER_ERROR_MSG}\n\n" f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" ) yield (state, state.to_gradio_chatbot()) + ( enable_btn, enable_btn, ) return except Exception as e: conv.update_last_message( f"{SERVER_ERROR_MSG}\n\n" f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" ) yield (state, state.to_gradio_chatbot()) + ( enable_btn, enable_btn, ) return # Delete "▌" conv.update_last_message(conv.messages[-1][-1][:-1]) yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2 finish_tstamp = time.time() logger.info(f"{output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "gen_params": { "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, }, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") block_css = """ #dialog_notice_markdown { font-size: 104% } #dialog_notice_markdown th { display: none; } #dialog_notice_markdown td { padding-top: 6px; padding-bottom: 6px; } #leaderboard_markdown { font-size: 104% } #leaderboard_markdown td { padding-top: 6px; padding-bottom: 6px; } #leaderboard_dataframe td { line-height: 0.1em; } """ def get_model_description_md(models): model_description_md = """ | | | | | ---- | ---- | ---- | """ ct = 0 visited = set() for i, name in enumerate(models): if name in model_info: minfo = model_info[name] if minfo.simple_name in visited: continue visited.add(minfo.simple_name) one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" else: visited.add(name) one_model_md = ( f"[{name}](): Add the description at fastchat/model/model_registry.py" ) if ct % 3 == 0: model_description_md += "|" model_description_md += f" {one_model_md} |" if ct % 3 == 2: model_description_md += "\n" ct += 1 return model_description_md def build_single_model_ui(models, add_promotion_links=False): with gr.Column(): with gr.Tab("🧠 模型对话 Dialog"): state = gr.State() with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False, ) chatbot = gr.Chatbot( elem_id="chatbot", label="Scroll down and start chatting", visible=False, height=550, ) with gr.Row(): with gr.Column(scale=20): textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False, ) with gr.Column(scale=1, min_width=50): send_btn = gr.Button(value="Send", visible=False) with gr.Row(visible=False) as button_row: regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) gr.Examples( examples=["如何变得富有?", "你能用Python写一段快速排序吗?", "How to be rich?", "Can you write a quicksort code in Python?"], inputs=textbox, ) with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Temperature", ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Top P", ) max_output_tokens = gr.Slider( minimum=16, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens", ) gr.Markdown(learn_more_md) # Register listeners btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( bot_response, [state, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) textbox.submit( add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list ).then( bot_response, [state, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) send_btn.click( add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list ).then( bot_response, [state, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row def ft_get_job_data(): running = 0 res_lst = [] try: r = requests.get(ft_list_job_url, headers={"PRIVATE-TOKEN": preset_token}, timeout=8) if "code" in r.json() and "invalid" in r.json()["code"]: gr.Warning("Invalid preset token.") return res_lst, running for d in r.json(): if isinstance(d['status'], str) and d['status'].lower() == "running": running += 1 hps = dict() for key in hps_keys: if key in d['parameter']: hps[key] = d['parameter'][key] res_lst.append([d['jobName'], d['username'], d['created_at'], d['model'], d['dataset'], d['status'], json.dumps(hps)]) res_lst = sorted(res_lst,key=(lambda x:x[2]), reverse=True) res_lst = sorted(res_lst,key=(lambda x:x[5]), reverse=True) except requests.exceptions.RequestException: logger.info(f"Get job list fail") return res_lst, running def ft_refresh_click(): return ft_get_job_data() def ft_cease_click(ft_console): output = ft_console + "\n" + "** Streaming output ceased by user **" return output def console_generator(addr, sleep_time): total_str = "" ws = websocket.WebSocket() ws.connect(addr, header={"PRIVATE-TOKEN": preset_token}) while True: try: new_str = ws.recv() total_str = total_str + new_str time.sleep(sleep_time) yield total_str except WebSocketConnectionClosedException: ws.close() break ws.close() def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length): if ft_user_name == "": gr.Warning(f"Submit fail, empty username.") res_lst, running = ft_get_job_data() return res_lst, running, no_change_textbox if str_not_int(ft_train_batch_size) or str_not_int(ft_eval_batch_size) or str_not_int(ft_gradient_accumulation_steps) or str_not_float(ft_learning_rate) or str_not_float(ft_weight_decay) or str_not_int(ft_model_max_length): gr.Warning(f"Submit fail, check the types. [learning rate] and [weight decay] should be float, others HPs should be int.") res_lst, running = ft_get_job_data() return res_lst, running, no_change_textbox if ft_latest_running_cnt < int(allow_running): midware_header = {"FINETUNE-SECRET": ft_token, "PRIVATE-TOKEN": preset_token} hps_json = { "epochs": str(ft_epochs), "train_batch_size": str(ft_train_batch_size), "eval_batch_size": str(ft_eval_batch_size), "gradient_accumulation_steps": str(ft_gradient_accumulation_steps), "learning_rate": str(ft_learning_rate), "weight_decay": str(ft_weight_decay), "model_max_length": str(ft_model_max_length) } json_data = { "dataset": dataset_to_midware_name[ft_dataset_name], "model": ft_model, "parameter": hps_json, "username": ft_user_name } try: r = requests.post(ft_submit_job_url, json=json_data, headers=midware_header, timeout=120) job_name = r.json()["jobName"] gr.Info(f"Job {job_name} submit success.") res_lst, running = ft_get_job_data() total_str = "" for s in console_generator(ft_console_log_url + job_name, 1): total_str = s yield res_lst, running, s res_lst, running = ft_get_job_data() yield res_lst, running, total_str except requests.exceptions.RequestException: gr.Warning(f"Connection Failure.") res_lst, running = ft_get_job_data() return res_lst, running, "" else: gr.Warning(f"Only allow {str(allow_running)} job(s) running simultaneously, please wait.") res_lst, running = ft_get_job_data() return res_lst, running, no_change_textbox def ft_show_click(ft_selected_row_data): for s in console_generator(ft_console_log_url + ft_selected_row_data[0], 0.2): yield s def ft_remove_click(ft_selected_row_data, ft_token): status = ft_selected_row_data[5] if isinstance(status, str) and status.lower() == "running": r = requests.delete(ft_remove_job_url + ft_selected_row_data[0], headers={'FINETUNE-SECRET': ft_token, "PRIVATE-TOKEN": preset_token}) if r.status_code == 200: gr.Info("Remove success.") else: gr.Warning(f"Remove fail. {r.status_code} {r.reason}.") else: gr.Warning("Remove fail. Can only remove a running job.") return ft_get_job_data() def ft_jobs_info_select(ft_jobs_info, evt: gr.SelectData): selected_row = ft_jobs_info[evt.index[0]] if evt.index[1] in (3, 4, 6): try: Hps = json.loads(selected_row[6]) except json.decoder.JSONDecodeError: Hps = dict() return [selected_row, selected_row[3], selected_row[4], Hps.get('epochs', ''), Hps.get('train_batch_size', ''), Hps.get('eval_batch_size', ''), Hps.get('gradient_accumulation_steps', ''), Hps.get('learning_rate', ''), Hps.get('weight_decay', ''), Hps.get('model_max_length', '')] else: return [selected_row, no_change_dropdown, no_change_dropdown, no_change_slider, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox] def ft_dataset_preview_click(ft_dataset_name): value = dataset_sample.get(ft_dataset_name, {}) return gr.JSON.update(value=value, visible=True) def ft_hide_dataset_click(): return gr.JSON.update(visible=False) def build_demo(models): with gr.Blocks( title="Vicuna Test", theme=gr.themes.Base(), css = block_css ) as demo: url_params = gr.JSON(visible=False) ( state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row, ) = build_single_model_ui(models) if model_list_mode not in ["once", "reload"]: raise ValueError(f"Unknown model list mode: {model_list_mode}") demo.load( load_demo, [url_params], [ state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row, ], _js=get_window_url_params_js, ) return demo try: print("Internet IP:", get_internet_ip()) except Exception as e: print(f"Get Internet IP error: {e}") models = get_model_list(midware_url) # Launch the demo demo = build_demo(models) demo.queue( concurrency_count=concurrency_count, status_update_rate=10, api_open=False ).launch( max_threads=200, )