"""
cmd example
You need a file called "sample.txt" (default path) with text to take tokens for prompts or supply --text_file "path/to/text.txt" as an argument to a text file.
You can use our attached "sample.txt" file with one of Deci's blogs as a prompt.
# Run this and record tokens per second (652 tokens per second on A10 for DeciLM-6b)
python hf_benchmark_example.py --model Deci/DeciLM-6b-instruct
# Run this and record tokens per second (136 tokens per second on A10 for meta-llama/Llama-2-7b-hf), CUDA OOM above batch size 8
python hf_benchmark_example.py --model meta-llama/Llama-2-7b-hf --batch_size 8
"""

import json

import datasets
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from argparse import ArgumentParser


def parse_args():
    parser = ArgumentParser()

    parser.add_argument(
        "--model",
        required=True,
        help="Model to evaluate, provide a repo name in Hugging Face hub or a local path",
    )
    parser.add_argument(
        "--temperature",
        default=0.2,
        type=float
    )
    parser.add_argument(
        "--top_p",
        default=0.95,
        type=float
    )
    parser.add_argument(
        "--top_k",
        default=0,
        type=float
    )

    parser.add_argument(
        "--revision",
        default=None,
        help="Model revision to use",
    )
    parser.add_argument(
        "--iterations",
        type=int,
        default=6,
        help="Model revision to use",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=64,
        help="Batch size for evaluation on each worker, can be larger for HumanEval",
    
    )
    parser.add_argument(
        "--prompt_length",
        type=int,
        default=512,
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=512,
        help="Maximum length of generated sequence (prompt+generation)",
    )
    parser.add_argument(
        "--precision",
        type=str,
        default="bf16",
        help="Model precision, from: fp32, fp16 or bf16",
    )
    parser.add_argument(
        "--text_file",
	type=str,
        default="sample.txt",
        help="text file that will be used to generate tokens for prompts",
    )
    parser.add_argument(
        "--load_in_8bit",
        action="store_true",
        help="Load model in 8bit",
    )
    parser.add_argument(
        "--load_in_4bit",
        action="store_true",
        help="Load model in 4bit",
    )
    return parser.parse_args()


def main():
    args = parse_args()
    transformers.logging.set_verbosity_error()
    datasets.logging.set_verbosity_error()


    results = {}
    dict_precisions = {
        "fp32": torch.float32,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
    }
    if args.precision not in dict_precisions:
        raise ValueError(
            f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16"
        )
    if args.load_in_8bit:
        print("Loading model in 8bit")
        # the model needs to fit in one GPU
        model = AutoModelForCausalLM.from_pretrained(
            args.model,
            revision=args.revision,
            load_in_8bit=args.load_in_8bit,
            trust_remote_code=args.trust_remote_code,
            use_auth_token=args.use_auth_token,
            device_map={"": 'cuda'},
        )
    elif args.load_in_4bit:
        print("Loading model in 4bit")
        # the model needs to fit in one GPU
        model = AutoModelForCausalLM.from_pretrained(
            args.model,
            revision=args.revision,
            load_in_4bit=args.load_in_4bit,
            trust_remote_code=args.trust_remote_code,
            use_auth_token=args.use_auth_token,
            device_map={"": 'cuda'},
        )
    else:
        print(f"Loading model in {args.precision}")
        model = AutoModelForCausalLM.from_pretrained(
            args.model,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            use_auth_token=True
        )

    tokenizer = AutoTokenizer.from_pretrained(
        args.model,
        revision=args.revision,
        trust_remote_code=True,
        use_auth_token=True,
    )

    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    model.cuda()
    model.eval()
    
    with open(args.text_file, "r") as f:
        prompt = f.read()

    prompt = torch.tensor(tokenizer.encode(prompt))[:args.prompt_length].cuda()    
    
    results = {'prefill': [], 'gen': [], 'max_new_tokens': args.max_new_tokens, 'prompt_length': args.prompt_length, 'model': args.model, 'batch_size': args.batch_size}
    inputs = prompt.repeat(args.batch_size, 1)

    #warmup
    print('start warmup')
    for _ in range(10):
        with torch.no_grad():
            _ = model.generate(
                                input_ids=inputs,
                                max_new_tokens=1,
                        do_sample=False,
                    )
    print('finish warmup')
    torch.cuda.synchronize()
            
    for prefill_iter in range(args.iterations):
        starter.record()
        with torch.no_grad():
            _ = model.generate(
                                input_ids=inputs,
                                max_new_tokens=1,
                        do_sample=False,
                    )
        ender.record()
        torch.cuda.synchronize()
        t = starter.elapsed_time(ender) / 1000
        results['prefill'].append(t)
        print(f'{args.batch_size} prefill iter {prefill_iter} took: {t}')

    
    for gen_iter in range(args.iterations):
        starter.record()
        with torch.no_grad():
            _ = model.generate(
                                input_ids=inputs,
                                max_new_tokens=args.max_new_tokens,
                        do_sample=False,
                    )
        ender.record()
        torch.cuda.synchronize()
        t = starter.elapsed_time(ender) / 1000
        results['gen'].append(t)

        print(f'{args.batch_size} total generation iter {gen_iter} took: {t}')
        print(f'{args.batch_size * args.max_new_tokens / t} tokens per seconds')
    model_str = args.model.split('/')[-1]
    with open(f'timing_{model_str}_{args.batch_size}.json', 'w') as f:
       json.dump(results, f)
    

if __name__ == "__main__":
    main()