Spaces:
Runtime error
Runtime error
import os | |
import json | |
import requests | |
import gradio as gr | |
from pingpong import PingPong | |
from pingpong.pingpong import PPManager | |
from pingpong.pingpong import PromptFmt | |
from pingpong.pingpong import UIFmt | |
from pingpong.gradio import GradioChatUIFmt | |
class LLaMA2ChatPromptFmt(PromptFmt): | |
def ctx(cls, context): | |
if context is None or context == "": | |
return "" | |
else: | |
return f"""<<SYS>> | |
{context} | |
<</SYS>> | |
""" | |
def prompt(cls, pingpong, truncate_size): | |
ping = pingpong.ping[:truncate_size] | |
pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size] | |
return f"""[INST] {ping} [/INST] {pong}""" | |
class LLaMA2ChatPPManager(PPManager): | |
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None): | |
if to_idx == -1 or to_idx >= len(self.pingpongs): | |
to_idx = len(self.pingpongs) | |
results = fmt.ctx(self.ctx) | |
for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]): | |
results += fmt.prompt(pingpong, truncate_size=truncate_size) | |
return results | |
class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager): | |
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt): | |
if to_idx == -1 or to_idx >= len(self.pingpongs): | |
to_idx = len(self.pingpongs) | |
results = [] | |
for pingpong in self.pingpongs[from_idx:to_idx]: | |
results.append(fmt.ui(pingpong)) | |
return results | |
TOKEN = os.getenv('HF_TOKEN') | |
MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf' | |
STYLES = """ | |
.small-big { | |
font-size: 12pt !important; | |
} | |
.small-big-textarea > label > textarea { | |
font-size: 12pt !important; | |
} | |
.highlighted-text { | |
background: yellow; | |
overflow-wrap: break-word; | |
} | |
.no-gap { | |
gap: 0px !important; | |
} | |
.group-border { | |
padding: 10px; | |
border-width: 1px; | |
border-radius: 10px; | |
border-color: gray; | |
border-style: dashed; | |
} | |
.control-label-font { | |
font-size: 13pt !important; | |
} | |
.control-button { | |
background: none !important; | |
border-color: #69ade2 !important; | |
border-width: 2px !important; | |
color: #69ade2 !important; | |
} | |
.center { | |
text-align: center; | |
} | |
.right { | |
text-align: right; | |
} | |
.no-label { | |
padding: 0px !important; | |
} | |
.no-label > label > span { | |
display: none; | |
} | |
.no-label-chatbot { | |
border: none !important; | |
box-shadow: none !important; | |
height: 520px !important; | |
} | |
.no-label-chatbot > div > div:nth-child(1) { | |
display: none; | |
} | |
.left-margin-30 { | |
padding-left: 30px !important; | |
} | |
.left { | |
text-align: left !important; | |
} | |
.alt-button { | |
color: gray !important; | |
border-width: 1px !important; | |
background: none !important; | |
border-color: gray !important; | |
text-align: justify !important; | |
} | |
.white-text { | |
color: #000 !important; | |
} | |
""" | |
def get_new_ppm(ping): | |
ppm = LLaMA2ChatPPManager() | |
ppm.ctx = """\ | |
You are a helpful, respectful and honest writing helper. Always write stories that suites to query. | |
You DO NOT give explanation but just stories. For instance, do not say such as "Sure! Here's a short paragraph to start a short story:""" | |
ppm.add_pingpong(PingPong(ping, '')) | |
return ppm | |
def get_new_ppm_for_chat(): | |
ppm = GradioLLaMA2ChatPPManager() | |
return ppm | |
def gen_text(prompt, hf_model='meta-llama/Llama-2-70b-chat-hf', hf_token=None, parameters=None): | |
if hf_token is None: | |
raise ValueError("Hugging Face Token is not set") | |
if parameters is None: | |
parameters = { | |
'max_new_tokens': 512, | |
'do_sample': True, | |
'return_full_text': False, | |
'temperature': 1.0, | |
'top_k': 50, | |
# 'top_p': 1.0, | |
'repetition_penalty': 1.2 | |
} | |
url = f'https://api-inference.huggingface.co/models/{hf_model}' | |
headers={ | |
'Authorization': f'Bearer {hf_token}', | |
'Content-type': 'application/json' | |
} | |
data = { | |
'inputs': prompt, | |
'stream': False, | |
'options': { | |
'use_cache': False, | |
}, | |
'parameters': parameters | |
} | |
r = requests.post( | |
url, | |
headers=headers, | |
data=json.dumps(data) | |
) | |
if r.reason != 'OK': | |
raise ValueError("Response other than 200") | |
return json.loads(r.content.decode("utf-8"))[0]['generated_text'] | |
def select(editor, evt: gr.SelectData): | |
return [ | |
evt.value, | |
evt.index[0], | |
evt.index[1] | |
] | |
def get_gen_txt(editor, prompt, only_gen_text=False): | |
if editor.strip() == '': | |
ppm = get_new_ppm('Write a short paragraph to start a short story for me') | |
else: | |
ppm = get_new_ppm(f"""{prompt} | |
-------------------------------- | |
{editor}""") | |
try: | |
txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
if only_gen_text: | |
return txt + "\n\n" | |
else: | |
return editor + txt + "\n\n" | |
except ValueError as e: | |
print(f"something went wrong - {e}") | |
return editor | |
def gen_txt(editor): | |
return [ | |
get_gen_txt(editor, "Write the next paragraph based on the following stories so far."), | |
0, | |
gr.update(interactive=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(interactive=True) | |
] | |
def gen_txt_with_prompt(editor, prompt): | |
return [ | |
get_gen_txt(editor, prompt), | |
0, | |
gr.update(interactive=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(interactive=True) | |
] | |
def chat_gen(editor, chat_txt, chatbot, ppm, regen=False): | |
ppm.ctx = f"""\ | |
You are a helpful, respectful and honest assistant. | |
you must consider multi-turn conversations. | |
Answer to questions based on the written stories so far as below | |
---------------- | |
{editor} | |
""" | |
if regen: | |
last_pingpong = ppm.pop_pingpong() | |
chat_txt = last_pingpong.ping | |
ppm.add_pingpong(PingPong(chat_txt, '')) | |
try: | |
txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
ppm.add_pong(txt) | |
except ValueError as e: | |
print(f"something went wrong - {e}") | |
return [ | |
"", | |
ppm.build_uis(), | |
ppm | |
] | |
def chat(editor, chat_txt, chatbot, ppm): | |
return chat_gen(editor, chat_txt, chatbot, ppm, regen=False) | |
def regen_chat(editor, chat_txt, chatbot, ppm): | |
return chat_gen(editor, chat_txt, chatbot, ppm, regen=True) | |
def get_new_ppm_for_range(): | |
ppm = LLaMA2ChatPPManager() | |
ppm.ctx = """\ | |
You are a helpful, respectful and honest writing helper. Always write text that suites to query. | |
You DO NOT give explanation but just stories. DO NOT say such as 'Sure! Here's a short paragraph to start a short story:' or 'Sure, here is a revised version of ....:' | |
""" | |
return ppm | |
def replace_sel(editor, replace_type, selected_text, sel_index_from, sel_index_to): | |
ppm = get_new_ppm_for_range() | |
ping = f"""replace {selected_text} in a single {replace_type} based on the story below | |
---------------- | |
{editor} | |
""" | |
ppm.add_pingpong(PingPong(ping, '')) | |
try: | |
txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
ppm.add_pong(txt) | |
except ValueError as e: | |
print(f"something went wrong - {e}") | |
return [ | |
f"{editor[:sel_index_from]} {txt} {editor[sel_index_to:]}", | |
"", | |
0, | |
0 | |
] | |
def gen_alt(editor, num_enabled_alts, alt_btn1, alt_btn2, alt_btn3): | |
if num_enabled_alts < 3: | |
gen_txt = get_gen_txt(editor, "Write the next paragraph based on the following stories so far.", only_gen_text=True) | |
return [ | |
min(num_enabled_alts+1, 3), | |
gr.update(interactive=False if num_enabled_alts >=2 else True), | |
gr.update(visible=True if num_enabled_alts >=0 else False), | |
gr.update(value=gen_txt if num_enabled_alts == 0 else alt_btn1), | |
gr.update(visible=True if num_enabled_alts >=1 else False), | |
gr.update(value=gen_txt if num_enabled_alts == 1 else alt_btn2), | |
gr.update(visible=True if num_enabled_alts >=2 else False), | |
gr.update(value=gen_txt if num_enabled_alts == 2 else alt_btn3), | |
" " | |
] | |
def fill_with_gen(alt_txt, editor): | |
return [ | |
editor + alt_txt, | |
0, | |
gr.update(interactive=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
with gr.Blocks(css=STYLES) as demo: | |
num_enabled_alts = gr.State(0) | |
sel_index_from = gr.State(0) | |
sel_index_to = gr.State(0) | |
chat_history = gr.State(get_new_ppm_for_chat()) | |
gr.Markdown("# Co-writing with AI", elem_classes=['center']) | |
gr.Markdown( | |
"This application is designed for you to collaborate with LLM to co-write stories. It is inspired by [Wordcraft project](https://wordcraft-writers-workshop.appspot.com/) from Google's PAIR and Magenta teams. " | |
"This application built on [Gradio](https://www.gradio.app), and the underlying text generation is powered by [Hugging Face Inference API](https://huggingface.co/inference-api). The text generation model might" | |
"be changed over time, but [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) is selected for now.", | |
elem_classes=['center', 'small-big']) | |
progress_bar = gr.Textbox(elem_classes=['no-label']) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
editor = gr.Textbox(lines=32, max_lines=32, elem_classes=['no-label', 'small-big-textarea']) | |
word_counter = gr.Markdown("0 words", elem_classes=['right']) | |
with gr.Column(scale=1): | |
with gr.Tab("Control"): | |
with gr.Column(elem_classes=['group-border']): | |
with gr.Column(): | |
gr.Markdown("`generate text` button generate continued text and attach it to the end. on the other hand, `generate alternatives` button generate alternate texts up to 3 and let you choose one of them.") | |
with gr.Row(): | |
gen_btn = gr.Button("generate text", elem_classes=['control-label-font', 'control-button']) | |
gen_alt_btn = gr.Button("generate alternatives", elem_classes=['control-label-font', 'control-button']) | |
with gr.Column(): | |
with gr.Row(visible=False) as first_alt: | |
gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
alt_btn1 = gr.Button("Alternative 1", elem_classes=['alt-button'], scale=8) | |
with gr.Row(visible=False) as second_alt: | |
gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
alt_btn2 = gr.Button("Alternative 2", elem_classes=['alt-button'], scale=8) | |
with gr.Row(visible=False) as third_alt: | |
gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
alt_btn3 = gr.Button("Alternative 3", elem_classes=['alt-button'], scale=8) | |
with gr.Row(elem_classes=['group-border']): | |
with gr.Column(): | |
gr.Markdown("'Write the next paragraph based on the following stories so far.' is the default prompt when clicking `generate text`, and the text so far will always be attached to the end. By giving your own prompt, only the default prompt will be replaced.") | |
with gr.Column(elem_classes=['no-gap']): | |
gen_with_prompt_btn = gr.Button("generate text with custom prompt", elem_classes=['control-label-font', 'control-button']) | |
prompt = gr.Textbox(placeholder="enter prompt: ", elem_classes=['no-label']) | |
with gr.Column(elem_classes=['group-border']): | |
with gr.Row(): | |
selected_text = gr.Markdown("Selected text will be displayed in this area", elem_classes=['highlighted-text']) | |
with gr.Row(): | |
with gr.Column(elem_classes=['no-gap']): | |
replace_sel_btn = gr.Button("replace selection", elem_classes=['control-label-font', 'control-button']) | |
replace_type = gr.Dropdown(choices=['word', 'sentense', 'phrase', 'paragraph'], value='sentense', interactive=True, elem_classes=['no-label']) | |
with gr.Row(): | |
with gr.Column(elem_classes=['no-gap']): | |
rewrite_sel_btn = gr.Button("rewrite selection", elem_classes=['control-label-font', 'control-button']) | |
rewrite_prompt = gr.Textbox(placeholder="Rewrite the text: ", elem_classes=['no-label']) | |
with gr.Tab("Chatting"): | |
chatbot = gr.Chatbot([], elem_classes=['no-label-chatbot']) | |
chat_txt = gr.Textbox(placeholder="enter question", elem_classes=['no-label']) | |
with gr.Row(): | |
clear_btn = gr.Button("clear", elem_classes=['control-label-font', 'control-button']) | |
regen_btn = gr.Button("regenerate", elem_classes=['control-label-font', 'control-button']) | |
editor.change( | |
fn=None, | |
inputs=[editor], | |
outputs=[word_counter, selected_text], | |
_js="(e) => [e.split(/\s+/).length, '']" | |
) | |
editor.select( | |
fn=select, | |
inputs=[editor], | |
outputs=[selected_text, sel_index_from, sel_index_to], | |
show_progress='minimal' | |
) | |
gen_btn.click( | |
lambda: (gr.update(interactive=False), gr.update(interactive=False)), | |
inputs=None, | |
outputs=[gen_btn, gen_alt_btn] | |
).then( | |
fn=gen_txt, | |
inputs=[editor], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt, gen_btn] | |
) | |
gen_alt_btn.click( | |
lambda: gr.update(interactive=False), | |
inputs=None, | |
outputs=[gen_alt_btn] | |
).then( | |
fn=gen_alt, | |
inputs=[editor, num_enabled_alts, alt_btn1, alt_btn2, alt_btn3], | |
outputs=[num_enabled_alts, gen_alt_btn, first_alt, alt_btn1, second_alt, alt_btn2, third_alt, alt_btn3, progress_bar], | |
) | |
alt_btn1.click( | |
fn=fill_with_gen, | |
inputs=[alt_btn1, editor], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
) | |
alt_btn2.click( | |
fn=fill_with_gen, | |
inputs=[alt_btn2, editor], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
) | |
alt_btn3.click( | |
fn=fill_with_gen, | |
inputs=[alt_btn3, editor], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
) | |
gen_with_prompt_btn.click( | |
gen_txt_with_prompt, | |
inputs=[editor, prompt], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
) | |
replace_sel_btn.click( | |
fn=replace_sel, | |
inputs=[editor, replace_type, selected_text, sel_index_from, sel_index_to], | |
outputs=[editor, selected_text, sel_index_from, sel_index_to, progress_bar], | |
show_progress='minimal' | |
) | |
chat_txt.submit( | |
fn=chat, | |
inputs=[editor, chat_txt, chatbot, chat_history], | |
outputs=[chat_txt, chatbot, chat_history] | |
) | |
regen_btn.click( | |
fn=regen_chat, | |
inputs=[editor, chat_txt, chatbot, chat_history], | |
outputs=[chat_txt, chatbot, chat_history] | |
) | |
demo.launch() |