|
"""Benchmarking script to test the throughput of serving workers.""" |
|
import argparse |
|
import json |
|
|
|
import requests |
|
import threading |
|
import time |
|
|
|
from fastchat.conversation import get_conv_template |
|
|
|
|
|
def main(): |
|
if args.worker_address: |
|
worker_addr = args.worker_address |
|
else: |
|
controller_addr = args.controller_address |
|
ret = requests.post(controller_addr + "/refresh_all_workers") |
|
ret = requests.post(controller_addr + "/list_models") |
|
models = ret.json()["models"] |
|
models.sort() |
|
print(f"Models: {models}") |
|
|
|
ret = requests.post( |
|
controller_addr + "/get_worker_address", json={"model": args.model_name} |
|
) |
|
worker_addr = ret.json()["address"] |
|
print(f"worker_addr: {worker_addr}") |
|
|
|
if worker_addr == "": |
|
return |
|
|
|
conv = get_conv_template("vicuna_v1.1") |
|
conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") |
|
prompt_template = conv.get_prompt() |
|
prompts = [prompt_template for _ in range(args.n_thread)] |
|
|
|
headers = {"User-Agent": "fastchat Client"} |
|
ploads = [ |
|
{ |
|
"model": args.model_name, |
|
"prompt": prompts[i], |
|
"max_new_tokens": args.max_new_tokens, |
|
"temperature": 0.0, |
|
|
|
} |
|
for i in range(len(prompts)) |
|
] |
|
|
|
def send_request(results, i): |
|
if args.test_dispatch: |
|
ret = requests.post( |
|
controller_addr + "/get_worker_address", json={"model": args.model_name} |
|
) |
|
thread_worker_addr = ret.json()["address"] |
|
else: |
|
thread_worker_addr = worker_addr |
|
print(f"thread {i} goes to {thread_worker_addr}") |
|
response = requests.post( |
|
thread_worker_addr + "/worker_generate_stream", |
|
headers=headers, |
|
json=ploads[i], |
|
stream=False, |
|
) |
|
k = list( |
|
response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") |
|
) |
|
|
|
response_new_words = json.loads(k[-2].decode("utf-8"))["text"] |
|
error_code = json.loads(k[-2].decode("utf-8"))["error_code"] |
|
|
|
results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) |
|
|
|
|
|
tik = time.time() |
|
threads = [] |
|
results = [None] * args.n_thread |
|
for i in range(args.n_thread): |
|
t = threading.Thread(target=send_request, args=(results, i)) |
|
t.start() |
|
|
|
threads.append(t) |
|
|
|
for t in threads: |
|
t.join() |
|
|
|
print(f"Time (POST): {time.time() - tik} s") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_words = sum(results) |
|
time_seconds = time.time() - tik |
|
print( |
|
f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " |
|
f"throughput: {n_words / time_seconds} words/s." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--controller-address", type=str, default="http://localhost:21001" |
|
) |
|
parser.add_argument("--worker-address", type=str) |
|
parser.add_argument("--model-name", type=str, default="vicuna") |
|
parser.add_argument("--max-new-tokens", type=int, default=2048) |
|
parser.add_argument("--n-thread", type=int, default=8) |
|
parser.add_argument("--test-dispatch", action="store_true") |
|
args = parser.parse_args() |
|
|
|
main() |
|
|