|
""" |
|
The gradio demo server for chatting with a single model. |
|
""" |
|
|
|
import argparse |
|
from collections import defaultdict |
|
import datetime |
|
import hashlib |
|
import json |
|
import os |
|
import random |
|
import time |
|
import uuid |
|
|
|
import gradio as gr |
|
import requests |
|
|
|
from src.constants import ( |
|
LOGDIR, |
|
WORKER_API_TIMEOUT, |
|
ErrorCode, |
|
MODERATION_MSG, |
|
CONVERSATION_LIMIT_MSG, |
|
RATE_LIMIT_MSG, |
|
SERVER_ERROR_MSG, |
|
INPUT_CHAR_LEN_LIMIT, |
|
CONVERSATION_TURN_LIMIT, |
|
SESSION_EXPIRATION_TIME, |
|
) |
|
from src.model.model_adapter import ( |
|
get_conversation_template, |
|
) |
|
from src.model.model_registry import get_model_info, model_info |
|
from src.serve.api_provider import get_api_provider_stream_iter |
|
from src.serve.remote_logger import get_remote_logger |
|
from src.utils import ( |
|
build_logger, |
|
get_window_url_params_js, |
|
get_window_url_params_with_tos_js, |
|
moderation_filter, |
|
parse_gradio_auth_creds, |
|
load_image, |
|
) |
|
|
|
logger = build_logger("gradio_web_server", "gradio_web_server.log") |
|
|
|
headers = {"User-Agent": "FastChat Client"} |
|
|
|
no_change_btn = gr.Button() |
|
enable_btn = gr.Button(interactive=True, visible=True) |
|
disable_btn = gr.Button(interactive=False) |
|
invisible_btn = gr.Button(interactive=False, visible=False) |
|
|
|
controller_url = None |
|
enable_moderation = False |
|
use_remote_storage = False |
|
|
|
acknowledgment_md = """ |
|
### Terms of Service |
|
|
|
Placeholder |
|
### Acknowledgment |
|
Placeholder |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api_endpoint_info = {} |
|
|
|
|
|
class State: |
|
def __init__(self, model_name, is_vision=False): |
|
self.conv = get_conversation_template(model_name) |
|
self.conv_id = uuid.uuid4().hex |
|
self.skip_next = False |
|
self.model_name = model_name |
|
self.oai_thread_id = None |
|
self.is_vision = is_vision |
|
|
|
|
|
self.has_csam_image = False |
|
|
|
self.regen_support = True |
|
if "browsing" in model_name: |
|
self.regen_support = False |
|
self.init_system_prompt(self.conv) |
|
|
|
def init_system_prompt(self, conv): |
|
if hasattr(conv, "get_system_message"): |
|
system_prompt = conv.get_system_message() |
|
elif (conv, "system"): |
|
system_prompt = conv.system |
|
return |
|
if len(system_prompt) == 0: |
|
return |
|
current_date = datetime.datetime.now().strftime("%Y-%m-%d") |
|
system_prompt = system_prompt.replace("{{currentDateTime}}", current_date) |
|
conv.set_system_message(system_prompt) |
|
|
|
def to_gradio_chatbot(self): |
|
return self.conv.to_gradio_chatbot() |
|
|
|
def dict(self): |
|
base = self.conv.dict() |
|
base.update( |
|
{ |
|
"conv_id": self.conv_id, |
|
"model_name": self.model_name, |
|
} |
|
) |
|
|
|
if self.is_vision: |
|
base.update({"has_csam_image": self.has_csam_image}) |
|
return base |
|
|
|
|
|
def set_global_vars(controller_url_, enable_moderation_, use_remote_storage_): |
|
global controller_url, enable_moderation, use_remote_storage |
|
controller_url = controller_url_ |
|
enable_moderation = enable_moderation_ |
|
use_remote_storage = use_remote_storage_ |
|
|
|
|
|
def get_conv_log_filename(is_vision=False, has_csam_image=False): |
|
t = datetime.datetime.now() |
|
conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json" |
|
if is_vision and not has_csam_image: |
|
name = os.path.join(LOGDIR, f"vision-tmp-{conv_log_filename}") |
|
elif is_vision and has_csam_image: |
|
name = os.path.join(LOGDIR, f"vision-csam-{conv_log_filename}") |
|
else: |
|
name = os.path.join(LOGDIR, conv_log_filename) |
|
|
|
return name |
|
|
|
|
|
def get_model_list(controller_url, register_api_endpoint_file, vision_arena): |
|
global api_endpoint_info |
|
|
|
|
|
if controller_url: |
|
ret = requests.post(controller_url + "/refresh_all_workers") |
|
assert ret.status_code == 200 |
|
|
|
if vision_arena: |
|
ret = requests.post(controller_url + "/list_multimodal_models") |
|
models = ret.json()["models"] |
|
else: |
|
ret = requests.post(controller_url + "/list_language_models") |
|
models = ret.json()["models"] |
|
else: |
|
models = [] |
|
|
|
|
|
if register_api_endpoint_file: |
|
api_endpoint_info = json.load(open(register_api_endpoint_file)) |
|
for mdl, mdl_dict in api_endpoint_info.items(): |
|
mdl_vision = mdl_dict.get("vision-arena", False) |
|
mdl_text = mdl_dict.get("text-arena", True) |
|
if vision_arena and mdl_vision: |
|
models.append(mdl) |
|
if not vision_arena and mdl_text: |
|
models.append(mdl) |
|
|
|
|
|
models = list(set(models)) |
|
visible_models = models.copy() |
|
for mdl in models: |
|
if mdl not in api_endpoint_info: |
|
continue |
|
mdl_dict = api_endpoint_info[mdl] |
|
if mdl_dict["anony_only"]: |
|
visible_models.remove(mdl) |
|
|
|
|
|
priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)} |
|
models.sort(key=lambda x: priority.get(x, x)) |
|
visible_models.sort(key=lambda x: priority.get(x, x)) |
|
logger.info(f"All models: {models}") |
|
logger.info(f"Visible models: {visible_models}") |
|
return visible_models, models |
|
|
|
|
|
def load_demo_single(models, url_params): |
|
selected_model = models[0] if len(models) > 0 else "" |
|
if "model" in url_params: |
|
model = url_params["model"] |
|
if model in models: |
|
selected_model = model |
|
|
|
dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True) |
|
state = None |
|
return state, dropdown_update |
|
|
|
|
|
def load_demo(url_params, request: gr.Request): |
|
global models |
|
|
|
ip = get_ip(request) |
|
logger.info(f"load_demo. ip: {ip}. params: {url_params}") |
|
|
|
if args.model_list_mode == "reload": |
|
models, all_models = get_model_list( |
|
controller_url, args.register_api_endpoint_file, vision_arena=False |
|
) |
|
|
|
return load_demo_single(models, url_params) |
|
|
|
|
|
def vote_last_response(state, vote_type, model_selector, request: gr.Request): |
|
filename = get_conv_log_filename() |
|
if "llava" in model_selector: |
|
filename = filename.replace("2024", "vision-tmp-2024") |
|
|
|
with open(filename, "a") as fout: |
|
data = { |
|
"tstamp": round(time.time(), 4), |
|
"type": vote_type, |
|
"model": model_selector, |
|
"state": state.dict(), |
|
"ip": get_ip(request), |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
get_remote_logger().log(data) |
|
|
|
|
|
def upvote_last_response(state, model_selector, request: gr.Request): |
|
ip = get_ip(request) |
|
logger.info(f"upvote. ip: {ip}") |
|
vote_last_response(state, "upvote", model_selector, request) |
|
return ("",) + (disable_btn,) * 3 |
|
|
|
|
|
def downvote_last_response(state, model_selector, request: gr.Request): |
|
ip = get_ip(request) |
|
logger.info(f"downvote. ip: {ip}") |
|
vote_last_response(state, "downvote", model_selector, request) |
|
return ("",) + (disable_btn,) * 3 |
|
|
|
|
|
def flag_last_response(state, model_selector, request: gr.Request): |
|
ip = get_ip(request) |
|
logger.info(f"flag. ip: {ip}") |
|
vote_last_response(state, "flag", model_selector, request) |
|
return ("",) + (disable_btn,) * 3 |
|
|
|
|
|
def regenerate(state, request: gr.Request): |
|
ip = get_ip(request) |
|
logger.info(f"regenerate. ip: {ip}") |
|
if not state.regen_support: |
|
state.skip_next = True |
|
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 |
|
state.conv.update_last_message(None) |
|
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 |
|
|
|
|
|
def clear_history(request: gr.Request): |
|
ip = get_ip(request) |
|
logger.info(f"clear_history. ip: {ip}") |
|
state = None |
|
return (state, [], "", None) + (disable_btn,) * 5 |
|
|
|
|
|
def get_ip(request: gr.Request): |
|
if "cf-connecting-ip" in request.headers: |
|
ip = request.headers["cf-connecting-ip"] |
|
elif "x-forwarded-for" in request.headers: |
|
ip = request.headers["x-forwarded-for"] |
|
else: |
|
ip = request.client.host |
|
return ip |
|
|
|
|
|
|
|
def report_csam_image(state, image): |
|
pass |
|
|
|
|
|
def _prepare_text_with_image(state, text, images, csam_flag): |
|
if images is not None and len(images) > 0: |
|
image = images[0] |
|
|
|
if len(state.conv.get_images()) > 0: |
|
|
|
state.conv = get_conversation_template(state.model_name) |
|
|
|
if hasattr(state.conv, "convert_image_to_base64"): |
|
image = state.conv.convert_image_to_base64( |
|
image |
|
) |
|
else: |
|
from src.conversation import convert_image_to_base64 |
|
image = convert_image_to_base64(image, None) |
|
|
|
if csam_flag: |
|
state.has_csam_image = True |
|
report_csam_image(state, image) |
|
|
|
text = text, [image] |
|
|
|
return text |
|
|
|
|
|
def add_text(state, model_selector, text, image, request: gr.Request): |
|
ip = get_ip(request) |
|
logger.info(f"add_text. ip: {ip}. len: {len(text)}; text: {text}") |
|
|
|
if state is None: |
|
state = State(model_selector) |
|
|
|
if len(text) <= 0: |
|
state.skip_next = True |
|
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 |
|
|
|
all_conv_text = state.conv.get_prompt() |
|
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text |
|
flagged = moderation_filter(all_conv_text, [state.model_name]) |
|
|
|
if flagged: |
|
logger.info(f"violate moderation. ip: {ip}. text: {text}") |
|
|
|
text = MODERATION_MSG |
|
|
|
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: |
|
logger.info(f"conversation turn limit. ip: {ip}. text: {text}") |
|
state.skip_next = True |
|
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG, None) + ( |
|
no_change_btn, |
|
) * 5 |
|
|
|
text = text[:INPUT_CHAR_LEN_LIMIT] |
|
text = _prepare_text_with_image(state, text, image, csam_flag=False) |
|
state.conv.append_message(state.conv.roles[0], text) |
|
state.conv.append_message(state.conv.roles[1], None) |
|
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 |
|
|
|
|
|
def model_worker_stream_iter( |
|
conv, |
|
model_name, |
|
worker_addr, |
|
prompt, |
|
temperature, |
|
repetition_penalty, |
|
top_p, |
|
max_new_tokens, |
|
images, |
|
): |
|
|
|
gen_params = { |
|
"model": model_name, |
|
"prompt": prompt, |
|
"temperature": temperature, |
|
"repetition_penalty": repetition_penalty, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
"stop": conv.stop_str, |
|
"stop_token_ids": conv.stop_token_ids, |
|
"echo": False, |
|
} |
|
|
|
logger.info(f"==== request ====\n{gen_params}") |
|
|
|
if len(images) > 0: |
|
gen_params["images"] = images |
|
|
|
|
|
response = requests.post( |
|
worker_addr + "/worker_generate_stream", |
|
headers=headers, |
|
json=gen_params, |
|
stream=True, |
|
timeout=WORKER_API_TIMEOUT, |
|
) |
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): |
|
if chunk: |
|
data = json.loads(chunk.decode()) |
|
yield data |
|
|
|
|
|
def is_limit_reached(model_name, ip): |
|
monitor_url = "http://localhost:9090" |
|
try: |
|
ret = requests.get( |
|
f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1 |
|
) |
|
obj = ret.json() |
|
return obj |
|
except Exception as e: |
|
logger.info(f"monitor error: {e}") |
|
return None |
|
|
|
|
|
def bot_response( |
|
state, |
|
temperature, |
|
top_p, |
|
max_new_tokens, |
|
request: gr.Request, |
|
apply_rate_limit=False, |
|
use_recommended_config=False, |
|
): |
|
ip = get_ip(request) |
|
logger.info(f"bot_response. ip: {ip}") |
|
start_tstamp = time.time() |
|
temperature = float(temperature) |
|
top_p = float(top_p) |
|
max_new_tokens = int(max_new_tokens) |
|
|
|
if state.skip_next: |
|
|
|
state.skip_next = False |
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 |
|
return |
|
|
|
if apply_rate_limit: |
|
ret = is_limit_reached(state.model_name, ip) |
|
if ret is not None and ret["is_limit_reached"]: |
|
error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"] |
|
logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}") |
|
state.conv.update_last_message(error_msg) |
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 |
|
return |
|
|
|
conv, model_name = state.conv, state.model_name |
|
model_api_dict = ( |
|
api_endpoint_info[model_name] if model_name in api_endpoint_info else None |
|
) |
|
images = conv.get_images() |
|
logger.info(f"model_name: {model_name}; model_api_dict: {model_api_dict}; msg: {len(conv.messages)}") |
|
if model_api_dict is None: |
|
if model_name == "llava-original": |
|
from src.model.model_llava import inference_by_prompt_and_images |
|
logger.info(f"prompt for llava-original: {conv.get_prompt()}; images: {len(images)}") |
|
output_text = inference_by_prompt_and_images(conv.get_prompt(), images)[0] |
|
else: |
|
from src.model.model_llava import inference_by_prompt_and_images_fire |
|
logger.info(f"prompt for llava-fire: {conv.get_prompt()}; images: {len(images)}") |
|
output_text = inference_by_prompt_and_images_fire(conv.get_prompt(), images)[0] |
|
stream_iter = [{ |
|
"error_code": 0, |
|
"text": output_text |
|
}] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
if use_recommended_config: |
|
recommended_config = model_api_dict.get("recommended_config", None) |
|
if recommended_config is not None: |
|
temperature = recommended_config.get("temperature", temperature) |
|
top_p = recommended_config.get("top_p", top_p) |
|
max_new_tokens = recommended_config.get( |
|
"max_new_tokens", max_new_tokens |
|
) |
|
|
|
stream_iter = [{ |
|
"error_code": 0, |
|
"text": "hello" |
|
}] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
html_code = ' <span class="cursor"></span> ' |
|
|
|
|
|
conv.update_last_message(html_code) |
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 |
|
|
|
try: |
|
data = {"text": ""} |
|
for i, data in enumerate(stream_iter): |
|
if data["error_code"] == 0: |
|
output = data["text"].strip() |
|
|
|
conv.update_last_message(output + html_code) |
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 |
|
else: |
|
output = data["text"] + f"\n\n(error_code: {data['error_code']})" |
|
conv.update_last_message(output) |
|
yield (state, state.to_gradio_chatbot()) + ( |
|
disable_btn, |
|
disable_btn, |
|
disable_btn, |
|
enable_btn, |
|
enable_btn, |
|
) |
|
return |
|
output = data["text"].strip() |
|
conv.update_last_message(output) |
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 |
|
except requests.exceptions.RequestException as e: |
|
conv.update_last_message( |
|
f"{SERVER_ERROR_MSG}\n\n" |
|
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" |
|
) |
|
yield (state, state.to_gradio_chatbot()) + ( |
|
disable_btn, |
|
disable_btn, |
|
disable_btn, |
|
enable_btn, |
|
enable_btn, |
|
) |
|
return |
|
except Exception as e: |
|
conv.update_last_message( |
|
f"{SERVER_ERROR_MSG}\n\n" |
|
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" |
|
) |
|
yield (state, state.to_gradio_chatbot()) + ( |
|
disable_btn, |
|
disable_btn, |
|
disable_btn, |
|
enable_btn, |
|
enable_btn, |
|
) |
|
return |
|
|
|
finish_tstamp = time.time() |
|
logger.info(f"{output}") |
|
|
|
conv.save_new_images( |
|
has_csam_images=state.has_csam_image, use_remote_storage=use_remote_storage |
|
) |
|
|
|
filename = get_conv_log_filename( |
|
is_vision=state.is_vision, has_csam_image=state.has_csam_image |
|
) |
|
|
|
with open(filename, "a") as fout: |
|
data = { |
|
"tstamp": round(finish_tstamp, 4), |
|
"type": "chat", |
|
"model": model_name, |
|
"gen_params": { |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
}, |
|
"start": round(start_tstamp, 4), |
|
"finish": round(finish_tstamp, 4), |
|
"state": state.dict(), |
|
"ip": get_ip(request), |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
get_remote_logger().log(data) |
|
|
|
|
|
block_css = """ |
|
#notice_markdown .prose { |
|
font-size: 110% !important; |
|
} |
|
#notice_markdown th { |
|
display: none; |
|
} |
|
#notice_markdown td { |
|
padding-top: 6px; |
|
padding-bottom: 6px; |
|
} |
|
#arena_leaderboard_dataframe table { |
|
font-size: 110%; |
|
} |
|
#full_leaderboard_dataframe table { |
|
font-size: 110%; |
|
} |
|
#model_description_markdown { |
|
font-size: 110% !important; |
|
} |
|
#leaderboard_markdown .prose { |
|
font-size: 110% !important; |
|
} |
|
#leaderboard_markdown td { |
|
padding-top: 6px; |
|
padding-bottom: 6px; |
|
} |
|
#leaderboard_dataframe td { |
|
line-height: 0.1em; |
|
} |
|
#about_markdown .prose { |
|
font-size: 110% !important; |
|
} |
|
#ack_markdown .prose { |
|
font-size: 110% !important; |
|
} |
|
#chatbot .prose { |
|
font-size: 105% !important; |
|
} |
|
.sponsor-image-about img { |
|
margin: 0 20px; |
|
margin-top: 20px; |
|
height: 40px; |
|
max-height: 100%; |
|
width: auto; |
|
float: left; |
|
} |
|
|
|
.chatbot h1, h2, h3 { |
|
margin-top: 8px; /* Adjust the value as needed */ |
|
margin-bottom: 0px; /* Adjust the value as needed */ |
|
padding-bottom: 0px; |
|
} |
|
|
|
.chatbot h1 { |
|
font-size: 130%; |
|
} |
|
.chatbot h2 { |
|
font-size: 120%; |
|
} |
|
.chatbot h3 { |
|
font-size: 110%; |
|
} |
|
.chatbot p:not(:first-child) { |
|
margin-top: 8px; |
|
} |
|
|
|
.typing { |
|
display: inline-block; |
|
} |
|
|
|
.cursor { |
|
display: inline-block; |
|
width: 7px; |
|
height: 1em; |
|
background-color: black; |
|
vertical-align: middle; |
|
animation: blink 1s infinite; |
|
} |
|
|
|
.dark .cursor { |
|
display: inline-block; |
|
width: 7px; |
|
height: 1em; |
|
background-color: white; |
|
vertical-align: middle; |
|
animation: blink 1s infinite; |
|
} |
|
|
|
@keyframes blink { |
|
0%, 50% { opacity: 1; } |
|
50.1%, 100% { opacity: 0; } |
|
} |
|
|
|
.app { |
|
max-width: 100% !important; |
|
padding: 20px !important; |
|
} |
|
|
|
a { |
|
color: #1976D2; /* Your current link color, a shade of blue */ |
|
text-decoration: none; /* Removes underline from links */ |
|
} |
|
a:hover { |
|
color: #63A4FF; /* This can be any color you choose for hover */ |
|
text-decoration: underline; /* Adds underline on hover */ |
|
} |
|
""" |
|
|
|
|
|
def get_model_description_md(models): |
|
model_description_md = """ |
|
| | | | |
|
| ---- | ---- | ---- | |
|
""" |
|
ct = 0 |
|
visited = set() |
|
for i, name in enumerate(models): |
|
minfo = get_model_info(name) |
|
if minfo.simple_name in visited: |
|
continue |
|
visited.add(minfo.simple_name) |
|
one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" |
|
|
|
if ct % 3 == 0: |
|
model_description_md += "|" |
|
model_description_md += f" {one_model_md} |" |
|
if ct % 3 == 2: |
|
model_description_md += "\n" |
|
ct += 1 |
|
return model_description_md |
|
|
|
|
|
def build_about(): |
|
about_markdown = """ |
|
# About Us |
|
Placeholder |
|
## Arena Core Team |
|
Placeholder |
|
## Past Members |
|
Placeholder |
|
## Learn more |
|
Placeholder |
|
|
|
## Contact Us |
|
Placeholder |
|
|
|
## Acknowledgment |
|
Placeholder |
|
""" |
|
gr.Markdown(about_markdown, elem_id="about_markdown") |
|
|
|
|
|
def build_single_model_ui(models, add_promotion_links=False): |
|
promotion = ( |
|
""" |
|
- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | |
|
- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/) |
|
- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/) |
|
|
|
## 🤖 Choose any model to chat |
|
""" |
|
if add_promotion_links |
|
else "" |
|
) |
|
|
|
notice_markdown = f""" |
|
# 🏔️ Chat with Open Large Language Models |
|
{promotion} |
|
""" |
|
|
|
state = gr.State() |
|
gr.Markdown(notice_markdown, elem_id="notice_markdown") |
|
|
|
with gr.Group(elem_id="share-region-named"): |
|
with gr.Row(elem_id="model_selector_row"): |
|
model_selector = gr.Dropdown( |
|
choices=models, |
|
value=models[0] if len(models) > 0 else "", |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
) |
|
with gr.Row(): |
|
with gr.Accordion( |
|
f"🔍 Expand to see the descriptions of {len(models)} models", |
|
open=False, |
|
): |
|
model_description_md = get_model_description_md(models) |
|
gr.Markdown(model_description_md, elem_id="model_description_markdown") |
|
|
|
chatbot = gr.Chatbot( |
|
elem_id="chatbot", |
|
label="Scroll down and start chatting", |
|
height=550, |
|
show_copy_button=True, |
|
) |
|
with gr.Row(): |
|
textbox = gr.Textbox( |
|
show_label=False, |
|
placeholder="👉 Enter your prompt and press ENTER", |
|
elem_id="input_box", |
|
) |
|
send_btn = gr.Button(value="Send", variant="primary", scale=0) |
|
|
|
with gr.Row() as button_row: |
|
upvote_btn = gr.Button(value="👍 Upvote", interactive=False) |
|
downvote_btn = gr.Button(value="👎 Downvote", interactive=False) |
|
flag_btn = gr.Button(value="⚠️ Flag", interactive=False) |
|
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) |
|
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) |
|
|
|
with gr.Accordion("Parameters", open=False) as parameter_row: |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=1.0, |
|
step=0.1, |
|
interactive=True, |
|
label="Top P", |
|
) |
|
max_output_tokens = gr.Slider( |
|
minimum=16, |
|
maximum=2048, |
|
value=1024, |
|
step=64, |
|
interactive=True, |
|
label="Max output tokens", |
|
) |
|
|
|
if add_promotion_links: |
|
gr.Markdown(acknowledgment_md, elem_id="ack_markdown") |
|
|
|
|
|
imagebox = gr.State(None) |
|
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] |
|
upvote_btn.click( |
|
upvote_last_response, |
|
[state, model_selector], |
|
[textbox, upvote_btn, downvote_btn, flag_btn], |
|
) |
|
downvote_btn.click( |
|
downvote_last_response, |
|
[state, model_selector], |
|
[textbox, upvote_btn, downvote_btn, flag_btn], |
|
) |
|
flag_btn.click( |
|
flag_last_response, |
|
[state, model_selector], |
|
[textbox, upvote_btn, downvote_btn, flag_btn], |
|
) |
|
regenerate_btn.click( |
|
regenerate, state, [state, chatbot, textbox, imagebox] + btn_list |
|
).then( |
|
bot_response, |
|
[state, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
) |
|
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) |
|
|
|
model_selector.change( |
|
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list |
|
) |
|
|
|
textbox.submit( |
|
add_text, |
|
[state, model_selector, textbox, imagebox], |
|
[state, chatbot, textbox, imagebox] + btn_list, |
|
).then( |
|
bot_response, |
|
[state, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
) |
|
send_btn.click( |
|
add_text, |
|
[state, model_selector, textbox, imagebox], |
|
[state, chatbot, textbox, imagebox] + btn_list, |
|
).then( |
|
bot_response, |
|
[state, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
) |
|
|
|
return [state, model_selector] |
|
|
|
|
|
def build_demo(models): |
|
with gr.Blocks( |
|
title="Chat with Open Large Language Models", |
|
theme=gr.themes.Default(), |
|
css=block_css, |
|
) as demo: |
|
url_params = gr.JSON(visible=False) |
|
|
|
state, model_selector = build_single_model_ui(models) |
|
|
|
if args.model_list_mode not in ["once", "reload"]: |
|
raise ValueError(f"Unknown model list mode: {args.model_list_mode}") |
|
|
|
if args.show_terms_of_use: |
|
load_js = get_window_url_params_with_tos_js |
|
else: |
|
load_js = get_window_url_params_js |
|
|
|
demo.load( |
|
load_demo, |
|
[url_params], |
|
[ |
|
state, |
|
model_selector, |
|
], |
|
js=load_js, |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
parser.add_argument("--port", type=int) |
|
parser.add_argument( |
|
"--share", |
|
action="store_true", |
|
help="Whether to generate a public, shareable link", |
|
) |
|
parser.add_argument( |
|
"--controller-url", |
|
type=str, |
|
default="http://localhost:21001", |
|
help="The address of the controller", |
|
) |
|
parser.add_argument( |
|
"--concurrency-count", |
|
type=int, |
|
default=10, |
|
help="The concurrency count of the gradio queue", |
|
) |
|
parser.add_argument( |
|
"--model-list-mode", |
|
type=str, |
|
default="once", |
|
choices=["once", "reload"], |
|
help="Whether to load the model list once or reload the model list every time", |
|
) |
|
parser.add_argument( |
|
"--moderate", |
|
action="store_true", |
|
help="Enable content moderation to block unsafe inputs", |
|
) |
|
parser.add_argument( |
|
"--show-terms-of-use", |
|
action="store_true", |
|
help="Shows term of use before loading the demo", |
|
) |
|
parser.add_argument( |
|
"--register-api-endpoint-file", |
|
type=str, |
|
help="Register API-based model endpoints from a JSON file", |
|
) |
|
parser.add_argument( |
|
"--gradio-auth-path", |
|
type=str, |
|
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', |
|
) |
|
parser.add_argument( |
|
"--gradio-root-path", |
|
type=str, |
|
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", |
|
) |
|
parser.add_argument( |
|
"--use-remote-storage", |
|
action="store_true", |
|
default=False, |
|
help="Uploads image files to google cloud storage if set to true", |
|
) |
|
args = parser.parse_args() |
|
logger.info(f"args: {args}") |
|
|
|
|
|
set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) |
|
models, all_models = get_model_list( |
|
args.controller_url, args.register_api_endpoint_file, vision_arena=False |
|
) |
|
|
|
|
|
auth = None |
|
if args.gradio_auth_path is not None: |
|
auth = parse_gradio_auth_creds(args.gradio_auth_path) |
|
|
|
|
|
demo = build_demo(models) |
|
demo.queue( |
|
default_concurrency_limit=args.concurrency_count, |
|
status_update_rate=10, |
|
api_open=False, |
|
).launch( |
|
server_name=args.host, |
|
server_port=args.port, |
|
share=args.share, |
|
max_threads=200, |
|
auth=auth, |
|
root_path=args.gradio_root_path, |
|
) |
|
|