Spaces:
Runtime error
Runtime error
File size: 3,489 Bytes
579a9aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import argparse
import gradio as gr
from peft import AutoPeftModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
parser = argparse.ArgumentParser()
parser.add_argument("--model_path_or_id",
type=str,
default = "NousResearch/Llama-2-7b-hf",
required = False,
help = "Model ID or path to saved model")
parser.add_argument("--lora_path",
type=str,
default = None,
required = False,
help = "Path to the saved lora adapter")
args = parser.parse_args()
if args.lora_path:
# load base LLM model with PEFT Adapter
model = AutoPeftModelForCausalLM.from_pretrained(
args.lora_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.lora_path)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_path_or_id,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True
)
tokenizer = AutoTokenizer.from_pretrained(args.model_path_or_id)
with gr.Blocks() as demo:
gr.HTML(f"""
<h2> Instruction Chat Bot Demo </h2>
<h3> Model ID : {args.model_path_or_id} </h3>
<h3> Peft Adapter : {args.lora_path} </h3>
""")
chat_history = gr.Chatbot(label = "Instruction Bot")
msg = gr.Textbox(label = "Instruction")
with gr.Accordion(label = "Generation Parameters", open = False):
prompt_format = gr.Textbox(
label = "Formatting prompt",
value = "{instruction}",
lines = 8)
with gr.Row():
max_new_tokens = gr.Number(minimum = 25, maximum = 500, value = 100, label = "Max New Tokens")
temperature = gr.Slider(minimum = 0, maximum = 1.0, value = 0.7, label = "Temperature")
clear = gr.ClearButton([msg, chat_history])
def user(user_message, history):
return "", [[user_message, None]]
def bot(chat_history, prompt_format, max_new_tokens, temperature):
# Format the instruction using the format string with key
# {instruction}
formatted_inst = prompt_format.format(
instruction = chat_history[-1][0]
)
# Tokenize the input
input_ids = tokenizer(
formatted_inst,
return_tensors="pt",
truncation=True).input_ids.cuda()
# Support for streaming of tokens within generate requires
# generation to run in a separate thread
streamer = TextIteratorStreamer(tokenizer, skip_prompt = True)
generation_kwargs = dict(
input_ids = input_ids,
streamer = streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=temperature,
use_cache=True
)
thread = Thread(target = model.generate, kwargs = generation_kwargs)
thread.start()
chat_history[-1][1] = ""
for new_text in streamer:
chat_history[-1][1] += new_text
yield chat_history
msg.submit(user,[msg, chat_history], [msg, chat_history], queue = False).then(
bot, [chat_history, prompt_format, max_new_tokens, temperature], chat_history
)
demo.queue()
demo.launch() |