"""
Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022" 

Workers are listed in format of `model-path`@`host`@`port` 

The key mechanism behind this scripts is: 
    1, execute shell cmd to launch the controller/worker/openai-api-server;
    2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly.
Note that a few of non-critical `fastchat.serve` cmd options are not supported currently.
"""
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

import subprocess
import re
import argparse

LOGDIR = "./logs/"

if not os.path.exists(LOGDIR):
    os.makedirs(LOGDIR)

parser = argparse.ArgumentParser()
# ------multi worker-----------------
parser.add_argument(
    "--model-path-address",
    default="THUDM/chatglm2-6b@localhost@20002",
    nargs="+",
    type=str,
    help="model path, host, and port, formatted as model-path@host@port",
)
# ---------------controller-------------------------

parser.add_argument("--controller-host", type=str, default="localhost")
parser.add_argument("--controller-port", type=int, default=21001)
parser.add_argument(
    "--dispatch-method",
    type=str,
    choices=["lottery", "shortest_queue"],
    default="shortest_queue",
)
controller_args = ["controller-host", "controller-port", "dispatch-method"]

# ----------------------worker------------------------------------------

parser.add_argument("--worker-host", type=str, default="localhost")
parser.add_argument("--worker-port", type=int, default=21002)
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
# parser.add_argument(
#     "--controller-address", type=str, default="http://localhost:21001"
# )
parser.add_argument(
    "--model-path",
    type=str,
    default="lmsys/vicuna-7b-v1.5",
    help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
    "--revision",
    type=str,
    default="main",
    help="Hugging Face Hub model revision identifier",
)
parser.add_argument(
    "--device",
    type=str,
    choices=["cpu", "cuda", "mps", "xpu", "npu"],
    default="cuda",
    help="The device type",
)
parser.add_argument(
    "--gpus",
    type=str,
    default="0",
    help="A single GPU like 1 or multiple GPUs like 0,2",
)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument(
    "--max-gpu-memory",
    type=str,
    help="The maximum memory per gpu. Use a string like '13Gib'",
)
parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization")
parser.add_argument(
    "--cpu-offloading",
    action="store_true",
    help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
)
parser.add_argument(
    "--gptq-ckpt",
    type=str,
    default=None,
    help="Load quantized model. The path to the local GPTQ checkpoint.",
)
parser.add_argument(
    "--gptq-wbits",
    type=int,
    default=16,
    choices=[2, 3, 4, 8, 16],
    help="#bits to use for quantization",
)
parser.add_argument(
    "--gptq-groupsize",
    type=int,
    default=-1,
    help="Groupsize to use for quantization; default uses full row.",
)
parser.add_argument(
    "--gptq-act-order",
    action="store_true",
    help="Whether to apply the activation order GPTQ heuristic",
)
parser.add_argument(
    "--model-names",
    type=lambda s: s.split(","),
    help="Optional display comma separated names",
)
parser.add_argument(
    "--limit-worker-concurrency",
    type=int,
    default=5,
    help="Limit the model concurrency to prevent OOM.",
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")

worker_args = [
    "worker-host",
    "worker-port",
    "model-path",
    "revision",
    "device",
    "gpus",
    "num-gpus",
    "max-gpu-memory",
    "load-8bit",
    "cpu-offloading",
    "gptq-ckpt",
    "gptq-wbits",
    "gptq-groupsize",
    "gptq-act-order",
    "model-names",
    "limit-worker-concurrency",
    "stream-interval",
    "no-register",
    "controller-address",
]
# -----------------openai server---------------------------

parser.add_argument("--server-host", type=str, default="localhost", help="host name")
parser.add_argument("--server-port", type=int, default=8001, help="port number")
parser.add_argument(
    "--allow-credentials", action="store_true", help="allow credentials"
)
# parser.add_argument(
#     "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
# )
# parser.add_argument(
#     "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
# )
# parser.add_argument(
#     "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
# )
parser.add_argument(
    "--api-keys",
    type=lambda s: s.split(","),
    help="Optional list of comma separated API keys",
)
server_args = [
    "server-host",
    "server-port",
    "allow-credentials",
    "api-keys",
    "controller-address",
]

args = parser.parse_args()

args = argparse.Namespace(
    **vars(args),
    **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"},
)

if args.gpus:
    if len(args.gpus.split(",")) < args.num_gpus:
        raise ValueError(
            f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
        )
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

# 0,controller, model_worker, openai_api_server
# 1, cmd options
# 2,LOGDIR
# 3, log file name
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"

# 0 LOGDIR
#! 1 log file name
# 2 controller, worker, openai_api_server
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
                        sleep 1s;
                        echo "wait {2} running"
                done
                echo '{2} running' """


def string_args(args, args_list):
    args_str = ""
    for key, value in args._get_kwargs():
        key = key.replace("_", "-")
        if key not in args_list:
            continue

        key = key.split("-")[-1] if re.search("port|host", key) else key
        if not value:
            pass
        # 1==True ->  True
        elif isinstance(value, bool) and value == True:
            args_str += f" --{key} "
        elif (
            isinstance(value, list)
            or isinstance(value, tuple)
            or isinstance(value, set)
        ):
            value = " ".join(value)
            args_str += f" --{key} {value} "
        else:
            args_str += f" --{key} {value} "

    return args_str


def launch_worker(item):
    log_name = (
        item.split("/")[-1]
        .split("\\")[-1]
        .replace("-", "_")
        .replace("@", "_")
        .replace(".", "_")
    )

    args.model_path, args.worker_host, args.worker_port = item.split("@")
    print("*" * 80)
    worker_str_args = string_args(args, worker_args)
    print(worker_str_args)
    worker_sh = base_launch_sh.format(
        "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}"
    )
    worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker")
    subprocess.run(worker_sh, shell=True, check=True)
    subprocess.run(worker_check_sh, shell=True, check=True)


def launch_all():
    controller_str_args = string_args(args, controller_args)
    controller_sh = base_launch_sh.format(
        "controller", controller_str_args, LOGDIR, "controller"
    )
    controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller")
    subprocess.run(controller_sh, shell=True, check=True)
    subprocess.run(controller_check_sh, shell=True, check=True)

    if isinstance(args.model_path_address, str):
        launch_worker(args.model_path_address)
    else:
        for idx, item in enumerate(args.model_path_address):
            print(f"loading {idx}th model:{item}")
            launch_worker(item)

    server_str_args = string_args(args, server_args)
    server_sh = base_launch_sh.format(
        "openai_api_server", server_str_args, LOGDIR, "openai_api_server"
    )
    server_check_sh = base_check_sh.format(
        LOGDIR, "openai_api_server", "openai_api_server"
    )
    subprocess.run(server_sh, shell=True, check=True)
    subprocess.run(server_check_sh, shell=True, check=True)


if __name__ == "__main__":
    launch_all()