File size: 4,887 Bytes
6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de f99efcc 6ef31de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from model.fastchat.conversation import (Conversation, SeparatorStyle,
compute_skip_echo_len)
from model.fastchat.serve.inference import ChatIO, generate_stream, load_model
class SimpleChatIO(ChatIO):
def prompt_for_input(self, role) -> str:
return input(f"{role}: ")
def prompt_for_output(self, role: str):
print(f"{role}: ", end="", flush=True)
def stream_output(self, output_stream, skip_echo_len: int):
pre = 0
for outputs in output_stream:
outputs = outputs[skip_echo_len:].strip()
outputs = outputs.split(" ")
now = len(outputs) - 1
if now > pre:
print(" ".join(outputs[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(outputs[pre:]), flush=True)
return " ".join(outputs)
class VicunaChatBot:
def __init__(
self,
model_path: str,
device: str,
num_gpus: str,
max_gpu_memory: str,
load_8bit: bool,
ChatIO: ChatIO,
debug: bool,
):
self.model_path = model_path
self.device = device
self.chatio = ChatIO
self.debug = debug
self.model, self.tokenizer = load_model(self.model_path, device,
num_gpus, max_gpu_memory,
load_8bit, debug)
def chat(self, inp: str, temperature: float, max_new_tokens: int,
conv: Conversation):
""" Vicuna as a chatbot. """
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
generate_stream_func = generate_stream
prompt = conv.get_prompt()
skip_echo_len = compute_skip_echo_len(self.model_path, conv, prompt)
stop_str = (
conv.sep if conv.sep_style
in [SeparatorStyle.SINGLE, SeparatorStyle.BAIZE] else None)
params = {
"model": self.model_path,
"prompt": prompt,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop": stop_str,
}
print(prompt)
self.chatio.prompt_for_output(conv.roles[1])
output_stream = generate_stream_func(self.model, self.tokenizer,
params, self.device)
outputs = self.chatio.stream_output(output_stream, skip_echo_len)
# NOTE: strip is important to align with the training data.
conv.messages[-1][-1] = outputs.strip()
return outputs, conv
class VicunaHandler:
""" VicunaHandler is a class that handles the communication between the
frontend and the backend. """
def __init__(self, config):
self.config = config
self.chat_io = SimpleChatIO()
self.chatbot = VicunaChatBot(
self.config['model_path'],
self.config['device'],
self.config['num_gpus'],
self.config['max_gpu_memory'],
self.config['load_8bit'],
self.chat_io,
self.config['debug'],
)
def chat(self):
""" Chat with the Vicuna. """
pass
def gr_chatbot_init(self, caption: str):
""" Initialise the chatbot for gradio. """
template = self._construct_conversation(caption)
print("Chatbot initialised.")
return template.copy(), template.copy()
def gr_chat(self, inp, conv: Conversation):
""" Chat using gradio as the frontend. """
return self.chatbot.chat(inp, self.config['temperature'],
self.config['max_new_tokens'], conv)
def _construct_conversation(self, prompt):
""" Construct a conversation template.
Args:
prompt: the prompt for the conversation.
"""
user_message = "The following text described what you have " +\
"seen, found, heard and notice from a consecutive video." +\
" Some of the texts may not be accurate. " +\
"Try to conclude what happens in the video, " +\
"then answer my question based on your conclusion.\n" +\
"<video begin>\n" + prompt + "<video end>\n" +\
"Example: Is this a Video?"
user_message = user_message.strip()
print(user_message)
return Conversation(
system=
"A chat between a curious user and an artificial intelligence assistant answering quetions on videos."
"The assistant answers the questions based on the given video captions and speech in time order.",
roles=("USER", "ASSISTANT"),
messages=(("USER", user_message), ("ASSISTANT", "yes")),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
|