Tonic commited on
Commit
9ff18cc
·
1 Parent(s): bdec39a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+
4
+ # Load the model and tokenizer
5
+ model_name = "01-ai/Yi-34B-200K"
6
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
+ def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0.3, top_p=0.9, top_k=50):
10
+ prompt = get_prompt(message, chat_history, system_prompt)
11
+
12
+ # Encode the prompt to tensor
13
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
14
+
15
+ # Generate a response using the model with adjusted parameters
16
+ response_ids = model.generate(
17
+ input_ids,
18
+ max_length=max_new_tokens + input_ids.shape[1],
19
+ temperature=temperature, # Controls randomness. Lower values make text more deterministic.
20
+ top_p=top_p, # Nucleus sampling: higher values allow more diversity.
21
+ top_k=top_k, # Top-k sampling: limits the number of top tokens considered.
22
+ pad_token_id=tokenizer.eos_token_id
23
+ )
24
+
25
+ # Decode the response
26
+ response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
27
+ return response
28
+
29
+ def get_prompt(message, chat_history, system_prompt):
30
+ texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
31
+
32
+ do_strip = False
33
+ for user_input, response in chat_history:
34
+ user_input = user_input.strip() if do_strip else user_input
35
+ do_strip = True
36
+ texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
37
+ message = message.strip() if do_strip else message
38
+ texts.append(f"{message} [/INST]")
39
+ return ''.join(texts)
40
+
41
+ DEFAULT_SYSTEM_PROMPT = """
42
+ You are Yi. You are an AI assistant, you are moderately-polite and give only true information.
43
+ You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning.
44
+ If you think there might not be a correct answer, you say so. Since you are autoregressive,
45
+ each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context,
46
+ assumptions, and step-by-step thinking BEFORE you try to answer a question.
47
+ """
48
+ MAX_MAX_NEW_TOKENS = 200000
49
+ DEFAULT_MAX_NEW_TOKENS = 100000
50
+ MAX_INPUT_TOKEN_LENGTH = 100000
51
+
52
+ DESCRIPTION = "# [Yi-6B](https://huggingface.co/01-ai/Yi-6B)"
53
+
54
+ def clear_and_save_textbox(message): return '', message
55
+
56
+ def display_input(message, history=[]):
57
+ history.append((message, ''))
58
+ return history
59
+
60
+ def delete_prev_fn(history=[]):
61
+ try:
62
+ message, _ = history.pop()
63
+ except IndexError:
64
+ message = ''
65
+ return history, message or ''
66
+
67
+ def generate(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
68
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
69
+ raise ValueError
70
+
71
+ history = history_with_input[:-1]
72
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
73
+ try:
74
+ first_response = next(generator)
75
+ yield history + [(message, first_response)]
76
+ except StopIteration:
77
+ yield history + [(message, '')]
78
+ for response in generator:
79
+ yield history + [(message, response)]
80
+
81
+ def process_example(message):
82
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
83
+ for x in generator:
84
+ pass
85
+ return '', x
86
+
87
+ def check_input_token_length(message, chat_history, system_prompt):
88
+ input_token_length = len(message) + len(chat_history)
89
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
90
+ raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
91
+
92
+ with gr.Blocks(theme='ParityError/Anime') as demo:
93
+ gr.Markdown(DESCRIPTION)
94
+
95
+
96
+
97
+ with gr.Group():
98
+ chatbot = gr.Chatbot(label='Yi-6B')
99
+ with gr.Row():
100
+ textbox = gr.Textbox(
101
+ container=False,
102
+ show_label=False,
103
+ placeholder='Hi, Yi',
104
+ scale=10
105
+ )
106
+ submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0)
107
+
108
+ with gr.Row():
109
+ retry_button = gr.Button('Retry', variant='secondary')
110
+ undo_button = gr.Button('Undo', variant='secondary')
111
+ clear_button = gr.Button('Clear', variant='secondary')
112
+
113
+ saved_input = gr.State()
114
+
115
+ with gr.Accordion(label='Advanced options', open=False):
116
+ system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
117
+ max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
118
+ temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
119
+ top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
120
+ top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10)
121
+
122
+ textbox.submit(
123
+ fn=clear_and_save_textbox,
124
+ inputs=textbox,
125
+ outputs=[textbox, saved_input],
126
+ api_name=False,
127
+ queue=False,
128
+ ).then(
129
+ fn=display_input,
130
+ inputs=[saved_input, chatbot],
131
+ outputs=chatbot,
132
+ api_name=False,
133
+ queue=False,
134
+ ).then(
135
+ fn=check_input_token_length,
136
+ inputs=[saved_input, chatbot, system_prompt],
137
+ api_name=False,
138
+ queue=False,
139
+ ).success(
140
+ fn=generate,
141
+ inputs=[
142
+ saved_input,
143
+ chatbot,
144
+ system_prompt,
145
+ max_new_tokens,
146
+ temperature,
147
+ top_p,
148
+ top_k,
149
+ ],
150
+ outputs=chatbot,
151
+ api_name=False,
152
+ )
153
+
154
+ button_event_preprocess = submit_button.click(
155
+ fn=clear_and_save_textbox,
156
+ inputs=textbox,
157
+ outputs=[textbox, saved_input],
158
+ api_name=False,
159
+ queue=False,
160
+ ).then(
161
+ fn=display_input,
162
+ inputs=[saved_input, chatbot],
163
+ outputs=chatbot,
164
+ api_name=False,
165
+ queue=False,
166
+ ).then(
167
+ fn=check_input_token_length,
168
+ inputs=[saved_input, chatbot, system_prompt],
169
+ api_name=False,
170
+ queue=False,
171
+ ).success(
172
+ fn=generate,
173
+ inputs=[
174
+ saved_input,
175
+ chatbot,
176
+ system_prompt,
177
+ max_new_tokens,
178
+ temperature,
179
+ top_p,
180
+ top_k,
181
+ ],
182
+ outputs=chatbot,
183
+ api_name=False,
184
+ )
185
+
186
+ retry_button.click(
187
+ fn=delete_prev_fn,
188
+ inputs=chatbot,
189
+ outputs=[chatbot, saved_input],
190
+ api_name=False,
191
+ queue=False,
192
+ ).then(
193
+ fn=display_input,
194
+ inputs=[saved_input, chatbot],
195
+ outputs=chatbot,
196
+ api_name=False,
197
+ queue=False,
198
+ ).then(
199
+ fn=generate,
200
+ inputs=[
201
+ saved_input,
202
+ chatbot,
203
+ system_prompt,
204
+ max_new_tokens,
205
+ temperature,
206
+ top_p,
207
+ top_k,
208
+ ],
209
+ outputs=chatbot,
210
+ api_name=False,
211
+ )
212
+
213
+ undo_button.click(
214
+ fn=delete_prev_fn,
215
+ inputs=chatbot,
216
+ outputs=[chatbot, saved_input],
217
+ api_name=False,
218
+ queue=False,
219
+ ).then(
220
+ fn=lambda x: x,
221
+ inputs=[saved_input],
222
+ outputs=textbox,
223
+ api_name=False,
224
+ queue=False,
225
+ )
226
+
227
+ clear_button.click(
228
+ fn=lambda: ([], ''),
229
+ outputs=[chatbot, saved_input],
230
+ queue=False,
231
+ api_name=False,
232
+ )
233
+
234
+ demo.queue(max_size=32).launch(show_api=False)