tianyang commited on
Commit
51e2020
·
1 Parent(s): 058af24
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+
5
+ import torch
6
+ import gradio as gr
7
+ import logging
8
+
9
+ from utils.inference import load_tokenizer_and_model, decode, \
10
+ get_prompt_with_history, is_stop_word_or_prefix
11
+
12
+ from utils.gradio import reset_textbox, cancel_outputing, transfer_input, \
13
+ delete_last_conversation, reset_state, convert_to_markdown
14
+
15
+
16
+
17
+ # set variables
18
+ model = "lemur-7B"
19
+
20
+
21
+ print("Loading model...")
22
+
23
+ import time
24
+
25
+ start = time.time()
26
+
27
+ tokenizer, model, device = load_tokenizer_and_model(model, load_8bit=True)
28
+
29
+ print("Model loaded in {} seconds.".format(time.time() - start))
30
+
31
+
32
+ def predict(
33
+ text,
34
+ chatbot,
35
+ history,
36
+ top_p,
37
+ temperature,
38
+ max_length_tokens,
39
+ max_context_length_tokens,
40
+ ):
41
+ if text == "":
42
+ yield chatbot, history, "Empty context."
43
+ return
44
+
45
+ inputs = get_prompt_with_history(
46
+ text, history, tokenizer, max_length=max_context_length_tokens
47
+ )
48
+ if inputs is None:
49
+ yield chatbot, history, "Input too long."
50
+ return
51
+ else:
52
+ prompt, inputs = inputs
53
+
54
+ input_ids = inputs["input_ids"][:, -max_context_length_tokens:].to(device)
55
+ torch.cuda.empty_cache()
56
+
57
+ with torch.no_grad():
58
+ for x in decode(
59
+ input_ids,
60
+ model,
61
+ tokenizer,
62
+ stop_words=["[Human]", "[AI]"],
63
+ max_length=max_length_tokens,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ ):
67
+ if is_stop_word_or_prefix(x, ["[Human]", "[AI]"]) is False:
68
+ if "[Human]" in x:
69
+ x = x[: x.index("[Human]")].strip()
70
+ if "[AI]" in x:
71
+ x = x[: x.index("[AI]")].strip()
72
+ x = x.strip(" ")
73
+ a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [
74
+ [text, convert_to_markdown(x)]
75
+ ], history + [[text, x]]
76
+ yield a, b, "Generating..."
77
+
78
+ torch.cuda.empty_cache()
79
+ print(prompt)
80
+ print(x)
81
+ print("=" * 80)
82
+ try:
83
+ yield a, b, "Generate: Success"
84
+ except:
85
+ pass
86
+
87
+ def retry(
88
+ text,
89
+ chatbot,
90
+ history,
91
+ top_p,
92
+ temperature,
93
+ max_length_tokens,
94
+ max_context_length_tokens,
95
+ ):
96
+ logging.info("Retry...")
97
+ if len(history) == 0:
98
+ yield chatbot, history, "Empty context."
99
+ return
100
+ chatbot.pop()
101
+ inputs = history.pop()[0]
102
+ for x in predict(
103
+ inputs,
104
+ chatbot,
105
+ history,
106
+ top_p,
107
+ temperature,
108
+ max_length_tokens,
109
+ max_context_length_tokens,
110
+ ):
111
+ yield x
112
+
113
+
114
+ with gr.Blocks(
115
+ theme=gr.themes.Soft(),
116
+ css=".disclaimer {font-variant-caps: all-small-caps;}"
117
+ ) as demo:
118
+ history = gr.State([])
119
+ user_question = gr.State("")
120
+ with gr.Row():
121
+ gr.HTML("<h1>Lemur 🦥</h1>")
122
+ status_display = gr.Markdown("Success", elem_id="status_display")
123
+
124
+ with gr.Row(scale=1).style(equal_height=True):
125
+ with gr.Column(scale=5):
126
+ with gr.Row(scale=1):
127
+ chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height=800)
128
+ with gr.Row(scale=1):
129
+ with gr.Column(scale=12):
130
+ user_input = gr.Textbox(
131
+ show_label=False, placeholder="Enter text"
132
+ ).style(container=False)
133
+ with gr.Column(min_width=70, scale=1):
134
+ submitBtn = gr.Button("📤 Send")
135
+ with gr.Column(min_width=70, scale=1):
136
+ cancelBtn = gr.Button("⏸️ Stop")
137
+
138
+ with gr.Row(scale=1):
139
+ emptyBtn = gr.Button(
140
+ "🧹 New Conversation",
141
+ )
142
+ retryBtn = gr.Button("🔄 Regenerate")
143
+ delLastBtn = gr.Button("🗑️ Remove Last Turn")
144
+ with gr.Column():
145
+ with gr.Column(min_width=50, scale=1):
146
+ with gr.Tab(label="Parameter Setting"):
147
+ gr.Markdown("# Parameters")
148
+ top_p = gr.Slider(
149
+ minimum=-0,
150
+ maximum=1.0,
151
+ value=0.95,
152
+ step=0.05,
153
+ interactive=True,
154
+ label="Top-p",
155
+ )
156
+ temperature = gr.Slider(
157
+ minimum=0.1,
158
+ maximum=2.0,
159
+ value=1,
160
+ step=0.1,
161
+ interactive=True,
162
+ label="Temperature",
163
+ )
164
+ max_length_tokens = gr.Slider(
165
+ minimum=0,
166
+ maximum=512,
167
+ value=512,
168
+ step=8,
169
+ interactive=True,
170
+ label="Max Generation Tokens",
171
+ )
172
+ max_context_length_tokens = gr.Slider(
173
+ minimum=0,
174
+ maximum=4096,
175
+ value=2048,
176
+ step=128,
177
+ interactive=True,
178
+ label="Max History Tokens",
179
+ )
180
+
181
+ predict_args = dict(
182
+ fn=predict,
183
+ inputs=[
184
+ user_question,
185
+ chatbot,
186
+ history,
187
+ top_p,
188
+ temperature,
189
+ max_length_tokens,
190
+ max_context_length_tokens,
191
+ ],
192
+ outputs=[chatbot, history, status_display],
193
+ show_progress=True,
194
+ )
195
+ retry_args = dict(
196
+ fn=retry,
197
+ inputs=[
198
+ user_input,
199
+ chatbot,
200
+ history,
201
+ top_p,
202
+ temperature,
203
+ max_length_tokens,
204
+ max_context_length_tokens,
205
+ ],
206
+ outputs=[chatbot, history, status_display],
207
+ show_progress=True,
208
+ )
209
+
210
+ reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display])
211
+
212
+ # Chatbot
213
+
214
+ transfer_input_args = dict(
215
+ fn=transfer_input,
216
+ inputs=[user_input],
217
+ outputs=[user_question, user_input, submitBtn, cancelBtn],
218
+ show_progress=True,
219
+ )
220
+
221
+ submit_event = user_input.submit(**transfer_input_args).then(**predict_args)
222
+
223
+ submit_click_event = submitBtn.click(**transfer_input_args).then(**predict_args)
224
+
225
+ emptyBtn.click(
226
+ reset_state,
227
+ outputs=[chatbot, history, status_display],
228
+ show_progress=True,
229
+ )
230
+ emptyBtn.click(**reset_args)
231
+
232
+ retry_click_event = retryBtn.click(**retry_args)
233
+
234
+ cancelBtn.click(
235
+ fn=cancel_outputing,
236
+ inputs=[],
237
+ outputs=[status_display],
238
+ cancels=[submit_event, submit_click_event]
239
+ )
240
+
241
+ delLastBtn.click(
242
+ delete_last_conversation,
243
+ [chatbot, history],
244
+ [chatbot, history, status_display],
245
+ show_progress=True,
246
+ )
247
+
248
+ demo.title = "Lemur"
249
+ demo.queue(max_size=128, concurrency_count=2)
250
+ demo.launch()
lemur-7B/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "llama-7B",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "bos_token_id": 1,
7
+ "eos_token_id": 2,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 4096,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 11008,
12
+ "max_position_embeddings": 2048,
13
+ "model_type": "llama",
14
+ "num_attention_heads": 32,
15
+ "num_hidden_layers": 32,
16
+ "pad_token_id": 0,
17
+ "rms_norm_eps": 1e-06,
18
+ "tie_word_embeddings": false,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.30.1",
21
+ "use_cache": true,
22
+ "vocab_size": 32000
23
+ }
lemur-7B/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.30.1"
7
+ }
lemur-7B/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8800f80fe257ad94942049beaa2dc86703571a8696bcaf0f03f57c021a2ec6ec
3
+ size 524332500
lemur-7B/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
lemur-7B/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
lemur-7B/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
lemur-7B/tokenizer_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "__type": "AddedToken",
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "clean_up_tokenization_spaces": false,
11
+ "eos_token": {
12
+ "__type": "AddedToken",
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "model_max_length": 1000000000000000019884624838656,
20
+ "pad_token": null,
21
+ "sp_model_kwargs": {},
22
+ "tokenizer_class": "LlamaTokenizer",
23
+ "unk_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<unk>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
utils/gradio.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.inference import shared_state
3
+ import re
4
+
5
+ def convert_to_markdown(text):
6
+ text = text.replace("$", "&#36;")
7
+
8
+ def replace_leading_tabs_and_spaces(line):
9
+ new_line = []
10
+
11
+ for char in line:
12
+ if char == "\t":
13
+ new_line.append("&#9;")
14
+ elif char == " ":
15
+ new_line.append("&nbsp;")
16
+ else:
17
+ break
18
+ return "".join(new_line) + line[len(new_line) :]
19
+
20
+ markdown_text = ""
21
+ lines = text.split("\n")
22
+ in_code_block = False
23
+
24
+ for line in lines:
25
+ if in_code_block is False and line.startswith("```"):
26
+ in_code_block = True
27
+ markdown_text += "```\n"
28
+ elif in_code_block is True and line.startswith("```"):
29
+ in_code_block = False
30
+ markdown_text += "```\n"
31
+ elif in_code_block:
32
+ markdown_text += f"{line}\n"
33
+ else:
34
+ line = replace_leading_tabs_and_spaces(line)
35
+ line = re.sub(r"^(#)", r"\\\1", line)
36
+ markdown_text += f"{line} \n"
37
+
38
+ return markdown_text
39
+
40
+ def reset_textbox():
41
+ return gr.update(value=""), ""
42
+
43
+ def cancel_outputing():
44
+ shared_state.interrupt()
45
+ textbox = reset_textbox()
46
+ return "Stop Done"
47
+
48
+ def reset_state():
49
+ return [], [], "Reset Done"
50
+
51
+ def transfer_input(inputs):
52
+ textbox = reset_textbox()
53
+ return (
54
+ inputs,
55
+ gr.update(value=""),
56
+ gr.Button.update(visible=True),
57
+ gr.Button.update(visible=True)
58
+ )
59
+
60
+ def delete_last_conversation(chatbot, history):
61
+ if len(chatbot) > 0:
62
+ chatbot.pop()
63
+
64
+ if len(history) > 0:
65
+ history.pop()
66
+
67
+ return (
68
+ chatbot,
69
+ history,
70
+ "Delete Done",
71
+ )
utils/inference.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
+ from typing import Iterator
5
+ from variables import SYSTEM, HUMAN, AI
6
+
7
+
8
+ def load_tokenizer_and_model(base_model, load_8bit=True):
9
+
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ else:
13
+ device = "cpu"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
16
+ model = AutoModelForCausalLM.from_pretrained(base_model, load_8bit=load_8bit)
17
+
18
+ return tokenizer, model, device
19
+
20
+ class State:
21
+ interrupted = False
22
+
23
+ def interrupt(self):
24
+ self.interrupted = True
25
+
26
+ def recover(self):
27
+ self.interrupted = False
28
+
29
+ shared_state = State()
30
+
31
+ def decode(
32
+ input_ids: torch.Tensor,
33
+ model: PeftModel,
34
+ tokenizer: AutoTokenizer,
35
+ stop_words: list,
36
+ max_length: int,
37
+ temperature: float = 1.0,
38
+ top_p: float = 1.0,
39
+ ) -> Iterator[str]:
40
+ generated_tokens = []
41
+ past_key_values = None
42
+
43
+ for _ in range(max_length):
44
+ with torch.no_grad():
45
+ if past_key_values is None:
46
+ outputs = model(input_ids)
47
+ else:
48
+ outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
49
+ logits = outputs.logits[:, -1, :]
50
+ past_key_values = outputs.past_key_values
51
+
52
+ # apply temperature
53
+ logits /= temperature
54
+
55
+ probs = torch.softmax(logits, dim=-1)
56
+ # apply top_p
57
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
58
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
59
+ mask = probs_sum - probs_sort > top_p
60
+ probs_sort[mask] = 0.0
61
+
62
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
63
+ next_token = torch.multinomial(probs_sort, num_samples=1)
64
+ next_token = torch.gather(probs_idx, -1, next_token)
65
+
66
+ input_ids = torch.cat((input_ids, next_token), dim=-1)
67
+
68
+ generated_tokens.append(next_token[0].item())
69
+ text = tokenizer.decode(generated_tokens)
70
+
71
+ yield text
72
+ if any([x in text for x in stop_words]):
73
+ return
74
+
75
+
76
+ def get_prompt_with_history(text, history, tokenizer, max_length=2048):
77
+ prompt = SYSTEM
78
+ history = [f"\n{HUMAN} {x[0]}\n{AI} {x[1]}" for x in history]
79
+ history.append(f"\n{HUMAN} {text}\n{AI}")
80
+ history_text = ""
81
+ flag = False
82
+ for x in history[::-1]:
83
+ if (
84
+ tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(
85
+ -1
86
+ )
87
+ <= max_length
88
+ ):
89
+ history_text = x + history_text
90
+ flag = True
91
+ else:
92
+ break
93
+ if flag:
94
+ return prompt + history_text, tokenizer(
95
+ prompt + history_text, return_tensors="pt"
96
+ )
97
+ else:
98
+ return None
99
+
100
+ def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
101
+ for stop_word in stop_words:
102
+ if s.endswith(stop_word):
103
+ return True
104
+ for i in range(1, len(stop_word)):
105
+ if s.endswith(stop_word[:i]):
106
+ return True
107
+ return False
variables.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
2
+ HUMAN = "[Human]:"
3
+ AI = "[AI]:"
4
+ NAME = "Lemur"
5
+ ORGANIZATION = "UC San Diego (UCSD)"