|
""" |
|
Usage: |
|
python3 -m fastchat.serve.huggingface_api --model ~/model_weights/vicuna-7b/ |
|
""" |
|
import argparse |
|
import json |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
from fastchat.conversation import get_default_conv_template, compute_skip_echo_len |
|
from fastchat.serve.inference import load_model |
|
|
|
|
|
@torch.inference_mode() |
|
def main(args): |
|
model, tokenizer = load_model( |
|
args.model_path, |
|
args.device, |
|
args.num_gpus, |
|
args.max_gpu_memory, |
|
args.load_8bit, |
|
debug=args.debug, |
|
) |
|
|
|
msg = args.message |
|
|
|
conv = get_default_conv_template(args.model_path).copy() |
|
conv.append_message(conv.roles[0], msg) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
inputs = tokenizer([prompt]) |
|
output_ids = model.generate( |
|
torch.as_tensor(inputs.input_ids).cuda(), |
|
do_sample=True, |
|
temperature=0.7, |
|
max_new_tokens=1024, |
|
) |
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] |
|
skip_echo_len = compute_skip_echo_len(args.model_path, conv, prompt) |
|
outputs = outputs[skip_echo_len:] |
|
|
|
print(f"{conv.roles[0]}: {msg}") |
|
print(f"{conv.roles[1]}: {outputs}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model-path", |
|
type=str, |
|
default="facebook/opt-350m", |
|
help="The path to the weights", |
|
) |
|
parser.add_argument( |
|
"--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda" |
|
) |
|
parser.add_argument("--num-gpus", type=str, default="1") |
|
parser.add_argument( |
|
"--max-gpu-memory", |
|
type=str, |
|
help="The maximum memory per gpu. Use a string like '13Gib'", |
|
) |
|
parser.add_argument( |
|
"--load-8bit", action="store_true", help="Use 8-bit quantization." |
|
) |
|
parser.add_argument( |
|
"--conv-template", type=str, default=None, help="Conversation prompt template." |
|
) |
|
parser.add_argument("--temperature", type=float, default=0.7) |
|
parser.add_argument("--max-new-tokens", type=int, default=512) |
|
parser.add_argument("--debug", action="store_true") |
|
parser.add_argument("--message", type=str, default="Hello! Who are you?") |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|