|
import gradio as gr |
|
import torch |
|
import sys |
|
import html |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from threading import Thread |
|
|
|
model_name_or_path = 'TencentARC/LLaMA-Pro-8B-Instruct' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False) |
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path) |
|
|
|
model.half().cuda() |
|
|
|
def convert_message(message): |
|
message_text = "" |
|
if message["content"] is None and message["role"] == "assistant": |
|
message_text += "<|assistant|>\n" |
|
elif message["role"] == "system": |
|
message_text += "<|system|>\n" + message["content"].strip() + "\n" |
|
elif message["role"] == "user": |
|
message_text += "<|user|>\n" + message["content"].strip() + "\n" |
|
elif message["role"] == "assistant": |
|
message_text += "<|assistant|>\n" + message["content"].strip() + "\n" |
|
else: |
|
raise ValueError("Invalid role: {}".format(message["role"])) |
|
|
|
|
|
message_text = html.unescape(message_text) |
|
|
|
message_text = message_text.replace("<br>", "\n") |
|
return message_text |
|
|
|
def convert_history(chat_history, max_input_length=1024): |
|
history_text = "" |
|
idx = len(chat_history) - 1 |
|
|
|
while len(tokenizer(history_text).input_ids) < max_input_length and idx >= 0: |
|
user_message, chatbot_message = chat_history[idx] |
|
user_message = convert_message({"role": "user", "content": user_message}) |
|
chatbot_message = convert_message({"role": "assistant", "content": chatbot_message}) |
|
history_text = user_message + chatbot_message + history_text |
|
idx = idx - 1 |
|
|
|
if history_text == "": |
|
history_text = "<|assistant|>\n" |
|
return history_text |
|
|
|
@torch.inference_mode() |
|
def instruct(instruction, max_token_output=1024): |
|
input_text = instruction |
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) |
|
input_ids = tokenizer(input_text, return_tensors='pt', truncation=False) |
|
input_ids["input_ids"] = input_ids["input_ids"].cuda() |
|
input_ids["attention_mask"] = input_ids["attention_mask"].cuda() |
|
generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=max_token_output) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
return streamer |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Tab("QA Demo"): |
|
with gr.Row(): |
|
instruction = gr.Textbox(label="Input") |
|
output = gr.Textbox(label="Output") |
|
greet_btn = gr.Button("Submit") |
|
def yield_instruct(instruction): |
|
|
|
instruction = "<|user|>\n" + instruction + "\n<|assistant|>\n" |
|
output = "" |
|
for token in instruct(instruction): |
|
output += token |
|
yield output |
|
greet_btn.click(fn=yield_instruct, inputs=[instruction], outputs=output, api_name="greet") |
|
|
|
with gr.Tab("Chatbot"): |
|
chatbot = gr.Chatbot([], elem_id="chatbot") |
|
msg = gr.Textbox() |
|
clear = gr.Button("Clear") |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history): |
|
prompt = convert_history(history) |
|
streaming_out = instruct(prompt) |
|
history[-1][1] = "" |
|
for new_token in streaming_out: |
|
history[-1][1] += new_token |
|
yield history |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(share=True) |