Spaces:
Runtime error
Runtime error
from strings import TITLE, ABSTRACT, BOTTOM_LINE | |
from strings import DEFAULT_EXAMPLES | |
from strings import SPECIAL_STRS | |
from styles import PARENT_BLOCK_CSS | |
from constants import num_of_characters_to_keep | |
import time | |
import gradio as gr | |
from model import load_model | |
from gen import get_output_batch, StreamModel | |
from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process | |
generation_config = get_generation_config( | |
"./generation_config_default.yaml" | |
) | |
model, tokenizer = load_model( | |
base="decapoda-research/llama-13b-hf", | |
finetuned="chansung/alpaca-lora-13b" | |
) | |
stream_model = StreamModel(model, tokenizer) | |
def chat_stream( | |
context, | |
instruction, | |
state_chatbot, | |
): | |
if len(context) > 1000 or len(instruction) > 300: | |
raise gr.Error("context or prompt is too long!") | |
bot_summarized_response = '' | |
# user input should be appropriately formatted (don't be confused by the function name) | |
instruction_display = common_post_process(instruction) | |
instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context) | |
if conv_length > num_of_characters_to_keep: | |
instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context, partial=True)[0] | |
state_chatbot = state_chatbot + [ | |
( | |
None, | |
"![](https://s2.gifyu.com/images/icons8-loading-circle.gif) too long conversations, so let's summarize..." | |
) | |
] | |
yield (state_chatbot, state_chatbot, context) | |
bot_summarized_response = get_output_batch( | |
model, tokenizer, [instruction_prompt], generation_config | |
)[0] | |
bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip() | |
state_chatbot[-1] = ( | |
None, | |
"✅ summarization is done and set as context" | |
) | |
print(f"bot_summarized_response: {bot_summarized_response}") | |
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) | |
instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0] | |
bot_response = stream_model( | |
instruction_prompt, | |
max_tokens=256, | |
temperature=1, | |
top_p=0.9 | |
) | |
instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display | |
state_chatbot = state_chatbot + [(instruction_display, None)] | |
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) | |
prev_index = 0 | |
agg_tokens = "" | |
cutoff_idx = 0 | |
for tokens in bot_response: | |
tokens = tokens.strip() | |
cur_token = tokens[prev_index:] | |
if "#" in cur_token and agg_tokens == "": | |
cutoff_idx = tokens.find("#") | |
agg_tokens = tokens[cutoff_idx:] | |
if agg_tokens != "": | |
if len(agg_tokens) < len("### Instruction:") : | |
agg_tokens = agg_tokens + cur_token | |
elif len(agg_tokens) >= len("### Instruction:"): | |
if tokens.find("### Instruction:") > -1: | |
processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip()) | |
state_chatbot[-1] = ( | |
instruction_display, | |
processed_response | |
) | |
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) | |
break | |
else: | |
agg_tokens = "" | |
cutoff_idx = 0 | |
if agg_tokens == "": | |
processed_response, to_exit = post_process_stream(tokens) | |
state_chatbot[-1] = (instruction_display, processed_response) | |
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) | |
if to_exit: | |
break | |
prev_index = len(tokens) | |
yield ( | |
state_chatbot, | |
state_chatbot, | |
f"{context} {bot_summarized_response}".strip() | |
) | |
def chat_batch( | |
contexts, | |
instructions, | |
state_chatbots, | |
): | |
state_results = [] | |
ctx_results = [] | |
instruct_prompts = [ | |
generate_prompt(instruct, histories, ctx) | |
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) | |
] | |
bot_responses = get_output_batch( | |
model, tokenizer, instruct_prompts, generation_config | |
) | |
bot_responses = post_processes_batch(bot_responses) | |
for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): | |
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] | |
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) | |
state_results.append(new_state_chatbot) | |
return (state_results, state_results, ctx_results) | |
def reset_textbox(): | |
return gr.Textbox.update(value='') | |
def reset_everything( | |
context_txtbox, | |
instruction_txtbox, | |
state_chatbot): | |
state_chatbot = [] | |
return ( | |
state_chatbot, | |
state_chatbot, | |
gr.Textbox.update(value=''), | |
gr.Textbox.update(value=''), | |
) | |
with gr.Blocks(css=PARENT_BLOCK_CSS) as demo: | |
state_chatbot = gr.State([]) | |
with gr.Column(elem_id='col_container'): | |
gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}") | |
with gr.Accordion("Context Setting", open=False): | |
context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context") | |
hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False) | |
chatbot = gr.Chatbot(elem_id='chatbot', label="Alpaca-LoRA") | |
instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction") | |
with gr.Row(): | |
cancel_btn = gr.Button(value="Cancel") | |
reset_btn = gr.Button(value="Reset") | |
with gr.Accordion("Helper Buttons", open=False): | |
gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.") | |
continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False) | |
summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False) | |
continue_btn = gr.Button(value="Continue") | |
summarize_btn = gr.Button(value="Summarize") | |
gr.Markdown("#### Examples") | |
for _, (category, examples) in enumerate(DEFAULT_EXAMPLES.items()): | |
with gr.Accordion(category, open=False): | |
if category == "Identity": | |
for item in examples: | |
with gr.Accordion(item["title"], open=False): | |
gr.Examples( | |
examples=item["examples"], | |
inputs=[ | |
hidden_txtbox, context_txtbox, instruction_txtbox | |
], | |
label=None | |
) | |
else: | |
for item in examples: | |
with gr.Accordion(item["title"], open=False): | |
gr.Examples( | |
examples=item["examples"], | |
inputs=[ | |
hidden_txtbox, instruction_txtbox | |
], | |
label=None | |
) | |
gr.Markdown(f"{BOTTOM_LINE}") | |
send_event = instruction_txtbox.submit( | |
chat_stream, | |
[context_txtbox, instruction_txtbox, state_chatbot], | |
[state_chatbot, chatbot, context_txtbox], | |
) | |
reset_event = instruction_txtbox.submit( | |
reset_textbox, | |
[], | |
[instruction_txtbox], | |
) | |
continue_event = continue_btn.click( | |
chat_stream, | |
[context_txtbox, continue_txtbox, state_chatbot], | |
[state_chatbot, chatbot, context_txtbox], | |
) | |
reset_continue_event = continue_btn.click( | |
reset_textbox, | |
[], | |
[instruction_txtbox], | |
) | |
summarize_event = summarize_btn.click( | |
chat_stream, | |
[context_txtbox, summrize_txtbox, state_chatbot], | |
[state_chatbot, chatbot, context_txtbox], | |
) | |
summarize_reset_event = summarize_btn.click( | |
reset_textbox, | |
[], | |
[instruction_txtbox], | |
) | |
cancel_btn.click( | |
None, None, None, | |
cancels=[ | |
send_event, continue_event, summarize_event | |
] | |
) | |
reset_btn.click( | |
reset_everything, | |
[context_txtbox, instruction_txtbox, state_chatbot], | |
[state_chatbot, chatbot, context_txtbox, instruction_txtbox], | |
cancels=[ | |
send_event, continue_event, summarize_event | |
] | |
) | |
demo.queue( | |
concurrency_count=1, | |
max_size=100, | |
).launch( | |
max_threads=5, | |
server_name="0.0.0.0", | |
) |