|
""" |
|
Chat with a model with command line interface. |
|
|
|
Usage: |
|
python3 -m fastchat.serve.cli --model ~/model_weights/llama-7b |
|
""" |
|
import argparse |
|
import os |
|
import re |
|
|
|
from prompt_toolkit import PromptSession |
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory |
|
from prompt_toolkit.completion import WordCompleter |
|
from prompt_toolkit.history import InMemoryHistory |
|
from rich.console import Console |
|
from rich.markdown import Markdown |
|
from rich.live import Live |
|
|
|
from fastchat.serve.inference import ChatIO, question_loop, answer_loop |
|
|
|
|
|
class SimpleChatIO(ChatIO): |
|
def prompt_for_input(self, role) -> str: |
|
return input(f"{role}: ") |
|
|
|
def prompt_for_output(self, role: str): |
|
print(f"{role}: ", end="", flush=True) |
|
|
|
def stream_output(self, output_stream, skip_echo_len: int): |
|
pre = 0 |
|
for outputs in output_stream: |
|
outputs = outputs[skip_echo_len:].strip() |
|
outputs = outputs.split(" ") |
|
now = len(outputs) - 1 |
|
if now > pre: |
|
print(" ".join(outputs[pre:now]), end=" ", flush=True) |
|
pre = now |
|
print(" ".join(outputs[pre:]), flush=True) |
|
return " ".join(outputs) |
|
|
|
|
|
class RichChatIO(ChatIO): |
|
def __init__(self): |
|
self._prompt_session = PromptSession(history=InMemoryHistory()) |
|
self._completer = WordCompleter( |
|
words=["!exit", "!reset"], pattern=re.compile("$") |
|
) |
|
self._console = Console() |
|
|
|
def prompt_for_input(self, role) -> str: |
|
self._console.print(f"[bold]{role}:") |
|
|
|
prompt_input = self._prompt_session.prompt( |
|
completer=self._completer, |
|
multiline=False, |
|
auto_suggest=AutoSuggestFromHistory(), |
|
key_bindings=None, |
|
) |
|
self._console.print() |
|
return prompt_input |
|
|
|
def prompt_for_output(self, role: str): |
|
self._console.print(f"[bold]{role}:") |
|
|
|
def stream_output(self, output_stream, skip_echo_len: int): |
|
"""Stream output from a role.""" |
|
|
|
|
|
|
|
|
|
with Live(console=self._console, refresh_per_second=4) as live: |
|
|
|
for outputs in output_stream: |
|
accumulated_text = outputs[skip_echo_len:] |
|
if not accumulated_text: |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lines = [] |
|
for line in accumulated_text.splitlines(): |
|
lines.append(line) |
|
if line.startswith("```"): |
|
|
|
|
|
lines.append("\n") |
|
else: |
|
lines.append(" \n") |
|
markdown = Markdown("".join(lines)) |
|
|
|
live.update(markdown) |
|
self._console.print() |
|
return outputs[skip_echo_len:] |
|
|
|
|
|
def main(args): |
|
if args.gpus: |
|
if args.num_gpus and len(args.gpus.split(",")) < int(args.num_gpus): |
|
raise ValueError(f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!") |
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus |
|
if args.style == "simple": |
|
chatio = SimpleChatIO() |
|
elif args.style == "rich": |
|
chatio = RichChatIO() |
|
else: |
|
raise ValueError(f"Invalid style for console: {args.style}") |
|
try: |
|
if args.answer_flag: |
|
print("answer loop") |
|
answer_loop( |
|
args.model_path, |
|
args.device, |
|
args.num_gpus, |
|
args.max_gpu_memory, |
|
args.load_8bit, |
|
args.conv_template, |
|
args.temperature, |
|
args.max_new_tokens, |
|
chatio, |
|
args.debug, |
|
args.question_path, |
|
args.caption_path, |
|
args.data_info_path, |
|
args.answer_path, |
|
|
|
) |
|
else: |
|
print("question loop") |
|
|
|
if os.path.exists(args.caption_path): |
|
print("caption.json already exists") |
|
|
|
question_loop( |
|
args.model_path, |
|
args.device, |
|
args.num_gpus, |
|
args.max_gpu_memory, |
|
args.load_8bit, |
|
args.conv_template, |
|
args.temperature, |
|
args.max_new_tokens, |
|
chatio, |
|
args.debug, |
|
args.question_path, |
|
args.caption_path, |
|
|
|
) |
|
except KeyboardInterrupt: |
|
print("exit...") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model-path", |
|
type=str, |
|
default="facebook/opt-350m", |
|
help="The path to the weights", |
|
) |
|
parser.add_argument( |
|
"--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda" |
|
) |
|
parser.add_argument( |
|
"--gpus", |
|
type=str, |
|
default=None, |
|
help="A single GPU like 1 or multiple GPUs like 0,2" |
|
) |
|
parser.add_argument("--num-gpus", type=str, default="1") |
|
parser.add_argument( |
|
"--max-gpu-memory", |
|
type=str, |
|
help="The maximum memory per gpu. Use a string like '13Gib'", |
|
) |
|
parser.add_argument( |
|
"--load-8bit", action="store_true", help="Use 8-bit quantization." |
|
) |
|
parser.add_argument( |
|
"--conv-template", type=str, default=None, help="Conversation prompt template." |
|
) |
|
parser.add_argument("--temperature", type=float, default=0.7) |
|
parser.add_argument("--max-new-tokens", type=int, default=512) |
|
parser.add_argument( |
|
"--style", |
|
type=str, |
|
default="simple", |
|
choices=["simple", "rich"], |
|
help="Display style.", |
|
) |
|
parser.add_argument("--debug", action="store_true") |
|
parser.add_argument("--question-path", type=str, default=None) |
|
parser.add_argument("--caption-path", type=str, default=None) |
|
parser.add_argument("--answer-flag", type=bool, default=False) |
|
parser.add_argument("--data-info-path", type=str, default="../Test_frameqa_question-balanced.csv") |
|
parser.add_argument("--answer-path", type=str, default="data_processed.json") |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|