gracexu commited on
Commit
579a9aa
Β·
1 Parent(s): b97bda2

add app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import gradio as gr
4
+ from peft import AutoPeftModelForCausalLM
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
+ from threading import Thread
8
+
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--model_path_or_id",
12
+ type=str,
13
+ default = "NousResearch/Llama-2-7b-hf",
14
+ required = False,
15
+ help = "Model ID or path to saved model")
16
+
17
+ parser.add_argument("--lora_path",
18
+ type=str,
19
+ default = None,
20
+ required = False,
21
+ help = "Path to the saved lora adapter")
22
+
23
+ args = parser.parse_args()
24
+
25
+ if args.lora_path:
26
+ # load base LLM model with PEFT Adapter
27
+ model = AutoPeftModelForCausalLM.from_pretrained(
28
+ args.lora_path,
29
+ low_cpu_mem_usage=True,
30
+ torch_dtype=torch.float16,
31
+ load_in_4bit=True,
32
+ )
33
+ tokenizer = AutoTokenizer.from_pretrained(args.lora_path)
34
+ else:
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ args.model_path_or_id,
37
+ low_cpu_mem_usage=True,
38
+ torch_dtype=torch.float16,
39
+ load_in_4bit=True
40
+ )
41
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path_or_id)
42
+
43
+ with gr.Blocks() as demo:
44
+
45
+ gr.HTML(f"""
46
+ <h2> Instruction Chat Bot Demo </h2>
47
+ <h3> Model ID : {args.model_path_or_id} </h3>
48
+ <h3> Peft Adapter : {args.lora_path} </h3>
49
+ """)
50
+
51
+ chat_history = gr.Chatbot(label = "Instruction Bot")
52
+ msg = gr.Textbox(label = "Instruction")
53
+ with gr.Accordion(label = "Generation Parameters", open = False):
54
+ prompt_format = gr.Textbox(
55
+ label = "Formatting prompt",
56
+ value = "{instruction}",
57
+ lines = 8)
58
+ with gr.Row():
59
+ max_new_tokens = gr.Number(minimum = 25, maximum = 500, value = 100, label = "Max New Tokens")
60
+ temperature = gr.Slider(minimum = 0, maximum = 1.0, value = 0.7, label = "Temperature")
61
+
62
+ clear = gr.ClearButton([msg, chat_history])
63
+
64
+ def user(user_message, history):
65
+ return "", [[user_message, None]]
66
+
67
+ def bot(chat_history, prompt_format, max_new_tokens, temperature):
68
+
69
+ # Format the instruction using the format string with key
70
+ # {instruction}
71
+ formatted_inst = prompt_format.format(
72
+ instruction = chat_history[-1][0]
73
+ )
74
+
75
+ # Tokenize the input
76
+ input_ids = tokenizer(
77
+ formatted_inst,
78
+ return_tensors="pt",
79
+ truncation=True).input_ids.cuda()
80
+
81
+ # Support for streaming of tokens within generate requires
82
+ # generation to run in a separate thread
83
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt = True)
84
+ generation_kwargs = dict(
85
+ input_ids = input_ids,
86
+ streamer = streamer,
87
+ max_new_tokens=max_new_tokens,
88
+ do_sample=True,
89
+ top_p=0.9,
90
+ temperature=temperature,
91
+ use_cache=True
92
+ )
93
+
94
+ thread = Thread(target = model.generate, kwargs = generation_kwargs)
95
+ thread.start()
96
+ chat_history[-1][1] = ""
97
+ for new_text in streamer:
98
+ chat_history[-1][1] += new_text
99
+ yield chat_history
100
+
101
+ msg.submit(user,[msg, chat_history], [msg, chat_history], queue = False).then(
102
+ bot, [chat_history, prompt_format, max_new_tokens, temperature], chat_history
103
+ )
104
+
105
+ demo.queue()
106
+ demo.launch()