FIRE / src /serve /test_throughput.py
zhangbofei
feat: change to fstchat
6dc0c9c
raw
history blame
3.98 kB
"""Benchmarking script to test the throughput of serving workers."""
import argparse
import json
import requests
import threading
import time
from fastchat.conversation import get_conv_template
def main():
if args.worker_address:
worker_addr = args.worker_address
else:
controller_addr = args.controller_address
ret = requests.post(controller_addr + "/refresh_all_workers")
ret = requests.post(controller_addr + "/list_models")
models = ret.json()["models"]
models.sort()
print(f"Models: {models}")
ret = requests.post(
controller_addr + "/get_worker_address", json={"model": args.model_name}
)
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")
if worker_addr == "":
return
conv = get_conv_template("vicuna_v1.1")
conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words")
prompt_template = conv.get_prompt()
prompts = [prompt_template for _ in range(args.n_thread)]
headers = {"User-Agent": "fastchat Client"}
ploads = [
{
"model": args.model_name,
"prompt": prompts[i],
"max_new_tokens": args.max_new_tokens,
"temperature": 0.0,
# "stop": conv.sep,
}
for i in range(len(prompts))
]
def send_request(results, i):
if args.test_dispatch:
ret = requests.post(
controller_addr + "/get_worker_address", json={"model": args.model_name}
)
thread_worker_addr = ret.json()["address"]
else:
thread_worker_addr = worker_addr
print(f"thread {i} goes to {thread_worker_addr}")
response = requests.post(
thread_worker_addr + "/worker_generate_stream",
headers=headers,
json=ploads[i],
stream=False,
)
k = list(
response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")
)
# print(k)
response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
error_code = json.loads(k[-2].decode("utf-8"))["error_code"]
# print(f"=== Thread {i} ===, words: {1}, error code: {error_code}")
results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" "))
# use N threads to prompt the backend
tik = time.time()
threads = []
results = [None] * args.n_thread
for i in range(args.n_thread):
t = threading.Thread(target=send_request, args=(results, i))
t.start()
# time.sleep(0.5)
threads.append(t)
for t in threads:
t.join()
print(f"Time (POST): {time.time() - tik} s")
# n_words = 0
# for i, response in enumerate(results):
# # print(prompt[i].replace(conv.sep, "\n"), end="")
# # make sure the streaming finishes at EOS or stopping criteria
# k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
# response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
# # print(response_new_words)
# n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
n_words = sum(results)
time_seconds = time.time() - tik
print(
f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
f"throughput: {n_words / time_seconds} words/s."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument("--worker-address", type=str)
parser.add_argument("--model-name", type=str, default="vicuna")
parser.add_argument("--max-new-tokens", type=int, default=2048)
parser.add_argument("--n-thread", type=int, default=8)
parser.add_argument("--test-dispatch", action="store_true")
args = parser.parse_args()
main()