chansung commited on
Commit
3fdb865
·
1 Parent(s): ee5de69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +471 -0
app.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import gradio as gr
5
+
6
+ from pingpong import PingPong
7
+ from pingpong.pingpong import PPManager
8
+ from pingpong.pingpong import PromptFmt
9
+ from pingpong.pingpong import UIFmt
10
+ from pingpong.gradio import GradioChatUIFmt
11
+
12
+ class LLaMA2ChatPromptFmt(PromptFmt):
13
+ @classmethod
14
+ def ctx(cls, context):
15
+ if context is None or context == "":
16
+ return ""
17
+ else:
18
+ return f"""<<SYS>>
19
+ {context}
20
+ <</SYS>>
21
+
22
+ """
23
+
24
+ @classmethod
25
+ def prompt(cls, pingpong, truncate_size):
26
+ ping = pingpong.ping[:truncate_size]
27
+ pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
28
+ return f"""[INST] {ping} [/INST] {pong}"""
29
+
30
+ class LLaMA2ChatPPManager(PPManager):
31
+ def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None):
32
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
33
+ to_idx = len(self.pingpongs)
34
+
35
+ results = fmt.ctx(self.ctx)
36
+
37
+ for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
38
+ results += fmt.prompt(pingpong, truncate_size=truncate_size)
39
+
40
+ return results
41
+
42
+ class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
43
+ def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
44
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
45
+ to_idx = len(self.pingpongs)
46
+
47
+ results = []
48
+
49
+ for pingpong in self.pingpongs[from_idx:to_idx]:
50
+ results.append(fmt.ui(pingpong))
51
+
52
+ return results
53
+
54
+ TOKEN = os.getenv('HF_TOKEN')
55
+ MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf'
56
+
57
+ STYLES = """
58
+ .small-big {
59
+ font-size: 12pt !important;
60
+ }
61
+
62
+ .small-big-textarea > label > textarea {
63
+ font-size: 12pt !important;
64
+ }
65
+
66
+ .highlighted-text {
67
+ background: yellow;
68
+ overflow-wrap: break-word;
69
+ }
70
+
71
+ .no-gap {
72
+ gap: 0px !important;
73
+ }
74
+
75
+ .group-border {
76
+ padding: 10px;
77
+ border-width: 1px;
78
+ border-radius: 10px;
79
+ border-color: gray;
80
+ border-style: dashed;
81
+ }
82
+
83
+ .control-label-font {
84
+ font-size: 13pt !important;
85
+ }
86
+
87
+ .control-button {
88
+ background: none !important;
89
+ border-color: #69ade2 !important;
90
+ border-width: 2px !important;
91
+ color: #69ade2 !important;
92
+ }
93
+
94
+ .center {
95
+ text-align: center;
96
+ }
97
+
98
+ .right {
99
+ text-align: right;
100
+ }
101
+
102
+ .no-label {
103
+ padding: 0px !important;
104
+ }
105
+
106
+ .no-label > label > span {
107
+ display: none;
108
+ }
109
+
110
+ .no-label-chatbot {
111
+ border: none !important;
112
+ box-shadow: none !important;
113
+ height: 520px !important;
114
+ }
115
+
116
+ .no-label-chatbot > div > div:nth-child(1) {
117
+ display: none;
118
+ }
119
+
120
+ .left-margin-30 {
121
+ padding-left: 30px !important;
122
+ }
123
+
124
+ .left {
125
+ text-align: left !important;
126
+ }
127
+
128
+ .alt-button {
129
+ color: gray !important;
130
+ border-width: 1px !important;
131
+ background: none !important;
132
+ border-color: gray !important;
133
+ text-align: justify !important;
134
+ }
135
+
136
+ .wrap {
137
+ display: contents !important;
138
+ }
139
+
140
+ .white-text {
141
+ color: #000 !important;
142
+ }
143
+ """
144
+
145
+ def get_new_ppm(ping):
146
+ ppm = LLaMA2ChatPPManager()
147
+ ppm.ctx = """\
148
+ You are a helpful, respectful and honest writing helper. Always write stories that suites to query.
149
+
150
+ You DO NOT give explanation but just stories. For instance, do not say such as "Sure! Here's a short paragraph to start a short story:"""
151
+
152
+ ppm.add_pingpong(PingPong(ping, ''))
153
+ return ppm
154
+
155
+ def get_new_ppm_for_chat():
156
+ ppm = GradioLLaMA2ChatPPManager()
157
+ return ppm
158
+
159
+ def gen_text(prompt, hf_model='meta-llama/Llama-2-70b-chat-hf', hf_token=None, parameters=None):
160
+ if hf_token is None:
161
+ raise ValueError("Hugging Face Token is not set")
162
+
163
+ if parameters is None:
164
+ parameters = {
165
+ 'max_new_tokens': 512,
166
+ 'do_sample': True,
167
+ 'return_full_text': False,
168
+ 'temperature': 1.0,
169
+ 'top_k': 50,
170
+ # 'top_p': 1.0,
171
+ 'repetition_penalty': 1.2
172
+ }
173
+
174
+ url = f'https://api-inference.huggingface.co/models/{hf_model}'
175
+ headers={
176
+ 'Authorization': f'Bearer {hf_token}',
177
+ 'Content-type': 'application/json'
178
+ }
179
+ data = {
180
+ 'inputs': prompt,
181
+ 'stream': False,
182
+ 'options': {
183
+ 'use_cache': False,
184
+ },
185
+ 'parameters': parameters
186
+ }
187
+
188
+ r = requests.post(
189
+ url,
190
+ headers=headers,
191
+ data=json.dumps(data)
192
+ )
193
+
194
+ if r.reason != 'OK':
195
+ raise ValueError("Response other than 200")
196
+
197
+ return json.loads(r.content.decode("utf-8"))[0]['generated_text']
198
+
199
+ def select(editor, evt: gr.SelectData):
200
+ return [
201
+ evt.value,
202
+ evt.index[0],
203
+ evt.index[1]
204
+ ]
205
+
206
+ def get_gen_txt(editor, prompt):
207
+ if editor.strip() == '':
208
+ ppm = get_new_ppm('Write a short paragraph to start a short story for me')
209
+ else:
210
+ ppm = get_new_ppm(f"""{prompt}
211
+ --------------------------------
212
+ {editor}""")
213
+
214
+ try:
215
+ txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN)
216
+ return editor + txt + "\n\n"
217
+ except ValueError as e:
218
+ print(f"something went wrong - {e}")
219
+ return editor
220
+
221
+ def gen_txt(editor):
222
+ return [
223
+ get_gen_txt(editor, "Write the next paragraph based on the following stories so far."),
224
+ 0,
225
+ gr.update(interactive=True),
226
+ gr.update(visible=False),
227
+ gr.update(visible=False),
228
+ gr.update(visible=False)
229
+ ]
230
+
231
+ def gen_txt_with_prompt(editor, prompt):
232
+ return [
233
+ get_gen_txt(editor, prompt),
234
+ 0,
235
+ gr.update(interactive=True),
236
+ gr.update(visible=False),
237
+ gr.update(visible=False),
238
+ gr.update(visible=False)
239
+ ]
240
+
241
+ def chat_gen(editor, chat_txt, chatbot, ppm, regen=False):
242
+ ppm.ctx = f"""\
243
+ You are a helpful, respectful and honest assistant.
244
+
245
+ you must consider multi-turn conversations.
246
+
247
+ Answer to questions based on the written stories so far as below
248
+ ----------------
249
+ {editor}
250
+ """
251
+ if regen:
252
+ last_pingpong = ppm.pop_pingpong()
253
+ chat_txt = last_pingpong.ping
254
+ ppm.add_pingpong(PingPong(chat_txt, ''))
255
+
256
+ try:
257
+ txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN)
258
+ ppm.add_pong(txt)
259
+ except ValueError as e:
260
+ print(f"something went wrong - {e}")
261
+
262
+ return [
263
+ "",
264
+ ppm.build_uis(),
265
+ ppm
266
+ ]
267
+
268
+ def chat(editor, chat_txt, chatbot, ppm):
269
+ return chat_gen(editor, chat_txt, chatbot, ppm, regen=False)
270
+
271
+ def regen_chat(editor, chat_txt, chatbot, ppm):
272
+ return chat_gen(editor, chat_txt, chatbot, ppm, regen=True)
273
+
274
+
275
+ def get_new_ppm_for_range():
276
+ ppm = LLaMA2ChatPPManager()
277
+ ppm.ctx = """\
278
+ You are a helpful, respectful and honest writing helper. Always write text that suites to query.
279
+
280
+ You DO NOT give explanation but just stories. DO NOT say such as 'Sure! Here's a short paragraph to start a short story:' or 'Sure, here is a revised version of ....:'
281
+ """
282
+ return ppm
283
+
284
+
285
+ def replace_sel(editor, replace_type, selected_text, sel_index_from, sel_index_to):
286
+ ppm = get_new_ppm_for_range()
287
+
288
+ ping = f"""replace {selected_text} in a single {replace_type} based on the story below
289
+ ----------------
290
+ {editor}
291
+ """
292
+
293
+ ppm.add_pingpong(PingPong(ping, ''))
294
+
295
+ try:
296
+ txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN)
297
+ ppm.add_pong(txt)
298
+ except ValueError as e:
299
+ print(f"something went wrong - {e}")
300
+
301
+ return [
302
+ f"{editor[:sel_index_from]} {txt} {editor[sel_index_to:]}",
303
+ "",
304
+ 0,
305
+ 0
306
+ ]
307
+
308
+ def gen_alt(editor, num_enabled_alts, alt_btn1, alt_btn2, alt_btn3):
309
+ if num_enabled_alts < 3:
310
+ gen_txt = get_gen_txt(editor, "Write the next paragraph based on the following stories so far.")
311
+
312
+ return [
313
+ min(num_enabled_alts+1, 3),
314
+ gr.update(interactive=False if num_enabled_alts >=2 else True),
315
+ gr.update(visible=True if num_enabled_alts >=0 else False),
316
+ gr.update(value=gen_txt if num_enabled_alts == 0 else alt_btn1),
317
+ gr.update(visible=True if num_enabled_alts >=1 else False),
318
+ gr.update(value=gen_txt if num_enabled_alts == 1 else alt_btn2),
319
+ gr.update(visible=True if num_enabled_alts >=2 else False),
320
+ gr.update(value=gen_txt if num_enabled_alts == 2 else alt_btn3),
321
+ ]
322
+
323
+ def fill_with_gen(alt_txt, editor):
324
+ return [
325
+ editor + alt_txt,
326
+ 0,
327
+ gr.update(interactive=True),
328
+ gr.update(visible=False),
329
+ gr.update(visible=False),
330
+ gr.update(visible=False)
331
+ ]
332
+
333
+ with gr.Blocks(css=STYLES) as demo:
334
+
335
+ num_enabled_alts = gr.State(0)
336
+
337
+ sel_index_from = gr.State(0)
338
+ sel_index_to = gr.State(0)
339
+
340
+ chat_history = gr.State(get_new_ppm_for_chat())
341
+
342
+ gr.Markdown("# Co-writing with AI", elem_classes=['center'])
343
+ gr.Markdown(
344
+ "This application is designed for you to collaborate with LLM to co-write stories. It is inspired by [Wordcraft project](https://wordcraft-writers-workshop.appspot.com/) from Google's PAIR and Magenta teams. "
345
+ "This application built on [Gradio](https://www.gradio.app), and the underlying text generation is powered by [Hugging Face Inference API](https://huggingface.co/inference-api). The text generation model might"
346
+ "be changed over time, but [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) is selected for now.",
347
+ elem_classes=['center', 'small-big'])
348
+
349
+ with gr.Row():
350
+ with gr.Column(scale=2):
351
+ editor = gr.Textbox(lines=32, max_lines=32, elem_classes=['no-label', 'small-big-textarea'])
352
+ word_counter = gr.Markdown("0 words", elem_classes=['right'])
353
+
354
+ with gr.Column(scale=1):
355
+ with gr.Tab("Control"):
356
+ with gr.Column(elem_classes=['group-border']):
357
+ with gr.Row():
358
+ gen_btn = gr.Button("generate text", elem_classes=['control-label-font', 'control-button'])
359
+ gen_alt_btn = gr.Button("generate alternatives", elem_classes=['control-label-font', 'control-button'])
360
+
361
+ with gr.Column():
362
+ with gr.Row(visible=False) as first_alt:
363
+ gr.Markdown("↳", scale=1, elem_classes=['wrap'])
364
+ alt_btn1 = gr.Button("Alternative 1", elem_classes=['alt-button'], scale=5)
365
+
366
+ with gr.Row(visible=False) as second_alt:
367
+ gr.Markdown("↳", scale=1, elem_classes=['wrap'])
368
+ alt_btn2 = gr.Button("Alternative 2", elem_classes=['alt-button'], scale=5)
369
+
370
+ with gr.Row(visible=False) as third_alt:
371
+ gr.Markdown("↳", scale=1, elem_classes=['wrap'])
372
+ alt_btn3 = gr.Button("Alternative 3", elem_classes=['alt-button'], scale=5)
373
+
374
+ with gr.Row(elem_classes=['group-border']):
375
+ with gr.Column():
376
+ gr.Markdown("'Write the next paragraph based on the following stories so far.' is the default prompt when clicking `generate text`, and the text so far will always be attached to the end. By giving your own prompt, only the default prompt will be replaced.")
377
+
378
+ with gr.Column(elem_classes=['no-gap']):
379
+ gen_with_prompt_btn = gr.Button("generate text with custom prompt", elem_classes=['control-label-font', 'control-button'])
380
+ prompt = gr.Textbox(placeholder="enter prompt: ", elem_classes=['no-label'])
381
+
382
+ with gr.Column(elem_classes=['group-border']):
383
+ with gr.Row():
384
+ selected_text = gr.Markdown("Selected text will be displayed in this area", elem_classes=['highlighted-text'])
385
+
386
+ with gr.Row():
387
+ with gr.Column(elem_classes=['no-gap']):
388
+ replace_sel_btn = gr.Button("replace selection", elem_classes=['control-label-font', 'control-button'])
389
+ replace_type = gr.Dropdown(choices=['word', 'sentense', 'phrase', 'paragraph'], value='sentense', interactive=True, elem_classes=['no-label'])
390
+
391
+ with gr.Row():
392
+ with gr.Column(elem_classes=['no-gap']):
393
+ rewrite_sel_btn = gr.Button("rewrite selection", elem_classes=['control-label-font', 'control-button'])
394
+ rewrite_prompt = gr.Textbox(placeholder="Rewrite the text: ", elem_classes=['no-label'])
395
+
396
+ with gr.Tab("Chatting"):
397
+ chatbot = gr.Chatbot([], elem_classes=['no-label-chatbot'])
398
+ chat_txt = gr.Textbox(placeholder="enter question", elem_classes=['no-label'])
399
+
400
+ with gr.Row():
401
+ clear_btn = gr.Button("clear", elem_classes=['control-label-font', 'control-button'])
402
+ regen_btn = gr.Button("regenerate", elem_classes=['control-label-font', 'control-button'])
403
+
404
+ editor.change(
405
+ fn=None,
406
+ inputs=[editor],
407
+ outputs=[word_counter, selected_text],
408
+ _js="(e) => [e.split(/\s+/).length, '']"
409
+ )
410
+
411
+ editor.select(
412
+ fn=select,
413
+ inputs=[editor],
414
+ outputs=[selected_text, sel_index_from, sel_index_to],
415
+ show_progress='minimal'
416
+ )
417
+
418
+ gen_btn.click(
419
+ fn=gen_txt,
420
+ inputs=[editor],
421
+ outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt]
422
+ )
423
+
424
+ gen_alt_btn.click(
425
+ fn=gen_alt,
426
+ inputs=[editor, num_enabled_alts, alt_btn1, alt_btn2, alt_btn3],
427
+ outputs=[num_enabled_alts, gen_alt_btn, first_alt, alt_btn1, second_alt, alt_btn2, third_alt, alt_btn3],
428
+ )
429
+
430
+ alt_btn1.click(
431
+ fn=fill_with_gen,
432
+ inputs=[alt_btn1, editor],
433
+ outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt]
434
+ )
435
+ alt_btn2.click(
436
+ fn=fill_with_gen,
437
+ inputs=[alt_btn2, editor],
438
+ outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt]
439
+ )
440
+ alt_btn3.click(
441
+ fn=fill_with_gen,
442
+ inputs=[alt_btn3, editor],
443
+ outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt]
444
+ )
445
+
446
+ gen_with_prompt_btn.click(
447
+ gen_txt_with_prompt,
448
+ inputs=[editor, prompt],
449
+ outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt]
450
+ )
451
+
452
+ replace_sel_btn.click(
453
+ fn=replace_sel,
454
+ inputs=[editor, replace_type, selected_text, sel_index_from, sel_index_to],
455
+ outputs=[editor, selected_text, sel_index_from, sel_index_to],
456
+ show_progress='minimal'
457
+ )
458
+
459
+ chat_txt.submit(
460
+ fn=chat,
461
+ inputs=[editor, chat_txt, chatbot, chat_history],
462
+ outputs=[chat_txt, chatbot, chat_history]
463
+ )
464
+
465
+ regen_btn.click(
466
+ fn=regen_chat,
467
+ inputs=[editor, chat_txt, chatbot, chat_history],
468
+ outputs=[chat_txt, chatbot, chat_history]
469
+ )
470
+
471
+ demo.launch()