import os import requests import json import gradio as gr from transformers import AutoTokenizer DESCRIPTION = """ # Demo: Breeze-7B-Instruct-v0.1 Breeze-7B is a language model family that builds on top of [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1), specifically intended for Traditional Chinese use. [Breeze-7B-Base](https://huggingface.co/MediaTek-Research/Breeze-7B-Base-v0.1) is the base model for the Breeze-7B series. It is suitable for use if you have substantial fine-tuning data to tune it for your specific use case. [Breeze-7B-Instruct](https://huggingface.co/MediaTek-Research/Breeze-7B-Instruct-v0.1) derives from the base model Breeze-7B-Base, making the resulting model amenable to be used as-is for commonly seen tasks. [Breeze-7B-Instruct-64k](https://huggingface.co/MediaTek-Research/Breeze-7B-Instruct-64k-v0.1) is a slightly modified version of Breeze-7B-Instruct to enable a 64k-token context length. Roughly speaking, that is equivalent to 88k Traditional Chinese characters. The current release version of Breeze-7B is v0.1. *A project by the members (in alphabetical order): Chan-Jan Hsu 許湛然, Chang-Le Liu 劉昶樂, Feng-Ting Liao 廖峰挺, Po-Chun Hsu 許博竣, Yi-Chang Chen 陳宜昌, and the supervisor Da-Shan Shiu 許大山.* **免責聲明: Breeze-7B-Instruct 和 Breeze-7B-Instruct-64k 並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。** """ LICENSE = """ """ DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan." API_URL = os.environ.get("API_URL") TOKEN = os.environ.get("TOKEN") HEADERS = { "accept": "application/json", "Authorization": f"Bearer {TOKEN}", "Content-Type": "application/json", } MODEL_NAME="breeze-7b-instruct-v01" PRESENCE_PENALTY=0 FREQUENCY_PENALTY=0 model_name = "MediaTek-Research/Breeze-7B-Instruct-v0.1" tokenizer = AutoTokenizer.from_pretrained(model_name) def refusal_condition(query): # 不要再問這些問題啦! is_including_tw = '台灣' in query or '台湾' in query or 'taiwan' in query.lower() or 'tw' in query.lower() or '中華民國' in query or '中华民国' in query is_including_cn = '中國' in query or '中国' in query or 'china' in query.lower() or 'cn' in query.lower() or '大陸' in query or '內地' in query or '大陆' in query or '内地' in query or '中華人民共和國' in query or '中华人民共和国' in query if is_including_tw and is_including_cn: return True return False with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) chatbot = gr.Chatbot() with gr.Row(): msg = gr.Textbox( container=False, show_label=False, placeholder='Type a message...', scale=10, ) submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0) with gr.Row(): retry_button = gr.Button('🔄 Retry', variant='secondary') undo_button = gr.Button('↩️ Undo', variant='secondary') clear = gr.Button('🗑️ Clear', variant='secondary') saved_input = gr.State() with gr.Accordion(label='Advanced options', open=False): system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=6) max_new_tokens = gr.Slider( label='Max new tokens', minimum=32, maximum=1024, step=1, value=512, ) temperature = gr.Slider( label='Temperature', minimum=0.01, maximum=0.5, step=0.01, value=0.01, ) top_p = gr.Slider( label='Top-p (nucleus sampling)', minimum=0.01, maximum=1.0, step=0.01, value=0.01, ) def user(user_message, history): return "", history + [[user_message, None]] def bot(history, max_new_tokens, temperature, top_p, system_prompt): chat_data = [] system_prompt = system_prompt.strip() if system_prompt: chat_data.append({"role": "system", "content": system_prompt}) for user_msg, assistant_msg in history: if user_msg is not None: chat_data.append({"role": "user", "content": user_msg}) if assistant_msg is not None: chat_data.append({"role": "assistant", "content": assistant_msg}) message = tokenizer.apply_chat_template(chat_data, tokenize=False) message = message[3:] # remove SOT token if refusal_condition(history[-1][0]): history = [['[安全拒答啟動]', '請清除再開啟對話']] yield history else: data = { "model": MODEL_NAME, "prompt": str(message), "temperature": float(temperature) + 0.01, "n": 1, "max_tokens": int(max_new_tokens), "stop": "", "top_p": float(top_p), "logprobs": 0, "echo": False, "presence_penalty": PRESENCE_PENALTY, "frequency_penalty": FREQUENCY_PENALTY, "stream": True, } with requests.post(API_URL, headers=HEADERS, data=json.dumps(data), stream=True) as r: for response in r.iter_lines(): if len(response) > 0: text = response.decode() if text != "data: [DONE]": if text.startswith("data: "): text = text[5:] delta = json.loads(text)["choices"][0]["text"] if history[-1][1] is None: history[-1][1] = delta else: history[-1][1] += delta yield history if history[-1][1].endswith(''): history[-1][1] = history[-1][1][:-4] yield history print('== Record ==\nQuery: {query}\nResponse: {response}'.format(query=repr(message), response=repr(history[-1][1]))) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( fn=bot, inputs=[ chatbot, max_new_tokens, temperature, top_p, system_prompt, ], outputs=chatbot ) submit_button.click( user, [msg, chatbot], [msg, chatbot], queue=False ).then( fn=bot, inputs=[ chatbot, max_new_tokens, temperature, top_p, system_prompt, ], outputs=chatbot ) def delete_prev_fn( history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: try: message, _ = history.pop() except IndexError: message = '' return history, message or '' def display_input(message: str, history: list[tuple[str, str]]) -> list[tuple[str, str]]: history.append((message, '')) return history retry_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).then( fn=bot, inputs=[ chatbot, max_new_tokens, temperature, top_p, system_prompt, ], outputs=chatbot, ) undo_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=lambda x: x, inputs=[saved_input], outputs=msg, api_name=False, queue=False, ) clear.click(lambda: None, None, chatbot, queue=False) gr.Markdown(LICENSE) demo.queue(concurrency_count=1, max_size=16) demo.launch()