|
import os |
|
import time |
|
import argparse |
|
|
|
from dotenv import load_dotenv |
|
from distutils.util import strtobool |
|
from memory_profiler import memory_usage |
|
from tqdm import tqdm |
|
|
|
from llama2_wrapper import LLAMA2_WRAPPER |
|
|
|
|
|
def run_iteration( |
|
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS |
|
): |
|
def generation(): |
|
generator = llama2_wrapper.run( |
|
prompt_example, |
|
[], |
|
DEFAULT_SYSTEM_PROMPT, |
|
DEFAULT_MAX_NEW_TOKENS, |
|
1, |
|
0.95, |
|
50, |
|
) |
|
model_response = None |
|
try: |
|
first_model_response = next(generator) |
|
except StopIteration: |
|
pass |
|
for model_response in generator: |
|
pass |
|
return llama2_wrapper.get_token_length(model_response), model_response |
|
|
|
tic = time.perf_counter() |
|
mem_usage, (output_token_length, model_response) = memory_usage( |
|
(generation,), max_usage=True, retval=True |
|
) |
|
toc = time.perf_counter() |
|
|
|
generation_time = toc - tic |
|
tokens_per_second = output_token_length / generation_time |
|
|
|
return generation_time, tokens_per_second, mem_usage, model_response |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--iter", type=int, default=5, help="Number of iterations") |
|
parser.add_argument("--model_path", type=str, default="", help="model path") |
|
parser.add_argument( |
|
"--backend_type", |
|
type=str, |
|
default="", |
|
help="Backend options: llama.cpp, gptq, transformers", |
|
) |
|
parser.add_argument( |
|
"--load_in_8bit", |
|
type=bool, |
|
default=False, |
|
help="Whether to use bitsandbytes 8 bit.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
load_dotenv() |
|
|
|
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "") |
|
MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048)) |
|
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024)) |
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000)) |
|
|
|
MODEL_PATH = os.getenv("MODEL_PATH") |
|
assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}" |
|
BACKEND_TYPE = os.getenv("BACKEND_TYPE") |
|
assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}" |
|
|
|
LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True"))) |
|
|
|
if args.model_path != "": |
|
MODEL_PATH = args.model_path |
|
if args.backend_type != "": |
|
BACKEND_TYPE = args.backend_type |
|
if args.load_in_8bit: |
|
LOAD_IN_8BIT = True |
|
|
|
|
|
init_tic = time.perf_counter() |
|
llama2_wrapper = LLAMA2_WRAPPER( |
|
model_path=MODEL_PATH, |
|
backend_type=BACKEND_TYPE, |
|
max_tokens=MAX_INPUT_TOKEN_LENGTH, |
|
load_in_8bit=LOAD_IN_8BIT, |
|
|
|
) |
|
|
|
init_toc = time.perf_counter() |
|
initialization_time = init_toc - init_tic |
|
|
|
total_time = 0 |
|
total_tokens_per_second = 0 |
|
total_memory_gen = 0 |
|
|
|
prompt_example = ( |
|
"Can you explain briefly to me what is the Python programming language?" |
|
) |
|
|
|
|
|
print("Performing cold run...") |
|
run_iteration( |
|
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS |
|
) |
|
|
|
|
|
print(f"Performing {args.iter} timed runs...") |
|
for i in tqdm(range(args.iter)): |
|
try: |
|
gen_time, tokens_per_sec, mem_gen, model_response = run_iteration( |
|
llama2_wrapper, |
|
prompt_example, |
|
DEFAULT_SYSTEM_PROMPT, |
|
DEFAULT_MAX_NEW_TOKENS, |
|
) |
|
total_time += gen_time |
|
total_tokens_per_second += tokens_per_sec |
|
total_memory_gen += mem_gen |
|
except: |
|
break |
|
avg_time = total_time / (i + 1) |
|
avg_tokens_per_second = total_tokens_per_second / (i + 1) |
|
avg_memory_gen = total_memory_gen / (i + 1) |
|
|
|
print(f"Last model response: {model_response}") |
|
print(f"Initialization time: {initialization_time:0.4f} seconds.") |
|
print( |
|
f"Average generation time over {(i + 1)} iterations: {avg_time:0.4f} seconds." |
|
) |
|
print( |
|
f"Average speed over {(i + 1)} iterations: {avg_tokens_per_second:0.4f} tokens/sec." |
|
) |
|
print(f"Average memory usage during generation: {avg_memory_gen:.2f} MiB") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|