""" Chat with a model with command line interface. Usage: python3 -m fastchat.serve.cli --model ~/model_weights/llama-7b """ import argparse import os import re from prompt_toolkit import PromptSession from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from prompt_toolkit.completion import WordCompleter from prompt_toolkit.history import InMemoryHistory from rich.console import Console from rich.markdown import Markdown from rich.live import Live from fastchat.serve.inference import ChatIO, question_loop, answer_loop class SimpleChatIO(ChatIO): def prompt_for_input(self, role) -> str: return input(f"{role}: ") def prompt_for_output(self, role: str): print(f"{role}: ", end="", flush=True) def stream_output(self, output_stream, skip_echo_len: int): pre = 0 for outputs in output_stream: outputs = outputs[skip_echo_len:].strip() outputs = outputs.split(" ") now = len(outputs) - 1 if now > pre: print(" ".join(outputs[pre:now]), end=" ", flush=True) pre = now print(" ".join(outputs[pre:]), flush=True) return " ".join(outputs) class RichChatIO(ChatIO): def __init__(self): self._prompt_session = PromptSession(history=InMemoryHistory()) self._completer = WordCompleter( words=["!exit", "!reset"], pattern=re.compile("$") ) self._console = Console() def prompt_for_input(self, role) -> str: self._console.print(f"[bold]{role}:") # TODO(suquark): multiline input has some issues. fix it later. prompt_input = self._prompt_session.prompt( completer=self._completer, multiline=False, auto_suggest=AutoSuggestFromHistory(), key_bindings=None, ) self._console.print() return prompt_input def prompt_for_output(self, role: str): self._console.print(f"[bold]{role}:") def stream_output(self, output_stream, skip_echo_len: int): """Stream output from a role.""" # TODO(suquark): the console flickers when there is a code block # above it. We need to cut off "live" when a code block is done. # Create a Live context for updating the console output with Live(console=self._console, refresh_per_second=4) as live: # Read lines from the stream for outputs in output_stream: accumulated_text = outputs[skip_echo_len:] if not accumulated_text: continue # Render the accumulated text as Markdown # NOTE: this is a workaround for the rendering "unstandard markdown" # in rich. The chatbots output treat "\n" as a new line for # better compatibility with real-world text. However, rendering # in markdown would break the format. It is because standard markdown # treat a single "\n" in normal text as a space. # Our workaround is adding two spaces at the end of each line. # This is not a perfect solution, as it would # introduce trailing spaces (only) in code block, but it works well # especially for console output, because in general the console does not # care about trailing spaces. lines = [] for line in accumulated_text.splitlines(): lines.append(line) if line.startswith("```"): # Code block marker - do not add trailing spaces, as it would # break the syntax highlighting lines.append("\n") else: lines.append(" \n") markdown = Markdown("".join(lines)) # Update the Live console output live.update(markdown) self._console.print() return outputs[skip_echo_len:] def main(args): if args.gpus: if args.num_gpus and len(args.gpus.split(",")) < int(args.num_gpus): raise ValueError(f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!") os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if args.style == "simple": chatio = SimpleChatIO() elif args.style == "rich": chatio = RichChatIO() else: raise ValueError(f"Invalid style for console: {args.style}") try: if args.answer_flag: print("answer loop") answer_loop( args.model_path, args.device, args.num_gpus, args.max_gpu_memory, args.load_8bit, args.conv_template, args.temperature, args.max_new_tokens, chatio, args.debug, args.question_path, args.caption_path, args.data_info_path, args.answer_path, # args.caption_path ) else: print("question loop") # detect if the caption.json is already there if os.path.exists(args.caption_path): print("caption.json already exists") # exit(0) question_loop( args.model_path, args.device, args.num_gpus, args.max_gpu_memory, args.load_8bit, args.conv_template, args.temperature, args.max_new_tokens, chatio, args.debug, args.question_path, args.caption_path, # args.caption_path ) except KeyboardInterrupt: print("exit...") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model-path", type=str, default="facebook/opt-350m", help="The path to the weights", ) parser.add_argument( "--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda" ) parser.add_argument( "--gpus", type=str, default=None, help="A single GPU like 1 or multiple GPUs like 0,2" ) parser.add_argument("--num-gpus", type=str, default="1") parser.add_argument( "--max-gpu-memory", type=str, help="The maximum memory per gpu. Use a string like '13Gib'", ) parser.add_argument( "--load-8bit", action="store_true", help="Use 8-bit quantization." ) parser.add_argument( "--conv-template", type=str, default=None, help="Conversation prompt template." ) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument( "--style", type=str, default="simple", choices=["simple", "rich"], help="Display style.", ) parser.add_argument("--debug", action="store_true") parser.add_argument("--question-path", type=str, default=None) parser.add_argument("--caption-path", type=str, default=None) parser.add_argument("--answer-flag", type=bool, default=False) parser.add_argument("--data-info-path", type=str, default="../Test_frameqa_question-balanced.csv") parser.add_argument("--answer-path", type=str, default="data_processed.json") # parser.add_argument("--prompt-path", type=str, default=None) args = parser.parse_args() main(args)