ChatVID / model /fastchat /serve /cli_caption.py
Yiqin's picture
init
6ef31de
raw
history blame
7.58 kB
"""
Chat with a model with command line interface.
Usage:
python3 -m fastchat.serve.cli --model ~/model_weights/llama-7b
"""
import argparse
import os
import re
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory
from rich.console import Console
from rich.markdown import Markdown
from rich.live import Live
from fastchat.serve.inference import ChatIO, question_loop, answer_loop
class SimpleChatIO(ChatIO):
def prompt_for_input(self, role) -> str:
return input(f"{role}: ")
def prompt_for_output(self, role: str):
print(f"{role}: ", end="", flush=True)
def stream_output(self, output_stream, skip_echo_len: int):
pre = 0
for outputs in output_stream:
outputs = outputs[skip_echo_len:].strip()
outputs = outputs.split(" ")
now = len(outputs) - 1
if now > pre:
print(" ".join(outputs[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(outputs[pre:]), flush=True)
return " ".join(outputs)
class RichChatIO(ChatIO):
def __init__(self):
self._prompt_session = PromptSession(history=InMemoryHistory())
self._completer = WordCompleter(
words=["!exit", "!reset"], pattern=re.compile("$")
)
self._console = Console()
def prompt_for_input(self, role) -> str:
self._console.print(f"[bold]{role}:")
# TODO(suquark): multiline input has some issues. fix it later.
prompt_input = self._prompt_session.prompt(
completer=self._completer,
multiline=False,
auto_suggest=AutoSuggestFromHistory(),
key_bindings=None,
)
self._console.print()
return prompt_input
def prompt_for_output(self, role: str):
self._console.print(f"[bold]{role}:")
def stream_output(self, output_stream, skip_echo_len: int):
"""Stream output from a role."""
# TODO(suquark): the console flickers when there is a code block
# above it. We need to cut off "live" when a code block is done.
# Create a Live context for updating the console output
with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream
for outputs in output_stream:
accumulated_text = outputs[skip_echo_len:]
if not accumulated_text:
continue
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in accumulated_text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines))
# Update the Live console output
live.update(markdown)
self._console.print()
return outputs[skip_echo_len:]
def main(args):
if args.gpus:
if args.num_gpus and len(args.gpus.split(",")) < int(args.num_gpus):
raise ValueError(f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!")
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if args.style == "simple":
chatio = SimpleChatIO()
elif args.style == "rich":
chatio = RichChatIO()
else:
raise ValueError(f"Invalid style for console: {args.style}")
try:
if args.answer_flag:
print("answer loop")
answer_loop(
args.model_path,
args.device,
args.num_gpus,
args.max_gpu_memory,
args.load_8bit,
args.conv_template,
args.temperature,
args.max_new_tokens,
chatio,
args.debug,
args.question_path,
args.caption_path,
args.data_info_path,
args.answer_path,
# args.caption_path
)
else:
print("question loop")
# detect if the caption.json is already there
if os.path.exists(args.caption_path):
print("caption.json already exists")
# exit(0)
question_loop(
args.model_path,
args.device,
args.num_gpus,
args.max_gpu_memory,
args.load_8bit,
args.conv_template,
args.temperature,
args.max_new_tokens,
chatio,
args.debug,
args.question_path,
args.caption_path,
# args.caption_path
)
except KeyboardInterrupt:
print("exit...")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
default="facebook/opt-350m",
help="The path to the weights",
)
parser.add_argument(
"--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda"
)
parser.add_argument(
"--gpus",
type=str,
default=None,
help="A single GPU like 1 or multiple GPUs like 0,2"
)
parser.add_argument("--num-gpus", type=str, default="1")
parser.add_argument(
"--max-gpu-memory",
type=str,
help="The maximum memory per gpu. Use a string like '13Gib'",
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization."
)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument(
"--style",
type=str,
default="simple",
choices=["simple", "rich"],
help="Display style.",
)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--question-path", type=str, default=None)
parser.add_argument("--caption-path", type=str, default=None)
parser.add_argument("--answer-flag", type=bool, default=False)
parser.add_argument("--data-info-path", type=str, default="../Test_frameqa_question-balanced.csv")
parser.add_argument("--answer-path", type=str, default="data_processed.json")
# parser.add_argument("--prompt-path", type=str, default=None)
args = parser.parse_args()
main(args)