import os import gradio as gr from text_generation import Client # HF-hosted endpoint for testing purposes (requires an HF API token) API_TOKEN = os.environ.get("API_TOKEN", None) CURRENT_CLIENT = Client("https://afrts4trc759c6eq.us-east-1.aws.endpoints.huggingface.cloud/generate_stream", timeout=120, headers={ "Accept": "application/json", "Authorization": f"Bearer {API_TOKEN}", "Content-Type": "application/json"} ) DEFAULT_HEADER = os.environ.get("HEADER", "") DEFAULT_USER_NAME = os.environ.get("USER_NAME", "user") DEFAULT_ASSISTANT_NAME = os.environ.get("ASSISTANT_NAME", "assistant") DEFAULT_SEPARATOR = os.environ.get("SEPARATOR", "<|im_end|>") PROMPT_TEMPLATE = "<|im_start|>{user_name}\n{query}{separator}\n<|im_start|>{assistant_name}\n{response}" repo = None def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() return total_inputs def has_no_history(chatbot, history): return not chatbot and not history def generate( user_message, chatbot, history, temperature, top_p, max_new_tokens, repetition_penalty, header, user_name, assistant_name, separator ): # Don't return meaningless message when the input is empty if not user_message: print("Empty input") history.append(user_message) past_messages = [] for data in chatbot: user_data, model_data = data past_messages.extend( [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] ) print(past_messages) if len(past_messages) < 1: prompt = header + PROMPT_TEMPLATE.format(user_name=user_name, query=user_message, assistant_name=assistant_name, response="", separator=separator) else: prompt = header for i in range(0, len(past_messages), 2): intermediate_prompt = PROMPT_TEMPLATE.format(user_name=user_name, query=past_messages[i]["content"], assistant_name=assistant_name, response=past_messages[i + 1]["content"], separator=separator) # print(prompt, separator, intermediate_prompt) prompt = prompt + intermediate_prompt + separator + "\n" # print(prompt) prompt = prompt + PROMPT_TEMPLATE.format(user_name=user_name, query=user_message, assistant_name=assistant_name, response="", separator=separator) temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=40, repetition_penalty=repetition_penalty, do_sample=True, truncate=1024, # seed=42, # stop_sequences=[user_name, DEFAULT_SEPARATOR] stop_sequences=[DEFAULT_SEPARATOR] ) # print(prompt) stream = CURRENT_CLIENT.generate_stream( prompt, **generate_kwargs, ) output = "" for idx, response in enumerate(stream): # print(response.token) if response.token.text == '': pass # print(response.token.text) # break if response.token.special: continue output += response.token.text if idx == 0: history.append(" " + output) else: history[-1] = output chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] # chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] yield chat, history, user_message, "" return chat, history, user_message, "" def clear_chat(): return [], [] title = """

CroissantLLMChat Playground 🥐

""" custom_css = """ #banner-image { display: block; margin-left: auto; margin-right: auto; } #chat-message { font-size: 14px; min-height: 300px; } """ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: gr.HTML(title) with gr.Row(): with gr.Column(): gr.Markdown( """ # Demo platform for 🥐 CroissantLLMChat The model is of small size (1.3B), about 130 times smaller than GPT3. As such, it logically exhibits reduced understanding, reasoning and knowlege capacities. For industrial uses, we recommend finetuning the model, but trained a Chat version to allow for experimenting and showcase the capabilities for it's size. We recommend testing it for open-ended writing tasks, tips, translations, etc... The model can hallucinate and generate incorrect or even toxic content. The demo is linked to an endpoint that auto-shutdowns after 15mn. If error message appears, wait about 5 minutes and test again once the server is back up ! """ ) with gr.Row(): with gr.Group(): output = gr.Markdown() chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") with gr.Row(): with gr.Column(scale=3): user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") with gr.Row(): send_button = gr.Button("Send", elem_id="send-btn", visible=True) clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True) with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"): temperature = gr.Slider( label="Temperature", value=0.5, minimum=0.1, maximum=1.0, step=0.1, interactive=True, info="Higher values produce more diverse outputs", ) top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.9, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ) max_new_tokens = gr.Slider( label="Max new tokens", value=512, minimum=0, maximum=1024, step=4, interactive=True, info="The maximum numbers of new tokens", ) repetition_penalty = gr.Slider( label="Repetition Penalty", value=1.2, minimum=0.0, maximum=10, step=0.1, interactive=True, info="The parameter for repetition penalty. 1.0 means no penalty.", ) with gr.Accordion(label="Prompt", open=False, elem_id="prompt-accordion"): header = gr.Textbox( label="Header instructions", value=DEFAULT_HEADER, interactive=True, info="Instructions given to the assistant at the beginning of the prompt", ) user_name = gr.Textbox( label="User name", value=DEFAULT_USER_NAME, interactive=True, info="Name to be given to the user in the prompt", ) assistant_name = gr.Textbox( label="Assistant name", value=DEFAULT_ASSISTANT_NAME, interactive=True, info="Name to be given to the assistant in the prompt", ) separator = gr.Textbox( label="Separator", value=DEFAULT_SEPARATOR, interactive=True, info="Character to be used when the speaker changes in the prompt", ) history = gr.State([]) last_user_message = gr.State("") user_message.submit( generate, inputs=[ user_message, chatbot, history, temperature, top_p, max_new_tokens, repetition_penalty, header, user_name, assistant_name, separator ], outputs=[chatbot, history, last_user_message, user_message], ) send_button.click( generate, inputs=[ user_message, chatbot, history, temperature, top_p, max_new_tokens, repetition_penalty, header, user_name, assistant_name, separator ], outputs=[chatbot, history, last_user_message, user_message], ) clear_chat_button.click(clear_chat, outputs=[chatbot, history]) demo.queue().launch()