from strings import TITLE, ABSTRACT, BOTTOM_LINE from strings import DEFAULT_EXAMPLES from strings import SPECIAL_STRS from styles import PARENT_BLOCK_CSS from constants import num_of_characters_to_keep import time import gradio as gr from model import load_model from gen import get_output_batch, StreamModel from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process generation_config = get_generation_config( "./generation_config_default.yaml" ) model, tokenizer = load_model( base="decapoda-research/llama-13b-hf", finetuned="chansung/alpaca-lora-13b" ) stream_model = StreamModel(model, tokenizer) def chat_stream( context, instruction, state_chatbot, ): if len(context) > 1000 or len(instruction) > 300: raise gr.Error("context or prompt is too long!") bot_summarized_response = '' # user input should be appropriately formatted (don't be confused by the function name) instruction_display = common_post_process(instruction) instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context) if conv_length > num_of_characters_to_keep: instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context, partial=True)[0] state_chatbot = state_chatbot + [ ( None, "![](https://s2.gifyu.com/images/icons8-loading-circle.gif) too long conversations, so let's summarize..." ) ] yield (state_chatbot, state_chatbot, context) bot_summarized_response = get_output_batch( model, tokenizer, [instruction_prompt], generation_config )[0] bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip() state_chatbot[-1] = ( None, "✅ summarization is done and set as context" ) print(f"bot_summarized_response: {bot_summarized_response}") yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0] bot_response = stream_model( instruction_prompt, max_tokens=256, temperature=1, top_p=0.9 ) instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display state_chatbot = state_chatbot + [(instruction_display, None)] yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) prev_index = 0 agg_tokens = "" cutoff_idx = 0 for tokens in bot_response: tokens = tokens.strip() cur_token = tokens[prev_index:] if "#" in cur_token and agg_tokens == "": cutoff_idx = tokens.find("#") agg_tokens = tokens[cutoff_idx:] if agg_tokens != "": if len(agg_tokens) < len("### Instruction:") : agg_tokens = agg_tokens + cur_token elif len(agg_tokens) >= len("### Instruction:"): if tokens.find("### Instruction:") > -1: processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip()) state_chatbot[-1] = ( instruction_display, processed_response ) yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) break else: agg_tokens = "" cutoff_idx = 0 if agg_tokens == "": processed_response, to_exit = post_process_stream(tokens) state_chatbot[-1] = (instruction_display, processed_response) yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) if to_exit: break prev_index = len(tokens) yield ( state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip() ) def chat_batch( contexts, instructions, state_chatbots, ): state_results = [] ctx_results = [] instruct_prompts = [ generate_prompt(instruct, histories, ctx) for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) ] bot_responses = get_output_batch( model, tokenizer, instruct_prompts, generation_config ) bot_responses = post_processes_batch(bot_responses) for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) state_results.append(new_state_chatbot) return (state_results, state_results, ctx_results) def reset_textbox(): return gr.Textbox.update(value='') def reset_everything( context_txtbox, instruction_txtbox, state_chatbot): state_chatbot = [] return ( state_chatbot, state_chatbot, gr.Textbox.update(value=''), gr.Textbox.update(value=''), ) with gr.Blocks(css=PARENT_BLOCK_CSS) as demo: state_chatbot = gr.State([]) with gr.Column(elem_id='col_container'): gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}") with gr.Accordion("Context Setting", open=False): context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context") hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False) chatbot = gr.Chatbot(elem_id='chatbot', label="Alpaca-LoRA") instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction") with gr.Row(): cancel_btn = gr.Button(value="Cancel") reset_btn = gr.Button(value="Reset") with gr.Accordion("Helper Buttons", open=False): gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.") continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False) summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False) continue_btn = gr.Button(value="Continue") summarize_btn = gr.Button(value="Summarize") gr.Markdown("#### Examples") for _, (category, examples) in enumerate(DEFAULT_EXAMPLES.items()): with gr.Accordion(category, open=False): if category == "Identity": for item in examples: with gr.Accordion(item["title"], open=False): gr.Examples( examples=item["examples"], inputs=[ hidden_txtbox, context_txtbox, instruction_txtbox ], label=None ) else: for item in examples: with gr.Accordion(item["title"], open=False): gr.Examples( examples=item["examples"], inputs=[ hidden_txtbox, instruction_txtbox ], label=None ) gr.Markdown(f"{BOTTOM_LINE}") send_event = instruction_txtbox.submit( chat_stream, [context_txtbox, instruction_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], ) reset_event = instruction_txtbox.submit( reset_textbox, [], [instruction_txtbox], ) continue_event = continue_btn.click( chat_stream, [context_txtbox, continue_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], ) reset_continue_event = continue_btn.click( reset_textbox, [], [instruction_txtbox], ) summarize_event = summarize_btn.click( chat_stream, [context_txtbox, summrize_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], ) summarize_reset_event = summarize_btn.click( reset_textbox, [], [instruction_txtbox], ) cancel_btn.click( None, None, None, cancels=[ send_event, continue_event, summarize_event ] ) reset_btn.click( reset_everything, [context_txtbox, instruction_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox, instruction_txtbox], cancels=[ send_event, continue_event, summarize_event ] ) demo.queue( concurrency_count=1, max_size=100, ).launch( max_threads=5, server_name="0.0.0.0", )